# Swin Transformer
###### tags: `CV`
![](https://i.imgur.com/4Ansji4.png)
## 介紹
**ViT** 將注意力機制引進了 **CV** 領域,取得很好的結果,不過,**VIT** 計算開銷太大,且無法應用於物件偵測、影像切割等任務, **SWIN Transformer** 是 **ViT** 的改進,使用 **shift window** 方法降低模型的計算複雜度,模型的應用也變得更加廣泛,可用於影像辨識,影像偵測,影像切割等任務。
<center>
<img src=https://i.imgur.com/e0qJTbg.png width=500 height=500>
</center>
## 模型
![](https://i.imgur.com/HzNvNYB.png)
1. **Patch Partition**: 將圖片劃分成多個 **patch**,如同 **Vision Transformer**。
2. **Linear Embedding**: 將 **patch** 做線性轉換,其中,**patch** 共用同一個線性映射函數。
:::info
**注意**: 這裡只單純做線性映射,沒有使用 **position encoding**。
:::
3. **SWIN Transformer Block**: 分為 **W-MSA**, **SW-MSA** 兩個步驟,每個步驟都會對 **window** 內的 **patch** 做(多頭)自注意力。自注意力的公式如下:
$$
Attention(Q, K, V) = SoftMax(QK^T/\sqrt{d} + B)V
$$ 其中,$Q, K, V \in \mathbb{R}^{M^2 \times d}$,分別代表 $query, key, value$, $M^2$ 為一個 **window** 中 **patch** 的數量,$B \in \mathbb{R}^{M^2 \times M^2}$ 為 **relative position bias**,提供兩兩 **patch** 間的相對位置訊息。
+ **W-MSA** 對 **window** 內的 **patch** 做 **self-attention**。
+ **SW-MSA**: 先對 **windows** 做平移,再將新的 **window** 內的 **patch** 做 **self-attention**。
$$
\begin{align}
&\hat{z}^{l} = WMSA(LN(z^{l-1})) + z^{l-1},\\
&z^l = MLP(LN(\hat{z}^{l})) + \hat{z}^{l},\\
&\hat{z}^{l+1} = SWMSA(LN(z^{l})) + z^{l},\\
&z^{l+1} = MLP(LN(\hat{z}^{l+1})) + \hat{z}^{l+1}.
\end{align}
$$
<center>
<img src=https://i.imgur.com/VdfVkKy.png =300x200
width=300>
</center>
:::info
1. **W-MSA** 可類比成 **CNN** 中 **stride=1** 的 **convolution**。
2. **SWIN Transformer Block** 的輸入與輸出形狀相同。
:::
4. **Patch Merging**: 將周圍的 **patch** 做 **concatenate**,再對融合特徵做線性轉換,如下圖所示。
![](https://i.imgur.com/Ps248m1.png)
:::info
**Patch Merging** 可類比成 **CNN** 中 **max pooling**。
:::
## Relative Position Bias & Mask Attention
1. **Relative Position Bias**: 每個維度上的相對位置取值範圍為 $[-M + 1, M-1]$,因此 $B$ 有 $(2M-1) \times (2M-1)$ 種不同的取值。相比而言,**Vision Transformer** 使用的 **Position Encoding** 是全局的,且只在 **Linear Embedding** 部分使用。
![](https://i.imgur.com/fQO5Nvq.png)
2. **Mask Attention**: **SW-MSA** 會先對 **feature map** 做 **cyclic shift**,如下圖所示,右、下、右下區域 **window** 中的一些 **patch** 原本是不相鄰的,因此,這些 **patch** 兩兩之間應該沒有注意力權重,要進行 **Masked Attention**。
![](https://i.imgur.com/fYWrH9j.png)
## 實驗
1. **模型參數**
+ **Window size** $M=7$
+ **Query dimension** $d=32$
+ **Expansion layer of each parameter** $\alpha = 4$:
| name | C (Channel) | layer number|
| ---- | ---- | ----|
| **Swin-T** | $96$ | $\{2, 2, 6, 2\}$|
| **Swin-S** | $96$ | $\{2, 2, 18, 2\}$|
| **Swin-B** | $128$ | $\{2, 2, 6, 2\}$|
| **Swin-L** | $192$ | $\{2, 2, 6, 2\}$|
2. **Classification**
+ 直接在 **ImageNet-1K** 上訓練
![](https://i.imgur.com/zllBCXt.png)
+ 在 **ImageNet-22K** 上 pretrain,**ImageNet-1K** 上 fine-tune
![](https://i.imgur.com/2NP7UQY.png)
## 總結
1. **SWIN** 通過 **Window Attention & Patch Merging** 來降低計算量。
2. **SWIN** 通過 **W-MSA & SW-MSA & Patch Merging** 來增加感受域;隨著層數增加,感受域不斷擴大。
3. **SWIN** 架構與 **CNN** 類似,且可應用於多種的影像任務
4. **CNN & ViT** 系列模型有收斂的趨勢。
5. **SWIN** 最好先在大資料上 pre-train,且結果要比 **ViT** 好 (**SAM optimizer?**)。
## 程式
1. **PatchEmbed**: input shape (B, C, H, W), output shape (B, Ph*Pw, C)
```mermaid
graph LR
A["Conv2d"] --> B["Layer Norm"]
```
2. **SwinTransformerBlock**: input shape (B, L, C), output shape (B, L, C)
```mermaid
graph LR
A["Layer Norm"] --> B["Roll (if shift)"] --> C["Window Partition"] --> D["self attention"] --> E["Window Reverse"] --> F["Roll (if shift)"] --> G["skip connection"] --> H["MLP"] --> I["skip connection"]
```
3. **WindowAttention**: input shape (num_windows*B, N, C), output shape (num_windows*B, N, C)
4. **PatchMerging**: input shape (B, L, C), output shape (B, L/4, 2*C). Share fully connected weight during patch merging.
```mermaid
graph LR
A["reshape to (B, H, W, C)"] --> B["concate neighbor"] --> C["fully connected layer (share weight)"]
```
5. **BasicLayer**
```mermaid
graph LR
A[" SwinTransformerBlock"] --> B["SwinTransformerBlock (shift)"]
B --> A
B --> C["Downsample (Patch Merging)"]
```
## 參考資料
1. [Github](https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py)
2. [論文](https://arxiv.org/pdf/2103.14030.pdf)
3. [B 站](https://www.bilibili.com/video/BV1bq4y1r75w?spm_id_from=333.337.search-card.all.click)
4. [中文解說](https://its201.com/article/qq_45893319/121207967)