Try   HackMD

Linux 核心專題: 改進 fibdrv

執行人: ctfish7063
GitHub
專題解說影片

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
提問清單

  • ?

任務簡述

依據 fibdrv 作業規範,繼續投入 Linux 核心模組和相關程式的開發。

大數處理

原先的 fibdrv 因為 uint64_t 的限制,僅能計算至

fib(92),需要提供新的資料結構以進行計算和儲存之用。

資料結構

參考作業說明 - 基於 list_head 的大數運算, 資料結構使用 linked-list,並以 linux 的 list.h 進行實作:

/**
 * bn_head - store the head of bn list
 * @size: size of the list
 * @sign: sign of the bn
 * @list: list_head of the list
 */
typedef struct {
    size_t size;
    struct list_head list;
} bn_head;

/**
 * bn_node - store a node of bn list
 * The value should be within 10^19
 * @val: value of the node
 * @list: list_head of the node
 */
typedef struct {
    uint64_t val;
    struct list_head list;
} bn_node;

bn_head 僅儲存該鏈結串列的長度

size,而 bn_node 則以 uint64_t 的格式儲存資料。若以一個鏈結串列儲存一個大數,則可以將該鏈結串列看成一個有
64size
bits 的數。
經過測試,此資料結構可計算至
fib(1000000)
(以 wolfarmalpha 作為基準)

fibdrv$ sudo ./client
malloc size: 10848
lseek to 1000000
str size: 208989
fib[1000000] in 1839909461 ns:
195328212870775773163201494759625633244354299659187339695340519457162525788701569476664198763415014612887952433522023608462551091201956023374401543811519663615691996
...

bn_new

Binet formula 可以近似出

fib(n) 的值,將其取
log2(fib(n))64
即可計算所需要的 bn_node 的數量,可在一開始便配置好記憶體。

#define DIVISOR 100000
#define LOG2PHI 69424
#define LOG2SQRT5 116096

static inline struct list_head *bn_new(size_t n)
{
    unsigned int list_len = n > 1 ? (n * LOG2PHI - LOG2SQRT5) / DIVISOR / 64 + 1 : 1;
    struct list_head *head = bn_alloc();
    for (; list_len; list_len--) {
        bn_newnode(head, 0);
    }
    return head;
}

bn_add

bn 的加法bn_add 會將第二個 bn 的值加至第一個 bn ,以 __bn_add 進行操作:

void bn_add(struct list_head *a, struct list_head *b)
{
    __bn_add(a, b);
}

__bn_add 則為 bn 加法的實作:

#define bn_node_val(node) (list_entry(node, bn_node, list)->val)

void __bn_add(struct list_head *shorter, struct list_head *longer)
{
    int carry = 0;
    bn_node *node;
    struct list_head *longer_cur = longer->next;
    list_for_each_entry (node, shorter, list) {
        uint64_t tmp = node->val;
        node->val += bn_node_val(longer_cur) + carry;
        carry = U64_MAX - tmp >= bn_node_val(longer_cur) + carry ? 0 : 1;
        longer_cur = longer_cur->next;
        if (longer_cur == longer) {
            break;
        }
    }
    while (longer_cur != longer) {
        uint64_t tmp = bn_node_val(longer_cur);
        bn_newnode(shorter, bn_node_val(longer_cur) + carry);
        carry = U64_MAX - tmp >= carry ? 0 : 1;
        longer_cur = longer_cur->next;
    }
    while (carry) {
        if (bn_size(shorter) > bn_size(longer)) {
            uint64_t tmp = bn_node_val(node->list.next);
            bn_node_val(node->list.next) += carry;
            carry = U64_MAX - tmp >= carry ? 0 : 1;
        } else {
            bn_newnode(shorter, carry);
            break;
        }
    }
}

bn_sub

由於在計算 Fibonacci 數時不會出現負數,bn_sub 的實作是將兩數比較後將較大的 bn 減去較小的 bn:

void bn_sub(struct list_head *a, struct list_head *b)
{
    int cmp = bn_cmp(a, b);
    if (cmp >= 0) {
        __bn_sub(a, b);
    } else {
        __bn_sub(b, a);
    }
}

