# Linux 核心專題: 並行的 fibdrv
> 執行人: ericlai1021
> [專題解說錄影](https://youtu.be/6kW0vLJVo9Q)
:::success
:question: 提問清單
* ?
:::
## 任務簡述
擴充 [fibdrv](https://hackmd.io/@sysprog/linux2023-fibdrv),強化其並行處理能力,預計達成:
* 有效運算 Fibonacci 數 (至少能算到第一百萬個) 並降低記憶體開銷
* 藉由 hashtable 或 cache 一類的機制,儲存已計算的 Fibonacci 數
* 引入 workqueue,將運算要求分派給個別 CPU 核,並確保降低非必要的同步處理成本
* 修訂 fibdrv 和應用程式之間的 API,使其適合用於同步處理
## TODO: 落實 Fibonacci 數的計算效率
1. 彙整教材和學員成果,可延用現有程式碼,但應清楚標示出處並持續改進。
2. 掌握「加速 Fibonacci 運算」及「[sysprog21/bignum](https://github.com/sysprog21/bignum) 程式碼分析」,落實於 fibdrv 內部實作中
3. 提供驗證機制,確保 fibdrv 至少能算到第一百萬個 Fibonacci 數
4. 修訂 fibdrv 和應用程式之間的 API,使資料傳輸和操作更有效
## TODO: 儲存已計算的 Fibonacci 數
1. 考慮到大數運算的特性,當以 key-value 形式保存時,不是儲存單純的整數值,而是指向特定結構的指標,於是當 fibdrv 嘗試釋放佔用的記憶體空間時,應有對應的操作
2. 考慮到 fast doubling 和 Fibonacci 數的特性,不用保存連續數值,而是關注第 N 個和第 2N 個 Fibonacci 數的關聯,儘量降低記憶體開銷
3. 應當善用 Linux 核心的 hashtable 或相關的資料結構
## 引入 workqueue,確保並行處理的效益
1. 學習 [ktcp](https://hackmd.io/@sysprog/linux2023-ktcp),引入 kthread 和 CMWQ 到 fibdrv,確保 Fibonacci 數的運算可發揮硬體能力
2. 確保並行處理的效益,不僅要確認結果正確,還要讓並行的 fibdrv 得以更有效的運算
## 改善大數運算
基於先前作業已完成部份,請參閱 [ericlai1021-fibdrv](https://hackmd.io/@ericlai1021/linux2023q1-fibdrv) ,先以 `perf stat` 分析程式碼,作為後續比對的基準
:::info
此部份為了方便後續做實驗比較,程式皆在 user space 執行
:::
```shell
input a number: 100000
Performance counter stats for './test':
685,1248,4134 instructions # 1.73 insn per cycle (83.31%)
395,8769,6486 cycles (83.31%)
14,5768 cache-misses # 12.170 % of all cache refs (83.34%)
119,7772 cache-references (83.35%)
3,8748,8927 branch-misses # 8.87% of all branches (83.35%)
43,6638,9672 branch-instructions (66.65%)
15.557167671 seconds time elapsed
11.697195000 seconds user
0.000000000 seconds sys
```
接下來使用 `perf record` 量測 call graph (省略部份內容)
```shell
$ perf record -g --call-graph dwarf ./test
$ perf report --stdio -g graph,0.5,caller
# Children Self Command Shared Object Symbol
# ........ ........ ....... .................... ...................
#
99.98% 0.00% test test [.] main
|
---main
|
|--86.94%--bn_to_string
|
--13.04%--bn_mult
|
|--6.65%--bn_add
|
--6.33%--bn_lshift
86.94% 86.62% test test [.] bn_to_string
|
--86.62%--_start
__libc_start_main_impl (inlined)
__libc_start_call_main
main
bn_to_string
13.04% 0.02% test test [.] bn_mult
|
--13.02%--bn_mult
|
|--6.65%--bn_add
|
--6.33%--bn_lshift
6.65% 6.53% test test [.] bn_add
|
--6.53%--_start
__libc_start_main_impl (inlined)
__libc_start_call_main
main
bn_mult
bn_add
6.33% 6.24% test test [.] bn_lshift
|
--6.24%--_start
__libc_start_main_impl (inlined)
__libc_start_call_main
main
|
--6.24%--bn_mult
bn_lshift
```
* 有 86.94% 的時間 (準確來說是樣本數) 落在 `bn_to_string` ,由此可見大數由二進制轉換成十進制的成本非常高,更不用說考慮到執行在 kernel space 時 `copy_to_user` 的成本,因此改善此部份勢必具有明顯的效能增益
* 有 13.04% 的時間落在 `bn_mult` ,這部份的實作為參考傳統乘法器原理,因此會有大量的加法以及左移運算,需要提出一個更高效能的計算方法
### 改善方案 1: 運用 Q-Matrix 改進 fast doubling 的實作
主函式內採用 `fast_doubling` 演算法實作大數運算,然而 `fast_doubling` 的特色為會紀錄第 `2n` 項以及第 `2n + 1` 項的結果,若要進一步加速此運算則可以採用 Q-Matrix 搭配 [`Exponentiation by squaring`](https://en.wikipedia.org/wiki/Exponentiation_by_squaring) 的技巧
所謂的 Q-Matrix 可以進一步將費氏數列改寫成以下形式:
$$
Q =
\begin{pmatrix}
1 & 1 \\
1 & 0
\end{pmatrix}
=
\begin{pmatrix}
F_2 & F_1 \\
F_1 & F_0
\end{pmatrix}
\\
Q^n =
\begin{pmatrix}
F_{n+1} & F_n \\
F_n & F_{n-1}
\end{pmatrix}
$$
Exponentiation by squaring 則可以進一步將 $Q^n$ 的計算改寫成:
$$
Q^n=\left\{\begin{array}{l}Q\left(Q^2\right)^\frac{(n-1)}2,\;\;if\;n\;is\;odd\\\left(Q^2\right)^\frac n2,\;\;if\;n\;is\;even\end{array}\right.
$$
對應的程式碼如下:
```python
def multiply_matrices(a, b):
# Matrix multiplication helper function
result = [[0, 0], [0, 0]]
for i in range(2):
for j in range(2):
for k in range(2):
result[i][j] += a[i][k] * b[k][j]
return result
def power_matrix(matrix, n):
# Matrix exponentiation helper function
result = [[1, 0], [0, 1]]
while n > 0:
if n % 2 == 1:
result = multiply_matrices(result, matrix)
matrix = multiply_matrices(matrix, matrix)
n //= 2
return result
def fibonacci(n):
if n <= 0:
return "Input should be a positive integer."
matrix = [[1, 1], [1, 0]]
power = power_matrix(matrix, n - 1)
fib_n = power[0][0]
return fib_n
```
實驗分析
![](https://hackmd.io/_uploads/Syn0Gmsr2.png)
Q-Matrix 的效能明顯比 Fast Doubling 來的差,其實不難理解是因為 Fast Doubling 其實就是 Q-Matrix 的改良,而直接使用 Q-Matrix 計算的話反而會額外多出許多計算量
* 改進 fast doubling 實作
參考 [KYG-yaya573142](https://hackmd.io/@KYWeng/rkGdultSU#%E6%94%B9%E5%96%84%E6%96%B9%E6%A1%88-2---%E4%BD%BF%E7%94%A8%E4%B8%8D%E5%90%8C%E7%9A%84-Q-Matrix-%E5%AF%A6%E4%BD%9C-bn_fib_fdoubling) 的作法稍微調整 fast doubling 的實作方法,式子推導如下
$$
\begin{split}
\begin{bmatrix}
F(2n-1) \\
F(2n)
\end{bmatrix} &=
\begin{bmatrix}
0 & 1 \\
1 & 1
\end{bmatrix}^{2n}
\begin{bmatrix}
F(0) \\
F(1)
\end{bmatrix}\\ \\ &=
\begin{bmatrix}
F(n-1) & F(n) \\
F(n) & F(n+1)
\end{bmatrix}
\begin{bmatrix}
F(n-1) & F(n) \\
F(n) & F(n+1)
\end{bmatrix}
\begin{bmatrix}
1 \\
0
\end{bmatrix}\\ \\ &=
\begin{bmatrix}
F(n)^2 + F(n-1)^2\\
F(n)F(n) + F(n)F(n-1)
\end{bmatrix}
\end{split}
$$
整理後可得
$$
\begin{split}
F(2k-1) &= F(k)^2+F(k-1)^2 \\
F(2k) &= F(k)[2F(k-1) + F(k)]
\end{split}
$$
使用上述式子實作比起作業說明的 [範例實作](https://hackmd.io/@sysprog/linux2023-fibdrv/%2F%40sysprog%2Flinux2023-fibdrv-a#-%E8%B2%BB%E6%B0%8F%E6%95%B8%E5%88%97) 可以減少一次迴圈的計算以及省去掉減法的運算
* 在實作的過程中還發現原先的寫法都是使用 `bn_cpy` 來更新變數的數值,其實可以藉由 `bn_swap` 以及改變各函式儲存結果的位置來達到同樣的目的,因此就將所有的 `bn_cpy` 去除改用 `bn_swap` 以降低複製資料造成的成本
* `bn_swap` 的實作如下
一開始的想法是將兩個 `bn *` 型態的位址互換即可
```c
/* swap bn ptr */
void bn_swap(bn *a, bn *b)
{
bn *tmp = a;
a = b;
b = tmp;
}
int main()
{
bn *ptrA, *ptrB;
bn_swap(ptrA, ptrB);
return 0;
}
```
結果發現根本沒有互換成功,於是先將 `a` 與 `b` 的位址印出來看,發現在呼叫 `bn_swap` 函式後 `a` 與 `b` 的位址並無交換,但在 `bn_swap` 函式內兩個指標位址確實有交換,研讀課程教材 〈[你所不知道的C語言:指標篇](https://hackmd.io/@sysprog/c-pointer#%E6%B2%92%E6%9C%89%E3%80%8C%E9%9B%99%E6%8C%87%E6%A8%99%E3%80%8D%E5%8F%AA%E6%9C%89%E3%80%8C%E6%8C%87%E6%A8%99%E7%9A%84%E6%8C%87%E6%A8%99%E3%80%8D)〉後才了解到原來 C 語言在函式呼叫皆是 call-by-value ,上述程式的行為是呼叫 `bn_swap` 時會傳遞 `ptrA` 與 `ptrB` 的位址,而 `bn_swap` 函式會產生 `bn *` 型態變數 `a` 與 `b` 分別將 `ptrA` 與 `ptrB` 的位址存在其中,示意圖如下
```graphviz
digraph structs {
node[shape=record]
{rank=same; structa}
structp [label="ptrA"]
structptr [label="<name_ptr> a"];
structa [label="ptrA內部的數值"];
structptr:ptr -> structa:A:nw
structp:p -> structa:A:nw
structpb [label="ptrB"]
structptrb [label="<name_ptr> b"];
structb [label="ptrB內部的數值"];
structptrb:ptr -> structb:A:nw
structpb:pb -> structb:A:nw
}
```
執行 `bn_swap` 函式後會變成下圖所示
```graphviz
digraph structs {
node[shape=record]
{rank=same; structa}
structp [label="ptrA"]
structptr [label="<name_ptr> a"];
structa [label="ptrA內部的數值"];
structptr:ptr -> structb:A:nw
structp:p -> structa:A:nw
structpb [label="ptrB"]
structptrb [label="<name_ptr> b"];
structb [label="ptrB內部的數值"];
structptrb:ptr -> structa:A:nw
structpb:pb -> structb:A:nw
}
```
若要正確交換,則必須要使用「指標的指標」技巧,因為 `bn` 資料結構中 `number` 紀錄的是指標,所以可以透過以下方式將兩個 `bn` 型態變數的內容互換,而不會改變儲存在 heap 中的數值
:::danger
C 語言沒有 "call-by-reference",只有數值傳遞 (包含指標在內,都是數值),請研讀〈[你所不知道的C語言:指標篇](https://hackmd.io/@sysprog/c-pointer)〉,用精準的描述。
:notes: jserv
> 已修正認知
:::
```c
/* swap bn ptr */
void bn_swap(bn *a, bn *b)
{
bn tmp = *a;
*a = *b;
*b = tmp;
}
```
示意圖如下
```graphviz
digraph structs {
node[shape=record]
{rank=same; structa}
structp [label="ptrA"]
structptr [label="<name_ptr> a"];
structa [label="ptrA內部的數值"];
structptr:ptr -> structa:A:nw
structp:p -> structa:A:nw
structpb [label="ptrB"]
structptrb [label="<name_ptr> b"];
structb [label="ptrB內部的數值"];
structptrb:ptr -> structb:A:nw
structpb:pb -> structb:A:nw
}
```
執行 `bn_swap` 函式後會變成下圖所示
```graphviz
digraph structs {
node[shape=record]
{rank=same; structa}
structp [label="ptrA"]
structptr [label="<name_ptr> a"];
structa [label="ptrB內部的數值"];
structptr:ptr -> structa:A:nw
structp:p -> structa:A:nw
structpb [label="ptrB"]
structptrb [label="<name_ptr> b"];
structb [label="ptrA內部的數值"];
structptrb:ptr -> structb:A:nw
structpb:pb -> structb:A:nw
}
```
實驗結果如下 (v1 綠線為修改的 fast doubling 並將所有 `bn_cpy` 換成 `bn_swap`)
![](https://hackmd.io/_uploads/BJ_UwLorh.png)
完全出乎意料之外的竟然沒有改善,為此我反覆測驗程式並檢查程式碼是否正確,但最終的結果確實如此,於是我就去查看兩個實驗的 call graph
* 原始實作的 call graph
```shell
# Children Self Command Shared Object Symbol
# ........ ........ ....... .................... ..........................................
#
95.50% 0.00% test test [.] _start
|
---_start
__libc_start_main_impl (inlined)
__libc_start_call_main
main
|
--94.81%--fast_doubling
|
|--93.29%--bn_mult
| |
| |--46.77%--bn_add
| | |
| | |--12.52%--bn_resize
| | | |
| | | |--4.95%--__GI___libc_realloc (inlined)
| | | | |
| | | | --1.28%--_int_realloc
| | | |
| | | --1.85%--__memset_avx2_unaligned_erms
| | |
| | --2.01%--_GLOBAL_OFFSET_TABLE_
| | __GI___libc_realloc (inlined)
| |
| |--34.63%--bn_lshift
| | |
| | |--4.72%--bn_clz
| | |
| | --1.34%--bn_resize
| |
| |--5.75%--bn_digit
| | |
| | --3.85%--bn_clz
| |
| |--0.60%--bn_alloc
| |
| --0.50%--bn_resize
|
--0.66%--bn_add
```
* 修改後的 call graph
```shell
# Children Self Command Shared Object Symbol
# ........ ........ ....... .................... ..........................................
#
95.79% 0.00% test test [.] _start
|
---_start
__libc_start_main_impl (inlined)
__libc_start_call_main
|
--95.78%--main
|
|--94.94%--fast_doubling
| |
| |--94.04%--bn_mult
| | |
| | |--47.49%--bn_add
| | | |
| | | |--12.71%--bn_resize
| | | | |
| | | | |--4.95%--__GI___libc_realloc (inlined)
| | | | | |
| | | | | --1.19%--_int_realloc
| | | | |
| | | | --2.17%--__memset_avx2_unaligned_erms
| | | |
| | | --2.11%--_GLOBAL_OFFSET_TABLE_
| | | __GI___libc_realloc (inlined)
| | |
| | |--34.43%--bn_lshift
| | | |
| | | |--4.76%--bn_clz
| | | |
| | | --1.27%--bn_resize
| | |
| | --5.74%--bn_digit
| | |
| | --3.66%--bn_clz
| |
| --0.64%--bn_add
|
--0.55%--__printf (inlined)
|
--0.54%--__vfprintf_internal
```
從 call graph 就可以看出來原來是因為大數的乘法運算太花費時間,兩個實驗都有 94% 左右的時間在執行 `bn_mult` ,所以才會導致上述所作的改善微乎其微
### 改善方案 2: 改進 `bn_mult` 的效能
原本的 `bn_mult` 實作是參考傳統硬體乘法器,如此次來假設乘數有 x 個位元,則 worst case 就會需要執行 x 次的 `bn_add` 以及 x-1 次的 `bn_lshift` ,可想而知這對大數運算非常不適合,因此初步參考 [KYG-yaya573142](https://hackmd.io/@KYWeng/rkGdultSU#%E6%94%B9%E5%96%84%E6%96%B9%E6%A1%88-2---%E4%BD%BF%E7%94%A8%E4%B8%8D%E5%90%8C%E7%9A%84-Q-Matrix-%E5%AF%A6%E4%BD%9C-bn_fib_fdoubling) 的作法,概念即為直式乘法,將兩個大數的數字陣列依序兩兩相乘,接著將結果直接疊加到輸出的變數
```c
void bn_mult(const bn *a, const bn *b, bn *c)
{
// max digits = sizeof(a) + sizeof(b))
int d = bn_digit(a) + bn_digit(b);
d = DIV_ROUNDUP(d, 32) + !d; // round up, min size = 1
bn *tmp;
/* make it work properly when c == a or c == b */
if (c == a || c == b) {
tmp = c; // save c
c = bn_alloc(d);
} else {
tmp = NULL;
for (int i = 0; i < c->size; i++)
c->number[i] = 0U;
bn_resize(c, d);
}
for (int i = 0; i < a->size; i++) {
for (int j = 0; j < b->size; j++) {
unsigned long long int carry = 0U;
carry = (unsigned long long int) a->number[i] * b->number[j];
unsigned long long int tmp = 0;
for (int k = i + j; k < c->size; k++) {
tmp += c->number[k] + (carry & 0xFFFFFFFF);
c->number[k] = tmp;
tmp >>= 32;
carry >>= 32;
if (!carry && !tmp) // done
break;
}
}
}
if (tmp) {
bn_swap(tmp, c);
bn_free(c);
}
}
```
實驗結果 (v2 綠線)
![](https://hackmd.io/_uploads/Bkk899oSn.png)
接著進一步發現第三層迴圈的用意為將每一輪相乘的結果疊加到輸出的變數,其實只需要將相乘的結果與輸出變數相加,再用一個變數 `carry` 儲存溢位的部份加到下一輪的相加結果就好
```c
void bn_mult(const bn *a, const bn *b, bn *c)
{
...
unsigned long long product;
unsigned int carry = 0U;
for (int i = 0; i < a->size; i++) {
for (int j = 0; j < b->size; j++) {
product = (unsigned long long) a->number[i] * b->number[j] +
carry + c->number[i + j];
carry = product >> 32;
c->number[i + j] = product & 0xFFFFFFFF;
}
if (carry) {
c->number[i + b->size] = carry;
carry = 0U;
}
}
...
}
```
實驗結果 (v3 綠線)
![](https://hackmd.io/_uploads/BynxR3jrh.png)
### 改善方案 3: 善用 64 位元 CPU 特性
原先 `bn` 結構體中數字陣列 `number` 的資料型態是 `unsigned int` ,為了能充分利用 64 位元處理器的特性改使用 `uint64_t` 確保每次的記憶體存取都是一個 word 的大小。
```c
#include <stdint.h>
#if defined(__LP64__) || defined(__x86_64__) || defined(__amd64__) || defined(__aarch64__)
#define BN_WSIZE 8
#else
#define BN_WSIZE 4
#endif
#if BN_WSIZE == 8
typedef uint64_t bn_data;
typedef unsigned __int128 bn_data_tmp; // gcc support __int128
#define DATA_BITS 64U
#define builtin_clz(x) __builtin_clzll(x)
#elif BN_WSIZE == 4
typedef uint32_t bn_data;
typedef uint64_t bn_data_tmp;
#define DATA_BITS 32U
#define builtin_clz(x) __builtin_clz(x)
#else
#error "BN_WSIZE must be 4 or 8"
#endif
typedef struct _bn {
bn_data *number;
bn_data size;
} bn;
```
* 參考 [bignum/apm.h](https://github.com/sysprog21/bignum/blob/master/apm.h) 當中的方式來定義 `bn` 結構體的資料型態,以便於根據不同 word 大小切換定義
* 加法及乘法運算時會用到 2 倍大小的的暫存變數,直接使用 gcc 提供的 `__int128` 實作
實驗結果如下 (v4 綠線)
![](https://hackmd.io/_uploads/H1csnTb82.png)
### 改善方案 4: 內嵌組合語言
參考 [bignum/apm_internal.h](https://github.com/sysprog21/bignum/blob/master/apm_internal.h#L26) 當中乘法運算使用內嵌組合語言 (inline assembly) 來直接取得乘法運算的高位與低位,直接使用一樣的方式實作乘法,取代原本使用的 `__int128`
```c
void bn_mult(const bn *a, const bn *b, bn *c)
{
...
bn_data_tmp carry = 0UL;
bn_data *numA = a->number;
for (int j = 0; j < b->size; j++) {
bn_data multiplier = b->number[j];
for (int i = 0; i < a->size; i++) {
bn_data high, low;
__asm__("mulq %3"
: "=a"(low), "=d"(high)
: "%0"(numA[i]), "rm"(multiplier)
);
carry +=(bn_data_tmp) low + c->number[i + j];
c->number[i + j] = carry;
carry = high + (carry >> DATA_BITS);
}
if (carry) {
c->number[j + a->size] = carry;
carry = 0UL;
}
}
...
}
```
由於我目前還無法完全理解 [bignum/apm.c](https://github.com/sysprog21/bignum/blob/master/apm.c#L206) 的實作,所以就初步按照自己目前對程式碼的理解進行實作
實驗結果如下 (v5 綠線)
![](https://hackmd.io/_uploads/HyIwhbXUn.png)
結果看起來差異不大,將範圍放大至10000項來看
![](https://hackmd.io/_uploads/S1rnhbQUh.png)
結果顯示我的實作效能較差,所以看起來差異真的是在處理如何將乘積疊加到輸出變數那邊,於是我先做實驗驗證看看
```c
void bn_mult(const bn *a, const bn *b, bn *c)
{
...
bn_data *numA = a->number;
for (int j = 0; j < b->size; j++) {
bn_data multiplier = b->number[j];
bn_data carry = 0;
for (int i = 0; i < a->size; i++) {
bn_data high, low;
__asm__("mulq %3"
: "=a"(low), "=d"(high)
: "%0"(numA[i]), "rm"(multiplier)
);
carry = high + ((low += carry) < carry);
carry += ((c->number[i + j] += low) < low);
}
c->number[j + a->size] = carry;
}
...
}
```
實驗結果如下 (v6 綠線)
![](https://hackmd.io/_uploads/HJTvVLV8n.png)
上述實作是利用無號整數不會 overflow 的特性,舉例假設有兩個 4 bits 的整數 `a` 與 `b` 且兩個數皆為 4 bits 可表示最大數值 (即 15, 二進制表示成 1111),則 `a + b` 等於 30, 二進制表示成 11110 ,但因為是無號整數,所以只會保留 0~3 bit 的值,也就是 1110,這裡就可以看出若兩數值相加後小於其中任一數值,就表示發生 overflow 且只要 overflow 溢位的數值一定是 1 。
藉由上述特性就可以避免使用 `__int128` (`bn_data_tmp`) 進行計算以節省多餘的記憶體開銷以及資料型態轉換成本。
### 改善方案 5: 改進 `bn_add` 的效能
原先的實作會在每次迴圈判斷需要相加的數值,這麼做的優點是只需寫一個迴圈就能完成計算,但缺點是每次迴圈都有兩個 branch 要判斷。參考 [KYG-yaya573142](https://hackmd.io/@KYWeng/rkGdultSU#%E6%94%B9%E5%96%84%E6%96%B9%E6%A1%88-I---%E6%94%B9%E5%AF%AB-bn_add-%E7%9A%84%E5%AF%A6%E4%BD%9C%E6%B3%95) 的實作改為使用兩個迴圈進行計算,第一個迴圈先計算兩者皆有資料的範圍,再於第二個迴圈處理 carry 與剩餘的資料範圍。另外,利用上一個改善方案所提及的無號整數不會 overflow 的特性,可以進一步可以避免使用 `__int128` (`bn_data_tmp`) 進行計算
```c
/* c = a + b */
void bn_add(const bn *a, const bn *b, bn *c)
{
if (a->size < b->size) {
SWAP(a, b);
}
int asize = a->size, bsize = b->size;
bn_resize(c, a->size + 1);
bn_data carry = 0;
for (int i = 0; i < bsize; i++) {
bn_data tmp1 = a->number[i];
bn_data tmp2 = b->number[i];
carry = (tmp1 += carry) < carry;
carry += (c->number[i] = tmp1 + tmp2) < tmp2;
}
if (asize != bsize) { // deal with the remaining part if asize > bsize
for (int i = bsize; i < asize; i++) {
bn_data tmp1 = a->number[i];
carry = (tmp1 += carry) < carry;
c->number[i] = tmp1;
}
}
if (carry) {
c->number[asize] = carry;
}
if (!c->number[c->size - 1] && c->size > 1)
bn_resize(c, c->size - 1);
}
```
因為要先計算兩者皆有的資料範圍,所以要先找出兩者當中範圍較小者,但這裡的做法是假設 `a` 的範圍比 `b`大,所以若遇到 `a` 的範圍比 `b` 小的情況就必須要將兩者互換,這裡值得注意的地方是原先實作的 `bn_swap` 函式為交換兩者的內容,但加法運算為了確保兩個輸入變數 (即 `a` 與 `b`) 的內容不會被更改,因此將兩者皆宣告成 `const` ,如此一來就不能使用原先的 `bn_swap`,參考 [bignum/apm_internal.h](https://github.com/sysprog21/bignum/blob/master/apm_internal.h#L111) 當中的做法,透過定義一個巨集,交換指定的二個變數的數值。
```c
#ifndef SWAP
#define SWAP(x, y) \
do { \
typeof(x) __tmp = x; \
x = y; \
y = __tmp; \
} while (0)
#endif
```
為了讓加法運算遇到 `a == c` 或 `b == c` 依舊能夠正確計算,必須要在 `bn_resize` 之前將 `a` 跟 `b` 的大小 (size) 暫存起來。
為了凸顯 `bn_add` 對效能的影響,這裡改為使用迭代的方法計算費氏數列
實驗結果如下 (v1 綠線)
![](https://hackmd.io/_uploads/B1KXRCBLh.png)
### 改善方案 6: 引入 `bn_sqr`
```
a b c
x a b c
-------------------
ac bc cc
ab bb bc
aa ab ac
```
考慮上述 $(abc)^2$ 的計算過程,會發現數值 $ab$ 、 $ac$ 與 $bc$ 各會重複一次,利用此特性,先計算對角線任一邊的數值,接著再將數值總和乘二,最後再加上對角線上的 $aa$ 、 $bb$ 與 $cc$。藉由此法,平方運算的成本可由本來的 $n^2$ 次乘法降為 $\dfrac{n^2 - n}{2}$ 次乘法
> 實作參考 [KYG-yaya573142](https://hackmd.io/@KYWeng/rkGdultSU#%E5%BC%95%E5%85%A5-bn_sqr) 及 [bignum/sqr.c](https://github.com/sysprog21/bignum/blob/master/sqr.c)
```c
void bn_sqr(bn *dest, const bn *src)
{
...
const bn_data *sp = src->number;
bn_data *dp = dest->number + 1;
bn_data size = src->size - 1;
for (int i = 0; i < size; i++) {
bn_data carry = 0;
for (int j = 0; j < size - i; j++) {
bn_data high, low;
__asm__("mulq %3"
: "=a"(low), "=d"(high)
: "%0"(sp[i + 1 + j]), "rm"(sp[i])
);
carry = high + ((low += carry) < carry);
carry += ((dp[j] += low) < low);
}
dp[size - i] = carry;
dp += 2;
}
/* Double it */
for (int i = 2 * src->size - 1; i > 0; i--)
dest->number[i] = dest->number[i] << 1 |
dest->number[i - 1] >> (DATA_BITS - 1);
dest->number[0] <<= 1;
/* add the (aa bb cc) part at diagonal line */
dp = dest->number;
sp = src->number;
size = src->size;
bn_data carry = 0;
for (int i = 0; i < size; i++) {
bn_data high, low;
__asm__("mulq %3"
: "=a"(low), "=d"(high)
: "%0"(sp[i]), "rm"(sp[i])
);
high += (low += carry) < carry;
high += (dp[0] += low) < low;
carry = (dp[1] += high) < high;
dp += 2;
}
...
}
```
實驗結果如下 (v7 綠線)
![](https://hackmd.io/_uploads/r1oiO9O83.png)
將範圍擴大至第 20000 項就可以明顯看出改善
![](https://hackmd.io/_uploads/rJdXY5_I2.png)
### 改善方案 7: 實作 [Karatsuba algorithm](https://en.wikipedia.org/wiki/Karatsuba_algorithm)
觀察 [bignum/mul.c](https://github.com/sysprog21/bignum/blob/master/mul.c) 及 [bignum/sqr.c](https://github.com/sysprog21/bignum/blob/master/sqr.c) 皆有使用 Karatsuba 演算法來加速乘法與平方運算,因此接下來一樣實作該演算法來提升效能
先放上 v7 版本與 bignum 的效能差異來觀察後續改進的成效
![](https://hackmd.io/_uploads/SJdHcYT8n.png)
v7 版本的 call graph (只擷取部份內容)
```shell
#
# Children Self Command Shared Object Symbol
# ........ ........ ....... .................... ..................................
#
99.42% 0.00% test test [.] main
|
---main
fast_doubling
|
|--49.78%--bn_sqr
|
--49.54%--bn_mult
do_mult_base
```
Karatsuba 的概念是將 $a$、$b$ 以第 $n$ 位數為界,拆成兩半 $a_1$、$a_0$、$b_1$、$b_0$,把這他們視為較小的數相乘,然後再透過左移補回 $a_1$、$b_1$ 損失的位數,以二進位為例:
$a = a_1 \times 2^n+ a_0 \\ b=b_1 \times 2^n+b_0$
則 $a \times b$ 可以化為:
$\underbrace{a_1b_1}_{z_2} \times 2^{2n}+\underbrace{(a_1b_0+b_1a_0)}_{z_1} \times 2^n+\underbrace{a_0b_0}_{z_0}$
上述算法計算 $z_2$、$z_1$、$z_0$ 需要 4 次乘法,我們還可以透過以下技巧縮減為 3 次乘法:
觀察 $(a_1+a_0)(b_1+b_0)$ 展開的結果
$(a_1+a_0)(b_1+b_0)=\underbrace{a_1b_1}_{z_2}+\underbrace{a_1b_0+a_0b_1}_{z_1}+\underbrace{a_0b_0}_{z_0}$
移項之後,我們就能利用 $(a_1+a_0)(b_1+b_0)$、$z_0$、$z_2$ 來計算 $z_1$
$z_2=a_1b_1$
$z_0 = a_0b_0$
$z_1=(a_1+a_0)(b_1+b_0)-z_0-z_2$
最後計算 $z_2\times2^{2n}+z_1\times2^n+z_0$ 便能得到 $a$ $b$ 相乘的結果,且 $\times 2^n$ 可以用左移運算代替。
再舉個例子,假設所採用的處理器只支援 8 位元乘法,當 $x$、$y$ 超過 8 位元時,可以透過分治法實作 Karatsuba。$x_1$、$x_0$、$y_1$、$y_0$ 的位元數超出處理器的乘法的位數時,就把他們再切為 $x_{11}$、$x_{10}$、$x_{01}$、$x_{00}$、...,再使用 Karatsuba 計算。以下以兩個 16 位元數值相乘變成 32 位元來演示
![](https://hackmd.io/_uploads/SJaGdu2Ri.png)
由上圖可以看出計算 $z_2$、$z_1$、$z_0$ 時,透過分治法將 $x_1$、$x_0$、$y_1$、$y_0$ 切成更小的數字執行乘法運算。最後再用左移與加法計算 $z_2 \times2^{16} + z_1 \times2^8 + z_0$ 即可求得結果。
至此可透過分治法,運用 Karatsuba 演算法計算任意位數的大數。
#### 實作 Karatsuba 乘法
> 實作參考 [KYG-yaya573142](https://github.com/KYG-yaya573142/fibdrv/blob/d8bbd795b11fa8a473f03a8bcb42c4ce2f1f8d62/bn.c#L411) 與 [bignum/mul.c](https://github.com/sysprog21/bignum/blob/master/mul.c)
:::spoiler 程式碼解析
首先來看 `bn_mult` 函式的修改
```c
/*
* c = a x b
* Note: work for c == a or c == b
* using the simple quadratic-time algorithm (long multiplication)
*/
void bn_mult(const bn *a, const bn *b, bn *c)
{
if (a->size < b->size) // need asize > bsize
SWAP(a, b);
// max digits = sizeof(a) + sizeof(b))
bn_data asize = a->size, bsize = b->size;
int csize = asize + bsize;
bn *tmp;
/* make it work properly when c == a or c == b */
if (c == a || c == b) {
tmp = c; // save c
c = bn_alloc(csize);
} else {
tmp = NULL;
for (int i = 0; i < c->size; i++)
c->number[i] = 0; // clean up c
bn_resize(c, csize);
}
bn_data *ap = a->number;
bn_data *bp = b->number;
bn_data *cp = c->number;
if (b->size < KARATSUBA_MUL_THRESHOLD) {
do_mult_base(ap, asize, bp, bsize, cp);
} else {
do_mult_karatsuba(ap, bp, bsize, cp);
/* it's assumed that a and b are equally length in
* Karatsuba multiplication, therefore we have to
* deal with the remaining part after hand */
if (asize == bsize)
goto end;
/* we have to calc a[bsize ~ asize-1] * b */
cp += bsize;
csize -= bsize;
ap += bsize;
asize -= bsize;
bn_data *_tmp = NULL;
/* if asize = n * bsize, multiply it with same method */
if (asize >= bsize) {
_tmp = (bn_data *) calloc(2 * bsize, sizeof(bn_data));
do {
do_mult_karatsuba(ap, bp, bsize, _tmp);
bn_data carry;
carry = _add_partial(cp, _tmp, bsize * 2, cp);
for (int i = bsize * 2; i < csize; i++) {
bn_data tmp1 = cp[i];
carry = (tmp1 += carry) < carry;
cp[i] = tmp1;
}
cp += bsize;
csize -= bsize;
ap += bsize;
asize -= bsize;
assert(carry == 0);
} while (asize >= bsize);
}
/* if asize != n * bsize, simply calculate the remaining part */
if (asize) {
if (!_tmp)
_tmp = (bn_data *) calloc(asize + bsize, sizeof(bn_data));
do_mult_base(bp, bsize, ap, asize, _tmp);
bn_data carry;
carry = _add_partial(cp, _tmp, asize + bsize, cp);
for (int i = asize + bsize; i < csize; i++) {
bn_data tmp1 = cp[i];
carry = (tmp1 += carry) < carry;
cp[i] = tmp1;
}
assert(carry == 0);
}
if (_tmp)
free(_tmp);
}
end:
if (!c->number[c->size - 1] && c->size > 1) // trim
bn_resize(c, c->size - 1);
if (tmp) {
bn_swap(tmp, c); // restore c
bn_free(c);
}
}
```
* 將原先乘法運算部份改寫成一個函式 `do_mult_base` 並定義切分界線 `KARATSUBA_MUL_THRESHOLD` (bignum 範例程式定為 32),因為 `do_mult_karatsuba` 函式假設 $a$ 與 $b$ 為相同 `size` ,因此判斷 $b$ 的 `size` 若小於 `KARATSUBA_MUL_THRESHOLD` 則執行一般的乘法運算 (即 `do_mult_base` 函式),否則執行 `do_mult_karatsuba` 函式
* 執行 `do_mult_karatsuba` 後若 $a$ 的 `size` 大於 $b$ 則要將 $a[bsize .. asize-1] \times b$ 加到 $c$ ,實作概念為判斷 `asize` 與 `bsize` 的差若大於等於 `bsize` ,則使用 Karatsuba 乘法計算,直到兩者的差小於 `bsize` 則使用一般乘法計算
接著看 `do_mult_karatsuba` 函式
* 一些初始化設置,將 $a$ 與 $b$ 各自分為 $a_1$、$a_0$、$b_1$、$b_0$
```c
void do_mult_karatsuba(const bn_data *a,
const bn_data *b,
bn_data size,
bn_data *c)
{
const int odd = size & 1;
const int even_size = size - odd;
const int half_size = even_size / 2;
const bn_data *a0 = a, *a1 = a + half_size;
const bn_data *b0 = b, *b1 = b + half_size;
bn_data *c0 = c, *c1 = c + even_size;
...
}
```
* 計算 $a_0 \times b_0$ 以及 $(a_1 \times b_1) \times 2^{2n}$ 並加到 $c$ ,這裡用遞迴方式實作上述提到的分治法
```c
/* c[0 ~ even_size-1] = a0*b0, c = 1*a0*b0 */
/* c[even_size ~ 2*even_size-1] += a1*b1, c += (2^2n)*a1*b1 */
if (half_size >= KARATSUBA_MUL_THRESHOLD) {
do_mult_karatsuba(a0, b0, half_size, c0);
do_mult_karatsuba(a1, b1, half_size, c1);
} else {
do_mult_base(a0, half_size, b0, half_size, c0);
do_mult_base(a1, half_size, b1, half_size, c1);
}
```
* 接著來計算 $2^n$ 項係數 (即 $z_1$),上述提到 $z_1=(a_1+a_0)(b_1+b_0)-z_0-z_2$ , 因為 $(a_1 + a_0)$ 及 $(b_1 + b_0)$ 為了解決溢位問題各自都要用 `(half_size + 1)` 的空間存放,相乘後的 `size` 會來到 `(even_size + 2)` 也就是 `(size + 1)` , 可想而知這樣對空間的使用效率不佳,因此可以進一步將 $z_1$ 的計算改寫成 $z_1 = |a_1 - a_0||b_0 - b_1| + z_0 + z_2$
* 將 $(z_0 + z_2) \times 2^n$ 加到 $c$,因為 $z_0$ 及 $z_2$ 前面已經算過了,分別放在 `c[0..even_size-1]` 及 `c[even_size..2*even_size-1]`,因此只需要將該部份取出並累加到 $c$ 即可
```c
/* since we have to add a0*b0 and a1*b1 to
* c[half_size ~ half_size+even_size-1] to obtain
* c = (2^2n + 2^n)a1*b1 + (2^n + 1)a0*b0,
* we have to make a copy of either a0*b0 or a1*b1 */
bn_data *tmp = (bn_data *) malloc(sizeof(bn_data) * even_size);
for (int i = 0; i < even_size; i++)
tmp[i] = c0[i];
/* c[half_size ~ half_size + even_size-1] += a1*b1 + a0*b0
* c += (2^n)*(a1*b1 + a0*b0)
* now c = (2^2n)a1*b1 + (2^n)*(a1*b1 + a0*b0) + a0*b0 */
bn_data carry = 0;
for (int i = 0; i < even_size; i++) {
bn_data in1 = c[half_size + i];
bn_data in2 = c1[i];
bn_data in3 = tmp[i];
carry = (in1 += carry) < carry;
carry += (c[half_size + i] = in1 + in2) < in2;
carry += (c[half_size + i] += in3) < in3;
}
```
* 計算 $|a_1 - a_0|$ 及 $|b_0 - b_1|$ ,可以注意到前面宣告用來暫存 `c[0..even_size-1]` 的變數 `tmp` 已經不需要使用,因此這邊可以將 $|a_1 - a_0|$ 存放到 `tmp[0..half_size-1]` $|b_0 - b_1|$ 存放到 `tmp[half_size..even_size-1]`, 減少了額外配置空間的成本
```c
/* calc |a1-a0| */
bn_data *a_tmp = tmp;
bool neg = bn_cmp(a1, half_size, a0, half_size) < 0;
if (neg)
_sub_partial(a0, a1, half_size, a_tmp);
else
_sub_partial(a1, a0, half_size, a_tmp);
/* calc |b0-b1| */
bn_data *b_tmp = tmp + half_size;
if (bn_cmp(b0, half_size, b1, half_size) < 0) {
_sub_partial(b1, b0, half_size, b_tmp);
neg ^= 1;
} else {
_sub_partial(b0, b1, half_size, b_tmp);
}
```
* 計算 $|a_1 - a_0||b_0 - b_1|$ 的方法與 Karatsuba 乘法相同
```c
/* tmp = |a1-a0||b0-b1| */
tmp = (bn_data *) calloc(even_size, sizeof(bn_data));
if (half_size >= KARATSUBA_MUL_THRESHOLD)
do_mult_karatsuba(a_tmp, b_tmp, half_size, tmp);
else
do_mult_base(a_tmp, half_size, b_tmp, half_size, tmp);
free(a_tmp);
```
* 將 $|a_1 - a_0||b_0 - b_1| \times 2^n$ 加到 $c$
```c
/* Now add / subtract (a1-a0)*(b0-b1) from
* c[half_size..half_size+even_size-1] based on whether it is negative or
* positive.
*/
if (neg)
carry -= _sub_partial(c + half_size, tmp, even_size, c + half_size);
else
carry += _add_partial(c + half_size, tmp, even_size, c + half_size);
free(tmp);
```
* 將上面產生的 `carry` 加到 $c$
```c
/* add carry to c[even_size+half_size ~ 2*even_size-1] */
for (int i = even_size + half_size; i < even_size << 1; i++) {
bn_data tmp1 = c[i];
carry = (tmp1 += carry) < carry;
c[i] = tmp1;
} // carry should be zero now!
```
* 現在已經計算好 `a[0..even_size-1]` $\times$ `b[0..even_size-1]`,但若 $a$ 與 $b$ 皆具奇數 `size`, 舉例來說 $a = a_2 a_1 a_0$ 、 $b = b_2 b_1 b_0$ ,則 $a \times b$ 為
$$
\begin{array}{r}
&&& a_2 & a_1 & a_0\\
\times
&&& b_2 & b_1 & b_0\\\hline
&&& a_2b_0 & {\color{Red} {\boxed {a_1b_0}}} & {\color{Red}{\boxed {a_0b_0}}}\\
&& a_2b_1 & {\color{Red} {\boxed {a_1b_1}}} & {\color{Red} {\boxed {a_0b_1}}}\\
+&a_2b_2 & a_1b_2 & a_0b_2\\\hline
\end{array}
$$
紅色圈起來的部份即為 `a[0..even_size-1]`$\times$`b[0..even_size-1]` ,接下來要將剩餘部份加回去 $c$ ,因此我們需要加上 `a[size-1]`$\times$`b[0..size-2]` 以及
`b[size-1]`$\times$`a[0..size-1]`
```c
if (odd) {
/* We have the product a[0..even_size-1] * b[0..even_size-1] in
* c[0..2*even_size-1]. We need to add the following to it:
* a[size-1] * b[0..size-2]
* b[size-1] * a[0..size-1] */
c[even_size * 2] =
_mult_partial(b, even_size, a[even_size], c + even_size);
c[even_size * 2 + 1] =
_mult_partial(a, size, b[even_size], c + even_size);
}
```
:::
實驗結果如下 (v8 藍線)
![](https://hackmd.io/_uploads/ryptWAlv3.png)
#### 實作 Karatsuba 平方運算
> 實作參考 [KYG-yaya573142](https://github.com/KYG-yaya573142/fibdrv/blob/d8bbd795b11fa8a473f03a8bcb42c4ce2f1f8d62/bn.c#L634) 與 [bignum/sqr.c](https://github.com/sysprog21/bignum/blob/master/sqr.c)
:::spoiler 程式碼解析
基本上跟 karatsuba 乘法運算一樣,差別就在平方運算的乘數與被乘數一樣,因此不需要去額外處理兩者 `size` 不同的情況。
* `bn_sqr` 函式內判斷 `src` 的 `size` 是否小於 `KARATSUBA_SQR_THRESHOLD` (bignum bignum 範例程式定為 64),小於則執行 `do_sqr_base` 函式(一般的平方運算),否則執行 `do_sqr_karatsuba` 函式
```c
/* c = a^2 */
void bn_sqr(bn *dest, const bn *src)
{
// int d = a->size * 2;
bn *tmp;
/* make it work properly when c == a */
if (dest == src) {
tmp = dest; // save c
dest = bn_alloc(src->size * 2);
} else {
tmp = NULL;
for (int i = 0; i < dest->size; i++)
dest->number[i] = 0; // clean up c
bn_resize(dest, src->size * 2);
}
if (src->size < KARATSUBA_SQR_THRESHOLD) {
do_sqr_base(src->number, src->size, dest->number);
} else {
do_sqr_karatsuba(src->number, src->size, dest->number);
}
if (!dest->number[dest->size - 1] && dest->size > 1) // trim
bn_resize(dest, dest->size - 1);
if (tmp) {
bn_swap(tmp, dest);
bn_free(dest);
}
}
```
```c
void do_sqr_base(const bn_data *src, bn_data ssize, bn_data *dest)
{
bn_data *dp = dest + 1;
const bn_data *sp = src;
bn_data size = ssize - 1;
for (int i = 0; i < size; i++) {
/* calc the (ab bc bc) part */
dp[size - i] = _mult_partial(&sp[i + 1], size - i, sp[i], dp);
dp += 2;
}
/* Double it */
for (int i = 2 * ssize - 1; i > 0; i--)
dest[i] = dest[i] << 1 | dest[i - 1] >> (DATA_BITS - 1);
dest[0] <<= 1;
/* add the (aa bb cc) part at diagonal line */
dp = dest;
sp = src;
size = ssize;
bn_data carry = 0;
for (int i = 0; i < size; i++) {
bn_data high, low;
__asm__("mulq %3" : "=a"(low), "=d"(high) : "%0"(sp[i]), "rm"(sp[i]));
high += (low += carry) < carry;
high += (dp[0] += low) < low;
carry = (dp[1] += high) < high;
dp += 2;
}
}
```
* `do_sqr_karatsuba` 函式實作與 `do_mult_karatsuba` 基本一樣,因此以下只講不同的地方
* 計算好 $dest = (2^{2n}+2^n)sp1^2 + (2^n + 1)sp0^2$ 之後要計算 $(sp1 - sp0)(sp0 - sp1) = -(sp1 - sp0)^2$
```c
/* (sp1-sp0)(sp0-sp1) = -|sp1-sp0|^2 */
if (bn_cmp(sp1, half_size, sp0, half_size) < 0)
_sub_partial(sp0, sp1, half_size, tmp);
else
_sub_partial(sp1, sp0, half_size, tmp);
```
* 計算好 $(sp1 - sp0)^2$ 後要計算 $dest = dest - (sp1 - sp0)^2$
```c
/* dest[half_size ~ half_size+even_size-1] += -(sp1-sp0)^2 */
carry -= _sub_partial(dest + half_size, tmp1, even_size, dest + half_size);
```
:::
實驗結果如下 (v9 藍線)
![](https://hackmd.io/_uploads/Skl_tN4vh.png)
> 圖中設定的閾值與 bignum 一樣,經實驗驗證放大或縮小閾值並不會顯著提升效能
v9 版本的 call graph (只擷取部份內容)
```shell
# Children Self Command Shared Object Symbol
# ........ ........ ....... .................... ..................................
#
95.39% 0.00% test test [.] main
|
---main
fast_doubling
|
|--60.19%--bn_sqr
| do_sqr_karatsuba
| |
| --59.59%--do_sqr_karatsuba
| |
| --59.00%--do_sqr_karatsuba
| |
| --57.37%--do_sqr_karatsuba
| |
| |--53.33%--do_sqr_karatsuba
| | |
| | |--48.07%--do_sqr_karatsuba
| | | |
| | | |--32.67%--do_sqr_karatsuba
| | | | |
| | | | |--30.30%--do_sqr_base
| | | | |
| | | | --0.59%--__memset_avx2_unaligned_erms
| | | |
| | | --10.05%--do_sqr_base
| | |
| | --3.49%--do_sqr_base
| |
| --1.08%--do_sqr_base
|
--35.21%--bn_mult
do_mult_karatsuba
|
|--34.61%--do_mult_karatsuba
| |
| --34.02%--do_mult_karatsuba
| |
| --33.43%--do_mult_karatsuba
| |
| --31.17%--do_mult_karatsuba
| |
| |--28.95%--do_mult_karatsuba
| | |
| | |--26.10%--do_mult_karatsuba
| | | |
| | | |--20.18%--do_mult_karatsuba
| | | | |
| | | | |--17.21%--do_mult_base
| | | | |
| | | | |--0.59%--__GI___libc_free (inlined)
| | | | |
| | | | --0.59%--__libc_calloc
| | | | _int_malloc
| | | |
| | | |--2.96%--do_mult_base
| | | |
| | | --0.59%--__GI___libc_free (inlined)
| | | _int_free
| | |
| | |--2.26%--do_mult_base
| | |
| | --0.59%--__libc_calloc
| | _int_malloc
| |
| --1.04%--do_mult_base
|
--0.59%--__memset_avx2_unaligned_erms
asm_exc_page_fault
exc_page_fault
do_user_addr_fault
handle_mm_fault
__handle_mm_fault
handle_pte_fault
```
可見使用 Karatsuba 演算法後乘法與平方運算的時間佔比合計從 v7 版本的 99.32% 降低為 95.4% ,但仍然是程式執行時間佔比最高的運算,其中可以看到使用 Karatsuba 演算法後 call graph 出現多次相同函式的遞迴呼叫,如此可以看出 Karatsuba 演算法的實作主要依賴遞迴方式實現
### 改進 `bn_to_string`
原始 `bn_to_string` 的實作原理是不斷將大數除以 10 取餘數,並將大數更新為商數,直到大數為零,此作法時間複雜度為 $O((digit + size) \times log_{10}(x))$ , digit 為大數的二進制位元數、size 為大數的大小、$x$ 為大數,參考 [KYG-yaya573142](https://hackmd.io/@KYWeng/rkGdultSU#bn-%E8%B3%87%E6%96%99%E7%B5%90%E6%A7%8B) 的實作方式從大數的 MSB 起逐位元將數值累加到字串當中,此作法的時間複雜度為 $O(digit \times log_{10}(x))$
```c
/*
* output bn to decimal string
* Note: the returned string should be freed with the free()
*/
char *bn_to_string(const bn *src)
{
// log10(x) = log2(x) / log2(10) ~= log2(x) / 3.322
size_t len = (8 * sizeof(bn_data) * src->size) / 3 + 2;
char *s = (char *) malloc(len);
char *p = s;
memset(p, '0', len - 1);
p[len - 1] = '\0';
/* src.number[0] contains least significant bits
* s[len - 2] contains least significant digit
*/
for (int i = src->size - 1; i >= 0; i--) {
for (bn_data d = MSB_MASK; d; d >>= 1) {
/* binary -> decimal string based on binary presentation */
int carry = !!(d & src->number[i]);
// add carry to p[len-2 .. 0]
for (int j = len - 2; j >= 0; j--) {
p[j] += p[j] - '0' + carry;
carry = (p[j] > '9');
if (carry)
p[j] -= 10;
}
}
}
// skip leading zero
while (p[0] == '0' && p[1] != '\0') {
p++;
}
memmove(s, p, strlen(p) + 1);
return s;
}
```
實驗結果如下 (v1 綠線)
![](https://hackmd.io/_uploads/S17qco_v2.png)
參考 [bignum/format.c](https://github.com/sysprog21/bignum/blob/master/format.c) 的實作方式,先定義 `max_radix` 為 $10^{19}$ ,相當於 `uint64_t` 可表示的數值範圍中最大的 10 的冪的值,每次迴圈藉由將大數除以 `max_radix` 獲得一個 `uint64_t` 可表示的 10 進制的值,再用一般 2 進制轉 10 進制的方法將數值存放到字串當中
:::info
目前看到 [bignum/format.c](https://github.com/sysprog21/bignum/blob/master/format.c) 裡計算 $log_{10}(x)$ 的部份 (即 apm_string_size 函式),當中如果 radix 不為 2 的冪時,結果是回傳 `(radix_sizes[radix] * (size * APM_DIGIT_SIZE)) + 2` ,可以理解為了要無條件進位所以 +1 ,但為什麼會是 +2 呢?
> 目前的猜測是因為要彌補乘以 `radix_sizes[radix]` 所產生的誤差[name=ericlai1021]
:::
:::spoiler 程式碼解析
```c
void bn_fprint(bn_data *sp, bn_data size, FILE *fp)
{
const size_t len = ((radix_size * (size * BN_WSIZE)) + 2) + 1;
char *str = (char *) malloc(len);
char *p = bn_to_string(sp, size, str);
fprintf(fp, "%s\n", p);
free(str);
}
```
* 配置 $log_{10}(x)$ 大小的空間存放轉換後的值,`radix_size` 為表示 1 Byte 的數值所需的 10 進制位數, `BN_WSIZE` 為一個 word 的大小 (單位為 `Byte` ),最後加上一個字元的大小存放字串結尾符號 `\0`
* `bn_to_string` 函式傳入大數的數字陣列開頭指標、大數的 `size` 及輸出的字串的開頭指標,並回傳轉換後的字串開頭指標
`bn_to_string` 函式
* 分成多精度運算與單精度運算,多精度運算會將大數除以 `MAX_RADIX` (即 $10^{19}$) 取得餘數為 `uint64_t` 可表示的大數轉換成 10 進制的數值,將大數更新為商數;接著將餘數透過單精度運算取得 10 進制的個別位數
```c
do {
/* Multi-precision: divide U by largest power of RADIX to fit in
* one apm_digit and extract remainder.
*/
bn_data remainder = bn_ddivi(sp, size, MAX_RADIX);
size -= (sp[size - 1] == 0U);
/* Single-precision: extract K remainders from that remainder,
* where K is the largest integer such that RADIX^K < 2^BITS.
*/
unsigned int i = 0;
do {
bn_data rq = remainder / 10;
bn_data rr = remainder % 10;
*outp++ = radix_chars[rr];
remainder = rq;
if (size == 0 && remainder == 0) /* Eliminate any leading zeroes */
break;
} while (++i < MAX_POWER);
/* Loop until TMP = 0. */
} while (size != 0);
```
`bn_ddivi` 函式
* 此函式傳入大數的數字陣列、大數的 `size` 及 `MAX_RADIX` ,將大數除以 `MAX_RADIX` ,商數更新為新的大數,並回傳餘數
* 為了有效解決除法結果溢位的問題,使用內嵌組合語言 (inline assembly) 指令 `divq` ,此指令的輸入為被除數的高位與低位以及除數,並輸出商數與餘數
```c
static bn_data bn_ddivi(bn_data *sp, bn_data size, bn_data div)
{
if (div == 1)
return 0;
if (!size)
return 0;
bn_data s1 = 0;
sp += size;
do {
bn_data s0 = *--sp;
bn_data q, r;
if (s1 == 0) {
q = s0 / div;
r = s0 % div;
} else {
digit_div(s1, s0, div, q, r); // use inline assembly
}
*sp = q;
s1 = r;
} while (--size);
return s1;
}
```
:::
實驗結果如圖 (v2 綠線)
![](https://hackmd.io/_uploads/ryedY2tv2.png)
放上與 bignum 的比較
![](https://hackmd.io/_uploads/r1E19hKvn.png)
> 使用 python 撰寫一驗證程式 [verify.py](https://gist.github.com/ericlai1021/c8ae9787ab97d29c41cd81401ea4683c) ,確保程式至少可以正確計算到第一百萬項
```shell
$ time ./test 1000000
Fib[1000000]
real 0m3.331s
user 0m3.331s
sys 0m0.000s
$ python3 verify.py
Please input a number: 1000000
Fib[1000000] -Pass
congratulations, you pass all test!!!
```
> 先附上最終的 [程式碼](https://gist.github.com/ericlai1021/551dc851c8bc80a63b1fda547ea75116) ,後續會逐步更新至 GitHub
## 減少 `copy_to_user` 傳送的資料量
原先對 fibdrv 呼叫 `read(fd, buf, size)` 時,會在 kernel space 將計算好的大數結構體 `bn` 轉換成十進制表示存放於字串當中,並透過 `copy_to_user` 將該字串從 kernel space 複製到 user space,但由於轉換後的字串每個字元只會存放 0~9 其中一個數值,因此光是傳遞一個字元就會浪費 4 位元的空間,為了減少空間的浪費,可以將大數轉換成十進制的操作 (即 `bn_to_string` 函式) 搬到 user space 來執行,讓 `copy_to_user` 直接傳遞大數結構體當中的二進制數值。
參考 [作業說明](https://hackmd.io/@sysprog/linux2023-fibdrv/%2F%40sysprog%2Flinux2023-fibdrv-b#%E8%A8%88%E7%AE%97%E4%BB%A5-2-%E7%82%BA%E5%BA%95%E6%95%B8%E7%9A%84%E5%B0%8D%E6%95%B8) 當中的實作,先計算大數的 leading zeros ,接著呼叫 `copy_to_user` 時不傳送全為 0 的位元組。
```c
static size_t my_copy_to_user(const bn *src, char __user *buf)
{
int lzbyte = bn_clz(src) >> 3;
size_t size = sizeof(bn_data) * src->size - lzbyte;
kt = ktime_get();
size_t sz = copy_to_user(buf, src->number, size);
kt = ktime_sub(ktime_get(), kt);
return size;
}
```
`bn_clz` 函式搭配 GCC 內建函式 `__builtin_clzll` 來計算大數的 leading zero bits 數量,其中 `>> 3` 右移操作計算 leading zeros 的位元組數量。針對 little-endian 架構,非零的位元組會被存在較低的記憶體位址,因此呼叫 `copy_to_user` 時只需要傳送 `數字陣列總 byte 數 - leading zero byte` 就可以不傳送全為 0 的位元組。
將複製的 byte 數量作為 `read` 的回傳值傳回 user
```c
/* calculate the fibonacci number at given offset */
static ssize_t fib_read(struct file *file,
char *buf,
size_t size,
loff_t *offset)
{
bn *fib = bn_alloc(1);
bn_fib_fast(fib, *offset);
size_t sz = my_copy_to_user(fib, buf);
bn_free(fib);
return sz; // return number of bytes that could not be copied
}
```
在 user space 中使用 `memcpy` 將 `buf` 字串內容複製到數字陣列後就可以執行 `bn_to_string` 函式將此數字陣列轉換成十進制字串表示
```c
...
lseek(fd, i, SEEK_SET);
size_t sz = read(fd, buf, 20900);
size_t size = (sz >> 3) + ((sz << 61) > 0);
uint64_t *number = malloc(sizeof(uint64_t) * size);
memcpy(number, buf, sz);
char *p = bn_to_string(number, size);
...
```
實驗結果如下 (計算到第 10 萬項,時間為 `copy_to_user` 函式執行的時間)
![](https://hackmd.io/_uploads/SkcFQyMOn.png)
* kernel 表示 `bn_to_string` 函式執行在 kernel space ,因此 `copy_to_user` 會傳送轉換後的字串
* user 表示 `bn_to_string` 函式執行在 user space , `copy_to_user` 會直接傳送大數的數字陣列
* 結果看出 `copy_to_user` 直接傳送大數的數字陣列確實可以有效節省空間
## 使用 hashtable 儲存已計算的 Fibonacci 數
> 參考資料: [Linux 核心的 hash table 實作](https://hackmd.io/@sysprog/linux-hashtable) 、 [chiangkd 同學的共筆](https://hackmd.io/@chiangkd/2023spring-fibdrv#%E4%BD%BF%E7%94%A8-hashtable-%E7%B4%80%E9%8C%84%E4%BB%A5%E8%A8%88%E7%AE%97%E9%81%8E%E7%9A%84%E5%80%BC)
### 初步引入 hashtable
預期引入 Linux 核心的 `hlist` 系列 API 儲存已經計算過的值,目前實作以 $Fib(n)$ 中的 n 作為 key
Linux 核心的 hash table 實作中,用以處理 hash 數值碰撞的 `hlist_node`:
```c
struct hlist_node {
struct hlist_node *next, **pprev;
};
```
* `pprev` 宣告成指標的指標為了方便之後刪除節點的操作,詳細解說請參閱 [Linux 核心的 hash table 實作](https://hackmd.io/@sysprog/linux-hashtable)
示意圖如下 :
```graphviz
digraph G {
rankdir=LR;
node[shape=record];
map [label="hlist_head.first |<ht0> |<ht1> |<ht2> |<ht3> |<ht4> |<ht5> |<ht7> |<ht8> "];
node[shape=none]
null1 [label=NULL]
null2 [label=NULL]
subgraph cluster_1 {
style=filled;
color=lightgrey;
node [shape=record];
hn1 [label="hlist_node | {<prev>pprev | <next>next}"];
label="hash_key 1"
}
/*
subgraph cluster_2 {
style=filled;
color=lightgrey;
node [shape=record];
hn2 [label="hlist_node | {<prev>pprev | <next>next}"];
label="hash_key 2"
}
*/
subgraph cluster_3 {
style=filled;
color=lightgrey;
node [shape=record];
hn3 [label="hlist_node | {<prev>pprev | <next>next}"];
label="hash_key 3"
}
map:ht1 -> hn1
hn1:next -> NULL
// hn2:next -> null1
// hn2:prev:s -> hn1:next:s
map:ht5 -> hn3
hn3:next -> null2
hn1:prev:s -> map:ht1
hn3:prev:s -> map:ht5
}
```
新增一個自定義結構 `hdata_node` 嵌入 `hlist_node` 並包含一個指向大數結構體 `bn` 的指標,用以儲存 value
```c
typedef struct _hdata_node {
bn *data;
struct hlist_node list;
} hdata_node;
```
```graphviz
digraph G {
rankdir=LR;
node[shape=record];
map [label="hlist_head.first |<ht0> |<ht1> |<ht2> |<ht3> |<ht4> |<ht5> |<ht7> |<ht8> "];
node[shape=none]
null1 [label=NULL]
null2 [label=NULL]
subgraph cluster_A {
subgraph cluster_1 {
style=filled;
color=lightgrey;
node [shape=record];
hn1 [label="hlist_node | {<prev>pprev | <next>next}"];
label="hash_key 1"
}
subgraph cluster_bn {
style=filled;
color=yellow;
node [shape=record]
bn1 [label="{number|size}"]
label=bn
}
label="hdata_node"
}
subgraph cluster_3 {
style=filled;
color=lightgrey;
node [shape=record];
hn3 [label="hlist_node | {<prev>pprev | <next>next}"];
label="hash_key 3"
}
map:ht1 -> hn1
hn1:next -> null1
// hn2:next -> null1
// hn2:prev:s -> hn1:next:s
map:ht5 -> hn3
hn3:next -> null2
hn1:prev:s -> map:ht1
hn3:prev:s -> map:ht5
}
```
將 $Fib(n)$ 當中的 n 作為 hashtable 的 key ,value 則指向計算後的 `bn` 結構體
```c
/* calculate the fibonacci number at given offset */
static ssize_t fib_read(struct file *file,
char *buf,
size_t size,
loff_t *offset)
{
bn *fib = NULL;
int key = (int) *offset;
/* hashtable method*/
kt = ktime_get();
if (is_in_ht(offset)) {
printk(KERN_INFO "find offset = %d\n", key);
fib = hlist_entry(htable[key].first, hdata_node, list)->data;
} else {
fib = bn_alloc(1);
dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
if (dnode == NULL)
printk("kcalloc failed \n");
bn_fib_fast(fib, *offset);
dnode->data = fib;
INIT_HLIST_NODE(&dnode->list);
hlist_add_head(&dnode->list, &htable[key]); // add to hash table
}
kt = ktime_sub(ktime_get(), kt);
size_t sz = my_copy_to_user(fib, buf);
return sz;
}
```
呼叫 `fib_read` 函式時先判斷 hashtable 該 key 值是否有值,若有值則直接從 hashtable 中取用,因為這裡 hashtable 設計為將 $Fib(n)$ 的 n 作為 key 值,所以不會發生 collision。
`is_in_ht(offset)` 函式判斷 hashtable 中的第 offset 個位址當中是否有值
```c
static int is_in_ht(loff_t *offset)
{
int key = (int) *(offset);
if (hlist_empty(&htable[key])) {
printk(KERN_INFO "No find in hash table\n");
return 0; /* no in hash table */
}
return 1;
}
```
* `hlist_empty` 函式判斷該 key 值對應到的 list 是否有值
執行 `client` ,測試程式為先計算 $Fib(0)$ 至 $Fib(100)$,接著再反著計算回來,使用 `printk` 測試是否有正確運行
```shell
[ 1305.396063] No find in hash table
[ 1305.396097] No find in hash table
[ 1305.396107] No find in hash table
...
[ 1305.397508] No find in hash table
[ 1305.397525] find offset = 100
[ 1305.397532] find offset = 99
[ 1305.397539] find offset = 98
...
[ 1305.398260] find offset = 0
```
在使用 `rmmod` 卸載 `fibdrv` 模組時會呼叫帶有 `__exit` macro 的函式, kernel 會將這個函式放入 read-only 的 `__exit` section 中,可參閱 [linux/init.h](https://github.com/torvalds/linux/blob/master/include/linux/init.h)
因此,需要在帶有 `__exit` macro 的 `exit_fib_dev` 函式當中對 hashtable 進行記憶體釋放
```c
static void __exit exit_fib_dev(void)
{
release_memory();
printk(KERN_INFO "successful release memory.\n");
...
}
```
```shell
...
[ 1305.398254] find offset = 1
[ 1305.398260] find offset = 0
[ 1331.330755] successful release memory.
```
`release_memory` 函式當中使用 `hlist_for_each_entry_safe` 走訪hashtable 當中的 list 的每個節點,將該節點中的大數結構體釋放並將該節點從 list 當中移除
```c
static void release_memory(void)
{
struct hlist_node *n = NULL;
/* go through and free hashtable */
for (int i = 0; i < MAX_LENGTH; i++) {
hlist_for_each_entry_safe(dnode, n, &htable[i], list)
{
bn_free(dnode->data);
hlist_del(&dnode->list);
kfree(dnode);
}
}
}
```
實驗結果如下
> 量測計算 $Fib(0)$ 至 $Fib(100000)$,並從 $Fib(100000)$ 至 $Fib(0)$ 的時間
![](https://hackmd.io/_uploads/rkwMPkGdh.png)
* 圖中的 x 座標不是 $Fib(n)$ ,而是第 n 次的時間量測,x 座標 `100001` 至 `200002` 為由 $Fib(100000)$ 至 $Fib(0)$ 的時間測量
在前半部份,也就是計算 $Fib(0)$ 至 $Fib(100000)$ 時,因為 hashtable 當中都沒有值,所以會如同原先的實作一樣,而後半部份因為 hashtable 中都已經有儲存計算過的值,所以會直接從 hashtable 中取用。
將後半部份的資料獨立出來看,可以看到整體趨勢已經是常數時間了
![](https://hackmd.io/_uploads/HkNmDJMOn.png)
### 引入 hashtable 機制至 fast doubling 演算法加速大數運算
概念如同上述作法,將 hashtable 機制引入到 `bn_fib_fast` 函式內,考慮 fast doubling 的特性,紀錄第 N 個和第 2N 個 Fibonacci 數
* 若 n 小於 2 ,則直接讓大數等於 n 並將其加入到 hashtable 當中, key 與 value 皆為 n
```c
if (n < 2) { // Fib(0) = 0, Fib(1) = 1
dest->number[0] = n;
dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
if (dnode == NULL)
printk("kcalloc failed \n");
dnode->data = dest;
INIT_HLIST_NODE(&dnode->list);
key = n;
hlist_add_head(&dnode->list, &htable[key]); // add to hash table
}
```
* 若 n 大於等於 2 ,則執行 fast doubling 演算法,主要可以分為兩部份的計算,一部分為計算 $F(2n)$ ,另一部份為計算 $F(2n + 1)$ ,並將其計算結果存於 hashtable 中
```c
else {
bn *tmp = NULL;
bn *b = bn_alloc(1);
tmp = hlist_entry(htable[0].first, hdata_node, list)->data; // extrct F(0)
printk(KERN_INFO "find offset = %d\n", 0);
bn_cpy(b, tmp); // copy F(0) to b
tmp = hlist_entry(htable[1].first, hdata_node, list)->data; // extrct F(1)
printk(KERN_INFO "find offset = %d\n", 1);
bn_cpy(dest, tmp); // copy F(1) to dest
/* F(2n - 1) = F(n)^2 + F(n - 1)^2
* F(2n) = F(n) * (2F(n - 1) + F(n))
*/
bn *t1 = bn_alloc(1);
int nbits = 32 - __builtin_clz(n);
key = 1;
for (int i = nbits - 2; i >= 0; i--) {
key <<= 1; // key = F(2n)
if (is_in_ht(key)) {
printk(KERN_INFO "find offset = %d\n", key);
tmp = hlist_entry(htable[key].first, hdata_node, list)->data; // extract F(2n)
bn_cpy(dest, tmp); // copy F(2n) to dest
tmp = hlist_entry(htable[key - 1].first, hdata_node, list)->data; // extract F(2n - 1)
bn_cpy(b, tmp); // copy F(2n - 1) to b
} else {
bn_lshift(t1, b, 1); // t1 = F(n - 1) * 2
bn_add(t1, dest, t1); // t1 = 2F(n - 1) + F(n)
bn_mult(dest, t1, t1); // t1 = F(n) * (2F(n - 1) + F(n)), now is F(2n)
bn_sqr(b, b); // b = F(n - 1)^2
bn_sqr(dest, dest); // dest = F(n)^2
bn_add(dest, b, b); // b = F(n)^2 + F(n - 1)^2, now is F(2n - 1)
bn_swap(dest, t1); // dest = F(2n)
/* add F(2n) to hashtable */
bn *tmp1 = bn_alloc(1);
bn_cpy(tmp1, dest);
dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
if (dnode == NULL)
printk("kcalloc failed \n");
dnode->data = tmp1;
INIT_HLIST_NODE(&dnode->list);
hlist_add_head(&dnode->list, &htable[key]); // add to hash table
}
if (n & (1U << i)) {
key++;
if (is_in_ht(key)) {
printk(KERN_INFO "find offset = %d\n", key);
bn_cpy(b, dest); // copy F(2n) to b
tmp = hlist_entry(htable[key].first, hdata_node, list)->data; // extract F(2n + 1)
bn_cpy(dest, tmp); // copy F(2n + 1) to dest
} else {
bn_swap(dest, b); // b = F(2n)
bn_add(dest, b, dest); // dest = F(2n + 1)
/* add F(2n + 1) to hashtable */
bn *tmp2 = bn_alloc(1);
bn_cpy(tmp2, dest);
dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
if (dnode == NULL)
printk("kcalloc failed \n");
dnode->data = tmp2;
INIT_HLIST_NODE(&dnode->list);
hlist_add_head(&dnode->list, &htable[key]); // add to hash table
}
}
}
dest = hlist_entry(htable[key].first, hdata_node, list)->data;
bn_free(t1);
bn_free(b);
}
```
* 迴圈每一輪都會先判斷 key 為 $2n$ 與 $2n + 1$ 是否已存在 hashtable 裡,若已有值,則直接從 hashtable 當中取出,否則才會執行一般 fast doubling 演算法對 $F(2n)$ 與 $F(2n + 1)$ 的計算
先用 `dmesg` 觀察執行過程
```shell
[90529.842755] find offset = 0
[90529.842763] find offset = 1
[90529.842767] No find in hash table // calculate F(2)
[90529.842782] find offset = 0
[90529.842785] find offset = 1
[90529.842788] find offset = 2
[90529.842790] No find in hash table // calculate F(3)
[90529.842804] find offset = 0
[90529.842807] find offset = 1
[90529.842810] find offset = 2
[90529.842813] No find in hash table // calculate F(4)
[90529.842827] find offset = 0
[90529.842830] find offset = 1
[90529.842832] find offset = 2
[90529.842835] find offset = 4
[90529.842838] No find in hash table // calculate F(5)
...
[90529.845965] find offset = 0
[90529.845968] find offset = 1
[90529.845971] find offset = 2
[90529.845973] find offset = 3
[90529.845976] find offset = 6
[90529.845979] find offset = 12
[90529.845982] find offset = 24
[90529.845984] find offset = 25
[90529.845987] find offset = 50
[90529.845990] No find in hash table // calculate F(100)
```
實驗結果如下 (計算至第十萬項)
![](https://hackmd.io/_uploads/SJ3N2gm_n.png)