Try   HackMD

Linux 核心專題: 井字遊戲改進

執行人: Daniel-0224
期末專題錄影

Reviewed by chloe0919

能否增加圖表顯示 Fixed point Square root 的實作和浮點數 sqrt 之間的誤差

Reviewed by fennecJ

  • 圖片存取權限未正確設定

  • 問題一
    針對 Fixed point Square root 的實做方法,除了第三周測驗一的方法外,也可使用牛頓法進行,可否針對兩者實做方法進行比較,針對 ttt 的應用場景探討兩個作法的優劣

  • 問題二
    專題中的定點數採用 Q23.8 的二補數型態,然而本次實做中 mcts 使用到定點數運算的場合應不涉及負數, log 以及開根號也不能針對負數進行運算。是否有機會善用因不會遇到負數而多出來的 1 bit 增加運算精度?

  • 討論一
    我嘗試執行了同學提供於 github 上的 專案 ,commit = e4fceab27 ,卻發生 Segmentation fault 。這邊提供發生 Segmentation fault 時我輸入的命令以便同學重現:

$ git clone git@github.com:Daniel-0224/lab0-c.git 
$ cd lab0-c
$ make
$ ./qtest
cmd> ttt
X> A1
computer move
X> A2
computer move
X> B3
computer move
X> B2
computer move
X> B1
cmd> option AI_vs_AI 1
cmd> ttt
... (computer moves)
cmd> quit

之後便發生 Segmentation fault ,請找出導致 Segmentation fault 的原因並改善。

Reviewed by yu-hsiennn

圖片存取權限失敗,以及使用 lab0 規範的程式碼書寫風格,務必用 clang-format 確認一致。

Reviewed by dockyu

如果想以 Linux Kernel 的標準來撰寫程式,要遵循以下的規範,使用 \** *\

The opening comment mark /** is used for kernel-doc comments. The kernel-doc tool will extract comments marked this way. The rest of the comment is formatted like a normal multi-line comment with a column of asterisks on the left side, closing with */ on a line by itself.

Reviewed by ChenFuhuangKye

如何證明歐拉方法的結果幾乎和浮點數 log 相等。

Reviewed by eleanorLYJ

逼近求近似 裡的公式推導 少一個 ")"

Reviewed by randyuncle

能否精簡的呈現您的程式碼,而非直接張貼所有的程式?並對呈現出來的程式碼做說明?

Reviewed by jujuegg

請問你在電腦 vs 電腦的對弈中會有每次測試的棋譜都一樣的問題嗎,只要第一步下的是一樣的位置,接著後面所有的步驟都會完全相同。

任務簡介

重做第三次作業,並彙整其他學員的成果。

TODO: 以定點數實作 Monte Carlo tree search (MCTS)

MCTS(Monte Carlo Tree Search) 是一種搜尋樹的算法。它通常應用於決策問題的求解,特別是在棋類等遊戲中。MCST 通過隨機模擬和搜尋樹的擴展來評估每個可能的決策,以找到最佳的行動策略。
MCTS 在每一次的疊代一共會有 4 步,分別為:

Selection :
在搜尋樹中從根節點開始,根據某種策略選擇下一步要擴展的節點,以找到潛在最佳的行動。
Expansion :
對於選擇的節點,擴展子節點,即生成可能的下一步行動或狀態。這些子節點尚未被完全探索,需要進一步評估其價值。
Simulation :
對每個擴展的子節點進行模擬或評估。通常使用蒙特卡羅模擬來模擬隨機的遊戲或決策過程,以估計每個子節點的潛在價值或勝率。
Backpropagation :
將模擬結果向上反向傳播到搜尋樹的根節點,更新每個節點的統計信息,如訪問次數和累計獎勵。這有助於調整每個節點的價值估計,以改進未來的選擇策略。

透過不斷迭代的選擇、擴展、模擬和反向傳播,MCTS 能夠在復雜的決策問題中尋找到較佳的解決方案。

定點數運算:

定點數運算中, Q notation 是一種指定二進制定點數參數的方法。

  • Qm.f:例如 Q23.8 表示該數有 23 個整數位、8 個小數位,是 32 位二補數。

假設我們將 fraction bit 設為

d,則一定點數
N
的實際值為
Nd
因此當我們在做定點數的加減時可以直接相加。

