Try   HackMD

2022q1 Homework3 (fibdrv)

contributed by < eric88525 >

作業描述

自我檢查清單

  • 研讀上述 Linux 效能分析的提示 描述,在自己的實體電腦運作 GNU/Linux,做好必要的設定和準備工作 \(\to\) 從中也該理解為何不希望在虛擬機器中進行實驗;
  • 研讀上述費氏數列相關材料 (包含論文),摘錄關鍵手法,並思考 clz / ctz 一類的指令對 Fibonacci 數運算的幫助。請列出關鍵程式碼並解說
  • 複習 C 語言 數值系統bitwise operation,思考 Fibonacci 數快速計算演算法的實作中如何減少乘法運算的成本;
  • 研讀 KYG-yaya573142 的報告,指出針對大數運算,有哪些加速運算和縮減記憶體操作成本的舉措?
  • lsmod 的輸出結果有一欄名為 Used by,這是 "each module's use count and a list of referring modules",但如何實作出來呢?模組間的相依性和實際使用次數 (reference counting) 在 Linux 核心如何追蹤呢?

    搭配閱讀 The Linux driver implementer’s API guide » Driver Basics

  • 注意到 fibdrv.c 存在著 DEFINE_MUTEX, mutex_trylock, mutex_init, mutex_unlock, mutex_destroy 等字樣,什麼場景中會需要呢?撰寫多執行緒的 userspace 程式來測試,觀察 Linux 核心模組若沒用到 mutex,到底會發生什麼問題。嘗試撰寫使用 POSIX Thread 的程式碼來確認。

Topic 1

研讀上述 Linux 效能分析的提示 描述,在自己的實體電腦運作 GNU/Linux,做好必要的設定和準備工作

硬體配置

Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          16
On-line CPU(s) list:             0-15
Thread(s) per core:              2
Core(s) per socket:              8
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       AuthenticAMD
CPU family:                      25
Model:                           33
Model name:                      AMD Ryzen 7 5800X 8-Core Processor
環境設置
// 檢查 cpu 數量
$ taskset -cp 1
pid 1's current affinity list: 0-15

// 獨立出 cpu 15
$ sudo vim /etc/default/grub
GRUB_CMDLINE_LINUX_DEFAULT="quiet splash isolcpus=15"

// update
$ sudo update-grub

// after reboot...
$ taskset -cp 1
pid 1's current affinity list: 0-14

// 抑制 address space layout randomization
$ sudo sh -c "echo 0 > /proc/sys/kernel/randomize_va_space"

設定 scaling_governor 為 performance (at performance.sh)

for i in /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor
do
    echo performance > ${i}
done

smp_affinity 設置: 16執行緒的 cpu 排除 cpu 15

#!/bin/bash
for file in `find /proc/irq -name "smp_affinity"`
do
    sudo bash -c "echo 7fff > ${file}"
done
sudo bash -c "echo 7fff > /proc/irq/default_smp_affinity"

Topic 2

研讀上述費氏數列相關材料 (包含論文),摘錄關鍵手法,並思考 clz / ctz 一類的指令對 Fibonacci 數運算的幫助。請列出關鍵程式碼並解說

Fast doubling

推導可得知 \(F(K)\)\(F(K+1)\) 可得出 \(F(2K)\), \(F(2K+1)\), \(F(2K+2)\)
\[ \begin{split} F(2k) &= F(k)[2F(k+1) - F(k)] \\ F(2k+1) &= F(k+1)^2+F(k)^2 \end{split} \]

透過最後需要求出的數字 N,不斷除以二來得知要如何從 \(\begin{pmatrix}F_0 \\ F_1\end{pmatrix}\)\(\begin{pmatrix}F_N \\ F_{N+1}\end{pmatrix}\)

如果將數字轉成二進制,最高位元會對應到前面狀態的 a 是奇數或偶數,最低位元會對應到後面狀態的 a 是奇數或偶數,例如求 \(F(10)\),10是奇數或偶數需要由最低位元來判斷,\(F(1)\) 的奇偶則是由 MSB 來判斷

  • 沒有 clz 的版本,參考 fast doubling 文章
    1. 計算最高有效位元 (MSB) 在第幾個 bit
    2. 讓 mask 由 MSB 在每次的迭代中 shift right , 逐一檢查是下一狀態的 a 是奇數或是偶數,並更新 a, b
      • 奇數: 計算 \(2n+1\)\(2n+2\)
      • 偶數: 計算 \(2n\)\(2n+1\)
static long long fib_fastdoubling(long long n)
{
    if (n == 0)
        return 0;
    else if (n <= 2)
        return 1;
    // The position of the highest bit of n.
    // So we need to loop `h` times to get the answer.
    // Example: n = (Dec)50 = (Bin)00110010, then h = 6.
    //                               ^ 6th bit from right side
    unsigned int h = 0;
    for (unsigned int i = n; i; ++h, i >>= 1)
        ;

    uint64_t a = 0;  // F(0) = 0
    uint64_t b = 1;  // F(1) = 1
    // There is only one `1` in the bits of `mask`. The `1`'s position is same
    // as the highest bit of n(mask = 2^(h-1) at first), and it will be shifted
    // right iteratively to do `AND` operation with `n` to check `n_j` is odd or
    // even, where n_j is defined below.
    for (unsigned int mask = 1 << (h - 1); mask; mask >>= 1) {  // Run h times!
        // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) = n
        // >> j (n_j = n when j = 0), k = floor(n_j / 2), then a = F(k), b =
        // F(k+1) now.
        uint64_t c = a * (2 * b - a);  // F(2k) = F(k) * [ 2 * F(k+1) – F(k) ]
        uint64_t d = a * a + b * b;    // F(2k+1) = F(k)^2 + F(k+1)^2

        if (mask & n) {  // n_j is odd: k = (n_j-1)/2 => n_j = 2k + 1
            a = d;       //   F(n_j) = F(2k + 1)
            b = c + d;   //   F(n_j + 1) = F(2k + 2) = F(2k) + F(2k + 1)
        } else {         // n_j is even: k = n_j/2 => n_j = 2k
            a = c;       //   F(n_j) = F(2k)
            b = d;       //   F(n_j + 1) = F(2k + 1)
        }
    }

    return a;
}

  • 加入 __builtin_clz (counting leading zero)後:
    • 可得知有多少 leading zero,可以更快速知道 MSB 位置
