Try   HackMD

L04: fibdrv

主講人: jserv / 課程討論區: 2023 年系統軟體課程

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →
返回「Linux 核心設計/實作」課程進度表

加速 Fibonacci 運算

應用 Fast Doubling 手法,搭配 bitwise 運算開發遞迴版的程式碼:

static inline uint64_t fast_doubling(uint32_t target)
{
    if (target <= 2)
        return !!target;

    // fib(2n) = fib(n) * (2 * fib(n+1) − fib(n))
    // fib(2n+1) = fib(n) * fib(n) + fib(n+1) * fib(n+1)
    uint64_t n = fast_doubling(target >> 1);
    uint64_t n1 = fast_doubling((target >> 1) + 1);

    // check 2n or 2n+1
    if (target & 1)
        return n * n + n1 * n1;
    return n * ((n1 << 1) - n);
}

運用 fast doubling 計算,可降低運算成本:

  • iterative 方法的時間複雜度為
    O(n)
  • fast doubling 的時間複雜度降為
    O(logn)

在費式數列的計算上,原本使用迭代方式計算,迴圈迭代次數與欲求費式數成正比,時間複雜度爲

O(n)。運用 fast doubling 後,至多只要迭代 64 (或 32,依設定有所不同)次,實際上去除 MSB 為 0 的 bit 不用做迭代,時間複雜度為
O(logn)

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

不過,上述方法雖減少計算量,但仍有重複計算的部份:







FIB



3a

3



2a

2



3a->2a





1a

1



3a->1a





3b

3



2c

2



3b->2c





1b

1



3b->1b





2b

2



6

6



6->3a





4

4



6->4





4->3b





4->2b





target 數值越大,重複的計算會效能衝擊越顯著。

Bottom-up Fast Doubling

觀察數值的 2 進制表示,可發現該數是如何產生,以

8710 為例:

                87 =  1010111            (87 >> i+1)
        i = 0 : 43 = (1010111 - 1) >> 1 = 101011
        i = 1 : 21 = ( 101011 - 1) >> 1 = 10101
        i = 2 : 10 = (  10101 - 1) >> 1 = 1010 
        i = 3 :  5 = (   1010 - 0) >> 1 = 101 
        i = 4 :  2 = (    101 - 1) >> 1 = 10 
        i = 5 :  1 = (     10 - 0) >> 1 = 1
        i = 6 :  0 = (      1 - 1) >> 1 = 0
                      	     ^
                         87 的第 i 個位元

若是進行移項並反過來看的話會變成:

                   (87 >> i+1)
        i = 6 :  1 =        0 << 1 + 1 = 1
        i = 5 :  2 =        1 << 1 + 0 = 10
        i = 4 :  5 =       10 << 1 + 1 = 101
        i = 3 : 10 =      101 << 1 + 0 = 1010
        i = 2 : 21 =     1010 << 1 + 1 = 10101
        i = 1 : 43 =    10101 << 1 + 1 = 101011
        i = 0 : 87 =   101011 << 1 + 1 = 1010111
                                     ^
                              87 的第 i 個位元

n=0 開始看,可發現每次位移後,只要檢查目標數對應的位元,即可知曉下次應以
fib(2n)
還是
fib(2n+1)
為基礎進行右移。

整理前述觀察,可知:

  1. 從最高位元的 1 開始,此時
    n=1
    ,而:
    • fib(n)=1
    • fib(n+1)=1
  2. 若下一個位元不存在的話跳到第 3 步,否則(假設目前為
    n=k
    ):
    • 透過
      fib(k)
      以及
      fib(k+1)
      計算
      fib(2k)
      以及
      fib(2k+1)
    • 檢查下一個位元:
      • 0:
        n=2k
      • 1:
        n=2k+1
        ,此時需要
        fib(n+1)
        讓下一迭代能夠計算
        fib(2n)
        以及
        fib(2n+1)
    • 回到第 2 步
  3. 此時
    n
    為目標數,回傳
    fib(n)

對應的程式碼:

static inline uint64_t fast_doubling_iter(uint64_t target)
{
    if (target <= 2)
        return !!target;

    // find first 1
    uint8_t count = 63 - __builtin_clzll(target);
    uint64_t fib_n0 = 1, fib_n1 = 1;
    
    for (uint64_t i = count, fib_2n0, fib_2n1; i-- > 0;) {
        fib_2n0 = fib_n0 * ((fib_n1 << 1) - fib_n0);
        fib_2n1 = fib_n0 * fib_n0 + fib_n1 * fib_n1;

        if (target & (1UL << i)) {
            fib_n0 = fib_2n1;
            fib_n1 = fib_2n0 + fib_2n1;

        } else {
            fib_n0 = fib_2n0;
            fib_n1 = fib_2n1;
        }
    }
    return fib_n0;
}

而迴圈內 if...else... 的部份可用 -!!(target & (1 << i)) 作為 mask 的技巧簡化成:

static inline uint64_t fast_doubling_iter(uint64_t target)
{
    if (target <= 2)
        return !!target;

    // find first 1
    uint8_t count = 63 - __builtin_clzll(target);
    uint64_t fib_n0 = 1, fib_n1 = 1;

    for (uint64_t i = count, fib_2n0, fib_2n1, mask; i-- > 0;) {
        fib_2n0 = fib_n0 * ((fib_n1 << 1) - fib_n0);
        fib_2n1 = fib_n0 * fib_n0 + fib_n1 * fib_n1;

        mask = -!!(target & (1UL << i));
        fib_n0 = (fib_2n0 & ~mask) + (fib_2n1 & mask);
        fib_n1 = (fib_2n0 & mask) + fib_2n1;
    }
    return fib_n0;
}

F(92)
以後的數值錯誤的原因

初次執行 client 會發現從

F(92) 之後輸出的數值都一樣,這是因為 fibdrv 中預設限制最大項目為
92

/* MAX_LENGTH is set to 92 because
 * ssize_t can't fit the number > 92
 */
#define MAX_LENGTH 92

fib_read 返回值的型態為 long long,即 64 位元有號整數,其有效範圍是

26411
264
之間,比對費氏數列的正確值,可確認
F(93)
會超出此範圍,這也是預設限制最大可用項目為 92 的原因

