# Binarized Neural Networks: Training Neural Networks with Weights and Activations Constrained to +1 or −1 作者:Matthieu Courbariaux, Itay Hubara, Daniel Soudry, Ran El-Yaniv, Yoshua Bengio 論文連結:https://arxiv.org/abs/1602.02830 整理by: [chewei](https://hackmd.io/@WTuIbJANSB26DiAX-WL4Sg) - - - - - ## 前言: 二元神經網路 這是2016一篇把quantize做到1bit的論文,基於Torch7 和 Theano的框架, 作到了當時對MNIST, CIFAR-10 和 SVHN datasets接近的SOTA效果,並且撰寫了 一個GPU kernel,可再對MNIST任務加速7倍。 - - - - - ## 0. 論文貢獻 儘管模型精度降低了非常多,但是在訓練效果卻不比全精度(32bit)的網路差,有的時候二值化後的訓練效果甚至會超越全精度網路,因為二值化過程给神經網路帶來了noise,就像dropout一樣,反而是一種regularization,可以部分避免網路的overfitting。 透過: * 將權重及激活函數二元化 $\Rightarrow$ 降低60%時間複雜度 $\Rightarrow$ 在MNIST任務上獲得7倍的加速 --- ## 1. 首先我們來講解二元化函式的一些細節 #### 1.1. 比較Deterministic與Stochastic的Binarization ***Deterministic function(AKA Sign function):*** $$x^b = \text{Sign}(x) = \begin{cases} +1 & \text{if } x \geq 0, \\ -1 & \text{otherwise}, \end{cases} $$ $x^b$ 是binarized後的變數 x 是轉換前的真實變數 ***Stochastic function:*** $$x^b = \begin{cases} +1 & with \, probability \,\, p=\sigma(x),\\ -1 & with \, probability \,\, 1-p, \end{cases} $$ ***$\sigma$是"hard sigmoid" 函式:*** $$ \sigma(x)=\text{clip}( \frac{x+1}{2} \quad,0,1)= \text{max}(0,\text{min}(1,\frac{x+1}{2})) $$ $\Rightarrow$ 作者認為雖然stochastic較吸引人,但考慮到在硬體上較難以實現在量化時隨機生成bits, 因此最終決定大部分皆採用Sign函式做實驗 #### 1.2. 梯度的計算和累計 * 雖然是使用二元化的權重及啟發函數做訓練,梯度仍然是在真實參數下運作,因此在訓練時還是需要Stochasic Gradient Descent(SGD),它能幫助在累積參數時減少雜訊。 * 此外作者表示在訓練時加入noise也能增加模型表現。 * 另外作者也提到,BNN這種作法也可以看做是一種dropout。 <details> <summary>計算參數梯度演算法</summary> #### 訓練一個 BNN **輸入**: 輸入和目標的一個小批次 \(($a_0$, $a^*$)\),先前的權重 \(W\),先前的BatchNorm參數 $\theta$,權重初始化係數 $\gamma$,以及先前的學習率 $\eta$。 **輸出**: 更新的權重 $W^{t+1}$,更新的BatchNorm參數 $\theta^{t+1}$ 和更新的學習率 $\eta^{t+1}$。 **計算參數梯度**: * 前向傳播 : $$ \begin{align*} & \text{for } k = 1 \text{ to } L \text{ do} \\ & \quad W_b^k \leftarrow \text{Binarize}(W^k) \\ & \quad s_k \leftarrow a_b^{k-1} W_b^k \\ & \quad a_k \leftarrow \text{BatchNorm}(s_k, \theta_k) \\ & \quad \text{if } k < L \text{ then} \\ & \quad \quad a_b^k \leftarrow \text{Binarize}(a_k) \\ & \text{end if} \\ & \text{end for} \end{align*} $$ * 後向傳播(並非binary): $$ \begin{align*} & \text{計算 } g_{a_L} = \frac{\partial C}{\partial a_L} \text{ 已知 } a_L \text{ 和 } a^* \\ & \text{for } k = L \text{ to } 1 \text{ do} \\ & \quad \text{if } k < L \text{ then} \\ & \quad \quad g_{a_k} \leftarrow g_{a_b} \odot 1_{|a_k| \leq 1} \\ & \quad \text{end if} \\ & \quad (g_{s_k}, g_{\theta_k}) \leftarrow \text{BackBatchNorm}(g_{a_k}, s_k, \theta_k) \\ & \quad g_{a_b^{k-1}} \leftarrow g_{s_k} W_b^k \\ & \quad g_{W_b^k} \leftarrow g_{s_k}^T a_b^{k-1} \\ & \text{end for} \\ \end{align*} $$ * 累計gradient參數 : $$ \begin{align*} & \text{for } k = 1 \text{ to } L \text{ do} \\ & \quad \theta_k^{t+1} \leftarrow \text{Update}(\theta_k, \eta, g_{\theta_k}) \\ & \quad W_k^{t+1} \leftarrow \text{Clip}(\text{Update}(W_k; \gamma\eta, g_{W_b^k}), -1, 1) \\ & \quad \eta^{t+1} \leftarrow \lambda\eta \\ & \text{end for} \end{align*} $$ </details> #### 1.3. 梯度離散化(STE) 由於在經過sign函式後所有直接會變為1,-1,導致在backpropagation時會發生無法微分的問題, 因此作者透過在backpropagation改以“straight-through estimator”的方式有效解決不能微分的問題。 當sign函式為: $$ q=\text{Sign}(r) $$ 則STE可以下列公式表示: $$ g_r=g_q 1_{|r|\leq1} $$ **註:** 當 |r|(即 r 的絕對值)小於1時,函數值為1;當 |r| 大於或等於1時,函數值為0。 所以,當 r 的絕對值小於1時,$g_r$ 將會等於 $g_q$;當 r 的絕對值大於或等於1時,$g_r$ 將會是0,這種方法用於保留梯度信息,而當 r 過大時則取消梯度,這樣做可以避免性能顯著惡化, 其中,$1_{|r|\leq1}$的計算公式就是Htanh: $$ Htanh(x)=Clip(x,-1,1)=max(-1,min(1,x)) $$ ![image](https://hackmd.io/_uploads/SkWPLYea6.png) #### 1.4. Shift-based Batch Normalization 作者認為一般的Batch Normalization(BN)需要的乘法元素太多,因此採用相對運算量較少 (幾乎不須做乘法)的shift-based batch normalization(SBN)。 並且Shift-based BN 的好處是: 1. 可以減少運算量 2. 可通過動態調整標準化過程中的位移參數,來更精細地控制數據的均值和方差,從而達到更好的訓練效果和模型穩定性。 <details> <summary>SBN演算法</summary> 應用於一個Mini Batch的激活函數(x),其中 AP2 是取二的次方之近似值,而" <<>> "代表左右二進位位移。 **Require:** Mini Batch的 $x$ 值: $B$ = {$x_{1...m}$}; 要學習的參數:( $\gamma$, $\beta$ ) **Ensure:** { $y_i$= BN($x_i$, $\gamma$, $\beta$) } $$ \begin{align*} & \mu_B \leftarrow \frac{1}{m} \sum_{i=1}^m x_i & {mini-batch \,mean} \\ & C(x_i) \leftarrow (x_i - \mu_B) & {centered \,input} \\ & \sigma_B^2 \leftarrow \frac{1}{m} \sum_{i=1}^m (C(x_i) <<>> AP2(C(x_i))) & {apx \,variance} \\ & \hat{x}_i \leftarrow C(x_i) <<>> AP2((\frac{\sigma_B^2 + \epsilon}{B})^{-1}) & {normalize} \\ & y_i \leftarrow AP2(\gamma) <<>> \hat{x}_i & {scale \,and \,shift} \\ \end{align*} $$ </details> #### 1.5. Shift-based AdaMax 和Shift-based Batch Normalization相同,使用shift-based AdaMax相較於一般的ADAM能減少 所需的運算量,使模型更加輕量化。 <details> <summary>Shift-based AdaMax演算法</summary> **Require:** 先前的參數 $\theta_{t-1}$ 及其梯度 $g_t$,和學習率 $\alpha$。 **Ensure:** 更新後的參數 $\theta_t$。 $\alpha$=$2^{-10}$,1-$\beta_1$=$2^{-3}$,1-$\beta_2$=$2^{-10}$ $$ \begin{align*} & m_t \leftarrow \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t & biased動量估計 \\ & v_t \leftarrow \max(\beta_2 \cdot v_{t-1}, |g_t|) & biased動量估計 \\ & \theta_t \leftarrow \theta_{t-1} - (\alpha \ll \gg (1 - \beta_1)) \cdot \hat{m} \ll \gg v_t^{-1} & 更新參數 \\ \end{align*} $$ </details> #### 1.6.架構中第一層的問題 在BNN架構中,除了第一層的原始輸入外,其他層的輸入及輸出皆為二元化後的值, 然而假設第一層輸入的是一個rgb的圖像,若直接轉換成二元化作者認為會損失過多的訊息, 因此在第一層使用不同的演算法來行轉換。 > 舉例來說:如果我們有一個8bits的數字來表示一個像素的紅色通道的強度,我們可以將這個8位的數字與一組1bit的二進位權重(只有1和-1)相乘,來計算這個象素對應於某個過濾器的激活值。這樣就可以在保持計算簡單的同時,也能夠處理比較複雜的輸入數據。 $$ \begin{align*} & s=x\cdot w^b ,\\ & s=\sum_{i=1}^m x_i \end{align*} $$ ## 2.SWAR : SIMD (single instruction, multiple data) within a register 這個方法使得在Run-Time時達到了7倍的加速: * 使用**XNOR-Count**代替原本32-bit floating point的乘法 $\Rightarrow$ 大量降低運算量 * **XNOR-popcount** $\Rightarrow$ 把做完XNOR後的1數量加起來,用以代替乘法 $$ \begin{align*} & a_1+=popcount(xnor(a_0^{32b},w_1^{32b})) \\ & a_1為相加後的累計結果,a_0^{32b}為輸入的總和,w_1^{32b}為權重的總和 \end{align*} $$ 依照上面公式理論上來說應該至少會加速: $1 + 4 + 1 = 6$ clock cycles $\Rightarrow 32/6\approx5.3$倍 然而實際上可以達到6倍以上! ## 3.實驗數據與結果 作者使用Torch7和Theano做實驗: **Dataset:** MNIST,CIFAR-10,SVHN 1. Torch7 * **Activation:** stochastically binarized * **Batch Normalization:** shift-based BN * **Learning rule:** Shift-based AdaMax 2. Theano * **Activation:** deterministically binarized * **Batch Normalization:** vanilla BN * **Learning rule:** vanilla ADAM **實驗結果(Error Rates)如下圖:** * 雖然Error Rates跟其他模型相比看似並無明顯下降,但在時間複雜度方面可以降低60%左右。 ![image](https://hackmd.io/_uploads/Hy3yTde6T.png) * 而在記憶體耗能方面,也能減少相當大的記憶體使用率。 ![image](https://hackmd.io/_uploads/Byg5mkFga6.png) ![image](https://hackmd.io/_uploads/rk8VkFe6a.png) ## Reference [1] [BinaryConnect: Training Deep Neural Networks with binary weights during propagations ](https://hackmd.io/Jj3GJmi4QgieqBqI-k9fhw?view) [2] EBP