# TOKEN MERGING: YOUR VIT BUT FASTER ###### tags: `paper notes` `deep learning` > 這是一篇 2022/10/17 發表的論文,來自 Meta,而這篇方法之簡潔確實是非常驚喜 [paper link](https://arxiv.org/abs/2210.09461) # Intro - 目標是透過在 ViT 中增加一個 module 來融合冗余(redundant) 的 token,好增加 training / inference throughput - 融合 Token 並不是什麼新提出的 trick,例如 [Token Pooling, 2021](https://arxiv.org/abs/2110.03860) 就也是在做類似的事情 - 宣稱不論是否有加進去重新訓練都可以運作,非常之神奇,也是這一特點讓我選擇這篇來報 - 目前提升 transformers 效率的領域有 Efficient transformer, Token Reduction, Combine Token,但都沒有提出可以不需要訓練就獲得提升的方法,而且這些方法不僅又耗費資源又慢又缺乏泛化性 - 他們所謂的不訓練就能獲得提升未必能適用在所有案例之中,有些案例可能還是需要訓練 ![](https://i.imgur.com/s4kmPav.png) - 這個 module 是會掉精度的,並不是完全沒有代價 - 有提供[實作](https://github.com/facebookresearch/ToMe),而且是寫成 timm module,或許可以直接使用 # **TOKEN MERGING** ![](https://i.imgur.com/EGYVMHy.png) - 很多方法都是把對原始 ViT 的操作放在輸入端和 Attention 之間,但他們是放在 Attention 之後 - 這或許也是他們能夠做到不訓練就能有用的原因之一 - 融合的機制是漸漸(gradually)的融合,而不是一次就融合 ## 設計理念 1. 融合 “相似” 的 token 2. 每一層都逐漸減少**固定大小**的 token,不會動態調整 - 這個機制是完全不論你的圖片內容如何都減少一樣數量的 token - 舉例來說,input token = 768,過了第一層 768-10 = 758,第二層又 758-10 = 748,至此就有 2*10 = 20 (rL) 個 token 被融合掉了,假如網路有 24 層數就總共會少掉 24*10 = 240 個 token ## 定義 Token Similarity - 首先,原始的 ViT 架構基本上一定是 overparameterized,舉例來說,ViT-B/16 共有 16*16*3=768 個特徵來編碼一張 RGB 圖片的 pixel, 這很有可能造成裡面包含了很多跟達到目的無關的 noise 難道距離越近的 token 代表相似嗎? 不如來點消融實驗吧 ![](https://i.imgur.com/0bUuLVa.png) - 從實驗結果可以得到,使用 consine 相似度作為距離函數在速度與精度上會是最好的選擇,而且把不同 head 的 Key 的特徵取**平均**作為融合之後的代表是最好的 - $X_{pre}$ 是 input token 的 feature,這個是把 merging 放在做 attention 之前,而 $X$ 則是做完 attention 之後的 token feature ![](https://i.imgur.com/0iUgrLm.png) ## **Bipartite Soft Matching** 有了計算距離的方式以及拿什麼來計算,那實際上要怎麼 **”選擇”** token 來減少? - K-means ? 那太慢了,網路中可能會有上千個 token,而且 clustering 方法沒辦法去限制每一個類別中的具體數量,所以他們選擇了 matching 方法 所以他們基於 1. 可平行化 2. 可做到“逐漸”融合 這兩個需求設計了下面的計算流程 ![](https://i.imgur.com/f1vCNv7.png) - 示意圖就在上面的 Figure 1. - 這其實就是所謂的 bipartite graph matching,所有的 edge 都會橫跨兩個不同類別 (A or B) 的 token ![](https://i.imgur.com/sXfwDae.png) - 關於這些 token 要如何去分為 A 和 B,他們也做了消融實驗 - 結論就是 token_1 → A ,token_2 →B, token_3 → A… 這種交替的方式最好 ![](https://i.imgur.com/Irg6irT.png) - 第二個步驟會造成**每一個 A 類別中的 token 都僅有一個 edge**,所以其實第四步驟並不會算太久,因為你只要一直去找 A 裡面還沒被選到的就好 - 計算複雜度隨著層數而遞減,因為 A 會越來越少 - 第三個步驟的參數 r 就是每一層要減少的固定 token 數量 - 論文裡面還有附這個流程的 code ![](https://i.imgur.com/ZAFdT9U.png) ## **Tracking Token Size** 在融合 token 之後,就表示我們會把一個 token 映射到多個輸入,這會造成 softmax 的結果受到影響 - 當有多個 token 的 key 完全一樣,這時候這個 key 對於 softmax 輸出的影響就會減弱 - 因為 ViT 的一個 token 代表一個輸入,Attnetion矩陣的維度也是根據這個來建立,在融合之後就會變成 (N-r)*(N-r),而此時 softmax 仍然只會把融合後的 token 當成一個正常的 token 來算,但這個 token 實際上卻代表了多個 token,理論上應該要有更高的權重才對 他們透過將原始的 QKV 運算修改成 proportional attention - s is a row vector containing the size of each token (number of patches the token represents) - 這個 s 就相當於是人為的去增加某些 token 的 weight,使其影響更大 ![](https://i.imgur.com/jRiDTHy.png) 這個方法的選擇當然也是消融實驗做出來的結果 - keep one 是只留下 B 中的 token ![](https://i.imgur.com/Ks1j6rr.png) ## Training with merging - 本質上這一整個 module 可以看成是一個 pooling,而且是 averging pooling,所以其實在做後傳導上並不需要特別做甚麼 ## 其他的消融實驗 - 比較其他不同的 matching 演算法和聚類 ![](https://i.imgur.com/TYCM3aI.png) - 這個算法的速度其實跟隨機丟掉tokens的速度差不多 - 這張圖是他們暴搜了 15000 種融合的策略的結果,就是為了測出在每一層要減少多少 token 才會是比較好的數值,但結果發現以每一層減少固定數量的方式其實跟暴搜出來的結果沒有差很多 ![](https://i.imgur.com/J5t7B10.png) - 另外他們也定義了另一個 decreasing schedule 來做實驗 ![](https://i.imgur.com/05Tb7Vf.png) # Result ## Comparison to SoTA models trained only on ImageNet-1k ![](https://i.imgur.com/IU7a2o3.png) ## Compare with pruning methods ![](https://i.imgur.com/m09tVa9.png) ## 結果展示圖 ![](https://i.imgur.com/WBI6UUE.png) ![](https://i.imgur.com/CXJU0DA.jpg)