20260329筆記 內容可能有錯誤,請參考原始影片 [李宏毅 加快語言模型生成速度 (1/2):Flash Attention](https://www.youtube.com/watch?v=vXb2QYOUzl4) [李宏毅 加快語言模型生成速度 (2/2):KV Cache](https://www.youtube.com/watch?v=fDQaadKysSA&t=68s) ### 加快語言模型生成速度 (1/2):Flash Attention 大綱 1. **LLM推論加速的挑戰與 Flash Attention 之定位** 2. **GPU 記憶體階層與 I/O 瓶頸** 3. **傳統 Attention 機制的運算困境:數值穩定與反覆讀寫** 4. **Flash Attention 的核心數學重構:動態修正與單趟掃描** 5. **極致記憶體最佳化:跳過注意力權重矩陣 ($\hat{A}$) 的計算** 6. **實驗和物理極限** --- ### 1. LLM推論加速的挑戰與 Flash Attention 之定位 在探討 LLM 的推論加速時,往往需要面臨效能與準確度之間的取捨。許多加速技術通常伴隨著代價,例如採用近似演算法 導致計算結果出現偏差,或是必須對特定模型進行客製化綁定與重新訓練。但是發表於 2022 年的 Flash Attention 突破了這個限制,它是一項隨插即用的技術,而且不會改變 Self-Attention 的計算結果,能在不犧牲任何精確度的前提下加速。 論文連結: https://arxiv.org/abs/2205.14135 ### 2. GPU 記憶體階層與 I/O 瓶頸 Flash Attention 考慮到 GPU 運算的底層邏輯。在 GPU 架構中,運算單元具備很快的處理速度,但設置的「工作台」——晶片內建的 SRAM 容量非常小。 而且語言模型推論時所需的大量資料 (如 Query, Key, Value 矩陣) 只能存放在容量龐大但讀寫速度較慢的 HBM (High Bandwidth Memory) 中。所以我們發現運算的真正瓶頸並非算力不足,而是頻繁將資料於 HBM 與 SRAM 之間來回搬運所產生的 I/O 延遲。 ![image](https://hackmd.io/_uploads/B1gDkSriWl.png) ### 3. 傳統 Attention 機制的運算困境:數值穩定與反覆讀寫 在傳統的 Self-Attention 運算中,為了防止 Softmax 函數在進行指數運算 ($e^x$) 時發生數值溢位 (Overflow),必須找出整排數值中的最大值 ($A_{max}$),並將所有數值減去該最大值。 由於 SRAM 無法一次容納與序列長度 ($L$) 呈正相關的龐大矩陣,系統必須將資料切分成多個區塊 (Chunk) 分批載入。這導致了傳統演算法必須進行多次的 HBM 讀寫循環: 1. **第一趟讀取**:掃描所有區塊以找出全域最大值 $A_{max}$。 ![image](https://hackmd.io/_uploads/BksK1Hri-l.png) 3. **第二趟讀取**:再次載入資料,計算每個數值減去 $A_{max}$ 後的指數總和 (即 Softmax 的分母)。 ![image](https://hackmd.io/_uploads/H1-sJrSj-x.png) 5. **第三趟讀取**:將算好的標準化注意力權重與 Value ($V$) 矩陣進行乘加運算 (Weighted Sum),得出最終輸出。 這種反覆讀取與寫入嚴重拖慢了推論效能。 ![image](https://hackmd.io/_uploads/ByL21HHj-x.png) ### 4. Flash Attention 的核心數學重構:動態修正與單趟掃描 Flash Attention 透過數學等價轉換,將多次讀寫簡化。其核心概念為「將錯就錯,動態修正」: 當系統讀入第一個 Chunk 時,會先將該區塊的區域最大值 $d_1$ 假定為全域最大值,並計算出一個暫時的指數總和 $S_1$。當讀入下一個 Chunk 且發現更大的數值 $d_2$ 時,演算法無需回頭重新讀取舊資料,而是直接對舊有的總和 $S_1$ 乘上一個修正項 $e^{d_1 - d_2}$。 透過這個一巧妙的數學補償機制,舊有的計算結果便能瞬間轉換為以 $d_2$ 為基準的數值,從而在單向掃描資料的過程中,同時確保最大值與分母總和的正確性。 ![image](https://hackmd.io/_uploads/BkCp1HSsbx.png) ### 5. 極致記憶體最佳化:跳過注意力權重矩陣 ($\hat{A}$) 計算 Flash Attention 最具突破性的設計,在於它選擇直接跳過計算出完整的注意力權重矩陣 ($\hat{A}$)。 演算法會在處理每個 Chunk 時,利用前面提到的動態修正技巧,一起對中間的輸出結果 $O_1$ 進行數學修正:將舊的輸出乘上 $S_1/S_2$ 以及指數差 $e^{d_1 - d_2}$ 來彌平誤差,隨即疊加新 Chunk 的計算結果,太聰明了吧。 系統不斷在更新並累加最終的輸出 $O$,在整個計算過程中,那張龐大的 Attention Matrix 從來在記憶體中被真正建立。所以在實作中試圖讀取 Attention Weights,系統將會報錯,因為這個矩陣根本不存在。 ![image](https://hackmd.io/_uploads/ByjceBHiWe.png) ### 6. 實驗與物理極限 李宏毅老師在實驗中,比較 Flash Attention 計算出的數值與傳統方法之間的差異極小 (約為 $10^{-7}$ ),證實了無損精確度的特性。在執行時間上,當輸入序列長度達到 4096 時,Flash Attention 可提供高達約 9 倍的加速效能。 目前這個技術已成為像是 PyTorch 中 `scaled_dot_product_attention` 模組的預設選項。 但這技術在輸入很短的狀況下,由於推論時間多耗費在 Embedding 轉換等其他運算上,Flash Attention 帶來的加速效益並不明顯。這個外就算 Flash Attention 解決了 SRAM 的工作瓶頸,當面對超長的輸入,龐大的快取仍會撐爆 HBM 倉庫的容量極限,引發 Out of Memory (OOM) 錯誤。 ![image](https://hackmd.io/_uploads/H13BZHHo-l.png) ![image](https://hackmd.io/_uploads/rysGZHHjZg.png) ![image](https://hackmd.io/_uploads/B11VbBHoWx.png) ### 加快語言模型生成速度 (2/2):KV Cache 大綱 1. **語言模型推論的兩階段和 KV Cache ** - Prefill 與 Decode 階段之差異 - 避免冗餘運算:儲存 Key (K) 與 Value (V) 的必要性 2. **KV Cache 引發的 OOM (Out of Memory) 危機** - 記憶體消耗的數學實證(以 LLaMA-2 27B 為例) 3. **架構層級的記憶體妥協:從 MHA、MQA 到 GQA** - 多頭注意力 (MHA) 的空間代價 - 共享機制:Multi-Query Attention 與 Group-Query Attention 的權衡 4. **極致的空間壓縮演算法:Multi-Head Latent Attention (MLA)** - 潛在空間 (Latent Space) 的降維投影 - 數學等價轉換:無須解壓縮的注意力計算 5. **Sliding Window 與 Streaming LLM** - 滑動視窗注意力 (Sliding Window Attention) 的侷限 - 注意力沉澱 (Attention Sink) 現象與 Streaming LLM 的解法 6. **KV Cache Pruning** - 稀疏注意力分佈的觀察 (如 Scissorhands, H2O) - 捨棄邊緣記憶以換取空間 7. **Prompt Caching (前綴匹配機制)** - 系統提示詞 (System Prompt) 工程學:靜態置頂、動態置底 - 降低 API 呼叫成本的實證 8. **總結:推論加速技術的權衡** --- ### 1. 語言模型推論的兩階段與 KV Cache LLM 的文字生成過程,在系統架構上可以分成為兩個階段:首先是處理人類輸入 Prompt 的 **Prefill 階段**,這個階段會平行運算所有 Token 的 Query (Q)、Key (K) 與 Value (V);接著進入逐字生成的 **Decode 階段**。 在 Decode 階段,每生成一個新 Token,傳統作法需要與歷史序列重新計算注意力權重,這會引發很大的冗餘運算。所以我們需要**KV Cache** ,系統只需要幫新生成的 Token 計算自身的 Q、K、V,再把 Q 與過去「已經存下來的 K」進行運算,再對「已經存下來的 V」進行加權總和。在這個機制中,只有 K 與 V 需要被寫入記憶體留存,因為 Q 的任務在完成當下的注意力配對後便已結束,不需被儲存。 ### 2. 硬體記憶體瓶頸:KV Cache 引發的 OOM 危機 雖然 KV Cache 大幅減少了算力消耗,但卻對 GPU 的 HBM 造成了極大的儲存壓力。 以 LLaMA-2 27B 網路架構為例,若模型具有 46 層網路、32 個 Attention Head,且每個 Head 維度為 128,在使用 FP16 精度 (每個數值佔 2 Bytes) 的情況下,每生成一個 Token,系統需消耗的 KV Cache 容量約為 0.72 MB ($46 \times 32 \times 128 \times 2 \times 2$)。若在配備 80GB 記憶體的 A100 GPU 上運行,其實際僅能容納約 11.4 萬個 Token。面對當前動輒數十萬 Token 的長文本需求,龐大的 KV Cache 會輕易撐爆硬體,導致 Out of Memory (OOM) 錯誤。 ![image](https://hackmd.io/_uploads/HkYnGBBsZg.png) ### 3. 架構層級的記憶體妥協:從 MHA、MQA 到 GQA 為了降低 KV Cache 的體積,需要對傳統的多頭注意力機制 (Multi-Head Attention, MHA) 進行了架構上的調整。 * **Multi-Query Attention (MQA)**:強制所有的 Query 共用「同一組」K 與 V,以極大化壓縮記憶體空間,但這個會嚴重限縮模型的多樣性表達,導致效能下降。 * **Group-Query Attention (GQA)**:作為 MHA 與 MQA 的折衷方案,將多個 Query 劃分為一組,同組內的 Query 共用一組 K 與 V。這個方法在縮減快取體積的同時,維持了與 MHA 相當的推論品質,目前已廣泛應用於 LLaMA 等主流模型中。 ![image](https://hackmd.io/_uploads/rJZU7BSsWg.png) ### 4. 極致的空間壓縮演算法:Multi-Head Latent Attention (MLA) DeepSeek 等模型採用了更為精妙的降維壓縮技術——MLA。這個技術透過 Bottleneck Layer ,將龐大的 K 與 V 壓縮成一個極低維度的潛在向量 $C$ 存入 KV Cache。 但 MLA 最具突破性的設計在於其**無須解壓縮**的數學特性。在傳統直覺中,若要在注意力計算時還原特徵,必須耗費龐大算力將 $C$ 解壓縮回高維矩陣。但是透過線性代數的結合律轉換,系統可將解壓縮矩陣 ($W_k$) 的轉置操作直接與新輸入的 Query 相乘 ($Q \cdot W_k^T$),隨後再與壓縮狀態的 $C$ 進行內積。同理,在 Value 的加權總和階段,也可直接在潛在空間 ($C$) 中進行運算,最後僅需進行一次解壓縮即可得出結果。用最少的空間與算力完成等價的 Attention 運算。 ![image](https://hackmd.io/_uploads/SyvnQHribl.png) ![image](https://hackmd.io/_uploads/B1R27rrs-x.png) ![image](https://hackmd.io/_uploads/rk107rBjbx.png) ![image](https://hackmd.io/_uploads/r13C7BSsWx.png) ### 5.Sliding Window 與 Streaming LLM * **Sliding Window Attention**:強制模型僅對距離最近的固定數量 Token (例如 4096 個) 進行注意力計算,確保 KV Cache 的大小有物理上限。然而,當序列長度超出訓練時的視窗大小時,模型的困惑度 (Perplexity) 會瞬間飆高,導致生成品質崩壞。 ![image](https://hackmd.io/_uploads/rJB7NHBiWg.png) * **Streaming LLM**:研究發現,模型在運算 Attention 時,若沒有特別需要關注的內容,會將注意力權重預設傾注於「序列中最開頭的第一個 Token」( Attention Sink)。如果 Sliding Window 無意間丟棄了第一個 Token,模型的運作邏輯便會崩潰。所以 Streaming LLM 主張在滑動視窗之外,永遠強制保留最開頭的數個 Token,這個無須重新訓練模型,也可以讓模型穩定處理長文本輸入。 ![image](https://hackmd.io/_uploads/SkfV4Brsbg.png) ### 6. 動態記憶體修剪技術:KV Cache Pruning 由於多數序列中的 Token 在後續生成中極少被再次關注,學界提出了 KV Pruning 技術 (如 Scissorhands 與 H2O)。透過分析注意力分佈,研究證實多數 Token 的 K 與 V 只是佔用空間的無效資訊。藉由演算法動態剔除這些未被關注的歷史紀錄,系統甚至能在捨棄高達 80% K 與 V 的極端情況下 (壓縮 5 倍),維持原本的任務表現。 ![image](https://hackmd.io/_uploads/ryjwVrSiWx.png) ### 7. 跨對話的商業實務:Prompt Caching (前綴匹配機制) KV Cache 不僅應用於單一生成任務,更可擴展至**跨對話 (Cross-session)** 的快取共用。 當兩個不同的對話擁有完全相同的「前綴 (Prefix)」時,它們所對應算出的 K 與 V 也是完全一致的,可直接轉移使用。然而,一旦前綴有任何微小差異 (如更換單詞),後續的 KV Cache 就會被破壞且無法共用。 ![image](https://hackmd.io/_uploads/rJ0_VSBobx.png) ![image](https://hackmd.io/_uploads/Sknt4rroZe.png) 在使用 AI Agent 時,**系統提示詞 (System Prompt) 的排列工程**很重要:可以把絕對靜態的內容 (如角色設定、可用工具清單) 放置於 Prompt 最前端,而將動態變數 (如當下時間、特定地點查詢) 放在末端。這個策略可最大化命中快取 (Cache Hit)。 ![image](https://hackmd.io/_uploads/S1AjEHHjZe.png) ### 8. 總結:推論加速技術的代價光譜 語言模型推論加速是一場空間、時間與模型準確度的權衡: * **Flash Attention**:無損而且不用重頭訓練,透過優化 GPU I/O 換取時間。 * **KV Cache**:無損且無須重頭訓練,透過佔用龐大記憶體空間以換取時間。 * **MQA / GQA / MLA**:架構層級變更,會改變原始 Attention 機制並需要從頭訓練模型,來換取 KV Cache 的空間縮減。 * **Sliding Window / Streaming LLM / KV Pruning**:改變了注意力計算範圍或捨棄部分記憶,部分方法可以不用重頭訓練直接應用,但極端情況下可能損害模型表現。 * **Speculative Decoding (投機解碼)**:使用小模型輔助大模型生成,無損而且不用重新訓練大模型,但代價是需消耗執行小模型的額外算力。 ![image](https://hackmd.io/_uploads/BJmbSSHsbl.png)