// the normal version of fib
static long long fib_clz_fastdoubling(long long k)
{
    if (k == 0)
        return 0;
    else if (k <= 2)
        return 1;

    uint64_t a = 0;  // F(0) = 0
    uint64_t b = 1;  // F(1) = 1
    // There is only one `1` in the bits of `mask`. The `1`'s position is same
    // as the highest bit of n(mask = 2^(h-1) at first), and it will be shifted
    // right iteratively to do `AND` operation with `n` to check `n_j` is odd or
    // even, where n_j is defined below.
    for (unsigned int mask = 1 << (32 - __builtin_clz(k)); mask;
         mask >>= 1) {  // Run h times!
        // Let j = h-i (looping from i = 1 to i = h), n_j = floor(n / 2^j) =
        // n >> j (n_j = n when j = 0), k = floor(n_j / 2), then a = F(k),
        // b = F(k+1) now.
        uint64_t c = a * (2 * b - a);  // F(2k) = F(k) * [ 2 * F(k+1) – F(k) ]
        uint64_t d = a * a + b * b;    // F(2k+1) = F(k)^2 + F(k+1)^2

        if (mask & k) {  // n_j is odd: k = (n_j-1)/2 => n_j = 2k + 1
            a = d;       //   F(n_j) = F(2k + 1)
            b = c + d;   //   F(n_j + 1) = F(2k + 2) = F(2k) + F(2k + 1)
        } else {         // n_j is even: k = n_j/2 => n_j = 2k
            a = c;       //   F(n_j) = F(2k)
            b = d;       //   F(n_j + 1) = F(2k + 1)
        }
    }
    return a;
}
  • iteration version: 用 額外變數 temp 來暫存 \(F(N+2) = F(N) + F(N+1)\) 的結果
static long long fib_sequence(long long k)
{

    if (k == 0)
        return 0;
    else if (k <= 2)
        return 1;

    unsigned long long a, b, temp;
    a=0;
    b=1;

    for (int i = 2; i <= k; i++) {
        temp = a + b;
        a = b;
        b = temp;
    }

    return b;
}

比較效能

iteration 預設版本 vs fast doubling vs 加上 clz 的 fast doubling

  • 每個 \(F(N)\) 重複取樣 50 次,消除極端數值後取平均

  • 如果先不管溢位將測量的 \(F(N)\) 加到 500,更能看出複雜度的差異

測量時間方式

  • 客戶端: 透過 argv 來傳遞要使用哪種模式,並對 fib device 寫入
int main(int argc, char *argv[])
{
    /* MODE
     * 0: normal
     * 1: fast doubling
     * 2: clz fast doubling
     * 3: bn
     * 4: bn + fast doubling
     */
    int mode = 0;
    if (argc == 2) {
        mode = *argv[1] - '0';
    }

    long long sz;
    char buf[BUFF_SIZE];
    char write_buf[] = "testing writing";
    int offset = 93;

    int fd = open(FIB_DEV, O_RDWR);
    if (fd < 0) {
        perror("Failed to open character device");
        exit(1);
    }
    for (int i = 0; i <= offset; i++) {
        lseek(fd, i, SEEK_SET);
        sz = write(fd, buf, mode);
        printf("%lld ", sz);
    }
    close(fd);
    return 0;
}
  • fib_write 對應的行為:
    • 傳入參數 mode 來決定要使用哪種運算
    • 測量時間時,用 macro 來展開重複的測量程式碼
      • ktime_get() 得到當前時間
      • ktime_sub(a, b) 得到 a-b 的結果
      • ktime_to_ns() 將結果轉為 ns
    • 在 的報告中提到,沒使用的變數會被編譯器優化給省略,因此要加入 __asm__ volatile 楊制讓編譯器使用原本程式碼
#define TIME_PROXY(fib_f, result, k, timer)             \
    ({                                                  \
        timer = ktime_get();                            \
        result = fib_f(k);                              \
        timer = (size_t) ktime_sub(ktime_get(), timer); \
    });

#define BN_TIME_PROXY(fib_f, result, k, timer)          \
    ({                                                  \
        timer = ktime_get();                            \
        fib_f(result, *offset);                         \
        timer = (size_t) ktime_sub(ktime_get(), timer); \
    });

static void escape(void *p)
{
    __asm__ volatile("" : : "g"(p) : "memory");
}

static ssize_t fib_write(struct file *file,
                         const char *buf,
                         size_t mode,
                         loff_t *offset)
{
    long long result = 0;
    bignum *fib = bn_init(1);

    escape(fib);
    escape(&result);

    switch (mode) {
    case 0: /* noraml */
        TIME_PROXY(fib_sequence, result, *offset, timer)
        break;
    case 1: /* fast doubling*/
        TIME_PROXY(fib_fastdoubling, result, *offset, timer)
        break;
    case 2: /*clz + fast doubling*/
        TIME_PROXY(fib_clz_fastdoubling, result, *offset, timer)
        break;
    case 3: /*big num normal*/
        BN_TIME_PROXY(bn_fib_sequence, fib, *offset, timer)
    default:
        break;
    }

    bn_free(fib);
    return (ssize_t) ktime_to_ns(timer);
}

  • 自動測試檔:
    • outlier_filter 參考老師所提供的方式,消除兩個標準差以外的數值
    • 執行 subprocess 將時間量測結果寫入檔案,並在 python script 讀取。測量不同模式時只需要加入不同 argv 即可
import numpy as np
import subprocess
import matplotlib.pyplot as plt

np.seterr(divide='ignore', invalid='ignore')

# remove outer value
def outlier_filter(datas, threshold=2):
    datas = np.array(datas)

    z = np.abs((datas - datas.mean() + 1e-7) / (datas.std() + 1e-7))
    return datas[z < threshold]

# for every f(n), remove outer value
def data_processing(data):
    data = np.array(data)
    _, test_samples = data.shape
    result = np.zeros(test_samples)

    for i in range(test_samples):
        result[i] = outlier_filter(data[:, i]).mean()

    return result


