# FlashAttention-2 : Faster Attention with Better Parallelism and Work Partitioning
論文 : [FlashAttention-2](https://arxiv.org/abs/2205.14135)
## abstract
### 核心問題:Transformer 長序列處理的瓶頸
* **痛點與潛力:**
* **問題:** 將 **Transformer** 擴展到更長的序列長度是一個主要問題。
* **潛力:** 解決此問題可望提升語言模型、高解析度圖像理解的性能,並開啟程式碼、音訊和影片生成等新應用。
* **瓶頸根源:**
* **元兇:** **注意力層 (Attention Layer)**。
* **具體影響:** 其運行時間(runtime)和記憶體(memory)消耗會隨著序列長度呈 **二次方 ($O(N^2)$)** 增長。
## 現有解決方案及其局限性
### 1\. 現有工作:FlashAttention
| 方面 | 描述 | 效能表現 | 局限性/待解決問題 |
| :--- | :--- | :--- | :--- |
| **方法** | 利用 **GPU 記憶體階層的不對稱性**。 | 1. 顯著的記憶體節省(**線性** $O(N)$ 而非二次方 $O(N^2)$)。<br>2. 運行時間加速(比優化過的基線快 **2-4倍**)。<br>3. **無近似** (no approximation)。 | **效能低落:** 僅達到理論最大 **FLOPs/s 的 25-40%**,遠不及優化後的矩陣乘法 (GEMM) 運算效率。 |
### 2\. 效率低落的原因
* **根本原因:** GPU 上不同 **執行緒區塊 (thread blocks)** 和 **Warp** 之間的工作劃分不夠理想。
* **導致結果:**
* **低佔用率 (low-occupancy)**。
* **不必要的共享記憶體讀/寫 (unnecessary shared memory reads/writes)**。
## 提出的新方法:FlashAttention-2
| 方面 | 描述 |
| :--- | :--- |
| **目標** | 解決 FlashAttention 的效率問題,提高 FLOPs/s 利用率。 |
| **核心策略** | 更好的 **工作劃分 (work partitioning)**。 |
### 關鍵改進點 (3項具體優化)
| 序號 | 優化項目 | 目的/效果 |
| :--- | :--- | :--- |
| (1) | **演算法微調 (Tweak the algorithm)** | 減少 **非矩陣乘法 (non-matmul) 的 FLOPs 數量**。 |
| (2) | **跨執行緒區塊平行化 (Parallelize across thread blocks)** | 即使是單個注意力頭 (single head) 的計算,也分散到不同的執行緒區塊上,以 **增加佔用率 (increase occupancy)**。 |
| (3) | **執行緒區塊內的 Warp 工作分配** | 在每個執行緒區塊內,合理分配 Warp 的工作,以 **減少透過共享記憶體的通訊 (reduce communication)**。 |
### 效能成果
* **加速比:** 相較於 FlashAttention 帶來約 **2倍** 的加速。
* **理論峰值:** 在 A100 上達到理論最大 **FLOPs/s 的 50-73%**,逼近 GEMM 運算的效率。
* **端到端訓練:** 用於訓練 GPT 類模型時,達到每個 A100 GPU **225 TFLOPs/s** 的訓練速度(相當於 **72% 的模型 FLOPs 利用率**)。
## Introduction
### 1. 研究背景與核心問題
* **基礎模型:** **Transformer** ($\sim$2017)
* **挑戰:** 擴展 Transformer 的**上下文長度 (context length)**。
* **瓶頸:** **注意力層 (Attention Layer)** 的運行時間和記憶體需求與輸入序列長度呈**二次方 ($O(N^2)$)** 關係。
* **理想目標:** 超越標準的 2k 序列長度限制,訓練模型以理解書籍、高解析度圖像和長篇影片。
* **趨勢證明:** 近期大型語言模型(LLMs)的上下文長度顯著增長:
* **GPT-4:** 32k
* **MPT:** 65k
* **Claude:** 100k
* **應用需求:** 長文件查詢、故事寫作等新興用例需要長上下文模型。
### 2. 現有解決方案與 (FlashAttention)
* **近似方法 (Approximation Methods):**
* **目標:** 減少長上下文注意力計算的需求。
* **例子:** Reformer, Linformer, Longformer, BigBird 等。
* **現狀:** 儘管存在,但**大多數大規模訓練仍使用標準注意力**。
* **優化方法:\sysnameone (FlashAttention)**
* **作者:** \citet{dao2022flashattention}
* **核心思想:** 重排注意力計算,並利用**平鋪 (tiling)** 和**重新計算 (recomputation)** 的經典技術。
* **優勢:**
1. 記憶體使用從二次方降至**線性 ($O(N)$)**。
2. 運行時間比優化後的基線快 **2-4倍** (wall-clock time)。
3. 記憶體節省達 **10-20倍**。
4. **無近似 (no approximation)**。
* **結果:** \sysnameone 已在 Transformer 的大規模訓練和推理中獲得廣泛採用。
### 3. FlashAttention 的局限性與效率瓶頸
* **問題:** 儘管 FlashAttention 已經很快,但**效率仍遠不如矩陣乘法 (GEMM)** 等其他基本運算。
* **具體數據:**
* **FlashAttention 順向傳播 (Forward Pass):** 僅達到理論最大 **FLOPs/s 的 30-50\%**。
* **FlashAttention 逆向傳播 (Backward Pass):** 僅達到理論最大 **FLOPs/s 的 25-35\%** (A100 GPU)。
* **對比 GEMM:** 優化後的 GEMM 可達到理論最大設備吞吐量的 **80-90\%**。
* **根本原因:** **次優的工作劃分 (suboptimal work partitioning)**
* 在 GPU 上,不同**執行緒區塊 (thread blocks)** 和 **Warp** 之間的工作劃分不佳。
* 導致:**低佔用率 (low-occupancy)** 或**不必要的共享記憶體讀/寫**。
### 4. 提出的新方法:FlashAttention2 (貢獻點)
新方法 **FlashAttention2** 建立在 FlashAttention 的基礎上,通過更好的**並行化 (parallelism)** 和**工作劃分 (work partitioning)** 來解決效率挑戰。
| 序號 | 貢獻點/優化內容 | 目的/原因 | 實作位置 |
| :--- | :--- | :--- | :--- |
| **1.** | **微調演算法以減少非 Matmul FLOPs** | **目的:** 減少不依賴矩陣乘法單元的運算量。**原因:** GPU 上的 Matmul 吞吐量比非 Matmul 吞吐量高達 **16倍**,應盡量將時間花在 Matmul 運算上。 | \cref{subsec:algo} |
| **2.** | **沿序列長度維度平行化** | **目的:** 增加 GPU 資源的**佔用率 (occupancy)**。**應用情境:** 當序列很長(通常批次大小 batch size 較小時),可提高資源利用。 | (方法論部分) |
| **3.** | **執行緒區塊內 Warp 的工作劃分** | **目的:** 減少**通訊 (communication)** 和**共享記憶體讀/寫**。 | (方法論部分) |
### 5. 實驗結果
* **核心成果:** 相較於 FlashAttention 實現了顯著加速。
* **加速比:** 相較於 FlashAttention,實現約 **2倍** 的加速。
* **理論峰值:**
* 順向傳播:高達理論最大吞吐量的 **73\%**。
* 逆向傳播:高達理論最大吞吐量的 **63\%**。
* **端到端訓練:** 訓練 GPT 類模型時,每個 A100 GPU 達到 **225 TFLOPs/s** 的訓練速度。
## Background
### 2.1 硬體特性 (Hardware characteristics)
#### GPU 性能特徵
* **結構:** 由計算元件(如浮點運算單元)和**記憶體階層 (memory hierarchy)** 組成。
* **專用單元:** 現代 GPU 包含用於加速低精度矩陣乘法(如 FP16/BF16)的**專用單元**(例如:Nvidia GPU 上的 **Tensor Cores**)。
* **記憶體階層:**
* **高頻寬記憶體 (HBM, High Bandwidth Memory):** 頻寬較低,容量大。
* *A100 範例:* 40-80GB, 頻寬 $\sim 1.5-2.0 \text{TB/s}$。
* **片上 SRAM (On-chip SRAM) / 共享記憶體 (Shared Memory):** 頻寬高,容量小
* *A100 範例:* 每個 **Streaming Multiprocessor (SM)** 有 192KB,頻寬估計 $\sim 19 \text{TB/s}$。
#### 執行模型 (Execution Model)
* **執行單位:**
* **Kernel:** 執行的操作。
* **Threads (執行緒):** 數量龐大。
* **組織結構:**
* **Thread Blocks (執行緒區塊):** 執行緒被組織成區塊,分配給 SM 執行。
* **Warps:** 每個 Thread Block 內,執行緒被分組為 Warps(**32 個執行緒**一組)。
* **通訊機制:**
* **Warp 內部:** 可透過快速的 **Shuffle Instructions** 或合作進行矩陣乘法來通訊。
* **Thread Block 內部(Warps 之間):** 透過讀/寫**共享記憶體 (Shared Memory)** 來通訊。
* **計算流程:** Kernel 將輸入從 HBM 載入到**暫存器 (registers)** 和 SRAM,進行計算,然後將輸出寫回 HBM。
### 2.2 標準注意力實作 (Standard Attention Implementation)
#### 順向傳播 (Forward Pass)
* **輸入:** 查詢 $Q$、鍵 $K$、值 $V$ (皆為 $\mathbb{R}^{N \times d}$)。
* $N$: 序列長度 (Sequence Length)。
* $d$: 頭維度 (Head Dimension)。
* **計算步驟 (三個主要步驟):**
1. **相似度計算:** $S = Q K^\top \in \mathbb{R}^{N \times N}$
2. **機率計算:** $P = softmax(S) \in \mathbb{R}^{N \times N}$ (逐行 $softmax$)
3. **輸出計算:** $O = P V \in \mathbb{R}^{N \times d}$
* *備註:* MHA (Multi-Head Attention) 是在 Batch 和 Heads 維度上平行執行此計算。
#### 逆向傳播 (Backward Pass)
* **輸入:** $O$ 的梯度 $dO \in \mathbb{R}^{N \times d}$。
* **梯度計算 (Chain Rule):**
$$dV = P^\top dO$$
$$dP = dO V^\top$$
$$dS = dsoftmax (dP)$$
$$dQ = dS K$$
$$dK = Q dS^\top$$
* 其中 $dsoftmax$ 是 $softmax$ 的梯度計算,其計算涉及到 $P$ 和 $dP$。
#### 實作局限性 (Standard Implementation Bottlenecks)
* **記憶體問題:** 矩陣 $S$ 和 $P$ 被**實體化 (materialize)** 到 **HBM**。
* **記憶體需求:** $O(N^2)$ (二次方)。
* **瓶頸:** 需要儲存 $O(N^2)$ 的 $P$ 矩陣用於逆向傳播。
* **序列長度比較:** 通常 $N \gg d$ (例如 $N \sim 1\text{k-8k}$,而 $d \sim 64\text{-}128$),使 $N^2$ 成為主要負擔。
* **速度問題:** 標準實作通常涉及三個分開的步驟,且中間結果 $(S, P)$ 寫入 HBM。
1. 呼叫 GEMM ($Q K^\top$),結果寫入 HBM。
2. 從 HBM 載入 $S$,計算 $softmax$,結果 $P$ 寫入 HBM。
3. 呼叫 GEMM ($P V$)。
* **瓶頸:** 大量的記憶體存取導致**記憶體頻寬受限 (memory bandwidth bounded)**,使得 wall-clock time 變慢。
### FlashAttention

### 核心思想:減少記憶體 I/O,保持輸出不變
* **目標:** 在硬體加速器(如 GPU)上加速注意力計算,通過減少**記憶體讀/寫次數 (memory reads/writes)** 來實現,同時**不使用近似 (without approximation)**,保持輸出相同。
* **關鍵技術:**
1. **分塊平鋪 (Tiling):** 將計算分解成塊。
2. **在線 Softmax (Online Softmax):** 允許按塊計算注意力並在最後進行校正。
---
### 順向傳播 (Forward Pass)
#### 1. 分塊與計算流程
\sysnameone 利用平鋪技術實現流程:
1. **載入:** 將輸入的塊 (blocks of inputs) 從 **HBM** 載入到 **SRAM**。
2. **計算:** 針對該塊計算注意力。
3. **更新:** **更新輸出 (output)**,而**不將大的中間矩陣 $S$ 和 $P$ 寫入 HBM**。
#### 2. 在線 Softmax 機制
* **必要性:** 標準 Softmax 涉及整行或整塊的耦合計算。**在線 Softmax** 允許將注意力計算分割成塊,並通過**重新縮放 (rescale)** 每一塊的輸出,最終得到正確的結果(無近似)。
* **Online Softmax 數學原理:**
* 對於包含 $S^{(1)}$ 和 $S^{(2)}$ 兩個塊的一行注意力分數:
* **標準 Softmax** 需要計算全局的最大值 $m$ 和歸一化因子 $\ell$。
* **在線 Softmax** 則採取迭代更新:
1. 先計算第一個塊的局部最大值 $m^{(1)}$ 和歸一化因子 $\ell^{(1)}$,並計算局部輸出 $O^{(1)}$。
2. 計算第二個塊時,更新全局最大值 $m^{(2)} = \max(m^{(1)}, \mathrm{rowmax}(S^{(2)}))$。
3. 使用 $m^{(2)}$ 和 $\ell^{(1)}$ 更新全局歸一化因子 $\ell^{(2)}$。
4. 利用 $\ell^{(2)}$ 重新縮放 $O^{(1)}$,並加上第二塊的輸出,得到最終輸出 $O$。
* **圖示:** \cref{fig:flash_attention_diagram} 說明了通過對 $K$ 和 $V$ 的分塊,實現對每個塊計算注意力並重新縮放輸出,從而避免昂貴的 $S$ 和 $P$ 記憶體讀寫。
#### 3. 效能提升
* **加速比:** 相較於優化後的基線注意力實作,實現 **2-4倍** 的 wall-clock time 加速。
---
### 逆向傳播 (Backward Pass)
#### 1. 記憶體節省與重新計算
* **策略:** 通過**重新計算 (re-computing)** 注意力矩陣 $S$ 和 $P$ 的值。
* **時機:** 當輸入塊 $Q, K, V$ 已載入到 SRAM 時,即時重新計算所需的中間值。
* **效果:** 避免了將 $O(N^2)$ 的 $S$ 和 $P$ 矩陣儲存到 HBM,實現:
* **記憶體節省:** 視序列長度,可達 **10-20倍**。
* **記憶體複雜度:** 從二次方 $O(N^2)$ 降至**線性 $O(N)$**。
#### 2. 複雜性
* **計算量:** 逆向傳播概念上較簡單(無 Softmax 重新縮放),但實作複雜得多。
* **原因:** 逆向傳播需要執行 **5 個矩陣乘法**($dV, dP, dQ, dK$ 的計算),而順向傳播只有 2 個,因此需要**在 SRAM 中保留更多數值**。
## FlashAttention2: Algorithm, Parallelism, and Work Partitioning

### 概述
* **目標:** 實現比 FlashAttention 更高的效率,特別是提高 FLOPs/s 利用率。
* **主要改進:**
1. **演算法調整:** 減少非矩陣乘法 (non-matmul) FLOPs。
2. **平行化:** 利用不同的執行緒區塊 (thread blocks) 實現全 GPU 資源利用。
3. **工作劃分:** 在單一執行緒區塊內,劃分 Warp 的工作,減少共享記憶體存取。
* **預期成果:** 實驗結果顯示可實現 **2-3 倍** 的加速。
---
### 3.1 演算法 (Algorithm)
* **動機:** 現代 GPU 對矩陣乘法有專門的計算單元(如 Nvidia 的 **Tensor Cores**),使其速度遠超非 Matmul 運算。
* **效率差距:** 在 A100 上,FP16/BF16 Matmul 的理論最大吞吐量是 **$312 \text{TFLOPs/s}$**,而非 Matmul FP32 僅 **$19.5 \text{TFLOPs/s}$**。
* **成本比:** 每個非 Matmul FLOP 的成本比 Matmul FLOP **高 16 倍**。
* **核心策略:** 盡可能將計算時間用於 Matmul FLOPs,以維持高吞吐量。
#### 3.1.1 順向傳播 (Forward Pass) 的兩處微調
\sysname 在 \sysnameone 的基礎上,對 **在線 Softmax (Online Softmax)** 技巧進行了兩處微調:
1. **延遲輸出縮放 (Delaying Output Scaling):**
* **FlashAttention 方式:** 在每次迭代中,對輸出更新的兩個項都進行 $diag(\ell^{(2)})^{-1}$ 的縮放。
* **FlashAttention2 方式 (優化):** 維護一個 **「未縮放」** 的輸出版本 $\tilde{O}^{(j)}$,並只在**循環結束時**,用最終的 $diag(\ell^{(\text{last})})^{-1}$ 因子進行一次最終縮放。
* **好處:** 減少了循環內的非 Matmul 運算(即對 $O^{(1)}$ 的重新縮放操作 $diag(e^{m^{(1)} - m^{(2)}})^{-1}$)。
2. **僅儲存 Logsumexp (Storing only Logsumexp):**
* **FlashAttention 方式:** 需要儲存 $\max (m^{(j)})$ 和 $\sum \exp (\ell^{(j)})$ 兩組統計量用於逆向傳播。
* **FlashAttention2 方式 (優化):** 僅儲存 **Logsumexp** $L^{(j)} = m^{(j)} + \log(\ell^{(j)})$。
* **好處:** 減少了需要儲存和處理的統計量數量,從而減少了非 Matmul FLOPs。
* **更新後的 Online Softmax 數學表達:** 論文中提供了 2 塊情況下的數學推導,確認了保持 $\tilde{O}^{(2)}$ 未縮放的計算,與最終縮放的結果 $O^{(2)}$ 等效。
#### 3.1.2 FlashAttention2 順向傳播演算法

1. **初始化:** 將 $Q$ 分成 $T_r$ 塊,$K, V$ 分成 $T_c$ 塊。初始化 $O$ 和 Logsumexp $L$ 的塊。
2. **外部循環(Row Block):** 迭代 $Q$ 的塊 $Q_i$ ($i = 1$ to $T_r$)。
* **載入與初始化:** 載入 $Q_i$ 到 SRAM。初始化**未縮放輸出 $O_{i}^{(0)}$**、**歸一化因子 $\ell_{i}^{(0)}$** 和 **最大值 $m_{i}^{(0)}$**。
* **內部循環(Column Block):** 迭代 $K, V$ 的塊 $K_j, V_j$ ($j = 1$ to $T_c$)。
* **載入與計算:** 載入 $K_j, V_j$。計算注意力分數 $S_{i}^{(j)} = Q_i K_j^T$。
* **統計量更新 (\ref{alg:stream_attn_statistics}):** 更新 $m_{i}^{(j)}$、計算 $\tilde{P}_{i}^{(j)} = \exp(S_{i}^{(j)} - m_{i}^{(j)})$,並更新 $\ell_{i}^{(j)}$。
* **輸出更新 (\ref{alg:stream_attn_update}):** 使用簡化的更新公式 $O_{i}^{(j)} = diag(e^{m_{i}^{(j-1)} - m_{i}^{(j)}})^{-1} O_{i}^{(j-1)} + \tilde{P}_{i}^{(j)} V_j$ 來維護未縮放的 $O_{i}^{(j)}$。
* **最終縮放與寫回:** 在內部循環結束後,對最終的 $O_{i}^{(T_c)}$ 進行一次性縮放 $O_{i} = diag(\ell_{i}^{(T_c)})^{-1} O_{i}^{(T_c)}$,並計算 Logsumexp $L_{i}$。將 $O_{i}$ 和 $L_{i}$ 寫回 HBM。
#### 3.1.3 因果遮罩 (Causal Masking)
* **應用場景:** 自迴歸語言模型(Auto-regressive language modeling)。
* **優化點:**
1. 對於**完全處於因果遮罩區**的塊(行索引 $i$ 的所有列索引 $j$ 都滿足 $j > i$),可以**跳過計算**,獲得約 **$1.7-1.8$ 倍**的加速。
2. 對於其他塊,只需要在**與對角線相交**的那一個塊上應用因果遮罩,減少非必要的遮罩操作。
#### 3.1.4 正確性、運行時間和記憶體
* **結果:** 保持了與 FlashAttention 相同的特性:
* **正確性:** 返回正確的輸出 $O$(無近似)。
* **運行時間:** $O(N^2 d)$ FLOPs。
* **記憶體需求:** $O(N)$ 額外記憶體(用於儲存 Logsumexp $L$)。
### 3.1.5 逆向傳播 (Backward Pass)

* **核心相似性:** FlashAttention 的逆向傳播與 FlashAttention2 **大致相同**。
* **演算法調整 (Tweak):** 唯一的微調是:**只使用逐行 Logsumexp $L$**,而不是同時使用 Softmax 的逐行最大值 $m$ 和逐行指數和 $\ell$。
* **意義:** 這與順向傳播的優化 (2) 相呼應,進一步減少了需要在 HBM 中讀取和處理的統計量數量(非 Matmul FLOPs 減少)。
#### FlashAttention2 逆向傳播演算法 (Algorithm \ref{alg:flash_bwd})
該演算法利用分塊和重新計算來高效地計算 $dQ, dK, dV$ 的梯度。
1. **分塊與初始化:**
* 將所有輸入和輸出的矩陣($Q, K, V, O, dO$)和 Logsumexp $L$ 分成 $T_r$ 或 $T_c$ 塊。
* 初始化梯度 $dQ, dK, dV$ 的塊($dQ$ 初始化在 HBM,$dK_j, dV_j$ 在 SRAM)。
* **中間計算:** 首先計算 $D = \mathrm{rowsum}(dO \circ O)$(梯度調節因子的一部分),並寫回 HBM。
2. **外部循環(Column Block):** 迭代 $K, V$ 的塊 $K_j, V_j$ ($j = 1$ to $T_c$)。
* **載入:** 載入 $K_j, V_j$ 到 SRAM。
* **內部循環(Row Block):** 迭代 $Q, O,dO, L$ 的塊 ($i = 1$ to $T_r$)。
* **載入:** 載入 $Q_i, O_i, dO_i, L_i, D_i$ 等到 SRAM。
* **重新計算 $S, P$:** 在 SRAM 上計算 $S_{i}^{(j)} = Q_i K_j^T$ 和 $P_{i}^{(j)} = \exp(S_{ij} - L_{i})$。
* **梯度計算與累積 (5個矩陣乘法/更新):**
* **$dV$ 更新:** $dV_j \leftarrow dV_j + (P_{i}^{(j)})^\top dO_i$
* **$dP$ 計算:** $dP_{i}^{(j)} = dO_{i} V_j^\top$
* **$dS$ 計算:** $dS_{i}^{(j)} = P_{i}^{(j)} \circ (dP_{i}^{(j)} - D_i)$ (這包含 Softmax 梯度的計算)
* **$dQ$ 更新:** $dQ_{i} \leftarrow dQ_i + dS_{i}^{(j)} K_j$(**注意:** $dQ_i$ 是唯一在內部循環中被載入、更新並**寫回 HBM** 的梯度塊)
* **$dK$ 更新:** $dK_{j} \leftarrow dK_j + {dS_{i}^{(j)}}^\top Q_i$
* **寫回:** 在外部循環結束時,將 $dK_j, dV_j$ 寫回 HBM。
---
### 多查詢與分組查詢注意力 (MQA and GQA)
* **概念:** 這是注意力的變體,其中多個查詢頭 (Query heads) 共享相同的鍵和值頭 (Key and Value heads)。
* **目的:** 主要用於**推理 (inference)** 期間,以減少 KV 緩存 (KV cache) 的大小。
* **FlashAttention2 的實作:** 透過**隱式地操作頭的索引**來執行相同的計算,而不是實際複製鍵和值頭。
* **逆向傳播需求:** 在逆向傳播中,需要將**隱式複製**的鍵 $dK$ 和值 $dV$ 梯度在不同的頭上**求和 (sum)**。
### Parallelism
### 核心問題與動機
* **FlashAttention 的平行化:** 僅沿 **批次大小 (Batch Size)** 和 **頭數 (Number of Heads)** 維度進行平行化。
* **執行單元分配:** 每個注意力頭由 **1 個執行緒區塊 (Thread Block)** 處理。總共有 $\text{Batch Size} \times \text{Number of Heads}$ 個 Thread Blocks。
* **效率瓶頸:**
* 當 $\text{Batch Size} \times \text{Number of Heads}$ 數量**大**(例如 $\geq 80$)時,GPU 資源(如 A100 的 108 個 SMs)可被有效利用。
* 在**長序列 (Long Sequences)** 的情況下,通常會導致 **小批次大小** 或 **小頭數**。此時 Thread Blocks 數量變少,導致 GPU **多重處理器 (SMs) 的利用率低 (low occupancy)**。
* **FlashAttention2 的目標:** 額外沿**序列長度維度 (Sequence Length dimension)** 增加平行化,以提高在長序列情況下的 SM 利用率和速度。
---
### 順向傳播 (Forward Pass) 的平行化
* **機會:** 在 \cref{alg:flash2_fwd} 中,**外部循環(沿序列長度維度的行塊 $T_r$ 迭代)** 具有**高度平行性 (embarrassingly parallel)**。
* **實作:** 將這些外部循環的計算(即每個 $Q_i$ 塊的計算)調度給**不同的執行緒區塊**。
* **優勢:** 這些執行緒區塊彼此**無需通訊 (do not need to communicate)**。
* **效果:** 當 Batch Size 和 Heads 數量小時,增加沿序列長度維度的平行化,可有效**提高佔用率 (occupancy)**,從而加速。
* **迴圈順序:** 這種將外部迴圈設為行塊、內部迴圈設為列塊的做法,以及沿序列長度維度平行化的想法,最早由 Phil Tillet 在 Triton 實作中提出。
* **示意圖(\cref{fig:parallelism} 左圖):** 每個工作單元(Thread Block)負責注意力矩陣的**一行塊 (Block of Rows)**。
---
### 逆向傳播 (Backward Pass) 的平行化
* **挑戰:** 逆向傳播的計算邏輯較為複雜,不同塊之間的**共享計算**主要發生在 $dQ$ 的更新上。
* 在 $dQ$ 更新中,需要將 $dQ_i$ 從 HBM 載入 SRAM,更新 $dQ_{i} \leftarrow dQ_i + dS_{i}^{(j)} K_j$,然後寫回 HBM。
* **實作:** 沿序列長度維度進行平行化。
* **調度:** 每個**列塊 (Column Block)** 的計算(即 \cref{alg:flash_bwd} 中的外部循環 $j$)由 **1 個執行緒區塊**處理。
* **通訊機制:** 由於多個 Thread Blocks 可能需要更新同一個 $dQ_i$,因此使用 **原子加法 (Atomic Adds)** 來協調不同執行緒區塊之間的 $dQ$ 梯度更新。
* **示意圖(\cref{fig:parallelism} 右圖):** 每個工作單元(Thread Block)負責注意力矩陣的**一列塊 (Block of Columns)**。
### Work Partitioning Between Warps
### 核心問題:FlashAttention 的 Warp 通訊瓶頸
* **背景:** 每個執行緒區塊 (Thread Block) 內,通常使用 4 或 8 個 Warps 來進一步劃分工作。
* **FlashAttention 的方法(Split-K Scheme):**
* **劃分方式:** 將 $K$ 和 $V$ **劃分 (Split)** 給 4 個 Warps,而 $Q$ 保持所有 Warps 都可存取。
* **計算流程:**
1. 每個 Warp 計算 $Q K^\top$ 的一個切片 ($S$)。
2. 每個 Warp 用 $S$ 乘上 $V$ 的切片,得到中間結果。
3. **問題:** 所有 Warps 需要將它們的中間結果**寫入共享記憶體 (Shared Memory)**。
4. **問題:** Warps 必須**同步 (synchronize)**,然後將這些中間結果**加總 (add up)** 得到最終的輸出。
* **效率低落原因:** **共享記憶體的讀/寫和同步操作** 減慢了 \sysnameone 的順向傳播速度。
### FlashAttention2 的優化方法:減少通訊
* **FlashAttention2 的方法 (Split-Q Scheme):**
* **劃分方式:** 將 $Q$ **劃分 (Split)** 給 4 個 Warps,而 $K$ 和 $V$ 保持所有 Warps 都可存取(即共享)。
* **計算流程:**
1. 每個 Warp 計算 $Q K^\top$ 的一個切片 ($S_{slice}$)。
2. 每個 Warp 使用**共享**的 $V$ 切片來計算其對應的**輸出切片 (slice of the output)**。
* **效果:** **Warps 之間無需通訊**。
* **結論:** 減少了共享記憶體的讀/寫操作,帶來顯著的速度提升。

#### 逆向傳播 (Backward Pass) 的工作劃分
* **策略:** 逆向傳播也同樣**避免採用 Split-K 方案**,從而減少共享記憶體的讀寫。
* **複雜性:** 由於涉及更多的輸入和梯度($Q, K, V, O, dO, dQ, dK, dV$)之間更複雜的依賴關係,逆向傳播**仍需要一些同步**。
* **效果:** 儘管複雜,避免 Split-K 仍然有效減少了共享記憶體 I/O,帶來速度提升。
### 區塊大小調優 (Tuning Block Sizes)
* **影響因素:**
* **增大區塊大小的好處:** 通常可以**減少共享記憶體的載入/儲存 (loads/stores)**。
* **增大區塊大小的壞處:** 會**增加所需的暫存器 (registers)** 數量和總共享記憶體需求。
* **臨界限制:** 當區塊大小超過臨界值時,可能導致:
1. **暫存器溢出 (Register Spilling):** 造成顯著的減速。
2. **共享記憶體不足:** 所需記憶體超過 GPU 限制,導致 Kernel 無法運行。
* **選擇:** 通常手動選擇 $\{64, 128\} \times \{64, 128\}$ 範圍內的區塊大小,具體取決於頭維度 $d$ 和設備的共享記憶體大小。
* **未來工作:** 論文提到這是一個可以透過**自動調優 (auto-tuning)** 來避免手動勞動的領域。
## 4 Empirical Validation
### 總體概括
* **評估範圍:** 評估 FlashAttention2 在訓練 Transformer 模型時的影響。
* **主要結果:**
* **注意力基準測試:** FlashAttention2 比 FlashAttention 快 **$1.7-3.0$ 倍**。在 A100 上達到高達 **$230 \text{TFLOPs/s}$** (理論峰值的 **73\%**)。
* **端到端訓練速度:** 相較於 FlashAttention 提速高達 **$1.3$ 倍**。在 A100 上達到高達 **$225 \text{TFLOPs/s}$** (模型 FLOPs 利用率的 **72\%**)。
---
## 4.1 注意力基準測試 (Benchmarking Attention)
#### 實驗設定
* **硬體:** A100 80GB SXM4 GPU。
* **變量:** 序列長度 (512 到 16k)、是否帶因果遮罩 (Causal Mask)、頭維度 ($d=64$ 或 $d=128$)。
* **固定參數:** 隱藏維度 $\text{hidden dim} = 2048$,總 $\text{tokens} = 16\text{k}$ (通過調整 Batch Size 實現)。
#### FLOPs 計算公式
* **順向傳播 FLOPs:** $4 \cdot \text{seqlen}^2 \cdot \text{head dimension} \cdot \text{number of heads}$。
* *因果遮罩:* 由於計算量約減半,需將上述數值除以 2。
* **逆向傳播 FLOPs:** 順向傳播 FLOPs $\times 2.5$(因為順向傳播有 2 次 Matmul,逆向傳播有 5 次 Matmul)。
#### 效能對比(A100 GPU)
| 對比對象 | 加速比 FlashAttention2 vs. 對象) | 額外說明 |
| :--- | :--- | :--- |
| **FlashAttention** | **$1.7-3.0 \times$** (約 $2 \times$) | **主要優化目標**。比 xformers (Cutlass 實作) 的 FlashAttention 也快約 $2 \times$。 |
| **FlashAttention in Triton** | 順向:$1.3-1.5 \times$;逆向:約 $2 \times$ | 證明 FlashAttention2 的優化策略比 Triton 的實作更有效。 |
| **標準 PyTorch 實作** | 高達 $10 \times$ | |
| **最高吞吐量 (A100)** | **$230 \text{TFLOPs/s}$** | 達到理論最大吞吐量的 **73\%**。 |
#### H100 GPU 基準測試
* **測試方式:** 在 H100 GPU 上運行相同的 FlashAttention2 實作(未利用 H100 的新特性,如 TMA 和第四代 Tensor Cores)。
* **結果:** 達到高達 **$335 \text{TFLOPs/s}$**。
* **未來展望:** 預計利用 H100 的新指令集可實現額外 $1.5 \times - 2 \times$ 的加速。
---
## 4.2 端到端效能 (End-to-end Performance)





