Lab 1 - AI Model Design and Quantization

1. Model Architecture (15%)

下圖顯示了我所使用的模型架構圖,該架構圖是利用神經網絡架構可視化工具 PlotNeuralNet 所實現。PlotNeuralNet 主要基於 LaTeX(TikZ)來繪製網絡結構,使用者可以透過 pycore.tikzeng 提供的簡單 API 來描述模型的各層結構,並將其轉換為 .tex 文件,最終生成對應的圖像格式(如 .png)。這種方法適合用於產生高品質的可視化圖表,特別是在學術論文或報告中表達網絡架構時能夠提供清晰直觀的呈現。

在本次實作中,我利用 PlotNeuralNet 來手動定義卷積層(Conv)、池化層(MaxPool)、全連接層(FC)等基本組件,並調整圖像排版,使其更符合模型的實際結構。此外,透過適當調整 LaTeX 參數(如圖層間距、顏色、標籤等),可以進一步提升可讀性,確保不同層之間的關係能夠清楚呈現。

模型的各層參數均依照作業需求進行設計。在網絡的前端,我使用了連續兩次 MaxPool 層來快速降低特徵圖的尺度,從而減少計算量。除了計算上的優化外,MaxPool 層本身還能提高模型對平移變換的魯棒性,對於前層特徵圖所提取的低層次特徵進行池化,更有助於強化模型對基本特徵的識別能力。

由於 CIFAR-10 資料集的分類難度相對較低,且根據作業規範的要求,我的模型在結構上並未過度加深。只使用作業規範中所要求的各層一次,即可達到不錯的表現。此外,考慮到模型經過量化後的大小需控制在 4MB 以下,儘管增加網絡深度有助於進一步提升模型性能,但在實作上並未選擇這種方案。

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

2. Loss/Epoch and Accuract/Epoch Plotting (15%)

Loss and Accuracy/Epoch plot of training and validation

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

Test Performance Summary

  • Before Quantization

    ​​​​Test: loss=0.4438, accuracy=0.8517
    ​​​​Model size: 13.37 MB
    ​​​​Plot saved at figure/weight_fp32.png
    ​​​​Time: 5.54s
    
  • After Quantization

    ​​​​Test: loss=0.4608
    ​​​​accuracy=0.8473( -0.44 % )
    ​​​​size=3.351496MB ( < 4MB )
    

3. Accuracy Tuning (20%)

Data Preprocessing

資料未經過任何額外前處理,僅依照原有程式碼實作,故訓練集所使用的資料增強僅包括以下

  • RandomCrop(32, padding=4)

    這個方法會在原始影像四周填充 4 個像素,然後隨機裁剪出 32×32 的區塊,這樣可以模擬視角的微小變化,使模型對影像的局部偏移更加穩定。

  • RandomHorizontalFlip()

    這個方法會隨機以 50% 機率水平翻轉影像,這樣可以讓模型學習到物體的對稱性特徵,增強模型的泛化能力。

此外,無論是訓練集、驗證集還是測試集,數據皆經過標準化(Normalization) 處理:

  • ToTensor()

    將圖片轉換為 PyTorch 的 Tensor 格式,並將像素值從 [0, 255] 映射到 [0,1]。

  • Normalize(mean, std)

    針對 CIFAR-10 數據集的平均值 (0.4914, 0.4822, 0.4465) 以及標準差 (0.247, 0.243, 0.261) 進行標準化,使得數據分佈更均勻,加速模型收斂。

Hyperparameters