def main():

    datas = []
    runtime = 50
    fib_modes = ["iteration", "fast_doubling",
                 "clz_fast_doubling", "bn_normal", "bn_fast_doubling"]

    # run program for runtime
    modes = 3
    # run test for every mode
    for mode in range(modes):
        temp = []
        # run test on each mode for runtime
        for _ in range(runtime):
            subprocess.run(
                f"sudo taskset -c 5 ./client_test {mode} > runtime.txt", shell=True)
            _data = np.loadtxt("runtime.txt", dtype=float)
            temp.append(_data)
        temp = data_processing(temp)
        datas.append(temp)

    # plot
    _, ax = plt.subplots()
    plt.grid()
    for i, data in enumerate(datas):
        ax.plot(np.arange(data.shape[-1]),
                data, marker='+',  markersize=3, label=fib_modes[i])
    plt.legend(loc='upper left')
    plt.savefig("runtime.png")

if __name__ == "__main__":
    main()

Topic 5

lsmod 的輸出結果有一欄名為 Used by,這是 “each module’s use count and a list of referring modules”,但如何實作出來呢?模組間的相依性和實際使用次數 (reference counting) 在 Linux 核心如何追蹤呢?

reference counting 在 linux kernel 中由 atomic_t 型態的變數來記錄,並提供 API 來讓我們初始化/增加/減少 reference counting。

需要為 atomic counting 原因為多個 cpu 可能會同時維護一個 reference count

// defined in <linux/kref.h>
struct kref {
    atomic_t refcount;
};
// Before using a kref, you must initialize it via kref_init()
void kref_init(struct kref *kref)
{
    atomic_set(&kref->refcount, 1);
}
// Incrementing the reference count is done via kobject_get():
void kref_get(struct kref *kref)
{
    WARN_ON(!atomic_read(&kref->refcount));
    atomic_inc(&kref->refcount);
}

// Decrementing the reference count is done via kobject_put():
int kref_put(struct kref *kref, void (*release)(struct kref *kref))
{
    WARN_ON(release == NULL);
    WARN_ON(release == (void (*)(struct kref *)) kfree);

    if ((atomic_read(&kref->refcount) == 1) ||
        (atomic_dec_and_test(&kref->refcount))) {
        release(kref);
        return 1;
    }
    return 0;
}

ref:

Topic 6

注意到 fibdrv.c 存在著 DEFINE_MUTEX, mutex_trylock, mutex_init, mutex_unlock, mutex_destroy 等字樣,什麼場景中會需要呢?撰寫多執行緒的 userspace 程式來測試,觀察 Linux 核心模組若沒用到 mutex,到底會發生什麼問題。嘗試撰寫使用 POSIX Thread 的程式碼來確認。

為了檢驗多個 process 同時對 fib device 互動時的行為,撰寫 multiprocess 的測試檔案(改寫自 wiki)

建立三個 thread ,如果無法開啟 device 則會等待 0.5 秒後再次嘗試開啟,在對 device read 後等待 0.1 秒後才會進行下一次 read

#include <assert.h>
#include <fcntl.h>
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <unistd.h>

#define NUM_THREADS 3
#define FIB_DEV "/dev/fibonacci"


void *perform_work(void *arguments)
{
    int fd = open(FIB_DEV, O_RDWR);
    while (fd  < 0) {
        printf("Process %d Fail to open device\n", *((int *) arguments));
        sleep(0.5);
        fd = open(FIB_DEV, O_RDWR);
    }

    char buf[1];
    for (int i = 5; i <= 8; i++) {
        lseek(fd, i, SEEK_SET);
        long long sz = read(fd, buf, 0);
        printf("process %d: f(%d) = %lld\n", *((int *) arguments), i, sz);
        sleep(0.1);
    }
    close(fd);
    return NULL;
}

int main(void)
{
    pthread_t threads[NUM_THREADS];
    int thread_args[NUM_THREADS];
    int i;
    int result_code;

    // create all threads one by one
    for (i = 0; i < NUM_THREADS; i++) {
        printf("IN MAIN: Creating thread %d.\n", i);
        thread_args[i] = i;
        result_code =
            pthread_create(&threads[i], NULL, perform_work, &thread_args[i]);
        assert(!result_code);
    }

    printf("IN MAIN: All threads are created.\n");

    // wait for each thread to complete
    for (i = 0; i < NUM_THREADS; i++) {
        result_code = pthread_join(threads[i], NULL);
        assert(!result_code);
        printf("IN MAIN: Thread %d has ended.\n", i);
    }

    printf("MAIN program has ended.\n");
    return 0;
}

有 mutex

process 0 進入 device 後上鎖,其他 process 無法進入 device 後開始等待直到能成功進入,因此執行順序為 process \(0 \rightarrow 1 \rightarrow 2\)

IN MAIN: Creating thread 0.
IN MAIN: Creating thread 1.
IN MAIN: Creating thread 2.
process 0: f(5) = 5
Process 1 Fail to open device
IN MAIN: All threads are created.
Process 2 Fail to open device
process 0: f(6) = 8
Process 1 Fail to open device
Process 2 Fail to open device
process 0: f(7) = 13
Process 1 Fail to open device
Process 2 Fail to open device
process 0: f(8) = 21
Process 1 Fail to open device
Process 2 Fail to open device
process 1: f(5) = 5
IN MAIN: Thread 0 has ended.
Process 2 Fail to open device
process 1: f(6) = 8
Process 2 Fail to open device
process 1: f(7) = 13
Process 2 Fail to open device
process 1: f(8) = 21
Process 2 Fail to open device
IN MAIN: Thread 1 has ended.
process 2: f(5) = 5
process 2: f(6) = 8
process 2: f(7) = 13
process 2: f(8) = 21
IN MAIN: Thread 2 has ended.
MAIN program has ended.

沒有 mutex

拔除 mutex 後執行順序會交互,但運算結果還是正確,這是因為 read 行為只用到 offset 來當 \(F(n)\) 的輸入,沒有共用任何變數

