Branchformer - Parallel MLP-Attention Architecture for Speech Recognition and Understanding
==
###### tags: `碩` `ML` `ASR`
## 概要
先前 Conformer 的提出證明了 local relationship 有助於提升 speech processing 任務的效能,透過 self-attention 提取 global context、convolution 提取 local relationship,讓模型可以專注於不同面向的資訊以達到更好的效果。
受到 Conformer 的啟發,本篇論文提出了將提取 global 與 local 資訊的 module 分開為兩個平行的分支 (branch),並命名為 *Branchformer*。Branchformer 是針對 **encoder** 去設計,目的是讓模型可以提取更多範圍 (various ranged) 之間的相關性,分支的設計讓 Branchformer 有以下優點:
1. 模型設計有彈性
2. 可根據目標客製分枝
3. 模型較好解釋
分枝的設計想法是將提取 global context 與 local relationship 的module 分別拆開為 attention 與 cgMLP (convolution gating MLP) module,並以可學習的權重來合併分枝,權重的值表示 global 和 local 對於當前的層的重要性,這個設計讓模型更有彈性好模改也可容易解釋。
本篇論文的最後也實際提出了如何修改分之以減少複雜度和客製化的例子。
## Introduction
Conformer 使用 convolution 和 self-attention 提取 local 和 global contextual 資訊,並使用序列的方式串聯兩個 module,這種做法有以下缺點
- 不易於模型的解釋
- Squential 的組合 global 與 local module 會犧牲一些準確度
- 模型不易修改
- 每一層的 global & local 之間的交互關係不清楚
Conformer 是依照固定的權重去看 local 和 global 資訊,但應該要讓模型決定要看哪個,尤其在初始層的時候更為重要,再來看兩個 module 的順序可以發現 local 的資訊是從 global 的部分抽的,很明顯會比直接從原始資料抽準確性稍差,且 global 和 local 的資訊在每一層中並不一定是同等重要的。
<center>
##### Figure 1. Conformer Architecture

