clh
    • Create new note
    • Create a note from template
      • Sharing URL Link copied
      • /edit
      • View mode
        • Edit mode
        • View mode
        • Book mode
        • Slide mode
        Edit mode View mode Book mode Slide mode
      • Customize slides
      • Note Permission
      • Read
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Write
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Engagement control Commenting, Suggest edit, Emoji Reply
    • Invite by email
      Invitee

      This note has no invitees

    • Publish Note

      Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note No publishing access yet

      Your note will be visible on your profile and discoverable by anyone.
      Your note is now live.
      This note is visible on your profile and discoverable online.
      Everyone on the web can find and read all notes of this public team.

      Your account was recently created. Publishing will be available soon, allowing you to share notes on your public page and in search results.

      Your team account was recently created. Publishing will be available soon, allowing you to share notes on your public page and in search results.

      Explore these features while you wait
      Complete general settings
      Bookmark and like published notes
      Write a few more notes
      Complete general settings
      Write a few more notes
      See published notes
      Unpublish note
      Please check the box to agree to the Community Guidelines.
      View profile
    • Commenting
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
      • Everyone
    • Suggest edit
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
    • Emoji Reply
    • Enable
    • Versions and GitHub Sync
    • Note settings
    • Note Insights New
    • Engagement control
    • Make a copy
    • Transfer ownership
    • Delete this note
    • Save as template
    • Insert from template
    • Import from
      • Dropbox
      • Google Drive
      • Gist
      • Clipboard
    • Export to
      • Dropbox
      • Google Drive
      • Gist
    • Download
      • Markdown
      • HTML
      • Raw HTML
Menu Note settings Note Insights Versions and GitHub Sync Sharing URL Create Help
Create Create new note Create a note from template
Menu
Options
Engagement control Make a copy Transfer ownership Delete this note
Import from
Dropbox Google Drive Gist Clipboard
Export to
Dropbox Google Drive Gist
Download
Markdown HTML Raw HTML
Back
Sharing URL Link copied
/edit
View mode
  • Edit mode
  • View mode
  • Book mode
  • Slide mode
