owned this note
owned this note
Published
Linked with GitHub
# Linux 核心專題: 改進 fibdrv
> 執行人: ctfish7063
> [GitHub](https://github.com/ctfish7063/fibdrv)
> [專題解說影片](https://youtu.be/svybO4_juaM)
:::success
:question: 提問清單
* ?
:::
## 任務簡述
依據 [fibdrv](https://hackmd.io/@sysprog/linux2023-fibdrv) 作業規範,繼續投入 Linux 核心模組和相關程式的開發。
### 大數處理
原先的 `fibdrv` 因為 `uint64_t` 的限制,僅能計算至 $fib(92)$,需要提供新的資料結構以進行計算和儲存之用。
#### 資料結構
參考[作業說明 - 基於 list_head 的大數運算](https://hackmd.io/@sysprog/linux2023-fibdrv-d#%E5%9F%BA%E6%96%BC-list_head-%E7%9A%84%E5%A4%A7%E6%95%B8%E9%81%8B%E7%AE%97), 資料結構使用 linked-list,並以 linux 的 `list.h` 進行實作:
```cpp
/**
* bn_head - store the head of bn list
* @size: size of the list
* @sign: sign of the bn
* @list: list_head of the list
*/
typedef struct {
size_t size;
struct list_head list;
} bn_head;
/**
* bn_node - store a node of bn list
* The value should be within 10^19
* @val: value of the node
* @list: list_head of the node
*/
typedef struct {
uint64_t val;
struct list_head list;
} bn_node;
```
`bn_head` 僅儲存該鏈結串列的長度 $size$,而 `bn_node` 則以 `uint64_t` 的格式儲存資料。若以一個鏈結串列儲存一個大數,則可以將該鏈結串列看成一個有 $64 * size$ bits 的數。
經過測試,此資料結構可計算至 $fib(1000000)$ (以 [wolfarmalpha](https://www.wolframalpha.com/input/?i=fibonacci+10000) 作為基準)
```shell
fibdrv$ sudo ./client
malloc size: 10848
lseek to 1000000
str size: 208989
fib[1000000] in 1839909461 ns:
195328212870775773163201494759625633244354299659187339695340519457162525788701569476664198763415014612887952433522023608462551091201956023374401543811519663615691996
...
```
#### `bn_new`
以 [Binet formula](https://en.wikipedia.org/wiki/Fibonacci_sequence#Relation_to_the_golden_ratio) 可以近似出 $fib(n)$ 的值,將其取 $\frac{log2(fib(n))}{64}$ 即可計算所需要的 `bn_node` 的數量,可在一開始便配置好記憶體。
```c
#define DIVISOR 100000
#define LOG2PHI 69424
#define LOG2SQRT5 116096
static inline struct list_head *bn_new(size_t n)
{
unsigned int list_len = n > 1 ? (n * LOG2PHI - LOG2SQRT5) / DIVISOR / 64 + 1 : 1;
struct list_head *head = bn_alloc();
for (; list_len; list_len--) {
bn_newnode(head, 0);
}
return head;
}
```
#### `bn_add`
`bn` 的加法`bn_add` 會將第二個 `bn` 的值加至第一個 `bn` ,以 `__bn_add` 進行操作:
```c
void bn_add(struct list_head *a, struct list_head *b)
{
__bn_add(a, b);
}
```
`__bn_add` 則為 `bn` 加法的實作:
```c
#define bn_node_val(node) (list_entry(node, bn_node, list)->val)
void __bn_add(struct list_head *shorter, struct list_head *longer)
{
int carry = 0;
bn_node *node;
struct list_head *longer_cur = longer->next;
list_for_each_entry (node, shorter, list) {
uint64_t tmp = node->val;
node->val += bn_node_val(longer_cur) + carry;
carry = U64_MAX - tmp >= bn_node_val(longer_cur) + carry ? 0 : 1;
longer_cur = longer_cur->next;
if (longer_cur == longer) {
break;
}
}
while (longer_cur != longer) {
uint64_t tmp = bn_node_val(longer_cur);
bn_newnode(shorter, bn_node_val(longer_cur) + carry);
carry = U64_MAX - tmp >= carry ? 0 : 1;
longer_cur = longer_cur->next;
}
while (carry) {
if (bn_size(shorter) > bn_size(longer)) {
uint64_t tmp = bn_node_val(node->list.next);
bn_node_val(node->list.next) += carry;
carry = U64_MAX - tmp >= carry ? 0 : 1;
} else {
bn_newnode(shorter, carry);
break;
}
}
}
```
#### `bn_sub`
由於在計算 Fibonacci 數時不會出現負數,`bn_sub` 的實作是將兩數比較後將較大的 `bn` 減去較小的 `bn`:
```c
void bn_sub(struct list_head *a, struct list_head *b)
{
int cmp = bn_cmp(a, b);
if (cmp >= 0) {
__bn_sub(a, b);
} else {
__bn_sub(b, a);
}
}
void __bn_sub(struct list_head *more, struct list_head *less)
{
int carry = 0;
bn_node *node;
struct list_head *less_cur = less->next;
list_for_each_entry (node, more, list) {
uint64_t tmp =
(less_cur == less) ? carry : bn_node_val(less_cur) + carry;
if (node->val >= tmp && likely(bn_node_val(less_cur) != U64_MAX - 1)) {
node->val -= tmp;
carry = 0;
} else {
node->val += (U64_MAX - tmp) + 1;
carry = 1;
}
if (less_cur != less) {
less_cur = less_cur->next;
}
}
bn_node *last = list_last_entry(more, bn_node, list);
if (last->val == 0 && likely(!list_is_singular(more))) {
bn_size(more)--;
list_del(&last->list);
kfree(last);
}
}
```
#### `bn_mul`
`bn_mul` 會將兩個 `bn` 得相乘結果儲存在另外的 `bn` 中。此實作僅將每個 `bn_node` 分別作相乘並將結果加至對應的位置中,時間複雜度為 $O(m \times n)$:
```c
void bn_mul(struct list_head *a, struct list_head *b, struct list_head *c)
{
bn_node *node;
// zeroing c
list_for_each_entry (node, c, list) {
node->val = 0;
}
bn_node *node_a, *node_b;
struct list_head *base = c->next;
list_for_each_entry (node_a, a, list) {
uint64_t carry = 0;
struct list_head *cur = base;
list_for_each_entry (node_b, b, list) {
uint128_t tmp = (uint128_t) node_b->val * (uint128_t) node_a->val;
uint64_t n_carry = tmp >> 64;
if (U64_MAX - bn_node_val(cur) < tmp << 64 >> 64)
n_carry++;
bn_node_val(cur) += tmp;
if (U64_MAX - bn_node_val(cur) < carry)
n_carry++;
bn_node_val(cur) += carry;
carry = n_carry;
cur = cur->next;
}
while (carry) {
if (cur == c) {
bn_newnode(c, carry);
break;
}
uint64_t tmp = bn_node_val(cur);
bn_node_val(cur) += carry;
carry = U64_MAX - tmp >= carry ? 0 : 1;
cur = cur->next;
}
base = base->next;
}
}
```
#### `bn_lshift` 和 `bn_rshift`
由於 `bn_node` 中所儲存的資料為 `uint64_t`, 一次左移或右移的操作最多僅能移動 63 bits,實作中會暫存所要移動的 bits 數量,以每 63 bits 作為單位分次移動:
```cpp
void bn_lshift(struct list_head *head, int bit)
{
int tmp = bit;
for (; tmp > 64; tmp -= 63) {
__bn_lshift(head, 63);
}
__bn_lshift(head, tmp);
}
void __bn_lshift(struct list_head *head, int bit)
{
uint64_t carry = 0;
bn_node *node;
list_for_each_entry (node, head, list) {
uint64_t tmp = node->val;
node->val <<= bit;
node->val |= carry;
carry = tmp >> (64 - bit);
}
if (carry) {
bn_newnode(head, carry);
}
}
void bn_rshift(struct list_head *head, int bit)
{
int tmp = bit;
for (; tmp > 64; tmp -= 63) {
__bn_rshift(head, 63);
}
__bn_rshift(head, tmp);
}
void __bn_rshift(struct list_head *head, int bit)
{
uint64_t carry = 0;
bn_node *node;
list_for_each_entry_reverse(node, head, list)
{
uint64_t tmp = node->val;
node->val >>= bit;
node->val |= carry;
carry = tmp << (64 - bit);
}
if (bn_last_val(head) == 0) {
bn_pop(head);
}
}
```
#### `bn_to_array`
為了減少 `copy_to_user` 所需複製的大小,在複製前會將 `bn` 中的資料儲存至陣列:
```cpp
uint64_t *bn_to_array(struct list_head *head)
{
bn_clean(head);
uint64_t *res = kmalloc(sizeof(uint64_t) * bn_size(head), GFP_KERNEL);
int i = 0;
bn_node *node;
list_for_each_entry (node, head, list) {
res[i++] = node->val;
}
return res;
}
```
### 改進 `fibdrv` 效能
#### 加速運算 -- 使用 fast doubling
Fibonacci 數的定義為:
$F(0)=0, F(1)=1$
$F(n) = F(n-1) + F(n-2)$
其運算之時間複雜度為 $O(n)$,程式實作如下:
```cpp
static inline size_t fib_sequence_naive(long long k, uint64_t **fib)
{
if (unlikely(k < 0)) {
return 0;
}
// return fib[n] without calculation for n <= 2
if (unlikely(k <= 2)) {
*fib = kmalloc(sizeof(uint64_t), GFP_KERNEL);
(*fib)[0] = !!k;
return 1;
}
BN_INIT_VAL(a, 1, 0);
BN_INIT_VAL(b, 1, 1);
for (int i = 2; i <= k; i++) {
bn_add(a, b);
XOR_SWAP(a, b);
}
*fib = bn_to_array(b);
size_t ret = bn_size(b);
bn_free(a);
bn_free(b);
return ret;
}
```
而[`fast doubling`](https://www.nayuki.io/page/fast-fibonacci-algorithms) 的方法,可以將時間複雜度減少至 $O(log(n))$,以下則參考了作業解說中的 [Bottom-up 方法](https://hackmd.io/@sysprog/linux2023-fibdrv/%2F%40sysprog%2Flinux2023-fibdrv-d#Bottom-up-Fast-Doubling)進行實作:
```cpp
// fast doubling
static inline void fast_doubling(struct list_head *fib_n0,
struct list_head *fib_n1,
struct list_head *fib_2n0,
struct list_head *fib_2n1)
{
// fib(2n+1) = fib(n)^2 + fib(n+1)^2
// use fib_2n0 to store the result temporarily
bn_mul(fib_n0, fib_n0, fib_2n1);
bn_mul(fib_n1, fib_n1, fib_2n0);
bn_add(fib_2n1, fib_2n0);
// fib(2n) = fib(n) * (2 * fib(n+1) - fib(n))
bn_lshift(fib_n1, 1);
bn_sub(fib_n1, fib_n0);
bn_mul(fib_n1, fib_n0, fib_2n0);
}
static inline size_t fib_sequence(long long k, uint64_t **fib)
{
if (unlikely(k < 0)) {
return 0;
}
// return fib[n] without calculation for n <= 2
if (unlikely(k <= 2)) {
*fib = kmalloc(sizeof(uint64_t), GFP_KERNEL);
(*fib)[0] = !!k;
return 1;
}
// starting from n = 1, fib[n] = 1, fib [n+1] = 1
uint8_t count = 63 - CLZ(k);
BN_INIT_VAL(a, 0, 1);
BN_INIT_VAL(b, 1, 1);
BN_INIT(c, 0);
BN_INIT(d, 0);
int n = 1;
for (uint8_t i = count; i-- > 0;) {
fast_doubling(a, b, c, d);
if (k & (1LL << i)) {
bn_add(c, d);
XOR_SWAP(a, d);
XOR_SWAP(b, c);
n = 2 * n + 1;
} else {
XOR_SWAP(a, c);
XOR_SWAP(b, d);
n = 2 * n;
}
}
*fib = bn_to_array(a);
size_t res = bn_size(a);
bn_free(a);
bn_free(b);
bn_free(c);
bn_free(d);
return res;
}
```
將輸出時間使用 gnuplot 作圖,可以看出兩者之間的耗時相差甚巨(僅使用 `ktime` 測量上述函式所用時間):
![](https://hackmd.io/_uploads/HJGzNeLLh.png)
#### 減少 `copy_to_user` 的傳送量
`bn_to_array` 會將 `bn` 轉換為 `uint64_t` 的陣列,若單純使用 `copy_to_user` 複製該陣列, 在其儲存的元素所使用的空間小於 64 bit 的情況下將會多複製了不必要的空間。
針對 little-endian 架構,非零的位元組會被存在較低的記憶體位址。以 $fib(100)$ 為例,需要兩個 `uint64_t` 來儲存,非零的位元組數為 9 個:
```shell
$ sudo ./client
| 00 | 01 | 02 | 03 | 04 | 05 | 06 | 07 | 08 | 09 | 10 | 11 | 12 | 13 | 14 | 15 |
| c3 | bf | 94 | c5 | a7 | 76 | db | 33 | 13 | 00 | 00 | 00 | 00 | 00 | 00 | 00 |
fib[100]: 354224848179261915075
```
參考作業說明中的方法,使用 `gcc` 內建的 `__builtin_clzll` 計算陣列最後一元素的 leading zeros 之後僅從陣列複製剩餘的非零位元組(以上述例子來說便是 9 個):
```cpp
static size_t my_copy_to_user(char *buf, uint64_t *src, size_t size)
{
size_t lbytes = src[size - 1] ? CLZ(src[size - 1]) >> 3 : 7;
size_t i = size * sizeof(uint64_t) - lbytes;
printk(KERN_INFO "fibdrv: total %zu bytes, copy_to_user %zu bytes",
size * sizeof(uint64_t), i);
return copy_to_user(buf, src, i);
}
```
其中 `lbytes` 為避免 `src` 為僅一元素 0 的狀況,須額外判斷並保留至少一位元組。
可以用 `dmesg` 指令確認所計算之位元組數結果:
```shell
$ dmesg | grep fibdrv
[1306390.359899] fibdrv: reading on offset 100
[1306390.359925] fibdrv: total 16 bytes, copy_to_user 9 bytes
```
在 `client.c` 中可初始化 `buf` 為 uint64_t 的陣列並傳入 `read` 作為複製的目的地,在輸出時將其轉換為字串即可,程式碼如下:
```cpp
char *bn_2_string(uint64_t *head, int head_size, uint64_t n)
{
//log10(fib(n)) = nlog10(phi) - log10(5)/2
double tmp = n * 0.20898764025 - 0.34948500216;
size_t size = n > 1 ? (size_t)tmp + 2 : 2;
printf("str size: %zu\n", size);
char *res = malloc(sizeof(char) * size);
res[--size] = '\0';
if (n < 3) {
res[0] = !!head[0] + '0';
return res;
}
for (int i = size; --i >= 0;) {
uint128_t tmp = 0;
for (int j = head_size; --j >= 0;) {
tmp <<= 64;
tmp |= head[j];
head[j] = tmp / 10;
tmp %= 10;
}
res[i] = tmp + '0';
}
return res;
}
```
### 實驗
#### 實驗環境設定
參考 [yanjiew](https://hackmd.io/@yanjiew/linux2023q1-fibdrv#%E7%B3%BB%E7%B5%B1%E7%92%B0%E5%A2%83%E8%A8%AD%E5%AE%9A) 探討系統環境的設定,使用 `cset` 進行 cpu 的獨立:
```shell
$ sudo cset set -c 0-1 isolated
$ sudo cset set -c 2-7 others
$ sudo sh -c "echo 0 > /cpusets/isolated/sched_load_balance"
$ sudo cset proc -m root others
```
將執行緒指定於獨立出來的 cpu 執行:
```shell
$ sudo cset proc -e isolated -- sh -c './test > data.txt'
```
#### 統計方法
![](https://hackmd.io/_uploads/BJvPcPPI2.png)
參考作業說明中的 [python script](https://hackmd.io/@sysprog/linux2023-fibdrv/%2F%40sysprog%2Flinux2023-fibdrv-c#%E7%94%A8%E7%B5%B1%E8%A8%88%E6%89%8B%E6%B3%95%E5%8E%BB%E9%99%A4%E6%A5%B5%E7%AB%AF%E5%80%BC) 並引入 [scripts/preprocess.py](https://github.com/ctfish7063/fibdrv/blob/master/scripts/preprocess.py)中,假設資料分佈為自然分佈,將兩個標準差之外(即 $95%$ 的信賴區間)的數據去除後計算其平均值並作圖, 程式碼如下:
```python
def outlier_filter(datas, threshold = 2):
datas = np.array(datas)
if datas.std() == 0:
return datas
z = np.abs((datas - datas.mean()) / datas.std())
return datas[z < threshold]
def data_processing(data_set, n):
catgories = data_set[0].shape[0]
samples = data_set[0].shape[1]
final = np.zeros((catgories, samples))
if np.isnan(data_set).any():
print("Warning: NaN detected in data set")
for c in range(catgories):
for s in range(samples):
final[c][s] = \
outlier_filter([data_set[i][c][s] for i in range(n)]).mean()
return final
```
#### 效能分析
## TODO: 紀錄閱讀作業說明中所有的疑惑
閱讀 [fibdrv](https://hackmd.io/@sysprog/linux2023-fibdrv) 作業規範,包含「作業說明錄影」和「Code Review 錄影」,本著「誠實面對自己」的心態,在本頁紀錄所有的疑惑,並與授課教師預約討論。
過程中,彙整 [Homework3](https://hackmd.io/@sysprog/linux2023-homework3) 學員的成果,挑選至少三份開發紀錄,提出值得借鏡之處,並重現相關實驗。
### Schönhage–Strassen Algorithm
**Q:** 依據作業說明中的解釋, 此演算法是將大數分成小數字後將小數們構成的向量線性捲積最後將其進位,此算法看似跟長乘法相似,不知差異為何?
**A:** 線性卷積可使用 FFT 和 iFFT 進行計算;在考量定義域的情況下,也可以使用[數論轉換](https://en.wikipedia.org/wiki/Discrete_Fourier_transform_over_a_ring)在整數環上計算。
## TODO: 回覆「自我檢查清單」
回答「自我檢查清單」的所有問題,需要附上對應的參考資料和必要的程式碼,以第一手材料 (包含自己設計的實驗) 為佳
## TODO: 以 [sysprog21/bignum](https://github.com/sysprog21/bignum) 為範本,實作有效的大數運算
理解其中的技巧並導入到 fibdrv 中,並留意以下:
* 在 Linux 核心模組中,可用 ktime 系列的 API;
* 在 userspace 可用 clock_gettime 相關 API;
* 善用統計模型,除去極端數值,過程中應詳述你的手法
* 分別用 gnuplot 製圖,分析 Fibonacci 數列在核心計算和傳遞到 userspace 的時間開銷,單位需要用 us 或 ns (自行斟酌)
## TODO: 實作更快速的乘法運算
參照 [Schönhage–Strassen algorithm](https://en.wikipedia.org/wiki/Sch%C3%B6nhage%E2%80%93Strassen_algorithm),在上述大數運算的程式碼基礎之上,改進乘法運算,確保在大多數的案例均可加速,需要有對應的驗證機制。
### 演算法原理
#### 基於捲積的直式乘法
以下為一個 $base$ 進位的直式乘法 $123_{base} \times 456 _{base}$:
```
1 2 3
x 4 5 6
---------------------------
6 12 18
5 10 15
4 8 12
---------------------------
4 13 28 27 18
```
假設 $base = 10$,可以將計算的結果進行進位以得到 $123_{10} \times 456_{10} = 56088_{10}$:
```
1 2 3
x 4 5 6
---------------------------
6 12 18
5 10 15
4 8 12
---------------------------
4 13 28 27 18
--------------------------- carry in base = 10
5 6 0 8 8
```
若將$123_{base}$ 視為一長度 $N_1 = 3$ 的序列 $x[n] = \{1, 2, 3\}$, $456_{base}$ 視為長度 $N_2 = 3$ 的有列 $y[n]=\{4, 5, 6\}$,上述計算便可以視為兩序列的 [linear convolution](https://en.wikipedia.org/wiki/Convolution):
$$
(x*y)[n] = \sum_{m=-\infty}^{\infty} x[m]y[n-m]
$$
其結果則為一長度 $N=N_1+N_2-1=5$ 的序列 $(x*y)[n]=\{4,13,28,27,18\}$
在此介紹另一種卷積 [circular convolution](https://en.wikipedia.org/wiki/Circular_convolution),定義如下:
$$
(f*g_N)[n]\equiv\sum_{m=0}^{N-1}(\sum_{k=-\infty}^\infty f[m+kN])\ g_N[n-m]
$$
可以發現關係式與 `linear convolution` 相似,主要的差別在於序列 $g_N$ 有一週期 $N$。若對 $x[n]$ 和 $y[n]$ 作 `circular convolution` 可得到長度 $N=3$ 的序列 $\{28, 31, 31\}$:
```
1 2 3
x 4 5 6
-----------------------
6 12 18
10 15 *5
12 *4 *8
-----------------------
28 31 31
```
可以發現原先在 `linear convolution` 中超過 $N=3$ 的項( `*` 號項目) 繞回了序列的尾端。
由此可知若將 $x[n]$ 和 $y[n]$ 兩序列補上 $2$ 個 $0$ 延伸成長度 $N=5$ 的序列並對他們作 `circular convolution` (週期 $N=5$):
```
0 0 1 2 3
x 0 0 4 5 6
-------------------------------------
0 0 6 12 18
0 5 10 15 *0
4 8 12 *0 *0
-------------------------------------
4 13 28 27 18
```
其運算結果等同於對兩序列作 `linear convolution` (須將長度補到至少 $N_1+N_2-1$,補 $0$ 的動作稱為 `zero padding`)。
根據 [convolution theorem](https://en.wikipedia.org/wiki/Convolution_theorem),兩序列的 circular convolution 會等於兩序列的 [discrete fourier transform]((https://en.wikipedia.org/wiki/Discrete_Fourier_transform)) (DFT) 進行 `element-wise multiplication` 後再進行 `inverse DFT`,即:
```
CircularConvolution(X, Y) = IDFT(DFT(X) · DFT(Y))
```
而 `DFT` 可由 `FFT` 演算法進行加速。
若要計算 $x \times y = z$,其流程為:
1. 分割被乘數 $x$ 與乘數 $y$ 為序列 $x[n]$ 和 $y[n]$ 並進行 `zero padding`
2. 使用 `FFT` 計算 $X[k] = DFT(x[n])$ 和 $Y[k]=DFT(x[n])$
3. 利用 `Schönhage–Strassen algorithm` 遞迴地計算 $Z[k]=X[k]*Y[k]$
4. 使用 `FFT` 計算 $z[n]=IDFT(Z[k])$
5. 對 $z[n]$ 進行進位操作得到 $z$
#### 時間複雜度比較
標準的直式乘法是一項一項相乘, 因此時間複雜度為 $O(n^2)$。
若假設將一長度 $n=2^k$ 位元的數分成 $B$ 個 $L$ 位元的段落,使用 `Schönhage–Strassen algorithm` 計算時間複雜度為: $O(n\ log\ n\ log\ log\ n)$,推導如下:
設計算時間為
$$
M(n) = BM(L) + O(B\ log B)\ M(L) +O(n\ logB)
$$
其中:
* $BM(L)$ 為 $Z[k]=X[k]*Y[k]$ 的計算時間(遞迴計算)
* $O(B\ log B)\ M(L)$ 計算 `ntt` 和 `intt` 中的乘法的運算時間(遞迴計算)
* $O(n\ logB)$ 為計算`ntt` 和 `intt` 中的加減法的運算時間
假設 $B=n^\alpha$, $L=n^{1-\alpha}$,可得:
$$
M(n) = n^\alpha M(2n^{1-\alpha})+ O(n\ logn)
$$
可以分為三種情況:
1. $\alpha < \frac{1}{2}$, 此時$M(n)=O(n\ log^2n)$
2. $\alpha > \frac{1}{2}$, 此時$M(n)=O(n\ logn)$
3. $\alpha = \frac{1}{2}$, 此時$M(n)=O(n\ log\ n\ log\ log\ n)$
因此在 $B=L=\sqrt n$ 時為最佳解;在 $k$ 為奇數時可選擇 $B=\sqrt{\frac{n}{2}}, L = \sqrt{2n}$ 或 $B=\sqrt{2n}, L = \sqrt{\frac{n}{2}}$,並不會影響最終的時間複雜度。
### 實作原理
#### 多項式乘法
參考〈[快速傅立葉轉換](https://hackmd.io/@8dSak6oVTweMeAe9fXWCPA/H1y3L57Yd#%E5%BF%AB%E9%80%9F%E5%82%85%E7%AB%8B%E8%91%89%E8%BD%89%E6%8F%9B-FFT)〉一文關於求解多項式函數相乘的方法:
設一多項式 $A(x)=\sum_{j=0}^{N-1}a_jx^j$,可以用 $N$ 個點 $\{(x_0,A_0)...(x_{N},A_{N})\}$ 來表示這個 $N-1$ 階多項式;反之若給定 $N$ 個 $x_j$, 可以用計算出的的 $N$ 個 $A_j$ 反推回原式(此時稱該多項式 $degree\ bound$ 為 $N$)。
若要計算兩個 $N$ 階多項式相乘 $C(x) = A(x) \times B(x)$, 我們可以在 $A(x)$ 和$B(x)$ 上分別找出 $2N-1$ 個點 $\{(x_0,A_0)...(x_{2N-1},A_{2N-1})\}$ 和 $\{(x_0,B_0)...(x_{2N-1},B_{2N-1})\}$ ,如此一來 $C(x)$ 便可以用 $\{(x_0,A_0B_0)...(x_{2N-1},A_{2N-1}B_{2N-1})\}$ 表示,計算的時間複雜度可以由將$A(x)$ 和 $B(x)$ 各項係數相乘的 $O(n^2)$ 縮減為 $O(n)$。
我們可以將一個數字的二進位表示拆解成一個多項式 $A(x=2^k)$,例如:
$$
\begin{split}
321_{10}&=101000001_2\\&=1*2^8+1*2^6+1=A(x=2)\\&=1*(2^4)^2+4*(2^4)+1=B(x=2^4)
\end{split}
$$
因此大數相乘問題其實可以看成多項式相乘問題,在計算完後將 $x$ 代入適合的 $2^k$ 即可 (實作上以 bit shift 進行)。
#### [FFT](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)
:::info
以下將以 radix-2 DIT([Cooley–Tukey FFT algorithm](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm)) 為主
:::
`DFT` 的公式為:
$$
\require{ams}
\begin{equation}
\tag{1}\label{eq:eq_1}
X_k=\sum_{j=0}^{N-1}x_j\ e^{-\frac{i2\pi}{N}kj}=\sum_{j=0}^{N-1}x_j\ w_N^{kj}
\end{equation}
$$
取$A(x)$ 在 $x_k=w_N^k=e^{-2\pi ki /N}$ 上的點 $y_k$ 便可以化成 `DFT` 型式:
$$
y_k=\sum_{j=0}^{N-1}a_jw_N^{kj}
$$
其中 $w_N^0...w_N^{N-1}$ 為 $x^N=1$ 的 $N$ 個根,並具有以下性質:
1. 對於所有整數 $n,d,k \geq 0,\ w_{dn}^{dk}=w_n^k$
2. 對於所有偶數 $n>0,\ w_n^{n/2}=w_2=-1$
3. 對於所有偶數 $n>0,(w_n^k)^2=(w_n^{k+n/2})^2=w_{n/2}^k$
4. 對於所有整數 $n,k\geq0,\ w_n^{k+n/2}=-w_n^k$
假設 $N$ 為 $2$ 的冪,若將 $A(x)$ 以次方的奇偶數分成兩部份:
$$
\begin{split}
A(x)&=a_0x^0+a_1x^1+a_2x^2+a_3x^3+\ldots+a_{N-1}x^{N-1}\\
&=(a_0x^0+a_2x^2+\ldots+a_{N-2}x^{N-2})+x(a_1x^0+a_3x^2+\ldots+a_{N-1}x^{N-2})
\end{split}
$$
令:
$$
A^{[0]}(x)=a_0+a_2x+a_4x^2+\ldots+a_{N-2}x^{N/2-1}\\
A^{[1]}(x)=a_1+a_3x+a_5x^2+\ldots+a_{N-1}x^{N/2-1}
$$
則 $A(x)$ 可表示為:
$$
\require{ams}
\begin{equation}
\tag{2}\label{eq:eq_2}
A(x)=A^{[0]}(x^2)+xA^{[1]}(x^2)
\end{equation}
$$
由上述性質3可知 $(w_N^k)^2=w_{N/2}^k$:
$$
A(w_N^k)=A^{[0]}((w_N^k)^2)+xA^{[1]}((w_N^k)^2)=A^{[0]}(w_{N/2}^k)+xA^{[1]}(w_{N/2}^k)
$$
可以發現 $A(x)$ 被拆成了 $A^{[0]}(x)$ 和 $A^{[1]}(x)$ 兩個子問題,但其 $degree\ bound$ 縮減為 $\frac{N}{2}$,因此可以將他一路分解到$degree\ bound$ 為 $1$ (此時直接回傳係數即可)。
在合併時,假設已計算出大小為 $\frac N2$的傅立葉轉換 $y_k^{[0]}=A^{[0]}(w_{N/2}^k)$ 和 $y_k^{[1]}=A^{[1]}(w_{N/2}^k)$,則合併後的結果 $y_{k=0,...,N}$ 為:
$$
y_k=y_k^{[0]}+w_n^k\ y_k^{[1]}
$$
而透過性質4.可以發現:
$$
\begin{split}
y_{k+\frac N2}&=A(w_N^{k+\frac N2})\\
&\overset{\eqref{eq:eq_2}}=A^{[0]}(w_N^{2k+N})+w_N^{k+\frac N2}A^{[1]}(w_N^{2k+N})\\
&=A^{[0]}(w_N^{2k})+w_N^{k+\frac N2}A^{[1]}(w_N^{2k})\\
&=A^{[0]}(w_{N/2}^k)-w_N^k\ A^{[1]}(w_{N/2}^k)\\
&=y_k^{[0]}-w_n^k\ y_k^{[1]}
\end{split}
$$
每次迭代時計算$y_k^{[0]}$ 和 $y_k^{[1]}$ 便可以同時計算$y_k$ 與 $y_{k+N/2}$ ,須計算的 $y_k$ 數量變成了一半,因此時間複雜度為$O(nlogn)$。
假設 $N=8$,演算法的遞迴關係如下面的樹狀圖所示:
![](https://hackmd.io/_uploads/Hyrv73SPn.png)
可以發現所有 children 是按照特定順序排列的 (將該 child 的 index 的位元反轉,如 $011_2 \implies 110_2$),若事先將其排列好,可以將其變成 bottom-up 的實作,參考〈[Cooley–Tukey FFT algorithm: Data reordering, bit reversal, and in-place algorithms](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#Data_reordering,_bit_reversal,_and_in-place_algorithms)〉的虛擬碼:
```c
algorithm iterative-fft is
input: Array a of n complex values where n is a power of 2.
output: Array A the DFT of a.
bit-reverse-copy(a, A) //copy the reversed array
n ← a.length
for s = 1 to log(n) do
m ← 2s
ωm ← exp(−2πi/m)
for k = 0 to n-1 by m do
ω ← 1
for j = 0 to m/2 – 1 do
t ← ω A[k + j + m/2]
u ← A[k + j]
A[k + j] ← u + t
A[k + j + m/2] ← u – t
ω ← ω ωm
return A
```
兩者的比較如下圖:
![](https://hackmd.io/_uploads/rkIv19BD3.png)
#### [Number Theoretic Transform(NTT)](https://en.wikipedia.org/wiki/Discrete_Fourier_transform_over_a_ring)
>The discrete Fourier transform is an abstract operation that can be performed in any algebraic ring; typically it's performed in the complex numbers, but actually performing complex arithmetic to sufficient precision to ensure accurate results for multiplication is slow and error-prone.
>
> -- <[Strassen algorithm - Choice of ring](https://en.wikipedia.org/wiki/Sch%C3%B6nhage%E2%80%93Strassen_algorithm#Choice_of_ring)>
`FFT` 中對複數的操作需要大量浮點數計算,若需要計算到一定的精準度將會消耗極大的運算資源,另外在[作業說明](https://hackmd.io/@sysprog/linux2023-fibdrv/%2F%40sysprog%2Flinux2023-fibdrv-b#-Linux-%E6%A0%B8%E5%BF%83%E7%9A%84%E6%B5%AE%E9%BB%9E%E6%95%B8%E9%81%8B%E7%AE%97)中也提到了核心文件中不建議使用浮點運算。由於 `DFT` 的性質大部分是基於 $w_n$ ($s.t. \ w_n^n=1$) 的性質,因此若是能在體之下能找到單位根也能夠進行運算。`NTT` 是 `DFT` 的特例,指的是在有限體之下所進行的 `DFT`。
參考<[數論轉換](https://observablehq.com/@andy0130tw/ntt)>一文,考慮一有限體 $\mathbb{F}_p$ ,$P=c\times 2^k+1$ 是質數,首先尋找 $r$ 滿足:
$$
\{r \ mod\ P,r^2 \ mod\ P,\ldots,r^{P-1} \ mod\ P\} = \{1,2,\ldots,(P-1)\}
$$
在 $r$ 和 $P$ 互質的情況下,根據[費馬小定理](https://en.wikipedia.org/wiki/Fermat%27s_little_theorem) $r^{P-1}\equiv 1(mod\ P)$,可以推論:
$$
r^m\not\equiv 1\ (mod\ P),\quad m\in \{1,2,\ldots,(P-2)\}
$$
也就是 $r$ 為 $\mathbb{F}_p$ 的[原根](https://en.wikipedia.org/wiki/Primitive_root_modulo_n),如此一來有限體中的每個數皆可用 $r$ 的冪表示。$\mathbb{F}_p$ 裡的 $2^k$ 階單位根 $x$ ($x^{2^k}=1$)可以用 $r^c$ 表示,將上面式子中的 $w_N$ 替換成 $r^c$ 即可,:
$$
(r^c)^{2k}=r^{c\times 2k}=r^{P-1}\equiv 1\ (mod \ P)
$$
### 程式碼實作
#### `NTT.h` (commit [d3385bc](https://github.com/ctfish7063/fibdrv/commit/d3385bcb7fac7e4050f30d369d91375ebf1ea4f0#diff-888808dbe784e97bcf05ebe170d3817159d1baffe98722507390e259653d0057))
##### `ntt`
`ntt` 使用了 `iterative FFT` 演算法,參考了上面的虛擬碼進行實作,在求解 `wm` 時使用快速冪。
```c
static inline int reverse_bits(int x, int n)
{
int result = 0;
for (int i = 0; i < n; i++) {
result <<= 1;
result |= (x & 1);
x >>= 1;
}
return result;
}
static inline uint64_t fast_pow(uint64_t x, uint64_t n, uint64_t p)
{
uint64_t result = 1;
while (n) {
if (n & 1) {
result = result * x % p;
}
x = x * x % p;
n >>= 1;
}
return result;
}
static inline void ntt(uint64_t *a, int n, uint64_t p, uint64_t g)
{
uint64_t len = 64 - CLZ(n - 1);
for (int i = 0; i < n; i++) {
if (i < reverse_bits(i, len)) {
a[reverse_bits(i, len)] ^= a[i];
a[i] ^= a[reverse_bits(i, len)];
a[reverse_bits(i, len)] ^= a[i];
}
}
for (int m = 2; m <= n; m <<= 1) {
uint64_t wm = fast_pow(g, (p - 1) / m, p);
for (int k = 0; k < n; k += m) {
uint64_t w = 1;
for (uint64_t j = 0; j < m / 2; j++) {
uint64_t t = w * a[k + j + m / 2] % p;
uint64_t u = a[k + j];
a[k + j] = (u + t) % p;
a[k + j + m / 2] = (u - t + p) % p;
w = w * wm % p;
}
}
}
}
```
##### `intt`
`IDFT` 的公式為:
$$
x_j=\sum_{k=0}^{N-1}X_k\ e^{\frac{i2\pi}{N}kj}=\sum_{k=0}^{N-1}X_k\ w_N^{kj}
$$
公式與 `DFT` 非常類似,僅有兩點需要更改:
1. 其中的 $w_N^{k}$ 與 $\eqref{eq:eq_1}$ 中的次方數差了負號(即倒數關係),在有限體中倒數可以[模反元素](https://en.wikipedia.org/wiki/Modular_multiplicative_inverse)代替,在實作中使用費馬小定理配合快速冪求解。
2. 在加總後須再乘以 $\frac{1}{N}$,一樣可以用費馬小定理求出模反元素。
除此兩點外其餘皆和 `ntt` 相同,程式碼如下:
```c
static inline void intt(uint64_t *a, int n, uint64_t p, uint64_t g)
{
uint64_t len = 64 - CLZ(n - 1);
for (int i = 0; i < n; i++) {
if (i < reverse_bits(i, len)) {
a[reverse_bits(i, len)] ^= a[i];
a[i] ^= a[reverse_bits(i, len)];
a[reverse_bits(i, len)] ^= a[i];
}
}
for (int m = 2; m <= n; m <<= 1) {
uint64_t wm = fast_pow(g, (p - 1) / m, p);
// modular inverse
wm = fast_pow(wm, p - 2, p);
for (int k = 0; k < n; k += m) {
uint64_t w = 1;
for (int j = 0; j < m / 2; j++) {
uint64_t t = w * a[k + j + m / 2] % p;
uint64_t u = a[k + j];
a[k + j] = (u + t) % p;
a[k + j + m / 2] = (u - t + p) % p;
w = w * wm % p;
}
}
}
// inv by Fermat's little theorem
uint64_t inv = fast_pow(n, p - 2, p);
for (int i = 0; i < n; i++) {
a[i] = a[i] * inv % p;
}
}
```
#### `bn_strassen`
由於數論轉換的限制,計算出的結果將落在有限體之內。為避免模數溢出,這裡將分割的大小 $L$ 固定為 $8$ 位元,同時不使用遞迴來計算 $Z[k]=X[k]\times Y[k]$,因此時間複雜度將提昇為 $O(nlogn)$,實作程式碼如下:
```c
void bn_strassen(struct list_head *a, struct list_head *b, struct list_head *c)
{
int a_size = bn_size(a) * per_size - CLZ(bn_last_val(a)) / chunck_size;
int b_size = bn_size(b) * per_size - CLZ(bn_last_val(b)) / chunck_size;
// could not do ntt if size is too small
if (a_size < 2 || b_size < 2) {
bn_mul(a, b, c);
return;
}
// zero padding
int size = nextpow2((uint64_t)(a_size + b_size - 1));
uint64_t *a_array = bn_split(a, size);
uint64_t *b_array = bn_split(b, size);
// number theoretic transform
ntt(a_array, size, mod, rou);
ntt(b_array, size, mod, rou);
// pointwise multiplication
for (int i = 0; i < size; i++) {
a_array[i] *= b_array[i] % mod;
}
// inverse ntt
intt(a_array, size, mod, rou);
// carrying
uint64_t carry = 0;
for (int i = 0; i < size; i++) {
a_array[i] += carry;
carry = a_array[i] >> chunck_size;
a_array[i] &= chunk_mask;
}
// convert to bn
for (; bn_size(c) < size / per_size;) {
bn_newnode(c, 0);
}
bn_node *node;
int i = 0;
list_for_each_entry (node, c, list) {
// four at a time : 8 bits to uint64_t
uint64_t val = 0;
for (int j = 0; j < val_size; j += chunck_size) {
if (i < size) {
val |= a_array[i++] << j;
} else if (carry) {
val |= (carry & chunk_mask) << j;
carry >>= chunck_size;
}
}
node->val = val;
}
if (carry) {
bn_newnode(c, carry);
}
bn_clean(c);
kfree(a_array);
kfree(b_array);
}
```
### 實驗
由於乘法的時間複雜度為其所佔位元數的函數,可以將乘數和被乘數初始化為 1 ,透過左右移的方式控制乘數和被乘數的位元數,並計算其 `trailing zeros` 來驗證乘法結果的準確性。測試程式如下:
```c
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "bn2.h"
#include "ntt2.h"
void bn_print(struct list_head *head)
{
uint64_t *res = bn_to_array(head);
char *ret = bn_2_string(res, bn_size(head), 100);
puts(ret);
free(ret);
}
long long getnanosec()
{
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return ts.tv_sec * 1000000000L + ts.tv_nsec;
}
long long bench(struct list_head *a,
struct list_head *b,
struct list_head *c,
void (*func_ptr)(struct list_head *,
struct list_head *,
struct list_head *))
{
long long st = getnanosec();
func_ptr(a, b, c);
long long ut = getnanosec();
return ut - st;
}
int main()
{
// test for fib
BN_INIT_VAL(a, 0, 1);
BN_INIT_VAL(b, 0, 1);
BN_INIT_VAL(c, 0, 1);
int a_ctz, b_ctz, c_ctz;
for (int i = 0; i <= 500000; i++) {
long long mul = bench(a, b, c, bn_mul);
a_ctz = CTZ(bn_last_val(a)) + 64 * (bn_size(a) - 1);
b_ctz = CTZ(bn_last_val(b)) + 64 * (bn_size(b) - 1);
c_ctz = CTZ(bn_last_val(c)) + 64 * (bn_size(c) - 1);
assert(a_ctz + b_ctz == c_ctz && "mul error");
long long strassen = bench(a, b, c, bn_strassen);
a_ctz = CTZ(bn_last_val(a)) + 64 * (bn_size(a) - 1);
b_ctz = CTZ(bn_last_val(b)) + 64 * (bn_size(b) - 1);
c_ctz = CTZ(bn_last_val(c)) + 64 * (bn_size(c) - 1);
assert(a_ctz + b_ctz == c_ctz && "strassen error");
printf("%i %lld %lld\n", i, mul, strassen);
bn_lshift(a, 1);
bn_lshift(b, 1);
}
return 0;
}
```
測試結果如下:
![](https://hackmd.io/_uploads/S1hpvIXd3.png)
可以觀察到 `bn_strassen` 在計算超過大約 170000 位元的數之後表現比起直式乘法有顯著的改善,而根據 Binet Formula:
$$
170000 = nlog_2(\phi)-log_2(\sqrt5)\\
n=244869.743094
$$
因此可以在需要計算至 $fib(500000)$ 時將乘法計算用 `bn_strassen`代替。