</center>0
Branchformer 的提出改善了上述的問題,將 Conformer 拆成了兩枝:
1. **Global**:使用 self-attention 或其他變體,捕捉 global 資訊。
2. **Local**:使用 convolution,此處採用 Multi-layer perceptron (MLP),並加入 gating 機制 (gMLP) 來抽取 local 關係。
因為雙分枝的緣故,branchformer 可以根據需求更換個別分枝所使用的架構,例如為降低複雜度在 attention module 使用 fastformer,這樣的設計讓兩個分枝在合併的時候可以被可學習的權重控制,讓模型學習到在不同的狀況下,哪個分枝的資訊較為重要,使模型更加彈性,另外雙分枝可以在 inference 時將 attention 端關閉來加速處理。
## Related Work
### 各種架構
| 架構 | 優點 | 缺點 |
| ----------- | ------------------------------------------------ | ------------------------------ |
| RNN | 可以 model 時間上的關係 | 無法平行化、關係隨著距離會衰減 |
| CNN | 可以有效 model 近距離的相關性、shift-invariant。 | NA |
| Transformer | 有效捕捉長距的相關性,不受距離影響 | 複雜度隨著序列長度承平方上升 |
| MLP | 內度較簡單,模型不深 | 只接受固定長度的輸入 |
其中 MLP 為較簡單的架構,近年來又開始將它重新應用回語言處理與電腦視覺領域,可被其他模型如 convolution 取代,根據 MLP-based 模型的研究,已經可以跟其他更複雜的模型比較了,此外在其中一篇研究中發現 MLP 加上 convolution 作為 gating 可以取得不錯的效果,因此本篇也使用 cgMLP (convolution gating MLP) 取代原本 Conformer 的 convolutionl module。
> Branchformer 使用 convoltion、MLP、self-attention 各自做為單一的模塊。
### Modeling Both Local and Global Context
Local 與 global 的資訊對語言處理來說都是重要的,Conformer 使用序列的方式做結合並探討了兩個模塊的擺放順序的影響,但序列的方式很難去做分析,且以固定的模式混和不會永遠都是最佳解,因此如何擺放與結合兩種資訊變成重要的問題。
> Convolution 對於語音處理有顯著的提升,在連續的語音資料中可以幫助 local 相關性的建模。
這邊總結一下,Branchformer 注重 local 和 global 模塊的擺放設計與兩種資訊如何混和,主要要達成**好訓練**、**易分析**、**彈性化**、**inference 可加速**。
## 模型架構
看到 Figure 3.,raw audio 訊號會經過 frontend 的 module 處理並抽成 Mel 參數,接著會使用 **convolutionl subsampling** 來降取樣,再送入 Branchformer 之前會再加入 position encoding,最後經過 N 次的 Branchformer encoder block 來抽取 local 和 global 資訊。
<center>
##### Figure 2. Architecture of Branchformer encoder block & Figure 3. Overall architecture of the encoder
<img src="https://i.imgur.com/lCVSVZB.png" width = "500"/> <img src="https://hackmd.io/_uploads/r1YXUS84n.png" width = "200"/>
</center>
### Attention Branch - Global Context
此分支用於建模輸入序列的 global 資訊的相關性,在進入 Branchformer encoder block 之前會加上 relative position encoding。
- #### Multi-headed Self-Attention
輸入 $X \in \mathbb{R}^{T \times d}$,$T$ 是序列長度,$d$ 是特徵維度,MHSA 首先會將輸入轉換成 $Q,K,V \in \mathbb{R}^{T \times d}$ (*query*、*key*、*value*) 三個矩陣,且投影內的參數為可學習的,$Q$ 與 $K$ 會做內積並經過 $softmax$ 得到一組權重代表每個位置資訊的重要性,最後跟 $V$ 相乘得到輸出。
$$
{\rm Attention(Q,K,V)} = {\rm softmax}(\frac{{\rm QK^T}}{\sqrt{d}}){\rm V}
$$
在 attention 的數學式中,${\rm softmax}(\frac{{\rm QK^T}}{\sqrt{d}})$ 的矩陣相乘,由於 $d$ 是常數,而 $T$ 則為輸入長度,而複雜度與 $T$ 承平方關係,當輸入越長,輸出的速度越慢。
MHSA 實際上輸入會經過 $h$ 次的投影,這些投影是平行化的,因此在最後需要將每個 attention head 的輸出組合起來再投影成原本的大小,才是 MHSA 最後的輸出。
$$
{\rm MultiHead}(Q,K,V)={\rm concat(head_1,...,head_h)W}^O \\
{\rm head}_i={\rm Attention}({\rm QW}_i^Q, {\rm KW}_i^K, {\rm VW}_i^V)
$$
此處 ${\rm W}_i^Q, {\rm W}_i^K, {\rm W}_i^V \in \mathbb{R}^{d \times d/h}$,為投影矩陣,將 $Q,K,V$ 投影到較低維度;${\rm W}^O \in \mathbb{R}^{d \times d}$,將各個 head 組合的結果投影轉換為最後的輸出。
- #### 其他 Attention
為了改善 Transformer 的複雜度,這邊參考了 Fastformer 在差不多的表現下更能降低到線性的複雜度,在此處的關鍵是 attention-based pooling,概念上是根據 global context 將整個序列總結成一個單一的向量,首先定義 pooling 的輸入序列 $q_1,q_2,...,q_T \in \mathbb{R}^d$ 接著輸出一組 weighted sum $q=\sum_{i=1}^T \alpha_i q_i$,其中 $\alpha_i$ 是 attention 的權重,是透過一組可學習的參數 $w$ 將輸入轉換一個值,然過過 $softmax$ 得到。
$$
\alpha_i = \frac{{\rm exp}(w^T q_i/\sqrt{d})}{\sum_{j=1}^T{\rm exp}(w^T q_j/\sqrt{d})}
$$
從數學式可以看出來,attention based pooling 的複雜度跟序列長度 $T$ 呈線性關係,且有辦法可以捕捉 global 的關係,並達到可與 self-attention 比擬的效果。
- #### Fastformer
<center>
##### Figure 4. Architecture of Fastformer block
<img src="https://i.imgur.com/DfcdCg5.png" width = "300"/>
</center>
Fastformer 一樣先將輸入經過投影轉換為一堆 $Q,K,V$ 向量,但會將原本 attention 的 $Q,K,V$ 值先經過 attention pooling 總結為一個向量,運作順序為 $Q$ 會先總結為一個向量,這個向量包含了 global contextual 資訊,接著會去跟 $K$ 的每個向量相乘 $p_i=q*k_i$ 為一 element wise 的相乘,同樣的步驟再對 $K,V$ 做一次,但最後的輸出會經過一次轉換作為 Fastformer block 的輸出。
---
### MLP Branch - Local Context
<center>
##### Figure 5. Architecture of cgMLP