F(0)     = 0
F(1)     = 1
...
F(91)    = 4660046610375530309 
F(92)    = 7540113804746346429

2^63 - 1 = 9223372036854775808

F(93)    = 12200160415121876738

移除限制並重新觀察輸出,會從 F(93) 開始 overflow

F(92)    = 7540113804746346429
F(93)    = -6246583658587674878
F(94)    = 1293530146158671551

雖然結果 overflow,但可根據二補數,算出 overflow 後為何是這個數值

if(A+B)>TMax(overflow)result=A+B264=F(91)+F(92)264=6246583658587674878

將使用的資料由 long long 更改為 uint64_t,可多計算出一項正確的數值

F(93),不過從
F(94)
開始仍會 overflow

F(92)    = 7540113804746346429
F(93)    = 12200160415121876738
F(94)    = 1293530146158671551
F(95)    = 13493690561280548289

一樣可以檢驗 overflow 後為何是這個數值

if(A+B)>UMax(overflow)result=A+B(mod2w1)=A+B264=F(92)+F(93)264=1293530146158671551

初步支援大數運算

bn 結構體

為了計算 92 項以後的費氏數列,我們引入長度可變動的數值表示法,動態配置不同大小的記憶體來呈現更大範圍的整數,定義的資料結構如下

/* number[size - 1] = msb, number[0] = lsb */
typedef struct _bn {
    unsigned int *number;
    unsigned int size;
    int sign;
} bn;
  • number - 指向儲存的數值,之後會以 array 的形式來取用
  • size - 配置的記憶體大小,單位為 sizeof(unsigned int)
  • sign - 0 為正數、1 為負數

由於大數沒辦法直接以數值的形式列出,這裡改用字串來呈現,轉換的部分利用 ASCII 的特性並根據 fast doubling 的邏輯來「組合」出 10 進位字串

/* 
 * output bn to decimal string
 * Note: the returned string should be freed with kfree()
 */
char *bn_to_string(bn src)
{
    // log10(x) = log2(x) / log2(10) ~= log2(x) / 3.322
    size_t len = (8 * sizeof(int) * src.size) / 3 + 2 + src.sign;
    char *s = kmalloc(len, GFP_KERNEL);
    char *p = s;

    memset(s, '0', len - 1);
    s[len - 1] = '\0';

    for (int i = src.size - 1; i >= 0; i--) {
        for (unsigned int d = 1U << 31; d; d >>= 1) {
            /* binary -> decimal string */
            int carry = !!(d & src.number[i]);
            for (int j = len - 2; j >= 0; j--) {
                s[j] += s[j] - '0' + carry; // double it
                carry = (s[j] > '9');
                if (carry)
                    s[j] -= 10;
            }
        }
    }
    // skip leading zero
    while (p[0] == '0' && p[1] != '\0') { 
        p++;
    }
    if (src.sign)
        *(--p) = '-';
    memmove(s, p, strlen(p) + 1);
    return s;
}

加法與減法

加法與減法由於需要考慮數值的正負號,因此分為兩個步驟,先由 bn_addbn_sub 判斷結果的正負號,再使用輔助函數 bn_do_addbn_do_sub 進行無號整數的計算

/* c = a + b 
 * Note: work for c == a or c == b
 */
void bn_add(const bn *a, const bn *b, bn *c)
{
    if (a->sign == b->sign) { // both positive or negative
        bn_do_add(a, b, c);
        c->sign = a->sign;
    } else { // different sign
        if (a->sign)  // let a > 0, b < 0
            SWAP(a, b);
        int cmp = bn_cmp(a, b);
        if (cmp > 0) {
            /* |a| > |b| and b < 0, hence c = a - |b| */
            bn_do_sub(a, b, c);
            c->sign = 0;
        } else if (cmp < 0) {
            /* |a| < |b| and b < 0, hence c = -(|b| - |a|) */
            bn_do_sub(b, a, c);
            c->sign = 1;
        } else {
            /* |a| == |b| */
            bn_resize(c, 1);
            c->number[0] = 0;
            c->sign = 0;
        }
    }
}

/* c = a - b 
 * Note: work for c == a or c == b
 */
void bn_sub(const bn *a, const bn *b, bn *c)
{
    /* xor the sign bit of b and let bn_add handle it */
    bn tmp = *b;
    tmp.sign ^= 1; // a - b = a + (-b)
    bn_add(a, &tmp, c);
}
  • 分類的方法參考 bignum
  • bn_add 負責所有正負號的判斷,所以 bn_sub 只是改變 b 的正負號後,再直接交給 bn_add 判斷
    • 但不能直接改變 b 的數值,所以這裡使用 tmp 來暫時的賦予不同的正負號
  • bn_cmp 負責比對兩個 bn 物件開絕對值後的大小,邏輯類似 strcmp
/* |c| = |a| + |b| */
static void bn_do_add(const bn *a, const bn *b, bn *c)
{
    // max digits = max(sizeof(a), sizeof(b)) + 1
    int d = MAX(bn_msb(a), bn_msb(b)) + 1;
    d = DIV_ROUNDUP(d, 32) + !d;
    bn_resize(c, d); // round up, min size = 1

    unsigned long long int carry = 0;
    for (int i = 0; i < c->size; i++) {
        unsigned int tmp1 = (i < a->size) ? a->number[i] : 0;
        unsigned int tmp2 = (i < b->size) ? b->number[i] : 0;
        carry += (unsigned long long int) tmp1 + tmp2;
        c->number[i] = carry;
        carry >>= 32;
    }

    if (!c->number[c->size - 1] && c->size > 1)
        bn_resize(c, c->size - 1);
}
  • 加法的部分比較簡單,只須確保 c 的大小足以儲存計算結果
  • DIV_ROUNDUP 的用法參考自 /arch/um/drivers/cow_user.c
  • 使用 8 bytes 大小的 carry 來實行兩個 4 bytes 項目的加法來避免 overflow
    • 等號右方記得要先將其中一方進行 integer promotion,不然會先被 truncated 然後才 implicit integer promotion
  • bn_msbbn_clz 是 bn 版本的 clz,詳見 bn_kernel.c
/* 
 * |c| = |a| - |b| 
 * Note: |a| > |b| must be true
 */