(N1+N2)÷d=N1+N2d

(N1N2)÷d=N1N2d

但是如果對於定點數的乘除不多加處理的話,結果會變為:

(N1N2)÷d=N1N2d

(N1/N2)÷d=N1/N2d

顯然與預期的結果不同,因此在計算完定點樹的乘除之,需要額外對積數與商數右移或左移

d

(N1N2)d÷d=N1N2d2=N1dN2d

(N1)N2dd=N1N2=N1d/N2d

了解定點數,將 MCTS 當中浮點數運算進行以下更動。

首先在 console.h 定義 SCALING_BITS

#define SCALING_BITS 8

calculate_win_value ,將 1.0, 0.0, 0.5 全部更換成自定義定點數的形式。

Q23_8 calculate_win_value(char win, char player)
{
    if (win == player)
-        return 1.0;
+        return 1U << SCALING_BITS;
    if (win == (player ^ 'O' ^ 'X'))
-        return 0.0;
+        return 0U;
-    return 0.5;
+    return 1U << (SCALING_BITS - 1);
}

另外一處重點變更在於 uct_score 的計算函式。
公式如下

wini+cln(Ni)ni

  • wi
    代表第 i 次決定後該節點贏的次數
  • ni
    代表該節點在第 i 次決定後共做幾次模擬
  • Ni
    代表該節點的父節點在第 i 次決定後共做幾次模擬
  • c
    代表 exploration parameter ,通常定為
    2

因為在 EXPLORATION_FACTOR * sqrt(log(n_total) / n_visits) 這一項, sqrt(), log() 本身就使用浮點數運算,因此我們要設計對應的定點數運算來取代這兩個函式。

static inline Q23_8 uct_score(int n_total, int n_visits, Q23_8 score)
{
    if (n_visits == 0)
        return INT32_MAX;
    Q23_8 result = score << Q / (Q23_8) (n_visits << Q);
    int64_t tmp =
        EXPLORATION_FACTOR * fixed_sqrt(fixed_log(n_total) / n_visits);
    Q23_8 resultN = tmp >> Q;
    return result + resultN;
}

Fixed point Square root

此處,利用 第三周測驗一 的方法實作,並且轉為定點數。

int i_sqrt(int x)
{
    if (x <= 1)
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}
Q23_8 fixed_sqrt(Q23_8 x)
{
    if (x <= 1 << Q)
        return x;

    Q23_8 z = 0;
    for (Q23_8 m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;
    }
    z = z << Q / 2;
    return z;
}

Fixed point log

泰勒展開式

嘗試使用泰勒展開實作自然對數的計算,公式如下:

ln(1+x)=xx22+x33x44+......

以下是我的實作:

Q23_8 fixed_log(Q23_8 x)
{
    Q23_8 fixed_x = x << SCALING_BITS;

    if (x == 0) return UINT32_MAX;

    if (x == 1) return 0ULL;

   Q23_8 result = 0;

    for (int i = 1; i <= 20; ++i) {
        if (i % 2 == 0) {
            result -= fixed_power_int(fixed_x, FIXED_SCALING_BITS, i) / (i << SCALING_BITS);
        } else {
            result += fixed_power_int(fixed_x, FIXED_SCALING_BITS, i) / (i << SCALING_BITS);
        }
        result = result >> SCALING_BITS;
    }
    return result;
}

但自己實作時遇到了一些問題,目前還沒解決。

因此參考了 marvin0102 同學的實作

Q23_8 fixed_log(int input)
{
    if (!input || input == (1U << Q))
        return 0;

    int64_t y = input << Q;
    y = ((y - (1U << Q)) << (Q)) / (y + (1U << Q));
    int64_t ans = 0U;
    for (unsigned i = 1; i < 20; i += 2) {
        int64_t z = (1U << Q);
        for (int j = 0; j < i; j++) {
            z *= y;
            z >>= Q;
        }
        z <<= Q;
        z /= (i << Q);

        ans += z;
    }
    ans <<= 1;
    return (Q23_8) ans;
}

comp

上圖是我利用同學的方法與浮點數 log 比較,可以從實驗結果發現使用泰勒展開,當數字越大時,誤差越大,且需要越多的項次才能做精準的估算。

逼近求近似

為了改善泰勒展開數字越大時誤差越大的問題,嘗試利用逼近法求 log 的近似值,假設

