# Swin Transformer ###### tags: `CV` ![](https://i.imgur.com/4Ansji4.png) ## 介紹 **ViT** 將注意力機制引進了 **CV** 領域,取得很好的結果,不過,**VIT** 計算開銷太大,且無法應用於物件偵測、影像切割等任務, **SWIN Transformer** 是 **ViT** 的改進,使用 **shift window** 方法降低模型的計算複雜度,模型的應用也變得更加廣泛,可用於影像辨識,影像偵測,影像切割等任務。 <center> <img src=https://i.imgur.com/e0qJTbg.png width=500 height=500> </center> ## 模型 ![](https://i.imgur.com/HzNvNYB.png) 1. **Patch Partition**: 將圖片劃分成多個 **patch**,如同 **Vision Transformer**。 2. **Linear Embedding**: 將 **patch** 做線性轉換,其中,**patch** 共用同一個線性映射函數。 :::info **注意**: 這裡只單純做線性映射,沒有使用 **position encoding**。 ::: 3. **SWIN Transformer Block**: 分為 **W-MSA**, **SW-MSA** 兩個步驟,每個步驟都會對 **window** 內的 **patch** 做(多頭)自注意力。自注意力的公式如下: $$ Attention(Q, K, V) = SoftMax(QK^T/\sqrt{d} + B)V $$ 其中,$Q, K, V \in \mathbb{R}^{M^2 \times d}$,分別代表 $query, key, value$, $M^2$ 為一個 **window** 中 **patch** 的數量,$B \in \mathbb{R}^{M^2 \times M^2}$ 為 **relative position bias**,提供兩兩 **patch** 間的相對位置訊息。 + **W-MSA** 對 **window** 內的 **patch** 做 **self-attention**。 + **SW-MSA**: 先對 **windows** 做平移,再將新的 **window** 內的 **patch** 做 **self-attention**。 $$ \begin{align} &\hat{z}^{l} = WMSA(LN(z^{l-1})) + z^{l-1},\\ &z^l = MLP(LN(\hat{z}^{l})) + \hat{z}^{l},\\ &\hat{z}^{l+1} = SWMSA(LN(z^{l})) + z^{l},\\ &z^{l+1} = MLP(LN(\hat{z}^{l+1})) + \hat{z}^{l+1}. \end{align} $$ <center> <img src=https://i.imgur.com/VdfVkKy.png =300x200 width=300> </center> :::info 1. **W-MSA** 可類比成 **CNN** 中 **stride=1** 的 **convolution**。 2. **SWIN Transformer Block** 的輸入與輸出形狀相同。 ::: 4. **Patch Merging**: 將周圍的 **patch** 做 **concatenate**,再對融合特徵做線性轉換,如下圖所示。 ![](https://i.imgur.com/Ps248m1.png) :::info **Patch Merging** 可類比成 **CNN** 中 **max pooling**。 ::: ## Relative Position Bias & Mask Attention 1. **Relative Position Bias**: 每個維度上的相對位置取值範圍為 $[-M + 1, M-1]$,因此 $B$ 有 $(2M-1) \times (2M-1)$ 種不同的取值。相比而言,**Vision Transformer** 使用的 **Position Encoding** 是全局的,且只在 **Linear Embedding** 部分使用。 ![](https://i.imgur.com/fQO5Nvq.png) 2. **Mask Attention**: **SW-MSA** 會先對 **feature map** 做 **cyclic shift**,如下圖所示,右、下、右下區域 **window** 中的一些 **patch** 原本是不相鄰的,因此,這些 **patch** 兩兩之間應該沒有注意力權重,要進行 **Masked Attention**。 ![](https://i.imgur.com/fYWrH9j.png) ## 實驗 1. **模型參數** + **Window size** $M=7$ + **Query dimension** $d=32$ + **Expansion layer of each parameter** $\alpha = 4$: | name | C (Channel) | layer number| | ---- | ---- | ----| | **Swin-T** | $96$ | $\{2, 2, 6, 2\}$| | **Swin-S** | $96$ | $\{2, 2, 18, 2\}$| | **Swin-B** | $128$ | $\{2, 2, 6, 2\}$| | **Swin-L** | $192$ | $\{2, 2, 6, 2\}$| 2. **Classification** + 直接在 **ImageNet-1K** 上訓練 ![](https://i.imgur.com/zllBCXt.png) + 在 **ImageNet-22K** 上 pretrain,**ImageNet-1K** 上 fine-tune ![](https://i.imgur.com/2NP7UQY.png) ## 總結 1. **SWIN** 通過 **Window Attention & Patch Merging** 來降低計算量。 2. **SWIN** 通過 **W-MSA & SW-MSA & Patch Merging** 來增加感受域;隨著層數增加,感受域不斷擴大。 3. **SWIN** 架構與 **CNN** 類似,且可應用於多種的影像任務 4. **CNN & ViT** 系列模型有收斂的趨勢。 5. **SWIN** 最好先在大資料上 pre-train,且結果要比 **ViT** 好 (**SAM optimizer?**)。 ## 程式 1. **PatchEmbed**: input shape (B, C, H, W), output shape (B, Ph*Pw, C) ```mermaid graph LR A["Conv2d"] --> B["Layer Norm"] ``` 2. **SwinTransformerBlock**: input shape (B, L, C), output shape (B, L, C) ```mermaid graph LR A["Layer Norm"] --> B["Roll (if shift)"] --> C["Window Partition"] --> D["self attention"] --> E["Window Reverse"] --> F["Roll (if shift)"] --> G["skip connection"] --> H["MLP"] --> I["skip connection"] ``` 3. **WindowAttention**: input shape (num_windows*B, N, C), output shape (num_windows*B, N, C) 4. **PatchMerging**: input shape (B, L, C), output shape (B, L/4, 2*C). Share fully connected weight during patch merging. ```mermaid graph LR A["reshape to (B, H, W, C)"] --> B["concate neighbor"] --> C["fully connected layer (share weight)"] ``` 5. **BasicLayer** ```mermaid graph LR A[" SwinTransformerBlock"] --> B["SwinTransformerBlock (shift)"] B --> A B --> C["Downsample (Patch Merging)"] ``` ## 參考資料 1. [Github](https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py) 2. [論文](https://arxiv.org/pdf/2103.14030.pdf) 3. [B 站](https://www.bilibili.com/video/BV1bq4y1r75w?spm_id_from=333.337.search-card.all.click) 4. [中文解說](https://its201.com/article/qq_45893319/121207967)