static void bn_do_sub(const bn *a, const bn *b, bn *c)
{
    // max digits = max(sizeof(a), sizeof(b))
    int d = MAX(a->size, b->size);
    bn_resize(c, d);

    long long int carry = 0;
    for (int i = 0; i < c->size; i++) {
        unsigned int tmp1 = (i < a->size) ? a->number[i] : 0;
        unsigned int tmp2 = (i < b->size) ? b->number[i] : 0;
        carry = (long long int) tmp1 - tmp2 - carry;
        if (carry < 0) {
            c->number[i] = carry + (1LL << 32);
            carry = 1;
        } else {
            c->number[i] = carry;
            carry = 0;
        }
    }
    
    d = bn_clz(c) / 32;
    if (d == c->size)
        --d;
    bn_resize(c, c->size - d);
}
  • 實際上使用無號整數進行計算,因此若絕對值相減會小於 0,需先對調 ab,並於計算完成後再再補上負號
  • 計算的邏輯和 bn_do_add 一樣,不過此時 carry 是作為借位使用

乘法

/* 
 * c = a x b
 * Note: work for c == a or c == b
 * using the simple quadratic-time algorithm (long multiplication)
 */
void bn_mult(const bn *a, const bn *b, bn *c)
{
    // max digits = sizeof(a) + sizeof(b))
    int d = bn_msb(a) + bn_msb(b);
    d = DIV_ROUNDUP(d, 32) + !d; // round up, min size = 1
    bn *tmp;
    /* make it work properly when c == a or c == b */
    if (c == a || c == b) {
        tmp = c; // save c
        c = bn_alloc(d);
    } else {
        tmp = NULL;
        for (int i = 0; i < c->size; i++)
            c->number[i] = 0;
        bn_resize(c, d);
    }
    
    for (int i = 0; i < a->size; i++) {
        for (int j = 0; j < b->size; j++) {
            unsigned long long int carry = 0;
            carry = (unsigned long long int) a->number[i] * b->number[j];
            bn_mult_add(c, i + j, carry);
        }
    }
    c->sign = a->sign ^ b->sign;

    if (tmp) {
        bn_swap(tmp, c); // restore c
        bn_free(c);
    }
}
  • 目前採用最簡單的 long multiplication,就像手算乘法一樣疊加上去
  • 與加減法不同,若 c == a || c == b,就必須配置記憶體來儲存計算結果,避免 ab 在計算途中就被改變
  • 輔助函式 bn_mult_add 負責將每一行的計算結果疊加上去,如下
/* c += x, starting from offset */
static void bn_mult_add(bn*c, int offset, unsigned long long int x)
{
    unsigned long long int carry = 0;
    for (int i = offset; i < c->size; i++) {
        carry += c->number[i] + (x & 0xFFFFFFFF);
        c->number[i] = carry;
        carry >>= 32;
        x >>= 32;
        if (!x && ! carry) //done
            return;
    }
}

位移操作

/* left bit shift on bn (maximun shift 31) */
void bn_lshift(const bn *src, size_t shift, bn *dest)
{
    size_t z = bn_clz(src);
    shift %= 32;  // only handle shift within 32 bits atm
    if (!shift)
        return;

    if (shift > z) {
        bn_resize(dest, src->size + 1);
    } else {
        bn_resize(dest, src->size);
    }
    /* bit shift */
    for (int i = src->size - 1; i > 0; i--)
        dest->number[i] =
            src->number[i] << shift | src->number[i - 1] >> (32 - shift);
    dest->number[0] = src->number[0] << shift;
}
  • 如果要移動超過 32 bits 會比較麻煩,考量目前不會有這種需求,先以較簡單的方式實作

swap

void bn_swap(bn *a, bn *b)
{
    bn tmp = *a;
    *a = *b;
    *b = tmp;
}
  • bn 資料結構中 number 紀錄的是指標,因此這麼做可以確實的互換兩個 bn 的數值,但不用更動儲存在 heap 中的數值

正確計算
F(92)
以後的數值

使用實作的大數運算來計算第

92 項以後的費氏數列,以下展示迭代演算法

/* calc n-th Fibonacci number and save into dest */
void bn_fib(bn *dest, unsigned int n)
{
    bn_resize(dest, 1);
    if (n < 2) { //Fib(0) = 0, Fib(1) = 1
        dest->number[0] = n;
        return;
    }

    bn *a = bn_alloc(1);
    bn *b = bn_alloc(1);
    dest->number[0] = 1;

    for (unsigned int i = 1; i < n; i++) {
        bn_swap(b, dest);
        bn_add(a, b, dest);
        bn_swap(a, b);
    }
    bn_free(a);
    bn_free(b);
}

接著是 fast doubling 的實作

void bn_fib_fdoubling(bn *dest, unsigned int n)
{
    bn_resize(dest, 1);
    if (n < 2) { //Fib(0) = 0, Fib(1) = 1
        dest->number[0] = n;
        return;
    }

    bn *f1 = dest; /* F(k) */
    bn *f2 = bn_alloc(1); /* F(k+1) */
    f1->number[0] = 0;
    f2->number[0] = 1;
    bn *k1 = bn_alloc(1);
    bn *k2 = bn_alloc(1);

    for (unsigned int i = 1U << 31; i; i >>= 1) {
        /* F(2k) = F(k) * [ 2 * F(k+1) – F(k) ] */
        bn_cpy(k1, f2);
        bn_lshift(k1, 1);
        bn_sub(k1, f1, k1);
        bn_mult(k1, f1, k1);
        /* F(2k+1) = F(k)^2 + F(k+1)^2 */
        bn_mult(f1, f1, f1);
        bn_mult(f2, f2, f2);
        bn_cpy(k2, f1);
        bn_add(k2, f2, k2);
        if (n & i) {
            bn_cpy(f1, k2);
            bn_cpy(f2, k1);
            bn_add(f2, k2, f2);
        } else {
            bn_cpy(f1, k1);
            bn_cpy(f2, k2);
        }
    }
    bn_free(f2);
    bn_free(k1);
    bn_free(k2);
}

我們可使用 Python 程式碼進行驗證,至少能正確計算至第

100000

def read_file(filename):
    f = open(filename, 'r')
    a = int(f.readline().strip())
    b = int(f.readline().strip())
    for target in f:
        target = int(target.strip())
        a, b = b, a + b  # a <- b, b <- (a + b)
        if b != target:
            print("wrong answer with value %d" % (target))
            return
    print("validation passed!")
    
