# 論文閱讀 : YaRN: Efficient Context Window Extension of Large Language Models ## 摘要 - RoPE(Rotary Position Embedding) 是一種在 Transformer 中編碼位置資訊的有效方法。 - 但 RoPE 有一個主要限制:無法泛化超過訓練時的序列長度(context length)。 問題 - 目前使用 RoPE 的語言模型,在超出原始訓練的 context 長度時表現不佳。 - 以往的延伸方法成本高昂(需要大量 token 和訓練步數)。 ### 核心貢獻 提出了一種新的方法: - YaRN(Yet another RoPE extensioN method) **YaRN 的特點:** - 更有效率的 context window 擴展方法 - 相較於先前方法,訓練所需 token 數少了 10 倍。 - 訓練步數少了 2.5 倍。 - 能讓 LLaMA 模型有效利用比原始訓練更長的序列長度。 - 在 context 擴展任務上超越過去的 SOTA 方法。 - 可以在 fine-tune 數據集本身 context 較短的情況下,仍能泛化到更長序列。 ### 實驗成果 - 利用 YaRN,模型能泛化至 最多 128k 的 context length。 - 官方已開源模型與程式碼: - https://github.com/jquesnelle/yarn YaRN 是一個針對 RoPE 的 高效 context 延伸方法,不但節省資源,也實證能在極長文本中維持模型效能,是目前 延長 Transformer context window 的 SOTA 方法之一。 ## Introduction ### Transformer-based LLM 的限制 * Transformer 是目前主流的 NLP 模型架構,具備強大的 **長距依賴建模能力**(如 In-Context Learning)。 * **最大 context 長度**(context window)會受限於模型訓練期間設定的上限,這會限制模型的應用能力。 * 如:無法處理長篇文本、無法進行長範例的 in-context learning。 ### 延伸 context window 的需求 * 希望能以「**少量微調**或**免微調**」的方式,**動態擴展 context window**。 * 為達成這目標,**位置編碼(Positional Encoding)** 是關鍵。 ## 相關研究回顧 ### Transformer 的位置編碼演進 * 最早是 **絕對正弦位置編碼(sinusoidal)**。 * 後來出現 **可學習的絕對編碼**。 * 更進一步發展為 **相對位置編碼(relative positional encoding)**,提升效能。 * 幾個流行的相對位置編碼方法: * **T5 Relative Bias** * **RoPE**(Rotary Position Embedding) * **XPos** * **ALiBi** ### 共同限制 * 多數位置編碼無法泛化到比訓練階段更長的序列。 * ALiBi 雖有部分泛化能力,但仍無法應對「大幅」延長的情境。 ## 既有方法與發展 ### Position Interpolation (PI) * 如 \[Chen et al., 2023] 提出對 RoPE 進行 **位置插值** 並搭配少量資料進行微調。 ### NTK-aware 方法系列(考慮頻率資訊的插值) * 解決高頻 RoPE 扭曲問題的插值方法: | 方法名稱 | 特點 | 是否需要微調 | | ----------------------- | -------------------- | ------ | | NTK-aware interpolation | 基本版本,考慮頻率縮放 | 是 | | **Dynamic NTK** | 無需微調即可套用 | 否 | | **NTK-by-parts** | 搭配少量長 context 微調效果最好 | 是 | > 許多開源模型已採用這些技術,如: > > * **Code Llama** 使用 NTK-aware interpolation。 > * **Qwen 7B** 使用 Dynamic NTK。 --- ## 本文貢獻:YaRN 方法 ### YaRN(Yet another RoPE extensioN) * 一種 **針對 RoPE 的 context 延伸方法**,可套用於: * LLaMA * GPT-NeoX * PaLM 等模型族群 ### YaRN 的優勢 * **訓練資料使用極少**(< 0.1% 的 pretraining 資料量) * **延伸後效果達 SOTA** * 可結合 **Dynamic Scaling**(只在推理時調整),形成: * **Dynamic-YaRN**:**無需微調即可將 context 長度擴大 2 倍以上** ## 小結整理 | 類別 | 方法 | 需要微調 | 特點 | | --------- | ---------------------- | ----- | ------------------------------- | | 基礎方法 | RoPE | 是 | 預設用於 LLaMA / GPT 等 | | 改進方法 | Position Interpolation | 是 | 微調 RoPE 參數 | | NTK 系列 | NTK-aware | 是 | 頻率縮放以解扭曲 | | NTK 系列 | Dynamic NTK | 否 | 可直接用於 pre-trained 模型 | | NTK 系列 | NTK-by-parts | 是 | 搭配少量長 context 效果佳 | | **本論文方法** | **YaRN** | 是 / 否 | 微調少、效能 SOTA,可搭配 Dynamic Scaling | ## Background and Related Work ## RoPE 背後的數學直觀 ### 基本流程 1. **輸入向量 $x\_1, \dots, x\_L \in \mathbb{R}^{|D|}$:** * 每個位置 $m$ 的輸入會轉換成: * 查詢向量 $q\_m$ * 鍵向量 $k\_n$ 2. **基本 attention score:** $$ \text{softmax} \left( \frac{q_m^T k_n}{\sqrt{|D|}} \right) $$ * 這是標準的 self-attention 中的點積公式。 --- ### RoPE 的核心:複數旋轉嵌入(Rotary Embedding) #### 核心設計 * 假設 hidden dim $|D|$ 是偶數,則: $$ \mathbb{R}^{|D|} \cong \mathbb{C}^{|D|/2} $$ * 將實數向量視為複數對表示,例如: $$ (x_1, x_2, \dots) \mapsto (x_1 + ix_2, \dots) $$ * 此設計讓每對連續的維度構成一個 **複數數值(實部 + 虛部)**。 #### 複數旋轉 * 為每個位置 $m$ 乘上一個 **位置相關的複數旋轉相位因子**: $$ f_q(x_m, m) = e^{m{\theta}} mW_q x_m $$ $$ f_k(x_n, n) = e^{m{\theta}} mW_k x_n $$ * 其中 $m{\theta} = \text{diag}(\theta\_1, \dots, \theta\_{|D|/2})$ 是對角頻率矩陣,每一對維度都有一個不同頻率。 --- ## RoPE 的關鍵特性 ### 相對位置不變性 RoPE 保證: $$ \text{Re} \left( \langle f_q(x_m, m), f_k(x_n, n) \rangle_{\mathbb{C}} \right) = \text{function of } (m - n) $$ * dot product 僅與 **相對位置差距 \$m - n\$ 有關**。 * 相對位置編碼的特性:**能跨序列位置平移,模型更穩定泛化。** --- ## 對應的實數矩陣形式(方便實作) 轉為 real-valued 實作時,每對維度構成一個 2x2 的旋轉矩陣,整體矩陣形式如下: ```text 每個位置 m 對應一組 2x2 旋轉矩陣,頻率為 θ_d: [ cos(mθ_d), -sin(mθ_d) ] [ sin(mθ_d), cos(mθ_d) ] ``` 拼成大矩陣後作用於輸入向量 $x\_m$,形成查詢向量: $$ f_{mW}(x_m, m, \theta_d) = R(m, m{\theta}) \cdot mW x_m $$ --- ## 物理意涵補充(推論) * 每個頻率 $\theta\_d$ 決定旋轉速度(高頻:快旋轉)。 * 不同維度感知不同粒度的相對位置資訊。 * **越靠後的維度對位置的感知越精細,越敏感於小變化。** --- | 特性 | 描述 | | ---------- | -------------------------------------------- | | 複數旋轉 | 將實向量轉為複數形式以應用旋轉 | | 頻率分布 | 每對維度一個專屬頻率 $\theta\_d =b^{-2d$ | | 相對位置性 | attention score 僅依賴 $m - n$(相對位置) | | 高效實作 | 可對應實數版旋轉矩陣,便於編碼實作 | | 結構泛化性 | 適合延伸 context 長度的研究 | ## 小節標題:Position Interpolation(PI)位置插值法 --- ### 背景動機 * **問題:** * 大型語言模型(LLM)通常在固定的 context 長度 $L$ 下預訓練。 * 若直接將測試序列長度延長到 $L' > L$(即 extrapolation),模型效果會明顯下降。 * **目標:** * 希望能透過 **少量資料進行微調**,使模型泛化到更長的 context。 --- ### 方法:位置插值(Position Interpolation) #### 提案者 * 由兩篇工作幾乎同時提出: * \[Chen et al., 2023] * \[Kaiokendev, concurrent work] #### 適用對象 * 專門針對 **使用 RoPE 的語言模型**,如 LLaMA、GPT-NeoX 等。 --- ### 核心公式與操作邏輯 **原本 RoPE 的位置輸入是 $m$,但當序列長度變長到 \$L' > L\$,我們將位置進行插值處理如下:** $$ f'_mW(x_m, m, \theta_d) = f_mW\left(x_m, \dfrac{mL}{L'}, \theta_d\right) $$ #### 說明: * $L$:模型訓練時的最大 context 長度。 * $L'$:希望延伸到的新 context 長度。 * 原本 $m \in [0, L']$,現在透過位置比例縮放,**將 $m$ 轉換為 $m \cdot L / L'$**。 * 這樣做的效果是:**把長序列「壓縮」回模型熟悉的位置範圍內**,再餵進原本的 RoPE 機制。 #### 注意: * 這不是 RoPE 的內部改動,而是修改「位置輸入」。 * 也就是:**RoPE 還是原本的權重與設計,只是換了一個位置輸入的計算方式**。 --- ### 實驗與結果 * 作者進行了少量資料微調(例如幾十億 tokens,相對 pretraining 是幾個數量級的下降),便能: * 成功延伸 context window。 * 在新長度 $L'$ 下達到良好性能。 --- ## 小結:Position Interpolation 方法精華 | 項目 | 說明 | | ---- | ------------------------------ | | 適用模型 | 使用 RoPE 的 Transformer(如 LLaMA) | | 技術核心 | 將位置 $m$ 按比例縮放回預訓練區間 | | 公式 | $m \to \dfrac{mL}{L'}$ | | 優點 | 不需修改模型結構,只要換 RoPE 輸入位置即可 | | 成果 | 透過少量 fine-tuning 成功擴展 context | | 限制 | 仍需 fine-tuning,且效果可能受頻率影響 | ## 小節標題:補充符號(Additional Notation) --- ## 1. **延伸比例:scale factor $s$** ### 定義: $$ s = \frac{L'}{L} $$ * $L$:模型原本訓練的最大 context 長度。 * $L'$:希望擴展到的新 context 長度。 * $s$ 表示「放大倍率」,是所有 RoPE 擴展方法中最重要的超參數之一。 --- ## 2. **統一公式表示 RoPE 擴展方法** 為了概括各種 RoPE 變形方式,引入以下統一形式: $$ f'_{mW}(x_m, m, \theta_d) = f_{mW}(x_m, \underbrace{g(m)}_{\text{新位置}}, \underbrace{h(\theta_d)}_{\text{新頻率}}) $$ ### 對應: * $g(m)$:修改後的位置函數。 * $h(\theta\_d)$:修改後的頻率函數。 ### 舉例:Position Interpolation (PI) $$ g(m) = \frac{m}{s}, \quad h(\theta_d) = \theta_d $$ → 只改位置,不改頻率。 --- ## 3. **RoPE 頻率對應波長:$\lambda\_d$** ### 定義: $$ \lambda_d = \frac{2\pi}{\theta_d} = 2\pi \cdot b^{\frac{2d}{|D|}} $$ * $b$ 為 RoPE 中的基數(通常是 10000)。 * 每個維度 $d$ 對應一個 **旋轉頻率 $\theta\_d$**,其週期為 \$\lambda\_d\$。 * 換言之,**RoPE 在該維度每走 $\lambda\_d$ token,旋轉一圈($2\pi$)**。 --- ## 4. 方法分類依據:是否考慮波長 作者根據方法是否考慮波長($\lambda\_d$)進行分類: | 分類 | 描述 | 範例 | | --------------------------------- | ----------------- | ---------- | | **盲目插值(Blind Interpolation)** | 只改位置,不管每個維度的頻率或波長 | PI | | **目標式插值(Targeted Interpolation)** | 根據頻率/波長特性調整位置與頻率 | YaRN 等進階方法 | --- ## 小結:符號統一與未來對照基礎 | 符號 | 定義 | 說明 | | ---------------------------------------- | ---------------- | -------------- | | $s = \dfrac{L'}{L}$ | Scale factor | 延伸倍率 | | $g(m)$ | 新位置函數 | 決定輸入位置如何被調整 | | $h(\theta\_d)$ | 新頻率函數 | 決定頻率是否縮放 | | $\lambda\_d = \dfrac{2\pi}{\theta\_d}$ | RoPE 波長 | 走幾個 token 才轉一圈 | | 方法分類 | Blind / Targeted | 有無考慮頻率結構(波長) | ## 方法論 ### 改進 PI * **PI 的核心問題**:對所有 RoPE 維度一視同仁地進行線性縮放(位置除以 $s$)。 * 作者發現:這種 **統一比例插值無法捕捉 RoPE 與模型內部嵌入的複雜動態**。 * 因此,YaRN 的方法組合解決這些關鍵缺陷,第一個要處理的是: ### Loss of High Frequency information – 「NTK-aware」插值 ### 背景與動機 ### 問題: * RoPE 本質上類似 **1D Fourier Features**(見 Tancik et al., 2020),嵌入中包含了多頻率成分(高頻 + 低頻)。 * PI 方法會 **等比例縮放所有頻率 $\theta\_d$**,導致: * **高頻資訊被壓縮得過度**,模型在區分「很接近」或「相似」的 token 上失效。 * 實驗觀察也發現:PI 雖能延長 context,但對短序列效能反而下降,可能與此高頻信息損失有關。 ### NTK-aware 插值法 * 與其對所有頻率都做一樣的伸縮,我們可以 **對不同維度(頻率)使用不同的縮放倍率**。 * **保留高頻不變,只縮放低頻** → 分散插值壓力,避免失真。 #### 數學定義 $$ g(m) = m \quad (\text{不改變位置}) \\ h(\theta_d) = b'^{-2d / |D|} \quad (\text{改變頻率}) $$ 其中: $$ b' = b \cdot s^{|D| / (|D| - 2)} $$ * $b$:原本 RoPE 中的 base,預設值為 10000。 * $s = L'/L$:scale factor。 ### 意涵: * **高頻維度(d 大)變動小;低頻維度(d 小)縮放多**。 * 整體調整頻率分佈以適應長序列,同時保留精細位置區辨能力。 ## 實驗與效果 | 面向 | 表現 | | -------- | -------------------------------------------- | | 無需微調情境 | NTK-aware 明顯優於 PI | | 有微調情境 | 效果反而比 PI 差,因為部分頻率被 extrapolate(超出原始範圍) | | 開源應用 | Code LLaMA 採用了 NTK-aware 方法,並手動設定 $b=10^6$ | ## 缺點與限制 1. **部分頻率超出訓練範圍(Out-of-bound)**: * 理論上的 $s$ 失去精確描述力。 * 實際使用時,常常需要設定 **比實際延伸比例更大的 \$s\$**。 2. **微調表現差**: * 在需微調的任務中,不如 PI 直接且穩定。 ### Loss of Relative Local Distances – “NTK-by-parts” interpolation ### 問題背景 ### 1. Blind interpolation 的侷限 * PI 和 NTK-aware 都屬於 **blind interpolation**(雖然 NTK-aware 有頻率調整,但仍對所有維度統一套用某種變換規則),假設所有 RoPE 維度對模型的重要性相同。 * 事實上,RoPE 各維度的 **波長**($\lambda\_d$)差異極大,模型可能對不同頻率有不同用途。 --- ### 2. 從波長觀察的啟示 * 在原始 context size $L$ 下: * 若某維度的波長 $\lambda > L$ → 該維度的旋轉在整個 context 內不會繞一圈,位置是唯一的 → 提供**絕對位置資訊**。 * 若 $\lambda \ll L$ → 該維度會重複多次,無法唯一標識位置 → 只能提供**相對位置資訊**。 ### 3. 為盲目插值會傷害模型 * 無論是 PI 還是 NTK-aware,對所有維度做縮放(位置或頻率),都會讓 token 之間的旋轉角度變小 → **局部距離被壓縮**。 * 壓縮後: * 局部 token 的內積變大 * 模型對相鄰 token 的區分能力下降 * 容易混淆 token 順序,影響短距依存關係的理解 ### 解法:NTK-by-parts 插值法 **核心想法**:根據每個維度的波長 $\lambda\_d$(或其比例 $r$)決定插值程度,而非一刀切。 ### 步驟 1. 定義比例: $$ r(d) = \frac{L}{\lambda_d} = \frac{L}{2\pi b'^{\frac{2d}{|D|}}} $$ * $r < 1$ → 波長比 context 長 → 絕對位置資訊 * $r \gg 1$ → 波長短 → 相對位置資訊 2. 設定兩個閾值 $\alpha, \beta$: * $r(d) < \alpha$ → 波長長 → 完全插值(像 PI),但避免外推 * $r(d) > \beta$ → 波長短 → 不插值(保留高頻局部資訊) * $\alpha \le r(d) \le \beta$ → 線性過渡(部分插值) 3. 定義 ramp function $\gamma(r)$: $$ \gamma(r) = \begin{cases} 0, & r < \alpha \\ 1, & r > \beta \\ \frac{r - \alpha}{\beta - \alpha}, & \text{其他情況} \end{cases} $$ * $\gamma=0$ → 完全插值 * $\gamma=1$ → 不插值 * 中間值 → 按比例混合 4. 插值公式(對頻率 $\theta\_d$): $$ g(m) = m $$ $$ h(\theta_d) = (1 - \gamma(r(d))) \cdot \frac{\theta_d}{s} + \gamma(r(d)) \cdot \theta_d $$ * 當 $\gamma=0$:$\theta\_d \to \theta\_d/s$(縮放頻率) * 當 $\gamma=1$:$\theta\_d$ 不變 * 中間:線性混合兩者 ## 實驗建議參數 * 在 LLaMA 模型中: * $\alpha = 1$ * $\beta = 32$ * 這組設定在非微調與微調場景中,都優於 PI 與 NTK-aware。 ### 方法比較 | 方法 | 插值依據 | 高頻維度 | 低頻維度 | 微調相容性 | | ---------------- | ------- | ---- | ------- | ----- | | PI | 統一位置縮放 | 壓縮 | 壓縮 | 好 | | NTK-aware | 按頻率縮放 | 較少壓縮 | 較多壓縮 | 微調較差 | | **NTK-by-parts** | 按波長分區縮放 | 保留 | 插值或部分插值 | 好 | ### **Dynamic Scaling – “Dynamic NTK” interpolation** ### 推理場景特性 * 很多任務在推理時會進行多次 forward pass,且序列長度從 1 一直增長到最大 context 長度。 * 典型例子:自回歸生成(autoregressive generation),每生成 1 個 token,序列長度就 +1。 ### 插值方法的兩種應用方式 1. **固定 scale factor** * 在整個推理過程中,RoPE 的 scale factor 固定為 $$ s = \frac{L'}{L} $$ 其中 $L'$ = 預設延伸後的 context 長度。 2. **動態 scale factor(Dynamic Scaling)** * 每次 forward pass 根據當前序列長度 $l'$ 動態調整: $$ s = \max\left(1, \frac{l'}{L}\right) $$ * $l'$ = 當前序列長度 * 當 $l' < L$ → $s=1$(不縮放) * 當 $l' > L$ → 按比例縮放 ### 固定 scale factor 的問題 * 在 $l' < L$ 的階段也被縮放 → 短序列性能下降。 * 當 $l'$ 超過 $L'$ 時,性能會**突然崩壞**(abrupt degradation)。 --- ### Dynamic Scaling 的優點 * **在短序列時不動作**,保留原有性能。 * **在超過 $L$ 後才逐步縮放**,讓性能「平滑退化」(graceful degradation),不會突然掉下去。 * 適合推理階段使用,不需要改動模型結構。 ### Dynamic NTK * **Dynamic Scaling + NTK-aware** = **Dynamic NTK interpolation**。 * 最早由一篇 reddit 貼文提出 * 特別適合用在 **$L'=L$ 且無任何 fine-tuning** 的模型延伸 context: * 在不改訓練資料的情況下,效果異常好。 * 實驗證據見 Appendix ### 與 KVCache 相容性 * 推理時常使用 **KV Caching**(儲存過去的 key/value 來節省計算)。 * 但 Dynamic Scaling 有個坑: * RoPE 的旋轉會依據 $s$ 改變 → **同一 token 在不同 $s$ 下旋轉結果不同**。 * 如果把「已旋轉過的 embedding」直接 cache 起來,後續調整 $s$ 會出錯。 * **正確作法**: * Cache **RoPE 之前**的 KV embedding * 每次需要時用當前 $s$ 重新套用 RoPE。 ## 加溫度 * 延長 context 會造成 Attention 分布**熵變化**(分數變得更尖或更鈍) * 調整溫度可以在 Softmax 前重新校準權重分布,使模型在長 context 下保持穩定的關係建模能力 * 實作成 **RoPE embedding 整體縮放** 更簡潔,也可和任何 Attention 優化(如 Flash Attention 2)相容 --- ## 優勢 | 特性 | 說明 | | --- | ------------------------------------------------------------ | | 效果 | 在 fine-tune 與 non-fine-tune 情境下均優於 PI、NTK-aware、NTK-by-parts | | 成本 | 0 額外計算量,僅需改 RoPE embedding 生成步驟 | | 相容性 | 可直接配合 Flash Attention 2 等高效注意力庫 | | 通用性 | $t$ 的選擇對不同規模的 LLaMA/LLaMA 2 都適用 | ## 實驗 * **YaRN 成功將使用 RoPE 的模型延長 context window**。 * 訓練成本極低: * 僅需 **400 步訓練**(約佔原預訓練語料的 0.1%) * 相比: * 比 Code LLaMA 減少 **10 倍** token * 比 Chen et al. 2023 減少 **2.5 倍** 訓練步數 * 無額外推理成本,推理效能不受影響。 * 在長文困惑度(perplexity)與標準基準測試上均**超越所有其他 context 延伸方法**。 --- ## 訓練設置 ### 模型 * **LLaMA 2**:7B 與 13B 參數版本 * 唯一改動:嵌入頻率的計算(依 YaRN 方法,§\ref{sec\:yarn}) * 設定兩種 scale factor: * $s=16$ * $s=32$ ### 超參數 * 學習率:$2\times 10^{-5}$ * 無 weight decay * Linear warmup:20 steps * AdamW 優化器:$\beta\_1 = 0.9, \beta\_2 = 0.95$ * 全域 batch size:64 * 分布式訓練:PyTorch Fully Sharded Data Parallel (FSDP) + Flash Attention 2 ### 訓練資料 * **PG19 dataset**(長文小說資料集) * 切成 64k tokens segment(每段加 BOS、EOS token) ### 訓練流程 * **$s=16$**: * 從原始 LLaMA 2 開始 * 訓練 400 steps * **$s=32$**: * 從已完成的 $s=16$ checkpoint 繼續 * 再訓練 200 steps --- ## 外推與遷移學習(Extrapolation & Transfer Learning) ### 背景對比 * **Code LLaMA**: * 使用 16k context 的資料,scale factor $s \approx 88.6$ → 相當於 355k context * 實驗顯示:模型可外推到 100k context,雖然訓練時沒見過那麼長的序列 ### YaRN 的測試 * 支援 **訓練 scale factor 大於資料長度** * 在實驗中: * **$s=32$** 模型只用 **64k context** 訓練 * 從 $s=16$ 模型繼續微調 200 steps * 成功外推到 **128k context** * 與 blind interpolation 不同: * YaRN 在遷移學習($s=16 \to s=32$)時非常高效 * 不需重新學習插值嵌入 * $s=32$ 模型在全部 context 長度內行為等同 $s=16$ 模型 ### 評估面向 論文主要從三方面評估: 1. **延伸 context 後的困惑度(perplexity)** 2. **passkey retrieval 任務表現** 3. **常見 LLM 基準測試分數** 這一小節專注在第 1 項(長序列困惑度)。 --- ### 延伸 context 後的困惑度(perplexity) ![image](https://hackmd.io/_uploads/Bk0cOz7dxe.png) ![image](https://hackmd.io/_uploads/SysYuGmuxx.png) ### 評測資料與方法 * **資料集**: * **GovReport**(長篇政府報告) * **Proof-pile**(包含許多超長文本) * **測試集**:僅用官方 test split * **困惑度計算方法**:Sliding window(\$S=256\$)方法【Press et al., 2022】 --- ### 實驗設計 1. **遞增 context 評測** * 從 Proof-pile 中挑選 10 個至少 128k tokens 的樣本 * 評測序列長度從 2k 到 128k(每次 +2k) * 比較 PI、NTK-aware、YaRN 三種延伸方法(全部由 LLaMA-2 7B 延伸到 8k) * PI & NTK-aware:依 Chen et al. (2023) 訓練流程 * YaRN:同流程,但訓練步數與資料量只有 1/2.5 **結果(表 \ref{tab:8k-comparison})**: * 短 context(2k\~8k):YaRN 和 PI 在困惑度上接近、優於 NTK-aware * 超過訓練長度(10k):YaRN 表現最佳(6.04 vs 8.07 / 6.24) --- 2. **更大 scale factor 測試** * 模型:YaRN (\$s=16\$) 與 YaRN (\$s=32\$) * 對照:Together.ai(PI,32k)、Code LLaMA(NTK-aware,100k) * 評測長度:8k、32k、64k、98k、131k tokens * 其中 YaRN(\$s=16\$) 與 (\$s=32\$) 訓練資料長度僅 64k **結果(表 \ref{tab\:proofpile-long-small})**: * **7B 模型**: * \$s=16\$:在 64k 長度內表現佳,但 98k 以上失效(困惑度爆高) * \$s=32\$:128k 仍維持低困惑度(2.37),成功泛化到未見過的長度 * **13B 模型**: * \$s=16\$:64k 內表現佳,98k 以上失效 * \$s=32\$:128k 仍穩定(2.24) * 與其他開源模型相比: * Together.ai(PI):超過 32k 後困惑度急劇惡化 * Code LLaMA(NTK-aware):100k 內表現穩定,但困惑度略高於 YaRN(\$s=32\$) --- ### 觀察 * **YaRN 是第一個成功將 LLaMA 2 的有效 context 擴展到 128k 的方法**。 * \$s=32\$ 模型在訓練只見過 64k tokens 的情況下,能泛化到 128k,顯示強外推能力。 * 相比 blind interpolation(PI, NTK-aware),YaRN 在 transfer learning(s=16 → s=32)與未見長度表現更穩定。 * GovReport 長文結果與 Proof-pile 一致,證明 YaRN 適合長序列任務。 ### Passkey Retrieval * **測試目的**:檢查模型能否在大量無意義文本中,正確找到一個隨機位置的簡單 passkey(五位數字)。 * **測試方法**: * 在不同 context window(8k~128k)中隨機放置 passkey(位置均勻分佈)。 * 每個設定跑 10 次測試。 * **結果**: * YaRN fine-tuned 於 128k 的 7B 與 13B 模型,在全範圍(8k~128k)都能 **>99%** 準確找出 passkey。 * 顯示 YaRN 延長後的模型,在長上下文的檢索能力非常穩定,沒有隨長度退化。 ### 標準化 Benchmarks(Hugging Face Open LLM Leaderboard)評測任務: ![image](https://hackmd.io/_uploads/SyJdOMmOge.png) * **評測任務**: * ARC-Challenge (25-shot) * HellaSwag (10-shot) * MMLU (5-shot) * TruthfulQA (0-shot) * **測試目的**:檢查長上下文擴展後,模型在一般 NLP 任務上的性能是否下降。 * **比較對象**: * LLaMA 2 原始模型(4k) * PI(Together 32k) * NTK-aware(Code Llama 100k) * YaRN (\$s=16\$, 64k) * YaRN (\$s=32\$, 128k) * **結果觀察**: * YaRN ($s=16$) 與 ($s=32$) 幾乎保留原始 LLaMA 2 的性能,性能下降極小。 * YaRN 與 PI、NTK-aware 相比,在多數任務上明顯更接近原始分數,尤其 HellaSwag 與 ARC-Challenge。 * $s=32$ 相比 $s=16$,平均僅 **0.49%** 分數下降,表示從 64k 擴到 128k 幾乎沒有額外精度損失。 * NTK-aware 在部分任務(例如 MMLU)有明顯性能下降,顯示外推策略不同對一般任務影響很大。 ## 結論整理 1. **全面優於現有 RoPE 插值方法* * 在延長上下文長度的任務上,YaRN 對比現有的 Positional Interpolation(PI)等方法都有更佳效果。 * 可直接作為 PI 的替代方案,**無任何缺點**,且實作成本低。 2. **維持原模型能力** * 在多項標準化 LLM 基準測試中,微調後的模型表現幾乎與原模型一致,原有能力不受影響。 * 同時能將可處理的上下文長度大幅提升。 3. **高效外推能力** * 支援「**短訓練 → 長推理**」:只需在短序列資料上微調,即可在推理時處理更長序列。 * 適用於運算資源受限情境,可利用遷移學習加速收斂,降低計算成本。 4. **實務價值** * 在訓練資源有限的情況下,仍能獲得大幅延伸上下文的能力,並在多種應用中直接落地。 ## 附註 ### 先看 RoPE 原始頻率公式 在 RoPE 中,每個維度 $d$ 的旋轉頻率是: $$ \theta_d = b^{-2d / |D|} $$ * $b$:base(預設 10000) * $|D|$:hidden size(例如 4096) * $d$ 越大 → 指數 $-2d/|D|$ 越負 → $\theta\_d$ 越小(頻率更低,波長更長) * **所以:** * **小 $d$ → 高頻(短波長)** * **大 $d$ → 低頻(長波長)** --- ### NTK-aware 如何改頻率 NTK-aware 把原本的 $b$ 換成 $b'$: $$ h(\theta_d) = b'^{-2d/|D|}, \quad b' = b \cdot s^{|D| / (|D| - 2)} $$ 因為 $s > 1$(延長 context),所以: * $b' > b$ → 整體頻率變小(波長變長)。 --- ### 為什麼高頻變動小? 我們來看變化倍率: $$ \frac{h(\theta_d)}{\theta_d} = \frac{b'^{-2d/|D|}}{b^{-2d/|D|}} = \left( \frac{b'}{b} \right)^{-2d/|D|} = \left( s^{|D| / (|D|-2)} \right)^{-2d/|D|} = s^{-2d / (|D|-2)} $$ --- ### 觀察 $d$ 對倍率的影響 * 當 $d$ **很小**(高頻維度): * 指數 $-2d / (|D|-2)$ 接近 0 * $s^0 \approx 1$ → 幾乎不改動頻率 * **高頻維度被保留** * 當 $d$ **很大**(低頻維度): * 指數 $-2d / (|D|-2)$ 是顯著的負數 * $s$ 被提升到較大負指數 → 頻率下降很多(波長變長很多) * **低頻維度被拉得更多** --- ### 直觀解釋 RoPE 的頻率分佈是從高頻到低頻排列的。 NTK-aware 透過公式 $s^{-2d/(|D|-2)}$,讓靠前的高頻維度幾乎不變,而後段低頻維度被大量「拉長」波長,這樣: * **高頻維度 → 保持細粒度位置分辨能力**(不會在短距離混淆 token) * **低頻維度 → 延長感知範圍**(適應長 context) | 維度 \$d\$ | 原波長比例 | 新波長比例(NTK-aware) | 變化量 | | -------- | ----- | ---------------- | --- | | 1 (高頻) | 小 | 幾乎一樣 | 小 | | 4 (中頻) | 中 | 長很多 | 中 | | 8 (低頻) | 大 | 長非常多 | 大 | --- **總結** NTK-aware 的數學設計保證了: * $d$ 小 → 頻率幾乎不變(保高頻) * $d$ 大 → 頻率大幅下降(延長低頻) * 達到「分散插值壓力」的目的,避免像 PI 那樣同時削弱全部頻率。 ### RoPE 與「旋轉角度」的關係 在 RoPE 中,每個 token 在第 $d$ 維的表示,會根據位置 $m$ 被旋轉一個角度: $$ \text{角度} = m \cdot \theta_d $$ * 兩個 token $m$ 和 $n$ 在該維的**相對角度差**是: $$ \Delta\phi_d = (m-n) \cdot \theta_d $$ * 這個角度差越大,表示它們在複數平面上的位置差越明顯 → 內積越小 → 模型容易區分它們。 --- ### PI / NTK-aware 的縮放效應 PI 或 NTK-aware 都會對 RoPE 的位置 \$m\$ 或頻率 \$\theta\_d\$ 做縮放,例如: * PI:$m \to m/s$ * NTK-aware:$\theta\_d \to \theta\_d'$(頻率降低) 無論是縮小 $m$ 還是降低 $\theta\_d$,結果都是讓: $$ \Delta\phi_d' < \Delta\phi_d $$ → **token 間旋轉角度變小**。 --- ### 「局部距離壓縮」 假設原本相鄰 token 的角度差是 **10°**,縮放後可能變成 **5°**。 在高維空間中,這等於: * 它們在該維度的向量變得更接近 * 多個相鄰 token 的表示分佈更擠 → **局部距離壓縮** --- ### 內積變大 內積公式: $$ \langle u, v \rangle = \|u\| \|v\| \cos(\Delta\phi) $$ * 當 $\Delta\phi$ 減小 → $\cos(\Delta\phi)$ 變大 → 內積變大 * 內積大 → 模型覺得它們更相似 --- ### 影響模型的局部辨識能力 * **高頻維度**通常負責區分相鄰 token(短距依存關係)。 * 如果這些高頻維度的角度差被壓縮: * 模型很難分清 token 的順序 * 短距關係(如詞組內的語法結構)辨識會變差 * 在長 context 延伸後,這種影響可能讓短距任務(如精準的局部檢索、短句推理)表現下降。 --- ### 先看 RoPE 旋轉的本質 在第 $d$ 維,RoPE 對位置 $m$ 的旋轉角度是: $$ \phi_{m,d} = m \cdot \theta_d $$ 波長 $\lambda\_d$ 與 $\theta\_d$ 的關係: $$ \lambda_d = \frac{2\pi}{\theta_d} $$ * **$\lambda\_d$ 大** → 一整段 context 旋轉相位不會重複(低頻) * **$\lambda\_d$ 小** → 多次重複相位(高頻) --- ## 「絕對位置區」可以安全插值 ### 絕對位置區特性 * $\lambda\_d > L$ → 旋轉相位對每個位置 $m$ 都是唯一的 (不會出現不同位置映射到相同相位的情況) * 模型靠這些維度可以直接判斷「我是第幾個位置」 ### 插值的影響 * 插值時,位置 $m$ 被縮放($m \to m/s$)或頻率被縮放,整體只是「拉伸」或「壓縮」相位。 * 因為 $\lambda\_d$ 已經遠大於 $L$,插值後仍然不會產生相位重疊 → **不會破壞原本的唯一性**。 * 所以可以安全地把它「拉長」去支撐更大的 context。 --- ## 為什麼「相對位置區」不宜插值 ### 相對位置區特性 * $\lambda\_d \ll L$ → 在原本的 context 長度內,旋轉相位會重複很多次。 * 模型對這些維度的使用更像是「局部座標系」,專門辨別相鄰 token 的相對位置。 ### 插值的影響 * 如果對這些高頻維度做縮放: * 相鄰 token 的旋轉角度差 $\Delta\phi$ 會變小(局部距離壓縮) * 內積變大 → 相鄰 token 的向量更相似 * 造成模型對短距依存的敏感度下降,順序辨識力變差 * 因為它們本來就不是為了長距定位設計的,硬要插值只會破壞它的原本功能。 ### KV Cache 的一般做法 在推理時,為了省計算量,我們會: * 把已計算好的 **Key/Value**(通常已經加上 RoPE 旋轉)存到 cache * 新 token 只需跟這些 cached Key 做 attention 這樣可以避免對舊 token 重複做線性變換和 RoPE 計算。 --- ### 問題出現的原因 Dynamic Scaling 的 $s$ **在推理過程中會變**: * 前幾步 $l'$ 還小,$s=1$ * 後面 $l'$ 超過 $L$,$s$ 開始變大 如果你 **在 cache 裡存的是「已經套過 RoPE 的 Key」**: * 舊 token 的 Key 是用舊的 $s\_{\text{old}}$ 旋轉的 * 新 token 的 Key 是用新的 $s\_{\text{new}}$ 旋轉的 * Attention 計算會不一致,破壞原本的相對位置對齊 → 模型性能崩壞 --- ### 正確做法 * **不要**在 cache 裡存「已套 RoPE」的 Key * 應該存 **linear projection 後、尚未加 RoPE 的 Key/Value** * 每次計算 attention 時: * 根據當下的 $s\_{\text{dyn}}$ * 對當前所有需要的 Query/Key 重新套用 RoPE * 這樣所有 Key/Query 都是同一個 $s$ 下的結果 → 保證一致性 --- ### 小結 KV Cache 的相容性問題,本質上就是: > RoPE 的旋轉與 $s$ 有關,而 $s$ 在 Dynamic Scaling 中會變 → 如果 cache 裡存的是已旋轉版本,就會混用不同 $s$ 的結果,破壞 attention 幾何對齊。 ## 附錄 ### 1. NTK-aware interpolation 的數學推導 **目標**: * 讓插值壓力在所有 hidden dimensions 均勻分布 * 最低頻率與線性位置縮放(scale factor = s)效果一致 * 最高頻率保持不變 **關鍵做法**: * 不是直接把頻率乘上固定 s,而是**改變 RoPE 的基底 b → b'** * 在 RoPE 中,最後一個維度 $d \in D$ 為 $|D|-2$(因為 RoPE 把 cos 與 sin 維度交錯存放) * 設定條件: $$ {b'}^{\frac{|D|-2}{|D|}} = s \cdot b^{\frac{|D|-2}{|D|}} $$ * 解出新基底: $$ b' = b \cdot s^{\frac{|D|}{|D|-2}} $$ --- ### 2. YaRN 在 Attention Pre-softmax Scaling 的影響 ![image](https://hackmd.io/_uploads/rkJkF7QOgg.png) **背景**: * 在 Attention 的 softmax 前,加入一個比例因子 $1/\sqrt{t}$ * 目的是**調整延長 context 後 Attention 分佈的尖銳/平緩程度** * 推薦公式(以 $s=8$ 為例): $$ \sqrt{\frac{1}{t}} = 0.1 \ln(s) + 1 \approx 1.208 $$ **實驗設計**: * 資料:896 篇 RedPajama 16k-token 文本 * 測試不同 $1/\sqrt{t}$ 對 perplexity 的影響 * 額外將每篇文件切成 2048-token 段,分析不同位置的影響 **觀察結果**: 1. 選擇適當 $t$ 時,延長 context 後的 perplexity 可全面下降 2. 最佳 $t$ 在不同樣本與位置區段間**高度一致** 3. 在不同 s 下,最佳 $t$ 都接近公式計算值 ### **3. Dynamic Scaling(無微調模型)** ![image](https://hackmd.io/_uploads/BJ4pt7Quex.png) * **方法**:推理時動態調整插值縮放因子 $s$(適用於 PI、NTK-by-parts、YaRN) * **測試**:在原始 LLaMA 2(max context 4k)上,對長 GovReport 進行 sliding window PPL 測試 * **結果**: * Dynamic Scaling 可有效防止超過訓練長度後 PPL 爆炸 * **Dynamic-YaRN > Dynamic-PI** 在長距 PPL 表現更好 * 適合直接延長未微調模型的推理長度 ### **1. Mistral 長上下文擴展實驗** ### **訓練設定** * **基礎模型**:Mistral 7B v0.1(與 LLaMA 架構類似) * **訓練步驟**: 1. 先以 $s=8$ 訓練 1000 steps → **64k context** 2. 再以 $s=16$ 繼續 500 steps → **128k context** * **資料集**:Together Computer Long-Data Collections(pre-train + fine-tune split) * **特殊設定**:關閉 sliding window attention(注意力範圍 = context window size) ### **測試方法** * 與 Mistral v0.1(base)與 MistralLite(NTK-aware, $\theta=1M$)比較 * 測資:Proof-pile(128k 文檔,截斷到不同測試長度) * 評估指標:Sliding window perplexity(S=256) ### **主要結果** | 模型 | Context | 方法 | 4096 | 8192 | 16384 | 65536 | 131072 | | ------------ | ------- | ---- | -------- | -------- | -------- | -------- | -------- | | Mistral v0.1 | 8k | - | **3.09** | **2.96** | 36.8 | >1e3 | >1e3 | | MistralLite | 16k | NTK | 3.26 | 3.13 | 47.3 | >1e3 | >1e3 | | YaRN $s=8$ | 64k | YaRN | 3.18 | 3.04 | **2.65** | **2.20** | 57.4 | | YaRN $s=16$ | 128k | YaRN | 3.21 | 3.08 | 2.68 | 2.24 | **2.19** | **結論**: * YaRN 在長上下文下的困惑度表現顯著優於基礎與 NTK-aware 版本 * $s=8$ 在 64k 時最優,$s=16$ 在 128k 時最優 * 整體趨勢與 LLaMA 家族模型一致