void __bn_sub(struct list_head *more, struct list_head *less)
{
    int carry = 0;
    bn_node *node;
    struct list_head *less_cur = less->next;
    list_for_each_entry (node, more, list) {
        uint64_t tmp =
            (less_cur == less) ? carry : bn_node_val(less_cur) + carry;
        if (node->val >= tmp && likely(bn_node_val(less_cur) != U64_MAX - 1)) {
            node->val -= tmp;
            carry = 0;
        } else {
            node->val += (U64_MAX - tmp) + 1;
            carry = 1;
        }
        if (less_cur != less) {
            less_cur = less_cur->next;
        }
    }
    bn_node *last = list_last_entry(more, bn_node, list);
    if (last->val == 0 && likely(!list_is_singular(more))) {
        bn_size(more)--;
        list_del(&last->list);
        kfree(last);
    }
}

bn_mul

bn_mul 會將兩個 bn 得相乘結果儲存在另外的 bn 中。此實作僅將每個 bn_node 分別作相乘並將結果加至對應的位置中,時間複雜度為

O(m×n)

void bn_mul(struct list_head *a, struct list_head *b, struct list_head *c)
{
    bn_node *node;
    // zeroing c
    list_for_each_entry (node, c, list) {
        node->val = 0;
    }
    bn_node *node_a, *node_b;
    struct list_head *base = c->next;
    list_for_each_entry (node_a, a, list) {
        uint64_t carry = 0;
        struct list_head *cur = base;
        list_for_each_entry (node_b, b, list) {
            uint128_t tmp = (uint128_t) node_b->val * (uint128_t) node_a->val;
            uint64_t n_carry = tmp >> 64;
            if (U64_MAX - bn_node_val(cur) < tmp << 64 >> 64)
                n_carry++;
            bn_node_val(cur) += tmp;
            if (U64_MAX - bn_node_val(cur) < carry)
                n_carry++;
            bn_node_val(cur) += carry;
            carry = n_carry;
            cur = cur->next;
        }
        while (carry) {
            if (cur == c) {
                bn_newnode(c, carry);
                break;
            }
            uint64_t tmp = bn_node_val(cur);
            bn_node_val(cur) += carry;
            carry = U64_MAX - tmp >= carry ? 0 : 1;
            cur = cur->next;
        }
        base = base->next;
    }
}

bn_lshiftbn_rshift

由於 bn_node 中所儲存的資料為 uint64_t, 一次左移或右移的操作最多僅能移動 63 bits,實作中會暫存所要移動的 bits 數量,以每 63 bits 作為單位分次移動:

void bn_lshift(struct list_head *head, int bit)
{
    int tmp = bit;
    for (; tmp > 64; tmp -= 63) {
        __bn_lshift(head, 63);
    }
    __bn_lshift(head, tmp);
}

void __bn_lshift(struct list_head *head, int bit)
{
    uint64_t carry = 0;
    bn_node *node;
    list_for_each_entry (node, head, list) {
        uint64_t tmp = node->val;
        node->val <<= bit;
        node->val |= carry;
        carry = tmp >> (64 - bit);
    }
    if (carry) {
        bn_newnode(head, carry);
    }
}

void bn_rshift(struct list_head *head, int bit)
{
    int tmp = bit;
    for (; tmp > 64; tmp -= 63) {
        __bn_rshift(head, 63);
    }
    __bn_rshift(head, tmp);
}

void __bn_rshift(struct list_head *head, int bit)
{
    uint64_t carry = 0;
    bn_node *node;
    list_for_each_entry_reverse(node, head, list)
    {
        uint64_t tmp = node->val;
        node->val >>= bit;
        node->val |= carry;
        carry = tmp << (64 - bit);
    }
    if (bn_last_val(head) == 0) {
        bn_pop(head);
    }
}

bn_to_array

為了減少 copy_to_user 所需複製的大小,在複製前會將 bn 中的資料儲存至陣列:

uint64_t *bn_to_array(struct list_head *head)
{
    bn_clean(head);
    uint64_t *res = kmalloc(sizeof(uint64_t) * bn_size(head), GFP_KERNEL);
    int i = 0;
    bn_node *node;
    list_for_each_entry (node, head, list) {
        res[i++] = node->val;
    }
    return res;
}

改進 fibdrv 效能

加速運算 使用 fast doubling

Fibonacci 數的定義為:

F(0)=0,F(1)=1
F(n)=F(n1)+F(n2)

其運算之時間複雜度為

O(n),程式實作如下:

static inline size_t fib_sequence_naive(long long k, uint64_t **fib)
{
    if (unlikely(k < 0)) {
        return 0;
    }
    // return fib[n] without calculation for n <= 2
    if (unlikely(k <= 2)) {
        *fib = kmalloc(sizeof(uint64_t), GFP_KERNEL);
        (*fib)[0] = !!k;
        return 1;
    }
    BN_INIT_VAL(a, 1, 0);
    BN_INIT_VAL(b, 1, 1);
    for (int i = 2; i <= k; i++) {
        bn_add(a, b);
        XOR_SWAP(a, b);
    }
    *fib = bn_to_array(b);
    size_t ret = bn_size(b);
    bn_free(a);
    bn_free(b);
    return ret;
}