IN MAIN: Creating thread 0.
IN MAIN: Creating thread 1.
process 0: f(5) = 5
IN MAIN: Creating thread 2.
process 1: f(5) = 5
IN MAIN: All threads are created.
process 2: f(5) = 5
process 0: f(6) = 8
process 1: f(6) = 8
process 2: f(6) = 8
process 0: f(7) = 13
process 1: f(7) = 13
process 2: f(7) = 13
process 0: f(8) = 21
process 1: f(8) = 21
process 2: f(8) = 21
IN MAIN: Thread 0 has ended.
IN MAIN: Thread 1 has ended.
IN MAIN: Thread 2 has ended.
MAIN program has ended.
  • 為了測試不加入 mutex 的影響,在 fibdrv.c 內加入共用變數 ratio,並更改 fib_read 行為:
    • fib_read 會計算 \(F(n)\) 後與 ratio 相乘後回傳運算結果
    • 每當 read 一次,ratio 就 +1
    • fib_open 成功後和 fib_release 都會把 ratio 歸零
static int ratio=0;

/* calculate the fibonacci number at given offset */
static ssize_t fib_read(struct file *file,
                        char *buf,
                        size_t size,
                        loff_t *offset)
{
    ratio += 1;
    if (NUM_MODE == 0) {
        return (ssize_t) fib_sequence(*offset);
    } else {
        bignum *fib = bn_init(1);
        bn_fib_sequence(fib, *offset);
        char *p = bn_to_str(fib);

        size_t len = strlen(p) + 1;
        size_t left = copy_to_user(buf, p, len);
        bn_free(fib);
        kfree(p);
        return left;
    }
}

預期正常的執行結果為:

F(5) = 5 * 1 = 5
F(6) = 8 * 2 = 16
F(7) = 13 * 3 = 39
F(8) = 21 * 4 = 84

有 mutex lock 輸出結果如同預期

IN MAIN: Creating thread 0.
IN MAIN: Creating thread 1.
process 0: f(5) = 5
IN MAIN: Creating thread 2.
IN MAIN: All threads are created.
Process 1 Fail to open device
Process 2 Fail to open device
process 0: f(6) = 16
Process 1 Fail to open device
Process 2 Fail to open device
process 0: f(7) = 39
Process 1 Fail to open device
Process 2 Fail to open device
process 0: f(8) = 84
Process 1 Fail to open device
Process 2 Fail to open device
IN MAIN: Thread 0 has ended.
process 1: f(5) = 5
Process 2 Fail to open device
process 1: f(6) = 16
Process 2 Fail to open device
process 1: f(7) = 39
Process 2 Fail to open device
process 1: f(8) = 84
Process 2 Fail to open device
IN MAIN: Thread 1 has ended.
process 2: f(5) = 5
process 2: f(6) = 16
process 2: f(7) = 39
process 2: f(8) = 84
IN MAIN: Thread 2 has ended.
MAIN program has ended.

去除 mutex 後 ratio 的數值會隨著不同 process 的 read / open / release 行為更動,最後產生無法預期的結果

IN MAIN: Creating thread 0.
IN MAIN: Creating thread 1.
process 0: f(5) = 5
IN MAIN: Creating thread 2.
process 1: f(5) = 5
IN MAIN: All threads are created.
process 2: f(5) = 5
process 0: f(6) = 16
process 1: f(6) = 24
process 2: f(6) = 32
process 0: f(7) = 65
process 1: f(7) = 78
process 2: f(7) = 91
process 0: f(8) = 168
process 1: f(8) = 189
process 2: f(8) = 210
IN MAIN: Thread 0 has ended.
IN MAIN: Thread 1 has ended.
IN MAIN: Thread 2 has ended.
MAIN program has ended.

實做大數運算

參考 KYG-yaya573142 的實做

資料結構

  • 用 1 byte 來存 10 進制的一個位元,確保不會超出9*9的數值
  • size 紀錄目前有多少位元, sign 紀錄正負號
typedef struct __bignum {
    unsigned char *number;  // store num by 1byte int
    unsigned int size;      // data length
    int sign;               // 0:negative ,1:positive
} bignum;

初始化

在宣告實體方面, bn_init 可以宣告 size 位元的實體,並把數值全設定為0

bignum *bn_init(size_t size)
{
    // create bn obj
    bignum *new_bn = kmalloc(sizeof(bignum), GFP_KERNEL);

    // init data
    new_bn->number = kmalloc(size, GFP_KERNEL);
    memset(new_bn->number, 0, size);

    // init size and sign
    new_bn->size = size;
    new_bn->sign = 0;

    return new_bn;
}

為了靈活運用,新增從 int 初始化的方式

// create bignum from int
bignum *bn_from_int(int x)
{
    if (x < 10) {
        bignum *result = bn_init(1);
        result->number[0] = x;
        return result;
    }

    bignum *result = bn_init(10);

    size_t size = 0;

    while (x) {
        result->number[size] = x % 10;
        x /= 10;
        size++;
    }
    bn_resize(result, size);
    return result;
}

大數加法

  • 先檢查兩者符號,如果有一方為負號改為減法運算
  • 加法的部份,先開一個大於 MAX(a->size, b->size) + 1 的空間來除存運算結果
  • 如果位元相加大於10就進位,用 carry 來紀錄
// add two bignum and store at result
void bn_add(bignum *a, bignum *b, bignum *result)
{
    if (a->sign && !b->sign) {  // a neg, b pos, do b-a
        return;
    } else if (!a->sign && b->sign) {  // a pos, b neg, do a-b
        return;
    }

    // pre caculate how many digits that result need
    int digit_width = MAX(a->size, b->size) + 1;
    bn_resize(result, digit_width);

    unsigned char carry = 0;  // store add carry

    for (int i = 0; i < digit_width - 1; i++) {
        unsigned char temp_a = i < a->size ? a->number[i] : 0;
        unsigned char temp_b = i < b->size ? b->number[i] : 0;

        carry += temp_a + temp_b;

        result->number[i] = carry - (10 & -(carry > 9));
        carry = carry > 9;
    }

    if (carry) {
        result->number[digit_width - 1] = 1;
    } else {
        bn_resize(result, digit_width - 1);
    }

    result->sign = a->sign;
}

大數轉字串

宣告位元長度 +1的空字串並填入數字

// bn to string
char *bn_to_str(bignum *src)
{
    size_t len = sizeof(char) * src->size + 1;
    char *p = kmalloc(len,GFP_KERNEL);
    memset(p,0, len-1);
    int n = src->size - 1;
    for (int i = 0; i < src->size; i++) {
        p[i] = src->number[n - i] + '0';
    }
    return p;
}

大數空間操作

包含 指標交換, free, copy

// free bignum
int bn_free(bignum *src)
{
    if (src == NULL)
        return -1;
    kfree(src->number);
    kfree(src);
    return 0;
}