parser = argparse.ArgumentParser(description='Validate the correctness of fibonacci numbers.')
parser.add_argument('--file', '-f', metavar='file_name', type=str, required=True, help='file for testing')
args = parser.parse_args()
read_file(args.file)

基於 list_head 的大數運算

typedef struct {
    size_t len;
    struct list_head link;
} bignum_head;

typedef struct {
    uint64_t value;
    struct list_head link;
} bignum_node;

#define NEW_NODE(head, val)                              \
    ({                                                   \
        bignum_node *node = malloc(sizeof(bignum_node)); \
        if (node) {                                      \
            node->value = val;                           \
            list_add_tail(&node->link, head);            \
        }                                                \
    })

解讀:

  • 使用鏈結串列避免像字串操作須重複 malloc 以及 free 的成本
  • 利用內建的 64 位元整數型別進行運算
  • 減少解碼成本(僅須 printf 類的格式化輸出函式)
static inline void bignum_add_to_smaller(struct list_head *lgr,
                                         struct list_head *slr)
{
    struct list_head **l = &lgr->next, **s = &slr->next;

    for (bool carry = 0;; l = &(*l)->next, s = &(*s)->next) {
        if (*s == slr) {
            if (*l == lgr) {
                if (carry)
                    NEW_NODE(slr, 1);
                break;
            }

            NEW_NODE(slr, 0);
        }

        bignum_node *lentry = list_entry(*l, bignum_node, link),
                    *sentry = list_entry(*s, bignum_node, link);

        carry = FULL_ADDER_64(lentry->value, sentry->value, carry);
    }
}

程式碼中的 l 以及 s 這二個間接指標 (indirect pointer) 雖看似多餘,但是其實不可被簡化,否則會影響答案,因為第 14 行的 NEW_NODE 會新增一節點並透過 list_add_tail 加到 slr 的尾端,因此若是單純使用指向 list_head 的指標的話,則需要在 NEW_NODE 後重新將 s 指向新的節點,否則第 18 行的 list_entry 相當於對鏈結串列的 head 使用 list_entry,造成存取到預期外位址。

因此運用間接指標 *s 來省略將 s 指向新的尾端節點的步驟。

#define MAX_DIGITS 18
#define BOUND64 1000000000000000000UL

#define FULL_ADDER_64(a, b, carry)               \
    ({                                           \
        uint64_t res = (a) + (b) + (carry);      \
        uint64_t overflow = -!!(res >= BOUND64); \
        (b) = res - (BOUND64 & overflow);        \
        overflow;                                \
    })

為了盡可能的減少節點數量,原本規劃以 UINT64_MAX 作為上限,但解碼時仍須進行大數運算,因此改以 10 的冪(

1018=1000000000000000000)作為上限值,來減少解碼成本。

而不選擇以

1019 作為上限的原因則是考慮到兩個
10191
相加時會超過 UINT64_MAX,造成需要額外判斷是否有 overflow 的發生,因此選擇少一位數的
1018
作為上限配合
進行簡單的判斷。

初步測試

int main(int argc, char const *argv[])
{
    struct list_head *h1 = bignum_new(1);
    struct list_head *h2 = bignum_new(0);

    for (int i = 0; i < atoi(argv[1]); ++i) {
        bignum_add_to_smaller(h1, h2);
        swap(h1, h2);
    }
    bignum_to_string(h2, NULL, 0);
    bignum_free(h1);
    bignum_free(h2);
    return 0;
}

由於目前只有實作加法的功能,因此先以最基本的遞迴呼叫測試:

$ taskset -cp 1
pid 1's current affinity list: 0-9

$ time taskset -c 10 ./a.out 100000 | wc
      0       1   20899
taskset -c 10 ./a.out 100000  0.25s user 0.00s system 99% cpu 0.257 total
wc  0.00s user 0.00s system 0% cpu 0.256 total

fib(100000) 大約在
14
秒左右計算完成,且與 WolframAlpha 上的結果相同。

$ time taskset -c 10 ./a.out 1000000 | wc
      0       1  208988
taskset -c 10 ./a.out 1000000  29.70s user 0.00s system 99% cpu 29.707 total
wc  0.00s user 0.00s system 0% cpu 29.707 total

fib(1000000) 大約在 30 秒內計算完成,但因位數過多無法與 WolframAlpha 上的結果相比。


改善大數運算

先以 perf stat 分析前述程式碼,作為後續比對的基準

    63,453,850,327      cycles                                                        ( +-  0.03% )  (66.65%)
   182,785,094,108      instructions              #    2.88  insn per cycle           ( +-  0.00% )  (83.33%)
            15,795      cache-misses              #    1.375 % of all cache refs      ( +- 19.12% )  (83.33%)
         1,148,592      cache-references                                              ( +- 11.66% )  (83.34%)
    36,448,212,424      branch-instructions                                           ( +-  0.00% )  (83.34%)
       117,825,450      branch-misses             #    0.32% of all branches          ( +-  0.56% )  (83.33%)

          18.73770 +- 0.00638 seconds time elapsed  ( +-  0.03% )

接下來使用 perf record 量測 call graph (省略部分內容)

$ sudo perf record -g --call-graph dwarf ./fib
$ sudo perf report --stdio -g graph,0.5,caller