Hyperparameter Loss function Optimizer Scheduler Weight decay or Momentum Epoch
Value Cross Entropy Adam stepLR Default 15
  • Loss Function

    選擇使用 Cross Entropy 作為損失函數。對於多類別分類問題來說,Cross Entropy 是最常用的選擇,因為它可以有效地處理多類別的預測和實際標籤之間的差異。由於 CIFAR10 資料集具有均勻的類別分布,並不存在類別不平衡的問題,因此使用 Cross Entropy 不會出現過度偏向某些類別的情況,並且能夠穩定地進行模型訓練,不需要使用到 Focal Loss 這類可以處理類別不平衡資料集的損失函數。

  • Optimizer

    我使用了 Adam 優化器,其會根據每個參數的歷史一階矩和二階矩來動態調整學習率,並且能夠自動適應不同參數的更新需求。由於在此任務中使用的是較為簡單的模型,並且在經過實驗後發現 Adam 表現穩定且收斂速度較快,因此最終選擇 Adam 優化器。

    沒有額外調整 Adammomentum 參數,而是採用其預設值

  • Scheduler

    我使用了 StepLR 學習率調度器,並設置每 5 個 epoch 減少學習率至原來的 0.1。這樣的設置有助於防止在訓練後期學習率過大而引發震盪,從而幫助模型更加穩定地收斂。這也是為了避免模型在接近最優解時,學習率過大而無法精細地調整。由於本次實驗的 EPOCH = 15,因此設置其 step_size 為 5,使其在訓練過程可以經過三次的學習率下降,到達最終 1e-6 的小學習率。

  • Epoch

    在本次實驗中,我設定了 Epoch = 15,這是根據在訓練過程中的觀察所做的決定。當訓練進行到一定階段時,我發現出現了過擬合現象。過擬合通常表現為訓練集的損失持續下降,而驗證集的損失卻開始上升。為了有效防止過度擬合,我選擇減少 epoch 數量,這樣可以避免模型在訓練集上過度優化,從而保持較好的泛化能力。通過這種方式,我能夠確保模型在未來的資料上也能保持良好的表現。

Ablation Study

4. Explain how the Power-of-Two Observer in your QConfig is implemented. (25%)

在本次 Lab 中,任務是實作了一個自定義的 Observer,因此首先需要理解 Observer 在整個量化流程中的角色。我將量化流程分為以下三個階段,並說明 Observer 在各階段的行為:

  • Prepare

    在此階段,Observer 會被插入到模型中,並修改模型架構,使其具備監控數據範圍的能力。然而,在這個階段不會進行任何統計計算,Observer 只是準備就緒,以便在後續過程中收集數據。在初始化時,max_valmin_val 的初始值分別設置為正無限大和負無限大。

  • Calibrate

    此階段的主要目標是收集數據範圍資訊。在前向傳播過程中,Observer 會監測各層張量的數值範圍,並更新對應的 max_valmin_val。值得注意的是,在這個階段仍然不會計算量化參數(如 scalezero_point),而是單純更新數據範圍,以確保後續的量化過程基於正確的範圍資訊。

  • Convert

    在此階段,Observer 會基於 max_valmin_val 計算量化參數(scalezero_point),並將模型轉換為 INT8 格式。此外,Observer 會從模型中移除,以減少不必要的計算和存儲開銷。這樣的轉換使得模型能夠在低精度格式下運行,同時保持數值範圍內的表示精度。

釐清各階段 Observer 的行為可以避免錯誤。例如,當量化模型已轉換完成並儲存於 .pth 檔案後,若後續需要進行推論,必須注意模型重新載入時的行為。由於重新載入模型時,系統只會經過 prepare 以及 convert 階段,而不會再經過 calibrate 階段(即不會執行前向傳播)。這意味著 max_valmin_val 仍維持預設值(正負無限大),導致 convert 階段的計算出現錯誤。因此,在計算量化參數的階段,應加入適當的檢驗機制,以確保其能正常運作。值得注意的是,在我的實作中,最終計算量化參數時,會使用繼承自 parent class_calculate_qparams 函式。該函式內部會自動檢查 min_valmax_val 是否有效,若無效,則會將量化參數設為固定值(如 0 或 1),以避免錯誤發生。

1. Explain how to caluclate scale and zero-point

PowerOfTwoObserver 類別中,calculate_qparams 函數負責計算量化參數 scalezero_point。其程式碼如下:

def calculate_qparams(self):
    """Calculates the quantization parameters with scale as power of two."""
    scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
    scale = self.scale_approximate(scale)
    return scale, zero_point
  • scale

    首先,調用父類 MinMaxObserver_calculate_qparams 函數,根據觀察到的 min_valmax_val 計算初始的 scale 值。這是基於標準的對稱量化方案(例如 qscheme=torch.per_tensor_symmetric)得出的結果。

    接著,將這個初始的 scale 值傳入 scale_approximate 函數,將其調整為最接近的 2 的冪次(power of two),以滿足 PowerOfTwoObserver 的設計要求。

  • zero_point

    zero_point 是由父類的 _calculate_qparams 函數直接計算得出的。在對稱量化模式下,對於 qint8(有符號整數),zero_point 通常為 0;而對於 quint8(無符號整數),zero_point 可能為 128,具體取決於量化範圍。