void bn_swap(bignum *a, bignum *b)
{
    bignum tmp = *a;
    *a = *b;
    *b = tmp;
}

int bn_cpy(bignum *dest, bignum *src)
{
    if (bn_resize(dest, src->size) < 0)
        return -1;
    dest->sign = src->sign;
    memcpy(dest->number, src->number, src->size * sizeof(int));
    return 0;
}

大數費氏數列

dest 為要寫入的目標,透過 swap 來達到把結果左移的效果

void bn_fib_sequence(bignum *dest, long long k)
{
    bn_resize(dest, 1);

    if (k < 2) {
        dest->number[0] = k;
        return;
    }
    
    bignum *a = bn_from_int(0);
    bignum *b = bn_from_int(0);
    dest->number[0] = 1;

    for (int i = 2; i <= k; i++) {
        bn_swap(b,dest);
        bn_add(a, b, dest);
        bn_swap(a,b);
    }

    bn_free(a);
    bn_free(b);

}

0xff07/bignum 的比較

嘗試研讀 bignum (fibdrv 分支) 的實作,理解其中的技巧並與你的實作進行效能比較,探討個中值得取鏡之處。

比較自己實做的 bignum 與教授提供提供的 bignum (my為我的實做, ref為教授的實做)

單就 iteration 版本來比較運算時間差異巨大,代表除存數字的資料結構本身有很大的問題。以十進制為單位的運算,相較於 apm_digituint_64_t 或是 uint_32_t 為,會進行更多次的進位加法

大數資料結構比較:假設要做6位數的運算,以十進制的資料結構要做許多次的加法和進位計算,以 uint_64_t / uint_32_t 的則只要一次即可







%0



a

uint_64_t / uint_32_t



d

0~9



e

0~9



d--e




f

0~9



e--f




x

0~9



f--x




y

0~9



x--y




z

0~9



y--z




分析-大數資料結構

  • bn 為大數的資料結構,裡面包含了
    • apm_digit pointer 指向數據,根據系統不同可為 uint32_t 或是 uint64_t
    • apm_size (uint32_t) 的 size 和 alloc 用來紀錄 digits 長度和已經分配多少個 apm_digit 大小的空間
    • sign 紀錄正負數
/* bn.h */
typedef struct {
    apm_digit *digits; /* Digits of number. */
    apm_size size;     /* Length of number. */
    apm_size alloc;    /* Size of allocation. */
    unsigned sign : 1; /* Sign bit. */
} bn, bn_t[1];

分析-加法 bn_add

c = a+b

  1. 檢查a,b 是否都為 0,是的話可以直接指定 c 的數值
  2. 檢查a,b 是否有相同
    • 如果 a == b == c 可以直接讓 c <<= 1
    • cy 來紀錄進位,有進位則讓 c 的空間增大並儲存進位
    • 位移運算透過 apm_lshifti ,
  3. 檢查 a, b 的負號是否相同
    • 相同: 執行正常加法,預先讓 c 的空間 = MAX(a->size, b->size) 確保能夠儲存進位
    • 不同: 減法運算,讓 a 指標指向正數, b 指標則指向負數,依據絕對值大小來做不同減法運算
code
/* bignum.c */
// add two bignum and store at result
void bn_add(const bn *a, const bn *b, bn *c)
{
    if (a->size == 0) {
        if (b->size == 0)
            c->size = 0;
        else
            bn_set(c, b);
        return;
    } else if (b->size == 0) {
        bn_set(c, a);
        return;
    }

    if (a == b) {
        apm_digit cy;
        if (a == c) {
            cy = apm_lshifti(c->digits, c->size, 1);
        } else {
            BN_SIZE(c, a->size);
            cy = apm_lshift(a->digits, a->size, 1, c->digits);
        }
        if (cy) {
            BN_MIN_ALLOC(c, c->size + 1);
            c->digits[c->size++] = cy;
        }
        return;
    }

    /* Note: it should work for A == C or B == C */
    apm_size size;
    if (a->sign == b->sign) { /* Both positive or negative. */
        size = MAX(a->size, b->size);
        BN_MIN_ALLOC(c, size + 1);
        apm_digit cy =
            apm_add(a->digits, a->size, b->digits, b->size, c->digits);
        if (cy)
            c->digits[size++] = cy;
        else
            APM_NORMALIZE(c->digits, size);
        c->sign = a->sign;
    } else { /* Differing signs. */
        if (a->sign)
            SWAP(a, b);

        ASSERT(a->sign == 0);
        ASSERT(b->sign == 1);

        int cmp = apm_cmp(a->digits, a->size, b->digits, b->size);
        if (cmp > 0) { /* |A| > |B| */
            /* If B < 0 and |A| > |B|, then C = A - |B| */
            BN_MIN_ALLOC(c, a->size);
            ASSERT(apm_sub(a->digits, a->size, b->digits, b->size, c->digits) ==
                   0);
            c->sign = 0;
            size = apm_rsize(c->digits, a->size);
        } else if (cmp < 0) { /* |A| < |B| */
            /* If B < 0 and |A| < |B|, then C = -(|B| - |A|) */
            BN_MIN_ALLOC(c, b->size);
            ASSERT(apm_sub(b->digits, b->size, a->digits, a->size, c->digits) ==
                   0);
            c->sign = 1;
            size = apm_rsize(c->digits, b->size);
        } else { /* |A| = |B| */
            c->sign = 0;
            size = 0;
        }
    }
    c->size = size;
}

每當要進行加減法之前,都先用 BN_MIN_ALLOC 來確保 c 所分配的空間能存放運算結果,也運用到了 do...while 來避免展開後的 dangling else 發生

而 __n->alloc = ((__s + 3) & ~3U)) 這段相當於讓 n->alloc = \(s*4\)
在系統中 apm_digit 可為 4 Byte / 8 Byte ,指派 alloc 為 \(size*4\) 也符合 min alloc 的用意

/* bignum.c */
#define BN_MIN_ALLOC(n, s)                                               \
    do {                                                                 \
        bn *const __n = (n);                                             \
        const apm_size __s = (s);                                        \
        if (__n->alloc < __s) {                                          \
            __n->digits =                                                \
                apm_resize(__n->digits, __n->alloc = ((__s + 3) & ~3U)); \
        }                                                                \
    } while (0)