fast doubling 的方法,可以將時間複雜度減少至

O(log(n)),以下則參考了作業解說中的 Bottom-up 方法進行實作:

// fast doubling
static inline void fast_doubling(struct list_head *fib_n0,
                   struct list_head *fib_n1,
                   struct list_head *fib_2n0,
                   struct list_head *fib_2n1)
{
    // fib(2n+1) = fib(n)^2 + fib(n+1)^2
    // use fib_2n0 to store the result temporarily
    bn_mul(fib_n0, fib_n0, fib_2n1);
    bn_mul(fib_n1, fib_n1, fib_2n0);
    bn_add(fib_2n1, fib_2n0);
    // fib(2n) = fib(n) * (2 * fib(n+1) - fib(n))
    bn_lshift(fib_n1, 1);
    bn_sub(fib_n1, fib_n0);
    bn_mul(fib_n1, fib_n0, fib_2n0);
}

static inline size_t fib_sequence(long long k, uint64_t **fib)
{
    if (unlikely(k < 0)) {
        return 0;
    }
    // return fib[n] without calculation for n <= 2
    if (unlikely(k <= 2)) {
        *fib = kmalloc(sizeof(uint64_t), GFP_KERNEL);
        (*fib)[0] = !!k;
        return 1;
    }
    // starting from n = 1, fib[n] = 1, fib [n+1] = 1
    uint8_t count = 63 - CLZ(k);
    BN_INIT_VAL(a, 0, 1);
    BN_INIT_VAL(b, 1, 1);
    BN_INIT(c, 0);
    BN_INIT(d, 0);
    int n = 1;
    for (uint8_t i = count; i-- > 0;) {
        fast_doubling(a, b, c, d);
        if (k & (1LL << i)) {
            bn_add(c, d);
            XOR_SWAP(a, d);
            XOR_SWAP(b, c);
            n = 2 * n + 1;
        } else {
            XOR_SWAP(a, c);
            XOR_SWAP(b, d);
            n = 2 * n;
        }
    }
    *fib = bn_to_array(a);
    size_t res = bn_size(a);
    bn_free(a);
    bn_free(b);
    bn_free(c);
    bn_free(d);
    return res;
}

將輸出時間使用 gnuplot 作圖,可以看出兩者之間的耗時相差甚巨(僅使用 ktime 測量上述函式所用時間):

減少 copy_to_user 的傳送量

bn_to_array 會將 bn 轉換為 uint64_t 的陣列,若單純使用 copy_to_user 複製該陣列, 在其儲存的元素所使用的空間小於 64 bit 的情況下將會多複製了不必要的空間。

針對 little-endian 架構,非零的位元組會被存在較低的記憶體位址。以

fib(100) 為例,需要兩個 uint64_t 來儲存,非零的位元組數為 9 個:

$ sudo ./client 
| 00 | 01 | 02 | 03 | 04 | 05 | 06 | 07 | 08 | 09 | 10 | 11 | 12 | 13 | 14 | 15 |
| c3 | bf | 94 | c5 | a7 | 76 | db | 33 | 13 | 00 | 00 | 00 | 00 | 00 | 00 | 00 |
fib[100]: 354224848179261915075

參考作業說明中的方法,使用 gcc 內建的 __builtin_clzll 計算陣列最後一元素的 leading zeros 之後僅從陣列複製剩餘的非零位元組(以上述例子來說便是 9 個):

static size_t my_copy_to_user(char *buf, uint64_t *src, size_t size)
{
    size_t lbytes = src[size - 1] ? CLZ(src[size - 1]) >> 3 : 7;
    size_t i = size * sizeof(uint64_t) - lbytes;
    printk(KERN_INFO "fibdrv: total %zu bytes, copy_to_user %zu bytes",
           size * sizeof(uint64_t), i);
    return copy_to_user(buf, src, i);
}

其中 lbytes 為避免 src 為僅一元素 0 的狀況,須額外判斷並保留至少一位元組。
可以用 dmesg 指令確認所計算之位元組數結果:

$ dmesg | grep fibdrv
[1306390.359899] fibdrv: reading on offset 100 
[1306390.359925] fibdrv: total 16 bytes, copy_to_user 9 bytes