2. Explain how scale_approximate() function in class PowerOfTwoObserver() is implemented.

scale_approximate 函數的目的是將給定的 scale 值近似到最接近的 2 的冪次。其程式碼如下:

def scale_approximate(self, scale: float, max_shift_amount=8):
    n = torch.round(torch.log2(scale))
    scale = torch.pow(2, n)
    return scale
  1. 計算對數

    首先使用 torch.log2(scale) 計算 scale 以 2 為底的對數,得到一個浮點數。然後使用 torch.round 將這個值四捨五入到最接近的整數 n。

  2. 計算 2 的冪次

    根據四捨五入後的 n,計算 2^n(即 torch.pow(2, n)),並將其作為最終的 scale 值返回。

  3. 參數max_shift_amount

    在程式碼中雖然被定義為參數(預設值為 8),但在當前實現中並未直接使用。它未來可以用來限制 n 的範圍,但目前僅作為占位符。

3. When writing scale_approximate(), is there a possibility of overflow? If so, how can it be handled?

因為我們討論的是 scale_approximate() 是否有可能發生溢位,故先假設其輸入參數 scale 是滿足 IEEE754 定義的合法浮點數,以此為假設條件去判斷是否有發生溢位的可能,即經過 scale_approximate() 後是否會輸出不符合 IEEE754 的浮點數數值。下圖是 IEEE754 所定義之浮點數,即其表示範圍

  • IEEE754

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • 表示範圍

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

  • 表示方法

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

根據以上三者,可以得出判斷是否會溢位之方法,即經過 round(log2()) 後發生 n = 128 的狀況,因此時會造成回傳的 scaleinf。接著根據以下計算分析輸入參數 scale 之數值超過多少時會發生此狀況。

  • 計算分析

log2(scale)127.5

scale2127.5

2127.51.701×1038

由此可知,傳入 scale 若是超過上述提及之數值便有可能發生溢位,而為避免此狀況發生可以提前設定 n 的上下限以避免此問題,可同理處理下溢位 (通常是只會有下溢位,發生溢位可能性非常低) 的狀況。

對於神經網路領域而言,判斷發生 sclae 大於該數值的情況是否是有可能的,故再考慮傳入之 scale 的計算方式 (因講義已提出這邊不再貼出) 可以發現其值取決於 Observer 所觀察的數值上下限,但要發生權重或是激活值出現如此大(小)之數值機率非常低,故我其實在實作過程並未加入任何檢測溢位並做出對應解決的邏輯控制。

重新分析 scale_approximate 函式中的參數 max_shift_amount=8,其目的或許是限制其位移量不能超過 8,但目前我不能明白設置其值為 8 之原因,因根據報告範本中的 Hint 提及累加時會以 INT32 進行,故我認為 max_shift_amount 應可以設置到 32 ( 即最多右移或左移 32 位元 ),我在我繳交的程式碼有提供應對不同 max_shift_amount 的功能 ( 已註解 ),雖確實有可能發生 n 超過 8 的狀況,但因根據目前的位移量超過 8 不會造成影響故該功能目前我未將其開啟,開啟仍可使準確度下降不超過 1 %。

def scale_approximate(self, scale: float, max_shift_amount=8):
        #########Implement your code here##########
        n = torch.round(torch.log2(scale))
        # n = torch.clamp(n, min=-max_shift_amount, max=max_shift_amount)

        scale = torch.pow(2, n)
        return scale

5. Comparison of Quantization Schemes (25%)

Given a linear layer (128 → 10) with an input shape of 1×128 and an output shape of 1×10, along with the energy costs for different data types, we will use the provided table to estimate the total energy consumption for executing such a fully connected layer during inference under the following two scenarios:

  1. Full precision (FP32)
  2. 8-bit integer, power-of-2, static, uniform symmetric quantization
    • activation: UINT8
    • weight: INT8
Operation Energy consumption (pJ)
FP32 Multiply 3.7
FP32 Add 0.9
INT32 Add 0.1
INT8 / UINT8 Multiply 0.2
INT8 / UINT8 Add 0.03
Bit Shift 0.01

Hint

  • The energy consumption of INT32 addition should also be considered. Each INT32 addition consumes 0.1 pJ of energy, as depicted in the figure of the lab hanout.
  • Since we are using static quantization in this lab, the power-of-two scaling factors for input, weight, and output can fused into ONE integer before the inference.
  • The summation is computed under INT32 rather than INT16.

