近期,笔者深入研究了大模型中的位置编码工作。其中,Sinusoidal位置编码作为一篇基石性的研究,在"Attention Is All You Need"中首次被提出,为深度学习领域带来了革命性的Transformer架构。这种编码方法旨在解决Transformer在自然捕捉序列位置信息上的挑战。深入了解Sinusoidal位置编码不仅有助于领会其核心思想,还为我们理解其他如RoPE等位置编码算法提供了坚实的基础。
1. Sinusoidal位置编码的推导
在深度学习中,Transformer模型因其纯Attention机制而著称。由于这种机制的全对称性,模型天然满足恒等式f(x,y) = f(y,x) ,导致Transformer无法有效识别输入序列的位置。这种对称性意味着,无论如何调换输入顺序,输出都保持不变。
为解决这一问题,研究者引入了位置编码,为每个输入位置加入一个独特的向量。当每个位置的编码向量都不同,这种对称性便被打破。
\begin{equation}\tilde{f}(\cdots,\boldsymbol{x}_m,\cdots,\boldsymbol{x}_n,\cdots)=f(\cdots,\boldsymbol{x}_m + \boldsymbol{p}_m,\cdots,\boldsymbol{x}_n + \boldsymbol{p}_n,\cdots)\end{equation}
为了深入探讨位置编码的性质,我们进一步将其简化为m和n两个位置上的编码,并采用泰勒展开到二阶进行近似。
\begin{equation}\tilde{f}\approx f + \boldsymbol{p}_m^{\top} \frac{\partial f}{\partial \boldsymbol{x}_m} + \boldsymbol{p}_n^{\top} \frac{\partial f}{\partial \boldsymbol{x}_n} + \frac{1}{2}\boldsymbol{p}_m^{\top} \frac{\partial^2 f}{\partial \boldsymbol{x}_m^2}\boldsymbol{p}_m + \frac{1}{2}\boldsymbol{p}_n^{\top} \frac{\partial^2 f}{\partial \boldsymbol{x}_n^2}\boldsymbol{p}_n + \underbrace{\boldsymbol{p}_m^{\top} \frac{\partial^2 f}{\partial \boldsymbol{x}_m \partial \boldsymbol{x}_n}\boldsymbol{p}_n}_{\boldsymbol{p}_m^{\top} \boldsymbol{\mathcal{H}} \boldsymbol{p}_n}\end{equation}
可以看到,最后一项是交互项,我们将它记为\boldsymbol{p}_m^{\top} \boldsymbol{\mathcal{H}} \boldsymbol{p}_n,希望它能表达一定的相对位置信息。
我们先假设\boldsymbol{\mathcal{H}}=\boldsymbol{I}是单位矩阵,此时\boldsymbol{p}_m^{\top} \boldsymbol{\mathcal{H}} \boldsymbol{p}_n = \boldsymbol{p}_m^{\top} \boldsymbol{p}_n = \langle\boldsymbol{p}_m, \boldsymbol{p}_n\rangle是两个位置编码的内积,我们希望在这个简单的例子中该项表达的是相对位置信息,即存在某个函数g使得
\begin{equation}\langle\boldsymbol{p}_m, \boldsymbol{p}_n\rangle = g(m-n)\end{equation}
这里的\boldsymbol{p}_m, \boldsymbol{p}_n是d维向量,这里我们从最简单d=2入手。对于2维向量,我们借助复数来推导,即将向量[x,y]视为复数x + y\text{i},根据复数乘法的运算法则,我们不难得到:
\begin{equation}\langle\boldsymbol{p}_m, \boldsymbol{p}_n\rangle = \text{Re}[\boldsymbol{p}_m \boldsymbol{p}_n^*]\end{equation}其中\boldsymbol{p}_n^*是\boldsymbol{p}_n的共轭复数,\text{Re}[]代表复数的实部。
为了满足式子 (3),我们假设存在复数 \boldsymbol{q}_{m-n} 使得 \boldsymbol{p}_m \boldsymbol{p}_n^* = \boldsymbol{q}_{m-n}。进一步使用复数的指数形式,我们设
\begin{equation}\boldsymbol{p}_m=r_m e^{\text{i}\phi_m}\end{equation} \begin{equation}\boldsymbol{p}_n^*=r_n e^{-\text{i}\phi_n}\end{equation} \begin{equation}\boldsymbol{q}_{m-n}=R_{m-n} e^{\text{i}\Phi_{m-n}}\end{equation}从这些等式中,我们得到 r_m r_n e^{\text{i}(\phi_m - \phi_n)} = R_{m-n} e^{\text{i}\Phi_{m-n}}。这给出了两个结论:r_m r_n = R_{m-n} 和 \phi_m - \phi_n=\Phi_{m-n}。
对于等式r_m r_n = R_{m-n} ,当 n = m 时,我们得到 r_m^2 = R_0 。为了简化,假设 r_m 为 1,这使得所有的位置编码都位于复平面的单位圆上。这样的选择使得位置编码仅由其相位角 \phi_m 确定,而不是其长度或模。
对于等式\phi_m - \phi_n = \phi_{m-n},我们从中可以看到,当 n 增加 1 时,\phi_m - \phi_n 的差值是常数,即 \phi_1。为了简化,我们可以设置一个常数 \theta 使得 \phi_1 = \theta。因此,对于任何整数 m ,我们有:
\begin{equation}\phi_m = m\theta\end{equation}这意味着 \{\phi_m\} 是一个等差数列,其通解为 m\theta。
所以,结合上述推导和Euler's formula,我们可以表示2维的位置编码为:
\begin{equation}\boldsymbol{p}_m = e^{\text{i}m\theta} \quad \Leftrightarrow \quad \boldsymbol{p}_m = \begin{pmatrix} \cos m\theta \\ \sin m\theta \end{pmatrix}\end{equation}这样,我们得到了位置编码 \boldsymbol{p}_m 的形式。
\begin{align} \langle\boldsymbol{p}_m, \boldsymbol{p}_n\rangle &= \begin{bmatrix} \cos m\theta \\\sin m\theta \end{bmatrix}^{\top}\begin{bmatrix} \cos n\theta \\\sin n\theta \end{bmatrix}\\ & = \cos m \theta \cdot \cos n \theta + \sin m \theta \cdot \sin n \theta \\ & = \cos((m-n)\theta) \\ \end{align}于是函数 g 的形式为
\begin{equation}g(m-n) = \cos((m-n)\theta)\end{equation}
向量的内积满足线性叠加,所以可以将任意偶数维的向量用多个二维向量来表示:
\begin{equation}\boldsymbol{p}_m=\begin{bmatrix} \cos m\theta_0 \\ \sin m\theta_0 \\ \cos m\theta_1 \\ \sin m\theta_1 \\ \vdots \\ \cos m\theta_{d/2-1} \\ \sin m\theta_{d/2-1} \end{bmatrix}\end{equation}2. 远程衰减
Sinusoidal位置编码使用了\theta_i = 10000^{-2i/d},这个形式有一个良好的性质:它使得随着|m-n|的增大,\langle\boldsymbol{p}_m, \boldsymbol{p}_n\rangle有着趋于零的趋势。
3. 一般情况
上述的推导是基于假设\boldsymbol{\mathcal{H}}=\boldsymbol{I}是单位矩阵,对于一般情况,\boldsymbol{\mathcal{H}}由模型学习得到,也就是说至于具体需要什么位置信息,则由模型的训练自行决定。本编码只是提供了一种可能的相对位置编码的实现。
参考资料
- Transformer升级之路:1、Sinusoidal位置编码追根溯源
- https://www.zhihu.com/question/307293465/answer/1028613658