Try   HackMD

Week 5:模型剪枝實作

tags: 技術研討

Githubhttps://github.com/Eric-mingjie/rethinking-network-pruning

資料集: cifar-10
方法: l1-norm-pruning

Python code

  • Baseline
  • Prune
  • Fine-tune
    • main_finetune.py:微調模型 (從上次最好的繼續訓練)
  • Retrain
    • main.py:剪枝後重新訓練
    • main_E.py (Scratch-E):剪枝後重新訓練且疊代次數相同
    • main_B.py (Scratch-B):剪枝後重新訓練且訓練時間與計算資源相同

實作

剪枝方法

複習連結點:3.1 如何決定裁減哪些 filters?

1. VGG16

實驗結果

模型 權重數 準確度 訓練 epoch 數 論文準確度 論文訓練 epoch 數
VGG 16 14,987,722 93.8 % 146 93.3% 160
VGG 16 (pruned) 5,397,034 93.8 % / 64.7%
(after / before retrain)
34 93.4% 40

VGG16 模型架構

  • vgg16 論文上的架構圖

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

剪枝實作

1. 模型剪枝設定

### vggprune.py ###

         cfg = [32, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M', 256, 256, 256]
original cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]

2. 依照 L1-norm 判斷要哪些 filters 要被剪枝

### vggprune.py ###

cfg_mask = [] # 每層有哪些 filters 要被 mask
layer_id = 0
for m in model.modules():
    if isinstance(m, nn.Conv2d): # dimension: [output, input, k, k]
        out_channels = m.weight.data.shape[0] # dimension 0 是 output_channels
        if out_channels == cfg[layer_id]: # 如果該層的 filters 數量跟剪枝設定相同,表示不剪枝可跳過
            cfg_mask.append(torch.ones(out_channels)) # 該層全部保留,給 1
            layer_id += 1
            continue
        weight_copy = m.weight.data.abs().clone() # 將 filters weights 取絕對值
        weight_copy = weight_copy.cpu().numpy()
        L1_norm = np.sum(weight_copy, axis=(1, 2, 3)) # dimension 0 是 output_channels
        arg_max = np.argsort(L1_norm) # ascending sort
        arg_max_rev = arg_max[::-1][:cfg[layer_id]] # 1. reverse array 2. 取權重最大的前 cfg[layer_id] 個 filters position
        assert arg_max_rev.size == cfg[layer_id], "size of arg_max_rev not correct"
        mask = torch.zeros(out_channels) # output_channels 預設是 masked 的
        mask[arg_max_rev.tolist()] = 1
        cfg_mask.append(mask)
        layer_id += 1
    elif isinstance(m, nn.MaxPool2d):
        layer_id += 1

3. 將剪枝後的權重搬移到新模型中

### vggprune.py ###

start_mask = torch.ones(3) # 初始 input_channel, 也就是 RGB
layer_id_in_cfg = 0
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    if isinstance(m0, nn.BatchNorm2d): # 注意, batch norm 基本上都是
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # 從 mask 中將值為 1 的 positions 全部找出來變成一個 array
        if idx1.size == 1:
            idx1 = np.resize(idx1,(1,))
        m1.weight.data = m0.weight.data[idx1.tolist()].clone()
        m1.bias.data = m0.bias.data[idx1.tolist()].clone()
        m1.running_mean = m0.running_mean[idx1.tolist()].clone()
        m1.running_var = m0.running_var[idx1.tolist()].clone()
        layer_id_in_cfg += 1
        start_mask = end_mask
        if layer_id_in_cfg < len(cfg_mask):  # 如果還沒有到 cfg_mask 的最後,可以繼續 assign 下一個 mask list
            end_mask = cfg_mask[layer_id_in_cfg]
    elif isinstance(m0, nn.Conv2d):
        idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy()))) # input channel 根據 mask 中將值為 1 的 positions 全部找出來變成一個 array
        idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy()))) # output channel mask 中將值為 1 的 positions 全部找出來變成一個 array
        print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
        if idx0.size == 1:
            idx0 = np.resize(idx0, (1,))
        if idx1.size == 1:
            idx1 = np.resize(idx1, (1,))
        w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # 先把 input 需要的 filters 選出來
        w1 = w1[idx1.tolist(), :, :, :].clone() # 再把 output 需要的 filters 選出來
        m1.weight.data = w1.clone() # 把裁減過後的 filter weights 放到新模型中
    elif isinstance(m0, nn.Linear):
        if layer_id_in_cfg == len(cfg_mask): # 第一次的 linear layer
            idx0 = np.squeeze(np.argwhere(np.asarray(cfg_mask[-1].cpu().numpy()))) # 最後一層可能被裁減的 filter 層
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1,))
            m1.weight.data = m0.weight.data[:, idx0].clone() # [out_channel, in_channel]
            m1.bias.data = m0.bias.data.clone()
            layer_id_in_cfg += 1
            continue
        m1.weight.data = m0.weight.data.clone() # 第二次的 linear layer 就不會受到裁減影響了
        m1.bias.data = m0.bias.data.clone()
    elif isinstance(m0, nn.BatchNorm1d): # linear layer 後的 1 dimensional batch norm
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()
        m1.running_mean = m0.running_mean.clone()
        m1.running_var = m0.running_var.clone()