#### 實驗設定
* **模型:** GPT 類模型 (1.3B 和 2.7B 參數)。
* **硬體:** $8 \times \text{A100 80GB SXM}$。
* **上下文長度:** $2\text{k}$ 或 $8\text{k}$。
#### 訓練 FLOPs 計算公式
遵循 Megatron-LM 等文獻的公式,計算總訓練 FLOPs (順向+逆向):
$$\text{Total FLOPs} = (6 \cdot \text{seqlen} \cdot \text{number of params}) + (12 \cdot \text{number of layers} \cdot \text{hidden dim} \cdot \text{seqlen}^2)$$
* *爭議說明:* 論文中提到,第二項(注意力 FLOPs)在有因果遮罩時可能應該減半,但為了與文獻保持一致,選擇**不減半**。
#### 效能結果 (TFLOPs/s/GPU)
| 模型設定 | Without FlashAttention | FlashAttention | FlashAttention2 | FlashAttention2 加速比 (相較 FlashAttention) | FlashAttention2 加速比 (相較 Baseline) |
| :--- | :--- | :--- | :--- | :--- | :--- |
| GPT3-1.3B 2k | 142 | 189 | 196 | $1.04 \times$ | $1.38 \times$ |
| GPT3-1.3B **8k** | 72 | 170 | **220** | **$1.29 \times$** | $3.05 \times$ |
| GPT3-2.7B 2k | 149 | 189 | 205 | $1.08 \times$ | $1.38 \times$ |
| GPT3-2.7B **8k** | 80 | 175 | **225** | **$1.29 \times$** | **$2.81 \times$** |
* **最大速度:** 在 **8k** 上下文長度時達到最大速度 **$225 \text{TFLOPs/s}$**,相較於 baseline(Without FlashAttention)提速高達 **$2.8 \times$**,相較於 FlashAttention 提速約 **$1.3 \times$**。
* **利用率:** 實現了 **72\%** 的模型 FLOPs 利用率。
## 5 Discussion and Future Directions
### **主要影響與價值**
* **成本效益:** **FlashAttention2** 比 **FlashAttention** 快 **$2 \times$**。
* **實際意義:** 可以用與訓練 $8\text{k}$ 上下文模型**相同的成本**,訓練具有 $16\text{k}$ 或更長上下文的模型。
* **應用潛力:**
* **理解能力提升:** 用於理解長篇書籍、報告、高解析度圖像、音訊和影片。
* **效率提升:** 加速現有模型的訓練 (Training)、微調 (Finetuning) 和推理 (Inference)。