20251011 筆記,內容可能有錯,請參考來源影片。
[李宏毅機器學習2022影片](https://www.youtube.com/playlist?list=PLJV_el3uVTsPM2mM-OQzJXziCGJa8nJL8)
今天影片內容
[【機器學習 2022】各式各樣神奇的自注意力機制 (Self-attention) 變型](https://www.youtube.com/watch?v=yHoAq1IT_og&list=PLJV_el3uVTsPM2mM-OQzJXziCGJa8nJL8&index=5)
### 【機器學習 2022】各式各樣神奇的自注意力機制 (Self-attention) 變型 大綱
#### **一、 自注意力機制 (Self-Attention) 的基礎與挑戰**
* **目的與機制**: 處理序列輸入,透過 Q(Query)與 K(Key)的點積計算 N x N 的注意力矩陣,對 V(Value)加權求和。
* **主要痛點**: 運算量為 **O(N²)**,當序列長度 N 非常長(如影像處理)時,計算難以負荷。
* **加速目標**: 將時間複雜度從 O(N²) 降低。
#### **二、 加速自注意力機制的方法 (O(N²) ➝ O(N) 或更低)**
#### **A. 減少注意力矩陣的計算量 (稀疏化)**
1. **基於人類設計的注意力模式 (Human-designed)**
* **局部注意力 (Local Attention)**:只關注左右鄰居的小範圍資訊。
* **跨步注意力 (Strided Attention)**:間隔性地關注更遠的鄰居,擴大receptive field。
* **全域注意力 (Global Attention)**:引入「特殊全域符號」作為資訊樞紐,收集並分發全局資訊。
* **結合模式 (Hybrid)**:多個注意力頭 (Heads) 執行不同模式。
* *實例*: **Longformer** (局部 + 跨步 + 全域), **Big Bird** (Longformer + 隨機)。
2. **資料驅動的注意力稀疏化 (Data-driven Pruning)**
* **基於分群的注意力 (Clustering-based)**:對 Q 和 K 進行快速分群,只計算「同群」之間的注意力值。
* *實例*: **Reformer**, **Routing Transformer**。
* **學習式注意力模式 (Learned Patterns)**:使用一個額外的「稀疏網路」來學習和生成 N x N 的二元稀疏矩陣。
#### **B. 減少注意力矩陣的維度 (Low-Rank Approximation)**
* **低秩注意力 (Low-Rank Attention / Linformer)**:觀察到注意力矩陣是低秩的,存在冗餘。
* **機制**: 只從 N 個 Key 向量中選取 **K** 個有代表性的 Key 向量(K << N),將運算量降為 N x K。
#### **C. 改變矩陣乘法順序 (Matrix Multiplication Reordering)**
* **Linear Transformer**:
* **核心思想**: 利用線性代數特性,在沒有 Softmax 的情況下,將運算順序從 $O = V \cdot (K^T \cdot Q)$ 改為 $O = (V \cdot K^T) \cdot Q$。
* **運算量**: 從 $O(D \cdot N^2)$ 降為 $O(D^2 \cdot N)$,實現與 N 成**線性正比 (O(N))**。
* **Softmax 處理**: 透過將點積 $\exp(Q \cdot K)$ 拆解為 $\phi(Q) \cdot \phi(K)$ 的形式來實現相同的 O(N) 運算結果。
#### **D. 重新思考注意力機制 (Rethinking Attention)**
* **Synthesizer**: 挑戰 Q-K 互動的必要性,將注意力矩陣本身視為網路參數(固定權重)。
* **無注意力機制 (Attention-Free)**:嘗試完全捨棄注意力,改用全連接網路 (FFN) 或 MLP 來處理序列。
#### **三、 總結與性能比較 (Summary)**
* **評估標準**: LRA Score (性能,越高越好)、Speed (速度,越右越快)、圓圈大小 (記憶體用量,越小越好)。
* **觀察重點**:
* **Transformer** (原始版) 為基準。
* **Local Attention** 速度快但性能差。
* **Linformer** (低秩) 和 **Performer / Linear Transformer** (改變順序) 在速度上有顯著提升,但性能相較於原始 Transformer 略有下降。
* **Big Bird** 在性能上略優於原始 Transformer。
#### **一、 自注意力機制的基礎與挑戰**

1. **目的與機制**:
* **目的**:處理序列型態的輸入 (input sequence),例如句子中的詞彙、語音訊號中的幀、圖片中的像素。
* **機制**:
* 針對輸入序列(長度為 N),產生 N 個 Key (K) 向量、N 個 Query (Q) 向量和 N 個 Value (V) 向量。
* Q 向量與 K 向量之間進行點積 (dot product) 計算,形成一個 N x N 的「注意力矩陣 (Attention Matrix)」。
* 根據注意力矩陣對 V 向量進行加權求和 (weighted sum)。
* **運算量**:注意力矩陣的計算量與序列長度 N 的平方成正比 (O(N^2))。
2. **主要痛點**:
* N x N 的注意力矩陣運算量非常大,尤其當輸入序列 N 非常長時,會變得難以承受。
* 這種 O(N^2) 的計算量是加速自注意力機制的首要目標。
3. **加速的時機與背景**:

* 自注意力機制通常是大型網路(如 Transformer)的一部分。
* 只有當輸入序列 N **非常長**,且自注意力機制在整個網路的運算中佔主導地位時,加速其運算才有顯著幫助。
* **範例**:影像處理 (Image Processing)。
* 若處理 256x256 的圖片,每個像素視為一個單位,則 N = 256 * 256。
* 注意力機制計算量將是 (256 * 256)^2 = 256^4,運算量極為驚人。
* 這也是為什麼許多自注意力變形技術最早應用於影像處理領域。
#### **二、 加速自注意力機制的方法:減少注意力矩陣的計算量**
人們主要從兩個方向來減少 N x N 注意力矩陣的計算:

**A. 基於人類理解設計注意力模式 (Human-designed Attention Patterns)**
此類方法透過先驗知識,直接設定注意力矩陣中的某些位置為零(即不計算),只計算重要的部分。

1. **局部注意力 (Local Attention / Truncated Attention)**:
* **想法**:在計算注意力時,每個位置只關注其「左右鄰居」的資訊。
* **機制**:直接將注意力矩陣中,超出預設小範圍以外的位置設為零(圖中灰色部分不計算)。
* **缺點**:
* 每次只看到小範圍內的資訊。
* 與卷積神經網路 (CNN) 類似,可能無法捕捉遠距離的關聯,效果不一定好。
2. **跨步注意力 (Strided Attention)**:

* **想法**:類似局部注意力,但可以跳過一些位置,關注更遠的鄰居,以擴大感受。
* **機制**:注意力矩陣中,計算位置與位置之間存在間隔(例如跳兩格看第三格之外的資訊)。跨步的間隔可以自訂。
3. **全域注意力 (Global Attention)**:

* **想法**:為了捕捉整個序列的全局資訊。
* **機制**:
* 在原始序列中添加一或多個「特殊全域符號 (Special Global Tokens)」。
* 這些特殊符號會「關注 (attend)」到序列中的所有其他符號,以收集整個序列的資訊。
* 同時,序列中的所有其他符號也會「關注」這些特殊符號,以獲取收集到的全局資訊。
* **比喻**:特殊符號如同「里長」,負責串聯所有資訊,而其他符號不直接互動,但都認識里長並透過里長傳遞資訊。
* **實現方式**:
* 從現有輸入序列中選取某些符號作為特殊符號 (e.g., BERT 中的 `[CLS]` 符號,句號)。
* 額外加入新的符號作為特殊符號。
* **注意力矩陣**:特殊符號對所有符號有注意力連結,所有符號對特殊符號有注意力連結。但**其他非特殊符號之間則沒有直接注意力連結**(矩陣中對應位置為零)。
4. **結合多種注意力模式 (Hybrid Approaches)**:

* **理念**:「小孩子才做選擇,真正好的結果是全部都要」。
* **機制**:在一個自注意力模組中設置多個注意力頭 (multiple heads),每個頭執行不同的注意力模式(例如,有些頭做局部注意力,有些做跨步注意力,有些做全域注意力)。
* **實例**:
* **Longformer**:結合了局部注意力 (Local Attention) + 跨步注意力 (Strided Attention) + 全域注意力 (Global Attention)。
* **Big Bird**:在 Longformer 的基礎上,額外加入了「隨機注意力 (Random Attention)」,即隨機選擇一些位置進行注意力計算。
**B. 資料驅動的注意力稀疏化 (Data-driven Attention Pruning)**

此類方法不預先指定注意力模式,而是嘗試讓模型學習或估計哪些注意力值是重要的。
1. **基於分群的注意力 (Clustering-based Attention)**:


* **想法**:注意力矩陣中可能有很多值很小,接近零的位置,可以直接設為零,而不影響結果。透過分群來估計哪些 Q 和 K 之間會有較大的注意力值。
* **機制**:
* 對 Query (Q) 和 Key (K) 向量進行快速分群 (clustering)。
* 只計算屬於「同一群」的 Q 和 K 之間的注意力值。
* 不同群之間的 Q 和 K 的注意力值直接設為零。
* **挑戰**:分群本身也可能耗費大量運算量。
* **解決方案**:採用快速且近似的分群方法。
* **實例**:Reformer 和 Routing Transformer 採用了不同但類似的分群方法。
2. **學習式注意力模式 (Learned Attention Patterns / Sparsity Network)**:

* **想法**:透過另一個神經網路來決定哪些位置需要計算注意力,而不是依賴人類的先驗知識。
* **機制**:
* 輸入序列先經過一個「稀疏網路 (Sparsity Network)」,生成一個 N x N 的二元矩陣 (Binary Matrix),其中「1」代表需要計算注意力,「0」代表直接設為零。
* 這個從連續輸出轉化為二元矩陣的過程是可微分的,使得稀疏網路可以與主網路一同訓練。
* **細節**:為了避免稀疏網路本身的計算量過大,多個輸入區塊可能會「共用」同一個稀疏模式(例如,稀疏網路產生一個較小的稀疏矩陣,再放大給多個區塊使用)。
* **後續研究**:Sympathizer 提出可以直接將注意力矩陣作為網路的參數,而無需兩步生成。
**C. 減少注意力矩陣的維度 (Low-Rank Approximation)**
此類方法基於注意力矩陣的特性,認為其存在冗餘資訊,可以簡化。


1. **低秩注意力 (Low-Rank Attention / Lformer)**:

* **觀察**:研究發現注意力矩陣通常是「低秩矩陣 (low-rank matrix)」,即許多 columns 是重複或相互線性相關的,包含了冗餘資訊。
* **想法**:不需要計算完整的 N x N 矩陣,而是計算一個較小的 N x K 矩陣,其中 K 遠小於 N。
* **機制**:
* 從 N 個 Key 向量中,只選取 K 個「有代表性 (representative)」的 Key 向量。
* 這樣注意力矩陣就從 N x N 變為 N x K。
* **為何不減少 Query 數量?**:
* Query 數量決定了輸出序列的長度。
* 對於需要為每個輸入位置都輸出一個標籤的任務(例如語音辨識、音素分類),減少 Query 數量會導致輸出長度不符,產生問題。
* 但對於只輸出一個總結性標籤的任務,則可能可以減少 Query 數量。
* **如何選擇代表性 Key 向量**:
* **Compressive Attention**:使用 CNN 掃描長序列的 Key 向量,壓縮成短序列作為代表性 Key 向量。
* **Linformer**:將 Key 向量集合視為一個 D x N 矩陣,乘以一個 N x K 矩陣,得到 D x K 矩陣。這 K 個代表性 Key 向量是原始 N 個 Key 向量的線性組合。
**D. 改變矩陣乘法順序 (Matrix Multiplication Reordering)**
這個方法基於線性代數的特性,不改變計算結果,但大幅改變計算成本。



1. **Linear Transformer**:
* **前置簡化**:假設暫時沒有 Softmax 函式,自注意力機制的輸出 O 可以表示為 O = V * K^T * Q。
* **傳統運算順序**:先計算 (K^T * Q),再乘以 V。
* K^T (N x D) 乘以 Q (D x N) 得到一個 N x N 的矩陣。
* 這個 N x N 的矩陣再乘以 V (D x N)。
* **運算量**:(N * D * N) + (N * N * D) = (D + D) * N^2,與 N 的平方成正比。
* **線性化運算順序**:先計算 V * K^T,再乘以 Q。
* V (D x N) 乘以 K^T (N x D) 得到一個 D x D 的矩陣。
* 這個 D x D 的矩陣再乘以 Q (D x N)。
* **運算量**:(D * N * D) + (D * D * N) = (D + D) * D * N = 2 * D^2 * N,與 N 成**線性正比**。
* 由於序列長度 N 通常遠大於向量維度 D,因此 2 * D^2 * N 遠小於 2 * D * N^2,實現了大幅加速。
沒有 Softmax 數學推導部分先跳過,有興趣可以看李宏毅影片
* **重新引入 Softmax**:



* Softmax (exp) 函式無法直接交換順序。
* **解決方案**:將 Q 與 K 的點積 `exp(Q dot K)` 拆解為 `phi(Q) dot phi(K)` 的形式,其中 `phi` 是一個非線性轉換函式。
* **計算流程**:
1. 對所有 Key (K) 向量和 Value (V) 向量,先進行互動處理。將所有 K 透過 `phi` 函式變成 phi(K),然後將 phi(K) 和 V 相乘,產生 M 個「模板向量 (Template Vectors)」。這個步驟只需計算一次。
2. 對於每個 Query (Qi) 向量,透過 `phi` 函式變成 phi(Qi)。
3. phi(Qi) 再與之前計算好的 M 個模板向量進行 weighted sum ,得到最終輸出 Bi。
4. Softmax 的分母部分也以類似的高效方式計算。
* **結果**:這種方法與傳統自注意力機制的計算結果**完全相同**,但運算量從 O(N^2) 降低到 O(N)。
* **不同之處**:Performer、Linear Transformer 等不同論文的主要區別在於其選擇的 `phi` 函式。
**E. 重新思考注意力機制 (Rethinking Attention)**
挑戰傳統自注意力機制的根本假設。

1. **Synthesizer**:
* **想法**:注意力矩陣不一定要由 Q 和 K 向量的互動產生。
* **機制**:
* 將注意力矩陣本身直接視為「神經網路的參數 (parameters)」。
* 這意味著注意力權重是固定的,不隨輸入序列變化而不同。
* **結果**:儘管注意力矩陣不再是動態的,但實際表現並沒有顯著下降。這引發了對注意力機制核心價值(必須是動態適應序列)的重新思考。
2. **無注意力機制 (Attention-Free Methods)**:

* **想法**:能否完全捨棄注意力機制,就像過去捨棄循環神經網路 (Recurrent Networks) 一樣?。
* **方法**:直接使用全連接網路 (Fully Connected Network, FFN) 或多層感知器 (MLP) 來處理序列。
#### **三、 Summary**

* **目的**:評估不同自注意力方法的性能。
* **圖表說明**:
* **Y 軸**:LRA Score,越高越好。
* **X 軸**:Speed,每秒處理的序列數量,越往右越快。
* **圓圈大小**:代表方法所需的記憶體 (Memory) 量,越大代表記憶體使用越多。
* **各方法表現 (從圖中觀察)**:
* **Transformer** (藍色圈):原始自注意力機制,作為基準點。
* **Big Bird**:性能略優於 Transformer,速度稍快。
* **Synthesizer**:性能略有下降,但速度明顯快於 Transformer。
* **Reformer**:性能不一定比 Transformer 好。
* **Local Attention**:速度非常快,但性能表現較差。
* **Sparsity Network / Sparse Transformer**:性能有所下降,但速度比原始 Transformer 快一個檔次。
* **Linformer**:性能有所下降,但速度相對 Transformer 快得多。透過選擇代表性 Key 向量來加速。
* **Linear Transformer**:速度最快的方法,性能相較於 Transformer 也有所下降。透過改變矩陣乘法順序實現線性時間複雜度。
---
其他課程
[【機器學習 2022】01~04 機器學習原理介紹](https://hackmd.io/@JuitingChen/Sk_VtIJaeg)
[【機器學習 2022】05 各式各樣神奇的自注意力機制](https://hackmd.io/@JuitingChen/rJeNpFIpxl)
[【機器學習 2022】06 如何有效的使用自監督式模型](https://hackmd.io/@JuitingChen/BJXeLKD6xx)
[【機器學習 2022】07 語音與影像上的神奇自監督式學習](https://hackmd.io/@JuitingChen/r1q-N1uagg)
[【機器學習 2022】08-09 自然語言處理上的對抗式攻擊-1](https://hackmd.io/@JuitingChen/B14i61uTxx)
[【機器學習 2022】10~11 自然語言處理上的對抗式攻擊-2](https://hackmd.io/@JuitingChen/HkLRoFOTgx)
[【機器學習 2022】12~13 Bert 三個故事 和 各種 Meta Learning 用法](https://hackmd.io/@JuitingChen/HyjfTptTel)