2. Resnet56

實驗結果

以下實驗皆使用 Scratch-B 的方法 retrain

模型 權重數 準確度 訓練 epoch 數 論文準確度 論文訓練 epoch 數
Resnet56 853,018 93.0% 160 93.0% 160
Resnet56 (pruned-A) 773,336 93.2 % / 91.4%
(after / before retrain)
220 93.1% 220
Resnet56 (pruned-B) 735,712 93.0 % / 67.0%
(after / before retrain)
220 93.0% 220

如何做到剪枝後重新訓練且訓練時間與計算資源相同?

FLOPs 介紹:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

以 5x5 的影像,卷積 3x3 為例:一次卷積需計算 (3x3) + (3x3-1) 的次數,總共要卷積 9 次
因此 FLOPs = 17*9 = 153

剪枝實作

### main_B.py (Scratch-B) ###

# 如果有要做 scratch 的話會 load 新剪枝的模型
if args.scratch:
    checkpoint = torch.load(args.scratch)
    # 新剪枝的模型
    model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg'])

# 正常的模型架構
model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)

flops_std = print_model_param_flops(model_ref, 32)  # 正常模型的計算量
flops_small = print_model_param_flops(model, 32)  # 新剪枝的模型的計算量
args.epochs = int(160 * (flops_std / flops_small))  # 計算新的 epoch 數量 (要讓訓練時間跟資源與原先相同) 

3. Resnet110

模型架構:

Depth 110
Stage 3 (每個Stage的Residual Block一樣多)
Residual Block conv1-bn1-relu-conv2-bn2
Kernel size 3*3
Channels 16(stage1), 32(stage2), 64(stage3)
1個stage 18 Block
1個Block 2 Conv Layer

Depth 110 怎麼算:
Conv1 Layer + Block 的 Conv Layer + FC
= Conv1 Layer + (3 stage * 18 Block * 2 Conv Layer) + FC
= 1 + 108 + 1
= 110
**

#看看model structure model.parameters
ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
.
.
.
    (17): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
.
.
.
    (17): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
.
.
.
    (17): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AvgPool2d(kernel_size=8, stride=8, padding=0)
  (fc): Linear(in_features=64, out_features=10, bias=True)
)
#看參數
from torchsummary import summary
model = model.cuda()
summary(model, input_size=(3,32,32))

parameters:k * k * InputChannel * OutputChannel

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 32, 32]             432 (3*3*3*16)
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]           2,304 (3*3*16*16)
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
       BasicBlock-10           [-1, 16, 32, 32]               0
           Conv2d-11           [-1, 16, 32, 32]           2,304
      BatchNorm2d-12           [-1, 16, 32, 32]              32
             ReLU-13           [-1, 16, 32, 32]               0
           Conv2d-14           [-1, 16, 32, 32]           2,304
      BatchNorm2d-15           [-1, 16, 32, 32]              32
             ReLU-16           [-1, 16, 32, 32]               0
       BasicBlock-17           [-1, 16, 32, 32]               0
           Conv2d-18           [-1, 16, 32, 32]           2,304
      BatchNorm2d-19           [-1, 16, 32, 32]              32
             ReLU-20           [-1, 16, 32, 32]               0

