Automatic Joint Structured Pruning and Quantization for Efficient Neural Network Training and Compression ========================================================================================================= > **Paper**:Qu _et al._ (2025) – GETA > **目的**:在 _一次_ 訓練流程中,同步完成 **結構化剪枝 (Structured Pruning)** \+ **混合精度量化 (Mixed-Precision QAT)**,最大化壓縮率且維持精度。 1 | GETA 框架概觀 ------------- | 特色 | 說明 | | --- | --- | | **G**eneral | 架構不綁定 CNN / Transformer,可自動構建依賴圖 | | **E**fficient | 單一訓練流程;無需「搜尋 ➜ 再訓練」的兩階段 | | **T**raining Automates | 透過 **QASSO** 優化器同時決定剪枝比例與 bitwidth | | White-box | γ, d 皆顯式可控,壓縮率可預估 | 論文提供的code ```python import GETA # 1️⃣ 建立框架(任何 PyTorch 模型皆可) geta = GETA(model) # 2️⃣ 取得 QASSO optimizer,開始一站式訓練 opt = geta.qasso() for data, target in loader: loss = criterion(model(data), target) loss.backward() opt.step() # 3️⃣ 訓練結束後,一行指令產出剪枝+量化子網 pruned_quant_model = geta.construct_subnet() ``` --- 2 | Quantization-Aware Dependency Graph (QADG) ---------------------------------------------- > **關鍵**:量化會在 computational graph 內插入 _attached / inserted branches_,若不額外處理,剪枝時易破壞拓撲。 - **Attached branch**:權重量化 (weight Q) 會在原 op 外掛一組參數化 quantizer。 - **Inserted branch**:Activation Q 則會插入多層 op(round/clip/scale)。 - **QADG**:演算法先 **合併** 這些 branch → 再跑依賴分析,得到 _可安全移除_ 的最小結構單位。 ![螢幕擷取畫面 2025-06-18 104557](https://hackmd.io/_uploads/rkskYsJVxl.png) | 子圖 | 內容 | 想傳達的重點 | | --- | --- | --- | | **(a) Weight quantization** | \- 藍色 _Conv_ 被加上一條 **attached branch**(虛線框內) \- 這支線包含 _weight quantizer_ 及與主圖共用權重的 **weight-sharing nodes** \- 灰色六邊形代表「Ambiguous Shape Operator」(例如 `view/reshape`),在依賴分析時會使張量形狀難以追蹤 | ➜ **權重量化** 會在原算子旁外掛整個子圖,若不特別處理,剪枝時無法判斷哪些權重可以一起移除 | | **(b) Activation quantization** | \- 綠色 _Linear_(或其他 op)與上一層輸出之間,被插入一段 **inserted branch**(黑色箭頭圈起) \- 分別處理 clip、round、scale 等量化步驟 | ➜ **Activation Q** 會在層與層之間插入多層小 op,導致依賴關係被「拉長」 | | **(c) Quantization-aware dependency graph** | \- 左右兩顆 _QConv_ 經 QADG 分析後,整條 attached branch 被**合併成單一節點** \- 同理,activation 的 inserted branch 也被折疊,最後剩下一顆 _QLinear_ \- 整體拓撲變得簡潔,可直接判定「如果砍掉上方 _Relu_ 一側,就必須同步移除下方哪一組節點」 | ➜ **QADG 的目標**:把 weight / activation 量化帶來的額外節點整併,得到最小可剪結構,方便之後做 _structured pruning_ | --- 3 | QASSO:四階段優化流程 ----------------- | Stage | 目的 | 核心操作 | | --- | --- | --- | | **Warm-up** | 良好初始化 | 全參數 SGD / Adam 幾輪 | | **Projection** | 逐步收斂到 _目標 bit 範圍_ | **PPSG** 只投射 `d` → 避免梯度爆炸 | | **Joint** | 同步剪枝+量化 | 分群 (`GI`, `GR`)、計算 γ 與 d | | **Cool-down** | 微調知識回收 | 固定 bitwidth / mask,再訓練至收斂 | ### 3.1 | 為何選 PPSG? - 傳統 **Penalty method** 需手動調 λ,難調參。 - 全量投射 (d, t, qₘ) 容易因 $(q_m)^t$ 非線性導致梯度爆/消失。 - 只對 **step size d** 投射,可穩定控制 bitwidth,且對指數項獨立。 $$b = \log_2\!\Bigl(\frac{(q_m)^t}{d}+1\Bigr)+1$$ --- 4 | Joint Stage:γ 與 d 的條件式更新 ---------------------------- 以下以您指定的格式呈現三種情況(γ)與兩種情況(d)。符號同論文。 ### 4.1 | Forget Rate γ #### **✅ 條件 1:資訊量極小,可忽略** 條件: $\text{clip} \le \varepsilon$ 設定: $\gamma = 0$ 解釋:權重量化後已近似 0,直接歸零省計算。 --- #### **✅ 條件 2:方向一致,需漸進削弱** 條件: $\cos\theta_\gamma \ge 0,\; \text{clip} > \varepsilon$ 設定: $\gamma = 1-\frac{K_p-k-1}{K_p-k}$ 解釋:隨 pruning step k 增大,平滑降低該 group 影響力。 --- #### **❗條件 3:方向相反,需強制遺忘** 條件: $\cos\theta_\gamma < 0,\; \text{clip} > \varepsilon$ 設定: $\gamma = -\frac{(1-\eta)\alpha\|\nabla_x f\|}{\cos\theta_\gamma\;\|\text{sgn}(x)\,\text{clip}\| }$ 解釋:資訊方向與梯度衝突,藉由正 γ 反向抑制,加速清除。 --- ### 4.2 | Step Size d | Case | 條件 | 設定 | 意涵 | | --- | --- | --- | --- | | **正常** | $\cos\theta_d \ge 0$ | $d=\dfrac{(q\_m)^t}{2^{b\_l-1}-1}$ | 直接壓到最低 bitwidth | | **衝突** | $\cos\theta_d < 0$ | $d=-\dfrac{\xi\eta\alpha\nabla\_xf}{\gamma\cos\theta\_d\text{sgn}(x)R(x)}$ | 使用梯度導向步長,避免量化誤導 | >實務中兩者會再經β-factor 微調,確保最終 `b∈[b_l,b_u]`。 --- 5 | 實驗亮點 -------- | Model / Dataset | Top-1 Acc ↓ | BOPs ↓ | 備註 | | --- | --- | --- | --- | | ResNet-20 / CIFAR-10 | **-0.28 %** | **-95.5 %** | weight-only Q | | VGG-7 / CIFAR-10 | **-0.48 %** | **-99.6 %** | weight+act Q | | ResNet-50 / ImageNet | **-1.03 %** | **-93 %** | 40 % sparsity | | BERT / SQuAD (50 %) | **+4.0 F1** | **-85.6 %** | prune-then-QAT baseline | > 當 sparsity > 60 % 時,需抬高 bitwidth(≥ 4-bit)才不掉分。 --- 6 | 小結 & Takeaways ------------------ 1. **一次性訓練** \> 「先剪後量」傳統流程;省時且壓縮率更可控。 2. **PPSG** 投射 `d` → 穩定學習 bitwidth,不易梯度爆炸。 3. **γ / d 條件式設計** 兼顧 _方向一致性_ \+ _資訊量_,有效協調剪枝與量化衝突。 --- **Reference** Qu X., Aponte D., Banbury C. _et al._, “Automatic Joint Structured Pruning and Quantization for Efficient Neural Network Training and Compression”, _arXiv:2502.16638_ (2025). ---