# Children      Self  Command  Shared Object      Symbol
# ........  ........  .......  .................  .................................
#
    84.92%     1.89%  fib      fib                [.] bn_fib_fdoubling
            |
            |--83.03%--bn_fib_fdoubling
            |          |
            |          |--48.43%--bn_mult
            |          |          |
            |          |          |--20.74%--bn_alloc
            |          |          |          |
            |          |          |          |--14.45%--__libc_calloc
            |          |          |          |
            |          |          |           --4.93%--__GI___libc_malloc (inlined)
            |          |          |
            |          |          |--13.00%--bn_mult_add (inlined)
            |          |          |
            |          |          |--3.43%--bn_msb (inlined)
            |          |          |
            |          |           --1.17%--bn_swap (inlined)
            |          |
            |          |--16.18%--bn_free
            |          |          |
            |          |          |--14.70%--__GI___libc_free (inlined)
            |          |          |
            |          |           --0.81%--free@plt
            |          |
            |          |--6.31%--bn_cpy
            |          |          |
            |          |          |--3.67%--memcpy (inlined)
            |          |          |
            |          |           --1.61%--bn_resize
            |          |                     |
            |          |                      --0.99%--__GI___libc_realloc (inlined)
            |          |--4.69%--bn_add
            |          |          |
            |          |           --4.52%--bn_do_add (inlined)
            |          |                     |
            |          |                      --1.69%--bn_msb (inlined)
            |          |
            |          |--4.31%--bn_sub (inlined)
            |          |          |
            |          |           --4.25%--bn_add
            |          |                     |
            |          |                     |--2.35%--bn_do_sub
            |          |                     |
            |          |                      --0.58%--bn_cmp
            |          |--1.93%--bn_lshift
            |          |          |
            |          |           --0.84%--bn_clz (inlined)
            |          |
            |           --0.55%--bn_alloc
            |
             --1.07%--_start
                       __libc_start_main
                       main
                       bn_fib_fdoubling
  • 有 84.92% 的時間 (準確來說是樣本數) 落在 bn_fib_fdoubling 內,其中有 83.03% 的時間會再呼叫其他函式
  • bn_mult 佔整體 48.43% 的時間,因此優化乘法會帶來明顯的效能增益
  • bn_fib_fdoubling 內有接近一半的時間在管理動態記憶體與複製資料,顯然需要相關的策略來降低這部分的成本
  • bn_addbn_sub 共佔 9% 的時間,需要再單獨使用 iterative 版本的 bn_fib 來進行分析與優化,否則很難在 bn_fib_fdoubling 內觀察到效能增益
  • bn_free 占有高比例的原因不明,目前先猜測可能是因為 bn_mult 過度頻繁的呼叫 bn_allocbn_free

改善方案 1: 改寫 bn_fib_fdoubling

原本的實作局限於使用 bn_cpy 來更新暫存變數 k1k2 的數值,其實可以藉由 bn_swap 以及改變各函式儲存結果的位置來達成一樣的目的,將所有的 bn_cpy 去除來降低複製資料造成的成本

當資料來源與目的重疊時 (c == a || c == b),bn_mult 必須先配置暫存的記憶體空間來儲存計算結果,因此可以進一步確保呼叫 bn_mult 時不要發生此狀況,降低使用 mallocmemcpy 的次數。

void bn_fib_fdoubling(bn *dest, unsigned int n)
{
    ...
    for (unsigned int i = 1U << (31 - __builtin_clz(n)); i; i >>= 1) {
        /* F(2k) = F(k) * [ 2 * F(k+1) – F(k) ] */
        /* F(2k+1) = F(k)^2 + F(k+1)^2 */
        bn_lshift(f2, 1, k1);// k1 = 2 * F(k+1)
        bn_sub(k1, f1, k1);  // k1 = 2 * F(k+1) – F(k)
        bn_mult(k1, f1, k2); // k2 = k1 * f1 = F(2k)
        bn_mult(f1, f1, k1); // k1 = F(k)^2
        bn_swap(f1, k2);     // f1 <-> k2, f1 = F(2k) now
        bn_mult(f2, f2, k2); // k2 = F(k+1)^2
        bn_add(k1, k2, f2);  // f2 = f1^2 + f2^2 = F(2k+1) now
        if (n & i) {
            bn_swap(f1, f2);    // f1 = F(2k+1)
            bn_add(f1, f2, f2); // f2 = F(2k+2)
        }
    }
    ...
}

結果如下 (v1 綠線)

    24,770,616,442      cycles                                                        ( +-  0.05% )  (66.63%)
    71,462,180,892      instructions              #    2.88  insn per cycle           ( +-  0.00% )  (83.32%)
             8,406      cache-misses              #    1.048 % of all cache refs      ( +-  4.19% )  (83.33%)
           802,258      cache-references                                              ( +-  9.39% )  (83.34%)
    12,105,857,981      branch-instructions                                           ( +-  0.00% )  (83.36%)
        39,389,038      branch-misses             #    0.33% of all branches          ( +-  1.16% )  (83.33%)

           7.31640 +- 0.00362 seconds time elapsed  ( +-  0.05% )
  • 效能大幅度改善,時間從 18.73770s 降到 7.31640s
  • 複製資料的成本真的很大,不難想像為何會有 COW 等策略來降低成本

改善方案 2: 運用 Q-Matrix

觀察 sysprog21/bignum 中費氏數列的實作函式 fibonacci,會發現雖看似採用 fast doubling,但實際是 Q-matrix 這樣的變形,推導如下:

[F(2n1)F(2n)]=[0111]2n[F(0)F(1)]=[F(n1)F(n)F(n)F(n+1)][F(n1)F(n)F(n)F(n+1)][10]=[F(n)2+F(n1)2F(n)F(n)+2F(n)F(n1)]

整理後可得

F(2k1)=F(k)2+F(k1)2F(2k)=F(k)[2F(k1)+F(k)]

使用上述公式改寫 bn_fib_fdoubling,搭配使用 clz ,後者可讓 n 值越小的時候,減少越多次迴圈運算,從而達成加速。

void bn_fib_fdoubling(bn *dest, unsigned int n)
{
    bn_resize(dest, 1);
    if (n < 2) {  // Fib(0) = 0, Fib(1) = 1
        dest->number[0] = n;
        return;
    }

    bn *f1 = bn_alloc(1); // f1 = F(k-1)
    bn *f2 = dest;        // f2 = F(k) = dest
    f1->number[0] = 0;
    f2->number[0] = 1;
    bn *k1 = bn_alloc(1);
    bn *k2 = bn_alloc(1);

    for (unsigned int i = 1U << (30 - __builtin_clz(n)); i; i >>= 1) {
        /* F(2k-1) = F(k)^2 + F(k-1)^2 */
        /* F(2k) = F(k) * [ 2 * F(k-1) + F(k) ] */
        bn_lshift(f1, 1, k1);// k1 = 2 * F(k-1)
        bn_add(k1, f2, k1);  // k1 = 2 * F(k-1) + F(k)
        bn_mult(k1, f2, k2); // k2 = k1 * f2 = F(2k)
        bn_mult(f2, f2, k1); // k1 = F(k)^2
        bn_swap(f2, k2);     // f2 <-> k2, f2 = F(2k) now
        bn_mult(f1, f1, k2); // k2 = F(k-1)^2
        bn_add(k2, k1, f1);  // f1 = k1 + k2 = F(2k-1) now
        if (n & i) {
            bn_swap(f1, f2);    // f1 = F(2k+1)
            bn_add(f1, f2, f2); // f2 = F(2k+2)
        }
    }
    bn_free(f1);
    bn_free(k1);
    bn_free(k2);
}

