Try   HackMD

Linux 核心專題: 藉由系統手段加速 BitNet

執行人: Denny0097

任務簡述

在大型語言模型資源需求日益高漲的背景下,微軟研究院提出 BitNet b1.58 2B4T,以 1.58 位元量化訓練的開放原始碼模型,在其推論過程中,BitNet 的矩陣乘法可化約為加法、減法與忽略,浮點乘法完全被移除。這使模型在實際部署中不僅延遲降低、記憶體存取負擔減輕,也進一步降低能源消耗。這種極簡的運算流程非常適合硬體加速器設計。

BitNet b1.58 採用 Transformer 架構進行大幅簡化與重設,包括移除線性層與正規化層中的 bias,使用 BitLinear 層取代傳統全連接層,搭配旋轉位置編碼 (RoPE)、ReLU² 激勵函數與 subLN 正規化。
BitNet b1.58 2B4T 模型包含約 20 億參數,訓練資料量達 4 兆個 token,涵蓋自然語言 (以英語為主)。它使用與 Llama 3 相同的分詞器,字彙表大小為 128,256,並支援最多 4096 token 的上下文長度。整體訓練過程分為三個階段:預訓練 (pretraining)、監督式微調 (SFT) 與偏好對齊 (DPO),使模型在效能與對話表現間取得良好平衡。

BitNet b1.58 的運行效率極高。在 Apple M2 這類通用 CPU 上可達 29 毫秒延遲,且記憶體佔用僅為 0.4GB,遠小於例如 Gemma 3 1B 所需的 1.4GB。微軟團隊在基準測試中將其與 Llama 3.2 1B、Gemma 3 1B 與 Qwen 2.5 1.5B 等全精度模型對比,發現 BitNet 即使在參數較少、權重精度較低的情況下,仍能在 MMLU、GSM8K、MATH 等任務中維持穩定表現,並在 GSM8K 上奪得最佳成績。BitNet 於 30 億參數以上的規模下,其性能已能接近 FP16 模型,並於 70 億參數規模達成高達 4 倍的推理速度提升與 7 倍的記憶體節省,顯示此技術具備高度擴展潛力。

本任務嘗試運用 Linux 核心提供的效能分析工具,定位出 BitNet 運行時期的效能瓶頸,並善用 Transparent Hugepage Support針對事件驅動的 I/O 模型 (如 io_uring),和課程所及的手法,加速 BitNet。

Alan 的實驗

Outline

  • TODO
  • BitNet Experiment
    • Env
  • SW
    • Model
      • Training data
    • Quantiztion scheme(QAT)
    • Code architecture
    • Analysis
      • Inference time
      • Accuracy
      • Size(param nums, total bits, memory)
  • HW simulation

TODO

參考 BitNet 以及 BitNetMCU 在 Linux 環境利用 BitNet quantiztion 的 VGG8 for MNIST dataset (model.h),使用 Linux 核心提供的效能分析工具分析 training & inference 運行時期的效能瓶頸,分析後設計 CPU & mem usage 最佳化,並用 SystemC 模擬硬體運算,另嘗試硬體加速。

研究 BitNet [1]:

  • 在 GNU/Linux 系統運作 BitNet b1.58 2B4T 並重現論文實驗

效能改進:

  • 以 perf 在內的工具,測量推理過程中運算資源佔比前 20 大的函式,並探討其作用
  • 分析記憶體使用量,特別是過程中的 page fault, TLB miss 等統計。在 XMRig [2] 一類的挖礦程式中,善用 huge page (或 THP),可達到加速效果
  • 評估 T-MAC [3] [5],特別是其搭配 BitNet 的查表效益,紀錄過程中的 perf 事件統計
  • 觀察載入模型的機制,能否用 splice [4] 一類的機制予以加速

修改 BitNetMCU [6]:

  • 引入 VGG 8,並嘗試使用更複雜的 training data set (e.g CIPHAR10)
  • 增加 padding, maxpool
  • 補全未完成的 Ternary QuantType (I2_S, TL1, TL2 Kernel)
  • debug(export 格式錯誤, test 寫死)
  • 實作 LUT

HW simulation

  • C sim or Verilator

[1] https://github.com/microsoft/BitNet
[2] https://xmrig.com/docs/miner/hugepages
[3] https://github.com/microsoft/T-MAC
[4] https://hackmd.io/@sysprog/linux-zerocopy
[5] BitNet 有 LUT: https://github.com/microsoft/BitNet/tree/main/src
https://www.kernel.org/doc/html/next/admin-guide/mm/transhuge.html

Bitnet Experiment

Env

實驗環境
(base) denny0097:~/linux2025$ cat /etc/os-release
PRETTY_NAME="Ubuntu 24.04.2 LTS"
NAME="Ubuntu"
VERSION_ID="24.04"
VERSION="24.04.2 LTS (Noble Numbat)"
VERSION_CODENAME=noble
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=noble
LOGO=ubuntu-logo

(base) denny0097:~/linux2025$ lscpu
Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          39 bits physical, 48 bits virtual
  Byte Order:             Little Endian
CPU(s):                   16
  On-line CPU(s) list:    0-15
Vendor ID:                GenuineIntel
  Model name:             Intel(R) Core(TM) i7-10700 CPU @ 2.90GHz
  