client.c 中可初始化 buf 為 uint64_t 的陣列並傳入 read 作為複製的目的地,在輸出時將其轉換為字串即可,程式碼如下:

char *bn_2_string(uint64_t *head, int head_size, uint64_t n)
{
    //log10(fib(n)) = nlog10(phi) - log10(5)/2
    double tmp = n * 0.20898764025 - 0.34948500216;
    size_t size = n > 1 ? (size_t)tmp + 2 : 2;
    printf("str size: %zu\n", size);
    char *res = malloc(sizeof(char) * size);
    res[--size] = '\0';
    if (n < 3) {
        res[0] = !!head[0] + '0';
        return res;
    }
    for (int i = size; --i >= 0;) {
        uint128_t tmp = 0;
        for (int j = head_size; --j >= 0;) {
            tmp <<= 64;
            tmp |= head[j];
            head[j] = tmp / 10;
            tmp %= 10;
        }
        res[i] = tmp + '0';
    }
    return res;
}

實驗

實驗環境設定

參考 yanjiew 探討系統環境的設定,使用 cset 進行 cpu 的獨立:

$ sudo cset set -c 0-1 isolated
$ sudo cset set -c 2-7 others
$ sudo sh -c "echo 0 > /cpusets/isolated/sched_load_balance"
$ sudo cset proc -m root others

將執行緒指定於獨立出來的 cpu 執行:

$ sudo cset proc -e isolated -- sh -c './test > data.txt'

統計方法


參考作業說明中的 python script 並引入 scripts/preprocess.py中,假設資料分佈為自然分佈,將兩個標準差之外(即
95
的信賴區間)的數據去除後計算其平均值並作圖, 程式碼如下:

def outlier_filter(datas, threshold = 2):
    datas = np.array(datas)
    if datas.std() == 0:
        return datas
    z = np.abs((datas - datas.mean()) / datas.std())
    return datas[z < threshold]

def data_processing(data_set, n):
    catgories = data_set[0].shape[0]
    samples = data_set[0].shape[1]
    final = np.zeros((catgories, samples))
    if np.isnan(data_set).any():
        print("Warning: NaN detected in data set")
    for c in range(catgories):        
        for s in range(samples):
            final[c][s] =                                                    \
                outlier_filter([data_set[i][c][s] for i in range(n)]).mean()
    return final

效能分析

TODO: 紀錄閱讀作業說明中所有的疑惑

閱讀 fibdrv 作業規範,包含「作業說明錄影」和「Code Review 錄影」,本著「誠實面對自己」的心態,在本頁紀錄所有的疑惑,並與授課教師預約討論。

過程中,彙整 Homework3 學員的成果,挑選至少三份開發紀錄,提出值得借鏡之處,並重現相關實驗。

Schönhage–Strassen Algorithm

Q: 依據作業說明中的解釋, 此演算法是將大數分成小數字後將小數們構成的向量線性捲積最後將其進位,此算法看似跟長乘法相似,不知差異為何?

A: 線性卷積可使用 FFT 和 iFFT 進行計算;在考量定義域的情況下,也可以使用數論轉換在整數環上計算。

TODO: 回覆「自我檢查清單」

回答「自我檢查清單」的所有問題,需要附上對應的參考資料和必要的程式碼,以第一手材料 (包含自己設計的實驗) 為佳

TODO: 以 sysprog21/bignum 為範本,實作有效的大數運算

理解其中的技巧並導入到 fibdrv 中,並留意以下:

  • 在 Linux 核心模組中,可用 ktime 系列的 API;
  • 在 userspace 可用 clock_gettime 相關 API;
  • 善用統計模型,除去極端數值,過程中應詳述你的手法
  • 分別用 gnuplot 製圖,分析 Fibonacci 數列在核心計算和傳遞到 userspace 的時間開銷,單位需要用 us 或 ns (自行斟酌)

TODO: 實作更快速的乘法運算

參照 Schönhage–Strassen algorithm,在上述大數運算的程式碼基礎之上,改進乘法運算,確保在大多數的案例均可加速,需要有對應的驗證機制。

演算法原理

基於捲積的直式乘法

以下為一個

base 進位的直式乘法
123base×456base


                 1    2    3
    x            4    5    6
    ---------------------------
                 6   12   18
           5    10   15
     4     8    12
    ---------------------------
     4    13    28   27   18

假設

