# FlashAttention-2 : Faster Attention with Better Parallelism and Work Partitioning 論文 : [FlashAttention-2](https://arxiv.org/abs/2205.14135) ## abstract ### 核心問題:Transformer 長序列處理的瓶頸 * **痛點與潛力:** * **問題:** 將 **Transformer** 擴展到更長的序列長度是一個主要問題。 * **潛力:** 解決此問題可望提升語言模型、高解析度圖像理解的性能,並開啟程式碼、音訊和影片生成等新應用。 * **瓶頸根源:** * **元兇:** **注意力層 (Attention Layer)**。 * **具體影響:** 其運行時間(runtime)和記憶體(memory)消耗會隨著序列長度呈 **二次方 ($O(N^2)$)** 增長。 ## 現有解決方案及其局限性 ### 1\. 現有工作:FlashAttention | 方面 | 描述 | 效能表現 | 局限性/待解決問題 | | :--- | :--- | :--- | :--- | | **方法** | 利用 **GPU 記憶體階層的不對稱性**。 | 1. 顯著的記憶體節省(**線性** $O(N)$ 而非二次方 $O(N^2)$)。<br>2. 運行時間加速(比優化過的基線快 **2-4倍**)。<br>3. **無近似** (no approximation)。 | **效能低落:** 僅達到理論最大 **FLOPs/s 的 25-40%**,遠不及優化後的矩陣乘法 (GEMM) 運算效率。 | ### 2\. 效率低落的原因 * **根本原因:** GPU 上不同 **執行緒區塊 (thread blocks)** 和 **Warp** 之間的工作劃分不夠理想。 * **導致結果:** * **低佔用率 (low-occupancy)**。 * **不必要的共享記憶體讀/寫 (unnecessary shared memory reads/writes)**。 ## 提出的新方法:FlashAttention-2 | 方面 | 描述 | | :--- | :--- | | **目標** | 解決 FlashAttention 的效率問題,提高 FLOPs/s 利用率。 | | **核心策略** | 更好的 **工作劃分 (work partitioning)**。 | ### 關鍵改進點 (3項具體優化) | 序號 | 優化項目 | 目的/效果 | | :--- | :--- | :--- | | (1) | **演算法微調 (Tweak the algorithm)** | 減少 **非矩陣乘法 (non-matmul) 的 FLOPs 數量**。 | | (2) | **跨執行緒區塊平行化 (Parallelize across thread blocks)** | 即使是單個注意力頭 (single head) 的計算,也分散到不同的執行緒區塊上,以 **增加佔用率 (increase occupancy)**。 | | (3) | **執行緒區塊內的 Warp 工作分配** | 在每個執行緒區塊內,合理分配 Warp 的工作,以 **減少透過共享記憶體的通訊 (reduce communication)**。 | ### 效能成果 * **加速比:** 相較於 FlashAttention 帶來約 **2倍** 的加速。 * **理論峰值:** 在 A100 上達到理論最大 **FLOPs/s 的 50-73%**,逼近 GEMM 運算的效率。 * **端到端訓練:** 用於訓練 GPT 類模型時,達到每個 A100 GPU **225 TFLOPs/s** 的訓練速度(相當於 **72% 的模型 FLOPs 利用率**)。 ## Introduction ### 1. 研究背景與核心問題 * **基礎模型:** **Transformer** ($\sim$2017) * **挑戰:** 擴展 Transformer 的**上下文長度 (context length)**。 * **瓶頸:** **注意力層 (Attention Layer)** 的運行時間和記憶體需求與輸入序列長度呈**二次方 ($O(N^2)$)** 關係。 * **理想目標:** 超越標準的 2k 序列長度限制,訓練模型以理解書籍、高解析度圖像和長篇影片。 * **趨勢證明:** 近期大型語言模型(LLMs)的上下文長度顯著增長: * **GPT-4:** 32k * **MPT:** 65k * **Claude:** 100k * **應用需求:** 長文件查詢、故事寫作等新興用例需要長上下文模型。 ### 2. 現有解決方案與 (FlashAttention) * **近似方法 (Approximation Methods):** * **目標:** 減少長上下文注意力計算的需求。 * **例子:** Reformer, Linformer, Longformer, BigBird 等。 * **現狀:** 儘管存在,但**大多數大規模訓練仍使用標準注意力**。 * **優化方法:\sysnameone (FlashAttention)** * **作者:** \citet{dao2022flashattention} * **核心思想:** 重排注意力計算,並利用**平鋪 (tiling)** 和**重新計算 (recomputation)** 的經典技術。 * **優勢:** 1. 記憶體使用從二次方降至**線性 ($O(N)$)**。 2. 運行時間比優化後的基線快 **2-4倍** (wall-clock time)。 3. 記憶體節省達 **10-20倍**。 4. **無近似 (no approximation)**。 * **結果:** \sysnameone 已在 Transformer 的大規模訓練和推理中獲得廣泛採用。 ### 3. FlashAttention 的局限性與效率瓶頸 * **問題:** 儘管 FlashAttention 已經很快,但**效率仍遠不如矩陣乘法 (GEMM)** 等其他基本運算。 * **具體數據:** * **FlashAttention 順向傳播 (Forward Pass):** 僅達到理論最大 **FLOPs/s 的 30-50\%**。 * **FlashAttention 逆向傳播 (Backward Pass):** 僅達到理論最大 **FLOPs/s 的 25-35\%** (A100 GPU)。 * **對比 GEMM:** 優化後的 GEMM 可達到理論最大設備吞吐量的 **80-90\%**。 * **根本原因:** **次優的工作劃分 (suboptimal work partitioning)** * 在 GPU 上,不同**執行緒區塊 (thread blocks)** 和 **Warp** 之間的工作劃分不佳。 * 導致:**低佔用率 (low-occupancy)** 或**不必要的共享記憶體讀/寫**。 ### 4. 提出的新方法:FlashAttention2 (貢獻點) 新方法 **FlashAttention2** 建立在 FlashAttention 的基礎上,通過更好的**並行化 (parallelism)** 和**工作劃分 (work partitioning)** 來解決效率挑戰。 | 序號 | 貢獻點/優化內容 | 目的/原因 | 實作位置 | | :--- | :--- | :--- | :--- | | **1.** | **微調演算法以減少非 Matmul FLOPs** | **目的:** 減少不依賴矩陣乘法單元的運算量。**原因:** GPU 上的 Matmul 吞吐量比非 Matmul 吞吐量高達 **16倍**,應盡量將時間花在 Matmul 運算上。 | \cref{subsec:algo} | | **2.** | **沿序列長度維度平行化** | **目的:** 增加 GPU 資源的**佔用率 (occupancy)**。**應用情境:** 當序列很長(通常批次大小 batch size 較小時),可提高資源利用。 | (方法論部分) | | **3.** | **執行緒區塊內 Warp 的工作劃分** | **目的:** 減少**通訊 (communication)** 和**共享記憶體讀/寫**。 | (方法論部分) | ### 5. 實驗結果 * **核心成果:** 相較於 FlashAttention 實現了顯著加速。 * **加速比:** 相較於 FlashAttention,實現約 **2倍** 的加速。 * **理論峰值:** * 順向傳播:高達理論最大吞吐量的 **73\%**。 * 逆向傳播:高達理論最大吞吐量的 **63\%**。 * **端到端訓練:** 訓練 GPT 類模型時,每個 A100 GPU 達到 **225 TFLOPs/s** 的訓練速度。 ## Background ### 2.1 硬體特性 (Hardware characteristics) #### GPU 性能特徵 * **結構:** 由計算元件(如浮點運算單元)和**記憶體階層 (memory hierarchy)** 組成。 * **專用單元:** 現代 GPU 包含用於加速低精度矩陣乘法(如 FP16/BF16)的**專用單元**(例如:Nvidia GPU 上的 **Tensor Cores**)。 * **記憶體階層:** * **高頻寬記憶體 (HBM, High Bandwidth Memory):** 頻寬較低,容量大。 * *A100 範例:* 40-80GB, 頻寬 $\sim 1.5-2.0 \text{TB/s}$。 * **片上 SRAM (On-chip SRAM) / 共享記憶體 (Shared Memory):** 頻寬高,容量小 * *A100 範例:* 每個 **Streaming Multiprocessor (SM)** 有 192KB,頻寬估計 $\sim 19 \text{TB/s}$。 #### 執行模型 (Execution Model) * **執行單位:** * **Kernel:** 執行的操作。 * **Threads (執行緒):** 數量龐大。 * **組織結構:** * **Thread Blocks (執行緒區塊):** 執行緒被組織成區塊,分配給 SM 執行。 * **Warps:** 每個 Thread Block 內,執行緒被分組為 Warps(**32 個執行緒**一組)。 * **通訊機制:** * **Warp 內部:** 可透過快速的 **Shuffle Instructions** 或合作進行矩陣乘法來通訊。 * **Thread Block 內部(Warps 之間):** 透過讀/寫**共享記憶體 (Shared Memory)** 來通訊。 * **計算流程:** Kernel 將輸入從 HBM 載入到**暫存器 (registers)** 和 SRAM,進行計算,然後將輸出寫回 HBM。 ### 2.2 標準注意力實作 (Standard Attention Implementation) #### 順向傳播 (Forward Pass) * **輸入:** 查詢 $Q$、鍵 $K$、值 $V$ (皆為 $\mathbb{R}^{N \times d}$)。 * $N$: 序列長度 (Sequence Length)。 * $d$: 頭維度 (Head Dimension)。 * **計算步驟 (三個主要步驟):** 1. **相似度計算:** $S = Q K^\top \in \mathbb{R}^{N \times N}$ 2. **機率計算:** $P = softmax(S) \in \mathbb{R}^{N \times N}$ (逐行 $softmax$) 3. **輸出計算:** $O = P V \in \mathbb{R}^{N \times d}$ * *備註:* MHA (Multi-Head Attention) 是在 Batch 和 Heads 維度上平行執行此計算。 #### 逆向傳播 (Backward Pass) * **輸入:** $O$ 的梯度 $dO \in \mathbb{R}^{N \times d}$。 * **梯度計算 (Chain Rule):** $$dV = P^\top dO$$ $$dP = dO V^\top$$ $$dS = dsoftmax (dP)$$ $$dQ = dS K$$ $$dK = Q dS^\top$$ * 其中 $dsoftmax$ 是 $softmax$ 的梯度計算,其計算涉及到 $P$ 和 $dP$。 #### 實作局限性 (Standard Implementation Bottlenecks) * **記憶體問題:** 矩陣 $S$ 和 $P$ 被**實體化 (materialize)** 到 **HBM**。 * **記憶體需求:** $O(N^2)$ (二次方)。 * **瓶頸:** 需要儲存 $O(N^2)$ 的 $P$ 矩陣用於逆向傳播。 * **序列長度比較:** 通常 $N \gg d$ (例如 $N \sim 1\text{k-8k}$,而 $d \sim 64\text{-}128$),使 $N^2$ 成為主要負擔。 * **速度問題:** 標準實作通常涉及三個分開的步驟,且中間結果 $(S, P)$ 寫入 HBM。 1. 呼叫 GEMM ($Q K^\top$),結果寫入 HBM。 2. 從 HBM 載入 $S$,計算 $softmax$,結果 $P$ 寫入 HBM。 3. 呼叫 GEMM ($P V$)。 * **瓶頸:** 大量的記憶體存取導致**記憶體頻寬受限 (memory bandwidth bounded)**,使得 wall-clock time 變慢。 ### FlashAttention ![image](https://hackmd.io/_uploads/S15S7AsAeg.png) ### 核心思想:減少記憶體 I/O,保持輸出不變 * **目標:** 在硬體加速器(如 GPU)上加速注意力計算,通過減少**記憶體讀/寫次數 (memory reads/writes)** 來實現,同時**不使用近似 (without approximation)**,保持輸出相同。 * **關鍵技術:** 1. **分塊平鋪 (Tiling):** 將計算分解成塊。 2. **在線 Softmax (Online Softmax):** 允許按塊計算注意力並在最後進行校正。 --- ### 順向傳播 (Forward Pass) #### 1. 分塊與計算流程 \sysnameone 利用平鋪技術實現流程: 1. **載入:** 將輸入的塊 (blocks of inputs) 從 **HBM** 載入到 **SRAM**。 2. **計算:** 針對該塊計算注意力。 3. **更新:** **更新輸出 (output)**,而**不將大的中間矩陣 $S$ 和 $P$ 寫入 HBM**。 #### 2. 在線 Softmax 機制 * **必要性:** 標準 Softmax 涉及整行或整塊的耦合計算。**在線 Softmax** 允許將注意力計算分割成塊,並通過**重新縮放 (rescale)** 每一塊的輸出,最終得到正確的結果(無近似)。 * **Online Softmax 數學原理:** * 對於包含 $S^{(1)}$ 和 $S^{(2)}$ 兩個塊的一行注意力分數: * **標準 Softmax** 需要計算全局的最大值 $m$ 和歸一化因子 $\ell$。 * **在線 Softmax** 則採取迭代更新: 1. 先計算第一個塊的局部最大值 $m^{(1)}$ 和歸一化因子 $\ell^{(1)}$,並計算局部輸出 $O^{(1)}$。 2. 計算第二個塊時,更新全局最大值 $m^{(2)} = \max(m^{(1)}, \mathrm{rowmax}(S^{(2)}))$。 3. 使用 $m^{(2)}$ 和 $\ell^{(1)}$ 更新全局歸一化因子 $\ell^{(2)}$。 4. 利用 $\ell^{(2)}$ 重新縮放 $O^{(1)}$,並加上第二塊的輸出,得到最終輸出 $O$。 * **圖示:** \cref{fig:flash_attention_diagram} 說明了通過對 $K$ 和 $V$ 的分塊,實現對每個塊計算注意力並重新縮放輸出,從而避免昂貴的 $S$ 和 $P$ 記憶體讀寫。 #### 3. 效能提升 * **加速比:** 相較於優化後的基線注意力實作,實現 **2-4倍** 的 wall-clock time 加速。 --- ### 逆向傳播 (Backward Pass) #### 1. 記憶體節省與重新計算 * **策略:** 通過**重新計算 (re-computing)** 注意力矩陣 $S$ 和 $P$ 的值。 * **時機:** 當輸入塊 $Q, K, V$ 已載入到 SRAM 時,即時重新計算所需的中間值。 * **效果:** 避免了將 $O(N^2)$ 的 $S$ 和 $P$ 矩陣儲存到 HBM,實現: * **記憶體節省:** 視序列長度,可達 **10-20倍**。 * **記憶體複雜度:** 從二次方 $O(N^2)$ 降至**線性 $O(N)$**。 #### 2. 複雜性 * **計算量:** 逆向傳播概念上較簡單(無 Softmax 重新縮放),但實作複雜得多。 * **原因:** 逆向傳播需要執行 **5 個矩陣乘法**($dV, dP, dQ, dK$ 的計算),而順向傳播只有 2 個,因此需要**在 SRAM 中保留更多數值**。 ## FlashAttention2: Algorithm, Parallelism, and Work Partitioning ![image](https://hackmd.io/_uploads/HyDjFAiAxe.png) ### 概述 * **目標:** 實現比 FlashAttention 更高的效率,特別是提高 FLOPs/s 利用率。 * **主要改進:** 1. **演算法調整:** 減少非矩陣乘法 (non-matmul) FLOPs。 2. **平行化:** 利用不同的執行緒區塊 (thread blocks) 實現全 GPU 資源利用。 3. **工作劃分:** 在單一執行緒區塊內,劃分 Warp 的工作,減少共享記憶體存取。 * **預期成果:** 實驗結果顯示可實現 **2-3 倍** 的加速。 --- ### 3.1 演算法 (Algorithm) * **動機:** 現代 GPU 對矩陣乘法有專門的計算單元(如 Nvidia 的 **Tensor Cores**),使其速度遠超非 Matmul 運算。 * **效率差距:** 在 A100 上,FP16/BF16 Matmul 的理論最大吞吐量是 **$312 \text{TFLOPs/s}$**,而非 Matmul FP32 僅 **$19.5 \text{TFLOPs/s}$**。 * **成本比:** 每個非 Matmul FLOP 的成本比 Matmul FLOP **高 16 倍**。 * **核心策略:** 盡可能將計算時間用於 Matmul FLOPs,以維持高吞吐量。 #### 3.1.1 順向傳播 (Forward Pass) 的兩處微調 \sysname 在 \sysnameone 的基礎上,對 **在線 Softmax (Online Softmax)** 技巧進行了兩處微調: 1. **延遲輸出縮放 (Delaying Output Scaling):** * **FlashAttention 方式:** 在每次迭代中,對輸出更新的兩個項都進行 $diag(\ell^{(2)})^{-1}$ 的縮放。 * **FlashAttention2 方式 (優化):** 維護一個 **「未縮放」** 的輸出版本 $\tilde{O}^{(j)}$,並只在**循環結束時**,用最終的 $diag(\ell^{(\text{last})})^{-1}$ 因子進行一次最終縮放。 * **好處:** 減少了循環內的非 Matmul 運算(即對 $O^{(1)}$ 的重新縮放操作 $diag(e^{m^{(1)} - m^{(2)}})^{-1}$)。 2. **僅儲存 Logsumexp (Storing only Logsumexp):** * **FlashAttention 方式:** 需要儲存 $\max (m^{(j)})$ 和 $\sum \exp (\ell^{(j)})$ 兩組統計量用於逆向傳播。 * **FlashAttention2 方式 (優化):** 僅儲存 **Logsumexp** $L^{(j)} = m^{(j)} + \log(\ell^{(j)})$。 * **好處:** 減少了需要儲存和處理的統計量數量,從而減少了非 Matmul FLOPs。 * **更新後的 Online Softmax 數學表達:** 論文中提供了 2 塊情況下的數學推導,確認了保持 $\tilde{O}^{(2)}$ 未縮放的計算,與最終縮放的結果 $O^{(2)}$ 等效。 #### 3.1.2 FlashAttention2 順向傳播演算法 ![image](https://hackmd.io/_uploads/BJqYXk2Cle.png) 1. **初始化:** 將 $Q$ 分成 $T_r$ 塊,$K, V$ 分成 $T_c$ 塊。初始化 $O$ 和 Logsumexp $L$ 的塊。 2. **外部循環(Row Block):** 迭代 $Q$ 的塊 $Q_i$ ($i = 1$ to $T_r$)。 * **載入與初始化:** 載入 $Q_i$ 到 SRAM。初始化**未縮放輸出 $O_{i}^{(0)}$**、**歸一化因子 $\ell_{i}^{(0)}$** 和 **最大值 $m_{i}^{(0)}$**。 * **內部循環(Column Block):** 迭代 $K, V$ 的塊 $K_j, V_j$ ($j = 1$ to $T_c$)。 * **載入與計算:** 載入 $K_j, V_j$。計算注意力分數 $S_{i}^{(j)} = Q_i K_j^T$。 * **統計量更新 (\ref{alg:stream_attn_statistics}):** 更新 $m_{i}^{(j)}$、計算 $\tilde{P}_{i}^{(j)} = \exp(S_{i}^{(j)} - m_{i}^{(j)})$,並更新 $\ell_{i}^{(j)}$。 * **輸出更新 (\ref{alg:stream_attn_update}):** 使用簡化的更新公式 $O_{i}^{(j)} = diag(e^{m_{i}^{(j-1)} - m_{i}^{(j)}})^{-1} O_{i}^{(j-1)} + \tilde{P}_{i}^{(j)} V_j$ 來維護未縮放的 $O_{i}^{(j)}$。 * **最終縮放與寫回:** 在內部循環結束後,對最終的 $O_{i}^{(T_c)}$ 進行一次性縮放 $O_{i} = diag(\ell_{i}^{(T_c)})^{-1} O_{i}^{(T_c)}$,並計算 Logsumexp $L_{i}$。將 $O_{i}$ 和 $L_{i}$ 寫回 HBM。 #### 3.1.3 因果遮罩 (Causal Masking) * **應用場景:** 自迴歸語言模型(Auto-regressive language modeling)。 * **優化點:** 1. 對於**完全處於因果遮罩區**的塊(行索引 $i$ 的所有列索引 $j$ 都滿足 $j > i$),可以**跳過計算**,獲得約 **$1.7-1.8$ 倍**的加速。 2. 對於其他塊,只需要在**與對角線相交**的那一個塊上應用因果遮罩,減少非必要的遮罩操作。 #### 3.1.4 正確性、運行時間和記憶體 * **結果:** 保持了與 FlashAttention 相同的特性: * **正確性:** 返回正確的輸出 $O$(無近似)。 * **運行時間:** $O(N^2 d)$ FLOPs。 * **記憶體需求:** $O(N)$ 額外記憶體(用於儲存 Logsumexp $L$)。 ### 3.1.5 逆向傳播 (Backward Pass) ![image](https://hackmd.io/_uploads/H1LkoRiRxx.png) * **核心相似性:** FlashAttention 的逆向傳播與 FlashAttention2 **大致相同**。 * **演算法調整 (Tweak):** 唯一的微調是:**只使用逐行 Logsumexp $L$**,而不是同時使用 Softmax 的逐行最大值 $m$ 和逐行指數和 $\ell$。 * **意義:** 這與順向傳播的優化 (2) 相呼應,進一步減少了需要在 HBM 中讀取和處理的統計量數量(非 Matmul FLOPs 減少)。 #### FlashAttention2 逆向傳播演算法 (Algorithm \ref{alg:flash_bwd}) 該演算法利用分塊和重新計算來高效地計算 $dQ, dK, dV$ 的梯度。 1. **分塊與初始化:** * 將所有輸入和輸出的矩陣($Q, K, V, O, dO$)和 Logsumexp $L$ 分成 $T_r$ 或 $T_c$ 塊。 * 初始化梯度 $dQ, dK, dV$ 的塊($dQ$ 初始化在 HBM,$dK_j, dV_j$ 在 SRAM)。 * **中間計算:** 首先計算 $D = \mathrm{rowsum}(dO \circ O)$(梯度調節因子的一部分),並寫回 HBM。 2. **外部循環(Column Block):** 迭代 $K, V$ 的塊 $K_j, V_j$ ($j = 1$ to $T_c$)。 * **載入:** 載入 $K_j, V_j$ 到 SRAM。 * **內部循環(Row Block):** 迭代 $Q, O,dO, L$ 的塊 ($i = 1$ to $T_r$)。 * **載入:** 載入 $Q_i, O_i, dO_i, L_i, D_i$ 等到 SRAM。 * **重新計算 $S, P$:** 在 SRAM 上計算 $S_{i}^{(j)} = Q_i K_j^T$ 和 $P_{i}^{(j)} = \exp(S_{ij} - L_{i})$。 * **梯度計算與累積 (5個矩陣乘法/更新):** * **$dV$ 更新:** $dV_j \leftarrow dV_j + (P_{i}^{(j)})^\top dO_i$ * **$dP$ 計算:** $dP_{i}^{(j)} = dO_{i} V_j^\top$ * **$dS$ 計算:** $dS_{i}^{(j)} = P_{i}^{(j)} \circ (dP_{i}^{(j)} - D_i)$ (這包含 Softmax 梯度的計算) * **$dQ$ 更新:** $dQ_{i} \leftarrow dQ_i + dS_{i}^{(j)} K_j$(**注意:** $dQ_i$ 是唯一在內部循環中被載入、更新並**寫回 HBM** 的梯度塊) * **$dK$ 更新:** $dK_{j} \leftarrow dK_j + {dS_{i}^{(j)}}^\top Q_i$ * **寫回:** 在外部循環結束時,將 $dK_j, dV_j$ 寫回 HBM。 --- ### 多查詢與分組查詢注意力 (MQA and GQA) * **概念:** 這是注意力的變體,其中多個查詢頭 (Query heads) 共享相同的鍵和值頭 (Key and Value heads)。 * **目的:** 主要用於**推理 (inference)** 期間,以減少 KV 緩存 (KV cache) 的大小。 * **FlashAttention2 的實作:** 透過**隱式地操作頭的索引**來執行相同的計算,而不是實際複製鍵和值頭。 * **逆向傳播需求:** 在逆向傳播中,需要將**隱式複製**的鍵 $dK$ 和值 $dV$ 梯度在不同的頭上**求和 (sum)**。 ### Parallelism ### 核心問題與動機 * **FlashAttention 的平行化:** 僅沿 **批次大小 (Batch Size)** 和 **頭數 (Number of Heads)** 維度進行平行化。 * **執行單元分配:** 每個注意力頭由 **1 個執行緒區塊 (Thread Block)** 處理。總共有 $\text{Batch Size} \times \text{Number of Heads}$ 個 Thread Blocks。 * **效率瓶頸:** * 當 $\text{Batch Size} \times \text{Number of Heads}$ 數量**大**(例如 $\geq 80$)時,GPU 資源(如 A100 的 108 個 SMs)可被有效利用。 * 在**長序列 (Long Sequences)** 的情況下,通常會導致 **小批次大小** 或 **小頭數**。此時 Thread Blocks 數量變少,導致 GPU **多重處理器 (SMs) 的利用率低 (low occupancy)**。 * **FlashAttention2 的目標:** 額外沿**序列長度維度 (Sequence Length dimension)** 增加平行化,以提高在長序列情況下的 SM 利用率和速度。 --- ### 順向傳播 (Forward Pass) 的平行化 * **機會:** 在 \cref{alg:flash2_fwd} 中,**外部循環(沿序列長度維度的行塊 $T_r$ 迭代)** 具有**高度平行性 (embarrassingly parallel)**。 * **實作:** 將這些外部循環的計算(即每個 $Q_i$ 塊的計算)調度給**不同的執行緒區塊**。 * **優勢:** 這些執行緒區塊彼此**無需通訊 (do not need to communicate)**。 * **效果:** 當 Batch Size 和 Heads 數量小時,增加沿序列長度維度的平行化,可有效**提高佔用率 (occupancy)**,從而加速。 * **迴圈順序:** 這種將外部迴圈設為行塊、內部迴圈設為列塊的做法,以及沿序列長度維度平行化的想法,最早由 Phil Tillet 在 Triton 實作中提出。 * **示意圖(\cref{fig:parallelism} 左圖):** 每個工作單元(Thread Block)負責注意力矩陣的**一行塊 (Block of Rows)**。 --- ### 逆向傳播 (Backward Pass) 的平行化 * **挑戰:** 逆向傳播的計算邏輯較為複雜,不同塊之間的**共享計算**主要發生在 $dQ$ 的更新上。 * 在 $dQ$ 更新中,需要將 $dQ_i$ 從 HBM 載入 SRAM,更新 $dQ_{i} \leftarrow dQ_i + dS_{i}^{(j)} K_j$,然後寫回 HBM。 * **實作:** 沿序列長度維度進行平行化。 * **調度:** 每個**列塊 (Column Block)** 的計算(即 \cref{alg:flash_bwd} 中的外部循環 $j$)由 **1 個執行緒區塊**處理。 * **通訊機制:** 由於多個 Thread Blocks 可能需要更新同一個 $dQ_i$,因此使用 **原子加法 (Atomic Adds)** 來協調不同執行緒區塊之間的 $dQ$ 梯度更新。 * **示意圖(\cref{fig:parallelism} 右圖):** 每個工作單元(Thread Block)負責注意力矩陣的**一列塊 (Block of Columns)**。 ### Work Partitioning Between Warps ### 核心問題:FlashAttention 的 Warp 通訊瓶頸 * **背景:** 每個執行緒區塊 (Thread Block) 內,通常使用 4 或 8 個 Warps 來進一步劃分工作。 * **FlashAttention 的方法(Split-K Scheme):** * **劃分方式:** 將 $K$ 和 $V$ **劃分 (Split)** 給 4 個 Warps,而 $Q$ 保持所有 Warps 都可存取。 * **計算流程:** 1. 每個 Warp 計算 $Q K^\top$ 的一個切片 ($S$)。 2. 每個 Warp 用 $S$ 乘上 $V$ 的切片,得到中間結果。 3. **問題:** 所有 Warps 需要將它們的中間結果**寫入共享記憶體 (Shared Memory)**。 4. **問題:** Warps 必須**同步 (synchronize)**,然後將這些中間結果**加總 (add up)** 得到最終的輸出。 * **效率低落原因:** **共享記憶體的讀/寫和同步操作** 減慢了 \sysnameone 的順向傳播速度。 ### FlashAttention2 的優化方法:減少通訊 * **FlashAttention2 的方法 (Split-Q Scheme):** * **劃分方式:** 將 $Q$ **劃分 (Split)** 給 4 個 Warps,而 $K$ 和 $V$ 保持所有 Warps 都可存取(即共享)。 * **計算流程:** 1. 每個 Warp 計算 $Q K^\top$ 的一個切片 ($S_{slice}$)。 2. 每個 Warp 使用**共享**的 $V$ 切片來計算其對應的**輸出切片 (slice of the output)**。 * **效果:** **Warps 之間無需通訊**。 * **結論:** 減少了共享記憶體的讀/寫操作,帶來顯著的速度提升。 ![image](https://hackmd.io/_uploads/HywA1Jh0lg.png) #### 逆向傳播 (Backward Pass) 的工作劃分 * **策略:** 逆向傳播也同樣**避免採用 Split-K 方案**,從而減少共享記憶體的讀寫。 * **複雜性:** 由於涉及更多的輸入和梯度($Q, K, V, O, dO, dQ, dK, dV$)之間更複雜的依賴關係,逆向傳播**仍需要一些同步**。 * **效果:** 儘管複雜,避免 Split-K 仍然有效減少了共享記憶體 I/O,帶來速度提升。 ### 區塊大小調優 (Tuning Block Sizes) * **影響因素:** * **增大區塊大小的好處:** 通常可以**減少共享記憶體的載入/儲存 (loads/stores)**。 * **增大區塊大小的壞處:** 會**增加所需的暫存器 (registers)** 數量和總共享記憶體需求。 * **臨界限制:** 當區塊大小超過臨界值時,可能導致: 1. **暫存器溢出 (Register Spilling):** 造成顯著的減速。 2. **共享記憶體不足:** 所需記憶體超過 GPU 限制,導致 Kernel 無法運行。 * **選擇:** 通常手動選擇 $\{64, 128\} \times \{64, 128\}$ 範圍內的區塊大小,具體取決於頭維度 $d$ 和設備的共享記憶體大小。 * **未來工作:** 論文提到這是一個可以透過**自動調優 (auto-tuning)** 來避免手動勞動的領域。 ## 4 Empirical Validation ### 總體概括 * **評估範圍:** 評估 FlashAttention2 在訓練 Transformer 模型時的影響。 * **主要結果:** * **注意力基準測試:** FlashAttention2 比 FlashAttention 快 **$1.7-3.0$ 倍**。在 A100 上達到高達 **$230 \text{TFLOPs/s}$** (理論峰值的 **73\%**)。 * **端到端訓練速度:** 相較於 FlashAttention 提速高達 **$1.3$ 倍**。在 A100 上達到高達 **$225 \text{TFLOPs/s}$** (模型 FLOPs 利用率的 **72\%**)。 --- ## 4.1 注意力基準測試 (Benchmarking Attention) #### 實驗設定 * **硬體:** A100 80GB SXM4 GPU。 * **變量:** 序列長度 (512 到 16k)、是否帶因果遮罩 (Causal Mask)、頭維度 ($d=64$ 或 $d=128$)。 * **固定參數:** 隱藏維度 $\text{hidden dim} = 2048$,總 $\text{tokens} = 16\text{k}$ (通過調整 Batch Size 實現)。 #### FLOPs 計算公式 * **順向傳播 FLOPs:** $4 \cdot \text{seqlen}^2 \cdot \text{head dimension} \cdot \text{number of heads}$。 * *因果遮罩:* 由於計算量約減半,需將上述數值除以 2。 * **逆向傳播 FLOPs:** 順向傳播 FLOPs $\times 2.5$(因為順向傳播有 2 次 Matmul,逆向傳播有 5 次 Matmul)。 #### 效能對比(A100 GPU) | 對比對象 | 加速比 FlashAttention2 vs. 對象) | 額外說明 | | :--- | :--- | :--- | | **FlashAttention** | **$1.7-3.0 \times$** (約 $2 \times$) | **主要優化目標**。比 xformers (Cutlass 實作) 的 FlashAttention 也快約 $2 \times$。 | | **FlashAttention in Triton** | 順向:$1.3-1.5 \times$;逆向:約 $2 \times$ | 證明 FlashAttention2 的優化策略比 Triton 的實作更有效。 | | **標準 PyTorch 實作** | 高達 $10 \times$ | | | **最高吞吐量 (A100)** | **$230 \text{TFLOPs/s}$** | 達到理論最大吞吐量的 **73\%**。 | #### H100 GPU 基準測試 * **測試方式:** 在 H100 GPU 上運行相同的 FlashAttention2 實作(未利用 H100 的新特性,如 TMA 和第四代 Tensor Cores)。 * **結果:** 達到高達 **$335 \text{TFLOPs/s}$**。 * **未來展望:** 預計利用 H100 的新指令集可實現額外 $1.5 \times - 2 \times$ 的加速。 --- ## 4.2 端到端效能 (End-to-end Performance) ![image](https://hackmd.io/_uploads/BkF_Gy2Cel.png) ![image](https://hackmd.io/_uploads/B1bqMynCeg.png) ![image](https://hackmd.io/_uploads/rkVozk2Cgl.png) ![image](https://hackmd.io/_uploads/H1bnzkhRle.png) ![image](https://hackmd.io/_uploads/Hk00GynCxe.png) #### 實驗設定 * **模型:** GPT 類模型 (1.3B 和 2.7B 參數)。 * **硬體:** $8 \times \text{A100 80GB SXM}$。 * **上下文長度:** $2\text{k}$ 或 $8\text{k}$。 #### 訓練 FLOPs 計算公式 遵循 Megatron-LM 等文獻的公式,計算總訓練 FLOPs (順向+逆向): $$\text{Total FLOPs} = (6 \cdot \text{seqlen} \cdot \text{number of params}) + (12 \cdot \text{number of layers} \cdot \text{hidden dim} \cdot \text{seqlen}^2)$$ * *爭議說明:* 論文中提到,第二項(注意力 FLOPs)在有因果遮罩時可能應該減半,但為了與文獻保持一致,選擇**不減半**。 #### 效能結果 (TFLOPs/s/GPU) | 模型設定 | Without FlashAttention | FlashAttention | FlashAttention2 | FlashAttention2 加速比 (相較 FlashAttention) | FlashAttention2 加速比 (相較 Baseline) | | :--- | :--- | :--- | :--- | :--- | :--- | | GPT3-1.3B 2k | 142 | 189 | 196 | $1.04 \times$ | $1.38 \times$ | | GPT3-1.3B **8k** | 72 | 170 | **220** | **$1.29 \times$** | $3.05 \times$ | | GPT3-2.7B 2k | 149 | 189 | 205 | $1.08 \times$ | $1.38 \times$ | | GPT3-2.7B **8k** | 80 | 175 | **225** | **$1.29 \times$** | **$2.81 \times$** | * **最大速度:** 在 **8k** 上下文長度時達到最大速度 **$225 \text{TFLOPs/s}$**,相較於 baseline(Without FlashAttention)提速高達 **$2.8 \times$**,相較於 FlashAttention 提速約 **$1.3 \times$**。 * **利用率:** 實現了 **72\%** 的模型 FLOPs 利用率。 ## 5 Discussion and Future Directions ### **主要影響與價值** * **成本效益:** **FlashAttention2** 比 **FlashAttention** 快 **$2 \times$**。 * **實際意義:** 可以用與訓練 $8\text{k}$ 上下文模型**相同的成本**,訓練具有 $16\text{k}$ 或更長上下文的模型。 * **應用潛力:** * **理解能力提升:** 用於理解長篇書籍、報告、高解析度圖像、音訊和影片。 * **效率提升:** 加速現有模型的訓練 (Training)、微調 (Finetuning) 和推理 (Inference)。