而實際的加減法是由 apm_add, apm_sub 執行

  1. apm_add 會先做檢查來確保運算能正常進行
  2. 讓 apm_add_n 執行加法,第三個參數放入較小 bn->size 較小者
  3. 加法完成後更高位元的部分,只需要直接複製即可
apm_add
apm_digit apm_add(const apm_digit *u,
                  apm_size usize,
                  const apm_digit *v,
                  apm_size vsize,
                  apm_digit *w)
{
    ASSERT(u != NULL);
    ASSERT(usize > 0);
    ASSERT(u[usize - 1]);
    ASSERT(v != NULL);
    ASSERT(vsize > 0);
    ASSERT(v[vsize - 1]);

    if (usize < vsize) {
        apm_digit cy = apm_add_n(u, v, usize, w);
        if (v != w)
            apm_copy(v + usize, vsize - usize, w + usize);
        return cy ? apm_inc(w + usize, vsize - usize) : 0;
    } else if (usize > vsize) {
        apm_digit cy = apm_add_n(u, v, vsize, w);
        if (u != w)
            apm_copy(u + vsize, usize - vsize, w + vsize);
        return cy ? apm_inc(w + vsize, usize - vsize) : 0;
    }
    /* usize == vsize */
    return apm_add_n(u, v, usize, w);
}

apm_add_n

讓 apm_digit *u 和 apm_digit *v 分別指向兩個變數的最低位元,從最低為原往高位元運算

在 15 16 兩行完成了

  1. u += 進位
  2. u + v 的計算結果儲存到 w
  3. 算出下一層的進位

非常精簡的寫法,最後 return 進位,來讓呼叫他的 apm_add 決定是否擴大空間儲存進位

apm_add_n
/* Set w[size] = u[size] + v[size] and return the carry. */ apm_digit apm_add_n(const apm_digit *u, const apm_digit *v, apm_size size, apm_digit *w) { ASSERT(u != NULL); ASSERT(v != NULL); ASSERT(w != NULL); apm_digit cy = 0; while (size--) { apm_digit ud = *u++; const apm_digit vd = *v++; cy = (ud += cy) < cy; cy += (*w = ud + vd) < vd; ++w; } return cy; }

apm_sub_n

在這邊的 u 必大於 v,技巧跟 add 一樣只是回傳的是借位。可以用 (*w = ud - vd) > ud 來計算借位,是因為 unsigned int 如果不夠減,會從最大值往循環,例如

(uint_32_t) 10 - (uint_32_t) 11 = 4294967295

apm_sub_n
/* Set w[size] = u[size] - v[size] and return the borrow. */
apm_digit apm_sub_n(const apm_digit *u,
                    const apm_digit *v,
                    apm_size size,
                    apm_digit *w)
{
    ASSERT(u != NULL);
    ASSERT(v != NULL);
    ASSERT(w != NULL);

    apm_digit cy = 0;
    while (size--) {
        const apm_digit ud = *u++;
        apm_digit vd = *v++;
        cy = (vd += cy) < cy;
        cy += (*w = ud - vd) > ud;
        ++w;
    }
    return cy;
}

分析-乘法 bn_mul

  1. 檢查是否為0或是相等
  2. 這邊先檢查,如果 a 或 b 其中一者跟 c 是否指向相同變數,如果相同則要另外宣告 prod 來存放運算結果,避免計算和儲存都是用同一個變數。
  3. 第15行用意是 宣告 prod 變數來暫存下一行 \(a*b\) 的結果。 分配的記憶體大小為 APM_TMP_ALLOC(csize) 展開成 xmalloc(size * APM_DIGIT_SIZE) (APM_DIGIT_SIZE 可為 4 或 8,代表 apm_dight 占用多少 Bytes)
  4. 實際運算由 apm_mul 來完成
bn_mul
void bn_mul(const bn *a, const bn *b, bn *c) { if (a->size == 0 || b->size == 0) { bn_zero(c); return; } if (a == b) { bn_sqr(a, c); return; } apm_size csize = a->size + b->size; if (a == c || b == c) { apm_digit *prod = APM_TMP_ALLOC(csize); apm_mul(a->digits, a->size, b->digits, b->size, prod); csize -= (prod[csize - 1] == 0); BN_SIZE(c, csize); apm_copy(prod, csize, c->digits); APM_TMP_FREE(prod); } else { ASSERT(a->digits[a->size - 1] != 0); ASSERT(b->digits[b->size - 1] != 0); BN_MIN_ALLOC(c, csize); apm_mul(a->digits, a->size, b->digits, b->size, c->digits); c->size = csize - (c->digits[csize - 1] == 0); } c->sign = a->sign ^ b->sign; }

apm_mul

  1. 在乘法運算之前, u 和 v 會先進行 resize,去除 leading zero,並回傳實際位元數
  2. 如果 u, v 其一為 0 ,結果為 0
  3. 在 15-16 ,將不會用到的位元設為 0。在進入 apm_mul 前已經分派了 usize + vsize 的空間,不會用到的部分就是 (usize+vsize)-(ul + vl)(ul + vl) 是去除 leading zero 的位元數
  4. 確保 u 為位元數較大的那一方
  5. 在 32 行做了優化,當 vsize < KARATSUBA_MUL_THRESHOLD 時用 \(O(N^2)\) 複雜度的乘法,但當超過 KARATSUBA_MUL_THRESHOLD 就用 apm_mul_n(karatsuba 快速乘法)。