base=10,可以將計算的結果進行進位以得到
12310×45610=5608810


                 1    2    3
    x            4    5    6
    ---------------------------
                 6   12   18
           5    10   15
     4     8    12
    ---------------------------
     4    13    28   27    18
    --------------------------- carry in base = 10
     5     6     0    8     8

若將

123base 視為一長度
N1=3
的序列
x[n]={1,2,3}
456base
視為長度
N2=3
的有列
y[n]={4,5,6}
,上述計算便可以視為兩序列的 linear convolution
(xy)[n]=m=x[m]y[nm]

其結果則為一長度

N=N1+N21=5 的序列
(xy)[n]={4,13,28,27,18}

在此介紹另一種卷積 circular convolution,定義如下:

(fgN)[n]m=0N1(k=f[m+kN]) gN[nm]

可以發現關係式與 linear convolution 相似,主要的差別在於序列

gN 有一週期
N
。若對
x[n]
y[n]
circular convolution 可得到長度
N=3
的序列
{28,31,31}


            1     2     3
    x       4     5     6
    -----------------------
            6    12    18
           10    15    *5
           12    *4    *8
    -----------------------
           28    31    31

可以發現原先在 linear convolution 中超過

N=3 的項( * 號項目) 繞回了序列的尾端。
由此可知若將
x[n]
y[n]
兩序列補上
2
0
延伸成長度
N=5
的序列並對他們作 circular convolution (週期
N=5
):


             0    0    1    2    3
    x        0    0    4    5    6
    -------------------------------------
             0    0    6   12   18
             0    5   10   15   *0
             4    8   12   *0   *0
    -------------------------------------
             4   13   28   27   18

其運算結果等同於對兩序列作 linear convolution (須將長度補到至少

N1+N21,補
0
的動作稱為 zero padding)。

根據 convolution theorem,兩序列的 circular convolution 會等於兩序列的 discrete fourier transform (DFT) 進行 element-wise multiplication 後再進行 inverse DFT,即:

CircularConvolution(X, Y) = IDFT(DFT(X) · DFT(Y))

DFT 可由 FFT 演算法進行加速。
若要計算

x×y=z,其流程為:

  1. 分割被乘數
    x
    與乘數
    y
    為序列
    x[n]
    y[n]
    並進行 zero padding
  2. 使用 FFT 計算
    X[k]=DFT(x[n])
    Y[k]=DFT(x[n])
  3. 利用 Schönhage–Strassen algorithm 遞迴地計算
    Z[k]=X[k]Y[k]
  4. 使用 FFT 計算
    z[n]=IDFT(Z[k])
  5. z[n]
    進行進位操作得到
    z

時間複雜度比較

標準的直式乘法是一項一項相乘, 因此時間複雜度為

O(n2)
若假設將一長度
n=2k
位元的數分成
B
L
位元的段落,使用 Schönhage–Strassen algorithm 計算時間複雜度為:
O(n log n log log n)
,推導如下:
設計算時間為
M(n)=BM(L)+O(B logB) M(L)+O(n logB)

其中:

  • BM(L)
    Z[k]=X[k]Y[k]
    的計算時間(遞迴計算)
  • O(B logB) M(L)
    計算 nttintt 中的乘法的運算時間(遞迴計算)
  • O(n logB)
    為計算nttintt 中的加減法的運算時間

假設

B=nα,
L=n1α
,可得:
M(n)=nαM(2n1α)+O(n logn)

可以分為三種情況:

  1. α<12
    , 此時
    M(n)=O(n log2n)
  2. α>12
    , 此時
    M(n)=O(n logn)
  3. α=12
    , 此時
    M(n)=O(n log n log log n)

因此在

B=L=n 時為最佳解;在
k
為奇數時可選擇
B=n2,L=2n
B=2n,L=n2
,並不會影響最終的時間複雜度。

實作原理

多項式乘法

參考〈快速傅立葉轉換〉一文關於求解多項式函數相乘的方法:

設一多項式

A(x)=j=0N1ajxj,可以用
N
個點
{(x0,A0)...(xN,AN)}
來表示這個
N1
階多項式;反之若給定
N
xj
, 可以用計算出的的
N
Aj
反推回原式(此時稱該多項式
degree bound
N
)。

若要計算兩個

N 階多項式相乘
C(x)=A(x)×B(x)
, 我們可以在
A(x)
B(x)
上分別找出
2N1
個點
{(x0,A0)...(x2N1,A2N1)}
{(x0,B0)...(x2N1,B2N1)}
,如此一來
C(x)
便可以用
{(x0,A0B0)...(x2N1,A2N1B2N1)}
表示,計算的時間複雜度可以由將
A(x)
B(x)
各項係數相乘的
O(n2)
縮減為
O(n)