結果如下 (v2 紅線)

    23,928,237,220      cycles                                                        ( +-  0.06% )  (66.64%)
    69,570,862,420      instructions              #    2.91  insn per cycle           ( +-  0.00% )  (83.33%)
             8,401      cache-misses              #    1.001 % of all cache refs      ( +-  5.17% )  (83.33%)
           839,163      cache-references                                              ( +-  9.65% )  (83.33%)
    11,641,338,644      branch-instructions                                           ( +-  0.00% )  (83.35%)
        41,101,058      branch-misses             #    0.35% of all branches          ( +-  1.42% )  (83.34%)

           7.06808 +- 0.00453 seconds time elapsed  ( +-  0.06% )
  • 時間從 7.31640 s 降低至 7.06808 s,小幅度減少約 3% 時間

改善方案 3: 引入 memory pool

原本實作的大數運算會在計算前先使用 bn_resize (realloc),確保有足夠大的空間來儲存計算結果,再於計算結束後檢查是否有多餘的空間 (msb 所在的 array 數值為 0) 並進行修剪 (trim),避免造成 memory leak 與增加後續計算的成本 (因為要存取的空間會越來越長),然而頻繁使用 realloc 可能會造成降低效能。

引入 memory pool,以 capacity 的方式管理 bn 實際可用的記憶體大小,降低 bn_resize 實際呼叫 realloc 的次數

typedef struct _bn {
    unsigned int *number;  /* ptr to number */
    unsigned int size;     /* length of number */
+   unsigned int capacity; /* total allocated length, size <= capacity */
    int sign;
} bn;
#define INIT_ALLOC_SIZE 4
#define ALLOC_CHUNK_SIZE 4

bn *bn_alloc(size_t size)
{
    bn *new = (bn *) malloc(sizeof(bn));
    new->size = size;
    new->capacity = size > INIT_ALLOC_SIZE ? size : INIT_ALLOC_SIZE;
    new->number = (unsigned int *) malloc(sizeof(int) * new->capacity);
    for (int i = 0; i <size; i++)
        new->number[i] = 0;
    new->sign = 0;
    return new;
}

static int bn_resize(bn *src, size_t size)
{
    ...
    if (size > src->capacity) { /* need to allocate larger capacity */
        src->capacity = (size + (ALLOC_CHUNK_SIZE - 1)) & ~(ALLOC_CHUNK_SIZE - 1); // ceil to 4*n
        src->number = realloc(src->number, sizeof(int) * src->capacity);
    }
    if (size > src->size) { /* memset(src, 0, size) */
        for (int i = src->size; i < size; i++)
            src->number[i] = 0;
    }
    src->size = size;
}
  • 只有當 size 超過 capacity 時才會 realloc,並以 4 為單位配置更大的空間
  • 所有計算仍以 size 作為計算的範圍,不會因為有多餘的空間而增加運算成本
  • trim 時只需要縮小 size,不需要實際 realloc 來縮小空間

結果如下 (v3 紅線)

    19,765,435,588      cycles                                                        ( +-  0.06% )  (66.64%)
    61,180,908,879      instructions              #    3.10  insn per cycle           ( +-  0.00% )  (83.33%)
             4,849      cache-misses              #    7.935 % of all cache refs      ( +-  5.97% )  (83.34%)
            61,110      cache-references                                              ( +-  5.79% )  (83.35%)
    10,612,740,290      branch-instructions                                           ( +-  0.00% )  (83.36%)
        32,583,167      branch-misses             #    0.31% of all branches          ( +-  1.54% )  (83.32%)

           5.83800 +- 0.00350 seconds time elapsed  ( +-  0.06% )
  • 時間從 7.06808 s 減少至 5.83800 s,減少約 17% 時間
  • cache-references 從 839,163 大幅度降低至 61,110,顯示原本頻繁呼叫 realloc 造成的成本非常可觀

改善方案 4: 善用 64 位元微處理器特性

bn 結構體中原本每個陣列的資料型態使用 unsigned int,在 64 位元微處理器下改為使用 uint64_t,搭配合適的 alignment,可確保每次記憶體存取都是 word size 寬度,從而提高資料存取的速度。

#include <stdint.h>

#if defined(__LP64__) || defined(__x86_64__) || defined(__amd64__) || defined(__aarch64__)
#define BN_WSIZE 8
#else
#define BN_WSIZE 4
#endif

#if BN_WSIZE == 8
typedef uint64_t bn_data;
typedef unsigned __int128 bn_data_tmp; // gcc support __int128
#elif BN_WSIZE == 4
typedef uint32_t bn_data;
typedef uint64_t bn_data_tmp;
#else
#error "BN_WSIZE must be 4 or 8"
#endif

typedef struct _bn {
    bn_data *number;  /* ptr to number */
    bn_data size;     /* length of number */
    bn_data capacity; /* total allocated length, size <= capacity */
    int sign;
} bn;
  • 使用 bignum/apm.h 中的方式來定義 bn 的資料型態,以便於根據不同的 word size 切換定義
  • 乘法運算時會用到 2 倍大小的的暫存變數,直接使用 gcc 提供的 __int128 實作

結果如下 (v4)

    12,669,256,697      cycles                                                        ( +-  0.07% )  (66.64%)
    38,320,121,559      instructions              #    3.02  insn per cycle           ( +-  0.00% )  (83.32%)
             5,867      cache-misses              #   11.048 % of all cache refs      ( +- 14.55% )  (83.32%)
            53,104      cache-references                                              ( +- 12.56% )  (83.33%)
     5,274,117,456      branch-instructions                                           ( +-  0.00% )  (83.35%)
         2,174,668      branch-misses             #    0.04% of all branches          ( +-  0.28% )  (83.36%)

           3.74384 +- 0.00279 seconds time elapsed  ( +-  0.07% )
  • 時間從 5.83800 s 減少至 3.74384 s,減少約 36% 時間
  • instructions 的數量降低約 37%,顯示使用 uint64_t 更能發揮 64 位元 CPU 的優勢

