# 旋轉位置嵌入 (Rotary Position Embedding, RoPE)筆記 ## 介紹 RoPE 是一種通過絕對位置編碼的方式,引入相對位置的資訊給自注意力機制(Self-Attention Mechanism)的位置嵌入。 簡單來說,原始 Transformer 架構所使用的是 Sinusoidal 函數是作為位置編碼,**與線性變換後的 QKV 相加**;而 RoPE 則是計算出旋轉位置編碼,**與線性變換後的 QK 相乘**。 以結論來說,給定 m, n 分別為當前計算的 Q, K 之位置,則可以將其視為: $(R_{m}q)^T(R_{n}k) = q^{T}R^{T}_{m}R_{n}k = q^{T}R_{n-m}k$ 其 Q 和 K 之間相對的位置資訊會在做內積時明確被引入,跟讓模型自行學習位置關係的絕對位置編碼不同。 不是上式是為了數學式顯式地表示相對位置編碼的用處,實際應用場景我們仍然是遵守既定的將 K 轉置的方法: $Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V$ 並且,在 RoPE 的使用場景中,我們只會將 Q 和 K 乘上旋轉位置嵌入,而 V 則沒有進行這一處理。 <br/> --- ## 推導 以下的推導來自於原作者苏神的網站,主要參考《[让研究人员绞尽脑汁的Transformer位置编码](https://spaces.ac.cn/archives/8130)》、《[Transformer升级之路:2、博采众长的旋转式位置编码](https://spaces.ac.cn/archives/8265/comment-page-1)》。 假設 $q_m$, $k_m$ 是位置於 m, n 的二維向量,我們將其轉為複數進行內積計算。 在複數域的向量計算中,兩複數向量的內積可以表示成一個向量和另外一個向量的共軛(complex conjugate)相乘的**實部**。 $q_m = a + bi$ $k_n = c + di$ a, b, c, d 為實數, i 為虛數單位。 $k_m$ 的共軛 $k_m^*=c+di$。 $\left \langle q_m, k_m \right \rangle = q_m \cdot k_n^*$ $\rightarrow \left \langle q_m, k_n^* \right \rangle = (a+bi) \cdot (c-di)$ $\rightarrow \left \langle q_m, k_n^* \right \rangle = ac+bd+(bc-ad)i$ 然後我們只取實部: $\left \langle q_m, k_m \right \rangle = Re[q_m, k_n^*] = ac+bd$ 可以視為這樣的一個計算過程。然而,推導還沒有結束。 如果我們把 $q_m$, $k_n$ 分別乘上 $e^{im\theta}$, $e^{in\theta}$,便可視為透過 n, m 加入了絕對位置的資訊。 $\left \langle q_{m}e^{im\theta}, k_{m}e^{in\theta} \right \rangle = Re[(q_{m}e^{im\theta}), (k_{n}e^{in\theta})^*] = Re[q_{m}k_{n}^{*}e^{i(m-n)\theta}]$ > 透過複數的 Euler 公式,我們可以把 $e^{ix}$ 表示成: > > $e^{ix} = cos(x) + isin(x)$ > > 所以在 $e^{im\theta}$ 和 $e^{in\theta}$ 的理解上,可以視為其表示在複數平面上的『旋轉』,擁有週期性與順序。 > ![image](https://hackmd.io/_uploads/SkVdrNBAT.png) > 引用自 https://en.wikipedia.org/wiki/Euler%27s_formula $e^{in\theta} = cos(n\theta)+isin(n\theta)$ $\rightarrow (e^{in\theta})^* = cos(n\theta)-isin(n\theta) = e^{-in\theta}$ 所以根據定義,我們可以推導出: $\begin{align} \langle q_{m}e^{im\theta}, k_{n}e^{in\theta} \rangle &= \text{Re}\left[ (q_{m}e^{im\theta}) \cdot (k_{n}e^{in\theta})^* \right] \\ &= \text{Re}\left[ q_{m}e^{im\theta} \cdot \overline{k_{n}e^{in\theta}} \right] \\ &= \text{Re}\left[ q_{m}e^{im\theta} \cdot (k_{n}^*e^{-in\theta}) \right] \\ &= \text{Re}\left[ q_{m}k_{n}^*e^{i(m-n)\theta} \right]\end{align}$ 既然我們確認了可以透過內積計算得到 m-n 的旋轉關係,即對於位置 m 的 $q_m$ 二維向量來說,乘上 $e^{im\theta}$ 便可以藉由後續計算得到與 $k_n$ 之間的相對資訊,我們可以將其旋轉矩陣的形式表達成: $qe^{im\theta}= \bigl(\begin{smallmatrix} cos(m\theta) & -sin(m\theta) \\ sin(m\theta) & cos(m\theta) \end{smallmatrix}\bigr)\binom{q_0}{q_1}$ 以上是在二維向量的情況。如果是在任務偶數維度 d 維的情況下,我們可以將旋轉矩陣拼接成,給定位置為 m 的向量 q 乘上矩陣 ${R_m}$: $R_mq$ 將其展開為: $\begin{bmatrix} cos(m\theta_{0}) & -sin(m\theta_{0}) & 0 & 0 & ... & 0 & 0\\ sin(m\theta_{0}) & cos(m\theta_{0}) & 0 & 0 & ... & 0 & 0 \\ 0 & 0 & cos(m\theta_{1}) & -sin(m\theta_{1}) & ... & 0 & 0\\ 0 & 0 & sin(m\theta_{1}) & cos(m\theta_{1}) & ... & 0 & 0\\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots\\ 0 & 0 & 0 & 0 & 0 & cos(m\theta_{d/2-1}) & -sin(m\theta_{d/2-1})\\ 0 & 0 & 0 & 0 & 0 & sin(m\theta_{d/2-1}) & cos(m\theta_{d/2-1}) \end{bmatrix}\begin{bmatrix}q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1}\end{bmatrix}$ 然而由於 $R_m$ 過於稀疏,所以在工程實現上,可以將其視為另一等價表示: $\begin{pmatrix} q_0 \\ q_1 \\ q_2 \\ q_3 \\ \vdots \\ q_{d-2} \\ q_{d-1} \end{pmatrix}\bigotimes\begin{pmatrix} cos(m\theta_{0}) \\ cos(m\theta_{0}) \\ cos(m\theta_{1}) \\ cos(m\theta_{1}) \\ \vdots \\ cos(m\theta_{d/2-1}) \\ cos(m\theta_{d/2-1}) \\ \end{pmatrix}+\begin{pmatrix} -q_1 \\ q_0 \\ -q_3 \\ q_2 \\ \vdots \\ -q_{d-1} \\ q_{d-2}\end{pmatrix}\bigotimes\begin{pmatrix} sin(m\theta_0) \\ sin(m\theta_0) \\ sin(m\theta_1) \\ sin(m\theta_1) \\ \vdots \\ sin(m\theta_{d/2-1}) \\ sin(m\theta_{d/2-1}) \end{pmatrix}$ 註:以上 $\bigotimes$ 符號為對位相乘。 <br/> --- ## 實作 按照以上苏神給出的推導結果,我們可以直接透過 PyTorch 進行實作。 唯一需要注意的是,在上方推導中,我們是假設輸入的位置資訊為一指定位置 m。但在實際計算中,我們是將 (0...max_len) 的全數位置一同進行計算。 ```python= class RoPEPositionEmbedding(torch.nn.Module): def __init__(self, dim: int, max_len: int = 512, base: int = 10000) -> None: super().__init__() self.theta = 1 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.theta = self.theta.repeat_interleave(2) self.position_ids = torch.arange(0, max_len) def forward(self, x: torch.Tensor): position_matrix = torch.outer(self.position_ids, self.theta) cos = torch.cos(position_matrix) sin = torch.sin(position_matrix) _x = torch.empty_like(x) _x[..., 0::2] = -x[..., 1::2] _x[..., 1::2] = x[..., 0::2] _x = _x * sin x = x * cos out = x + _x return out ``` 實作完成後,我有與 transformers 中的 RoPE 做過比較,顯然與那邊實現的方式不同。而在與開源專案:https://github.com/lucidrains/rotary-embedding-torch 的比較上,我的實現與其: ```python= import torch from rotary_embedding_torch import RotaryEmbedding # instantiate the positional embedding in your transformer and pass to all your attention layers rotary_emb = RotaryEmbedding(dim = 32) # mock queries and keys - dimensions should end with (seq_len, feature dimension), and any number of preceding dimensions (batch, heads, etc) q = torch.randn(1, 8, 1024, 64) # queries - (batch, heads, seq len, dimension of head) k = torch.randn(1, 8, 1024, 64) # keys # apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention) q = rotary_emb.rotate_queries_or_keys(q) k = rotary_emb.rotate_queries_or_keys(k) # then do your attention with your queries (q) and keys (k) as usual ``` 的輸出結果一模一樣。應可斟酌參考。 至於 transformers 中 Mistral / Llama-2 的 RoPE 實現跟原始版本哪裡不同,具體來說差異體現在: 原本的實現: $\begin{pmatrix} q_0 \\ q_1 \\ q_2 \\ q_3 \end{pmatrix}\bigotimes\begin{pmatrix} cos(m\theta_{0}) \\ cos(m\theta_{0}) \\ cos(m\theta_{1}) \\ cos(m\theta_{1}) \end{pmatrix}+\begin{pmatrix} -q_1 \\ q_0 \\ -q_3 \\ q_2 \end{pmatrix}\bigotimes\begin{pmatrix} sin(m\theta_0) \\ sin(m\theta_0) \\ sin(m\theta_1) \\ sin(m\theta_1) \end{pmatrix}$ 但在 transformers 的實現中,卻可以看成: $\begin{pmatrix} q_0 \\ q_1 \\ q_2 \\ q_3 \end{pmatrix}\bigotimes\begin{pmatrix} cos(m\theta_{0}) \\ cos(m\theta_{1}) \\ cos(m\theta_{0}) \\ cos(m\theta_{1}) \end{pmatrix}+\begin{pmatrix} -q_2 \\ -q_3 \\ q_0 \\ q_1 \end{pmatrix}\bigotimes\begin{pmatrix} sin(m\theta_0) \\ sin(m\theta_1) \\ sin(m\theta_0) \\ sin(m\theta_1) \end{pmatrix}$ 所以若是要實現與 transformers 中的 RoPE 同樣的版本,則需要寫成: ```python= class HFRoPEPositionEmbedding(torch.nn.Module): def __init__(self, dim: int, max_len: int = 512, base: int = 10000) -> None: super().__init__() self.theta = 1 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.theta = torch.cat([self.theta, self.theta], dim=-1) self.position_ids = torch.arange(0, max_len) def forward(self, x: torch.Tensor): position_matrix = torch.outer(self.position_ids, self.theta) cos = torch.cos(position_matrix) sin = torch.sin(position_matrix) x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] _x = torch.cat([-x2, x1], dim=-1) x = x * cos _x = _x * sin out = x + _x return out ``` 雖然與原始的複數旋轉不同,但仍然是一種試圖應用旋轉的概念捕捉相對位置編碼。以上是一些個人看原始碼後的淺見,若有誤還請各方大神不吝指出。 --- ## References - [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) - 《[让研究人员绞尽脑汁的Transformer位置编码](https://spaces.ac.cn/archives/8130)》 - 《[Transformer升级之路:2、博采众长的旋转式位置编码](https://spaces.ac.cn/archives/8265/comment-page-1)》。