# Binarized Neural Networks: Training Neural Networks with Weights and Activations Constrained to +1 or −1
作者:Matthieu Courbariaux, Itay Hubara, Daniel Soudry, Ran El-Yaniv, Yoshua Bengio
論文連結:https://arxiv.org/abs/1602.02830
整理by: [chewei](https://hackmd.io/@WTuIbJANSB26DiAX-WL4Sg)
- - - - -
## 前言: 二元神經網路
這是2016一篇把quantize做到1bit的論文,基於Torch7 和 Theano的框架,
作到了當時對MNIST, CIFAR-10 和 SVHN datasets接近的SOTA效果,並且撰寫了
一個GPU kernel,可再對MNIST任務加速7倍。
- - - - -
## 0. 論文貢獻
儘管模型精度降低了非常多,但是在訓練效果卻不比全精度(32bit)的網路差,有的時候二值化後的訓練效果甚至會超越全精度網路,因為二值化過程给神經網路帶來了noise,就像dropout一樣,反而是一種regularization,可以部分避免網路的overfitting。
透過:
* 將權重及激活函數二元化
$\Rightarrow$ 降低60%時間複雜度
$\Rightarrow$ 在MNIST任務上獲得7倍的加速
---
## 1. 首先我們來講解二元化函式的一些細節
#### 1.1. 比較Deterministic與Stochastic的Binarization
***Deterministic function(AKA Sign function):***
$$x^b =
\text{Sign}(x) = \begin{cases}
+1 & \text{if } x \geq 0, \\
-1 & \text{otherwise},
\end{cases} $$
$x^b$ 是binarized後的變數
x 是轉換前的真實變數
***Stochastic function:***
$$x^b =
\begin{cases}
+1 & with \, probability \,\, p=\sigma(x),\\
-1 & with \, probability \,\, 1-p,
\end{cases}
$$
***$\sigma$是"hard sigmoid" 函式:***
$$
\sigma(x)=\text{clip}(
\frac{x+1}{2} \quad,0,1)=
\text{max}(0,\text{min}(1,\frac{x+1}{2}))
$$
$\Rightarrow$ 作者認為雖然stochastic較吸引人,但考慮到在硬體上較難以實現在量化時隨機生成bits,
因此最終決定大部分皆採用Sign函式做實驗
#### 1.2. 梯度的計算和累計
* 雖然是使用二元化的權重及啟發函數做訓練,梯度仍然是在真實參數下運作,因此在訓練時還是需要Stochasic Gradient Descent(SGD),它能幫助在累積參數時減少雜訊。
* 此外作者表示在訓練時加入noise也能增加模型表現。
* 另外作者也提到,BNN這種作法也可以看做是一種dropout。
<details>
<summary>計算參數梯度演算法</summary>
#### 訓練一個 BNN
**輸入**: 輸入和目標的一個小批次 \(($a_0$, $a^*$)\),先前的權重 \(W\),先前的BatchNorm參數 $\theta$,權重初始化係數 $\gamma$,以及先前的學習率 $\eta$。
**輸出**: 更新的權重 $W^{t+1}$,更新的BatchNorm參數 $\theta^{t+1}$ 和更新的學習率 $\eta^{t+1}$。
**計算參數梯度**:
* 前向傳播 :
$$
\begin{align*}
& \text{for } k = 1 \text{ to } L \text{ do} \\
& \quad W_b^k \leftarrow \text{Binarize}(W^k) \\
& \quad s_k \leftarrow a_b^{k-1} W_b^k \\
& \quad a_k \leftarrow \text{BatchNorm}(s_k, \theta_k) \\
& \quad \text{if } k < L \text{ then} \\
& \quad \quad a_b^k \leftarrow \text{Binarize}(a_k) \\
& \text{end if} \\
& \text{end for}
\end{align*}
$$
* 後向傳播(並非binary):
$$
\begin{align*}
& \text{計算 } g_{a_L} = \frac{\partial C}{\partial a_L} \text{ 已知 } a_L \text{ 和 } a^* \\
& \text{for } k = L \text{ to } 1 \text{ do} \\
& \quad \text{if } k < L \text{ then} \\
& \quad \quad g_{a_k} \leftarrow g_{a_b} \odot 1_{|a_k| \leq 1} \\
& \quad \text{end if} \\
& \quad (g_{s_k}, g_{\theta_k}) \leftarrow \text{BackBatchNorm}(g_{a_k}, s_k, \theta_k) \\
& \quad g_{a_b^{k-1}} \leftarrow g_{s_k} W_b^k \\
& \quad g_{W_b^k} \leftarrow g_{s_k}^T a_b^{k-1} \\
& \text{end for} \\
\end{align*}
$$
* 累計gradient參數 :
$$
\begin{align*}
& \text{for } k = 1 \text{ to } L \text{ do} \\
& \quad \theta_k^{t+1} \leftarrow \text{Update}(\theta_k, \eta, g_{\theta_k}) \\
& \quad W_k^{t+1} \leftarrow \text{Clip}(\text{Update}(W_k; \gamma\eta, g_{W_b^k}), -1, 1) \\
& \quad \eta^{t+1} \leftarrow \lambda\eta \\
& \text{end for}
\end{align*}
$$
</details>
#### 1.3. 梯度離散化(STE)
由於在經過sign函式後所有直接會變為1,-1,導致在backpropagation時會發生無法微分的問題,
因此作者透過在backpropagation改以“straight-through estimator”的方式有效解決不能微分的問題。
當sign函式為:
$$
q=\text{Sign}(r)
$$
則STE可以下列公式表示:
$$
g_r=g_q 1_{|r|\leq1}
$$
**註:** 當 |r|(即 r 的絕對值)小於1時,函數值為1;當 |r| 大於或等於1時,函數值為0。
所以,當 r 的絕對值小於1時,$g_r$ 將會等於 $g_q$;當 r 的絕對值大於或等於1時,$g_r$
將會是0,這種方法用於保留梯度信息,而當 r 過大時則取消梯度,這樣做可以避免性能顯著惡化,
其中,$1_{|r|\leq1}$的計算公式就是Htanh:
$$
Htanh(x)=Clip(x,-1,1)=max(-1,min(1,x))
$$

#### 1.4. Shift-based Batch Normalization
作者認為一般的Batch Normalization(BN)需要的乘法元素太多,因此採用相對運算量較少
(幾乎不須做乘法)的shift-based batch normalization(SBN)。
並且Shift-based BN 的好處是:
1. 可以減少運算量
2. 可通過動態調整標準化過程中的位移參數,來更精細地控制數據的均值和方差,從而達到更好的訓練效果和模型穩定性。
<details>
<summary>SBN演算法</summary>
應用於一個Mini Batch的激活函數(x),其中 AP2 是取二的次方之近似值,而" <<>> "代表左右二進位位移。
**Require:** Mini Batch的 $x$ 值: $B$ = {$x_{1...m}$};
要學習的參數:( $\gamma$, $\beta$ )
**Ensure:** { $y_i$= BN($x_i$, $\gamma$, $\beta$) }
$$
\begin{align*}
& \mu_B \leftarrow \frac{1}{m} \sum_{i=1}^m x_i & {mini-batch \,mean} \\
& C(x_i) \leftarrow (x_i - \mu_B) & {centered \,input} \\
& \sigma_B^2 \leftarrow \frac{1}{m} \sum_{i=1}^m (C(x_i) <<>> AP2(C(x_i))) & {apx \,variance} \\
& \hat{x}_i \leftarrow C(x_i) <<>> AP2((\frac{\sigma_B^2 + \epsilon}{B})^{-1}) & {normalize} \\
& y_i \leftarrow AP2(\gamma) <<>> \hat{x}_i & {scale \,and \,shift} \\
\end{align*}
$$
</details>
#### 1.5. Shift-based AdaMax
和Shift-based Batch Normalization相同,使用shift-based AdaMax相較於一般的ADAM能減少
所需的運算量,使模型更加輕量化。
<details>
<summary>Shift-based AdaMax演算法</summary>
**Require:** 先前的參數 $\theta_{t-1}$ 及其梯度 $g_t$,和學習率 $\alpha$。
**Ensure:** 更新後的參數 $\theta_t$。
$\alpha$=$2^{-10}$,1-$\beta_1$=$2^{-3}$,1-$\beta_2$=$2^{-10}$
$$
\begin{align*}
& m_t \leftarrow \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t & biased動量估計 \\
& v_t \leftarrow \max(\beta_2 \cdot v_{t-1}, |g_t|) & biased動量估計 \\
& \theta_t \leftarrow \theta_{t-1} - (\alpha \ll \gg (1 - \beta_1)) \cdot \hat{m} \ll \gg v_t^{-1} & 更新參數 \\
\end{align*}
$$
</details>
#### 1.6.架構中第一層的問題
在BNN架構中,除了第一層的原始輸入外,其他層的輸入及輸出皆為二元化後的值,
然而假設第一層輸入的是一個rgb的圖像,若直接轉換成二元化作者認為會損失過多的訊息,
因此在第一層使用不同的演算法來行轉換。
> 舉例來說:如果我們有一個8bits的數字來表示一個像素的紅色通道的強度,我們可以將這個8位的數字與一組1bit的二進位權重(只有1和-1)相乘,來計算這個象素對應於某個過濾器的激活值。這樣就可以在保持計算簡單的同時,也能夠處理比較複雜的輸入數據。
$$
\begin{align*}
& s=x\cdot w^b ,\\
& s=\sum_{i=1}^m x_i
\end{align*}
$$
## 2.SWAR : SIMD (single instruction, multiple data) within a register
這個方法使得在Run-Time時達到了7倍的加速:
* 使用**XNOR-Count**代替原本32-bit floating point的乘法
$\Rightarrow$ 大量降低運算量
* **XNOR-popcount**
$\Rightarrow$ 把做完XNOR後的1數量加起來,用以代替乘法
$$
\begin{align*}
& a_1+=popcount(xnor(a_0^{32b},w_1^{32b})) \\
& a_1為相加後的累計結果,a_0^{32b}為輸入的總和,w_1^{32b}為權重的總和
\end{align*}
$$
依照上面公式理論上來說應該至少會加速:
$1 + 4 + 1 = 6$ clock cycles
$\Rightarrow 32/6\approx5.3$倍
然而實際上可以達到6倍以上!
## 3.實驗數據與結果
作者使用Torch7和Theano做實驗:
**Dataset:** MNIST,CIFAR-10,SVHN
1. Torch7
* **Activation:** stochastically binarized
* **Batch Normalization:** shift-based BN
* **Learning rule:** Shift-based AdaMax
2. Theano
* **Activation:** deterministically binarized
* **Batch Normalization:** vanilla BN
* **Learning rule:** vanilla ADAM
**實驗結果(Error Rates)如下圖:**
* 雖然Error Rates跟其他模型相比看似並無明顯下降,但在時間複雜度方面可以降低60%左右。

* 而在記憶體耗能方面,也能減少相當大的記憶體使用率。


## Reference
[1] [BinaryConnect: Training Deep Neural Networks with binary weights during propagations ](https://hackmd.io/Jj3GJmi4QgieqBqI-k9fhw?view)
[2] EBP