apm_mul
void apm_mul(const apm_digit *u, apm_size usize, const apm_digit *v, apm_size vsize, apm_digit *w) { { const apm_size ul = apm_rsize(u, usize); const apm_size vl = apm_rsize(v, vsize); if (!ul || !vl) { apm_zero(w, usize + vsize); return; } /* Zero digits which won't be set. */ if (ul + vl != usize + vsize) apm_zero(w + (ul + vl), (usize + vsize) - (ul + vl)); /* Wanted: USIZE >= VSIZE. */ if (ul < vl) { SWAP(u, v); usize = vl; vsize = ul; } else { usize = ul; vsize = vl; } } ASSERT(usize >= vsize); if (vsize < KARATSUBA_MUL_THRESHOLD) { _apm_mul_base(u, usize, v, vsize, w); return; } apm_mul_n(u, v, vsize, w); if (usize == vsize) return; apm_size wsize = usize + vsize; apm_zero(w + (vsize * 2), wsize - (vsize * 2)); w += vsize; wsize -= vsize; u += vsize; usize -= vsize; apm_digit *tmp = NULL; if (usize >= vsize) { tmp = APM_TMP_ALLOC(vsize * 2); do { apm_mul_n(u, v, vsize, tmp); ASSERT(apm_addi(w, wsize, tmp, vsize * 2) == 0); w += vsize; wsize -= vsize; u += vsize; usize -= vsize; } while (usize >= vsize); } if (usize) { /* Size of U isn't a multiple of size of V. */ if (!tmp) tmp = APM_TMP_ALLOC(usize + vsize); /* Now usize < vsize. Rearrange operands. */ if (usize < KARATSUBA_MUL_THRESHOLD) _apm_mul_base(v, vsize, u, usize, tmp); else apm_mul(v, vsize, u, usize, tmp); ASSERT(apm_addi(w, wsize, tmp, usize + vsize) == 0); } APM_TMP_FREE(tmp); }

_apm_mul_base

複雜度為 \(O(N)\) 的乘法,作法為依序讓 v 的單一位元和 u 相乘。

在 digit 乘法做了以下優化,用了 inline asm

#define digit_mul(u, v, hi, lo) \
    __asm__("mulq %3" : "=a"(lo), "=d"(hi) : "%0"(u), "rm"(v))

Extended assembly 的格式為以下,得知 lo 和 hi 是 output operands

asm ( assembler template 
           : output operands                  /* optional */
           : input operands                   /* optional */
           : list of clobbered registers      /* optional */
           );

register 對應表

+---+--------------------+
| r |    Register(s)     |
+---+--------------------+
| a |   %eax, %ax, %al   |
| b |   %ebx, %bx, %bl   |
| c |   %ecx, %cx, %cl   |
| d |   %edx, %dx, %dl   |
| S |   %esi, %si        |
| D |   %edi, %di        |
+---+--------------------+

回來看程式碼,以下解釋 %0%3 代表什麼意思

// operand = [%rax, %rdx, %rax, v]
// %0 = operand[0] = rax
// %3 = operand[3] = v
__asm__("mulq %3" : "=a"(lo), "=d"(hi) : "%0"(u), "rm"(v))

再來談到限制,=a 指定要用 rax 暫存器,=b 指定 rdx 暫存器

// output operand
"=a"(lo), "=d"(hi)
  • rm 的解釋:
    • r 代表 read only
    • m 代表可以直接用記憶體運算(不用移動到任何暫存器)
"%0"(u), "rm"(v)

mulq 執行: 將輸入 s 乘上 rax 暫存器,存放結果到 R[%rdx]:R[%rax] 構成的 128bit oct-word

mulq S | R[%rdx]:R[%rax] ← S × R[%rax]
  • 在我們的例子就是
    1. 將 %3 (變數 v) 乘上 rax 暫存器 (變數 u)
    2. 結果放到 %rdx%rax 暫存器 (都是64bit),%rdx 位元 / %rax 位元
    3. 放完後再移動到我們指定的位置 lo 和 hi 變數

ref:

_apm_mul_base
/* Multiply u[usize] by v[vsize] and store the result in w[usize + vsize],
 * using the simple quadratic-time algorithm.
 */
void _apm_mul_base(const apm_digit *u,
                   apm_size usize,
                   const apm_digit *v,
                   apm_size vsize,
                   apm_digit *w)
{
    ASSERT(usize >= vsize);

    /* Find real sizes and zero any part of answer which will not be set. */
    apm_size ul = apm_rsize(u, usize);
    apm_size vl = apm_rsize(v, vsize);
    /* Zero digits which will not be set in multiply-and-add loop. */
    if (ul + vl != usize + vsize)
        apm_zero(w + (ul + vl), usize + vsize - (ul + vl));
    /* One or both are zero. */
    if (!ul || !vl)
        return;

    /* Now multiply by forming partial products and adding them to the result
     * so far. Rather than zero the low ul digits of w before starting, we
     * store, rather than add, the first partial product.
     */
    apm_digit *wp = w + ul;
    *wp = apm_dmul(u, ul, *v, w);
    while (--vl) {
        apm_digit vd = *++v;
        *++wp = apm_dmul_add(u, ul, vd, ++w);
    }
}

Karatsuba 快速乘法

Karatsuba 為一種快速乘法方式,可將原先 \(n 位數 \times n位數\) ,所需 \(n^2\) 次的乘法次數減少為 \(3n^{\log{2}^{3}}\)

影片介紹

假要做 \(x=56\)\(y=12\) 的乘法,可以先把 x 和 y 個別拆成 (0~u-1)位元 和 (u~2u)位元,\(u = max(digit(x),digit(y)) /2\) ,在這為1

\[ x = 56, \ a = 5, \ b = 6 \\ y = 12, \ c = 1, \ d =2 \\ \]
x, y 可寫成
\[ x = a * 10^{u} + b =56 \\ y = c * 10^{u} + d =12 \]

兩者相乘
\[ x * y = ac*10^{2u} + 10^{u}(ad+bc) + bd \]
需要求出
\[ (ac) \ and \ (ad+bc) \ and \ (bd) \]

Karatsuba 的做法為:

\(Step \ 1.\) 求出 \(a*c\) (mul times+1)
\(Step \ 2.\) 求出 \(b*d\) (mul times+1)
\(Step \ 3.\) 求出 \((a+b)(c+d)\) (mul times+1)
\(Step \ 4.\) 計算 step3 - step2 - step1 = ad + bc
\(Step \ 5.\) 整合結果 \(x * y = ac*10^{2u} + 10^{u}(ad+bc) + bd\)

  • 上述 step 乘法的部分只要其中一方位元數 > threadhold 就用 karatsuba 乘法,反之正常乘法。
  • 節省乘法關鍵在於: step4 原本需要計算二次乘法,但透過儲存 \(step1\)\(step2\) 運算結果能節省一次乘法。
  • 乘上10的次方只需要做 shift 運算就可,不需要真的去做乘法運算。

apm_mul_n - 基於 karatsuba 的快速乘法

目標為求出