pruning:


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4            [-1, 8, 32, 32]           1,152
       BatchNorm2d-5            [-1, 8, 32, 32]              16
              ReLU-6            [-1, 8, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           1,152
       BatchNorm2d-8           [-1, 16, 32, 32]              32
              ReLU-9           [-1, 16, 32, 32]               0
       BasicBlock-10           [-1, 16, 32, 32]               0
           Conv2d-11            [-1, 8, 32, 32]           1,152
      BatchNorm2d-12            [-1, 8, 32, 32]              16
             ReLU-13            [-1, 8, 32, 32]               0
           Conv2d-14           [-1, 16, 32, 32]           1,152
      BatchNorm2d-15           [-1, 16, 32, 32]              32
             ReLU-16           [-1, 16, 32, 32]               0
       BasicBlock-17           [-1, 16, 32, 32]               0
           Conv2d-18            [-1, 8, 32, 32]           1,152
      BatchNorm2d-19            [-1, 8, 32, 32]              16
             ReLU-20            [-1, 8, 32, 32]               0
skip layer pruning rate [stage1, stage2, stage3]
A [36] [0.5, 0, 0]
B [36, 38, 74] [0.5, 0.4, 0.3]

剪枝步驟:
1.記錄誰該留下(誰該剪掉):
(1) 如果是skip layer就跳過
(2) 剪每一個Residual Block的first layer
(3) 要留下多少比例 (1 - pruning rate)
(4) 計算L1_nrom加總,由小排到大
(5) 由後面往前取出要留下的比例
(6) 要留下的人,標記為1,剪掉的人標記為0
2.把要留下的人存進newmodel
(1) if是第一個conv layer: 跳過他,把全部一樣的weight直接複製存進newmodel
(2) if是每個Block的conv layer1: 把要留下的weight存進newmodel的weight (要剪掉的不要存進newmodel)
(3) if是每個Block的conv layer2: 把conv layer1相對應留下的weight存進new_model (要剪掉的不要存進newmodel) (link)
(4) if是BN layer: 也留下相對應的weight,把相對應留下的weight存進newmodel (要剪掉的不要存進newmodel)
(5) if是最後的FC: 跳過他,把全部一樣的weight直接複製存進newmodel
完成!

圖一
https://stackoverflow.com/questions/40857930/how-does-numpy-sum-with-axis-work
圖二

skip = { 'A': [36], 'B': [36, 38, 74], } prune_prob = { 'A': [0.5, 0.0, 0.0], 'B': [0.5, 0.4, 0.3], } layer_id = 1 cfg = [] cfg_mask = [] #1.記錄誰要留下誰不要留下 for m in model.modules(): ''' ResNet( (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ''' if isinstance(m, nn.Conv2d): ''' - m: Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) - m.weight.data.shape[0]: torch.Size([16, 3, 3, 3]) torch.Size([16, 16, 3, 3]) - out_channels: 16 16 ''' out_channels = m.weight.data.shape[0] #16 (torch.Size([16, 3, 3, 3])) #(1) 如果是skip layer就跳過,不做剪枝,全部的weight都要留下來,因此全部標記為1 if layer_id in skip[args.v]: #args.v: A or B #[tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])...] cfg_mask.append(torch.ones(out_channels)) #一個空的16維(output_channels size) 裡面都存1的matrix #[16, 16, 16,...., 32, 32, 32,...., 64, 64 ,64,.....] cfg.append(out_channels) #conv layer的channels記下來 layer_id += 1 continue #(2) 剪每一個Residual Block的first layer if layer_id % 2 == 0: #第2個conv layer stage = layer_id // 36 #看現在這個layer是 3個stage中哪個stage(一個stage 18 Block/36 layer) if layer_id <= 36: stage = 0 elif layer_id <= 72: stage = 1 elif layer_id <= 108: stage = 2 prune_prob_stage = prune_prob[args.v][stage] # A: 0.5/0/0 ; B: 0.5/0.4/0.3 ### torch函數說明: #.clone(): 複製一個完全相同的tensor #.cpu(): gpu tensor 轉 cpu tensor #.numpy(): tensor 轉 numpy # k: 3*3, input channel:3, output channel:16 weight_copy = m.weight.data.abs().clone().cpu().numpy() #weight 絕對值 #(4) 計算L1_nrom加總,由小排到大 L1_norm = np.sum(weight_copy, axis=(1,2,3)) # 算L1 全部加總, 3*3*3*16 -> 1*16 ,看圖一 #(3) 要留下多少比例 (1 - pruning rate) num_keep = int(out_channels * (1 - prune_prob_stage)) # 要留下的數量 16 * (1-0.5) = 8 arg_max = np.argsort(L1_norm) #[11 0 8 15 1 12 9 6 14 3 13 4 7 10 5 2] #x中的元素從小到大排列,提取其對應的index(索引),然後輸出到y # x=np.array([1,4,3,-1,6,9]) # y=array([3,0,2,1,4,5]) #[::-1]從後面取回來,[:num_keep]取要留下的8個 arg_max_rev = arg_max[::-1][:num_keep] #[ 2 5 10 7 4 13 3 14] mask = torch.zeros(out_channels) #一個空的 裡面都存0的matrix, 1*16維 #把要留下的人給1 mask[arg_max_rev.tolist()] = 1 #tensor([0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0.]) cfg_mask.append(mask) #tensor([0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0.], tensor([])....) cfg.append(num_keep) #[8, 8, ...] 留下幾片filter記下來 layer_id += 1 continue layer_id += 1
#newmodel: 存一個resnet model架構,model structure size是剛剛記錄數量 newmodel = resnet(dataset=args.dataset, depth=args.depth, cfg=cfg) if args.cuda: newmodel.cuda()
layer_id_in_cfg = 0
conv_count = 1
for [m0, m1] in zip(model.modules(), newmodel.modules()):
    '''
        ResNet(
      (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        '''
    # CONV
    if isinstance(m0, nn.Conv2d):
        '''(conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)'''
        #(1) if是第一個conv layer: 跳過他,把全部一樣的weight直接複製存進newmodel
        if conv_count == 1:
            m1.weight.data = m0.weight.data.clone()
            conv_count += 1
            continue
        
        '''
        (stage 1, Block 0, conv 1)
        (stage 1, Block 1, conv 1)
        (stage 1, Block 2, conv 1)
        .
        .
        .
        '''
        # 看圖二
        # (2) if是每個Block的conv layer1: 把要留下的weight存進newmodel的weight (要剪掉的不要存進newmodel)
        if conv_count % 2 == 0:
            #cfg_mask: tensor([0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0.], tensor([])....)
            mask = cfg_mask[layer_id_in_cfg] #選出該layer cfg_mask: tensor([0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0.])
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy()))) #值為非0的,也就是等於1的位子: [ 2  3  4  5  7 10 13 14]
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[idx.tolist(), :, :, :].clone() # 留下[ 2  3  4  5  7 10 13 14],weight就會有 3*3, 3 channels, 8output_channels
            m1.weight.data = w.clone() #存進newmodel
            layer_id_in_cfg += 1
            conv_count += 1
            continue

        '''
        (stage 1, Block 0, conv 2)
        (stage 1, Block 1, conv 2)
        (stage 1, Block 2, conv 2)
        .
        .
        .
        '''
        #(3) if是每個Block的conv layer2: 把conv layer1相對應留下的weight存進new_model (要剪掉的不要存進newmodel)
        if conv_count % 2 == 1:
            #cfg_mask: tensor([0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 0.], tensor([])....)
            mask = cfg_mask[layer_id_in_cfg-1] #拿跟上面一樣的mask
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[:, idx.tolist(), :, :].clone() #取conv 1在conv2會被剪掉的地方
            m1.weight.data = w.clone()
            conv_count += 1
            continue
    #Batch
    #(4) if是BN layer: 也留下相對應的weight,把相對應留下的weight存進newmodel (要剪掉的不要存進newmodel)
    elif isinstance(m0, nn.BatchNorm2d):
        #也剪相對應的batch
        if conv_count % 2 == 1:
            mask = cfg_mask[layer_id_in_cfg-1]
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            # BN公式: Y = (X - running_mean) / sqrt(running_var + eps) * gamma + beta
            m1.weight.data = m0.weight.data[idx.tolist()].clone()
            m1.bias.data = m0.bias.data[idx.tolist()].clone()
            m1.running_mean = m0.running_mean[idx.tolist()].clone()
            m1.running_var = m0.running_var[idx.tolist()].clone()
            continue
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()
        m1.running_mean = m0.running_mean.clone()
        m1.running_var = m0.running_var.clone()
    #Linear
    #(5) if是最後的FC: 跳過他,把全部一樣的weight直接複製存進newmodel
    elif isinstance(m0, nn.Linear):
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()