You can ignore the energy consumption of type casting, memory movement, and other operations not listed in the table.

You can refer to the following formula previously-mentioned in the lab handout:

(6)y¯i=(ReLU(b¯i+j(x¯j128)w¯ji)(cx+cwcy)pre-computed offline)+128

Write down your calculation process and answer in detail. Answers without the process will only get partial credit.

Your Answer

Before quantization (FP32) After quantization
Energy consumption (pJ) 5879 449.5

計算過程

1. 全精度 (FP32) 的能量消耗

在全精度情境中,線性層使用浮點乘法和加法進行計算:

  • 乘法:每個輸出節點需要將 128 個輸入值與權重相乘,使用 FP32。
  • 加法:每個輸出節點需要 127 次 FP32 加法來將 128 個乘法結果相加。

計算步驟

  • 乘法次數:10 × 128 = 1280
  • 加法次數:10 × 127 = 1270

能量消耗

  • 乘法能量:1280 × 3.7 pJ = 4736 pJ
  • 加法能量:1270 × 0.9 pJ = 1143 pJ
  • 總能量:4736 + 1143 = 5879 pJ

2. 8-bit 量化的能量消耗

在 8-bit 量化情境中:

  • 激活值:UINT8
  • 權重:INT8
  • 量化計算公式為:

y¯i=(ReLU(b¯i+j(x¯j128)w¯ji)(cx+cwcy))+128

由於採用靜態量化,縮放因子 ( c_x + c_w - c_y ) 被融合為單一整數,推理過程包括:

  • (x - 128) 操作:這是每個輸入節點的額外加法運算。
  • INT8/UINT8 乘法
  • INT32 加法(累加使用 INT32)
  • 位元移位
  • 最後的 (+128) 操作:每個輸出節點的加法運算。

計算步驟

  • (x - 128) 的加法次數:10 × 128 = 1280(INT8 加法)
  • INT8/UINT8 乘法次數:10 × 128 = 1280(INT8/UINT8 乘法)
  • INT32 加法次數:10 × 127 = 1270(INT32 加法)
  • 位元移位次數:10 × 1 = 10
  • (+128) 的加法次數:1 × 10 = 10(INT32 加法)

能量消耗

  • (x - 128) 的加法能量:1280 × 0.03 pJ = 38.4 pJ
  • 乘法能量:1280 × 0.2 pJ = 256 pJ
  • 加法能量:1270 × 0.1 pJ = 127 pJ
  • 位元移位能量:10 × 0.01 pJ = 0.1 pJ
  • (+128) 的加法能量:10 × 0.1 pJ = 1 pJ

總能量

38.4pJ+256pJ+127pJ+0.1pJ+1pJ=449.5pJ

6. Questions

在完成此次實驗後,我有一些未解的疑問。已知在深度學習模型中,權重和激活值通常是小數,但根據 Power-of-Two 量化方法的公式,這些浮點數將被轉換為整數,並且通過位元移位來取代原本的浮點數乘法運算。這樣一來,最終的數值不再是浮點數,而是整數型態。然而,在這種情況下,如何使用整數來表示小數值呢?

具體來說,當我們將數值轉換為整數時,如何處理小數部分?是否可以通過提前進行統計分析來決定小數點的位置,以便準確地將 INT(Fixed-point number)用來表示實際的小數?這是我目前的疑惑。

In Equation (4), the only non-integer is the multiplier

M.
As a constant depending only on the quantization scales
S1
,
S2
,
S3
, it can be computed offline. We empirically find
it to always be in the interval
(0,1)
, and can therefore express it in the normalized form
(6)M=2nM0

where
M0
is in the interval
[0.5,1)
and
n
is a non-negative
integer. The normalized multiplier
M0
now lends itself well
to being expressed as a fixed-point multiplier (e.g., int16 or
int32 depending on hardware capability). For example, if
int32 is used, the integer representing
M0
is the int32 value
nearest to
231M0
. Since
M0>0.5
, this value is always at
least
230
and will therefore always have at least 30 bits of
relative accuracy. Multiplication by
M0
can thus be implemented as a fixed-point multiplication.
Meanwhile, multiplication by
2n

can be implemented with an efficient bitshift, albeit one that needs to have correct round-to-nearest
behavior, an issue that we return to in Appendix B.