# FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 論文連結 : [FlashAttention](https://arxiv.org/abs/2205.14135) ## abstract * Transformer 的 **自注意力(self-attention)計算複雜度為 (O(n^2))**,當序列長度變長時,**時間與記憶體需求快速成長** * 這使得 Transformer 在 **長序列(long sequence)任務上非常慢且吃記憶體** * 過去已有許多 **近似式注意力(approximate attention)** 企圖改善此問題,方法通常藉由: * 降低計算量 * 允許模型品質下降(trade-off) * 但這些方法多數 **沒有真正帶來速度提升**(理論變快但 GPU 實跑不快) | 目標 | 說明 | | --------------------------- | --------------------------- | | 突破 GPU 記憶體讀寫瓶頸 | 將 HBM ↔ SRAM 的 IO 成本納入演算法設計 | | 保持注意力「精確」 | 不採用近似法也能加速 | | 提供可延伸到 sparse attention 的版本 | 在需要近似時,也能取得比現有方法更快的結果 | ### 方法核心 | FlashAttention 核心概念 | 說明 | | -------------------------- | ---------------------------- | | **IO-aware 設計** | 將注意力計算法視為 GPU IO 問題,而不是純計算問題 | | **利用 tiling 設計** | 分塊運算避免頻繁讀寫 HBM(高延遲) | | **最大化使用 GPU on-chip SRAM** | 讓 attention 過程在小記憶體上完成更多運算 | 結果:**減少 HBM 存取次數 → 真正提升速度與節省記憶體** ### 實驗結果 | 任務 | seq length | 加速比(vs baseline) | | ---------------- | ---------- | ---------------------- | | BERT-large | 512 | **15% 更快(end-to-end)** | | GPT-2 | 1K | **3× faster** | | Long Range Arena | 1K–4K | **2.4× faster** | ## Introduction ### 現有解法的侷限 * 過去許多研究提出 **近似注意力(approximate attention)**,包含: * sparse-based * low-rank-based * hybrid-based * 儘管這些方法能將計算降到近似線性,但 **大多沒有在 wall-clock 上真正加速**,也 **未被主流採用**。 * 原因是它們: 1. **聚焦於 FLOPs 降低,而非整體 runtime** 2. **忽略記憶體存取成本(IO)是主要瓶頸** ### **本文核心觀點** * 「缺失的核心原則」是 **IO-aware attention**: * 需考慮「不同階層 GPU 記憶體間的讀寫成本」(SRAM ↔ HBM)。 * 在現代 GPU 中,**計算速度遠快於記憶體速度**,Transformer 許多運算 **受限於 IO 而非 compute**。 * 但現有框架(PyTorch/TensorFlow)無法提供對記憶體存取的細緻控制。 ### 提出的方法:FlashAttention FlashAttention 的目標:**計算 exact attention,但將 HBM 存取降到最少。** | 技術 | 目的 | | -------------------------------- | --------------------------------------------- | | **Tiling / incremental softmax** | 不需存整個 attention matrix,即可完成 softmax reduction | | **Recompute in backward** | 只存 normalization,而不是整個中間矩陣 | * 全部使用 **單一 CUDA kernel** 完成 attention → 減少不必要讀寫 * 即使 backward 多做計算(recompute),仍 **更快且更省記憶體** #### **理論貢獻** * FlashAttention 的 **HBM IO complexity**:$$O(N^2 d^2 / M)$$ * 標準 attention IO complexity:$$O(Nd + N^2)$$ * 作者證明: 1. **FlashAttention HBM 存取更少,可達 9× 減量** 2. **不存在更優的 exact attention(在 IO 複雜度上)** → 為最佳上界 #### **延伸:Block-sparse FlashAttention** * 可進一步加速 **2–4×**,序列長度可達 **64K tokens** * IO 複雜度優於 FlashAttention,且與 sparsity 成正相關(越 sparse 越快) #### **實證結果** FlashAttention 與 block-sparse 版本帶來三大成果: | 成果 | 證據 | | ---------- | --------------------------------------------- | | **速度提升** | BERT-large:+15%;GPT-2:3×;LRA:2.4× | | **模型品質提升** | 更長 context → GPT-2 perplexity +0.7、長文本分類 +6.4 | | **解鎖新任務** | 首次 > chance on Path-X(16K)與 Path-256(64K) | ## Background ## GPU 的記憶體階層與效能瓶頸 ### GPU 記憶體階層(GPU Memory Hierarchy) ![image](https://hackmd.io/_uploads/ByPKehiAlx.png) ![image](https://hackmd.io/_uploads/ryNv-6iAxx.png) GPU 具有不同層級的記憶體,**容量越小 → 速度越快** | 記憶體層級 | 位置 | 特性 | 速度 | 容量 | | ------------------------------------------- | ----------- | ------- | ------------------- | -------- | | **On-chip SRAM / Shared Memory / Register** | GPU 核心內 | 超高速、低延遲 | 非常快(TB/s 級) | 很小(KB 級) | | **HBM(High Bandwidth Memory)** | GPU 外掛 DRAM | 大容量但慢 | 較慢(比 SRAM 慢一個數量級以上) | 很大(GB 級) | 以 **NVIDIA A100 為例**: | 項目 | A100 指標 | | ------------------ | ------------ | | HBM 容量 | 40–80 GB | | HBM 速度 | 1.5–2.0 TB/s | | On-chip SRAM(每 SM) | 192 KB | | On-chip SRAM 速度 | ~19 TB/s | → **SRAM 比 HBM 快至少一個數量級,但容量小很多** ### 重點 * GPU 計算速度(FLOPs)成長 **遠快於** HBM 記憶體速度 * **越來越多深度學習運算變成 IO-bound,而不是 compute-bound** * Transformer attention: * 需要「反覆讀寫 Q/K/V 以及 attention matrix」 * 這些動作 **主要由 HBM → SRAM → HBM 溝通組成** * 瓶頸在 **HBM 存取次數,而非矩陣乘法 FLOPs** ### Execution Model * GPU 在執行一個運算時會啟動大量平行 thread,該運算稱為 **kernel** * kernel 的運作流程固定為: 1. **從 HBM 載入資料到 SRAM / registers** 2. **在 on-chip 記憶體上進行計算** 3. **再將結果寫回 HBM** → **每個 kernel 必然涉及至少一次 HBM → SRAM → HBM 的資料流** ### **Compute-bound vs Memory-bound** 運算效能依「計算 vs 記憶體存取」平衡被分成兩類: | 類型 | 主導時間因素 | 特徵 | 範例 | | ----------------- | ---------- | -------------- | --------------------------------------------------------------------------- | | **Compute-bound** | 計算量(FLOPs) | memory cost 很小 | 大維度 matmul、conv(大量 channel) | | **Memory-bound** | HBM 存取次數 | 計算成本很低但 IO 大 | elementwise(dropout, activation)、reduction(softmax, batch norm, layer norm) | * 衡量指標:**arithmetic intensity** = (運算次數)/(讀寫記憶體 byte 數) * arithmetic intensity 越低 → 越容易 **memory-bound** * **許多 Transformer 模組(尤其 softmax 與 elementwise)是 memory-bound** ### **Kernel Fusion** * **概念**:若多個 operation 作用在同一資料,將它們合併到同一 kernel → 只需從 HBM 讀一次,而非多次 * compilers(如 PyTorch JIT, XLA)已能自動 fuse 多個 elementwise ops * **限制**:在 *訓練* 中,backward 需要中間值,因此 intermediate 仍要寫回 HBM → kernel fusion ### Standard Attention Implementation 給定: | 符號 | 尺寸 | 意義 | | ------------------------------------- | --------------------- | ------------ | | $Q, K, V \in \mathbb{R}^{N \times d}$ | N = 序列長度,d = head dim | attention 輸入 | | $O \in \mathbb{R}^{N \times d}$ | output | 注意力結果 | 標準 attention公式: $$ S = QK^\top \in \mathbb{R}^{N \times N}, \quad P = \text{softmax}(S) \in \mathbb{R}^{N \times N}, \quad O = PV $$ * softmax **逐 row** 計算 * 會產生兩個 (O(N^2)) 級別中間矩陣:**注意力分數 (S)** 與 **attention 機率矩陣 (P)** #### **主要效能問題** * 這些矩陣全部 **materialize 到 HBM → 需 (O(N^2)) memory** * 在實務中常常 (N \gg d)(例如 GPT2: (N=1024, d=64)) * 計算過程中包含多個 memory-bound operations(如 softmax、mask、dropout) → **大量 HBM 存取** * 為 $P$ 或 $S$ 的 elementwise ops(mask / dropout / scaling)雖可 fuse * 但 **仍要把 $S$ 與 $P$ 寫回 HBM** → 本質仍是 $O(N^2)$ IO,瓶頸不會消失 ![image](https://hackmd.io/_uploads/HJd9Qhi0lx.png) ## Algorithm, Analysis, and Extensions #### **目標與核心成果** * 本節提出一個 **能計算 exact attention、但需更少 HBM 存取與記憶體的演算法** * 降低 IO → **速度更快、記憶體更省(backward也不需儲存 $O(N^2)$ 中間矩陣)** * 內容包含: 1. forward 的高效率 attention 演算法 2. IO complexity 分析(與 standard attention 比較) 3. block-sparse attention 擴展 * 本段先聚焦 **forward pass**(backward 置於 appendix) ### **實現** 作者為達到「sub-quadratic HBM access」採用兩項既有但少被 attention 正式結合的概念: | 技術 | 設計目的 | | --------------------- | ------------------------------------------------- | | **Tiling(分塊)** | 逐區塊計算 softmax 以避免 materialize $N \times N$ matrix | | **Recomputation(重算)** | backward 時不儲存 $S$ 與 $P$,僅存 softmax 統計量 | ### **Tiling 的核心想法(Forward)** * 將 $Q, K, V$ **切成多個 block**(row-block 與 column-block) * 每次只將少量 block **從 HBM 載入 SRAM → 計算 → 更新輸出** * softmax 被拆成 block-wise 累積計算,需要額外維護統計量: | 需累積統計 | 作用 | | --------------------------- | ------------------ | | $m$(row-wise max) | 用來穩定 softmax | | $\ell$(row-wise sum of exp) | 用來完成 normalization | * 可根據 softmax 性質做到 **增量式 aggregation**,因此最終結果與一次算完整 softmax 相同 ### Recomputation * backward 通常需要 $\mathbf{S}$ 與 $\mathbf{P}$(兩者為 $O(N^2)$ matrix) * FlashAttention 的做法: * forward **只儲存 $O$、$m$、$\ell$**(皆為 $O(N)$ 量級) * backward 時再 **block-wise 重算 $S$ 與 $P$** | 結果 | 意義 | | ---------------------------------- | ----------------- | | 不需儲存 $N^2$ 中間矩陣 | 記憶體降至線性 | | 以少量 recompute FLOPs 換取大量 HBM IO 減量 | backward 反而更快而非更慢 | ### **Kernel Fusion** * 因為 tiling 讓所有動作都在 block 中進行 * 可用 **單一 CUDA kernel** 完成: * matmul → softmax → masking/dropout → matmul * → **避免多次讀寫 HBM**,進一步降低 IO cost ![image](https://hackmd.io/_uploads/S1_Mv3jAle.png) ### IO Complexity Analysis * 分析 FlashAttention 的 **HBM IO 複雜度** * 與 standard attention 做理論比較 * 並提出 **下界證明(lower bound)**,說明 **沒有 exact attention 能在所有 SRAM 大小下比 FlashAttention 更快(在 IO 次數上)** ### **主要理論結論(Theorem 2)** 令: | 符號 | 含義 | | --- | ---------------------------- | | $N$ | sequence length | | $d$ | head dimension | | $M$ | SRAM 容量(假設 $d \le M \le Nd$) | 則: | 演算法 | HBM IO 複雜度 | | ---------------------- | -------------------------------------- | | **Standard Attention** | $\Theta(Nd + N^2$) | | **FlashAttention** | $\Theta\left(\frac{N^2 d^2}{M}\right)$ | → 在一般情況(e.g., $d = 64 \sim 128$, $M \approx 100KB$)下: * $d^2 \ll M$ * 因此 **FlashAttention 的 IO 次數遠少於 standard attention** | 成果 | 原因 | | --------------------- | --------- | | Wall-clock runtime 更快 | 因 IO 減少 | | 記憶體用量更小 | 不需寫回大中間矩陣 | ### **證明直覺(Proof Sketch)** FlashAttention 的 IO 來源可由 block 設計推導: | 步驟 | 來源 | | ---------------------------------------------- | -- | | 每次載入 $K_j, V_j$ block,可放入 SRAM 大小 $\Theta(M)$ | | | 對每個 $K_j, V_j$ block,需要 iterate 所有 $Q_i$ block | | | 因此需 **$\Theta(N d / M)$ 次 pass over Q** | | | 每次 pass 載入 $\Theta(N d)$ 量資料 | | | → 得到 **$\Theta(N^2 d^2 / M)$ HBM loads** | | Backward 的 IO 複雜度與 forward 相同(同為 $\Theta(N^2 d^2 / M)$) ### **Lower Bound(Proposition 3)** 命題內容: * **不存在一種 exact attention,可以在所有 $M \in [d, Nd]$ 範圍內達到 $o(N^2 d^2 / M)$ 的 HBM IO 複雜度** * 直覺來源: * 若 $M = \Theta(Nd)$(也就是 SRAM 可一次容納全部 Q/K/V) * 則 IO 下界為 $\Omega(Nd)$ * 因此 $\Omega(N^2 d^2 / M)$ 即為不可突破的 IO 極限 * 此下界符合 streaming 理論中常見的 memory–IO tradeoff 設計 --- ### **實驗驗證對 IO 的依賴性** FlashAttention 的 runtime 觀察: | 實驗結果 | 意義 | | ---------------------------------------------------------------------- | ---------------------- | | 即使 FLOPs 較高(因 backward recomputation),FlashAttention 仍比標準 attention 更快 | → IO 是主導 runtime 的關鍵因素 | | 改變 block size $B_c$ → IO 減少則 runtime 下降 | → 直接驗證 IO 決定速度 | | 當 block size 大到某程度後,瓶頸轉為 compute | → SRAM 容量限制 block size | ## Experiments 驗證 FlashAttention(含 block-sparse 版本)是否在 **(1)訓練速度、(2)模型品質、(3)注意力運算效能** 上明顯優於現有方法。 | 評估面向 | FlashAttention的結果 | | -------------------------- | ------------------------------------------------------------------------------------------ | | **Training Speed** | BERT 速度 +15%,GPT-2 最多 **3×** 加速,LRA **2.4×** 加速 | | **Model Quality** | 能訓練更長序列 → GPT-2 perplexity 改善 **0.7**、長文本分類提升 **6.4**、首個 Transformer 能通過 Path-X / Path-256 | | **Benchmarking Attention** | runtime 最快、記憶體最省,block-sparse 版本可線性 scaling 到 64K seq | ## **4.2 Faster Models with FlashAttention** ### **(A) BERT Training Speed** ![image](https://hackmd.io/_uploads/rk7WR3s0gx.png) ### **(B) GPT-2 Training Speed** ![image](https://hackmd.io/_uploads/HJ_SR2jCle.png) | GPT-2 系統 | Speed | | ------------------------ | ------------------------------------------------------ | | HuggingFace | baseline | | Megatron-LM | ~2× faster | | **FlashAttention(ours)** | **up to 3.5× faster(small)** / **3.0× faster(medium)** | * Perplexity 完全一致 → 表示 **僅加速,無品質損失** ### **(C) Long-Range Arena (LRA) Benchmark** * 序列長度 1K–4K * 任務:ListOps, Text, Retrieval, Image, Pathfinder ![image](https://hackmd.io/_uploads/HJT41aiClx.png) ### **(A) Language Modeling — Long Context GPT-2** * FlashAttention 允許將 GPT-2 context length 提高到 **4K(原 1K 的 4×)** * 仍比 Megatron-LM(context=1K)更快 * 成果(Table 4.4): ![image](https://hackmd.io/_uploads/Hkn_ypsAlg.png) | Model | ctx len | ppl | speed | | ------------------ | ------- | -------- | --------------- | | Megatron-LM | 1K | 18.2 | 1.0× | | **FlashAttention** | 4K | **17.5** | **1.3× faster** | → 高速 + 更佳 perplexity(改善 0.7) ### **4.4 Benchmarking Attention (Runtime & Memory)** 環境:A100 40GB,帶 dropout + padding mask ### **Runtime** * FlashAttention:對 exact attention **最多 3× 加速** * Approximate attention 雖然具線性 scaling,但 **在 512–1024 之前 FlashAttention 仍更快** * Block-sparse FlashAttention:**全序列範圍下最快** ### **Memory Footprint** ![image](https://hackmd.io/_uploads/SJeTJpo0xl.png) * FlashAttention & Block-sparse: * 記憶體 **線性成長** * 比 exact attention **最多節省 20×** * 甚至比 approximate attention(如 Performer、Reformer)更省 * 能跑到 **64K seq**,其他 baseline 多數 OOM ## Limitations and Future Directions ### **1. Limitation: CUDA 工程成本高(Compiling to CUDA)** * 目前 FlashAttention 的 IO-aware 設計 **必須手寫 CUDA kernel** * 缺點: | 問題 | 說明 | | ------------- | -------------------------- | | 開發難度高 | 需要以比 PyTorch 低階許多的語言撰寫與最佳化 | | 工程成本大 | 每種注意力變種都需重寫 kernel | | 跨 GPU 架構可移植性差 | 不同 GPU、不同記憶體層級架構 → 需維護多版本 | * **未來方向**:開發一種 **高階 → IO-aware CUDA 的自動轉譯方法** 類似 Halide(在影像領域提供高層 DSL + 自動記憶體最佳化) --- ### **2. Future Direction: IO-Aware Beyond Attention** * Attention 是 Transformer **最記憶體密集的 Layer** * 但「IO-bound 問題」**在所有層都存在**(每一層都需頻繁讀寫 HBM) * 作者認為 IO-aware 的概念: * **並非只適用於 attention** * 可擴展至 **更多 Deep Learning modules** * 論文附錄另有展望,包括: multi-head attention variants、kernel regression、block-sparse matmul 等 --- ### **3. Future Direction: Multi-GPU IO-aware Strategies** * FlashAttention 在 **單 GPU 上已達到 IO-optimal(up to constants)** * 但 Transformer 訓練常使用 **multi-GPU / distributed training** * 未來研究方向: | 多 GPU 新挑戰 | 說明 | | ---------------- | ------------------------- | | GPU-to-GPU 通訊 | 引入額外 IO 層級(NVLink / PCIe) | | 新的 IO bottleneck | 必須同時最佳化 GPU內 + GPU間 訪存 | * 可能與 parallel algorithm 研究結合(文中引用 Recht 2013)