Try   HackMD

Linux 核心專題: 並行的 fibdrv

執行人: ericlai1021
專題解說錄影

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
提問清單

  • ?

任務簡述

擴充 fibdrv,強化其並行處理能力,預計達成:

  • 有效運算 Fibonacci 數 (至少能算到第一百萬個) 並降低記憶體開銷
  • 藉由 hashtable 或 cache 一類的機制,儲存已計算的 Fibonacci 數
  • 引入 workqueue,將運算要求分派給個別 CPU 核,並確保降低非必要的同步處理成本
  • 修訂 fibdrv 和應用程式之間的 API,使其適合用於同步處理

TODO: 落實 Fibonacci 數的計算效率

  1. 彙整教材和學員成果,可延用現有程式碼,但應清楚標示出處並持續改進。
  2. 掌握「加速 Fibonacci 運算」及「sysprog21/bignum 程式碼分析」,落實於 fibdrv 內部實作中
  3. 提供驗證機制,確保 fibdrv 至少能算到第一百萬個 Fibonacci 數
  4. 修訂 fibdrv 和應用程式之間的 API,使資料傳輸和操作更有效

TODO: 儲存已計算的 Fibonacci 數

  1. 考慮到大數運算的特性,當以 key-value 形式保存時,不是儲存單純的整數值,而是指向特定結構的指標,於是當 fibdrv 嘗試釋放佔用的記憶體空間時,應有對應的操作
  2. 考慮到 fast doubling 和 Fibonacci 數的特性,不用保存連續數值,而是關注第 N 個和第 2N 個 Fibonacci 數的關聯,儘量降低記憶體開銷
  3. 應當善用 Linux 核心的 hashtable 或相關的資料結構

引入 workqueue,確保並行處理的效益

  1. 學習 ktcp,引入 kthread 和 CMWQ 到 fibdrv,確保 Fibonacci 數的運算可發揮硬體能力
  2. 確保並行處理的效益,不僅要確認結果正確,還要讓並行的 fibdrv 得以更有效的運算

改善大數運算

基於先前作業已完成部份,請參閱 ericlai1021-fibdrv ,先以 perf stat 分析程式碼,作為後續比對的基準

此部份為了方便後續做實驗比較,程式皆在 user space 執行

input a number: 100000

 Performance counter stats for './test':

     685,1248,4134      instructions              #    1.73  insn per cycle           (83.31%)
     395,8769,6486      cycles                                                        (83.31%)
           14,5768      cache-misses              #   12.170 % of all cache refs      (83.34%)
          119,7772      cache-references                                              (83.35%)
       3,8748,8927      branch-misses             #    8.87% of all branches          (83.35%)
      43,6638,9672      branch-instructions                                           (66.65%)

      15.557167671 seconds time elapsed

      11.697195000 seconds user
       0.000000000 seconds sys

接下來使用 perf record 量測 call graph (省略部份內容)

$ perf record -g --call-graph dwarf ./test
$ perf report --stdio -g graph,0.5,caller

# Children      Self  Command  Shared Object         Symbol                             
# ........  ........  .......  ....................  ...................
#    
    99.98%     0.00%  test     test                  [.] main
            |
            ---main
               |          
               |--86.94%--bn_to_string
               |          
                --13.04%--bn_mult
                          |          
                          |--6.65%--bn_add
                          |          
                           --6.33%--bn_lshift

    86.94%    86.62%  test     test                  [.] bn_to_string
            |          
             --86.62%--_start
                       __libc_start_main_impl (inlined)
                       __libc_start_call_main
                       main
                       bn_to_string

    13.04%     0.02%  test     test                  [.] bn_mult
            |          
             --13.02%--bn_mult
                       |          
                       |--6.65%--bn_add
                       |          
                        --6.33%--bn_lshift

     6.65%     6.53%  test     test                  [.] bn_add
            |          
             --6.53%--_start
                       __libc_start_main_impl (inlined)
                       __libc_start_call_main
                       main
                       bn_mult
                       bn_add

     6.33%     6.24%  test     test                  [.] bn_lshift
            |          
             --6.24%--_start
                       __libc_start_main_impl (inlined)
                       __libc_start_call_main
                       main
                       |          
                        --6.24%--bn_mult
                                  bn_lshift
  • 有 86.94% 的時間 (準確來說是樣本數) 落在 bn_to_string ,由此可見大數由二進制轉換成十進制的成本非常高,更不用說考慮到執行在 kernel space 時 copy_to_user 的成本,因此改善此部份勢必具有明顯的效能增益
  • 有 13.04% 的時間落在 bn_mult ,這部份的實作為參考傳統乘法器原理,因此會有大量的加法以及左移運算,需要提出一個更高效能的計算方法

改善方案 1: 運用 Q-Matrix 改進 fast doubling 的實作

主函式內採用 fast_doubling 演算法實作大數運算,然而 fast_doubling 的特色為會紀錄第 2n 項以及第 2n + 1 項的結果,若要進一步加速此運算則可以採用 Q-Matrix 搭配 Exponentiation by squaring 的技巧

所謂的 Q-Matrix 可以進一步將費氏數列改寫成以下形式:

Q=(1110)=(F2F1F1F0)Qn=(Fn+1FnFnFn1)

Exponentiation by squaring 則可以進一步將