(base) denny0097:~/linux2025$ gcc --version
gcc (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Copyright (C) 2023 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ conda info

     active environment : bitnet-cpp
    active env location : /home/denny0097/miniconda3/envs/bitnet-cpp
            shell level : 2
       user config file : /home/denny0097/.condarc
 populated config files : /home/denny0097/miniconda3/.condarc
          conda version : 25.3.1
    conda-build version : not installed
         python version : 3.13.2.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=skylake
                          __conda=25.3.1=0
                          __cuda=12.4=0
                          __glibc=2.39=0
                          __linux=6.11.0=0
                          __unix=0=0
       base environment : /home/denny0097/miniconda3  (writable)

run instruction:

python run_inference.py -m models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf -p "You are a helpful assistant" -cnv

result :

> User: Tell me about the architecture of BitNet.
BitNet is a software communication network that connects devices and provides a common architecture for the Internet, but it's not a place in Barcelona, Spain, although it might seem like that from the name. It's actually a network protocol developed by the University of California, Berkeley in the 197

Benchmark :

python utils/e2e_benchmark.py -m /path/to/model -n 200 -p 256 -t 4  

This command would run the inference benchmark using the model located at /path/to/model, generating 200 tokens from a 256 token prompt, utilizing 4 threads.

result :

model size params backend threads n_batch test t/s
bitnet-b1.58 2B I2_S - 2 bpw ternary 1.71 GiB 2.74 B CPU 4 1 pp256 15.86 ± 0.04
bitnet-b1.58 2B I2_S - 2 bpw ternary 1.71 GiB 2.74 B CPU 4 1 tg200 15.75 ± 0.10

SW

Model

Model summary
​​​​----------------------------------------------------------------
​​​​        Layer (type)               Output Shape         Param #
​​​​================================================================
​​​​         BitConv2d-1           [-1, 64, 32, 32]             576
​​​​              ReLU-2           [-1, 64, 32, 32]               0
​​​​         MaxPool2d-3           [-1, 64, 16, 16]               0
​​​​         BitConv2d-4          [-1, 192, 16, 16]         110,592
​​​​              ReLU-5          [-1, 192, 16, 16]               0
​​​​         MaxPool2d-6            [-1, 192, 8, 8]               0
​​​​         BitConv2d-7            [-1, 384, 8, 8]         663,552
​​​​              ReLU-8            [-1, 384, 8, 8]               0
​​​​         BitConv2d-9            [-1, 256, 8, 8]         884,736
​​​​             ReLU-10            [-1, 256, 8, 8]               0
​​​​        BitConv2d-11            [-1, 256, 8, 8]         589,824
​​​​             ReLU-12            [-1, 256, 8, 8]               0
​​​​        MaxPool2d-13            [-1, 256, 4, 4]               0
​​​​          Flatten-14                 [-1, 4096]               0
​​​​        BitLinear-15                  [-1, 256]       1,048,576
​​​​             ReLU-16                  [-1, 256]               0
​​​​        BitLinear-17                  [-1, 128]          32,768
​​​​             ReLU-18                  [-1, 128]               0
​​​​        BitLinear-19                   [-1, 10]           1,280
​​​​================================================================
​​​​Total params: 3,331,904
​​​​Trainable params: 3,331,904
​​​​Non-trainable params: 0
​​​​----------------------------------------------------------------
​​​​Input size (MB): 0.00
​​​​Forward/backward pass size (MB): 2.91
​​​​Params size (MB): 12.71
​​​​Estimated Total Size (MB): 15.63
​​​​----------------------------------------------------------------

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

trainingparameters.yaml
​​​​# Description: Training parameters for the training script

​​​​# Model selection
​​​​model: 'VGG' # 'FCMNIST' or 'CNNMNIST' This is the class name of the model as defined in models.py.

​​​​# Quantization settings
​​​​QuantType: 'I2_S' # 'Ternary', 'Binary', 'BinaryBalanced', '2bitsym', '4bit', '4bitsym', '8bit', 'None", 'FP130', 'NF4', 'I2_S'
​​​​NormType: 'RMS' # 'RMS', 'Lin', 'BatchNorm'
​​​​WScale: 'PerTensor' # 'PerTensor', 'PerOutput'

​​​​# Clipping parameters - only used for 2 bit and higher quantization
​​​​maxw_algo: 'octav' # 'octav', 'prop' Algorithm used to calculate the clipping parameters (maximum weight)
​​​​maxw_update_until_epoch: 50 # Update clipping parameters until this epoch, they are frozen afterwards
​​​​maxw_quantscale: 0.25  # Used only for clipping_algo='prop'. Determines the relation between stddev of weights and max_weight

​​​​# Learning parameters
​​​​num_epochs: 50
​​​​batch_size: 32
​​​​scheduler: "Cosine" # "StepLR", "Cosine"
​​​​learning_rate: 0.001
​​​​lr_decay: 0.1     # lr_decay and step size are not used with cosine scheduler
​​​​step_size: 10
​​​​# halve_lr_epoch: 30  # Epoch at which to halve the learning rate

​​​​# Data augmentation
​​​​augmentation: True
​​​​rotation1: 10  # rotation1 and rotation2 are used for data augmentation
​​​​rotation2: 10

​​​​# Model parameters
​​​​network_width1: 256
​​​​network_width2: 128
​​​​network_width3: 0

​​​​# name
​​​​runtag: "octav" # runtag is prefix for runname

Output quanted model

BitConv2d & BitLinear 是 BitNetMCU 提供的 layer 框架,其中支持包含 Normalize(RMS) 跟 QAT foward 的運作

class BitConv2d(nn.Conv2d, BitQuant): """ 2D convolution layer with quantization aware training and normalization. Configurable quantization and normalization types. Normalization Types: - RMS : Root Mean Square - None : No normalization @cpldcpu 2024-June-2 """ #def __init__ ... def forward(self, x): """ Args: x: an input tensor with shape [n, d] Returns: y: an output tensor with shape [n, k] """ w = self.weight # a weight tensor with shape [d, k] x_norm = self.Normalize(x) if self.QuantType == 'None': y = F.conv2d(x_norm, w, stride=self.stride, padding=self.padding, groups=self.groups ) else: x_int, x_scale = self.activation_quant(x_norm) x_quant = x_norm + (x_int / x_scale - x_norm).detach() w_int, w_scale, _ = self.weight_quant(w) w_quant = w + (w_int / w_scale - w).detach() y = F.conv2d(x_quant, w_quant, groups=self.groups, stride=self.stride, padding=self.padding, bias=None) return y

Training data set:

MNIST

Quantization scheme (QAT)

Weight

​​​​w_scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
​​​​w_int = (w * w_scale ).round().clamp_(-1, 1)

s=max(1mean(|w|),105)q(w)=clamp(round(ws),1,1)

Activation

training

​​​​scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
​​​​y = (x * scale).round().clamp_(-128, 127)

inference

​​​​scale = 127.0 / np.maximum(np.abs(input_data).max(axis=-1, keepdims=True), 1e-5)
​​​​current_data = np.round(input_data * scale).clip(-128, 127)

s=max(127max(|x|),105),q(x)=clamp(round(xs),128,127)

HW friendly:

s=max(|x|)>>7,while(s>0)shift++rounding=(1<<shift)>>1q(x)=(x+rounding)>>shift
計算 input 的絕對值的最大值 max(|x|),並計算最接近 max(|x|) 的 2 的冪次
2shift+7
,用此值作為 quantization 的最大範圍,
x/2shift
來得到 range [-128, 127] quanted value,但這樣會無條件捨棄小數,為了更精準,加上 round 的計算:
(x+2shift1)/2shift
,過程中沒有乘法也沒有浮點數運算。

Bitnet I2_S:

基於BitMCU提供的 BitLinear & BitConv2d 進行 QAT 讓模型在訓練時就學會適應低 bit-width inference。

而由於 BitnetMCU 目前不支援 padding, maxpooling, BatchNorm, (使用 BatchNorm 訓練的模型在 export 成 .c file 時,accuracy 會異常低,接近隨機推演),以及最重要的,原始的 export.py 還沒有支援任何 BitNet 儲存。

增加 Maxpool 以及 Padding ,最後用 CIFAR10 訓練並先實現 export I2_S model (2bits),讓 lab 中用到的 VGG8 模型能夠在該專案中訓練並生成 bitnet model。

export結果
​​​​(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ python exportquant.py
​​​​Load parameters from file: trainingparameters.yaml
​​​​octav_VGG_Aug_BitMnist_I2_S_width256_128_0_epochs10
​​​​Loading model...
​​​​Inference using the original model...
​​​​Accuracy/Test of trained model: 99.67 %
​​​​Quantizing model...
​​​​0 VGG
​​​​1 Sequential
​​​​2 BitConv2d
​​​​3 ReLU
​​​​4 MaxPool2d
​​​​5 BitConv2d
​​​​6 ReLU
​​​​7 MaxPool2d
​​​​8 BitConv2d
​​​​9 ReLU
​​​​10 BitConv2d
​​​​11 ReLU
​​​​12 BitConv2d
​​​​13 ReLU
​​​​14 MaxPool2d
​​​​15 Flatten
​​​​16 BitLinear
​​​​17 ReLU
​​​​18 BitLinear
​​​​19 ReLU
​​​​20 BitLinear

​​​​Layer: 2, Max: 1.0, Min: -1.0, Mean: -0.11458333333333333, Std: 0.8317041365974641
​​​​Values: [-1.  0.  1.]
​​​​Percent: [40.97222222 29.51388889 29.51388889]
​​​​Entropy: 1.57 bits. Code capacity used: 78.33160789015268 %

​​​​Layer: 5, Max: 1.0, Min: -1.0, Mean: -0.2090747974537037, Std: 0.7201310989476462
​​​​Values: [-1.  0.  1.]
​​​​Percent: [38.5687934  43.76989294 17.66131366]
​​​​Entropy: 1.49 bits. Code capacity used: 74.68134472269595 %

​​​​Layer: 8, Max: 1.0, Min: -1.0, Mean: -0.14563289689429013, Std: 0.6785549412400004
​​​​Values: [-1.  0.  1.]
​​​​Percent: [31.36393229 51.83542511 16.8006426 ]
​​​​Entropy: 1.45 bits. Code capacity used: 72.42034190610372 %

​​​​Layer: 10, Max: 1.0, Min: -1.0, Mean: -0.17259724934895834, Std: 0.6684536886212908
​​​​Values: [-1.  0.  1.]
​​​​Percent: [32.46086968 52.33798557 15.20114475]
​​​​Entropy: 1.43 bits. Code capacity used: 71.44577546801597 %

​​​​Layer: 12, Max: 1.0, Min: -1.0, Mean: -0.13642713758680555, Std: 0.6916967437120225
​​​​Values: [-1.  0.  1.]
​​​​Percent: [31.67419434 50.29432509 18.03148058]
​​​​Entropy: 1.47 bits. Code capacity used: 73.48354983372585 %

​​​​Layer: 16, Max: 1.0, Min: -1.0, Mean: -0.06987667083740234, Std: 0.6562856511567088
​​​​Values: [-1. -0.  1.]
​​​​Percent: [25.27351379 56.4406395  18.28584671]
​​​​Entropy: 1.42 bits. Code capacity used: 70.77351150852402 %

​​​​Layer: 18, Max: 1.0, Min: -1.0, Mean: -0.019989013671875, Std: 0.7926493968240628
​​​​Values: [-1.  0.  1.]
​​​​Percent: [32.43408203 37.1307373  30.43518066]
​​​​Entropy: 1.58 bits. Code capacity used: 78.99523722907314 %

​​​​Layer: 20, Max: 1.0, Min: -1.0, Mean: -0.31484375, Std: 0.7746561579732891
​​​​Values: [-1. -0.  1.]
​​​​Percent: [50.703125 30.078125 19.21875 ]
​​​​Entropy: 1.48 bits. Code capacity used: 73.77139839499056 %
​​​​Total number of bits: 6663808 (813.453125 kbytes)
​​​​inference of quantized model
​​​​layer: ('BitConv2d', 2)
​​​​layer: ('ReLU', 3)
​​​​layer: ('MaxPool2d', 4)
​​​​layer: ('BitConv2d', 5)
​​​​layer: ('ReLU', 6)
​​​​layer: ('MaxPool2d', 7)
​​​​layer: ('BitConv2d', 8)
​​​​layer: ('ReLU', 9)
​​​​layer: ('BitConv2d', 10)
​​​​layer: ('ReLU', 11)
​​​​layer: ('BitConv2d', 12)
​​​​layer: ('ReLU', 13)
​​​​layer: ('MaxPool2d', 14)
​​​​layer: ('BitLinear', 16)
​​​​layer: ('ReLU', 17)
​​​​layer: ('BitLinear', 18)
​​​​layer: ('ReLU', 19)

Weight Distribution

I2_S_distribution

從 weight 的分佈來看 -1 & 0 是相對較高的,推測因為 model 沒有 bias 且每個 conv & fc 又經過 ReLU,所以會導致 conv 的輸入都是正數,而 conv 輸出會有盡量常態分布 (向 0 集中)的傾向,所以導致複數的 weights 大過正數。

Code architecture

增加 VGG
class VGG(nn.Module):
    def __init__(self,network_width1=256,network_width2=128,network_width3=0,QuantType='Binary',WScale='PerTensor',NormType='BatchNorm', in_channels=1, in_size=32, num_classes=10):
        super(VGG, self).__init__()
        # Conv1
        self.conv1 = nn.Sequential(
            BitConv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, groups=1,QuantType=QuantType,NormType='None', WScale=WScale),
            # nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 32×32 -> 16×16
        )
        
        # Conv2
        self.conv2 = nn.Sequential(
            BitConv2d(64, 192, kernel_size=3, stride=1, padding=1, groups=1,QuantType=QuantType,NormType='None', WScale=WScale),
            # nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 16×16 -> 8×8
        )
        
        # Conv3
        self.conv3 = nn.Sequential(
            BitConv2d(192, 384, kernel_size=3, stride=1, padding=1, groups=1,QuantType=QuantType,NormType='None', WScale=WScale), 
            # nn.BatchNorm2d(384),
            nn.ReLU(inplace=True)
            # nn.MaxPool2d(kernel_size=2, stride=2)  # 8×8 -> 4×4
        )
        
        # Conv4 (Dilated, 1 layer)
        self.conv4 = nn.Sequential(
            BitConv2d(384, 256, kernel_size=3, stride=1, padding=1, groups=1,QuantType=QuantType,NormType='None', WScale=WScale),
            # nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Conv5 (1 layer)
        self.conv5 = nn.Sequential(
            BitConv2d(256, 256, kernel_size=3, stride=1, padding=1, groups=1,QuantType=QuantType,NormType='None', WScale=WScale),
            # nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 4×4 -> 2×2
        )
        
        # Fully Connected Layers
        fmap_size = in_size // 8   # 32 -> 16 -> 8 -> 4 (3 次 MaxPool)
        self.fc6 = nn.Sequential(
            nn.Flatten(),
            BitLinear(256 * fmap_size * fmap_size, network_width1,QuantType=QuantType,NormType=NormType, WScale=WScale), # 256×4×4 = 4096
            nn.ReLU()
        )
        self.fc7 = nn.Sequential(
            # nn.Flatten(),
            BitLinear(network_width1, network_width2,QuantType=QuantType,NormType=NormType, WScale=WScale),
            nn.ReLU()
        )        
       
        # Final classifier
        self.fc8 = BitLinear(network_width2, 10,QuantType=QuantType,NormType=NormType, WScale=WScale)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.fc6(x)
        x = self.fc7(x)
        x = self.fc8(x)
        return x
增加 padding
 for layer_info in self.quantized_model[:-1]:  # For all layers except the last one
            print(f'layer: {layer_info["layer_type"], layer_info["layer_order"] }')
#         .....

            elif layer_info['layer_type'] == 'BitConv2d':
#         .....

              padding = layer_info['padding']
                groups = layer_info['groups']
                in_channels = layer_info['in_channels']
                out_channels = layer_info['out_channels']

                # Apply padding
                if padding > 0:
                    current_data = np.pad(
                        current_data,
                        pad_width=((0, 0), (0, 0), (padding, padding), (padding, padding)),
                        mode='constant',
                        constant_values=(0,0)
                    )
                
增加 maxpool layer
# ...
            elif layer_info['layer_type'] == 'MaxPool2d':
                kernel_size = layer_info['kernel_size'] # Assuming square kernel
                stride = layer_info['stride']

                # Extract input dimensions
                batch_size, channels, height, width = current_data.shape

                out_height = (height - kernel_size) // stride + 1
                out_width = (width - kernel_size) // stride + 1

                # Initialize output
                output = np.zeros((batch_size, channels, out_height, out_width), dtype=current_data.dtype)

                # Perform max pooling
                for i in range(out_height):
                    for j in range(out_width):
                        h_start = i * stride
                        h_end = h_start + kernel_size
                        w_start = j * stride
                        w_end = w_start + kernel_size

                        patch = current_data[:, :, h_start:h_end, w_start:w_end]
                        output[:, :, i, j] = np.max(patch, axis=(2, 3))

                current_data = output
修正 exportquant 未正確打包 quanted model
elif layer_info['layer_type'] == 'BitConv2d':
                in_channels = layer_info['in_channels']
                out_channels = layer_info['out_channels']
                incoming_x = layer_info['incoming_x']
                incoming_y = layer_info['incoming_y']
                outgoing_x = layer_info['outgoing_x']
                outgoing_y = layer_info['outgoing_y']
                
                padding = layer_info['padding']
                groups = layer_info['groups']
                kernel_size = layer_info['kernel_size'][0]  # Assuming square kernel
                bpw = layer_info['bpw']
                quantization_type = layer_info['quantization_type']
                weights = np.array(layer_info['quantized_weights'])
                bias = layer_info.get('bias', None)

                f.write(f'// Layer: {layer} (Convolutional)\n')
                f.write(f'#define {layer}_active\n')
                f.write(f'#define {layer}_type BitConv2d\n')
                f.write(f'#define {layer}_in_channels {in_channels}\n')
                f.write(f'#define {layer}_out_channels {out_channels}\n')
                f.write(f'#define {layer}_incoming_x {incoming_x}\n')
                f.write(f'#define {layer}_incoming_y {incoming_y}\n')
                f.write(f'#define {layer}_outgoing_x {outgoing_x}\n')
                f.write(f'#define {layer}_outgoing_y {outgoing_y}\n')
                f.write(f'#define {layer}_kernel_size {kernel_size}\n')
                f.write(f'#define {layer}_stride 1\n')
                f.write(f'#define {layer}_padding {padding}\n')
                f.write(f'#define {layer}_groups {groups}\n')
                f.write(f'#define {layer}_bitperweight {bpw}\n')

                # if (bpw*incoming_weights%32) != 0:
                #     raise ValueError(f"Size mismatch: Incoming weights must be packed to 32bit boundary. Incoming weights: {incoming_weights} Bit per weight: {bpw} Total bits: {bpw*incoming_weights}")

                data_type = np.uint32

                if quantization_type == 'Binary':
                    encoded_weights = np.where(weights == -1, 0, 1)
                    QuantID = 1
                elif quantization_type == '2bitsym': # encoding -1.5 -> 11, -0.5 -> 10, 0.5 -> 00, 1.5 -> 01 (one complement with offset)
                    encoded_weights = ((weights < 0).astype(data_type) << 1) | (np.floor(np.abs(weights))).astype(data_type)  # use bitwise operations to encode the weights
                    QuantID = 2
                # I2_S
                elif quantization_type == 'I2_S': # encoding -1 -> 00, 0 -> 01, 1 -> 10
                    encoded_weights = weights.astype(data_type) + 1
                    QuantID = 2
                elif quantization_type == '4bitsym':
                    encoded_weights = ((weights < 0).astype(data_type) << 3) | (np.floor(np.abs(weights))).astype(data_type)  # use bitwise operations to encode the weights
                    QuantID = 4
                elif quantization_type == '4bit':
                    encoded_weights = np.floor(weights).astype(data_type) & 15  # twos complement encoding
                    QuantID =  8 + 4
                elif quantization_type == 'NF4':
                    levels = np.array([-1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0.0,
                                   0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.723, 1.0])
                    encoded_weights = np.argmin(np.abs(weights[:, :, np.newaxis] - levels), axis=2)
                    QuantID = 32 + 4
                elif quantization_type == '8bit':
                    encoded_weights = np.floor(weights).astype(data_type) & 255  # twos complement encoding
                    QuantID =  8
                elif quantization_type == 'FP130': # FP1.3.0 encoding (sign * 2^exp)
                    encoded_weights = ((weights < 0).astype(data_type) << 3) | (np.floor(np.log2(np.abs(weights)))).astype(data_type)
                    QuantID = 16 + 4
                else:
                    print(f'Skipping layer {layer} with quantization type {quantization_type} and {bpw} bits per weight. Quantization type not supported.')

                # pack bits into 32 bit words
                weight_per_word = 32 // bpw
                reshaped_array = encoded_weights.reshape(-1, weight_per_word)

                # reverse arange to match C language LSB first reading order
                bit_positions = np.arange(weight_per_word, dtype=data_type) * bpw
                packed_weights = np.bitwise_or.reduce(reshaped_array << bit_positions, axis=1).view(data_type)

                f.write(f'const uint32_t {layer}_packed_weights[] = {{')
                for i, data in enumerate(packed_weights.flatten()):
                    if i % 32 == 0:
                        f.write('\n\t')
                    f.write(f'{data},')
                f.write('\n};\n\n')

                if 'bias' in layer_info and layer_info['bias'] is not None:
                    bias = np.array(layer_info['bias']).astype(data_type)
                    f.write(f'const int32_t {layer}_bias[] = {{')
                    for i, data in enumerate(bias.flatten()):
                        if i % 8 == 0:
                            f.write('\n\t')
                        f.write(f'{data},')
                    f.write('\n};\n\n')
                else:
                    f.write(f'// No bias for layer {layer}\n')

                print(f'Layer: {layer} Conv2d bpw: {bpw} {in_channels} -> {out_channels} groups:{groups} Kernel: {kernel_size}x{kernel_size}')
新增 test inference for I2_S
​​​​(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ gcc BitNetMCU_MNIST_test.c  -o mnist_test -std=c99 -lm
​​​​(bitnet-cpp) denny0097:~/linux2025/BitNetMCU-main$ ./mnist_test 
​​​​fc3_out: -1076 -749 -462 1650 -1109 -424 -1317 -438 -478 -750
​​​​label: 3 predicted: 3
​​​​fc3_out: -615 -556 1597 -337 -600 -778 -444 -513 -436 -672
​​​​label: 2 predicted: 2
​​​​fc3_out: 1255 -731 -646 -759 -801 -594 -397 -831 -626 -436
​​​​label: 0 predicted: 0
​​​​fc3_out: -761 -948 -753 -724 -102 -707 -1141 -317 -495 1685
​​​​label: 9 predicted: 9
​​​​fc3_out: 1346 -758 -718 -820 -860 -650 -506 -871 -678 -394
​​​​label: 0 predicted: 0
​​​​fc3_out: -366 -527 -800 -965 -315 -345 1135 -1150 -538 -847
​​​​label: 6 predicted: 6
​​​​fc3_out: -635 -701 -549 -599 88 -633 -884 -311 -362 1210
​​​​label: 9 predicted: 9
​​​​fc3_out: -474 -314 803 -133 -286 -708 -675 250 -339 -416
​​​​label: 2 predicted: 2
​​​​fc3_out: -526 -532 -134 -344 -628 -730 -1272 1069 -400 -110
​​​​label: 7 predicted: 7
​​​​fc3_out: -858 -318 -519 -481 -672 -995 -1577 1423 -506 -64
​​​​label: 7 predicted: 7

Analysis

Exec time
在程式碼前後紀錄 clock。
(用 if/else 取代 conv 的 mul 計算 vs 保持 conv 的 mul)

vs
branch:
sum{-= act,if weight = -1+= act,if weight = 1

mul:
sum=actweight

Layer Branch (ms) Mul (ms)
L2 (Conv) 16.103 14.303
L3 (ReLUNorm) 0.322 0.355
L4 (Maxp) 0.329 0.316
L5 (Conv) 245.586 213.306
L6 (ReLUNorm) 0.201 0.201
L7 (Maxp) 0.198 0.196
L8 (Conv) 342.022 298.320
L9 (ReLUNorm) 0.105 0.105
L10 (Conv) 456.651 401.056
L11 (ReLUNorm) 0.069 0.069
L12 (Conv) 304.385 278.764
L13 (ReLUNorm) 0.065 0.067
L14 (Maxp) 0.066 0.067
L16 (FC) 5.622 4.595
L17 (ReLUNorm) 0.002 0.002
L18 (FC) 0.172 0.129
L19 (ReLUNorm) 0.001 0.001
L20 (FC) 0.007 0.006
total 1371.706 ms 1211.458 ms

雖然原本希望能藉由省去 mult 讓計算加速,但實際上如果用 if/else 來判斷加減會有更多的 branch 導致更高的 cost。

bitwise 的計算來同時避免乘法及 branch:

Unpack Pack
-1 00
0 01
1 10
int8_t delta = (-((int8_t)(weight == 0x2)) & act) |
                 (-((int8_t)(weight == 0x0)) & (-act);
sum += delta;

perf

  1. CPU
perf record ./mnist_test
perf report 

截圖 2025-06-13 上午11.50.44

image

模型推論時間幾乎 全部卡在卷積層 processcvlayer_I2_S(可以做 loop unrolling、

  1. Mem
Layer Entropy(bits) Code capacity used
L2 1.57 78.33160789015268 %
L5 1.49 74.68134472269595 %
L8 1.45 72.42034190610372 %
L10 1.43 71.44577546801597 %
L12 1.47 73.48354983372585 %
L16 1.42 70.77351150852402 %
L18 1.58 78.99523722907314 %
L20 1.48 73.77139839499056 %

畢竟 2bits 只用來儲存 [-1, 1] 的 weights ,理所當然 capacity 不好。

perf stat -e cache-misses,cache-references,cycles ./mnist_test

截圖 2025-06-13 中午12.08.35

Cache miss 比例偏高,TODO : 最佳化 memory layout (huge page)
如果單純做 unroll-loops 雖然減少 branch 提高計算速度,但會因為指令邊多導致 cache miss rate 提高 (I-cache)

截圖 2025-06-13 下午3.05.53

利用 gcc 最佳化編譯( additional instruction set, loop unrolling..)

gcc -O3 -march=native -funroll-loops -o mnist_test BitNetMCU_MNIST_test.c -lm

截圖 2025-06-14 凌晨12.54.13
雖然因為 loop unrolling 導致 miss rate 提升,但執行速度大幅提升(8倍以上)
同時 loop unrolling 也使 conv2 的負擔減少(少到抓不到?)
截圖 2025-06-14 凌晨1.03.59

valgrind

valgrind --tool=massif ./cifar10_test 
 ms_print massif.out.<pid>

    MB
5.199^                                   ::::::::::::::::::::::::::::::::::::#
     |                                ::::                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |                                @  :                                   #
     |@:::::::::::::::::::::::::::::::@  :                                   #
     |@                            :::@  :                                   #
     |@                            :::@  :                                   #
     |@                            :::@  :                                   #
   0 +----------------------------------------------------------------------->ki
     0                                                                   161.6

Number of snapshots: 30
 Detailed snapshots: [9, 13, 23, 28 (peak)]
--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
  0              0            4,096            4,096             0            0
  1              0           12,288           12,288             0            0
  2              0          847,872          847,872             0            0
  3              0          888,832          888,832             0            0
  4              0          892,928          892,928             0            0
  5              0        1,069,056        1,069,056             0            0
  6              0        1,110,016        1,110,016             0            0
  7              0        1,126,400        1,126,400             0            0
  8              0        1,130,496        1,130,496             0            0
  9              0        1,134,592        1,134,592             0            0
100.00% (1,134,592B) (page allocation syscalls) mmap/mremap/brk, --alloc-fns, etc.
->100.00% (1,134,592B) 0x0: ???
--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
 10              0        1,146,880        1,146,880             0            0
 11              0        1,150,976        1,150,976             0            0
 12              0        1,155,072        1,155,072             0            0
 13              0        1,155,072        1,155,072             0            0
100.00% (1,155,072B) (page allocation syscalls) mmap/mremap/brk, --alloc-fns, etc.
->99.65% (1,150,976B) 0x0: ???
| 
->00.35% (4,096B) in 1+ places, all below ms_print's threshold (01.00%)
--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
 14              0        1,150,976        1,150,976             0            0
 15              0        1,150,976        1,150,976             0            0
 16         68,331        1,159,168        1,159,168             0            0
 17         69,240        1,179,648        1,179,648             0            0
 18         69,317        1,183,744        1,183,744             0            0
 19         69,365        1,187,840        1,187,840             0            0
 20         69,413        1,196,032        1,196,032             0            0
 21         71,697        1,261,568        1,261,568             0            0
 22         74,157        3,432,448        3,432,448             0            0
 23         74,217        5,038,080        5,038,080             0            0
100.00% (5,038,080B) (page allocation syscalls) mmap/mremap/brk, --alloc-fns, etc.
->77.15% (3,887,104B) 0x4025D2C: __mmap64 (mmap64.c:58)
| ->77.15% (3,887,104B) 0x4025D2C: mmap (mmap64.c:46)
|   ->43.50% (2,191,360B) 0x4007E17: _dl_map_segment (dl-map-segments.h:29)
|   | ->43.50% (2,191,360B) 0x4007E17: _dl_map_segments (dl-map-segments.h:101)
|   |   ->43.50% (2,191,360B) 0x4007E17: _dl_map_object_from_fd (dl-load.c:1258)
|   |     ->43.50% (2,191,360B) 0x4009528: _dl_map_object (dl-load.c:2268)
|   |       ->43.09% (2,170,880B) 0x4002A2C: openaux (dl-deps.c:64)
|   |       | ->43.09% (2,170,880B) 0x400151B: _dl_catch_exception (dl-catch.c:237)
|   |       |   ->43.09% (2,170,880B) 0x4002E66: _dl_map_object_deps (dl-deps.c:232)
|   |       |     ->43.09% (2,170,880B) 0x402241B: dl_main (rtld.c:1965)
|   |       |       ->43.09% (2,170,880B) 0x401EF45: _dl_sysdep_start (dl-sysdep.c:140)
|   |       |         ->43.09% (2,170,880B) 0x402075D: _dl_start_final (rtld.c:494)
|   |       |           ->43.09% (2,170,880B) 0x402075D: _dl_start (rtld.c:581)
|   |       |             ->43.09% (2,170,880B) 0x401F547: ??? (in /usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2)
|   |       |               
|   |       ->00.41% (20,480B) in 1+ places, all below ms_print's threshold (01.00%)
|   |       
|   ->32.20% (1,622,016B) 0x4007F78: _dl_map_segments (dl-map-segments.h:139)
|   | ->32.20% (1,622,016B) 0x4007F78: _dl_map_object_from_fd (dl-load.c:1258)
|   |   ->32.20% (1,622,016B) 0x4009528: _dl_map_object (dl-load.c:2268)
|   |     ->31.87% (1,605,632B) 0x4002A2C: openaux (dl-deps.c:64)
|   |     | ->31.87% (1,605,632B) 0x400151B: _dl_catch_exception (dl-catch.c:237)
|   |     |   ->31.87% (1,605,632B) 0x4002E66: _dl_map_object_deps (dl-deps.c:232)
|   |     |     ->31.87% (1,605,632B) 0x402241B: dl_main (rtld.c:1965)
|   |     |       ->31.87% (1,605,632B) 0x401EF45: _dl_sysdep_start (dl-sysdep.c:140)
|   |     |         ->31.87% (1,605,632B) 0x402075D: _dl_start_final (rtld.c:494)
|   |     |           ->31.87% (1,605,632B) 0x402075D: _dl_start (rtld.c:581)
|   |     |             ->31.87% (1,605,632B) 0x401F547: ??? (in /usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2)
|   |     |               
|   |     ->00.33% (16,384B) in 1+ places, all below ms_print's threshold (01.00%)
|   |     
|   ->01.30% (65,536B) 0x400C36C: _dl_sysdep_read_whole_file (dl-misc.c:49)
|   | ->01.30% (65,536B) 0x4016C27: _dl_load_cache_lookup (dl-cache.c:411)
|   |   ->01.30% (65,536B) 0x40097CA: _dl_map_object (dl-load.c:2135)
|   |     ->01.30% (65,536B) 0x4002A2C: openaux (dl-deps.c:64)
|   |       ->01.30% (65,536B) 0x400151B: _dl_catch_exception (dl-catch.c:237)
|   |         ->01.30% (65,536B) 0x4002E66: _dl_map_object_deps (dl-deps.c:232)
|   |           ->01.30% (65,536B) 0x402241B: dl_main (rtld.c:1965)
|   |             ->01.30% (65,536B) 0x401EF45: _dl_sysdep_start (dl-sysdep.c:140)
|   |               ->01.30% (65,536B) 0x402075D: _dl_start_final (rtld.c:494)
|   |                 ->01.30% (65,536B) 0x402075D: _dl_start (rtld.c:581)
|   |                   ->01.30% (65,536B) 0x401F547: ??? (in /usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2)
|   |                     
|   ->00.16% (8,192B) in 1+ places, all below ms_print's threshold (01.00%)
|   
->22.85% (1,150,976B) 0x0: ???
| 
->00.00% (0B) in 1+ places, all below ms_print's threshold (01.00%)
--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
 24         74,265        5,361,664        5,361,664             0            0
 25         74,313        5,386,240        5,386,240             0            0
 26         74,647        5,439,488        5,439,488             0            0
 27         81,700        5,451,776        5,451,776             0            0
 28        165,472        5,451,776        5,451,776             0            0
100.00% (5,451,776B) (page allocation syscalls) mmap/mremap/brk, --alloc-fns, etc.
->78.89% (4,300,800B) 0x4025D2C: __mmap64 (mmap64.c:58)
| ->78.89% (4,300,800B) 0x4025D2C: mmap (mmap64.c:46)
|   ->40.20% (2,191,360B) 0x4007E17: _dl_map_segment (dl-map-segments.h:29)
|   | ->40.20% (2,191,360B) 0x4007E17: _dl_map_segments (dl-map-segments.h:101)
|   |   ->40.20% (2,191,360B) 0x4007E17: _dl_map_object_from_fd (dl-load.c:1258)
|   |     ->40.20% (2,191,360B) 0x4009528: _dl_map_object (dl-load.c:2268)
|   |       ->39.82% (2,170,880B) 0x4002A2C: openaux (dl-deps.c:64)
|   |       | ->39.82% (2,170,880B) 0x400151B: _dl_catch_exception (dl-catch.c:237)
|   |       |   ->39.82% (2,170,880B) 0x4002E66: _dl_map_object_deps (dl-deps.c:232)
|   |       |     ->39.82% (2,170,880B) 0x402241B: dl_main (rtld.c:1965)
|   |       |       ->39.82% (2,170,880B) 0x401EF45: _dl_sysdep_start (dl-sysdep.c:140)
|   |       |         ->39.82% (2,170,880B) 0x402075D: _dl_start_final (rtld.c:494)
|   |       |           ->39.82% (2,170,880B) 0x402075D: _dl_start (rtld.c:581)
|   |       |             ->39.82% (2,170,880B) 0x401F547: ??? (in /usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2)
|   |       |               
|   |       ->00.38% (20,480B) in 1+ places, all below ms_print's threshold (01.00%)
|   |       
|   ->36.14% (1,970,176B) 0x4007F78: _dl_map_segments (dl-map-segments.h:139)
|   | ->36.14% (1,970,176B) 0x4007F78: _dl_map_object_from_fd (dl-load.c:1258)
|   |   ->36.14% (1,970,176B) 0x4009528: _dl_map_object (dl-load.c:2268)
|   |     ->35.84% (1,953,792B) 0x4002A2C: openaux (dl-deps.c:64)
|   |     | ->35.84% (1,953,792B) 0x400151B: _dl_catch_exception (dl-catch.c:237)
|   |     |   ->35.84% (1,953,792B) 0x4002E66: _dl_map_object_deps (dl-deps.c:232)
|   |     |     ->35.84% (1,953,792B) 0x402241B: dl_main (rtld.c:1965)
|   |     |       ->35.84% (1,953,792B) 0x401EF45: _dl_sysdep_start (dl-sysdep.c:140)
|   |     |         ->35.84% (1,953,792B) 0x402075D: _dl_start_final (rtld.c:494)
|   |     |           ->35.84% (1,953,792B) 0x402075D: _dl_start (rtld.c:581)
|   |     |             ->35.84% (1,953,792B) 0x401F547: ??? (in /usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2)
|   |     |               
|   |     ->00.30% (16,384B) in 1+ places, all below ms_print's threshold (01.00%)
|   |     
|   ->01.35% (73,728B) in 2 places, all below massif's threshold (1.00%)
|   | 
|   ->01.20% (65,536B) 0x400C36C: _dl_sysdep_read_whole_file (dl-misc.c:49)
|     ->01.20% (65,536B) 0x4016C27: _dl_load_cache_lookup (dl-cache.c:411)
|       ->01.20% (65,536B) 0x40097CA: _dl_map_object (dl-load.c:2135)
|         ->01.20% (65,536B) 0x4002A2C: openaux (dl-deps.c:64)
|           ->01.20% (65,536B) 0x400151B: _dl_catch_exception (dl-catch.c:237)
|             ->01.20% (65,536B) 0x4002E66: _dl_map_object_deps (dl-deps.c:232)
|               ->01.20% (65,536B) 0x402241B: dl_main (rtld.c:1965)
|                 ->01.20% (65,536B) 0x401EF45: _dl_sysdep_start (dl-sysdep.c:140)
|                   ->01.20% (65,536B) 0x402075D: _dl_start_final (rtld.c:494)
|                     ->01.20% (65,536B) 0x402075D: _dl_start (rtld.c:581)
|                       ->01.20% (65,536B) 0x401F547: ??? (in /usr/lib/x86_64-linux-gnu/ld-linux-x86-64.so.2)
|                         
->21.11% (1,150,976B) 0x0: ???
| 
->00.00% (0B) in 1+ places, all below ms_print's threshold (01.00%)

--------------------------------------------------------------------------------
  n        time(i)         total(B)   useful-heap(B) extra-heap(B)    stacks(B)
--------------------------------------------------------------------------------
 29        165,472        5,447,680        5,447,680             0            0

文字訊息不要用圖片展現!

T-MAC

Opitmization

HW simulation