實驗結果

模型 權重數 準確度 訓練 epoch 數 論文準確度 論文訓練 epoch 數
ResNet 110 1,727,962 93.6 % 129 93.6% 160
ResNet 110 (pruned-A) 1,688,522 93.3% / 87.6%
(after / before retrain)
40 93.5% 40
ResNet 110 (pruned-B) 1,168,424 93.2% / 79.2%
(after / before retrain)
40 93.3% 40

視覺化

Github: https://github.com/tyui592/Pruning_filters_for_efficient_convnets

Graph 1: Absolute sum of filter weights for each layer

論文範例 實作

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
這張圖最好畫!只要把每一層的 filter 找出來,排序後計算每個 filter 絕對值的平均就好了

colors = ['r', 'g', 'b', 'k', 'y', 'm', 'c']
lines = ['-', '--', '-.']

plt.figure(figsize=(7,5))
conv_count = 0
for layer in network.features: # 放入自己的 model
    if isinstance(layer, torch.nn.Conv2d):
        line_style = colors[conv_count%len(colors)] + lines[conv_count//len(colors)]
        
        # 取出各層的 filter 
        fw = layer.weight.data.cpu().numpy()
        
        # 取絕對值 + 排序 (大到小)
        sorted_abs_sum = np.sort(np.sum(np.abs(fw.reshape(fw.shape[0], -1)), axis=1))[::-1]
        
        # 將權重標準化 0~1 之間
        normalized_abs_sum = sorted_abs_sum/sorted_abs_sum[0]
        conv_count += 1
        plt.plot(np.linspace(0, 100, normalized_abs_sum.shape[0]), normalized_abs_sum, line_style, label='conv %d'%conv_count)
        
plt.title("Data: %s, Model: %s"%(args.data_set, args.vgg))        
plt.ylabel("normalized abs sum of filter weight")
plt.xlabel("filter index / # filters (%)")
plt.legend(loc='upper right')
plt.xlim([0, 140])
plt.grid()
plt.savefig("figure1.png", dpi=150, bbox_inches='tight')
plt.show()

Graph 2: Pruning filters with the lowest absolute weights sum and their corresponding test accuracies

論文範例 實作

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
這張圖用 train 好的模型去做剪枝,依照各層剪枝的比例去計算 testing data 的 accuracy

# 前置設定
prune_step_ratio = 1/8
max_channel_ratio = 0.90 

prune_channels = [64, 64, 128, 128, 256, 256, 256, 512, 512, 512, 512, 512, 512]
prune_layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'conv6', 'conv7', 'conv8', 'conv9', 'conv10', 'conv11', 'conv12', 'conv13']
top1_accuracies = {}
top5_accuracies = {}

