# FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
論文連結 : [FlashAttention](https://arxiv.org/abs/2205.14135)
## abstract
* Transformer 的 **自注意力(self-attention)計算複雜度為 (O(n^2))**,當序列長度變長時,**時間與記憶體需求快速成長**
* 這使得 Transformer 在 **長序列(long sequence)任務上非常慢且吃記憶體**
* 過去已有許多 **近似式注意力(approximate attention)** 企圖改善此問題,方法通常藉由:
* 降低計算量
* 允許模型品質下降(trade-off)
* 但這些方法多數 **沒有真正帶來速度提升**(理論變快但 GPU 實跑不快)
| 目標 | 說明 |
| --------------------------- | --------------------------- |
| 突破 GPU 記憶體讀寫瓶頸 | 將 HBM ↔ SRAM 的 IO 成本納入演算法設計 |
| 保持注意力「精確」 | 不採用近似法也能加速 |
| 提供可延伸到 sparse attention 的版本 | 在需要近似時,也能取得比現有方法更快的結果 |
### 方法核心
| FlashAttention 核心概念 | 說明 |
| -------------------------- | ---------------------------- |
| **IO-aware 設計** | 將注意力計算法視為 GPU IO 問題,而不是純計算問題 |
| **利用 tiling 設計** | 分塊運算避免頻繁讀寫 HBM(高延遲) |
| **最大化使用 GPU on-chip SRAM** | 讓 attention 過程在小記憶體上完成更多運算 |
結果:**減少 HBM 存取次數 → 真正提升速度與節省記憶體**
### 實驗結果
| 任務 | seq length | 加速比(vs baseline) |
| ---------------- | ---------- | ---------------------- |
| BERT-large | 512 | **15% 更快(end-to-end)** |
| GPT-2 | 1K | **3× faster** |
| Long Range Arena | 1K–4K | **2.4× faster** |
## Introduction
### 現有解法的侷限
* 過去許多研究提出 **近似注意力(approximate attention)**,包含:
* sparse-based
* low-rank-based
* hybrid-based
* 儘管這些方法能將計算降到近似線性,但 **大多沒有在 wall-clock 上真正加速**,也 **未被主流採用**。
* 原因是它們:
1. **聚焦於 FLOPs 降低,而非整體 runtime**
2. **忽略記憶體存取成本(IO)是主要瓶頸**
### **本文核心觀點**
* 「缺失的核心原則」是 **IO-aware attention**:
* 需考慮「不同階層 GPU 記憶體間的讀寫成本」(SRAM ↔ HBM)。
* 在現代 GPU 中,**計算速度遠快於記憶體速度**,Transformer 許多運算 **受限於 IO 而非 compute**。
* 但現有框架(PyTorch/TensorFlow)無法提供對記憶體存取的細緻控制。
### 提出的方法:FlashAttention
FlashAttention 的目標:**計算 exact attention,但將 HBM 存取降到最少。**
| 技術 | 目的 |
| -------------------------------- | --------------------------------------------- |
| **Tiling / incremental softmax** | 不需存整個 attention matrix,即可完成 softmax reduction |
| **Recompute in backward** | 只存 normalization,而不是整個中間矩陣 |
* 全部使用 **單一 CUDA kernel** 完成 attention → 減少不必要讀寫
* 即使 backward 多做計算(recompute),仍 **更快且更省記憶體**
#### **理論貢獻**
* FlashAttention 的 **HBM IO complexity**:$$O(N^2 d^2 / M)$$
* 標準 attention IO complexity:$$O(Nd + N^2)$$
* 作者證明:
1. **FlashAttention HBM 存取更少,可達 9× 減量**
2. **不存在更優的 exact attention(在 IO 複雜度上)** → 為最佳上界
#### **延伸:Block-sparse FlashAttention**
* 可進一步加速 **2–4×**,序列長度可達 **64K tokens**
* IO 複雜度優於 FlashAttention,且與 sparsity 成正相關(越 sparse 越快)
#### **實證結果**
FlashAttention 與 block-sparse 版本帶來三大成果:
| 成果 | 證據 |
| ---------- | --------------------------------------------- |
| **速度提升** | BERT-large:+15%;GPT-2:3×;LRA:2.4× |
| **模型品質提升** | 更長 context → GPT-2 perplexity +0.7、長文本分類 +6.4 |
| **解鎖新任務** | 首次 > chance on Path-X(16K)與 Path-256(64K) |
## Background
## GPU 的記憶體階層與效能瓶頸
### GPU 記憶體階層(GPU Memory Hierarchy)


GPU 具有不同層級的記憶體,**容量越小 → 速度越快**
| 記憶體層級 | 位置 | 特性 | 速度 | 容量 |
| ------------------------------------------- | ----------- | ------- | ------------------- | -------- |
| **On-chip SRAM / Shared Memory / Register** | GPU 核心內 | 超高速、低延遲 | 非常快(TB/s 級) | 很小(KB 級) |
| **HBM(High Bandwidth Memory)** | GPU 外掛 DRAM | 大容量但慢 | 較慢(比 SRAM 慢一個數量級以上) | 很大(GB 級) |
以 **NVIDIA A100 為例**:
| 項目 | A100 指標 |
| ------------------ | ------------ |
| HBM 容量 | 40–80 GB |
| HBM 速度 | 1.5–2.0 TB/s |
| On-chip SRAM(每 SM) | 192 KB |
| On-chip SRAM 速度 | ~19 TB/s |
→ **SRAM 比 HBM 快至少一個數量級,但容量小很多**
### 重點
* GPU 計算速度(FLOPs)成長 **遠快於** HBM 記憶體速度
* **越來越多深度學習運算變成 IO-bound,而不是 compute-bound**
* Transformer attention:
* 需要「反覆讀寫 Q/K/V 以及 attention matrix」
* 這些動作 **主要由 HBM → SRAM → HBM 溝通組成**
* 瓶頸在 **HBM 存取次數,而非矩陣乘法 FLOPs**
### Execution Model
* GPU 在執行一個運算時會啟動大量平行 thread,該運算稱為 **kernel**
* kernel 的運作流程固定為:
1. **從 HBM 載入資料到 SRAM / registers**
2. **在 on-chip 記憶體上進行計算**
3. **再將結果寫回 HBM**
→ **每個 kernel 必然涉及至少一次 HBM → SRAM → HBM 的資料流**
### **Compute-bound vs Memory-bound**
運算效能依「計算 vs 記憶體存取」平衡被分成兩類:
| 類型 | 主導時間因素 | 特徵 | 範例 |
| ----------------- | ---------- | -------------- | --------------------------------------------------------------------------- |
| **Compute-bound** | 計算量(FLOPs) | memory cost 很小 | 大維度 matmul、conv(大量 channel) |
| **Memory-bound** | HBM 存取次數 | 計算成本很低但 IO 大 | elementwise(dropout, activation)、reduction(softmax, batch norm, layer norm) |
* 衡量指標:**arithmetic intensity** = (運算次數)/(讀寫記憶體 byte 數)
* arithmetic intensity 越低 → 越容易 **memory-bound**
* **許多 Transformer 模組(尤其 softmax 與 elementwise)是 memory-bound**
### **Kernel Fusion**
* **概念**:若多個 operation 作用在同一資料,將它們合併到同一 kernel → 只需從 HBM 讀一次,而非多次
* compilers(如 PyTorch JIT, XLA)已能自動 fuse 多個 elementwise ops
* **限制**:在 *訓練* 中,backward 需要中間值,因此 intermediate 仍要寫回 HBM → kernel fusion
### Standard Attention Implementation
給定:
| 符號 | 尺寸 | 意義 |
| ------------------------------------- | --------------------- | ------------ |
| $Q, K, V \in \mathbb{R}^{N \times d}$ | N = 序列長度,d = head dim | attention 輸入 |
| $O \in \mathbb{R}^{N \times d}$ | output | 注意力結果 |
標準 attention公式:
$$
S = QK^\top \in \mathbb{R}^{N \times N}, \quad
P = \text{softmax}(S) \in \mathbb{R}^{N \times N}, \quad
O = PV
$$
* softmax **逐 row** 計算
* 會產生兩個 (O(N^2)) 級別中間矩陣:**注意力分數 (S)** 與 **attention 機率矩陣 (P)**
#### **主要效能問題**
* 這些矩陣全部 **materialize 到 HBM → 需 (O(N^2)) memory**
* 在實務中常常 (N \gg d)(例如 GPT2: (N=1024, d=64))
* 計算過程中包含多個 memory-bound operations(如 softmax、mask、dropout)
→ **大量 HBM 存取**
* 為 $P$ 或 $S$ 的 elementwise ops(mask / dropout / scaling)雖可 fuse
* 但 **仍要把 $S$ 與 $P$ 寫回 HBM**
→ 本質仍是 $O(N^2)$ IO,瓶頸不會消失

## Algorithm, Analysis, and Extensions
#### **目標與核心成果**
* 本節提出一個 **能計算 exact attention、但需更少 HBM 存取與記憶體的演算法**
* 降低 IO → **速度更快、記憶體更省(backward也不需儲存 $O(N^2)$ 中間矩陣)**
* 內容包含:
1. forward 的高效率 attention 演算法
2. IO complexity 分析(與 standard attention 比較)
3. block-sparse attention 擴展
* 本段先聚焦 **forward pass**(backward 置於 appendix)
### **實現**
作者為達到「sub-quadratic HBM access」採用兩項既有但少被 attention 正式結合的概念:
| 技術 | 設計目的 |
| --------------------- | ------------------------------------------------- |
| **Tiling(分塊)** | 逐區塊計算 softmax 以避免 materialize $N \times N$ matrix |
| **Recomputation(重算)** | backward 時不儲存 $S$ 與 $P$,僅存 softmax 統計量 |
### **Tiling 的核心想法(Forward)**
* 將 $Q, K, V$ **切成多個 block**(row-block 與 column-block)
* 每次只將少量 block **從 HBM 載入 SRAM → 計算 → 更新輸出**
* softmax 被拆成 block-wise 累積計算,需要額外維護統計量:
| 需累積統計 | 作用 |
| --------------------------- | ------------------ |
| $m$(row-wise max) | 用來穩定 softmax |
| $\ell$(row-wise sum of exp) | 用來完成 normalization |
* 可根據 softmax 性質做到 **增量式 aggregation**,因此最終結果與一次算完整 softmax 相同
### Recomputation
* backward 通常需要 $\mathbf{S}$ 與 $\mathbf{P}$(兩者為 $O(N^2)$ matrix)
* FlashAttention 的做法:
* forward **只儲存 $O$、$m$、$\ell$**(皆為 $O(N)$ 量級)
* backward 時再 **block-wise 重算 $S$ 與 $P$**
| 結果 | 意義 |
| ---------------------------------- | ----------------- |
| 不需儲存 $N^2$ 中間矩陣 | 記憶體降至線性 |
| 以少量 recompute FLOPs 換取大量 HBM IO 減量 | backward 反而更快而非更慢 |
### **Kernel Fusion**
* 因為 tiling 讓所有動作都在 block 中進行
* 可用 **單一 CUDA kernel** 完成:
* matmul → softmax → masking/dropout → matmul
* → **避免多次讀寫 HBM**,進一步降低 IO cost

### IO Complexity Analysis
* 分析 FlashAttention 的 **HBM IO 複雜度**
* 與 standard attention 做理論比較
* 並提出 **下界證明(lower bound)**,說明 **沒有 exact attention 能在所有 SRAM 大小下比 FlashAttention 更快(在 IO 次數上)**
### **主要理論結論(Theorem 2)**
令:
| 符號 | 含義 |
| --- | ---------------------------- |
| $N$ | sequence length |
| $d$ | head dimension |
| $M$ | SRAM 容量(假設 $d \le M \le Nd$) |
則:
| 演算法 | HBM IO 複雜度 |
| ---------------------- | -------------------------------------- |
| **Standard Attention** | $\Theta(Nd + N^2$) |
| **FlashAttention** | $\Theta\left(\frac{N^2 d^2}{M}\right)$ |
→ 在一般情況(e.g., $d = 64 \sim 128$, $M \approx 100KB$)下:
* $d^2 \ll M$
* 因此 **FlashAttention 的 IO 次數遠少於 standard attention**
| 成果 | 原因 |
| --------------------- | --------- |
| Wall-clock runtime 更快 | 因 IO 減少 |
| 記憶體用量更小 | 不需寫回大中間矩陣 |
### **證明直覺(Proof Sketch)**
FlashAttention 的 IO 來源可由 block 設計推導:
| 步驟 | 來源 |
| ---------------------------------------------- | -- |
| 每次載入 $K_j, V_j$ block,可放入 SRAM 大小 $\Theta(M)$ | |
| 對每個 $K_j, V_j$ block,需要 iterate 所有 $Q_i$ block | |
| 因此需 **$\Theta(N d / M)$ 次 pass over Q** | |
| 每次 pass 載入 $\Theta(N d)$ 量資料 | |
| → 得到 **$\Theta(N^2 d^2 / M)$ HBM loads** | |
Backward 的 IO 複雜度與 forward 相同(同為 $\Theta(N^2 d^2 / M)$)
### **Lower Bound(Proposition 3)**
命題內容:
* **不存在一種 exact attention,可以在所有 $M \in [d, Nd]$ 範圍內達到 $o(N^2 d^2 / M)$ 的 HBM IO 複雜度**
* 直覺來源:
* 若 $M = \Theta(Nd)$(也就是 SRAM 可一次容納全部 Q/K/V)
* 則 IO 下界為 $\Omega(Nd)$
* 因此 $\Omega(N^2 d^2 / M)$ 即為不可突破的 IO 極限
* 此下界符合 streaming 理論中常見的 memory–IO tradeoff 設計
---
### **實驗驗證對 IO 的依賴性**
FlashAttention 的 runtime 觀察:
| 實驗結果 | 意義 |
| ---------------------------------------------------------------------- | ---------------------- |
| 即使 FLOPs 較高(因 backward recomputation),FlashAttention 仍比標準 attention 更快 | → IO 是主導 runtime 的關鍵因素 |
| 改變 block size $B_c$ → IO 減少則 runtime 下降 | → 直接驗證 IO 決定速度 |
| 當 block size 大到某程度後,瓶頸轉為 compute | → SRAM 容量限制 block size |
## Experiments
驗證 FlashAttention(含 block-sparse 版本)是否在 **(1)訓練速度、(2)模型品質、(3)注意力運算效能** 上明顯優於現有方法。
| 評估面向 | FlashAttention的結果 |
| -------------------------- | ------------------------------------------------------------------------------------------ |
| **Training Speed** | BERT 速度 +15%,GPT-2 最多 **3×** 加速,LRA **2.4×** 加速 |
| **Model Quality** | 能訓練更長序列 → GPT-2 perplexity 改善 **0.7**、長文本分類提升 **6.4**、首個 Transformer 能通過 Path-X / Path-256 |
| **Benchmarking Attention** | runtime 最快、記憶體最省,block-sparse 版本可線性 scaling 到 64K seq |
## **4.2 Faster Models with FlashAttention**
### **(A) BERT Training Speed**

### **(B) GPT-2 Training Speed**

| GPT-2 系統 | Speed |
| ------------------------ | ------------------------------------------------------ |
| HuggingFace | baseline |
| Megatron-LM | ~2× faster |
| **FlashAttention(ours)** | **up to 3.5× faster(small)** / **3.0× faster(medium)** |
* Perplexity 完全一致 → 表示 **僅加速,無品質損失**
### **(C) Long-Range Arena (LRA) Benchmark**
* 序列長度 1K–4K
* 任務:ListOps, Text, Retrieval, Image, Pathfinder

### **(A) Language Modeling — Long Context GPT-2**
* FlashAttention 允許將 GPT-2 context length 提高到 **4K(原 1K 的 4×)**
* 仍比 Megatron-LM(context=1K)更快
* 成果(Table 4.4):

| Model | ctx len | ppl | speed |
| ------------------ | ------- | -------- | --------------- |
| Megatron-LM | 1K | 18.2 | 1.0× |
| **FlashAttention** | 4K | **17.5** | **1.3× faster** |
→ 高速 + 更佳 perplexity(改善 0.7)
### **4.4 Benchmarking Attention (Runtime & Memory)**
環境:A100 40GB,帶 dropout + padding mask
### **Runtime**
* FlashAttention:對 exact attention **最多 3× 加速**
* Approximate attention 雖然具線性 scaling,但 **在 512–1024 之前 FlashAttention 仍更快**
* Block-sparse FlashAttention:**全序列範圍下最快**
### **Memory Footprint**

* FlashAttention & Block-sparse:
* 記憶體 **線性成長**
* 比 exact attention **最多節省 20×**
* 甚至比 approximate attention(如 Performer、Reformer)更省
* 能跑到 **64K seq**,其他 baseline 多數 OOM
## Limitations and Future Directions
### **1. Limitation: CUDA 工程成本高(Compiling to CUDA)**
* 目前 FlashAttention 的 IO-aware 設計 **必須手寫 CUDA kernel**
* 缺點:
| 問題 | 說明 |
| ------------- | -------------------------- |
| 開發難度高 | 需要以比 PyTorch 低階許多的語言撰寫與最佳化 |
| 工程成本大 | 每種注意力變種都需重寫 kernel |
| 跨 GPU 架構可移植性差 | 不同 GPU、不同記憶體層級架構 → 需維護多版本 |
* **未來方向**:開發一種 **高階 → IO-aware CUDA 的自動轉譯方法**
類似 Halide(在影像領域提供高層 DSL + 自動記憶體最佳化)
---
### **2. Future Direction: IO-Aware Beyond Attention**
* Attention 是 Transformer **最記憶體密集的 Layer**
* 但「IO-bound 問題」**在所有層都存在**(每一層都需頻繁讀寫 HBM)
* 作者認為 IO-aware 的概念:
* **並非只適用於 attention**
* 可擴展至 **更多 Deep Learning modules**
* 論文附錄另有展望,包括:
multi-head attention variants、kernel regression、block-sparse matmul 等
---
### **3. Future Direction: Multi-GPU IO-aware Strategies**
* FlashAttention 在 **單 GPU 上已達到 IO-optimal(up to constants)**
* 但 Transformer 訓練常使用 **multi-GPU / distributed training**
* 未來研究方向:
| 多 GPU 新挑戰 | 說明 |
| ---------------- | ------------------------- |
| GPU-to-GPU 通訊 | 引入額外 IO 層級(NVLink / PCIe) |
| 新的 IO bottleneck | 必須同時最佳化 GPU內 + GPU間 訪存 |
* 可能與 parallel algorithm 研究結合(文中引用 Recht 2013)