我們可以將一個數字的二進位表示拆解成一個多項式

A(x=2k),例如:
32110=1010000012=128+126+1=A(x=2)=1(24)2+4(24)+1=B(x=24)

因此大數相乘問題其實可以看成多項式相乘問題,在計算完後將

x 代入適合的
2k
即可 (實作上以 bit shift 進行)。

FFT

以下將以 radix-2 DIT(Cooley–Tukey FFT algorithm) 為主

DFT 的公式為:

(1)Xk=j=0N1xj ei2πNkj=j=0N1xj wNkj
A(x)
xk=wNk=e2πki/N
上的點
yk
便可以化成 DFT 型式:
yk=j=0N1ajwNkj

其中
wN0...wNN1
xN=1
N
個根,並具有以下性質:

  1. 對於所有整數
    n,d,k0, wdndk=wnk
  2. 對於所有偶數
    n>0, wnn/2=w2=1
  3. 對於所有偶數
    n>0,(wnk)2=(wnk+n/2)2=wn/2k
  4. 對於所有整數
    n,k0, wnk+n/2=wnk

假設

N
2
的冪,若將
A(x)
以次方的奇偶數分成兩部份:
A(x)=a0x0+a1x1+a2x2+a3x3++aN1xN1=(a0x0+a2x2++aN2xN2)+x(a1x0+a3x2++aN1xN2)

令:

A[0](x)=a0+a2x+a4x2++aN2xN/21A[1](x)=a1+a3x+a5x2++aN1xN/21

A(x) 可表示為:
(2)A(x)=A[0](x2)+xA[1](x2)

由上述性質3可知

(wNk)2=wN/2k
A(wNk)=A[0]((wNk)2)+xA[1]((wNk)2)=A[0](wN/2k)+xA[1](wN/2k)

可以發現

A(x) 被拆成了
A[0](x)
A[1](x)
兩個子問題,但其
degree bound
縮減為
N2
,因此可以將他一路分解到
degree bound
1
(此時直接回傳係數即可)。

在合併時,假設已計算出大小為

N2的傅立葉轉換
yk[0]=A[0](wN/2k)
yk[1]=A[1](wN/2k)
,則合併後的結果
yk=0,...,N
為:
yk=yk[0]+wnk yk[1]

而透過性質4.可以發現:

yk+N2=A(wNk+N2)=(2)A[0](wN2k+N)+wNk+N2A[1](wN2k+N)=A[0](wN2k)+wNk+N2A[1](wN2k)=A[0](wN/2k)wNk A[1](wN/2k)=yk[0]wnk yk[1]

每次迭代時計算

yk[0]
yk[1]
便可以同時計算
yk
yk+N/2
,須計算的
yk
數量變成了一半,因此時間複雜度為
O(nlogn)

假設

N=8,演算法的遞迴關係如下面的樹狀圖所示:

可以發現所有 children 是按照特定順序排列的 (將該 child 的 index 的位元反轉,如

01121102),若事先將其排列好,可以將其變成 bottom-up 的實作,參考〈Cooley–Tukey FFT algorithm: Data reordering, bit reversal, and in-place algorithms〉的虛擬碼:

algorithm iterative-fft is
    input: Array a of n complex values where n is a power of 2.
    output: Array A the DFT of a.
 
    bit-reverse-copy(a, A) //copy the reversed array
    n ← a.length 
    for s = 1 to log(n) do
        m ← 2s
        ωm ← exp(2πi/m) 
        for k = 0 to n-1 by m do
            ω ← 1
            for j = 0 to m/21 do
                t ← ω A[k + j + m/2]
                u ← A[k + j]
                A[k + j] ← u + t
                A[k + j + m/2] ← u – t
                ω ← ω ωm
   
    return A

兩者的比較如下圖:

Number Theoretic Transform(NTT)

The discrete Fourier transform is an abstract operation that can be performed in any algebraic ring; typically it's performed in the complex numbers, but actually performing complex arithmetic to sufficient precision to ensure accurate results for multiplication is slow and error-prone.

<Strassen algorithm - Choice of ring>

FFT 中對複數的操作需要大量浮點數計算,若需要計算到一定的精準度將會消耗極大的運算資源,另外在作業說明中也提到了核心文件中不建議使用浮點運算。由於 DFT 的性質大部分是基於