Qn 的計算改寫成:
Qn={Q(Q2)(n1)2,ifnisodd(Q2)n2,ifniseven

對應的程式碼如下:

def multiply_matrices(a, b):
    # Matrix multiplication helper function
    result = [[0, 0], [0, 0]]
    for i in range(2):
        for j in range(2):
            for k in range(2):
                result[i][j] += a[i][k] * b[k][j]
    return result

def power_matrix(matrix, n):
    # Matrix exponentiation helper function
    result = [[1, 0], [0, 1]]
    while n > 0:
        if n % 2 == 1:
            result = multiply_matrices(result, matrix)
        matrix = multiply_matrices(matrix, matrix)
        n //= 2
    return result

def fibonacci(n):
    if n <= 0:
        return "Input should be a positive integer."

    matrix = [[1, 1], [1, 0]]
    power = power_matrix(matrix, n - 1)
    fib_n = power[0][0]

    return fib_n

實驗分析

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

Q-Matrix 的效能明顯比 Fast Doubling 來的差,其實不難理解是因為 Fast Doubling 其實就是 Q-Matrix 的改良,而直接使用 Q-Matrix 計算的話反而會額外多出許多計算量

  • 改進 fast doubling 實作
    參考 KYG-yaya573142 的作法稍微調整 fast doubling 的實作方法,式子推導如下
    [F(2n1)F(2n)]=[0111]2n[F(0)F(1)]=[F(n1)F(n)F(n)F(n+1)][F(n1)F(n)F(n)F(n+1)][10]=[F(n)2+F(n1)2F(n)F(n)+F(n)F(n1)]

整理後可得

F(2k1)=F(k)2+F(k1)2F(2k)=F(k)[2F(k1)+F(k)]

使用上述式子實作比起作業說明的 範例實作 可以減少一次迴圈的計算以及省去掉減法的運算

  • 在實作的過程中還發現原先的寫法都是使用 bn_cpy 來更新變數的數值,其實可以藉由 bn_swap 以及改變各函式儲存結果的位置來達到同樣的目的,因此就將所有的 bn_cpy 去除改用 bn_swap 以降低複製資料造成的成本
  • bn_swap 的實作如下

一開始的想法是將兩個 bn * 型態的位址互換即可

/* swap bn ptr */
void bn_swap(bn *a, bn *b)
{
    bn *tmp = a;
    a = b;
    b = tmp;
}

int main()
{
    bn *ptrA, *ptrB;
    bn_swap(ptrA, ptrB);
    return 0;
}

結果發現根本沒有互換成功,於是先將 ab 的位址印出來看,發現在呼叫 bn_swap 函式後 ab 的位址並無交換,但在 bn_swap 函式內兩個指標位址確實有交換,研讀課程教材 〈你所不知道的C語言:指標篇〉後才了解到原來 C 語言在函式呼叫皆是 call-by-value ,上述程式的行為是呼叫 bn_swap 時會傳遞 ptrAptrB 的位址,而 bn_swap 函式會產生 bn * 型態變數 ab 分別將 ptrAptrB 的位址存在其中,示意圖如下







structs



structa

ptrA內部的數值



structp

ptrA



structp:p->structa:nw





structptr

a



structptr:ptr->structa:nw





structpb

ptrB



structb

ptrB內部的數值



structpb:pb->structb:nw





structptrb

b



structptrb:ptr->structb:nw





執行 bn_swap 函式後會變成下圖所示







structs



structa

ptrA內部的數值



structp

ptrA



structp:p->structa:nw





structptr

a



structb

ptrB內部的數值



structptr:ptr->structb:nw





structpb

ptrB



structpb:pb->structb:nw





structptrb

b



structptrb:ptr->structa:nw





若要正確交換,則必須要使用「指標的指標」技巧,因為 bn 資料結構中 number 紀錄的是指標,所以可以透過以下方式將兩個 bn 型態變數的內容互換,而不會改變儲存在 heap 中的數值

C 語言沒有 "call-by-reference",只有數值傳遞 (包含指標在內,都是數值),請研讀〈你所不知道的C語言:指標篇〉,用精準的描述。

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
jserv

已修正認知

/* swap bn ptr */
void bn_swap(bn *a, bn *b)
{
    bn tmp = *a;
    *a = *b;
    *b = tmp;
}

示意圖如下







structs



structa

ptrA內部的數值



structp

ptrA



structp:p->structa:nw





structptr

a



structptr:ptr->structa:nw





structpb

ptrB



structb

ptrB內部的數值



structpb:pb->structb:nw





structptrb

b



structptrb:ptr->structb:nw





執行 bn_swap 函式後會變成下圖所示







structs



structa

ptrB內部的數值



structp

ptrA



structp:p->structa:nw





structptr

a



structptr:ptr->structa:nw





structpb

ptrB



structb

ptrA內部的數值



structpb:pb->structb:nw





structptrb

b



structptrb:ptr->structb:nw





實驗結果如下 (v1 綠線為修改的 fast doubling 並將所有 bn_cpy 換成 bn_swap)

完全出乎意料之外的竟然沒有改善,為此我反覆測驗程式並檢查程式碼是否正確,但最終的結果確實如此,於是我就去查看兩個實驗的 call graph

  • 原始實作的 call graph
# Children      Self  Command  Shared Object         Symbol                                    
# ........  ........  .......  ....................  ..........................................
#
    95.50%     0.00%  test     test                  [.] _start
            |
            ---_start
               __libc_start_main_impl (inlined)
               __libc_start_call_main
               main
               |          
                --94.81%--fast_doubling
                          |          
                          |--93.29%--bn_mult
                          |          |          
                          |          |--46.77%--bn_add
                          |          |          |          
                          |          |          |--12.52%--bn_resize
                          |          |          |          |          
                          |          |          |          |--4.95%--__GI___libc_realloc (inlined)
                          |          |          |          |          |          
                          |          |          |          |           --1.28%--_int_realloc
                          |          |          |          |          
                          |          |          |           --1.85%--__memset_avx2_unaligned_erms
                          |          |          |          
                          |          |           --2.01%--_GLOBAL_OFFSET_TABLE_
                          |          |                     __GI___libc_realloc (inlined)
                          |          |          
                          |          |--34.63%--bn_lshift
                          |          |          |          
                          |          |          |--4.72%--bn_clz
                          |          |          |          
                          |          |           --1.34%--bn_resize
                          |          |          
                          |          |--5.75%--bn_digit
                          |          |          |          
                          |          |           --3.85%--bn_clz
                          |          |          
                          |          |--0.60%--bn_alloc
                          |          |          
                          |           --0.50%--bn_resize
                          |          
                           --0.66%--bn_add
  • 修改後的 call graph
# Children      Self  Command  Shared Object         Symbol                                    
# ........  ........  .......  ....................  ..........................................
#
    95.79%     0.00%  test     test                  [.] _start
            |
            ---_start
               __libc_start_main_impl (inlined)
               __libc_start_call_main
               |          
                --95.78%--main
                          |          
                          |--94.94%--fast_doubling
                          |          |          
                          |          |--94.04%--bn_mult
                          |          |          |          
                          |          |          |--47.49%--bn_add
                          |          |          |          |          
                          |          |          |          |--12.71%--bn_resize
                          |          |          |          |          |          
                          |          |          |          |          |--4.95%--__GI___libc_realloc (inlined)
                          |          |          |          |          |          |          
                          |          |          |          |          |           --1.19%--_int_realloc
                          |          |          |          |          |          
                          |          |          |          |           --2.17%--__memset_avx2_unaligned_erms
                          |          |          |          |          
                          |          |          |           --2.11%--_GLOBAL_OFFSET_TABLE_
                          |          |          |                     __GI___libc_realloc (inlined)
                          |          |          |          
                          |          |          |--34.43%--bn_lshift
                          |          |          |          |          
                          |          |          |          |--4.76%--bn_clz
                          |          |          |          |          
                          |          |          |           --1.27%--bn_resize
                          |          |          |          
                          |          |           --5.74%--bn_digit
                          |          |                     |          
                          |          |                      --3.66%--bn_clz
                          |          |          
                          |           --0.64%--bn_add
                          |          
                           --0.55%--__printf (inlined)
                                     |          
                                      --0.54%--__vfprintf_internal

從 call graph 就可以看出來原來是因為大數的乘法運算太花費時間,兩個實驗都有 94% 左右的時間在執行 bn_mult ,所以才會導致上述所作的改善微乎其微

改善方案 2: 改進 bn_mult 的效能

原本的 bn_mult 實作是參考傳統硬體乘法器,如此次來假設乘數有 x 個位元,則 worst case 就會需要執行 x 次的 bn_add 以及 x-1 次的 bn_lshift ,可想而知這對大數運算非常不適合,因此初步參考 KYG-yaya573142 的作法,概念即為直式乘法,將兩個大數的數字陣列依序兩兩相乘,接著將結果直接疊加到輸出的變數

void bn_mult(const bn *a, const bn *b, bn *c)
{
    // max digits = sizeof(a) + sizeof(b))
    int d = bn_digit(a) + bn_digit(b);
    d = DIV_ROUNDUP(d, 32) + !d;  // round up, min size = 1
    bn *tmp;
    /* make it work properly when c == a or c == b */
    if (c == a || c == b) {
        tmp = c;  // save c
        c = bn_alloc(d);
    } else {
        tmp = NULL;
        for (int i = 0; i < c->size; i++)
        	c->number[i] = 0U;
        bn_resize(c, d);
    }

    for (int i = 0; i < a->size; i++) {
        for (int j = 0; j < b->size; j++) {
            unsigned long long int carry = 0U;
            carry = (unsigned long long int) a->number[i] * b->number[j];
            unsigned long long int tmp = 0;
            for (int k = i + j; k < c->size; k++) {
                tmp += c->number[k] + (carry & 0xFFFFFFFF);
                c->number[k] = tmp;
                tmp >>= 32;
                carry >>= 32;
                if (!carry && !tmp)  // done
                    break;
            }
        }
    }
    
    if (tmp) {
    	bn_swap(tmp, c);
    	bn_free(c);
    }
}

實驗結果 (v2 綠線)

接著進一步發現第三層迴圈的用意為將每一輪相乘的結果疊加到輸出的變數,其實只需要將相乘的結果與輸出變數相加,再用一個變數 carry 儲存溢位的部份加到下一輪的相加結果就好

void bn_mult(const bn *a, const bn *b, bn *c)
{
    ...

    unsigned long long product;
    unsigned int carry = 0U;
    for (int i = 0; i < a->size; i++) {
        for (int j = 0; j < b->size; j++) {
        	product = (unsigned long long) a->number[i] * b->number[j] + 
                          carry + c->number[i + j];
        	carry = product >> 32;
        	c->number[i + j] = product & 0xFFFFFFFF;
        }
        if (carry) {
        	c->number[i + b->size] = carry;
        	carry = 0U;
        }
    }
    ...
}

實驗結果 (v3 綠線)

改善方案 3: 善用 64 位元 CPU 特性

原先 bn 結構體中數字陣列 number 的資料型態是 unsigned int ,為了能充分利用 64 位元處理器的特性改使用 uint64_t 確保每次的記憶體存取都是一個 word 的大小。

#include <stdint.h>

#if defined(__LP64__) || defined(__x86_64__) || defined(__amd64__) || defined(__aarch64__)
#define BN_WSIZE 8
#else
#define BN_WSIZE 4
#endif

#if BN_WSIZE == 8
typedef uint64_t bn_data;
typedef unsigned __int128 bn_data_tmp; // gcc support __int128
#define DATA_BITS 64U
#define builtin_clz(x) __builtin_clzll(x)
#elif BN_WSIZE == 4
typedef uint32_t bn_data;
typedef uint64_t bn_data_tmp;
#define DATA_BITS 32U
#define builtin_clz(x) __builtin_clz(x)
#else
#error "BN_WSIZE must be 4 or 8"
#endif

typedef struct _bn {
    bn_data *number;
    bn_data size;
} bn;
  • 參考 bignum/apm.h 當中的方式來定義 bn 結構體的資料型態,以便於根據不同 word 大小切換定義
  • 加法及乘法運算時會用到 2 倍大小的的暫存變數,直接使用 gcc 提供的 __int128 實作

實驗結果如下 (v4 綠線)

改善方案 4: 內嵌組合語言

參考 bignum/apm_internal.h 當中乘法運算使用內嵌組合語言 (inline assembly) 來直接取得乘法運算的高位與低位,直接使用一樣的方式實作乘法,取代原本使用的 __int128

void bn_mult(const bn *a, const bn *b, bn *c)
{
    ...

    bn_data_tmp carry = 0UL;
    bn_data *numA = a->number;
    for (int j = 0; j < b->size; j++) {
    	bn_data multiplier = b->number[j];
        for (int i = 0; i < a->size; i++) {
        	bn_data high, low;
        	__asm__("mulq %3"
        			 : "=a"(low), "=d"(high)
        			 : "%0"(numA[i]), "rm"(multiplier)
        		   );
        	carry +=(bn_data_tmp) low + c->number[i + j];
        	c->number[i + j] = carry;
        	carry = high + (carry >> DATA_BITS);
        }
        if (carry) {
        	c->number[j + a->size] = carry;
        	carry = 0UL;
        }
    }
    
    ...
}

由於我目前還無法完全理解 bignum/apm.c 的實作,所以就初步按照自己目前對程式碼的理解進行實作

實驗結果如下 (v5 綠線)

結果看起來差異不大,將範圍放大至10000項來看

結果顯示我的實作效能較差,所以看起來差異真的是在處理如何將乘積疊加到輸出變數那邊,於是我先做實驗驗證看看

void bn_mult(const bn *a, const bn *b, bn *c)
{
    ...

    bn_data *numA = a->number;
    for (int j = 0; j < b->size; j++) {
    	bn_data multiplier = b->number[j];
    	bn_data carry = 0;
        for (int i = 0; i < a->size; i++) {
        	bn_data high, low;
        	__asm__("mulq %3"
        			 : "=a"(low), "=d"(high)
        			 : "%0"(numA[i]), "rm"(multiplier)
        		   );
        	carry = high + ((low += carry) < carry);
        	carry += ((c->number[i + j] += low) < low);
        }
        c->number[j + a->size] = carry;
    }
    
    ...
}

實驗結果如下 (v6 綠線)

上述實作是利用無號整數不會 overflow 的特性,舉例假設有兩個 4 bits 的整數 ab 且兩個數皆為 4 bits 可表示最大數值 (即 15, 二進制表示成 1111),則 a + b 等於 30, 二進制表示成 11110 ,但因為是無號整數,所以只會保留 0~3 bit 的值,也就是 1110,這裡就可以看出若兩數值相加後小於其中任一數值,就表示發生 overflow 且只要 overflow 溢位的數值一定是 1 。

藉由上述特性就可以避免使用 __int128 (bn_data_tmp) 進行計算以節省多餘的記憶體開銷以及資料型態轉換成本。

改善方案 5: 改進 bn_add 的效能

原先的實作會在每次迴圈判斷需要相加的數值,這麼做的優點是只需寫一個迴圈就能完成計算,但缺點是每次迴圈都有兩個 branch 要判斷。參考 KYG-yaya573142 的實作改為使用兩個迴圈進行計算,第一個迴圈先計算兩者皆有資料的範圍,再於第二個迴圈處理 carry 與剩餘的資料範圍。另外,利用上一個改善方案所提及的無號整數不會 overflow 的特性,可以進一步可以避免使用 __int128 (bn_data_tmp) 進行計算

/* c = a + b */
void bn_add(const bn *a, const bn *b, bn *c)
{
    if (a->size < b->size) {
        SWAP(a, b);
    }
    int asize = a->size, bsize = b->size;
    bn_resize(c, a->size + 1);

    bn_data carry = 0;
    for (int i = 0; i < bsize; i++) {
    	bn_data tmp1 = a->number[i];
        bn_data tmp2 = b->number[i];
        carry = (tmp1 += carry) < carry;
        carry += (c->number[i] = tmp1 + tmp2) < tmp2;
    }
    if (asize != bsize) {  // deal with the remaining part if asize > bsize
        for (int i = bsize; i < asize; i++) {
            bn_data tmp1 = a->number[i];
            carry = (tmp1 += carry) < carry;
            c->number[i] = tmp1;
        }
    }
    if (carry) {
        c->number[asize] = carry;
    }

    if (!c->number[c->size - 1] && c->size > 1)
        bn_resize(c, c->size - 1);
}

因為要先計算兩者皆有的資料範圍,所以要先找出兩者當中範圍較小者,但這裡的做法是假設 a 的範圍比 b大,所以若遇到 a 的範圍比 b 小的情況就必須要將兩者互換,這裡值得注意的地方是原先實作的 bn_swap 函式為交換兩者的內容,但加法運算為了確保兩個輸入變數 (即 ab) 的內容不會被更改,因此將兩者皆宣告成 const ,如此一來就不能使用原先的 bn_swap,參考 bignum/apm_internal.h 當中的做法,透過定義一個巨集,交換指定的二個變數的數值。

#ifndef SWAP
#define SWAP(x, y)           \
    do {                     \
        typeof(x) __tmp = x; \
        x = y;               \
        y = __tmp;           \
    } while (0)
#endif

為了讓加法運算遇到 a == cb == c 依舊能夠正確計算,必須要在 bn_resize 之前將 ab 的大小 (size) 暫存起來。

為了凸顯 bn_add 對效能的影響,這裡改為使用迭代的方法計算費氏數列

實驗結果如下 (v1 綠線)

改善方案 6: 引入 bn_sqr

         a   b   c
      x  a   b   c
 -------------------
        ac  bc  cc
    ab  bb  bc
aa  ab  ac

考慮上述

(abc)2 的計算過程,會發現數值
ab
ac
bc
各會重複一次,利用此特性,先計算對角線任一邊的數值,接著再將數值總和乘二,最後再加上對角線上的
aa
bb
cc
。藉由此法,平方運算的成本可由本來的
n2
次乘法降為
n2n2
次乘法

實作參考 KYG-yaya573142bignum/sqr.c

void bn_sqr(bn *dest, const bn *src)
{
    ...
        
    const bn_data *sp = src->number;
    bn_data *dp = dest->number + 1;
    bn_data size = src->size - 1;
    for (int i = 0; i < size; i++) {
    	bn_data carry = 0;
        for (int j = 0; j < size - i; j++) {
            bn_data high, low;
            __asm__("mulq %3" 
                     : "=a"(low), "=d"(high)
                     : "%0"(sp[i + 1 + j]), "rm"(sp[i])
                   );
            carry = high + ((low += carry) < carry);
            carry += ((dp[j] += low) < low);
        }
        dp[size - i] = carry;
        dp += 2;
    }
    
    /* Double it */
    for (int i = 2 * src->size - 1; i > 0; i--)
        dest->number[i] = dest->number[i] << 1 | 
        dest->number[i - 1] >> (DATA_BITS - 1);
    dest->number[0] <<= 1;
    
    /* add the (aa bb cc) part at diagonal line */
    dp = dest->number;
    sp = src->number;
    size = src->size;
    bn_data carry = 0;
    for (int i = 0; i < size; i++) {
        bn_data high, low;
        __asm__("mulq %3" 
                 : "=a"(low), "=d"(high)
                 : "%0"(sp[i]), "rm"(sp[i])
               );
        high += (low += carry) < carry;
        high += (dp[0] += low) < low;
        carry = (dp[1] += high) < high;
        dp += 2;
    }
    
    ...
}

實驗結果如下 (v7 綠線)

將範圍擴大至第 20000 項就可以明顯看出改善

改善方案 7: 實作 Karatsuba algorithm

觀察 bignum/mul.cbignum/sqr.c 皆有使用 Karatsuba 演算法來加速乘法與平方運算,因此接下來一樣實作該演算法來提升效能
先放上 v7 版本與 bignum 的效能差異來觀察後續改進的成效


v7 版本的 call graph (只擷取部份內容)

#
# Children  Self      Command  Shared Object         Symbol                            
# ........  ........  .......  ....................  ..................................
#
    99.42%     0.00%  test     test                  [.] main
            |
            ---main
               fast_doubling
               |          
               |--49.78%--bn_sqr
               |          
                --49.54%--bn_mult
                          do_mult_base

Karatsuba 的概念是將

a
b
以第
n
位數為界,拆成兩半
a1
a0
b1
b0
,把這他們視為較小的數相乘,然後再透過左移補回
a1
b1
損失的位數,以二進位為例:
a=a1×2n+a0b=b1×2n+b0

a×b
可以化為:

a1b1z2×22n+(a1b0+b1a0)z1×2n+a0b0z0

上述算法計算

z2
z1
z0
需要 4 次乘法,我們還可以透過以下技巧縮減為 3 次乘法:

觀察

(a1+a0)(b1+b0) 展開的結果
(a1+a0)(b1+b0)=a1b1z2+a1b0+a0b1z1+a0b0z0

移項之後,我們就能利用

(a1+a0)(b1+b0)
z0
z2
來計算
z1

z2=a1b1
z0=a0b0

z1=(a1+a0)(b1+b0)z0z2

最後計算

z2×22n+z1×2n+z0 便能得到
a
b
相乘的結果,且
×2n
可以用左移運算代替。

再舉個例子,假設所採用的處理器只支援 8 位元乘法,當

x
y
超過 8 位元時,可以透過分治法實作 Karatsuba。
x1
x0
y1
y0
的位元數超出處理器的乘法的位數時,就把他們再切為
x11
x10
x01
x00
,再使用 Karatsuba 計算。以下以兩個 16 位元數值相乘變成 32 位元來演示

由上圖可以看出計算
z2
z1
z0
時,透過分治法將
x1
x0
y1
y0
切成更小的數字執行乘法運算。最後再用左移與加法計算
z2×216+z1×28+z0
即可求得結果。

至此可透過分治法,運用 Karatsuba 演算法計算任意位數的大數。

實作 Karatsuba 乘法

實作參考 KYG-yaya573142bignum/mul.c

程式碼解析

首先來看 bn_mult 函式的修改

/*
 * c = a x b
 * Note: work for c == a or c == b
 * using the simple quadratic-time algorithm (long multiplication)
 */
void bn_mult(const bn *a, const bn *b, bn *c)
{
    if (a->size < b->size)  // need asize > bsize
        SWAP(a, b);
    // max digits = sizeof(a) + sizeof(b))
    bn_data asize = a->size, bsize = b->size;
    int csize = asize + bsize;
    bn *tmp;
    /* make it work properly when c == a or c == b */
    if (c == a || c == b) {
        tmp = c;  // save c
        c = bn_alloc(csize);
    } else {
        tmp = NULL;
        for (int i = 0; i < c->size; i++)
            c->number[i] = 0;  // clean up c
        bn_resize(c, csize);
    }

    bn_data *ap = a->number;
    bn_data *bp = b->number;
    bn_data *cp = c->number;
    if (b->size < KARATSUBA_MUL_THRESHOLD) {
        do_mult_base(ap, asize, bp, bsize, cp);
    } else {
        do_mult_karatsuba(ap, bp, bsize, cp);
        /* it's assumed that a and b are equally length in
         * Karatsuba multiplication, therefore we have to
         * deal with the remaining part after hand */
        if (asize == bsize)
            goto end;
        /* we have to calc a[bsize ~ asize-1] * b */
        cp += bsize;
        csize -= bsize;
        ap += bsize;
        asize -= bsize;
        bn_data *_tmp = NULL;
        /* if asize = n * bsize, multiply it with same method */
        if (asize >= bsize) {
            _tmp = (bn_data *) calloc(2 * bsize, sizeof(bn_data));
            do {
                do_mult_karatsuba(ap, bp, bsize, _tmp);
                bn_data carry;
                carry = _add_partial(cp, _tmp, bsize * 2, cp);
                for (int i = bsize * 2; i < csize; i++) {
                    bn_data tmp1 = cp[i];
                    carry = (tmp1 += carry) < carry;
                    cp[i] = tmp1;
                }
                cp += bsize;
                csize -= bsize;
                ap += bsize;
                asize -= bsize;
                assert(carry == 0);
            } while (asize >= bsize);
        }
        /* if asize != n * bsize, simply calculate the remaining part */
        if (asize) {
            if (!_tmp)
                _tmp = (bn_data *) calloc(asize + bsize, sizeof(bn_data));
            do_mult_base(bp, bsize, ap, asize, _tmp);
            bn_data carry;
            carry = _add_partial(cp, _tmp, asize + bsize, cp);
            for (int i = asize + bsize; i < csize; i++) {
                bn_data tmp1 = cp[i];
                carry = (tmp1 += carry) < carry;
                cp[i] = tmp1;
            }
            assert(carry == 0);
        }
        if (_tmp)
            free(_tmp);
    }

end:
    if (!c->number[c->size - 1] && c->size > 1) // trim
        bn_resize(c, c->size - 1);
    if (tmp) {
        bn_swap(tmp, c);  // restore c
        bn_free(c);
    }
}
  • 將原先乘法運算部份改寫成一個函式 do_mult_base 並定義切分界線 KARATSUBA_MUL_THRESHOLD (bignum 範例程式定為 32),因為 do_mult_karatsuba 函式假設
    a
    b
    為相同 size ,因此判斷
    b
    size 若小於 KARATSUBA_MUL_THRESHOLD 則執行一般的乘法運算 (即 do_mult_base 函式),否則執行 do_mult_karatsuba 函式
  • 執行 do_mult_karatsuba 後若
    a
    size 大於
    b
    則要將
    a[bsize..asize1]×b
    加到
    c
    ,實作概念為判斷 asizebsize 的差若大於等於 bsize ,則使用 Karatsuba 乘法計算,直到兩者的差小於 bsize 則使用一般乘法計算

接著看 do_mult_karatsuba 函式

  • 一些初始化設置,將
    a
    b
    各自分為
    a1
    a0
    b1
    b0
void do_mult_karatsuba(const bn_data *a,
                       const bn_data *b,
                       bn_data size,
                       bn_data *c)
{
    const int odd = size & 1;
    const int even_size = size - odd;
    const int half_size = even_size / 2;

    const bn_data *a0 = a, *a1 = a + half_size;
    const bn_data *b0 = b, *b1 = b + half_size;
    bn_data *c0 = c, *c1 = c + even_size;
    ...
}

  • 計算
    a0×b0
    以及
    (a1×b1)×22n
    並加到
    c
    ,這裡用遞迴方式實作上述提到的分治法
/* c[0 ~ even_size-1] = a0*b0, c = 1*a0*b0 */
/* c[even_size ~ 2*even_size-1] += a1*b1, c += (2^2n)*a1*b1 */
if (half_size >= KARATSUBA_MUL_THRESHOLD) {
    do_mult_karatsuba(a0, b0, half_size, c0);
    do_mult_karatsuba(a1, b1, half_size, c1);
} else {
    do_mult_base(a0, half_size, b0, half_size, c0);
    do_mult_base(a1, half_size, b1, half_size, c1);
}
  • 接著來計算
    2n
    項係數 (即
    z1
    ),上述提到
    z1=(a1+a0)(b1+b0)z0z2
    , 因為
    (a1+a0)
    (b1+b0)
    為了解決溢位問題各自都要用 (half_size + 1) 的空間存放,相乘後的 size 會來到 (even_size + 2) 也就是 (size + 1) , 可想而知這樣對空間的使用效率不佳,因此可以進一步將
    z1
    的計算改寫成
    z1=|a1a0||b0b1|+z0+z2
  • (z0+z2)×2n
    加到
    c
    ,因為
    z0
    z2
    前面已經算過了,分別放在 c[0..even_size-1]c[even_size..2*even_size-1],因此只需要將該部份取出並累加到
    c
    即可
/* since we have to add a0*b0 and a1*b1 to
 * c[half_size ~ half_size+even_size-1] to obtain
 * c = (2^2n + 2^n)a1*b1 + (2^n + 1)a0*b0,
 * we have to make a copy of either a0*b0 or a1*b1 */
bn_data *tmp = (bn_data *) malloc(sizeof(bn_data) * even_size);
for (int i = 0; i < even_size; i++)
    tmp[i] = c0[i];

/* c[half_size ~ half_size + even_size-1] += a1*b1 + a0*b0
 * c += (2^n)*(a1*b1 + a0*b0)
 * now c = (2^2n)a1*b1 + (2^n)*(a1*b1 + a0*b0) + a0*b0 */
bn_data carry = 0;
for (int i = 0; i < even_size; i++) {
    bn_data in1 = c[half_size + i];
    bn_data in2 = c1[i];
    bn_data in3 = tmp[i];

    carry = (in1 += carry) < carry;
    carry += (c[half_size + i] = in1 + in2) < in2;
    carry += (c[half_size + i] += in3) < in3;
}
  • 計算
    |a1a0|
    |b0b1|
    ,可以注意到前面宣告用來暫存 c[0..even_size-1] 的變數 tmp 已經不需要使用,因此這邊可以將
    |a1a0|
    存放到 tmp[0..half_size-1]
    |b0b1|
    存放到 tmp[half_size..even_size-1], 減少了額外配置空間的成本
/* calc |a1-a0| */
bn_data *a_tmp = tmp;
bool neg = bn_cmp(a1, half_size, a0, half_size) < 0;
if (neg)
    _sub_partial(a0, a1, half_size, a_tmp);
else
    _sub_partial(a1, a0, half_size, a_tmp);

/* calc |b0-b1| */
bn_data *b_tmp = tmp + half_size;
if (bn_cmp(b0, half_size, b1, half_size) < 0) {
    _sub_partial(b1, b0, half_size, b_tmp);
    neg ^= 1;
} else {
    _sub_partial(b0, b1, half_size, b_tmp);
}
  • 計算
    |a1a0||b0b1|
    的方法與 Karatsuba 乘法相同
/* tmp = |a1-a0||b0-b1| */
tmp = (bn_data *) calloc(even_size, sizeof(bn_data));
if (half_size >= KARATSUBA_MUL_THRESHOLD)
    do_mult_karatsuba(a_tmp, b_tmp, half_size, tmp);
else
    do_mult_base(a_tmp, half_size, b_tmp, half_size, tmp);
free(a_tmp);
  • |a1a0||b0b1|×2n
    加到
    c
/* Now add / subtract (a1-a0)*(b0-b1) from
 * c[half_size..half_size+even_size-1] based on whether it is negative or
 * positive.
 */
if (neg)
    carry -= _sub_partial(c + half_size, tmp, even_size, c + half_size);
else
    carry += _add_partial(c + half_size, tmp, even_size, c + half_size);
free(tmp);
  • 將上面產生的 carry 加到
    c
/* add carry to c[even_size+half_size ~ 2*even_size-1] */
for (int i = even_size + half_size; i < even_size << 1; i++) {
    bn_data tmp1 = c[i];
    carry = (tmp1 += carry) < carry;
    c[i] = tmp1;
}  // carry should be zero now!
  • 現在已經計算好 a[0..even_size-1]
    ×
    b[0..even_size-1],但若
    a
    b
    皆具奇數 size, 舉例來說
    a=a2a1a0
    b=b2b1b0
    ,則
    a×b

    a2a1a0×b2b1b0a2b0\colorReda1b0\colorReda0b0a2b1\colorReda1b1\colorReda0b1+a2b2a1b2a0b2

    紅色圈起來的部份即為 a[0..even_size-1]
    ×
    b[0..even_size-1] ,接下來要將剩餘部份加回去
    c
    ,因此我們需要加上 a[size-1]
    ×
    b[0..size-2] 以及
    b[size-1]
    ×
    a[0..size-1]
if (odd) {
    /* We have the product a[0..even_size-1] * b[0..even_size-1] in
     * c[0..2*even_size-1].  We need to add the following to it:
     * a[size-1] * b[0..size-2]
     * b[size-1] * a[0..size-1] */
    c[even_size * 2] =
        _mult_partial(b, even_size, a[even_size], c + even_size);
    c[even_size * 2 + 1] =
        _mult_partial(a, size, b[even_size], c + even_size);
}

實驗結果如下 (v8 藍線)

實作 Karatsuba 平方運算

實作參考 KYG-yaya573142bignum/sqr.c

程式碼解析

基本上跟 karatsuba 乘法運算一樣,差別就在平方運算的乘數與被乘數一樣,因此不需要去額外處理兩者 size 不同的情況。

  • bn_sqr 函式內判斷 srcsize 是否小於 KARATSUBA_SQR_THRESHOLD (bignum bignum 範例程式定為 64),小於則執行 do_sqr_base 函式(一般的平方運算),否則執行 do_sqr_karatsuba 函式
/* c = a^2 */
void bn_sqr(bn *dest, const bn *src)
{
    // int d = a->size * 2;
    bn *tmp;
    /* make it work properly when c == a */
    if (dest == src) {
        tmp = dest;  // save c
        dest = bn_alloc(src->size * 2);
    } else {
        tmp = NULL;
        for (int i = 0; i < dest->size; i++)
            dest->number[i] = 0;  // clean up c
        bn_resize(dest, src->size * 2);
    }

    if (src->size < KARATSUBA_SQR_THRESHOLD) {
        do_sqr_base(src->number, src->size, dest->number);
    } else {
        do_sqr_karatsuba(src->number, src->size, dest->number);
    }

    if (!dest->number[dest->size - 1] && dest->size > 1) // trim
        bn_resize(dest, dest->size - 1);
    if (tmp) {
    	bn_swap(tmp, dest);
    	bn_free(dest);
    }
}
void do_sqr_base(const bn_data *src, bn_data ssize, bn_data *dest)
{
    bn_data *dp = dest + 1;
    const bn_data *sp = src;
    bn_data size = ssize - 1;
    for (int i = 0; i < size; i++) {
        /* calc the (ab bc bc) part */
        dp[size - i] = _mult_partial(&sp[i + 1], size - i, sp[i], dp);
        dp += 2;
    }

    /* Double it */
    for (int i = 2 * ssize - 1; i > 0; i--)
        dest[i] = dest[i] << 1 | dest[i - 1] >> (DATA_BITS - 1);
    dest[0] <<= 1;

    /* add the (aa bb cc) part at diagonal line */
    dp = dest;
    sp = src;
    size = ssize;
    bn_data carry = 0;
    for (int i = 0; i < size; i++) {
        bn_data high, low;
        __asm__("mulq %3" : "=a"(low), "=d"(high) : "%0"(sp[i]), "rm"(sp[i]));
        high += (low += carry) < carry;
        high += (dp[0] += low) < low;
        carry = (dp[1] += high) < high;
        dp += 2;
    }
}
  • do_sqr_karatsuba 函式實作與 do_mult_karatsuba 基本一樣,因此以下只講不同的地方
  • 計算好
    dest=(22n+2n)sp12+(2n+1)sp02
    之後要計算
    (sp1sp0)(sp0sp1)=(sp1sp0)2
/* (sp1-sp0)(sp0-sp1)  = -|sp1-sp0|^2 */
if (bn_cmp(sp1, half_size, sp0, half_size) < 0)
    _sub_partial(sp0, sp1, half_size, tmp);
else
    _sub_partial(sp1, sp0, half_size, tmp);
  • 計算好
    (sp1sp0)2
    後要計算
    dest=dest(sp1sp0)2
/* dest[half_size ~ half_size+even_size-1] += -(sp1-sp0)^2 */
carry -= _sub_partial(dest + half_size, tmp1, even_size, dest + half_size);

實驗結果如下 (v9 藍線)

圖中設定的閾值與 bignum 一樣,經實驗驗證放大或縮小閾值並不會顯著提升效能

v9 版本的 call graph (只擷取部份內容)

# Children      Self  Command  Shared Object         Symbol                            
# ........  ........  .......  ....................  ..................................
#
    95.39%     0.00%  test     test                  [.] main
            |
            ---main
               fast_doubling
               |          
               |--60.19%--bn_sqr
               |          do_sqr_karatsuba
               |          |          
               |           --59.59%--do_sqr_karatsuba
               |                     |          
               |                      --59.00%--do_sqr_karatsuba
               |                                |          
               |                                 --57.37%--do_sqr_karatsuba
               |                                           |          
               |                                           |--53.33%--do_sqr_karatsuba
               |                                           |          |          
               |                                           |          |--48.07%--do_sqr_karatsuba
               |                                           |          |          |          
               |                                           |          |          |--32.67%--do_sqr_karatsuba
               |                                           |          |          |          |          
               |                                           |          |          |          |--30.30%--do_sqr_base
               |                                           |          |          |          |          
               |                                           |          |          |           --0.59%--__memset_avx2_unaligned_erms
               |                                           |          |          |          
               |                                           |          |           --10.05%--do_sqr_base
               |                                           |          |          
               |                                           |           --3.49%--do_sqr_base
               |                                           |          
               |                                            --1.08%--do_sqr_base
               |          
                --35.21%--bn_mult
                          do_mult_karatsuba
                          |          
                          |--34.61%--do_mult_karatsuba
                          |          |          
                          |           --34.02%--do_mult_karatsuba
                          |                     |          
                          |                      --33.43%--do_mult_karatsuba
                          |                                |          
                          |                                 --31.17%--do_mult_karatsuba
                          |                                           |          
                          |                                           |--28.95%--do_mult_karatsuba
                          |                                           |          |          
                          |                                           |          |--26.10%--do_mult_karatsuba
                          |                                           |          |          |          
                          |                                           |          |          |--20.18%--do_mult_karatsuba
                          |                                           |          |          |          |          
                          |                                           |          |          |          |--17.21%--do_mult_base
                          |                                           |          |          |          |          
                          |                                           |          |          |          |--0.59%--__GI___libc_free (inlined)
                          |                                           |          |          |          |          
                          |                                           |          |          |           --0.59%--__libc_calloc
                          |                                           |          |          |                     _int_malloc
                          |                                           |          |          |          
                          |                                           |          |          |--2.96%--do_mult_base
                          |                                           |          |          |          
                          |                                           |          |           --0.59%--__GI___libc_free (inlined)
                          |                                           |          |                     _int_free
                          |                                           |          |          
                          |                                           |          |--2.26%--do_mult_base
                          |                                           |          |          
                          |                                           |           --0.59%--__libc_calloc
                          |                                           |                     _int_malloc
                          |                                           |          
                          |                                            --1.04%--do_mult_base
                          |          
                           --0.59%--__memset_avx2_unaligned_erms
                                     asm_exc_page_fault
                                     exc_page_fault
                                     do_user_addr_fault
                                     handle_mm_fault
                                     __handle_mm_fault
                                     handle_pte_fault

可見使用 Karatsuba 演算法後乘法與平方運算的時間佔比合計從 v7 版本的 99.32% 降低為 95.4% ,但仍然是程式執行時間佔比最高的運算,其中可以看到使用 Karatsuba 演算法後 call graph 出現多次相同函式的遞迴呼叫,如此可以看出 Karatsuba 演算法的實作主要依賴遞迴方式實現

改進 bn_to_string

原始 bn_to_string 的實作原理是不斷將大數除以 10 取餘數,並將大數更新為商數,直到大數為零,此作法時間複雜度為

O((digit+size)×log10(x)) , digit 為大數的二進制位元數、size 為大數的大小、
x
為大數,參考 KYG-yaya573142 的實作方式從大數的 MSB 起逐位元將數值累加到字串當中,此作法的時間複雜度為
O(digit×log10(x))

/*
 * output bn to decimal string
 * Note: the returned string should be freed with the free()
 */
char *bn_to_string(const bn *src)
{
    // log10(x) = log2(x) / log2(10) ~= log2(x) / 3.322
    size_t len = (8 * sizeof(bn_data) * src->size) / 3 + 2;
    char *s = (char *) malloc(len);
    char *p = s;

    memset(p, '0', len - 1);
    p[len - 1] = '\0';

    /* src.number[0] contains least significant bits 
     * s[len - 2] contains least significant digit
     */
    for (int i = src->size - 1; i >= 0; i--) {
        for (bn_data d = MSB_MASK; d; d >>= 1) {
            /* binary -> decimal string based on binary presentation */
            int carry = !!(d & src->number[i]);
            // add carry to p[len-2 .. 0]
            for (int j = len - 2; j >= 0; j--) {
                p[j] += p[j] - '0' + carry;
                carry = (p[j] > '9');
                if (carry)
                    p[j] -= 10;
            }
        }
    }
    // skip leading zero
    while (p[0] == '0' && p[1] != '\0') {
        p++;
    }
    memmove(s, p, strlen(p) + 1);
    return s;
}

實驗結果如下 (v1 綠線)

參考 bignum/format.c 的實作方式,先定義 max_radix

1019 ,相當於 uint64_t 可表示的數值範圍中最大的 10 的冪的值,每次迴圈藉由將大數除以 max_radix 獲得一個 uint64_t 可表示的 10 進制的值,再用一般 2 進制轉 10 進制的方法將數值存放到字串當中

目前看到 bignum/format.c 裡計算

log10(x) 的部份 (即 apm_string_size 函式),當中如果 radix 不為 2 的冪時,結果是回傳 (radix_sizes[radix] * (size * APM_DIGIT_SIZE)) + 2 ,可以理解為了要無條件進位所以 +1 ,但為什麼會是 +2 呢?

目前的猜測是因為要彌補乘以 radix_sizes[radix] 所產生的誤差ericlai1021

程式碼解析
void bn_fprint(bn_data *sp, bn_data size, FILE *fp)
{
	const size_t len = ((radix_size * (size * BN_WSIZE)) + 2) + 1;
	char *str = (char *) malloc(len);
	char *p = bn_to_string(sp, size, str);
	fprintf(fp, "%s\n", p);
	free(str);
}
  • 配置
    log10(x)
    大小的空間存放轉換後的值,radix_size 為表示 1 Byte 的數值所需的 10 進制位數, BN_WSIZE 為一個 word 的大小 (單位為 Byte ),最後加上一個字元的大小存放字串結尾符號 \0
  • bn_to_string 函式傳入大數的數字陣列開頭指標、大數的 size 及輸出的字串的開頭指標,並回傳轉換後的字串開頭指標

bn_to_string 函式

  • 分成多精度運算與單精度運算,多精度運算會將大數除以 MAX_RADIX (即
    1019
    ) 取得餘數為 uint64_t 可表示的大數轉換成 10 進制的數值,將大數更新為商數;接著將餘數透過單精度運算取得 10 進制的個別位數
do {
    /* Multi-precision: divide U by largest power of RADIX to fit in
     * one apm_digit and extract remainder.
     */
    bn_data remainder = bn_ddivi(sp, size, MAX_RADIX);
    size -= (sp[size - 1] == 0U);

    /* Single-precision: extract K remainders from that remainder,
     * where K is the largest integer such that RADIX^K < 2^BITS.
     */
    unsigned int i = 0;
    do {
        bn_data rq = remainder / 10;
        bn_data rr = remainder % 10;
        *outp++ = radix_chars[rr];
        remainder = rq;
        if (size == 0 && remainder == 0) /* Eliminate any leading zeroes */
            break;
    } while (++i < MAX_POWER);
    /* Loop until TMP = 0. */
} while (size != 0);

bn_ddivi 函式

  • 此函式傳入大數的數字陣列、大數的 sizeMAX_RADIX ,將大數除以 MAX_RADIX ,商數更新為新的大數,並回傳餘數
  • 為了有效解決除法結果溢位的問題,使用內嵌組合語言 (inline assembly) 指令 divq ,此指令的輸入為被除數的高位與低位以及除數,並輸出商數與餘數
static bn_data bn_ddivi(bn_data *sp, bn_data size, bn_data div)
{
    if (div == 1)
        return 0;
    if (!size)
        return 0;

    bn_data s1 = 0;
    sp += size;
    do {
        bn_data s0 = *--sp;
        bn_data q, r;
        if (s1 == 0) {
            q = s0 / div;
            r = s0 % div;
        } else {
            digit_div(s1, s0, div, q, r); // use inline assembly
        }
        *sp = q;
        s1 = r;
    } while (--size);

    return s1;
}

實驗結果如圖 (v2 綠線)

放上與 bignum 的比較

使用 python 撰寫一驗證程式 verify.py ,確保程式至少可以正確計算到第一百萬項

$ time ./test 1000000
Fib[1000000]

real	0m3.331s
user	0m3.331s
sys	0m0.000s

$ python3 verify.py
Please input a number: 1000000    
Fib[1000000] -Pass
congratulations, you pass all test!!!

先附上最終的 程式碼 ,後續會逐步更新至 GitHub

減少 copy_to_user 傳送的資料量

原先對 fibdrv 呼叫 read(fd, buf, size) 時,會在 kernel space 將計算好的大數結構體 bn 轉換成十進制表示存放於字串當中,並透過 copy_to_user 將該字串從 kernel space 複製到 user space,但由於轉換後的字串每個字元只會存放 0~9 其中一個數值,因此光是傳遞一個字元就會浪費 4 位元的空間,為了減少空間的浪費,可以將大數轉換成十進制的操作 (即 bn_to_string 函式) 搬到 user space 來執行,讓 copy_to_user 直接傳遞大數結構體當中的二進制數值。
參考 作業說明 當中的實作,先計算大數的 leading zeros ,接著呼叫 copy_to_user 時不傳送全為 0 的位元組。

static size_t my_copy_to_user(const bn *src, char __user *buf)
{
    int lzbyte = bn_clz(src) >> 3;
    size_t size = sizeof(bn_data) * src->size - lzbyte;

    kt = ktime_get();
    size_t sz = copy_to_user(buf, src->number, size);
    kt = ktime_sub(ktime_get(), kt);

    return size;
}

bn_clz 函式搭配 GCC 內建函式 __builtin_clzll 來計算大數的 leading zero bits 數量,其中 >> 3 右移操作計算 leading zeros 的位元組數量。針對 little-endian 架構,非零的位元組會被存在較低的記憶體位址,因此呼叫 copy_to_user 時只需要傳送 數字陣列總 byte 數 - leading zero byte 就可以不傳送全為 0 的位元組。

將複製的 byte 數量作為 read 的回傳值傳回 user

/* calculate the fibonacci number at given offset */
static ssize_t fib_read(struct file *file,
                        char *buf,
                        size_t size,
                        loff_t *offset)
{
    bn *fib = bn_alloc(1);
    bn_fib_fast(fib, *offset);
 
    size_t sz = my_copy_to_user(fib, buf);
    
    bn_free(fib);
    return sz;  // return number of bytes that could not be copied
}

在 user space 中使用 memcpybuf 字串內容複製到數字陣列後就可以執行 bn_to_string 函式將此數字陣列轉換成十進制字串表示

...
lseek(fd, i, SEEK_SET);
size_t sz = read(fd, buf, 20900);

size_t size = (sz >> 3) + ((sz << 61) > 0);
uint64_t *number = malloc(sizeof(uint64_t) * size);
memcpy(number, buf, sz);
char *p = bn_to_string(number, size);
...

實驗結果如下 (計算到第 10 萬項,時間為 copy_to_user 函式執行的時間)

  • kernel 表示 bn_to_string 函式執行在 kernel space ,因此 copy_to_user 會傳送轉換後的字串
  • user 表示 bn_to_string 函式執行在 user space , copy_to_user 會直接傳送大數的數字陣列
  • 結果看出 copy_to_user 直接傳送大數的數字陣列確實可以有效節省空間

使用 hashtable 儲存已計算的 Fibonacci 數

參考資料: Linux 核心的 hash table 實作chiangkd 同學的共筆

初步引入 hashtable

預期引入 Linux 核心的 hlist 系列 API 儲存已經計算過的值,目前實作以

Fib(n) 中的 n 作為 key

Linux 核心的 hash table 實作中,用以處理 hash 數值碰撞的 hlist_node:

struct hlist_node {
    struct hlist_node *next, **pprev;
};

示意圖如下 :







G


cluster_3

hash_key 3


cluster_1

hash_key 1



map

hlist_head.first

 

 

 

 

 

 

 

 



hn1

hlist_node

pprev

next



map:ht1->hn1





hn3

hlist_node

pprev

next



map:ht5->hn3





null1
NULL



null2
NULL



hn1:s->map:ht1





NULL
NULL



hn1:next->NULL





hn3:s->map:ht5





hn3:next->null2





新增一個自定義結構 hdata_node 嵌入 hlist_node 並包含一個指向大數結構體 bn 的指標,用以儲存 value

typedef struct _hdata_node {
    bn *data;
    struct hlist_node list;
} hdata_node;






G


cluster_A

hdata_node


cluster_bn

bn


cluster_1

hash_key 1


cluster_3

hash_key 3



map

hlist_head.first

 

 

 

 

 

 

 

 



hn1

hlist_node

pprev

next



map:ht1->hn1





hn3

hlist_node

pprev

next



map:ht5->hn3





null1
NULL



null2
NULL



hn1:s->map:ht1





hn1:next->null1





bn1

number

size



hn3:s->map:ht5





hn3:next->null2





Fib(n) 當中的 n 作為 hashtable 的 key ,value 則指向計算後的 bn 結構體

/* calculate the fibonacci number at given offset */
static ssize_t fib_read(struct file *file,
                        char *buf,
                        size_t size,
                        loff_t *offset)
{
    bn *fib = NULL;
    int key = (int) *offset;
    /* hashtable method*/
    kt = ktime_get();
    if (is_in_ht(offset)) {
        printk(KERN_INFO "find offset = %d\n", key);
        fib = hlist_entry(htable[key].first, hdata_node, list)->data;
    } else {
        fib = bn_alloc(1);
        dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
        if (dnode == NULL)
            printk("kcalloc failed \n");
        bn_fib_fast(fib, *offset);
        dnode->data = fib;
        INIT_HLIST_NODE(&dnode->list);
        hlist_add_head(&dnode->list, &htable[key]);  // add to hash table
    }
    kt = ktime_sub(ktime_get(), kt);

    size_t sz = my_copy_to_user(fib, buf);
    
    return sz;
}

呼叫 fib_read 函式時先判斷 hashtable 該 key 值是否有值,若有值則直接從 hashtable 中取用,因為這裡 hashtable 設計為將

Fib(n) 的 n 作為 key 值,所以不會發生 collision。
is_in_ht(offset) 函式判斷 hashtable 中的第 offset 個位址當中是否有值

static int is_in_ht(loff_t *offset)
{
    int key = (int) *(offset);
    if (hlist_empty(&htable[key])) {
        printk(KERN_INFO "No find in hash table\n");
        return 0; /* no in hash table */
    }
    return 1;
}
  • hlist_empty 函式判斷該 key 值對應到的 list 是否有值

執行 client ,測試程式為先計算

Fib(0)
Fib(100)
,接著再反著計算回來,使用 printk 測試是否有正確運行

[ 1305.396063] No find in hash table
[ 1305.396097] No find in hash table
[ 1305.396107] No find in hash table
...
[ 1305.397508] No find in hash table
[ 1305.397525] find offset = 100
[ 1305.397532] find offset = 99
[ 1305.397539] find offset = 98
...
[ 1305.398260] find offset = 0

在使用 rmmod 卸載 fibdrv 模組時會呼叫帶有 __exit macro 的函式, kernel 會將這個函式放入 read-only 的 __exit section 中,可參閱 linux/init.h
因此,需要在帶有 __exit macro 的 exit_fib_dev 函式當中對 hashtable 進行記憶體釋放

static void __exit exit_fib_dev(void)
{
    release_memory();
    printk(KERN_INFO "successful release memory.\n");
    ...
}
...
[ 1305.398254] find offset = 1
[ 1305.398260] find offset = 0
[ 1331.330755] successful release memory.

release_memory 函式當中使用 hlist_for_each_entry_safe 走訪hashtable 當中的 list 的每個節點,將該節點中的大數結構體釋放並將該節點從 list 當中移除

static void release_memory(void)
{
    struct hlist_node *n = NULL;
    /* go through and free hashtable */
    for (int i = 0; i < MAX_LENGTH; i++) {
        hlist_for_each_entry_safe(dnode, n, &htable[i], list)
        {
            bn_free(dnode->data);
            hlist_del(&dnode->list);
            kfree(dnode);
        }
    }
}

實驗結果如下

量測計算

Fib(0)
Fib(100000)
,並從
Fib(100000)
Fib(0)
的時間

  • 圖中的 x 座標不是
    Fib(n)
    ,而是第 n 次的時間量測,x 座標 100001200002 為由
    Fib(100000)
    Fib(0)
    的時間測量

在前半部份,也就是計算

Fib(0)
Fib(100000)
時,因為 hashtable 當中都沒有值,所以會如同原先的實作一樣,而後半部份因為 hashtable 中都已經有儲存計算過的值,所以會直接從 hashtable 中取用。

將後半部份的資料獨立出來看,可以看到整體趨勢已經是常數時間了

引入 hashtable 機制至 fast doubling 演算法加速大數運算

概念如同上述作法,將 hashtable 機制引入到 bn_fib_fast 函式內,考慮 fast doubling 的特性,紀錄第 N 個和第 2N 個 Fibonacci 數

  • 若 n 小於 2 ,則直接讓大數等於 n 並將其加入到 hashtable 當中, key 與 value 皆為 n
if (n < 2) {  // Fib(0) = 0, Fib(1) = 1
    dest->number[0] = n;
    dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
    if (dnode == NULL)
        printk("kcalloc failed \n");
    dnode->data = dest;
    INIT_HLIST_NODE(&dnode->list);
    key = n;
    hlist_add_head(&dnode->list, &htable[key]);  // add to hash table
}
  • 若 n 大於等於 2 ,則執行 fast doubling 演算法,主要可以分為兩部份的計算,一部分為計算
    F(2n)
    ,另一部份為計算
    F(2n+1)
    ,並將其計算結果存於 hashtable 中
else {
    bn *tmp = NULL;
    bn *b = bn_alloc(1);
    tmp = hlist_entry(htable[0].first, hdata_node, list)->data; // extrct F(0)
    printk(KERN_INFO "find offset = %d\n", 0);
    bn_cpy(b, tmp); // copy F(0) to b

    tmp = hlist_entry(htable[1].first, hdata_node, list)->data; // extrct F(1)
    printk(KERN_INFO "find offset = %d\n", 1);
    bn_cpy(dest, tmp); // copy F(1) to dest
    /* F(2n - 1) = F(n)^2 + F(n - 1)^2
     * F(2n) = F(n) * (2F(n - 1) + F(n))
     */
    bn *t1 = bn_alloc(1);
    int nbits = 32 - __builtin_clz(n);
    key = 1;
    for (int i = nbits - 2; i >= 0; i--) {
        key <<= 1; // key = F(2n)
        if (is_in_ht(key)) {
            printk(KERN_INFO "find offset = %d\n", key);
            tmp = hlist_entry(htable[key].first, hdata_node, list)->data; // extract F(2n)
            bn_cpy(dest, tmp); // copy F(2n) to dest
            tmp = hlist_entry(htable[key - 1].first, hdata_node, list)->data; // extract F(2n - 1)
            bn_cpy(b, tmp); // copy F(2n - 1) to b
        } else {
            bn_lshift(t1, b, 1); // t1 = F(n - 1) * 2
            bn_add(t1, dest, t1); // t1 = 2F(n - 1) + F(n)	
            bn_mult(dest, t1, t1); // t1 = F(n) * (2F(n - 1) + F(n)), now is F(2n)
            bn_sqr(b, b); // b = F(n - 1)^2
            bn_sqr(dest, dest); // dest = F(n)^2
            bn_add(dest, b, b); // b = F(n)^2 + F(n - 1)^2, now is F(2n - 1)

            bn_swap(dest, t1); // dest = F(2n)

            /* add F(2n) to hashtable */
            bn *tmp1 = bn_alloc(1);
            bn_cpy(tmp1, dest);
            dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
            if (dnode == NULL)
                printk("kcalloc failed \n");
            dnode->data = tmp1;
            INIT_HLIST_NODE(&dnode->list);
            hlist_add_head(&dnode->list, &htable[key]);  // add to hash table
        }
        if (n & (1U << i)) {
            key++;
            if (is_in_ht(key)) {
                printk(KERN_INFO "find offset = %d\n", key);
                bn_cpy(b, dest); // copy F(2n) to b
                tmp = hlist_entry(htable[key].first, hdata_node, list)->data; // extract F(2n + 1)
                bn_cpy(dest, tmp); // copy F(2n + 1) to dest
            } else {
                bn_swap(dest, b); // b = F(2n)
                bn_add(dest, b, dest); // dest = F(2n + 1)

                /* add F(2n + 1) to hashtable */
                bn *tmp2 = bn_alloc(1);
                bn_cpy(tmp2, dest);
                dnode = kcalloc(1, sizeof(hdata_node), GFP_KERNEL);
                if (dnode == NULL)
                    printk("kcalloc failed \n");
                dnode->data = tmp2;
                INIT_HLIST_NODE(&dnode->list);
                hlist_add_head(&dnode->list, &htable[key]);  // add to hash table
            }
        }
    }
    dest = hlist_entry(htable[key].first, hdata_node, list)->data;

    bn_free(t1);
    bn_free(b);
}
  • 迴圈每一輪都會先判斷 key 為
    2n
    2n+1
    是否已存在 hashtable 裡,若已有值,則直接從 hashtable 當中取出,否則才會執行一般 fast doubling 演算法對
    F(2n)
    F(2n+1)
    的計算

先用 dmesg 觀察執行過程

[90529.842755] find offset = 0
[90529.842763] find offset = 1
[90529.842767] No find in hash table // calculate F(2)
[90529.842782] find offset = 0
[90529.842785] find offset = 1
[90529.842788] find offset = 2
[90529.842790] No find in hash table // calculate F(3)
[90529.842804] find offset = 0
[90529.842807] find offset = 1
[90529.842810] find offset = 2
[90529.842813] No find in hash table // calculate F(4)
[90529.842827] find offset = 0
[90529.842830] find offset = 1
[90529.842832] find offset = 2
[90529.842835] find offset = 4
[90529.842838] No find in hash table // calculate F(5)

...

[90529.845965] find offset = 0
[90529.845968] find offset = 1
[90529.845971] find offset = 2
[90529.845973] find offset = 3
[90529.845976] find offset = 6
[90529.845979] find offset = 12
[90529.845982] find offset = 24
[90529.845984] find offset = 25
[90529.845987] find offset = 50
[90529.845990] No find in hash table // calculate F(100)

實驗結果如下 (計算至第十萬項)