# 從實務上來講 CNN Pruning 到底在做什麼 #### 從 Fully Connected 講起 在一般的 FC layer 中, model pruning 並不是一件很難理解的事情,當我們想採用 structured pruning 時,直接針對 layer 中的 node 執行就好;unstructured pruning 時,我們就針對 weight 就好。 . 使用程式中的狀況來講,當我們有一個這樣子的網路: ```python= fc1 = torch.nn.Linear(784, 32) fc2 = torch.nn.Linear(32, 16) classifier = torch.nn.Linear(16, 10) ``` 我們將他的 weight 的形狀表示出來,分別是: ```python fc1.weight.shape : (32, 784) fc1.bias.shape : (32) fc2.weight.shape : (16, 32) fc2.bias.shape : (16) classifier.weight.shape : (10, 16) classifier.bias.shape : (10) ``` 考慮到 structured prunin,我們將會同時處理掉 weight 中 dim=1 的部分、 bias 中 dim=0 的部分,如果我們 prune fc1 50% 的 weight,她會變成這樣: ```python fc1.weight.shape : (16, 784) fc1.bias.shape : (16) ``` 當我們 fc1 的輸出變成 ==16== 維,下一層的輸入也會受到影響: ```python fc2.weight.shape : (16, 16) fc2.bias.shape : (16) ``` 在 fc2 中,由 fc1 被 pruned 掉的部分所對應的 weight 也會被刪減掉。 在這個狀況下,我們會發現到:直接對結構處理是一件很麻煩的事,在 pyTorch 中,我們可以給 layer 設定一個 mask 來控制他最後出去的樣子 被 mask 掉的輸出將會被「歸零」,由此來模擬這個 weight / node 沒有輸出,我們重新考慮一個形狀為 (4, 3) 的 linear layer,我們分別以 structured pruning 與 unstrucutred pruning 來刪減掉 50% ```python # structured pruning module.weight.mask = [ [1, 1, 1], [1, 1, 1], [0, 0, 0], [0, 0, 0], ] # mask[2, :] = 0 # mask[3, :] = 0 # unstructured pruning module.weight.mask = [ [1, 0, 1], [0, 1, 0], [1, 0, 1], [0, 1, 0], ] ``` #### CNN 的 pruning 在 CNN 中,pruning 的方式又更多元了,主要有這三種方式: 1. Filter pruning 2. Channel pruning 3. Shape pruning 以及 unstructured pruning 不過,在提到 pruning 之前,要先回顧一下 CNN 是怎麼運作的 ![image](https://miro.medium.com/v2/resize:fit:1100/format:webp/1*52JnBGMb29SuwbtS8DQlAw.png)[1] 可以看到我們的輸入是 (4, 4) 個 pixel,每個 pixel 有三個 channel。 第一層網路有 2 個 filter,每個 filter 有對應前一層輸入的 3 個 channel,每個 channel 的 kernel 為 (3, 3), - 由此,我們的 conv2d 的形狀為 (2, 3, 3, 3) - 第 0 維: filter 數量 - 第 1 維: channel 數量 - 第2 3維: kernel 的高寬 - 第 1 2 3 維就等於是一個 filter 我們在計算的時候,會先對每個 channel 進行捲機操作,然後把同個 filter 的 channel 捲機結果以 pixel-wise 的方式加起來。 - 套上例子,每個 filter 有 3 個 kernel,分別對前層(輸入層)的紅藍綠 3 個 channel 進行計算,接著將這三個捲積輸出加起來 - 由於我們有兩個 filter,最後會產出 2 張conv output,會作為下一層的 input channel 使用 當我們把他具象化出來後,我們就能很好的理解一開始提及的三種 pruning 方式在做什麼了 ##### 1. Filter pruning:針對conv output 進行刪減 套用到上述例子 - 就是一次刪除 3 個 channel 的 kernel (在圖中表示為 kernel map) - 當我們 prune 掉 1 個 filter,conv 的形狀為 (==1==, 3, 3, 3) 而在實務上,我們有兩個做法,與 FC 層的狀況相同,可以調整結構,變成(==1==, 3, 3, 3) ```python # Before pruning conv2d = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3) conv2d.shape : (2, 3, 3, 3) # Apply pruning # After pruning conv2d.shape : (1, 3, 3, 3) ``` 由於我們刪減了filter的數量,下一層的 in channel 也需要改變: ```python # Before pruning conv2d_2 = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3) conv2d_2.shape : (3, 2, 3, 3) # After pruning conv2d_2.shape : (3, 1, 3, 3) ``` 第二個則是給他一個 mask ```python mask : torch.ones( (2, 3, 3, 3) ) # Assume that we puring the filter index = 1 mask[1, :, :, :] = 0 Apply mask to conv2d ``` 由於我們是加 mask,不再需要調整結構。 如此一來我們就能成功刪減 filter 了。 ##### 2. Channel pruning Channel 比較有趣,我們在上面提到了很多次 filter 中的 channel 數量,所以 channel pruning 就是刪減每個 filter 中固定的 channel 一樣拿上面的例子講,假設我們要刪除 1 個 channel,具體來說是希望刪除掉紅色 channel,我們的 conv 形狀會變成 (2, 2, 3, 3) - 有 2 個 filter,每個 filter 有 2 個 channel (藍綠),每個 channel 的 kernel shape 為 (3, 3) - 在計算 filter 時,我們針對 藍綠 兩個 channel 進行捲積計算,算出兩個捲積輸出後進行 pixel-wise 的加總 在實務上,如果要調整結構,整個操作會有點從後往前來。由於對輸入 $x$ 的處理比較特別,可以參考 2017 的 channel pruning 開山始祖 [2] 這裡讓我們刪除 conv2d_2 中的 index = 1 的 channel ```python # Before pruning conv2d = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3) conv2d.shape : (2, 3, 3, 3) conv2d_2 = torch.nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3) conv2d_2.shape : (3, 2, 3, 3) # Apply pruning # After pruning conv2d_2.shape : (3, 1, 3, 3) # Adjust previous layer conv2d.shape : (1, 3, 3, 3) ``` 一樣我們可以採用 mask 的方式: ```python mask : torch.ones( (3, 2, 3, 3) ) # Assume that we puring the channel index = 1 mask[:, 1, :, :] = 0 Apply mask conv2d_2 ``` 其實就只是套用的維度不同而已 ##### 3. Shape pruning Shape pruning 我覺得其實蠻少用的,不過簡單來講就是對 kernel 的特定 weight 調整 (其實這個我不太熟,不過以上面的來講,就是調整 kernel_size,套用mask 的話就是對 2 3 維設0) ##### Unstructure pruning 這個跟 FC layer 的時候一樣,對 mask 四維中任意元素設0 #### 總結 其實整體的困難在於 filter 與 channel 套用在具體的神經網路中的差異到底是什麼,由於計算涉及 pixel-wise sum,使用圖片示意刪減範圍可匯難以理解,所以這次是以神經網路實做的角度來講這件事,大概是這樣 [1] [卷積神經網路(Convolutional neural network, CNN): 1×1卷積計算在做什麼](https://chih-sheng-huang821.medium.com/%E5%8D%B7%E7%A9%8D%E7%A5%9E%E7%B6%93%E7%B6%B2%E8%B7%AF-convolutional-neural-network-cnn-1-1%E5%8D%B7%E7%A9%8D%E8%A8%88%E7%AE%97%E5%9C%A8%E5%81%9A%E4%BB%80%E9%BA%BC-7d7ebfe34b8) [2] He, Yihui, Xiangyu Zhang, and Jian Sun. "Channel pruning for accelerating very deep neural networks." Proceedings of the IEEE international conference on computer vision. 2017.