執行人: Denny0097
在大型語言模型資源需求日益高漲的背景下,微軟研究院提出 BitNet b1.58 2B4T,以 1.58 位元量化訓練的開放原始碼模型,在其推論過程中,BitNet 的矩陣乘法可化約為加法、減法與忽略,浮點乘法完全被移除。這使模型在實際部署中不僅延遲降低、記憶體存取負擔減輕,也進一步降低能源消耗。這種極簡的運算流程非常適合硬體加速器設計。
BitNet b1.58 採用 Transformer 架構進行大幅簡化與重設,包括移除線性層與正規化層中的 bias,使用 BitLinear 層取代傳統全連接層,搭配旋轉位置編碼 (RoPE)、ReLU² 激勵函數與 subLN 正規化。
BitNet b1.58 2B4T 模型包含約 20 億參數,訓練資料量達 4 兆個 token,涵蓋自然語言 (以英語為主)。它使用與 Llama 3 相同的分詞器,字彙表大小為 128,256,並支援最多 4096 token 的上下文長度。整體訓練過程分為三個階段:預訓練 (pretraining)、監督式微調 (SFT) 與偏好對齊 (DPO),使模型在效能與對話表現間取得良好平衡。
BitNet b1.58 的運行效率極高。在 Apple M2 這類通用 CPU 上可達 29 毫秒延遲,且記憶體佔用僅為 0.4GB,遠小於例如 Gemma 3 1B 所需的 1.4GB。微軟團隊在基準測試中將其與 Llama 3.2 1B、Gemma 3 1B 與 Qwen 2.5 1.5B 等全精度模型對比,發現 BitNet 即使在參數較少、權重精度較低的情況下,仍能在 MMLU、GSM8K、MATH 等任務中維持穩定表現,並在 GSM8K 上奪得最佳成績。BitNet 於 30 億參數以上的規模下,其性能已能接近 FP16 模型,並於 70 億參數規模達成高達 4 倍的推理速度提升與 7 倍的記憶體節省,顯示此技術具備高度擴展潛力。
本任務嘗試運用 Linux 核心提供的效能分析工具,定位出 BitNet 運行時期的效能瓶頸,並善用 Transparent Hugepage Support、針對事件驅動的 I/O 模型 (如 io_uring
),和課程所及的手法,加速 BitNet。
參考 BitNet 以及 BitNetMCU 在 Linux 環境利用 BitNet quantiztion 的 VGG8 for MNIST dataset (model.h),使用 Linux 核心提供的效能分析工具分析 training & inference 運行時期的效能瓶頸,分析後設計 CPU & mem usage 最佳化,並用 SystemC 模擬硬體運算,另嘗試硬體加速。
研究 BitNet [1]:
效能改進:
修改 BitNetMCU [6]:
HW simulation
[1] https://github.com/microsoft/BitNet
[2] https://xmrig.com/docs/miner/hugepages
[3] https://github.com/microsoft/T-MAC
[4] https://hackmd.io/@sysprog/linux-zerocopy
[5] BitNet 有 LUT: https://github.com/microsoft/BitNet/tree/main/src
https://www.kernel.org/doc/html/next/admin-guide/mm/transhuge.html
run instruction:
result :
This command would run the inference benchmark using the model located at /path/to/model, generating 200 tokens from a 256 token prompt, utilizing 4 threads.
model | size | params | backend | threads | n_batch | test | t/s |
---|---|---|---|---|---|---|---|
bitnet-b1.58 2B I2_S - 2 bpw ternary | 1.71 GiB | 2.74 B | CPU | 4 | 1 | pp256 | 15.86 ± 0.04 |
bitnet-b1.58 2B I2_S - 2 bpw ternary | 1.71 GiB | 2.74 B | CPU | 4 | 1 | tg200 | 15.75 ± 0.10 |
…
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
BitConv2d-1 [-1, 64, 32, 32] 576
ReLU-2 [-1, 64, 32, 32] 0
MaxPool2d-3 [-1, 64, 16, 16] 0
BitConv2d-4 [-1, 192, 16, 16] 110,592
ReLU-5 [-1, 192, 16, 16] 0
MaxPool2d-6 [-1, 192, 8, 8] 0
BitConv2d-7 [-1, 384, 8, 8] 663,552
ReLU-8 [-1, 384, 8, 8] 0
BitConv2d-9 [-1, 256, 8, 8] 884,736
ReLU-10 [-1, 256, 8, 8] 0
BitConv2d-11 [-1, 256, 8, 8] 589,824
ReLU-12 [-1, 256, 8, 8] 0
MaxPool2d-13 [-1, 256, 4, 4] 0
Flatten-14 [-1, 4096] 0
BitLinear-15 [-1, 256] 1,048,576
ReLU-16 [-1, 256] 0
BitLinear-17 [-1, 128] 32,768
ReLU-18 [-1, 128] 0
BitLinear-19 [-1, 10] 1,280
================================================================
Total params: 3,331,904
Trainable params: 3,331,904
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 2.91
Params size (MB): 12.71
Estimated Total Size (MB): 15.63
----------------------------------------------------------------
# Description: Training parameters for the training script
# Model selection
model: 'VGG' # 'FCMNIST' or 'CNNMNIST' This is the class name of the model as defined in models.py.
# Quantization settings
QuantType: 'I2_S' # 'Ternary', 'Binary', 'BinaryBalanced', '2bitsym', '4bit', '4bitsym', '8bit', 'None", 'FP130', 'NF4', 'I2_S'
NormType: 'RMS' # 'RMS', 'Lin', 'BatchNorm'
WScale: 'PerTensor' # 'PerTensor', 'PerOutput'
# Clipping parameters - only used for 2 bit and higher quantization
maxw_algo: 'octav' # 'octav', 'prop' Algorithm used to calculate the clipping parameters (maximum weight)
maxw_update_until_epoch: 50 # Update clipping parameters until this epoch, they are frozen afterwards
maxw_quantscale: 0.25 # Used only for clipping_algo='prop'. Determines the relation between stddev of weights and max_weight
# Learning parameters
num_epochs: 50
batch_size: 32
scheduler: "Cosine" # "StepLR", "Cosine"
learning_rate: 0.001
lr_decay: 0.1 # lr_decay and step size are not used with cosine scheduler
step_size: 10
# halve_lr_epoch: 30 # Epoch at which to halve the learning rate
# Data augmentation
augmentation: True
rotation1: 10 # rotation1 and rotation2 are used for data augmentation
rotation2: 10
# Model parameters
network_width1: 256
network_width2: 128
network_width3: 0
# name
runtag: "octav" # runtag is prefix for runname
BitConv2d & BitLinear 是 BitNetMCU 提供的 layer 框架,其中支持包含 Normalize(RMS) 跟 QAT foward 的運作
MNIST
w_scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
w_int = (w * w_scale ).round().clamp_(-1, 1)
training
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127)
inference
scale = 127.0 / np.maximum(np.abs(input_data).max(axis=-1, keepdims=True), 1e-5)
current_data = np.round(input_data * scale).clip(-128, 127)
HW friendly:
計算 input 的絕對值的最大值 max(|x|),並計算最接近 max(|x|) 的 2 的冪次 ,用此值作為 quantization 的最大範圍, 來得到 range [-128, 127] quanted value,但這樣會無條件捨棄小數,為了更精準,加上 round 的計算:,過程中沒有乘法也沒有浮點數運算。
基於BitMCU提供的 BitLinear & BitConv2d 進行 QAT 讓模型在訓練時就學會適應低 bit-width inference。
而由於 BitnetMCU 目前不支援 padding, maxpooling, BatchNorm, (使用 BatchNorm 訓練的模型在 export 成 .c file 時,accuracy 會異常低,接近隨機推演),以及最重要的,原始的 export.py 還沒有支援任何 BitNet 儲存。
增加 Maxpool 以及 Padding ,最後用 CIFAR10 訓練並先實現 export I2_S model (2bits),讓 lab 中用到的 VGG8 模型能夠在該專案中訓練並生成 bitnet model。
(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ python exportquant.py
Load parameters from file: trainingparameters.yaml
octav_VGG_Aug_BitMnist_I2_S_width256_128_0_epochs10
Loading model...
Inference using the original model...
Accuracy/Test of trained model: 99.67 %
Quantizing model...
0 VGG
1 Sequential
2 BitConv2d
3 ReLU
4 MaxPool2d
5 BitConv2d
6 ReLU
7 MaxPool2d
8 BitConv2d
9 ReLU
10 BitConv2d
11 ReLU
12 BitConv2d
13 ReLU
14 MaxPool2d
15 Flatten
16 BitLinear
17 ReLU
18 BitLinear
19 ReLU
20 BitLinear
Layer: 2, Max: 1.0, Min: -1.0, Mean: -0.11458333333333333, Std: 0.8317041365974641
Values: [-1. 0. 1.]
Percent: [40.97222222 29.51388889 29.51388889]
Entropy: 1.57 bits. Code capacity used: 78.33160789015268 %
Layer: 5, Max: 1.0, Min: -1.0, Mean: -0.2090747974537037, Std: 0.7201310989476462
Values: [-1. 0. 1.]
Percent: [38.5687934 43.76989294 17.66131366]
Entropy: 1.49 bits. Code capacity used: 74.68134472269595 %
Layer: 8, Max: 1.0, Min: -1.0, Mean: -0.14563289689429013, Std: 0.6785549412400004
Values: [-1. 0. 1.]
Percent: [31.36393229 51.83542511 16.8006426 ]
Entropy: 1.45 bits. Code capacity used: 72.42034190610372 %
Layer: 10, Max: 1.0, Min: -1.0, Mean: -0.17259724934895834, Std: 0.6684536886212908
Values: [-1. 0. 1.]
Percent: [32.46086968 52.33798557 15.20114475]
Entropy: 1.43 bits. Code capacity used: 71.44577546801597 %
Layer: 12, Max: 1.0, Min: -1.0, Mean: -0.13642713758680555, Std: 0.6916967437120225
Values: [-1. 0. 1.]
Percent: [31.67419434 50.29432509 18.03148058]
Entropy: 1.47 bits. Code capacity used: 73.48354983372585 %
Layer: 16, Max: 1.0, Min: -1.0, Mean: -0.06987667083740234, Std: 0.6562856511567088
Values: [-1. -0. 1.]
Percent: [25.27351379 56.4406395 18.28584671]
Entropy: 1.42 bits. Code capacity used: 70.77351150852402 %
Layer: 18, Max: 1.0, Min: -1.0, Mean: -0.019989013671875, Std: 0.7926493968240628
Values: [-1. 0. 1.]
Percent: [32.43408203 37.1307373 30.43518066]
Entropy: 1.58 bits. Code capacity used: 78.99523722907314 %
Layer: 20, Max: 1.0, Min: -1.0, Mean: -0.31484375, Std: 0.7746561579732891
Values: [-1. -0. 1.]
Percent: [50.703125 30.078125 19.21875 ]
Entropy: 1.48 bits. Code capacity used: 73.77139839499056 %
Total number of bits: 6663808 (813.453125 kbytes)
inference of quantized model
layer: ('BitConv2d', 2)
layer: ('ReLU', 3)
layer: ('MaxPool2d', 4)
layer: ('BitConv2d', 5)
layer: ('ReLU', 6)
layer: ('MaxPool2d', 7)
layer: ('BitConv2d', 8)
layer: ('ReLU', 9)
layer: ('BitConv2d', 10)
layer: ('ReLU', 11)
layer: ('BitConv2d', 12)
layer: ('ReLU', 13)
layer: ('MaxPool2d', 14)
layer: ('BitLinear', 16)
layer: ('ReLU', 17)
layer: ('BitLinear', 18)
layer: ('ReLU', 19)
從 weight 的分佈來看 -1 & 0 是相對較高的,推測因為 model 沒有 bias 且每個 conv & fc 又經過 ReLU,所以會導致 conv 的輸入都是正數,而 conv 輸出會有盡量常態分布 (向 0 集中)的傾向,所以導致複數的 weights 大過正數。
(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ gcc BitNetMCU_MNIST_test.c -o mnist_test -std=c99 -lm
(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ ./mnist_test
fc3_out: -1076 -749 -462 1650 -1109 -424 -1317 -438 -478 -750
label: 3 predicted: 3
fc3_out: -615 -556 1597 -337 -600 -778 -444 -513 -436 -672
label: 2 predicted: 2
fc3_out: 1255 -731 -646 -759 -801 -594 -397 -831 -626 -436
label: 0 predicted: 0
fc3_out: -761 -948 -753 -724 -102 -707 -1141 -317 -495 1685
label: 9 predicted: 9
fc3_out: 1346 -758 -718 -820 -860 -650 -506 -871 -678 -394
label: 0 predicted: 0
fc3_out: -366 -527 -800 -965 -315 -345 1135 -1150 -538 -847
label: 6 predicted: 6
fc3_out: -635 -701 -549 -599 88 -633 -884 -311 -362 1210
label: 9 predicted: 9
fc3_out: -474 -314 803 -133 -286 -708 -675 250 -339 -416
label: 2 predicted: 2
fc3_out: -526 -532 -134 -344 -628 -730 -1272 1069 -400 -110
label: 7 predicted: 7
fc3_out: -858 -318 -519 -481 -672 -995 -1577 1423 -506 -64
label: 7 predicted: 7
Exec time
在程式碼前後紀錄 clock。
(用 if/else 取代 conv 的 mul 計算 vs 保持 conv 的 mul)
branch:
mul:
Layer | Branch (ms) | Mul (ms) |
---|---|---|
L2 (Conv) | 16.103 | 14.303 |
L3 (ReLUNorm) | 0.322 | 0.355 |
L4 (Maxp) | 0.329 | 0.316 |
L5 (Conv) | 245.586 | 213.306 |
L6 (ReLUNorm) | 0.201 | 0.201 |
L7 (Maxp) | 0.198 | 0.196 |
L8 (Conv) | 342.022 | 298.320 |
L9 (ReLUNorm) | 0.105 | 0.105 |
L10 (Conv) | 456.651 | 401.056 |
L11 (ReLUNorm) | 0.069 | 0.069 |
L12 (Conv) | 304.385 | 278.764 |
L13 (ReLUNorm) | 0.065 | 0.067 |
L14 (Maxp) | 0.066 | 0.067 |
L16 (FC) | 5.622 | 4.595 |
L17 (ReLUNorm) | 0.002 | 0.002 |
L18 (FC) | 0.172 | 0.129 |
L19 (ReLUNorm) | 0.001 | 0.001 |
L20 (FC) | 0.007 | 0.006 |
total | 1371.706 ms | 1211.458 ms |
雖然原本希望能藉由省去 mult 讓計算加速,但實際上如果用 if/else 來判斷加減會有更多的 branch 導致更高的 cost。
bitwise 的計算來同時避免乘法及 branch:
Unpack | Pack |
---|---|
-1 | 00 |
0 | 01 |
1 | 10 |
模型推論時間幾乎 全部卡在卷積層 processcvlayer_I2_S(可以做 loop unrolling、
Layer | Entropy(bits) | Code capacity used |
---|---|---|
L2 | 1.57 | 78.33160789015268 % |
L5 | 1.49 | 74.68134472269595 % |
L8 | 1.45 | 72.42034190610372 % |
L10 | 1.43 | 71.44577546801597 % |
L12 | 1.47 | 73.48354983372585 % |
L16 | 1.42 | 70.77351150852402 % |
L18 | 1.58 | 78.99523722907314 % |
L20 | 1.48 | 73.77139839499056 % |
畢竟 2bits 只用來儲存 [-1, 1] 的 weights ,理所當然 capacity 不好。
Cache miss 比例偏高,TODO : 最佳化 memory layout (huge page)
如果單純做 unroll-loops 雖然減少 branch 提高計算速度,但會因為指令邊多導致 cache miss rate 提高 (I-cache)
利用 gcc 最佳化編譯( additional instruction set, loop unrolling..)
雖然因為 loop unrolling 導致 miss rate 提升,但執行速度大幅提升(8倍以上)
同時 loop unrolling 也使 conv2 的負擔減少(少到抓不到?)
文字訊息不要用圖片展現!