wn (
s.t. wnn=1
) 的性質,因此若是能在體之下能找到單位根也能夠進行運算。NTTDFT 的特例,指的是在有限體之下所進行的 DFT

參考<數論轉換>一文,考慮一有限體

Fp
P=c×2k+1
是質數,首先尋找
r
滿足:
{r mod P,r2 mod P,,rP1 mod P}={1,2,,(P1)}

r
P
互質的情況下,根據費馬小定理
rP11(mod P)
,可以推論:
rm1 (mod P),m{1,2,,(P2)}

也就是
r
Fp
原根,如此一來有限體中的每個數皆可用
r
的冪表示。
Fp
裡的
2k
階單位根
x
(
x2k=1
)可以用
rc
表示,將上面式子中的
wN
替換成
rc
即可,:
(rc)2k=rc×2k=rP11 (mod P)

程式碼實作

NTT.h (commit d3385bc)

ntt

ntt 使用了 iterative FFT 演算法,參考了上面的虛擬碼進行實作,在求解 wm 時使用快速冪。

static inline int reverse_bits(int x, int n)
{
    int result = 0;
    for (int i = 0; i < n; i++) {
        result <<= 1;
        result |= (x & 1);
        x >>= 1;
    }
    return result;
}

static inline uint64_t fast_pow(uint64_t x, uint64_t n, uint64_t p)
{
    uint64_t result = 1;
    while (n) {
        if (n & 1) {
            result = result * x % p;
        }
        x = x * x % p;
        n >>= 1;
    }
    return result;
}

static inline void ntt(uint64_t *a, int n, uint64_t p, uint64_t g)
{
    uint64_t len = 64 - CLZ(n - 1);
    for (int i = 0; i < n; i++) {
        if (i < reverse_bits(i, len)) {
            a[reverse_bits(i, len)] ^= a[i];
            a[i] ^= a[reverse_bits(i, len)];
            a[reverse_bits(i, len)] ^= a[i];
        }
    }
    for (int m = 2; m <= n; m <<= 1) {
        uint64_t wm = fast_pow(g, (p - 1) / m, p);
        for (int k = 0; k < n; k += m) {
            uint64_t w = 1;
            for (uint64_t j = 0; j < m / 2; j++) {
                uint64_t t = w * a[k + j + m / 2] % p;
                uint64_t u = a[k + j];
                a[k + j] = (u + t) % p;
                a[k + j + m / 2] = (u - t + p) % p;
                w = w * wm % p;
            }
        }
    }
}
intt

IDFT 的公式為:

xj=k=0N1Xk ei2πNkj=k=0N1Xk wNkj
公式與 DFT 非常類似,僅有兩點需要更改:

  1. 其中的
    wNk
    (1)
    中的次方數差了負號(即倒數關係),在有限體中倒數可以模反元素代替,在實作中使用費馬小定理配合快速冪求解。
  2. 在加總後須再乘以
    1N
    ,一樣可以用費馬小定理求出模反元素。
    除此兩點外其餘皆和 ntt 相同,程式碼如下:
static inline void intt(uint64_t *a, int n, uint64_t p, uint64_t g)
{
    uint64_t len = 64 - CLZ(n - 1);
    for (int i = 0; i < n; i++) {
        if (i < reverse_bits(i, len)) {
            a[reverse_bits(i, len)] ^= a[i];
            a[i] ^= a[reverse_bits(i, len)];
            a[reverse_bits(i, len)] ^= a[i];
        }
    }
    for (int m = 2; m <= n; m <<= 1) {
        uint64_t wm = fast_pow(g, (p - 1) / m, p);
        // modular inverse
        wm = fast_pow(wm, p - 2, p);
        for (int k = 0; k < n; k += m) {
            uint64_t w = 1;
            for (int j = 0; j < m / 2; j++) {
                uint64_t t = w * a[k + j + m / 2] % p;
                uint64_t u = a[k + j];
                a[k + j] = (u + t) % p;
                a[k + j + m / 2] = (u - t + p) % p;
                w = w * wm % p;
            }
        }
    }
    // inv by Fermat's little theorem
    uint64_t inv = fast_pow(n, p - 2, p);
    for (int i = 0; i < n; i++) {
        a[i] = a[i] * inv % p;
    }
}

bn_strassen

由於數論轉換的限制,計算出的結果將落在有限體之內。為避免模數溢出,這裡將分割的大小

L 固定為
8
位元,同時不使用遞迴來計算
Z[k]=X[k]×Y[k]
,因此時間複雜度將提昇為
O(nlogn)
,實作程式碼如下:

