本次期末專題的題目為嘗試在 pynq-z2 上實做出簡單的 single head attentation 的運算加速器。
詳細的運算過程可參考李弘毅教授的 講解 或這篇包含詳細計算過程的範例 文章 ,本文只簡單列出計算過程公式以便分析系統所需功能。
single head attentation 運算方式如下:
以上列出了所有需要運算的公式,其中
也就是說,我們系統需要的核心功能有三個:
另外為了增加傳送資料的效率,我們的 data type 設為 int8 ,如此透過 AXI 傳資料到 bram 時(一次只能傳 32 + 4 bits, 4 for parity bits),每一次能傳 32/8 = 4 筆資料。我們將 vector size 定為 4 * 1 ,如此一次傳送便能傳一個 vector 。
由於我們的 data type 為 int8 ,但經過 softmax 後的 output 為 32 bits ,因此我們還需要實現 quantize 的功能,因此總共需要實現的功能有:
除法運算和開根號這點,由於我們將 vector size 定為 4,因此除法 + 開根號相當於計算
矩陣乘法運算則是用 1 個 cycle 的組合電路達成,要注意的是為了能以 col 為單位取一個 vector 出來,我們所有矩陣乘法的 output matrix 都會是 transpose 過得結果(原本一次取 32 bits 取的會是 row )
softmax 經過我們的評估後最後決定將它交由軟體處理,一來是因為時間不夠,但 softmax 要以硬體實做,而且是浮點數的運算,但板子上的 dsp 只支援整數,用 LUT 實做的話可能會吃掉太多資源,二來是板子上的 arm cpu 有 FPU ,將很適合將浮點運算交給它處理。
而關於 quantize ,一般來說 qunatize 的運算如下:
而我們的 quantize ,是直接以硬體實做以下運算:
即我們系統中的 S =
將 S 定為
而 Z 會對應到 quantize 前的零點,因此:
我們來看看 IEEE 754 中對 float 儲存的規範如下:
由於 softmax 的 output 為 0 ~ 1 ,因此 sign bit 永遠為 0 ,而乘上 256 =
因此可得:
因此我們要做的計算是
如此便能很簡單的實做出 quantize 硬體:
assign E = floatNum[30:23];
assign Mantisa = floatNum[22:0];
always @(*) begin
shift_delta = E + 8 - 127;
// Clac the shift delta based on IEEE 754
// Reasonable range of E: 120 ~ 126
// shift_delta = 1 ~ 7
// rare -> E = 127 (output of softmax = 1)
// shift_delta = 8, int_res is sat to 255
int_res = 0;
case (shift_delta)
0: int_res = 32'b1;
1: int_res = {30'b0 ,1'b1, Mantisa[22]};
2: int_res = {29'b0 ,1'b1, Mantisa[22:21]};
3: int_res = {28'b0 ,1'b1, Mantisa[22:20]};
4: int_res = {27'b0 ,1'b1, Mantisa[22:19]};
5: int_res = {26'b0 ,1'b1, Mantisa[22:18]};
6: int_res = {25'b0 ,1'b1, Mantisa[22:17]};
7: int_res = {24'b0 ,1'b1, Mantisa[22:16]};
8: int_res = 32'd255; // int_res is sat to 255
default: int_res = 32'b0; // if shift_delta < 0
endcase
result = int_res[7:0];
end
再來要處理的便是 PS 和 PL 要如何溝通,我們原先想用兩塊 bram ,一塊是 PS 寫 PL 讀,另一塊則是 PL 寫 PS 讀,但由於我們不熟悉 vivado 的 bram IP ,因此最後只有 PS 寫 PL 讀的 bram 有正常運作,後者我們最後改成拉 gpio 出來。
最終接線圖如下:
基本上就是用軟體實做簡單的 softmax ,唯一要注意的地方是,由於我們輸入給 Softmax 的 input 是 int8 ,每個數值間的離散程度很大,現在看一下 softmax 的運算:
可以看到,當每個數值間的散程度很大時,會有一個最大值
為了避免這個現象,我們在將資料送入 softmax 時會除上一個 softmax temperature T ,這個想法是源自進行 Knowledge distillation 時為了避免 int8 導致一樣的問題(只有一個數值為 1 其他為 0) 時的作法。
因此我們的 softmax 變成:
而我們的 T 設為 64 。
最後試了幾筆測資,結果正確,但是我們發現一個嚴重的問題,就是我們用硬體做計算的計算速度沒有比用軟體還要快,甚至稍慢。
我們便著手研究問題出在哪裡,我們嘗試從分析 AXI 和 cpu 上 clk 的速度差異開始切入,首先, pynq-z2 上的 cpu clk rate 是 650MHz ,且有兩個 core ,而 axi bus 上的 clk rate 僅為 100MHz 。
最終我們發現最大的瓶頸是卡在 gpio 的傳輸上:
我們嘗試做了 10000 次連續的讀寫並計時花了多少時間,發現一次 gpio 的讀或寫會有相當於 2050 個 cpu clk 的 overhead 。
由於我們所有硬體上所有操作都需要透過 gpio 給指令,而一個 gpio 傳輸的 overhead 會用掉 ~ 2050 個 cpu clk ,因此會造成很大的時間浪費。
我們嘗試計算了兩者的時間差,卻發現算出來的結果和我們實際量出來的差了 32381 個 axi clk ,明顯不能算做誤差,但我們認為可能會有這樣情況的原因有:
而老師在這個基礎上又告訴我們幾個可能的原因:
以上種種原因造成我們算出來的 clk 和真實情況有所誤差。
最後附上我們專案的 網址 ,歡迎有興趣的人嘗試重現實驗。
當有一個字串