改善方案 5: 改寫 bn_addbn_mult

改善 bn_add 的效能

為了凸顯 bn_add 對效能的影響,這個章節改為量測 bn_fib (iterative) 作為判斷依據,並將量測的範圍提高到 F(10000)。由於上述幾個改善策略也會提升 bn_add 的效能,因此先重新量測現有的效能,結果如下 (v1 紅線)

    87,130,307,524      cycles                                                        ( +-  0.01% )  (66.66%)
   262,062,098,878      instructions              #    3.01  insn per cycle           ( +-  0.00% )  (83.33%)
            11,863      cache-misses              #    2.193 % of all cache refs      ( +-  4.13% )  (83.33%)
           540,853      cache-references                                              ( +- 10.00% )  (83.33%)
    33,988,594,050      branch-instructions                                           ( +-  0.00% )  (83.34%)
       243,724,292      branch-misses             #    0.72% of all branches          ( +-  0.00% )  (83.33%)

          25.73128 +- 0.00279 seconds time elapsed  ( +-  0.01% )

原本的實作會在每次迴圈判斷需要相加的數值,這麼做的優點是只需寫一個迴圈就能完成計算,但缺點是每次迴圈都有兩個 branch 要判斷。為了改善這點,改為使用兩個迴圈進行計算,第一個迴圈先計算兩者皆有資料的範圍,再於第二個迴圈處理 carry 與剩餘的資料範圍。另外,藉由無號整數不會 overflow 的特性 (modulo),可以進一步避免使用 __int128 (bn_data_tmp) 進行計算

/* |c| = |a| + |b| */
static void bn_do_add(const bn *a, const bn *b, bn *c)
{
    ...
-   bn_data_tmp carry = 0;
-   for (int i = 0; i < c->size; i++) {
-       bn_data tmp1 = (i < asize) ? a->number[i] : 0;
-       bn_data tmp2 = (i < bsize) ? b->number[i] : 0;
-       carry += (bn_data_tmp) tmp1 + tmp2;
-       c->number[i] = carry;
-       carry >>= DATA_BITS;
-   }
    
+   bn_data carry = 0;
+   for (int i = 0; i < bsize; i++) {
+       bn_data tmp1 = a->number[i];
+       bn_data tmp2 = b->number[i];
+       carry = (tmp1 += carry) < carry;
+       carry += (c->number[i] = tmp1 + tmp2) < tmp2;
+   }
+   if (asize != bsize) {  // deal with the remaining part if asize > bsize
+       for (int i = bsize; i < asize; i++) {
+           bn_data tmp1 = a->number[i];
+           carry = (tmp1 += carry) < carry;
+           c->number[i] = tmp1;
+       }
+   }

    if (carry) {
        c->number[asize] = carry;
        ++(c->size);
    }
}

 Performance counter stats for './fib' (10 runs):

    42,111,360,506      cycles                                                        ( +-  0.41% )  (66.66%)
   125,087,664,564      instructions              #    2.97  insn per cycle           ( +-  0.00% )  (83.33%)
             9,037      cache-misses              #    5.927 % of all cache refs      ( +-  7.14% )  (83.33%)
           152,468      cache-references                                              ( +-  8.29% )  (83.34%)
    12,833,863,666      branch-instructions                                           ( +-  0.00% )  (83.34%)
       147,335,826      branch-misses             #    1.15% of all branches          ( +-  0.02% )  (83.34%)

           12.4361 +- 0.0512 seconds time elapsed  ( +-  0.41% )
  • branch-instructions 減少約 63%,branch-misses 也減少約 40%
  • cache-references 減少約 72%,顯示我本來的實作法有多餘的執行步驟,使 CPU 不斷重複讀取某些數值
  • 時間從 25.73128 s 減少至 12.4361 s,減少約 52% 時間

改善 bn_mult 的效能

改回量測 bn_fib_fdoubling 作為判斷依據,並接續上述 fast doubling v4 版本,將測試範圍提高至 F(10000),會發現 bn_mult 的效能明顯低於對照組

     7,208,970,350      cycles                                                        ( +-  0.11% )  (66.54%)
    15,804,723,358      instructions              #    2.19  insn per cycle           ( +-  0.00% )  (83.27%)
             3,826      cache-misses              #    9.269 % of all cache refs      ( +-  4.17% )  (83.29%)
            41,280      cache-references                                              ( +-  8.72% )  (83.37%)
     1,667,790,605      branch-instructions                                           ( +-  0.01% )  (83.44%)
        58,185,471      branch-misses             #    3.49% of all branches          ( +-  0.06% )  (83.36%)

           2.13072 +- 0.00229 seconds time elapsed  ( +-  0.11% )

原本實作 bn_mult 的方式為依序將兩格陣列相乘,再將結果直接疊加至輸出的變數,然而這會導致每行乘法被拆分成 2 個步驟 (相乘後先將 carry 疊加至下個 array,下次迴圈又再從該 array 取出數值來進行乘法),降低計算的速度。接下來參考 bignum/mul.c 來改寫 bn_mult,改為一次將乘積與 carry 疊加至輸出的變數來提升效能

/* c[size] += a[size] * k, and return the carry */
static bn_data _mult_partial(const bn_data *a, bn_data asize, const bn_data k, bn_data *c)
{
    if (k == 0)
        return 0;
    
    bn_data carry = 0;
    for (int i = 0; i < asize; i++) {
        bn_data high, low;
        bn_data_tmp prod = (bn_data_tmp) a[i] * k;
        low = prod;
        high = prod >> DATA_BITS;
        carry = high + ((low += carry) < carry);
        carry += ((c[i] += low) < low);
    }
    return carry;
}