for conv, channel in zip(prune_layers, prune_channels):    
    top1_accuracies[conv] = []
    top5_accuracies[conv] = []
    
    # 計算初始 top1, top5 accuracy
    network, _, _ = test_network(args, data_set=test_set)
        
    # 計算剪枝過程要裁減的 filter 數量
    step = np.linspace(0, int(channel*max_channel_ratio), int(1/prune_step_ratio), dtype=np.int)
    ## output example: array([ 0,  8, 16, 24, 32, 40, 48, 57])
    steps = (step[1:] - step[:-1]).tolist()
    ## output example: [8, 8, 8, 8, 8, 8, 9]
    
    for i in range(len(steps)):
        print("\n%s: %s Layer, %d Channels pruned"%(time.ctime(), conv, sum(steps[:i+1])))
        
        # 設定目前層數以及要剪枝的 filter 數量
        args.prune_layers = [conv]
        args.prune_channels =[steps[i]]
        
        # 進行剪枝
        network = prune_network(args, network)
        
        # 計算 top1, top5 accuracy
        network, _, (top1, top5) = test_network(args, network, test_set)
            
        top1_accuracies[conv].append(top1)
        top5_accuracies[conv].append(top5)

output:

-*--*--*--*--*--*--*--*--*--*-
	Evalute network
