# DeiT 模型剪枝後之微調流程報告 [TOC] --- ## 目標 針對 6 種 DeiT 系列模型(Tiny / Small / Base × 是否蒸餾)進行: 🔹 結構化剪枝(N:M Sparsity) 🔹 兩種微調策略比較(Full、LoRA) 將剪枝與微調過程中之準確率與時間進行記錄,輸出結果統整為 .csv。 ## 模型組合 | Model Type | Model Name | Distilled | | ---------- | ---------------------------------- | --------- | | Tiny | `deit_tiny_patch16_224` | False | | Tiny | `deit_tiny_distilled_patch16_224` | True | | Small | `deit_small_patch16_224` | False | | Small | `deit_small_distilled_patch16_224` | True | | Base | `deit_base_patch16_224` | False | | Base | `deit_base_distilled_patch16_224` | True | ## 一、N:M 稀疏剪枝流程(Pruning) ### 實作策略 使用 torch.ao.pruning.WeightNormSparsifier >參考 >[Accelerating BERT with semi-structured (2:4) sparsity — PyTorch Tutorials 2.7.0+cu126 documentation](https://docs.pytorch.org/tutorials/advanced/semi_structured_sparse.html) #### 適用稀疏樣式: - 2:4 → 每 4 個元素保留 2 個非零 - 4:8 → 每 8 個元素保留 4 個非零 僅針對 nn.Linear 層進行剪枝 #### 忽略模組: - model.head - model.patch_embed.proj ### 剪枝步驟與儲存 1. 載入 timm 預訓練模型 2. 計算原始準確率 3. 進行剪枝(依據設定 N:M pattern) 4. 計算剪枝後準確率 5. 分別儲存原始與剪枝模型 .pth ```mermaid flowchart TD A[Load pretrained DeiT model] --> B[Eval original accuracy] B --> C[Apply N:M pruning] C --> D[Eval pruned accuracy] D --> E[Save original / pruned weights] ``` ### 模型儲存結構 ```swift /homes/nfs/ben/project/lab2/ ├── models/ │ ├── original/{model}_orig.pth │ └── pruned/{model}_{sparsity}_pruned.pth ├── prune_results.csv ``` ![image](https://hackmd.io/_uploads/rkaFj8x-ll.png) ## 二、微調(Fine-tuning)流程 ### 策略總覽 | 策略 | 描述 | 預估訓練參數量 | | ----------- | -------------------------- | ----- | | Full | 全模型參數解凍訓練 | 全部 | | LoRA | 僅在 `qkv` / `proj` 插入低秩參數訓練 | < 1% | 訓練參數量還未做計算 ### LoRA 設定參數 - rank $𝑟=8$ - alpha = 32 - dropout = 0.1 ### 微調輸出路徑 ```swift /homes/nfs/ben/project/lab2/models/finetuned/ ├── full/{model}_{sparsity}/model_full.pth ├── lora/{model}_{sparsity}/model_lora.pth ``` ## 三、評估與統計記錄 ### 評估指標 - `original_acc`:剪枝前準確率 - `pruned_acc`:剪枝後準確率 - `fin_*_acc`:微調後準確率($full, lora$) - `fin_*_time`:微調耗時(秒) ### 統整結果輸出(CSV) ``` /homes/nfs/ben/project/lab2/prune_results.csv /homes/nfs/ben/project/lab2/finetune_single.csv ``` ## 四、Pruning 效果總結 ### 1️⃣ 剪枝準備與模型組合 ![image](https://hackmd.io/_uploads/rybnyDeZel.png) 本實驗共針對 **6 個 DeiT 模型組合**進行剪枝測試,依據模型大小(Tiny / Small / Base)與是否蒸餾(Distilled / Non-distilled)進行排列,並套用兩種結構化稀疏度設定: - `2:4`:每 4 個權重中保留 2 個 - `4:8`:每 8 個權重中保留 4 個 此表即為全體剪枝組合設定,共計 12 組,後續所有分析皆圍繞這些組合展開。 ### 2️⃣ 不同稀疏度對準確率影響 ![image](https://hackmd.io/_uploads/S15FJveZgx.png) 各模型在不同剪枝稀疏度下(`2:4` 與 `4:8`)的準確率下降情形。主要觀察如下: - **2:4 稀疏度損失較大**:尤其是 Tiny 模型,平均損失接近 0.47,顯示較小模型無法承受過度剪枝。 - **4:8 剪枝較溫和**:Base 與 Small 模型在 4:8 下的精度下降顯著減少,顯示高階模型對剪枝更具韌性。 **較大的模型結構在 4:8 剪枝下仍能保留良好效能**,為後續微調與壓縮提供較佳起點。 ### 3️⃣ 蒸餾模型對剪枝後穩定性的影響 ![image](https://hackmd.io/_uploads/Sy-i1PgZgl.png) 此圖比較是否使用 **Knowledge Distillation(蒸餾)** 對於剪枝穩定性的影響,橫軸為蒸餾標記(True/False),縱軸為剪枝導致的精度下降幅度。 - 在 **Small 與 Base 模型中**,蒸餾組略為穩定,精度下降幅度稍低。 - 然而在 **Tiny 模型中**,蒸餾反而導致更高的精度損失,顯示 distillation 不總是能對抗剪枝帶來的退化。 **蒸餾效果因模型大小與結構而異**,不一定能一體適用於抗剪枝的策略。 ### 4️⃣個別組合結果分析(逐點視覺化) ![image](https://hackmd.io/_uploads/HySokveWeg.png) 此圖依據 `model_type + distill + sparsity` 三種屬性組合,逐點呈現剪枝前後準確率: - 🔴 **紅色點**代表剪枝前的 `original_acc` - 🔵 **藍色點**為剪枝後的 `pruned_acc` - ➤ 每一對點之間以**虛線箭頭**連接,視覺化展示模型精度下滑幅度 從圖中可觀察: - Tiny 模型箭頭最長(精度損失最多) - Base 模型大多箭頭較短,表示剪枝後仍維持良好性能 - Distilled Small 模型相對穩定,具備壓縮後續應用潛力 ### 5️⃣ `_nm_prune` 實作方式說明 ![image](https://hackmd.io/_uploads/BJ2okDl-gg.png) ## 微調策略實作說明 ### 1️⃣ Full Finetuning ![code](https://hackmd.io/_uploads/B1bAXvxWlx.png) 在 `finetune_full()` 中,我們對 **整個模型參數進行完整微調(Full Finetuning)**: - 所有參數 `requires_grad=True` - 使用 `AdamW` 優化器,搭配標準 CrossEntropy Loss - 每個 epoch 後存 checkpoint,最後儲存最終模型 `model_full.pth` 由於全參數皆可更新,此方式訓練成本高、記憶體消耗大,但也是效能最充分的微調方式,適合在資源允許下使用。 **tiny每一個epoch需要50分鐘 small每一個epoch需要2個小時 base每一個epoch需要6個小時 實際訓練時間將會在csv檔案產出後更新** ### 2️⃣ LoRA Finetuning ![code](https://hackmd.io/_uploads/B1FNVvgZxx.png) ![code](https://hackmd.io/_uploads/HyxONPlZeg.png) - 呼叫 `_inject_lora()` 將模型中的特定 `nn.Linear` 替換為 `loralib.Linear` - 僅針對 `qkv` 和 `proj` 等 Attention 關鍵層注入 LoRA 模組 - 設定 `requires_grad`,**僅讓 LoRA 權重可訓練** - 其他參數保持凍結狀態