---
# System prepended metadata

title: 'FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness'

---

# FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
論文連結 : [FlashAttention](https://arxiv.org/abs/2205.14135)
## abstract
* Transformer 的 **自注意力（self-attention）計算複雜度為 (O(n^2))**，當序列長度變長時，**時間與記憶體需求快速成長**
* 這使得 Transformer 在 **長序列（long sequence）任務上非常慢且吃記憶體**
* 過去已有許多 **近似式注意力（approximate attention）** 企圖改善此問題，方法通常藉由：

  * 降低計算量
  * 允許模型品質下降（trade-off）
* 但這些方法多數 **沒有真正帶來速度提升**（理論變快但 GPU 實跑不快）

| 目標                          | 說明                          |
| --------------------------- | --------------------------- |
| 突破 GPU 記憶體讀寫瓶頸              | 將 HBM ↔ SRAM 的 IO 成本納入演算法設計 |
| 保持注意力「精確」                   | 不採用近似法也能加速                  |
| 提供可延伸到 sparse attention 的版本 | 在需要近似時，也能取得比現有方法更快的結果       |

### 方法核心
| FlashAttention 核心概念        | 說明                           |
| -------------------------- | ---------------------------- |
| **IO-aware 設計**            | 將注意力計算法視為 GPU IO 問題，而不是純計算問題 |
| **利用 tiling 設計**           | 分塊運算避免頻繁讀寫 HBM（高延遲）          |
| **最大化使用 GPU on-chip SRAM** | 讓 attention 過程在小記憶體上完成更多運算   |
結果：**減少 HBM 存取次數 → 真正提升速度與節省記憶體**

### 實驗結果
| 任務               | seq length | 加速比（vs baseline）       |
| ---------------- | ---------- | ---------------------- |
| BERT-large       | 512        | **15% 更快（end-to-end）** |
| GPT-2            | 1K         | **3× faster**          |
| Long Range Arena | 1K–4K      | **2.4× faster**        |

## Introduction
### 現有解法的侷限
* 過去許多研究提出 **近似注意力（approximate attention）**，包含：
  * sparse-based
  * low-rank-based
  * hybrid-based
* 儘管這些方法能將計算降到近似線性，但 **大多沒有在 wall-clock 上真正加速**，也 **未被主流採用**。
* 原因是它們：
  1. **聚焦於 FLOPs 降低，而非整體 runtime**
  2. **忽略記憶體存取成本（IO）是主要瓶頸**

### **本文核心觀點**

* 「缺失的核心原則」是 **IO-aware attention**：
  * 需考慮「不同階層 GPU 記憶體間的讀寫成本」(SRAM ↔ HBM)。
  * 在現代 GPU 中，**計算速度遠快於記憶體速度**，Transformer 許多運算 **受限於 IO 而非 compute**。
* 但現有框架（PyTorch/TensorFlow）無法提供對記憶體存取的細緻控制。

### 提出的方法：FlashAttention
FlashAttention 的目標：**計算 exact attention，但將 HBM 存取降到最少。**
| 技術                               | 目的                                            |
| -------------------------------- | --------------------------------------------- |
| **Tiling / incremental softmax** | 不需存整個 attention matrix，即可完成 softmax reduction |
| **Recompute in backward**        | 只存 normalization，而不是整個中間矩陣                    |
* 全部使用 **單一 CUDA kernel** 完成 attention → 減少不必要讀寫
* 即使 backward 多做計算（recompute），仍 **更快且更省記憶體**

#### **理論貢獻**
* FlashAttention 的 **HBM IO complexity**：$$O(N^2 d^2 / M)$$
* 標準 attention IO complexity：$$O(Nd + N^2)$$
* 作者證明：
  1. **FlashAttention HBM 存取更少，可達 9× 減量**
  2. **不存在更優的 exact attention（在 IO 複雜度上）** → 為最佳上界

#### **延伸：Block-sparse FlashAttention**

* 可進一步加速 **2–4×**，序列長度可達 **64K tokens**
* IO 複雜度優於 FlashAttention，且與 sparsity 成正相關（越 sparse 越快）

#### **實證結果**

FlashAttention 與 block-sparse 版本帶來三大成果：

| 成果         | 證據                                            |
| ---------- | --------------------------------------------- |
| **速度提升**   | BERT-large：+15%；GPT-2：3×；LRA：2.4×             |
| **模型品質提升** | 更長 context → GPT-2 perplexity +0.7、長文本分類 +6.4 |
| **解鎖新任務**  | 首次 > chance on Path-X（16K）與 Path-256（64K）     |


## Background
## GPU 的記憶體階層與效能瓶頸

### GPU 記憶體階層（GPU Memory Hierarchy）
![image](https://hackmd.io/_uploads/ByPKehiAlx.png)
![image](https://hackmd.io/_uploads/ryNv-6iAxx.png)

GPU 具有不同層級的記憶體，**容量越小 → 速度越快**

| 記憶體層級                                       | 位置          | 特性      | 速度                  | 容量       |
| ------------------------------------------- | ----------- | ------- | ------------------- | -------- |
| **On-chip SRAM / Shared Memory / Register** | GPU 核心內     | 超高速、低延遲 | 非常快（TB/s 級）         | 很小（KB 級） |
| **HBM（High Bandwidth Memory）**              | GPU 外掛 DRAM | 大容量但慢   | 較慢（比 SRAM 慢一個數量級以上） | 很大（GB 級） |

以 **NVIDIA A100 為例**：
| 項目                 | A100 指標      |
| ------------------ | ------------ |
| HBM 容量             | 40–80 GB     |
| HBM 速度             | 1.5–2.0 TB/s |
| On-chip SRAM（每 SM） | 192 KB       |
| On-chip SRAM 速度    | ~19 TB/s     |

→ **SRAM 比 HBM 快至少一個數量級，但容量小很多**

### 重點
* GPU 計算速度（FLOPs）成長 **遠快於** HBM 記憶體速度
* **越來越多深度學習運算變成 IO-bound，而不是 compute-bound**
* Transformer attention：
  * 需要「反覆讀寫 Q/K/V 以及 attention matrix」
  * 這些動作 **主要由 HBM → SRAM → HBM 溝通組成**
  * 瓶頸在 **HBM 存取次數，而非矩陣乘法 FLOPs**

### Execution Model
* GPU 在執行一個運算時會啟動大量平行 thread，該運算稱為 **kernel**
* kernel 的運作流程固定為：
  1. **從 HBM 載入資料到 SRAM / registers**
  2. **在 on-chip 記憶體上進行計算**
  3. **再將結果寫回 HBM**
→ **每個 kernel 必然涉及至少一次 HBM → SRAM → HBM 的資料流**

### **Compute-bound vs Memory-bound**

運算效能依「計算 vs 記憶體存取」平衡被分成兩類：

| 類型                | 主導時間因素     | 特徵             | 範例                                                                          |
| ----------------- | ---------- | -------------- | --------------------------------------------------------------------------- |
| **Compute-bound** | 計算量（FLOPs） | memory cost 很小 | 大維度 matmul、conv（大量 channel）                                                 |
| **Memory-bound**  | HBM 存取次數   | 計算成本很低但 IO 大   | elementwise（dropout, activation）、reduction（softmax, batch norm, layer norm） |

* 衡量指標：**arithmetic intensity** = （運算次數）/（讀寫記憶體 byte 數）
* arithmetic intensity 越低 → 越容易 **memory-bound**
* **許多 Transformer 模組（尤其 softmax 與 elementwise）是 memory-bound**

### **Kernel Fusion**

* **概念**：若多個 operation 作用在同一資料，將它們合併到同一 kernel → 只需從 HBM 讀一次，而非多次
* compilers（如 PyTorch JIT, XLA）已能自動 fuse 多個 elementwise ops
* **限制**：在 *訓練* 中，backward 需要中間值，因此 intermediate 仍要寫回 HBM → kernel fusion

### Standard Attention Implementation
給定：

| 符號                                    | 尺寸                    | 意義           |
| ------------------------------------- | --------------------- | ------------ |
| $Q, K, V \in \mathbb{R}^{N \times d}$ | N = 序列長度，d = head dim | attention 輸入 |
| $O \in \mathbb{R}^{N \times d}$      | output                | 注意力結果        |

標準 attention公式：

$$
S = QK^\top \in \mathbb{R}^{N \times N}, \quad
P = \text{softmax}(S) \in \mathbb{R}^{N \times N}, \quad
O = PV
$$
* softmax **逐 row** 計算
* 會產生兩個 (O(N^2)) 級別中間矩陣：**注意力分數 (S)** 與 **attention 機率矩陣 (P)**

#### **主要效能問題**

* 這些矩陣全部 **materialize 到 HBM → 需 (O(N^2)) memory**
* 在實務中常常 (N \gg d)（例如 GPT2: (N=1024, d=64)）
* 計算過程中包含多個 memory-bound operations（如 softmax、mask、dropout）
  → **大量 HBM 存取**
* 為 $P$ 或 $S$ 的 elementwise ops（mask / dropout / scaling）雖可 fuse
* 但 **仍要把 $S$ 與 $P$ 寫回 HBM**
  → 本質仍是 $O(N^2)$ IO，瓶頸不會消失
![image](https://hackmd.io/_uploads/HJd9Qhi0lx.png)
  

## Algorithm, Analysis, and Extensions
#### **目標與核心成果**
* 本節提出一個 **能計算 exact attention、但需更少 HBM 存取與記憶體的演算法**
* 降低 IO → **速度更快、記憶體更省（backward也不需儲存 $O(N^2)$ 中間矩陣）**
* 內容包含：
  1. forward 的高效率 attention 演算法
  2. IO complexity 分析（與 standard attention 比較）
  3. block-sparse attention 擴展
* 本段先聚焦 **forward pass**（backward 置於 appendix）

### **實現**

作者為達到「sub-quadratic HBM access」採用兩項既有但少被 attention 正式結合的概念：

| 技術                    | 設計目的                                              |
| --------------------- | ------------------------------------------------- |
| **Tiling（分塊）**        | 逐區塊計算 softmax 以避免 materialize $N \times N$ matrix |
| **Recomputation（重算）** | backward 時不儲存 $S$ 與 $P$，僅存 softmax 統計量            |

### **Tiling 的核心想法（Forward）**

* 將 $Q, K, V$ **切成多個 block**（row-block 與 column-block）
* 每次只將少量 block **從 HBM 載入 SRAM → 計算 → 更新輸出**
* softmax 被拆成 block-wise 累積計算，需要額外維護統計量：

| 需累積統計                       | 作用                 |
| --------------------------- | ------------------ |
| $m$（row-wise max）           | 用來穩定 softmax       |
| $\ell$（row-wise sum of exp） | 用來完成 normalization |

* 可根據 softmax 性質做到 **增量式 aggregation**，因此最終結果與一次算完整 softmax 相同

### Recomputation
* backward 通常需要 $\mathbf{S}$ 與 $\mathbf{P}$（兩者為 $O(N^2)$ matrix）
* FlashAttention 的做法：

  * forward **只儲存 $O$、$m$、$\ell$**（皆為 $O(N)$ 量級）
  * backward 時再 **block-wise 重算 $S$ 與 $P$**

| 結果                                 | 意義                |
| ---------------------------------- | ----------------- |
| 不需儲存 $N^2$ 中間矩陣                    | 記憶體降至線性           |
| 以少量 recompute FLOPs 換取大量 HBM IO 減量 | backward 反而更快而非更慢 |

### **Kernel Fusion**

* 因為 tiling 讓所有動作都在 block 中進行
* 可用 **單一 CUDA kernel** 完成：
  * matmul → softmax → masking/dropout → matmul
* → **避免多次讀寫 HBM**，進一步降低 IO cost
![image](https://hackmd.io/_uploads/S1_Mv3jAle.png)


### IO Complexity Analysis
* 分析 FlashAttention 的 **HBM IO 複雜度**
* 與 standard attention 做理論比較
* 並提出 **下界證明（lower bound）**，說明 **沒有 exact attention 能在所有 SRAM 大小下比 FlashAttention 更快（在 IO 次數上）**

### **主要理論結論（Theorem 2）**

令：

| 符號  | 含義                           |
| --- | ---------------------------- |
| $N$ | sequence length              |
| $d$ | head dimension               |
| $M$ | SRAM 容量（假設 $d \le M \le Nd$） |

則：

| 演算法                    | HBM IO 複雜度                             |
| ---------------------- | -------------------------------------- |
| **Standard Attention** | $\Theta(Nd + N^2$)                     |
| **FlashAttention**     | $\Theta\left(\frac{N^2 d^2}{M}\right)$ |

→ 在一般情況（e.g., $d = 64 \sim 128$, $M \approx 100KB$）下：

* $d^2 \ll M$
* 因此 **FlashAttention 的 IO 次數遠少於 standard attention**
| 成果                    | 原因        |
| --------------------- | --------- |
| Wall-clock runtime 更快 | 因 IO 減少   |
| 記憶體用量更小               | 不需寫回大中間矩陣 |

### **證明直覺（Proof Sketch）**
FlashAttention 的 IO 來源可由 block 設計推導：
| 步驟                                             | 來源 |
| ---------------------------------------------- | -- |
| 每次載入 $K_j, V_j$ block，可放入 SRAM 大小 $\Theta(M)$  |    |
| 對每個 $K_j, V_j$ block，需要 iterate 所有 $Q_i$ block |    |
| 因此需 **$\Theta(N d / M)$ 次 pass over Q**        |    |
| 每次 pass 載入 $\Theta(N d)$ 量資料                   |    |
| → 得到 **$\Theta(N^2 d^2 / M)$ HBM loads**       |    |

Backward 的 IO 複雜度與 forward 相同（同為 $\Theta(N^2 d^2 / M)$）


### **Lower Bound（Proposition 3）**

命題內容：

* **不存在一種 exact attention，可以在所有 $M \in [d, Nd]$ 範圍內達到 $o(N^2 d^2 / M)$ 的 HBM IO 複雜度**
* 直覺來源：

  * 若 $M = \Theta(Nd)$（也就是 SRAM 可一次容納全部 Q/K/V）
  * 則 IO 下界為 $\Omega(Nd)$
  * 因此 $\Omega(N^2 d^2 / M)$ 即為不可突破的 IO 極限
* 此下界符合 streaming 理論中常見的 memory–IO tradeoff 設計

---

### **實驗驗證對 IO 的依賴性**

FlashAttention 的 runtime 觀察：

| 實驗結果                                                                   | 意義                     |
| ---------------------------------------------------------------------- | ---------------------- |
| 即使 FLOPs 較高（因 backward recomputation），FlashAttention 仍比標準 attention 更快 | → IO 是主導 runtime 的關鍵因素 |
| 改變 block size $B_c$ → IO 減少則 runtime 下降                                | → 直接驗證 IO 決定速度         |
| 當 block size 大到某程度後，瓶頸轉為 compute                                       | → SRAM 容量限制 block size |

## Experiments
驗證 FlashAttention（含 block-sparse 版本）是否在 **(1)訓練速度、(2)模型品質、(3)注意力運算效能** 上明顯優於現有方法。
| 評估面向                       | FlashAttention的結果                                                                          |
| -------------------------- | ------------------------------------------------------------------------------------------ |
| **Training Speed**         | BERT 速度 +15%，GPT-2 最多 **3×** 加速，LRA **2.4×** 加速                                            |
| **Model Quality**          | 能訓練更長序列 → GPT-2 perplexity 改善 **0.7**、長文本分類提升 **6.4**、首個 Transformer 能通過 Path-X / Path-256 |
| **Benchmarking Attention** | runtime 最快、記憶體最省，block-sparse 版本可線性 scaling 到 64K seq                                      |

## **4.2 Faster Models with FlashAttention**
### **(A) BERT Training Speed**
![image](https://hackmd.io/_uploads/rk7WR3s0gx.png)
### **(B) GPT-2 Training Speed**
![image](https://hackmd.io/_uploads/HJ_SR2jCle.png)
| GPT-2 系統                 | Speed                                                  |
| ------------------------ | ------------------------------------------------------ |
| HuggingFace              | baseline                                               |
| Megatron-LM              | ~2× faster                                             |
| **FlashAttention（ours）** | **up to 3.5× faster（small）** / **3.0× faster（medium）** |

* Perplexity 完全一致 → 表示 **僅加速，無品質損失**

### **(C) Long-Range Arena (LRA) Benchmark**

* 序列長度 1K–4K
* 任務：ListOps, Text, Retrieval, Image, Pathfinder
![image](https://hackmd.io/_uploads/HJT41aiClx.png)

### **(A) Language Modeling — Long Context GPT-2**

* FlashAttention 允許將 GPT-2 context length 提高到 **4K（原 1K 的 4×）**
* 仍比 Megatron-LM（context=1K）更快
* 成果（Table 4.4）：
![image](https://hackmd.io/_uploads/Hkn_ypsAlg.png)

| Model              | ctx len | ppl      | speed           |
| ------------------ | ------- | -------- | --------------- |
| Megatron-LM        | 1K      | 18.2     | 1.0×            |
| **FlashAttention** | 4K      | **17.5** | **1.3× faster** |

→ 高速 + 更佳 perplexity（改善 0.7）

### **4.4 Benchmarking Attention (Runtime & Memory)**

環境：A100 40GB，帶 dropout + padding mask

### **Runtime**

* FlashAttention：對 exact attention **最多 3× 加速**
* Approximate attention 雖然具線性 scaling，但 **在 512–1024 之前 FlashAttention 仍更快**
* Block-sparse FlashAttention：**全序列範圍下最快**

### **Memory Footprint**
![image](https://hackmd.io/_uploads/SJeTJpo0xl.png)

* FlashAttention & Block-sparse：

  * 記憶體 **線性成長**
  * 比 exact attention **最多節省 20×**
  * 甚至比 approximate attention（如 Performer、Reformer）更省
  * 能跑到 **64K seq**，其他 baseline 多數 OOM


## Limitations and Future Directions
### **1. Limitation: CUDA 工程成本高（Compiling to CUDA）**

* 目前 FlashAttention 的 IO-aware 設計 **必須手寫 CUDA kernel**

* 缺點：

  | 問題            | 說明                         |
  | ------------- | -------------------------- |
  | 開發難度高         | 需要以比 PyTorch 低階許多的語言撰寫與最佳化 |
  | 工程成本大         | 每種注意力變種都需重寫 kernel         |
  | 跨 GPU 架構可移植性差 | 不同 GPU、不同記憶體層級架構 → 需維護多版本  |

* **未來方向**：開發一種 **高階 → IO-aware CUDA 的自動轉譯方法**
  類似 Halide（在影像領域提供高層 DSL + 自動記憶體最佳化）

---

### **2. Future Direction: IO-Aware Beyond Attention**

* Attention 是 Transformer **最記憶體密集的 Layer**
* 但「IO-bound 問題」**在所有層都存在**（每一層都需頻繁讀寫 HBM）
* 作者認為 IO-aware 的概念：

  * **並非只適用於 attention**
  * 可擴展至 **更多 Deep Learning modules**
* 論文附錄另有展望，包括：
  multi-head attention variants、kernel regression、block-sparse matmul 等

---

### **3. Future Direction: Multi-GPU IO-aware Strategies**

* FlashAttention 在 **單 GPU 上已達到 IO-optimal（up to constants）**
* 但 Transformer 訓練常使用 **multi-GPU / distributed training**
* 未來研究方向：

  | 多 GPU 新挑戰        | 說明                        |
  | ---------------- | ------------------------- |
  | GPU-to-GPU 通訊    | 引入額外 IO 層級（NVLink / PCIe） |
  | 新的 IO bottleneck | 必須同時最佳化 GPU內 + GPU間 訪存    |
* 可能與 parallel algorithm 研究結合（文中引用 Recht 2013）