-*--*--*--*--*--*--*--*--*--*-
Wed Oct  2 15:58:31 2019: Test information, Data(s): 1.112, Forward(s): 0.169, Top1: 93.470, Top5: 99.430, 

Wed Oct  2 15:58:31 2019: conv1 Layer, 8 Channels pruned
-*--*--*--*--*--*--*--*--*--*-
	Prune network
-*--*--*--*--*--*--*--*--*--*-
-*--*--*--*--*--*--*--*--*--*-
	Evalute network
-*--*--*--*--*--*--*--*--*--*-
Wed Oct  2 15:58:33 2019: Test information, Data(s): 1.115, Forward(s): 0.168, Top1: 93.470, Top5: 99.430, 

Wed Oct  2 15:58:33 2019: conv1 Layer, 16 Channels pruned
-*--*--*--*--*--*--*--*--*--*-
	Prune network
-*--*--*--*--*--*--*--*--*--*-
-*--*--*--*--*--*--*--*--*--*-
	Evalute network
-*--*--*--*--*--*--*--*--*--*-
Wed Oct  2 15:58:36 2019: Test information, Data(s): 1.115, Forward(s): 0.170, Top1: 93.470, Top5: 99.430, 

Wed Oct  2 15:58:36 2019: conv1 Layer, 24 Channels pruned
-*--*--*--*--*--*--*--*--*--*-
	Prune network
-*--*--*--*--*--*--*--*--*--*-
            .
            .
            .
-*--*--*--*--*--*--*--*--*--*-
	Evalute network
-*--*--*--*--*--*--*--*--*--*-
Wed Oct  2 16:02:44 2019: Test information, Data(s): 1.117, Forward(s): 0.167, Top1: 93.380, Top5: 99.240, 

Wed Oct  2 16:02:44 2019: conv13 Layer, 394 Channels pruned
-*--*--*--*--*--*--*--*--*--*-
	Prune network
-*--*--*--*--*--*--*--*--*--*-
-*--*--*--*--*--*--*--*--*--*-
	Evalute network
-*--*--*--*--*--*--*--*--*--*-
Wed Oct  2 16:02:46 2019: Test information, Data(s): 1.128, Forward(s): 0.169, Top1: 66.960, Top5: 98.280, 

Wed Oct  2 16:02:46 2019: conv13 Layer, 460 Channels pruned
-*--*--*--*--*--*--*--*--*--*-
	Prune network
-*--*--*--*--*--*--*--*--*--*-
-*--*--*--*--*--*--*--*--*--*-
	Evalute network
-*--*--*--*--*--*--*--*--*--*-
Wed Oct  2 16:02:48 2019: Test information, Data(s): 1.124, Forward(s): 0.170, Top1: 44.280, Top5: 62.030, 

Graph 3: Prune and retrain for each single layer

論文範例 實作
       太花時間了QQ       

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
這張圖最麻煩,也是用 train 好的模型去做剪枝,接著每剪枝一次後就 retrain 一次計算 testing data 的 accuracy

# 前置設定
args.retrain_flag = True
args.retrain_epoch = 10
args.independent_prune_flag = False
args.retrain_lr = 0.001