void bn_mult(const bn *a, const bn *b, bn *c)
{
    ...
    bn_data *cp = c->number + a->size;
    for (int j = 0; j < b->size; j++) {
        c->number[a->size + j] =
            _mult_partial(a->number, a->size, b->number[j], c->number + j);
    }
    ...
}

     2,288,892,189      cycles                                                        ( +-  0.05% )  (66.40%)
     7,563,269,285      instructions              #    3.30  insn per cycle           ( +-  0.00% )  (83.35%)
             3,584      cache-misses              #   13.507 % of all cache refs      ( +- 29.45% )  (83.48%)
            26,534      cache-references                                              ( +-  9.79% )  (83.48%)
       602,658,857      branch-instructions                                           ( +-  0.02% )  (83.48%)
         3,857,937      branch-misses             #    0.64% of all branches          ( +-  0.15% )  (83.15%)

          0.678312 +- 0.000674 seconds time elapsed  ( +-  0.10% )
  • 時間從 2.13072 s 減少至 0.678312 s,減少約 68% 時間

改善方案 6: 內嵌組合語言

sysprog21/bignum 中使用內嵌組合語言 (inline assembly) 來直接取得乘法運算的高位與低位,直接使用一樣的方式實作乘法,取代原本使用的 __int128 (bn_data_tmp)

static bn_data _mult_partial(const bn_data *a, bn_data asize, const bn_data k, bn_data *c)
{
    if (k == 0)
        return 0;
    
    bn_data carry = 0;
    for (int i = 0; i < asize; i++) {
        bn_data high, low;
-       bn_data_tmp prod = (bn_data_tmp) a[i] * k;
-       low = prod;
-       high = prod >> DATA_BITS;
+       __asm__("mulq %3" : "=a"(low), "=d"(high) : "%0"(a[i]), "rm"(k));
        carry = high + ((low += carry) < carry);
        carry += ((c[i] += low) < low);
    }
    return carry;
}

     1,412,000,613      cycles                                                        ( +-  0.07% )  (65.71%)
     3,782,233,502      instructions              #    2.68  insn per cycle           ( +-  0.02% )  (82.91%)
             1,816      cache-misses              #    9.135 % of all cache refs      ( +- 17.30% )  (83.56%)
            19,878      cache-references                                              ( +-  1.72% )  (83.76%)
       357,455,000      branch-instructions                                           ( +-  0.02% )  (83.76%)
         3,862,706      branch-misses             #    1.08% of all branches          ( +-  0.15% )  (83.21%)

          0.418849 +- 0.000460 seconds time elapsed  ( +-  0.11% )
  • 時間從 0.678312 s 減少至 0.418849 s,減少約 38% 時間
  • 使用內嵌組合語言,之所以會得到比 __int128 好的效能表現,主因是沒辦法藉由使用 __int128 直接把乘積的高位與低位儲存至指定的空間

改善方案 7: 引入 bn_sqr

         a   b   c
      x  a   b   c
 -------------------
        ac  bc  cc
    ab  bb  bc
aa  ab  ac

考慮上述

abc2 的計算過程,會發現數值
ab
ac
bc
各會重複一次,因此可先計算對角線其中一邊的數值,將數值的總和直接乘二,最終再加上對角線上的
aa
bb
cc
。藉由這種方式,平方運算的成本可由本來的
n2
次乘法降為
(n2n)/2
次乘法

void do_sqr_base(const bn_data *a, bn_data size, bn_data *c)
{
    bn_data *cp = c + 1;
    const bn_data *ap = a;
    bn_data asize = size - 1;
    for (int i = 0; i < asize; i++) {
        /* calc the (ab bc bc) part */
        cp[asize - i] = _mult_partial(&ap[i + 1], asize - i, ap[i], cp);
        cp += 2;
    }

    /* Double it */
    for (int i = 2 * size - 1; i > 0; i--)
        c[i] = c[i] << 1 | c[i - 1] >> (DATA_BITS - 1);
    c[0] <<= 1;

    /* add the (aa bb cc) part at diagonal line */
    cp = c;
    ap = a;
    asize = size;
    bn_data carry = 0;
    for (int i = 0; i < asize; i++) {
        bn_data high, low;
        __asm__("mulq %3" : "=a"(low), "=d"(high) : "%0"(ap[i]), "rm"(ap[i]));
        high += (low += carry) < carry;
        high += (cp[0] += low) < low;
        carry = (cp[1] += high) < high;
        cp += 2;
    }
}

結果如下 (v7 藍線)

     1,057,685,945      cycles                                                        ( +-  0.14% )  (66.56%)
     2,744,641,149      instructions              #    2.59  insn per cycle           ( +-  0.02% )  (83.39%)
             1,304      cache-misses              #    6.200 % of all cache refs      ( +- 19.30% )  (83.46%)
            21,032      cache-references                                              ( +-  3.28% )  (83.46%)
       292,210,120      branch-instructions                                           ( +-  0.05% )  (83.46%)
         3,400,028      branch-misses             #    1.16% of all branches          ( +-  1.65% )  (83.05%)

          0.314624 +- 0.000825 seconds time elapsed  ( +-  0.26% )
  • 時間從 0.418849 s 減少至 0.314624 s,減少約 25% 時間
  • 資料長度越長,節省的時間越明顯

改善方案 8: 實作 Karatsuba algorithm

雖然上述 v7 版本所花的時間已略低於參考組,但若將量測範圍逐漸提高,會發現效能仍不及參考組,至

F(100000) 時差距約有 1 倍,觀察 sysprog21/bignum 的原始碼會發現使用 Karatsuba algorithm 來加速乘法與平方運算,因此接下來一樣實作該演算法來提升效能

Karatsuba algorithm 的核心概念是將 a 與 b 拆分為高位與低位再進行計算,考量計算

a×b,且 a 與 b 的位數皆為
N=2n
位 (2 進位下的位數,不過 10 進位時邏輯相同),可將 a 與 b 表示如下

a=a0+a1×2n
b=b0+b1×2n

因此

a×b 可進一步整理為

(2n+22n)(a1b1)+2n(a1a0)(b0b1)+(2n+1)(a0b0)

由於

2n 可藉由 bit shift 達成,因此實際使用乘法的部分只剩 3 項,遠少於直接使用乘法的
N2
項,可大幅度降低乘法運算的成本

將 Karatsuba multiplication 應用至 bn_multbn_sqr 後,效能如下 (v8 藍線)

圖中設定的閾值與參考組一樣,但縮小閾值,在數值長度較小時,不會顯著提升效能