Edit mode View mode Book mode Slide mode
Customize slides
Note Permission
Read
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Write
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Engagement control Commenting, Suggest edit, Emoji Reply
  • Invite by email
    Invitee

    This note has no invitees

  • Publish Note

    Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note No publishing access yet

    Your note will be visible on your profile and discoverable by anyone.
    Your note is now live.
    This note is visible on your profile and discoverable online.
    Everyone on the web can find and read all notes of this public team.

    Your account was recently created. Publishing will be available soon, allowing you to share notes on your public page and in search results.

    Your team account was recently created. Publishing will be available soon, allowing you to share notes on your public page and in search results.

    Explore these features while you wait
    Complete general settings
    Bookmark and like published notes
    Write a few more notes
    Complete general settings
    Write a few more notes
    See published notes
    Unpublish note
    Please check the box to agree to the Community Guidelines.
    View profile
    Engagement control
    Commenting
    Permission
    Disabled Forbidden Owners Signed-in users Everyone
    Enable
    Permission
    • Forbidden
    • Owners
    • Signed-in users
    • Everyone
    Suggest edit
    Permission
    Disabled Forbidden Owners Signed-in users Everyone
    Enable
    Permission
    • Forbidden
    • Owners
    • Signed-in users
    Emoji Reply
    Enable
    Import from Dropbox Google Drive Gist Clipboard
       Owned this note    Owned this note      
    Published Linked with GitHub
    • Any changes
      Be notified of any changes
    • Mention me
      Be notified of mention me
    • Unsubscribe
    # Swin Transformer 源自: [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/pdf/2103.14030) ![image](https://hackmd.io/_uploads/rJlicK-ZWg.png) ## 原理 ### Hierarchical Architecture >Swin Transformer, which constructs hierarchical feature maps and has linear computational complexity to image size. Swin Transformer 則建構了分層特徵圖 (hierarchical feature maps),使其能夠像傳統的卷積神經網絡(CNNs,如 VGG 和 ResNet)一樣,產生具有不同解析度的特徵表示 ![image](https://hackmd.io/_uploads/HyRNoKZ-Zl.png) ### Patch Partition and Embedding > It first splits an input RGB image into non-overlapping patches by a patch splitting module, like ViT. Each patch is treated as a “token” and its feature is set as a concatenation of the raw pixel RGB values. - 模型首先透過圖像塊分割模組 (patch splitting module) 將輸入的 RGB 圖像分割成不重疊的圖像塊 (non-overlapping patches)。 - 在實作中,每個 4×4 的圖像塊被視為一個「token」,其特徵是原始像素 RGB 值的串聯,維度為 4×4×3=48。 - 接著,應用線性嵌入層 (linear embedding layer) 將此原始 feature 投影到任意維度 C。 ![image](https://hackmd.io/_uploads/ryeghtbWbl.png) ### Patch Merging - 隨著網路深入,模型透過圖像塊合併層 (patch merging layers) 來減少 tokens 的數量,從而建立分層表示。 >To produce a hierarchical representation, the number of tokens is reduced by patch merging layers as the network gets deeper. - 第一個合併層會將相鄰的 2×2 個圖像塊的特徵進行串聯,形成 4C 維度的 feature。 >The first patch merging layer concatenates the features of each group of 2 × 2 neighboring patches, and applies a linear layer on the 4C-dimensional concatenated features. - 隨後,應用一個線性層,使 tokens 數量減少 2×2=4 倍(解析度降低一半),但輸出維度增加到 2C >This reduces the number of tokens by a multiple of 2×2 = 4 (2× downsampling of resolution), and the out-put dimension is set to 2C. reference: [Swin Transformer解读](https://datawhalechina.github.io/thorough-pytorch/%E7%AC%AC%E5%8D%81%E7%AB%A0/Swin-Transformer%E8%A7%A3%E8%AF%BB.html) ![image](https://hackmd.io/_uploads/BkM-85-bbg.png) ![image](https://hackmd.io/_uploads/ry3BI9-Z-l.png) ### Shifted Window based Self-Attention (SW-MSA) ![image](https://hackmd.io/_uploads/r1-SJsb-bg.png) >To intro-duce cross-window connections while maintaining the effi-cient computation of non-overlapping windows, we propose a shifted window partitioning approach which alternates be-tween two partitioning configurations in consecutive Swin Transformer blocks. 因為在僅在固定窗口內計算 attention 會導致窗口之間缺乏連接,限制了模型的建模能力 ![image](https://ask.qcloudimg.com/http-save/yehe-7220647/3e9687789154a6138a9fbe5b87910468.gif) 所以在各自 window 中算完 MSA(Multi-Head Self-Attention) 後,再用 Shifted Window MSA,增加模型 cross-window connection的能力 reference: [使用动图深入解释微软的Swin Transformer](https://cloud.tencent.com/developer/article/2015888) ![image](https://ask.qcloudimg.com/http-save/yehe-7220647/1f201eee7e0d61403ea7714c306bffe5.gif) ## Source Code reference: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py ### `MLP` ```python class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x ``` 1. 為什麼 `fc` 不能用只用一個? **MLP 數學本質** $$y=W_2 \cdot \sigma(W_1 \cdot + b_1) + b_2$$ 基本上 $W_1$ 和 $W_2$ 是兩個不同矩陣,中間還有個 $\sigma$ 是激活函數 (activation function) 如果共用一個 `fc` $$y=W\cdot \sigma(W\cdot x)$$ 表達能力下降,可能造成 underfitting,且兩次線性轉換的目的不同 ``` x = fc1(x) # in_features → hidden_features x = GELU(x) x = fc2(x) # hidden_features → out_features ``` 第一次線性轉換: 展開特徵,改變座標系,feature projection 激活函數: 打破線性限制 ![image](https://hackmd.io/_uploads/BJv4DJ-mbx.png) 第二次線性轉換: 篩選出已經被非線性處理過的高維特徵,feature mixing,對齊模型需要的輸出格式 ### `window partition` ```python def window_partition(x, window_size): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows ``` 比較不同維度操作的差別 | 操作 | 做什麼| 改變資料順序?| 需要 contiguous? | 常見用途 | | -| -| -| - | - | | `view`| 改 shape| ❌ 否| ✅ 必須| 最快,純 reshape | | `reshape` | 改 shape| ❌ 否*| ❌(必要時會 copy)| 安全版 view | | `permute` | 換維度順序 | ❌(只是換 index) | ❌| NCHW ↔ NHWC | | `transpose` | 交換兩個維度 | ❌ | ❌| 矩陣轉置 | 所以,如果要避免出錯,可以先用 `reshape`,否則一定要改成 `contiguous()` ```python x = x.permute(0, 2, 1).contiguous().view(...) # 或 x = x.permute(0, 2, 1).reshape(...) ``` 所以 `window_partition` 不能直接 ```python x = x.view(B, H/window_size, W/window_size,window_size, window_size, C ) ``` 因為 `view` 不能不能改變資料在記憶體中的排列順序,只能重新解讀線性連續的一維記憶體 ### `window_reverse` ```python def window_reverse(windows, window_size, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, H, W, C) """ B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x ``` 從 partition 後的 window,轉回原本的維度 `(B,H,W,C)` ### `WindowAttention` ```python class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) coords_w = torch.arange(self.window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, std=.02) self.softmax = nn.Softmax(dim=-1) def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x def extra_repr(self) -> str: return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' def flops(self, N): # calculate flops for 1 window with token length of N flops = 0 # qkv = self.qkv(x) flops += N * self.dim * 3 * self.dim # attn = (q @ k.transpose(-2, -1)) flops += self.num_heads * N * (self.dim // self.num_heads) * N # x = (attn @ v) flops += self.num_heads * N * N * (self.dim // self.num_heads) # x = self.proj(x) flops += N * self.dim * self.dim return flops ``` 就是論文中的 $W-MSA$ ![image](https://ask.qcloudimg.com/http-save/yehe-7220647/3e9687789154a6138a9fbe5b87910468.gif) #### `relative_position` 可以參考: https://datawhalechina.github.io/thorough-pytorch/%E7%AC%AC%E5%8D%81%E7%AB%A0/Swin-Transformer%E8%A7%A3%E8%AF%BB.html#id3 透過 relative position bias 可以讓模型學習 **距離感知** 的權重 - 為什麼不用絕對位置? 會對每個 token 的固定位置產生 bias - 保留相對空間信息: 讓模型知道哪個 patch 在上下左右的相對距離有多遠 - 平移不變: 這樣對於之後要做 shifted windows(SW-MSA) 很重要 - 提升注意力經度: 模型能更合理分配注意力給空間上相近或相關的 patch #### `QKV` 1. `q`, `k`, `v` : `(B_, nH, N, d)` `B_`: $\cfrac{H \times W}{window_size}$ `nH`: num of head `N` : $H \times W$ `d` : $\cfrac{C}{n\_H}$ 2. `q = q * self.scale` `(B_, nH, N, d)` $\to$ `(B_, nH, N, d)` $$scale = \cfrac{1}{\sqrt{d}}$$ 3. `attn = (q @ k.transpose(-2, -1))` `k.transpose(-2, -1)` : `(B_, nH, N, d)` $\to$ `(B_, hH, d, N)` `attn` : `(B_, nH, N, d)` $\otimes$ `(B_, nH, d, N)` $=$ `(B_, nH, N, N)` 4. attn = attn + relative_position_bias.unsqueeze(0) `relative_position_bias` : $(N \times N, nH)$ reshape $\to$ `(N, N, nH)` permute $\to$ `(nH, N, N)` unsqueeze $\to$ `(1, nH, N, N)` 所以 attn : `(B_, nH, N, N)` 5. `x = (attn @ v).transpose(1, 2).reshape(B_, N, C)` `(attn @ v)`: `(B_, nH, N, N)` $\otimes$ `(B_, nH, N, d)` $\to$ `(B_, nH, N, d)` transpose $\to$ `(B_, N, nH, d)` reshape $\to$ `(B_, N, C)` ### `SwinTransformerBlock` ```python class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, fused_window_process=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.window_size: # if window size is larger than input resolution, we don't partition windows self.shift_size = 0 self.window_size = min(self.input_resolution) assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None self.register_buffer("attn_mask", attn_mask) self.fused_window_process = fused_window_process def forward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" shortcut = x x = self.norm1(x) x = x.view(B, H, W, C) # cyclic shift if self.shift_size > 0: if not self.fused_window_process: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C else: x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # reverse cyclic shift if self.shift_size > 0: if not self.fused_window_process: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size) else: shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) # FFN x = x + self.drop_path(self.mlp(self.norm2(x))) return x def extra_repr(self) -> str: return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" def flops(self): flops = 0 H, W = self.input_resolution # norm1 flops += self.dim * H * W # W-MSA/SW-MSA nW = H * W / self.window_size / self.window_size flops += nW * self.attn.flops(self.window_size * self.window_size) # mlp flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # norm2 flops += self.dim * H * W return flops ``` ![image](https://hackmd.io/_uploads/HyK2Kp7QZx.png) 輸入: `x` $\to$ `(B, L, C)` \# $L = H \times W$ $\to$ `(B, H, W, C)` #### Cyclic Shift ```python shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) ``` `torch.roll` : 對張量沿指定維度平移(循環移動)元素 1 維平移 ``` x = torch.tensor([1, 2, 3, 4, 5]) y = torch.roll(x, shifts=2) print(y) # tensor([4, 5, 1, 2, 3]) ``` 2 維平移 ``` x = torch.tensor([[1, 2, 3], [4, 5, 6]]) y = torch.roll(x, shifts=1, dims=0) print(y) # tensor([[4, 5, 6], # [1, 2, 3]]) ``` 在 SwinTransformer 是這樣運作的 ![image](https://hackmd.io/_uploads/r1-SJsb-bg.png) 在維度 1, 2: `H`(Height), `W`(Width) 進行平移 ```python dims=(1, 2) ``` 平移範圍 ```python shifts=(-self.shift_size, -self.shift_size) ``` 也就是,如果 `shift_size` 是 1 ```python 原始矩陣 x: [[1, 2, 3], [4, 5, 6], [7, 8, 9]] torch.roll(x, shifts=(-1, -1), dims=(0,1)) => [[5, 6, 4], [8, 9, 7], [2, 3, 1]] ``` - 高度方向平移 `-shift_size`,寬度方向也平移 `-shift_size` - 負號表示向上和向左平移 向上平移 ```python [[4, 5, 6], [7, 8, 9], [1, 2, 3]] ``` 向左平移 ```python [[5, 6, 4], [8, 9, 7], [2, 3, 1]] ``` - 循環移動:溢出的元素從另一端回來 #### Partition windows 把整個 image 切成多個 patch `(B, H, W, C)` $\to$ `(B * H/window_size * W/window_size, window_size, window_size, C )` ![image](https://hackmd.io/_uploads/B1domp7m-l.png) 這樣看來是 shift 完才 partition 的 #### Attention ```python attn_windows = self.attn(x_windows, mask=self.attn_mask) ``` 用 [WindowAttention](https://hackmd.io/@clh/BkQ0dg2Nkl/https%3A%2F%2Fhackmd.io%2F%40clh%2FBypDcKbZ-g#WindowAttention) 來計算每個 window 的 attention 現在維度: `(B * W/window_size * H/window_size, window_size * window_size, C)` #### Merge Windows $\to$ Reverse Cyclic Shift 進行 merge ```python attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) ``` 其實只是轉換維度 `(B * W/window_size * H/window_size, window_size * window_size, C)` $\to$ `(nW, window_size, window_size, C)` `nW` : `B * W/window_size * H/window_size` 最後 reverse 回去 ```pytyhon shifted_x = window_reverse(attn_windows, self.window_size, H, W) ``` $\to$ `(B, H, W, C)` cyclic shift 也要平移回去 ```python x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) ``` #### Residual Connection ![image](https://hackmd.io/_uploads/Hkcl9677be.png) ```python shortcut = x ... x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) ``` 原本輸入的 `x` 維度: `(B, H*W, C)` 再加上 `MSA` 的 attention 加上 Residual Connection 的好處是 - 梯度流暢:殘差連接可以幫助梯度直接傳回,避免深層網路訓練困難。 - 保留原始訊息:原始輸入訊息不會完全被改變,注意力只負責補充特徵。 - 提升模型穩定性:避免注意力輸出過度影響網路表現。 #### Block ![image](https://hackmd.io/_uploads/SyM9z0QXbx.png) 最後經過 MLP ```python x = x + self.drop_path(self.mlp(self.norm2(x))) ``` 維度: `(B, H*W, C)` ### `PatchMerging` ![image](https://hackmd.io/_uploads/ryH6SpQ7-l.png) ```python class PatchMerging(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size" assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." x = x.view(B, H, W, C) x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x ``` 它的作用類似於卷積神經網絡(CNN)中的 Pooling(池化) 層,目的是降低解析度(下採樣)並增加通道數,從而讓模型能夠捕捉到更大範圍的特徵(增大感受野)。 x: `(B, H*W, C)` view $\to$ `(B, H, W, C)` 這樣做是為了在降低計算量(減少 Token 數量)的同時,極大化保留影像資訊並增加特徵維度。 $\to$ `(B, H/2, W/2, 4C)` view $\to$ `(B, H/2 * W/2, 4C)` ### `BasicLayer` ```python lass BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, fused_window_process=False): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.use_checkpoint = use_checkpoint # build blocks self.blocks = nn.ModuleList([ SwinTransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, fused_window_process=fused_window_process) for i in range(depth)]) # patch merging layer if downsample is not None: self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x ``` 指這每一塊 ![image](https://hackmd.io/_uploads/BJNcBRmQWg.png) ### `PatchEmbed` ```python class PatchEmbed(nn.Module): r""" Image to Patch Embedding Args: img_size (int): Image size. Default: 224. patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C if self.norm is not None: x = self.norm(x) return x ``` 把 image 進行 patch embedding: 將一張連續的影像(像素)轉換成一個個離散的序列(Tokens),這樣 Transformer 才能處理。 假設輸入影像是 $224 \times 224 \times 3$,patch_size=4,embed_dim=96: |步驟|程式碼操作|Shape|說明| |-|-|-|-| |輸入|`x`|$(B, 3, 224, 224)$|原始影像 $(C, H, W)$| |卷積投影|`self.proj(x)`|$(B, 96, 56, 56)$|$224/4 = 56$。影像縮小,深度變厚| |展平|`.flatten(2)`|$(B, 96, 3136)$|將 $56 \times 56$ 的空間維度拉直成 $3136$ 個點| |轉置|`.transpose(1, 2)`|$(B, 3136, 96)$|最終格式:$(B, L, C)$,符合 Transformer 要求| ### `SwinTransformer`

    Import from clipboard

    Paste your markdown or webpage here...

    Advanced permission required

    Your current role can only read. Ask the system administrator to acquire write and comment permission.

    This team is disabled

    Sorry, this team is disabled. You can't edit this note.

    This note is locked

    Sorry, only owner can edit this note.

    Reach the limit

    Sorry, you've reached the max length this note can be.
    Please reduce the content or divide it to more notes, thank you!

    Import from Gist

    Import from Snippet

    or

    Export to Snippet

    Are you sure?

    Do you really want to delete this note?
    All users will lose their connection.

    Create a note from template

    Create a note from template

    Oops...
    This template has been removed or transferred.
    Upgrade
    All
    • All
    • Team
    No template.

    Create a template

    Upgrade

    Delete template

    Do you really want to delete this template?
    Turn this template into a regular note and keep its content, versions, and comments.

    This page need refresh

    You have an incompatible client version.
    Refresh to update.
    New version available!
    See releases notes here
    Refresh to enjoy new features.
    Your user state has changed.
    Refresh to load new user state.

    Sign in

    Forgot password
    or
    Sign in via Facebook Sign in via X(Twitter) Sign in via GitHub Sign in via Dropbox Sign in with Wallet
    Wallet ( )
    Connect another wallet

    New to HackMD? Sign up

    By signing in, you agree to our terms of service.

    Help

    • English
    • 中文
    • Français
    • Deutsch
    • 日本語
    • Español
    • Català
    • Ελληνικά
    • Português
    • italiano
    • Türkçe
    • Русский
    • Nederlands
    • hrvatski jezik
    • język polski
    • Українська
    • हिन्दी
    • svenska
    • Esperanto
    • dansk

    Documents

    Help & Tutorial

    How to use Book mode

    Slide Example

    API Docs

    Edit in VSCode

    Install browser extension

    Contacts

    Feedback

    Discord

    Send us email

    Resources

    Releases

    Pricing

    Blog

    Policy

    Terms

    Privacy

    Cheatsheet

    Syntax Example Reference
    # Header Header 基本排版
    - Unordered List
    • Unordered List
    1. Ordered List
    1. Ordered List
    - [ ] Todo List
    • Todo List
    > Blockquote
    Blockquote
    **Bold font** Bold font
    *Italics font* Italics font
    ~~Strikethrough~~ Strikethrough
    19^th^ 19th
    H~2~O H2O
    ++Inserted text++ Inserted text
    ==Marked text== Marked text
    [link text](https:// "title") Link
    ![image alt](https:// "title") Image
    `Code` Code 在筆記中貼入程式碼
    ```javascript
    var i = 0;
    ```
    var i = 0;
    :smile: :smile: Emoji list
    {%youtube youtube_id %} Externals
    $L^aT_eX$ LaTeX
    :::info
    This is a alert area.
    :::

    This is a alert area.

    Versions and GitHub Sync
    Get Full History Access

    • Edit version name
    • Delete

    revision author avatar     named on  

    More Less

    Note content is identical to the latest version.
    Compare
      Choose a version
      No search result
      Version not found
    Sign in to link this note to GitHub
    Learn more
    This note is not linked with GitHub
     

    Feedback

    Submission failed, please try again

    Thanks for your support.

    On a scale of 0-10, how likely is it that you would recommend HackMD to your friends, family or business associates?

    Please give us some advice and help us improve HackMD.

     

    Thanks for your feedback

    Remove version name

    Do you want to remove this version name and description?

    Transfer ownership

    Transfer to
      Warning: is a public team. If you transfer note to this team, everyone on the web can find and read this note.

        Link with GitHub

        Please authorize HackMD on GitHub
        • Please sign in to GitHub and install the HackMD app on your GitHub repo.
        • HackMD links with GitHub through a GitHub App. You can choose which repo to install our App.
        Learn more  Sign in to GitHub

        Push the note to GitHub Push to GitHub Pull a file from GitHub

          Authorize again
         

        Choose which file to push to

        Select repo
        Refresh Authorize more repos
        Select branch
        Select file
        Select branch
        Choose version(s) to push
        • Save a new version and push
        • Choose from existing versions
        Include title and tags
        Available push count

        Pull from GitHub

         
        File from GitHub
        File from HackMD

        GitHub Link Settings

        File linked

        Linked by
        File path
        Last synced branch
        Available push count

        Danger Zone

        Unlink
        You will no longer receive notification when GitHub file changes after unlink.

        Syncing

        Push failed

        Push successfully