20251011 筆記,內容可能有錯,請參考來源影片。 [李宏毅機器學習2022影片](https://www.youtube.com/playlist?list=PLJV_el3uVTsPM2mM-OQzJXziCGJa8nJL8) 今天影片內容 [【機器學習 2022】各式各樣神奇的自注意力機制 (Self-attention) 變型](https://www.youtube.com/watch?v=yHoAq1IT_og&list=PLJV_el3uVTsPM2mM-OQzJXziCGJa8nJL8&index=5) ### 【機器學習 2022】各式各樣神奇的自注意力機制 (Self-attention) 變型 大綱 #### **一、 自注意力機制 (Self-Attention) 的基礎與挑戰** * **目的與機制**: 處理序列輸入,透過 Q(Query)與 K(Key)的點積計算 N x N 的注意力矩陣,對 V(Value)加權求和。 * **主要痛點**: 運算量為 **O(N²)**,當序列長度 N 非常長(如影像處理)時,計算難以負荷。 * **加速目標**: 將時間複雜度從 O(N²) 降低。 #### **二、 加速自注意力機制的方法 (O(N²) ➝ O(N) 或更低)** #### **A. 減少注意力矩陣的計算量 (稀疏化)** 1. **基於人類設計的注意力模式 (Human-designed)** * **局部注意力 (Local Attention)**:只關注左右鄰居的小範圍資訊。 * **跨步注意力 (Strided Attention)**:間隔性地關注更遠的鄰居,擴大receptive field。 * **全域注意力 (Global Attention)**:引入「特殊全域符號」作為資訊樞紐,收集並分發全局資訊。 * **結合模式 (Hybrid)**:多個注意力頭 (Heads) 執行不同模式。 * *實例*: **Longformer** (局部 + 跨步 + 全域), **Big Bird** (Longformer + 隨機)。 2. **資料驅動的注意力稀疏化 (Data-driven Pruning)** * **基於分群的注意力 (Clustering-based)**:對 Q 和 K 進行快速分群,只計算「同群」之間的注意力值。 * *實例*: **Reformer**, **Routing Transformer**。 * **學習式注意力模式 (Learned Patterns)**:使用一個額外的「稀疏網路」來學習和生成 N x N 的二元稀疏矩陣。 #### **B. 減少注意力矩陣的維度 (Low-Rank Approximation)** * **低秩注意力 (Low-Rank Attention / Linformer)**:觀察到注意力矩陣是低秩的,存在冗餘。 * **機制**: 只從 N 個 Key 向量中選取 **K** 個有代表性的 Key 向量(K << N),將運算量降為 N x K。 #### **C. 改變矩陣乘法順序 (Matrix Multiplication Reordering)** * **Linear Transformer**: * **核心思想**: 利用線性代數特性,在沒有 Softmax 的情況下,將運算順序從 $O = V \cdot (K^T \cdot Q)$ 改為 $O = (V \cdot K^T) \cdot Q$。 * **運算量**: 從 $O(D \cdot N^2)$ 降為 $O(D^2 \cdot N)$,實現與 N 成**線性正比 (O(N))**。 * **Softmax 處理**: 透過將點積 $\exp(Q \cdot K)$ 拆解為 $\phi(Q) \cdot \phi(K)$ 的形式來實現相同的 O(N) 運算結果。 #### **D. 重新思考注意力機制 (Rethinking Attention)** * **Synthesizer**: 挑戰 Q-K 互動的必要性,將注意力矩陣本身視為網路參數(固定權重)。 * **無注意力機制 (Attention-Free)**:嘗試完全捨棄注意力,改用全連接網路 (FFN) 或 MLP 來處理序列。 #### **三、 總結與性能比較 (Summary)** * **評估標準**: LRA Score (性能,越高越好)、Speed (速度,越右越快)、圓圈大小 (記憶體用量,越小越好)。 * **觀察重點**: * **Transformer** (原始版) 為基準。 * **Local Attention** 速度快但性能差。 * **Linformer** (低秩) 和 **Performer / Linear Transformer** (改變順序) 在速度上有顯著提升,但性能相較於原始 Transformer 略有下降。 * **Big Bird** 在性能上略優於原始 Transformer。 #### **一、 自注意力機制的基礎與挑戰**  1. **目的與機制**: * **目的**:處理序列型態的輸入 (input sequence),例如句子中的詞彙、語音訊號中的幀、圖片中的像素。 * **機制**: * 針對輸入序列(長度為 N),產生 N 個 Key (K) 向量、N 個 Query (Q) 向量和 N 個 Value (V) 向量。 * Q 向量與 K 向量之間進行點積 (dot product) 計算,形成一個 N x N 的「注意力矩陣 (Attention Matrix)」。 * 根據注意力矩陣對 V 向量進行加權求和 (weighted sum)。 * **運算量**:注意力矩陣的計算量與序列長度 N 的平方成正比 (O(N^2))。 2. **主要痛點**: * N x N 的注意力矩陣運算量非常大,尤其當輸入序列 N 非常長時,會變得難以承受。 * 這種 O(N^2) 的計算量是加速自注意力機制的首要目標。 3. **加速的時機與背景**:  * 自注意力機制通常是大型網路(如 Transformer)的一部分。 * 只有當輸入序列 N **非常長**,且自注意力機制在整個網路的運算中佔主導地位時,加速其運算才有顯著幫助。 * **範例**:影像處理 (Image Processing)。 * 若處理 256x256 的圖片,每個像素視為一個單位,則 N = 256 * 256。 * 注意力機制計算量將是 (256 * 256)^2 = 256^4,運算量極為驚人。 * 這也是為什麼許多自注意力變形技術最早應用於影像處理領域。 #### **二、 加速自注意力機制的方法:減少注意力矩陣的計算量** 人們主要從兩個方向來減少 N x N 注意力矩陣的計算:  **A. 基於人類理解設計注意力模式 (Human-designed Attention Patterns)** 此類方法透過先驗知識,直接設定注意力矩陣中的某些位置為零(即不計算),只計算重要的部分。  1. **局部注意力 (Local Attention / Truncated Attention)**: * **想法**:在計算注意力時,每個位置只關注其「左右鄰居」的資訊。 * **機制**:直接將注意力矩陣中,超出預設小範圍以外的位置設為零(圖中灰色部分不計算)。 * **缺點**: * 每次只看到小範圍內的資訊。 * 與卷積神經網路 (CNN) 類似,可能無法捕捉遠距離的關聯,效果不一定好。 2. **跨步注意力 (Strided Attention)**:  * **想法**:類似局部注意力,但可以跳過一些位置,關注更遠的鄰居,以擴大感受。 * **機制**:注意力矩陣中,計算位置與位置之間存在間隔(例如跳兩格看第三格之外的資訊)。跨步的間隔可以自訂。 3. **全域注意力 (Global Attention)**:  * **想法**:為了捕捉整個序列的全局資訊。 * **機制**: * 在原始序列中添加一或多個「特殊全域符號 (Special Global Tokens)」。 * 這些特殊符號會「關注 (attend)」到序列中的所有其他符號,以收集整個序列的資訊。 * 同時,序列中的所有其他符號也會「關注」這些特殊符號,以獲取收集到的全局資訊。 * **比喻**:特殊符號如同「里長」,負責串聯所有資訊,而其他符號不直接互動,但都認識里長並透過里長傳遞資訊。 * **實現方式**: * 從現有輸入序列中選取某些符號作為特殊符號 (e.g., BERT 中的 `[CLS]` 符號,句號)。 * 額外加入新的符號作為特殊符號。 * **注意力矩陣**:特殊符號對所有符號有注意力連結,所有符號對特殊符號有注意力連結。但**其他非特殊符號之間則沒有直接注意力連結**(矩陣中對應位置為零)。 4. **結合多種注意力模式 (Hybrid Approaches)**:  * **理念**:「小孩子才做選擇,真正好的結果是全部都要」。 * **機制**:在一個自注意力模組中設置多個注意力頭 (multiple heads),每個頭執行不同的注意力模式(例如,有些頭做局部注意力,有些做跨步注意力,有些做全域注意力)。 * **實例**: * **Longformer**:結合了局部注意力 (Local Attention) + 跨步注意力 (Strided Attention) + 全域注意力 (Global Attention)。 * **Big Bird**:在 Longformer 的基礎上,額外加入了「隨機注意力 (Random Attention)」,即隨機選擇一些位置進行注意力計算。 **B. 資料驅動的注意力稀疏化 (Data-driven Attention Pruning)**  此類方法不預先指定注意力模式,而是嘗試讓模型學習或估計哪些注意力值是重要的。 1. **基於分群的注意力 (Clustering-based Attention)**:   * **想法**:注意力矩陣中可能有很多值很小,接近零的位置,可以直接設為零,而不影響結果。透過分群來估計哪些 Q 和 K 之間會有較大的注意力值。 * **機制**: * 對 Query (Q) 和 Key (K) 向量進行快速分群 (clustering)。 * 只計算屬於「同一群」的 Q 和 K 之間的注意力值。 * 不同群之間的 Q 和 K 的注意力值直接設為零。 * **挑戰**:分群本身也可能耗費大量運算量。 * **解決方案**:採用快速且近似的分群方法。 * **實例**:Reformer 和 Routing Transformer 採用了不同但類似的分群方法。 2. **學習式注意力模式 (Learned Attention Patterns / Sparsity Network)**:  * **想法**:透過另一個神經網路來決定哪些位置需要計算注意力,而不是依賴人類的先驗知識。 * **機制**: * 輸入序列先經過一個「稀疏網路 (Sparsity Network)」,生成一個 N x N 的二元矩陣 (Binary Matrix),其中「1」代表需要計算注意力,「0」代表直接設為零。 * 這個從連續輸出轉化為二元矩陣的過程是可微分的,使得稀疏網路可以與主網路一同訓練。 * **細節**:為了避免稀疏網路本身的計算量過大,多個輸入區塊可能會「共用」同一個稀疏模式(例如,稀疏網路產生一個較小的稀疏矩陣,再放大給多個區塊使用)。 * **後續研究**:Sympathizer 提出可以直接將注意力矩陣作為網路的參數,而無需兩步生成。 **C. 減少注意力矩陣的維度 (Low-Rank Approximation)** 此類方法基於注意力矩陣的特性,認為其存在冗餘資訊,可以簡化。   1. **低秩注意力 (Low-Rank Attention / Lformer)**:  * **觀察**:研究發現注意力矩陣通常是「低秩矩陣 (low-rank matrix)」,即許多 columns 是重複或相互線性相關的,包含了冗餘資訊。 * **想法**:不需要計算完整的 N x N 矩陣,而是計算一個較小的 N x K 矩陣,其中 K 遠小於 N。 * **機制**: * 從 N 個 Key 向量中,只選取 K 個「有代表性 (representative)」的 Key 向量。 * 這樣注意力矩陣就從 N x N 變為 N x K。 * **為何不減少 Query 數量?**: * Query 數量決定了輸出序列的長度。 * 對於需要為每個輸入位置都輸出一個標籤的任務(例如語音辨識、音素分類),減少 Query 數量會導致輸出長度不符,產生問題。 * 但對於只輸出一個總結性標籤的任務,則可能可以減少 Query 數量。 * **如何選擇代表性 Key 向量**: * **Compressive Attention**:使用 CNN 掃描長序列的 Key 向量,壓縮成短序列作為代表性 Key 向量。 * **Linformer**:將 Key 向量集合視為一個 D x N 矩陣,乘以一個 N x K 矩陣,得到 D x K 矩陣。這 K 個代表性 Key 向量是原始 N 個 Key 向量的線性組合。 **D. 改變矩陣乘法順序 (Matrix Multiplication Reordering)** 這個方法基於線性代數的特性,不改變計算結果,但大幅改變計算成本。    1. **Linear Transformer**: * **前置簡化**:假設暫時沒有 Softmax 函式,自注意力機制的輸出 O 可以表示為 O = V * K^T * Q。 * **傳統運算順序**:先計算 (K^T * Q),再乘以 V。 * K^T (N x D) 乘以 Q (D x N) 得到一個 N x N 的矩陣。 * 這個 N x N 的矩陣再乘以 V (D x N)。 * **運算量**:(N * D * N) + (N * N * D) = (D + D) * N^2,與 N 的平方成正比。 * **線性化運算順序**:先計算 V * K^T,再乘以 Q。 * V (D x N) 乘以 K^T (N x D) 得到一個 D x D 的矩陣。 * 這個 D x D 的矩陣再乘以 Q (D x N)。 * **運算量**:(D * N * D) + (D * D * N) = (D + D) * D * N = 2 * D^2 * N,與 N 成**線性正比**。 * 由於序列長度 N 通常遠大於向量維度 D,因此 2 * D^2 * N 遠小於 2 * D * N^2,實現了大幅加速。 沒有 Softmax 數學推導部分先跳過,有興趣可以看李宏毅影片 * **重新引入 Softmax**:    * Softmax (exp) 函式無法直接交換順序。 * **解決方案**:將 Q 與 K 的點積 `exp(Q dot K)` 拆解為 `phi(Q) dot phi(K)` 的形式,其中 `phi` 是一個非線性轉換函式。 * **計算流程**: 1. 對所有 Key (K) 向量和 Value (V) 向量,先進行互動處理。將所有 K 透過 `phi` 函式變成 phi(K),然後將 phi(K) 和 V 相乘,產生 M 個「模板向量 (Template Vectors)」。這個步驟只需計算一次。 2. 對於每個 Query (Qi) 向量,透過 `phi` 函式變成 phi(Qi)。 3. phi(Qi) 再與之前計算好的 M 個模板向量進行 weighted sum ,得到最終輸出 Bi。 4. Softmax 的分母部分也以類似的高效方式計算。 * **結果**:這種方法與傳統自注意力機制的計算結果**完全相同**,但運算量從 O(N^2) 降低到 O(N)。 * **不同之處**:Performer、Linear Transformer 等不同論文的主要區別在於其選擇的 `phi` 函式。 **E. 重新思考注意力機制 (Rethinking Attention)** 挑戰傳統自注意力機制的根本假設。  1. **Synthesizer**: * **想法**:注意力矩陣不一定要由 Q 和 K 向量的互動產生。 * **機制**: * 將注意力矩陣本身直接視為「神經網路的參數 (parameters)」。 * 這意味著注意力權重是固定的,不隨輸入序列變化而不同。 * **結果**:儘管注意力矩陣不再是動態的,但實際表現並沒有顯著下降。這引發了對注意力機制核心價值(必須是動態適應序列)的重新思考。 2. **無注意力機制 (Attention-Free Methods)**:  * **想法**:能否完全捨棄注意力機制,就像過去捨棄循環神經網路 (Recurrent Networks) 一樣?。 * **方法**:直接使用全連接網路 (Fully Connected Network, FFN) 或多層感知器 (MLP) 來處理序列。 #### **三、 Summary**  * **目的**:評估不同自注意力方法的性能。 * **圖表說明**: * **Y 軸**:LRA Score,越高越好。 * **X 軸**:Speed,每秒處理的序列數量,越往右越快。 * **圓圈大小**:代表方法所需的記憶體 (Memory) 量,越大代表記憶體使用越多。 * **各方法表現 (從圖中觀察)**: * **Transformer** (藍色圈):原始自注意力機制,作為基準點。 * **Big Bird**:性能略優於 Transformer,速度稍快。 * **Synthesizer**:性能略有下降,但速度明顯快於 Transformer。 * **Reformer**:性能不一定比 Transformer 好。 * **Local Attention**:速度非常快,但性能表現較差。 * **Sparsity Network / Sparse Transformer**:性能有所下降,但速度比原始 Transformer 快一個檔次。 * **Linformer**:性能有所下降,但速度相對 Transformer 快得多。透過選擇代表性 Key 向量來加速。 * **Linear Transformer**:速度最快的方法,性能相較於 Transformer 也有所下降。透過改變矩陣乘法順序實現線性時間複雜度。 --- 其他課程 [【機器學習 2022】01~04 機器學習原理介紹](https://hackmd.io/@JuitingChen/Sk_VtIJaeg) [【機器學習 2022】05 各式各樣神奇的自注意力機制](https://hackmd.io/@JuitingChen/rJeNpFIpxl) [【機器學習 2022】06 如何有效的使用自監督式模型](https://hackmd.io/@JuitingChen/BJXeLKD6xx) [【機器學習 2022】07 語音與影像上的神奇自監督式學習](https://hackmd.io/@JuitingChen/r1q-N1uagg) [【機器學習 2022】08-09 自然語言處理上的對抗式攻擊-1](https://hackmd.io/@JuitingChen/B14i61uTxx) [【機器學習 2022】10~11 自然語言處理上的對抗式攻擊-2](https://hackmd.io/@JuitingChen/HkLRoFOTgx) [【機器學習 2022】12~13 Bert 三個故事 和 各種 Meta Learning 用法](https://hackmd.io/@JuitingChen/HyjfTptTel)
×
Sign in
Email
Password
Forgot password
or
By clicking below, you agree to our
terms of service
.
Sign in via Facebook
Sign in via Twitter
Sign in via GitHub
Sign in via Dropbox
Sign in with Wallet
Wallet (
)
Connect another wallet
New to HackMD?
Sign up