void bn_strassen(struct list_head *a, struct list_head *b, struct list_head *c)
{
    int a_size = bn_size(a) * per_size - CLZ(bn_last_val(a)) / chunck_size;
    int b_size = bn_size(b) * per_size - CLZ(bn_last_val(b)) / chunck_size;
    // could not do ntt if size is too small
    if (a_size < 2 || b_size < 2) {
        bn_mul(a, b, c);
        return;
    }
    // zero padding
    int size = nextpow2((uint64_t)(a_size + b_size - 1));
    uint64_t *a_array = bn_split(a, size);
    uint64_t *b_array = bn_split(b, size);
    // number theoretic transform
    ntt(a_array, size, mod, rou);
    ntt(b_array, size, mod, rou);
    // pointwise multiplication
    for (int i = 0; i < size; i++) {
        a_array[i] *= b_array[i] % mod;
    }
    // inverse ntt
    intt(a_array, size, mod, rou);
    // carrying
    uint64_t carry = 0;
    for (int i = 0; i < size; i++) {
        a_array[i] += carry;
        carry = a_array[i] >> chunck_size;
        a_array[i] &= chunk_mask;
    }
    // convert to bn
    for (; bn_size(c) < size / per_size;) {
        bn_newnode(c, 0);
    }
    bn_node *node;
    int i = 0;
    list_for_each_entry (node, c, list) {
        // four at a time : 8 bits to uint64_t
        uint64_t val = 0;
        for (int j = 0; j < val_size; j += chunck_size) {
            if (i < size) {
                val |= a_array[i++] << j;
            } else if (carry) {
                val |= (carry & chunk_mask) << j;
                carry >>= chunck_size;
            }
        }
        node->val = val;
    }
    if (carry) {
        bn_newnode(c, carry);
    }
    bn_clean(c);
    kfree(a_array);
    kfree(b_array);
}

實驗

由於乘法的時間複雜度為其所佔位元數的函數,可以將乘數和被乘數初始化為 1 ,透過左右移的方式控制乘數和被乘數的位元數,並計算其 trailing zeros 來驗證乘法結果的準確性。測試程式如下:

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "bn2.h"
#include "ntt2.h"

void bn_print(struct list_head *head)
{
    uint64_t *res = bn_to_array(head);
    char *ret = bn_2_string(res, bn_size(head), 100);
    puts(ret);
    free(ret);
}

long long getnanosec()
{
    struct timespec ts;
    clock_gettime(CLOCK_MONOTONIC, &ts);
    return ts.tv_sec * 1000000000L + ts.tv_nsec;
}

long long bench(struct list_head *a,
                struct list_head *b,
                struct list_head *c,
                void (*func_ptr)(struct list_head *,
                                 struct list_head *,
                                 struct list_head *))
{
    long long st = getnanosec();
    func_ptr(a, b, c);
    long long ut = getnanosec();
    return ut - st;
}

int main()
{
    // test for fib
    BN_INIT_VAL(a, 0, 1);
    BN_INIT_VAL(b, 0, 1);
    BN_INIT_VAL(c, 0, 1);
    int a_ctz, b_ctz, c_ctz;
    for (int i = 0; i <= 500000; i++) {
        long long mul = bench(a, b, c, bn_mul);
        a_ctz = CTZ(bn_last_val(a)) + 64 * (bn_size(a) - 1);
        b_ctz = CTZ(bn_last_val(b)) + 64 * (bn_size(b) - 1);
        c_ctz = CTZ(bn_last_val(c)) + 64 * (bn_size(c) - 1);
        assert(a_ctz + b_ctz == c_ctz && "mul error");
        long long strassen = bench(a, b, c, bn_strassen);
        a_ctz = CTZ(bn_last_val(a)) + 64 * (bn_size(a) - 1);
        b_ctz = CTZ(bn_last_val(b)) + 64 * (bn_size(b) - 1);
        c_ctz = CTZ(bn_last_val(c)) + 64 * (bn_size(c) - 1);
        assert(a_ctz + b_ctz == c_ctz && "strassen error");
        printf("%i %lld %lld\n", i, mul, strassen);
        bn_lshift(a, 1);
        bn_lshift(b, 1);
    }

    return 0;
}

測試結果如下:

可以觀察到 bn_strassen 在計算超過大約 170000 位元的數之後表現比起直式乘法有顯著的改善,而根據 Binet Formula:

170000=nlog2(ϕ)log2(5)n=244869.743094
因此可以在需要計算至
fib(500000)
時將乘法計算用 bn_strassen代替。