\((2^{2N} + 2^N)u1*v1 + (2^N)(u1-U0)(v0-v1) + (2^N + 1)(u0*v0)\)

中項使用 (U1-U0)(V0-V1) 是希望避免進位導致多餘的計算

  1. 第 12~46 行將數組切分的方式為用指標 u1 u0 v1 v0 指向切分起點(這邊代表前面介紹算法的a, b, c, d 變數),進行遞迴算出\(u1 * v1\)\(u0 * v0\)
  2. 第 52~81 行
    • 先將 \(u0*v0\)\(u1*v1\) 的結果存放到 w,對應到公式的 \(2^N\) 部份需要乘上的係數
    • cy 紀錄上面(\(2^N\))運算產生的加法進位
    • tmp 存放 (U1-U0)(V0-V1) 運算結果
    • prod_neg 紀錄 (U1-U0)(V0-V1) 會不會產生負號
  3. 計算 \(2^N\) 的係數: \(cy + (U1-U0)(V0-V1)\)
  4. 結合所有項
apm_mul_n
/* Karatsuba multiplication [cf. Knuth 4.3.3, vol.2, 3rd ed, pp.294-295] * Given U = U1*2^N + U0 and V = V1*2^N + V0, * we can recursively compute U*V with * (2^2N + 2^N)U1*V1 + (2^N)(U1-U0)(V0-V1) + (2^N + 1)U0*V0 * * We might otherwise use * (2^2N - 2^N)U1*V1 + (2^N)(U1+U0)(V1+V0) + (1 - 2^N)U0*V0 * except that (U1+U0) or (V1+V0) may become N+1 bit numbers if there is carry * in the additions, and this will slow down the routine. However, if we use * the first formula the middle terms will not grow larger than N bits. */ static void apm_mul_n(const apm_digit *u, const apm_digit *v, apm_size size, apm_digit *w) { /* TODO: Only allocate a temporary buffer which is large enough for all * following recursive calls, rather than allocating at each call. */ if (u == v) { apm_sqr(u, size, w); return; } if (size < KARATSUBA_MUL_THRESHOLD) { _apm_mul_base(u, size, v, size, w); return; } const bool odd = size & 1; const apm_size even_size = size - odd; const apm_size half_size = even_size / 2; const apm_digit *u0 = u, *u1 = u + half_size; const apm_digit *v0 = v, *v1 = v + half_size; apm_digit *w0 = w, *w1 = w + even_size; /* U0 * V0 => w[0..even_size-1]; */ /* U1 * V1 => w[even_size..2*even_size-1]. */ if (half_size >= KARATSUBA_MUL_THRESHOLD) { apm_mul_n(u0, v0, half_size, w0); apm_mul_n(u1, v1, half_size, w1); } else { _apm_mul_base(u0, half_size, v0, half_size, w0); _apm_mul_base(u1, half_size, v1, half_size, w1); } /* Since we cannot add w[0..even_size-1] to w[half_size ... * half_size+even_size-1] in place, we have to make a copy of it now. * This later gets used to store U1-U0 and V0-V1. */ apm_digit *tmp = APM_TMP_COPY(w0, even_size); apm_digit cy; /* w[half_size..half_size+even_size-1] += U1*V1. */ cy = apm_addi_n(w + half_size, w1, even_size); /* w[half_size..half_size+even_size-1] += U0*V0. */ cy += apm_addi_n(w + half_size, tmp, even_size); /* Get absolute value of U1-U0. */ apm_digit *u_tmp = tmp; bool prod_neg = apm_cmp_n(u1, u0, half_size) < 0; if (prod_neg) apm_sub_n(u0, u1, half_size, u_tmp); else apm_sub_n(u1, u0, half_size, u_tmp); /* Get absolute value of V0-V1. */ apm_digit *v_tmp = tmp + half_size; if (apm_cmp_n(v0, v1, half_size) < 0) apm_sub_n(v1, v0, half_size, v_tmp), prod_neg ^= 1; else apm_sub_n(v0, v1, half_size, v_tmp); /* tmp = (U1-U0)*(V0-V1). */ tmp = APM_TMP_ALLOC(even_size); if (half_size >= KARATSUBA_MUL_THRESHOLD) apm_mul_n(u_tmp, v_tmp, half_size, tmp); else _apm_mul_base(u_tmp, half_size, v_tmp, half_size, tmp); APM_TMP_FREE(u_tmp); /* Now add / subtract (U1-U0)*(V0-V1) from * w[half_size..half_size+even_size-1] based on whether it is negative or * positive. */ if (prod_neg) cy -= apm_subi_n(w + half_size, tmp, even_size); else cy += apm_addi_n(w + half_size, tmp, even_size); APM_TMP_FREE(tmp); /* Now if there was any carry from the middle digits (which is at most 2), * add that to w[even_size+half_size..2*even_size-1]. */ if (cy) { ASSERT(apm_daddi(w + even_size + half_size, half_size, cy) == 0); } if (odd) { /* We have the product U[0..even_size-1] * V[0..even_size-1] in * W[0..2*even_size-1]. We need to add the following to it: * V[size-1] * U[0..size-2] * U[size-1] * V[0..size-1] */ w[even_size * 2] = apm_dmul_add(u, even_size, v[even_size], &w[even_size]); w[even_size * 2 + 1] = apm_dmul_add(v, size, u[even_size], &w[even_size]); } }

面談問題

dev/fibonacci ⇒ client.c (可能會是多個 process) / lseek
VFS 的 lseek 操作的用意?

圖/資料來源

virtual file system 是核心建立的中介層,提供抽象的方法 像是 open, stat, read, write, chmod 讓 user 和 kernel 互動。遵循 vfs 開發的檔案系統可以被掛載到 linux 核心內。

fibdrv.c 內的 file object

const struct file_operations fib_fops = {
    .owner = THIS_MODULE,
    .read = fib_read,
    .write = fib_write,
    .open = fib_open,
    .release = fib_release,
    .llseek = fib_device_lseek,
};

file_operations 結構定義在 <linux/fs.h> 內,有此結構的程式載入到 kernel 後可與 user space 透過 VFS 介面互動。

lseek 的用意是更改 process 的 file descriptor 指向的 open file table offset(讀寫位置),且每個被 open 的 file,他們的 file descriptor 都是獨立的互不干擾。