</center>
MLP with convolution gating 是提取 local context 資訊的關鍵,其藉由 depth-with convolution 和 linear gating 的強大效能來實現,cgMLP 比起 Conformer 的 convolution module 效能要好,其主要組成是由 ***Gaussian error Linear Unit (GeLU)***、***convolution spatial gating unit (CSGU)*** 與***投影轉換層***所組成。
cgMLP 首先將輸入 $X \in \mathbb{R}^{T \times d}$ 通過 layernorm,之後經過一系列 layer 到到最後的輸出:
$$
% after \\: \hline or \cline{col1-col2} \cline{col3-col4}
{\rm Z = GeLU(XU)} \in \mathbb{R}^{T \times d_{hidden}} \\
{\rm \tilde{Z} = CSGU(Z)} \in \mathbb{R}^{T \times d_{hidden/2}} \\
{\rm Y = \tilde{Z}V} \in \mathbb{R}^{T \times d}
$$
其中 ${\rm U} \in \mathbb{R}^{d \times d_{hidden}}$,${\rm V} \in \mathbb{R}^{d_{hidden}/2 \times d}$,為兩個 channel projection,隱藏層的維度通常會大於輸入的維度,這樣的設計相似於 position-wise 的 feed-forward 層。
cgMLP 的另外一個要件為 **CSGU**,它包含了一個 linear gating 並採用了 **depth-wise convolution** 來捕捉 local 關係,他的輸入 ${\rm Z} \in \mathbb{R}^{T \times d_{hidden}}$ 會在特徵維度被均等分成 ${\rm Z_1, Z_2} \in \mathbb{R}^{T \times d_{hidden/2}}$,之後只有 ${\rm Z_2}$ 會沿著時間維度做 depth-wise convolution:
$$
{\rm Z_2' = DWConv(LayerNorm(Z_2))}
$$
最後的輸出是將 ${\rm Z_1, Z_2'}$ 做 element-wise 的相乘,得到 ${\rm \tilde{{Z}} = Z_1 \bigotimes Z_2'}$,這實際上是一種 linear gating,因為在相乘之前不會經過非線性激活層。
> #### 複雜度分析
> 在 cgMLP module 主要有兩個 channel projection 和 CSGU,其複雜度分別為 $O(Tdd_{hidden}),O(Tdd_{hidden}/2),O(TKd_{hidden}/2)$,其中 $K$ 是 kernel size 為一常數,全部看下來 cgMLP 的複雜度只跟序列長度 $T$ 成線性關係。
<center>
##### Figure 6. Convolution module of Conformer

</center>
> #### GeLU - Gaussian Error Linerar Units
> GeLU 為一機活函數,從 ReLU 變體而來,目的為改善 ReLU 在負半軸不可微的缺點,同時加入了正規化的特性與保留 ReLU 的線性特點,設計的概念為**根據輸入自身的機率分布決定是否輸出**,應用上常用於語言的預訓練模型中,GeLU 數學式如下:
> $$
> GeLU(x) = xP(X \leq x)=x\Phi(x)=x\cdot \frac{1}{2}[1+erf(x/\sqrt{2})] \\
> \Phi(x) = \int_{-\infty}^x \frac{e^{-t^2/2}}{\sqrt{2\pi}}dt = \frac{1}{2}[1+erf(x/\sqrt{2})]
> $$
>
> 透過給輸入乘上高斯分布累積函數,讓大的輸入越容易被記得,小的輸入越容易被忘記,同時也帶有正態分布的一些特性,在實驗上也證明比 ReLU 的效果更好。
> <center>
>
> ###### Figure 7. CDF of Gaussian Distribution & Figure 8. Different ELUs
> <img src="https://i.imgur.com/pw36wYg.png" height = "230"/><img src="https://i.imgur.com/SYe6bDVm.png" height = "230"/>
> </center>
### 分支的合併
Branchformer 的最後一個重點是要怎麼把兩個分支的資訊合併,這邊採用兩種作法,**直接連接 (Concatenation)**、**權重相加 (Weight average)**,預設是採用直接連接的作法,但權重相加的方式更易於解釋。
#### Concatenation
直接連接的方式易於實作,將兩個分支的輸出 ${\rm Y_{att}, Y_{mlp}} \in \mathbb{R}^{T \times d}$ 直接沿著特徵的維度相接成 ${\rm Y_{concat}} \in \mathbb{R}^{T \times 2d}$,接著乘上一個轉換矩陣投影到原本的維度:
$$
{\rm Y_{merge} = concat(Y_{att}, Y_{mlp})W_{merge}} \in \mathbb{R}^{T \times d}
$$
其中轉換矩陣 ${\rm W_{merge}} \in \mathbb{R}^{2d \times d}$ 是可學習的參數。
#### Weighted Average
為了讓模型更好解釋與增加修改性,提出了 weighted average 的做法:
1. 使用 attention pooling 總結每一分支的輸出成單一的向量
$$
{\rm y_{att}=AttPooling(Y_{att})} \\
{\rm y_{mlp}=AttPooling(Y_{mlp})} \\
where \ {\rm y_{att}, y_{mlp}} \in \mathbb{R}^d
$$
2. 將兩個分支總結的向量乘上線性轉換矩陣成單一的數值
3. 將上一步的數值經過 $softmax$ 得到分支權重
$$
w_{\rm att}, w_{\rm mlp} = {\rm softmax(W_{att}y_{att}, W_{mlp}y_{mlp})} \\
where \ {\rm W_{att}, W_{mlp}} \in \mathbb{R}^{1 \times d}
$$
4. 將分支呈上權重後相加即為最後合併的輸出
$$
{\rm Y_{merge}'}=w_{\rm att}{\rm Y_{att}}+w_{\rm mlp}{\rm Y_{mlp}} \in \mathbb{R}^{T \times d}
$$
為了讓模型 inference 速度加快,在 inference 的時候剪掉 attention branch,因此訓練的時候使用 **branch dropout** 的技巧,讓 attenion branch 以一定的機率權重為 0。
## 複雜度分析
這邊總結一下 Branchformer 各分支的複雜度,$T$ 為序列長度,$d$ 為特徵維度。
- ##### Attention Branch
如果使用原本的 self-attention 由於每個位置的特徵都要去跟其他位置的特徵做一個 $Q,K,V$ 的計算,因此複雜度會隨著序列長度平方上升;但若使用 Fastformer,由於其採用 attention pooling 技術,每個位置只會經過轉換後統一做 $softmax$,因此複雜度只會沿著 $T$ 線性增長增長,另外在線性投影轉換其複雜度由於是對特徵維度做因此複雜度與特維維度成平方關係。
| 模型 | 複雜度 | 線性投影 |
|:-------------- |:---------:|:---------:|
| Self-attention | $O(T^2d)$ | $O(Td^2)$ |
| Fastformer | $O(Td)$ | $O(Td^2)$ |
- ##### MLP Branch
MLP 的複雜度就相對簡單,由於是執行 convolution 因此複雜度只跟序列長度呈現性增長 $O(Td)$。
## 實驗結果
### 資料集
1. Aishell:170 小時中文語料
2. Switchboard (SWBD):300 小時英文電話對話
3. LibriSpeech:960 小時英文有聲書
### 設定
- 使用 ESPnet 做資料準備 (data preparation) 與模型訓練。
- 原始語音訊號會抽成 80-dim Mel filterbank
- Window 和 hop 的長度隨不同資料及調整
- 資料加強使用 SpecAugment 和語速調整
- Subsampling:$3 \times 3$ convolution,stride=2
- **Encoder**:不同的資料使用不同 $d,d_{hidden},h$
- **Decoder**:皆為 6 層 Transformer decoder
- CTC weight=0.3 with attention decoder training
- 使用 joint CTC-attention decoding
<center>
##### Table 1. Parameter details for each dataset

</center>
### Main Result
採用標準 self-attention 和 concatenation-based 分支合併,下表為與其他模型在 Aishell 下的表現:
<center>
##### Table 2. CER in different model

</center>
### 結果分析
- #### 模型大小
<center>
##### Figure 9. Performance vs. model size on SLU task

</center>
- #### 訓練穩定度
與 Conformer 比較, Branchforemr 的訓練更穩定,特別是在短句與有限制的資料上,猜測原因可能是 Conformer 在 batch norm 上的問題或是模型深度導致收斂困難,另外也發現與其他模型相比在達到同樣的 performance 下,Branchformer 更容易訓練。
- #### 不同 attention 的差異
<center>
##### Figure 10. Encoder forward time in different attention model

</center>
- #### 合併方式
<center>
##### Table 3. CER for concatenation vs. weighted ave

</center>
實驗結果發現 weighted average 的效果略輸 concatenation。
- #### Global & Local 分支的重要性
<center>
##### Figure 11. Visualization of branch weights in Branchformer with different sizes

</center>
根據上述的實驗結果可以發現,一開始模型對兩個分支的重要性不會差太多,但到了中段會出現一連串由 attention 分支主宰的情況表示這段比較需要 global 的資訊,而到後段開始會觀察到許多 MLP 分支連續出現,表示在輸出前 local 資訊比較重要,這個結果跟 Conformer 和其他研究的結果吻合,Conformer 實驗發現將 convolution module 放在 attention module 後效果較好。
- #### Attention 分析
這邊針對不同架構的 self-attention 做分析,去看 attention 能抓到多少 global 資訊,測量標準使用 attention weight 的 diagonality metric,diagonality 越高代表 attention 更專注在 local context。
實驗結果發現 Branchformer 的 diagonality 小於 Transformer,代表 Branchformer 的分支結構更適合抓取 global context。
<center>
##### Figure 12. Diagonality of self-attention in each encoder layer

</center>
下表為 Branchformer 和 Transformer 的 attention 值,有對角線的代表抓取較多 local context,較均勻分布的則代表更專注於 global context。
| Branchformer | Transformer |
|:-------------------------------------------------------------------:|:-------------------------------------------------------------------:|
| <img src="https://hackmd.io/_uploads/HylGgk9V2.png" width = "400"/> | <img src="https://hackmd.io/_uploads/H1-mxkc4h.png" width = "400"/> |
> ##### Diagonality
> 根據出處的定義,${\rm A} \in \mathbb{R}^{T\times T}$ 為 attention 權重矩陣,$a_{ij}$ 為 $i$th 對 $j$th 個特徵的 attention 值,這邊定義 $i$th row 的 centrality $C_i$:
> $$
> C_i=1-\frac{\sum_{i=1}^T a_{ij}|i-j|}{max_{1\leq j \leq T} |i-j|}
> $$
> 表示靠近的程度,上面的代表離越遠的放越大,下面表示離 $i$ 的最遠距離。那 diagonality 就是整個矩陣的平均 centrality:
> $$
> D=\frac{1}{T}\sum_{i=1}^T C_i
> $$
## 參考資料
https://zhuanlan.zhihu.com/p/349492378
https://arxiv.org/abs/2005.08100
https://arxiv.org/abs/2108.09084