A<=x<=B ,求其近似值,方法如下:

CAB,因為
log(C)log(AB)=12(log(A)+log(B)
,因此我們可以得到
log(C)
log(A)
log(B)
的平均。

接著比較

x
C
,如果
x=C
及得到我們所求,如果
x<C
,我們則將
B
替換成
C
,如果
x>C
,我們則將
A
替換成
C
,繼續下一輪的求近似。

因為求近似值時,須先知道

log(A)
log(B)
的實際值,因此實作時,將以
log2(x)
做計算,且
A
B
分別為
2m
2m+1
,其中
2m<=x<=2m+1

      A   x   B
------>---|---<------

實作如下:

Q23_8 fixed_log(int input)
{
    if (!input || input == 1)
        return 0;

    Q23_8 y = input << Q;  // int to Q15_16
    Q23_8 L = 1L << ((31 - __builtin_clz(y))), R = L << 1;
    Q23_8 Llog = (31 - __builtin_clz(y) - Q) << Q, Rlog = Llog + (1 << Q), log;

    for (int i = 1; i < 20; i++) {
        if (y == L)
            return Llog;
        else if (y == R)
            return Rlog;
        log = fixed_div(Llog + Rlog, 2 << Q);

        int64_t tmp = ((int64_t) L * (int64_t) R) >> Q;
        tmp = fixed_sqrt((Q23_8) tmp);

        if (y >= tmp) {
            L = tmp;
            Llog = log;
        } else {
            R = tmp;
            Rlog = log;
        }
    }

    return (Q23_8) log;
}

comp

歐拉方法的結果幾乎和浮點數 log 相等。

TODO: 實作「人 vs.電腦」和「電腦 vs. 電腦」的對弈

整合 linux/list.hjserv/ttt

首先將 Linux 核心的 linux/list.h 標頭檔與 hlist 相關的程式碼抽出,成為 hlist.h 整合進 lab-0,再將 jserv/ttt 專案的程式碼部分整合進 lab0-c 。其中包含 consolegamemt19937mctsnegamex 等檔案。

新增 ttt 命令

jserv/tttmain 檔案修改整合入 console.c ,新增ADD_COMMMAND ,修改 Makefile 讓程式能夠執行。

ADD_COMMAND(ttt, "Play Tic-Tac-Toe Game", "");

commit 95a09bc

人 vs.電腦

做完上述方法就能執行 qtest 再使用 do_ttt 開始玩 Tic-Tac-Toe

tseng@tseng-System-Product-Name:~/linux2024/lab0-c$ ./qtest
cmd> ttt
 1 |             
 2 |             
 3 |             
 4 |             
---+-------------
      A  B  C  D
X> b2
 1 |             
 2 |     ×     ○ 
 3 |             
 4 |             
---+-------------
      A  B  C  D

電腦 vs.電腦

增加 option 指令 AI_vs_AI,讓 AIAI 對弈,使用的演算法是 negamax

static int ai_vs_ai = 0;
else if (ai_vs_ai) {
            int move = negamax_predict(table, ai).move;
            if (move != -1) {
                table[move] = turn;
                record_move(move);
            }
}

預設值為 0,欲更改時需要設定 AI_vs_AI 的參數。在 qtest 要下的命令是:option AI_vs_AI 1

螢幕顯示當下的時間 (含秒數):

static void show_time()
{
    struct timeval currentTime;
    gettimeofday(&currentTime, NULL);
    time_t rawtime;
    time(&rawtime);
    struct tm *timeinfo = localtime(&rawtime);
    printf("Current Time: %02d:%02d:%02d:%02ld\n", timeinfo->tm_hour,
           timeinfo->tm_min, timeinfo->tm_sec, currentTime.tv_usec / 10000);
}

部份結果:

tseng@tseng-System-Product-Name:~/linux2024/lab0-c$ ./qtest
cmd> option AI_vs_AI 1
cmd> ttt
 1 |             
 2 |     ×       
 3 |             
 4 |             
---+-------------
      A  B  C  D
Current Time: 05:28:11:07
 1 |             
 2 |     ×       
 3 |             
 4 |        ○    
---+-------------
      A  B  C  D
Current Time: 05:28:11:55
 1 |             
 2 |     ×       
 3 |             
 4 |     ×  ○    
---+-------------
      A  B  C  D

TODO: 引入 coroutine 來處理對弈

參考 排程器原理
參考 coro
< HenryChaing > 同學

使用 coroutine 的方式實作「電腦 vs 電腦」的對弈模式,其中電腦 A 是使用 negamax 演算法,而電腦 B 是使用 MCTS 演算法。而電腦 A、B 分別對應到 task0 及 task1。至於任務之間的切換,是採用 setjmp + longjmp 的方法。

task

struct task {
    jmp_buf env;
    struct list_head list;
    char task_name[10];
    char *table;
    char turn;
};

首先可以看到 struct task 有 2 個成員,分別為用於 setjmp()jmp_buf env 以及用於排程的 list。第一個成員變數 env 即是用來儲存這次進入任務前的程式執行狀態。

setjmp/longjmp

這兩個函式可以轉換程式執行的順序,其中 setjmp 可以利用 jum_buf 儲存目前程式的狀態,並且在遇到 longjmp(jum_buf) 後跳回 setjmp 並恢復程式的儲存狀態,這樣的函式設計可以方便程式執行時在不同任務間轉換。

task_add/task_switch

這是主要切換任務的函式,我們用 list_head 構成的循環雙向鏈結串列存放即將執行的任務,也就是存放 jmp_buf 。其中 task_add 可以將任務加到串列當中, task_switch 可以切換到我們紀錄的下一個任務執行。

流程設計參照下方程式碼, schedule 函式會將兩個任務放到佇列中,而任務執行完的當下會再將這個任務加到佇列當中,若此對局勝負揭曉則不會再將加到佇列當中,佇列為空也就代表並行程式結束執行。

static void task_add(struct task *task)
{
    list_add_tail(&task->list, &tasklist);
}

static void task_switch()
{
    if (!list_empty(&tasklist)) {
        struct task *t = list_first_entry(&tasklist, struct task, list);
        list_del(&t->list);
        cur_task = t;
        longjmp(t->env, 1);
    }
}

void schedule(void)
{
    static int i;
    i = 0;
    setjmp(sched);
    while (ntasks-- > 0) {
        struct arg arg = args[i];
        tasks[i++](&arg);
        printf("Never reached\n");
    }
    task_switch();
}
task0 , task1
/*negamax*/
void task0(void *arg)
{
    if (setjmp(task->env) == 0) {
        task_add(task);
        longjmp(sched, 1);
    }
    while (1) {
        task = cur_task;
        if (setjmp(task->env) == 0) {
            
            char win = check_win(task->table);
            if (win == 'D') {
                draw_board(task->table);
                printf("It is a draw!\n");
                break;
            } 
            
            draw_board(task->table);
            int move = negamax_predict(task->table, task->turn).move;
            if (move != -1) {
                task->table[move] = task->turn;
                record_move(move);
            }

            task_add(task);
            task_switch();
        }
    }
    print_all_moves();
    printf("%s: complete\n", task->task_name);
    longjmp(sched, 1);
}

/*mcts*/
void task1(void *arg)
{
    if (setjmp(task->env) == 0) {
        task_add(task);
        longjmp(sched, 1);
    }
    while (1) {
        task = cur_task;
        if (setjmp(task->env) == 0) {
            
            char win = check_win(task->table);
            if (win == 'D') {
                draw_board(task->table);
                printf("It is a draw!\n");
                break;
            } 
            
            draw_board(task->table);
            int move = mcts(task->table, task->turn);
            if (move != -1) {
                task->table[move] = task->turn;
                record_move(move);
            }

            task_add(task);
            task_switch();
        }
    }
    printf("%s: complete\n", task->task_name);
    longjmp(sched, 1);
}

task0task1 的結構一樣,有三大部份,第一部份根據 schedule() 設定的參數決定迴圈次數,將 task 加入排程後呼叫 longjmp(sched, 1),讓 schedule() 可繼續將新任務加進排程。當所有任務被加入排程後,schedule() 會呼叫 task_join(&tasklist),其中,會根據 list 的 first entry longjmp 回該 task 的 setjmp 位置:

if (setjmp(task->env) == 0) {
        task_add(task);
        longjmp(sched, 1);
    }

第二部份是兩個 task 交互排程,每次執行一個 loop 後,呼叫 task_add() 重新將 task 加入 list 的尾端,並呼叫 task_switch 指定 list 的 first task 執行:

while(1) {
    if (setjmp(env) == 0) {
        task_add(task);
        task_switch();
    }
}

第三部份完成該 task,會呼叫 longjmp(sched, 1) 跳到 schedule(),接著會執行 task_join(&tasklist) 執行尚未執行完的 task:

printf("Task 1: complete\n");
longjmp(sched, 1);

TODO: 引入其他快速的 PRNG 實作並比較 MCTS 實作獲得的效益

MCTS 中,利用 PRNG 來產生亂數,以隨機挑選每次模擬中下一步的位置。原本的實作使用 glibcrand() 函式來產生亂數。以 4x4 的棋盤為例,每次最多可能有 15 種不同的選項,我用 014 來代表這些選項,並使用 rand() 進行了一億次的抽樣,接著利用 perf stat 執行五次 ttt 來測量程式的效能表現,分佈結果如下。

int move = moves[mt19937_rand() % n_moves];

rand()

rand

Performance counter stats for './qtest':

           4091.34 msec task-clock                       #    0.153 CPUs utilized             
                40      context-switches                 #    9.777 /sec                      
                 6      cpu-migrations                   #    1.467 /sec                      
            4,9967      page-faults                      #   12.213 K/sec                     
     197,2995,5128      cycles                           #    4.822 GHz                       
     348,7598,8301      instructions                     #    1.77  insn per cycle            
      62,0083,0615      branches                         #    1.516 G/sec                     
       1,4453,6081      branch-misses                    #    2.33% of all branches           

      26.799816027 seconds time elapsed

       4.045053000 seconds user
       0.046658000 seconds sys

mt19937_rand()

mt

Performance counter stats for './qtest':

           3850.85 msec task-clock                       #    0.274 CPUs utilized             
                27      context-switches                 #    7.011 /sec                      
                 7      cpu-migrations                   #    1.818 /sec                      
            4,8811      page-faults                      #   12.675 K/sec                     
     185,1063,2378      cycles                           #    4.807 GHz                       
     324,6618,8678      instructions                     #    1.75  insn per cycle            
      57,2558,3620      branches                         #    1.487 G/sec                     
       1,4270,6467      branch-misses                    #    2.49% of all branches           

      14.057938599 seconds time elapsed

       3.811129000 seconds user
       0.040032000 seconds sys

SplitMix64()

split

Performance counter stats for './qtest':

           3906.89 msec task-clock                       #    0.178 CPUs utilized             
                19      context-switches                 #    4.863 /sec                      
                 4      cpu-migrations                   #    1.024 /sec                      
            4,3332      page-faults                      #   11.091 K/sec                     
     187,3365,9566      cycles                           #    4.795 GHz                       
     330,9484,1594      instructions                     #    1.77  insn per cycle            
      58,3272,5701      branches                         #    1.493 G/sec                     
       1,3575,9274      branch-misses                    #    2.33% of all branches           

      21.998475802 seconds time elapsed

       3.859266000 seconds user
       0.048090000 seconds sys

wyhash()

wy

Performance counter stats for './qtest':

           3896.78 msec task-clock                       #    0.360 CPUs utilized             
                24      context-switches                 #    6.159 /sec                      
                 0      cpu-migrations                   #    0.000 /sec                      
            4,7946      page-faults                      #   12.304 K/sec                     
     187,4050,0413      cycles                           #    4.809 GHz                       
     331,1468,0080      instructions                     #    1.77  insn per cycle            
      58,1496,6440      branches                         #    1.492 G/sec                     
       1,4438,8132      branch-misses                    #    2.48% of all branches           

      10.825147098 seconds time elapsed

       3.835545000 seconds user
       0.061725000 seconds sys

結論:
亂數分佈除了 mt19937_rand() 相對平均一些,其他分佈都有較明顯的差異。
雖然每個方法都大約差了0.05 至 0.2秒,和差幾億個 cycle,但是沒有像 vax-r 同學的實驗結果差異那麽明顯,目前不確定是否是實驗方法有細節的不同。

TODO: 改善定點數

改寫 lab0-c 既有的 shannon_entropy.c 和 log2_lshift16.h,採用更有效、準確度 (accuracy) 更高的定點數運算實作,需要有對應的數學統計分析和實際執行的討論。