Try   HackMD

Linux 核心專題: 井字遊戲改進

執行人: Daniel-0224

任務簡介

重做第三次作業,並彙整其他學員的成果。

TODO: 以定點數實作 Monte Carlo tree search (MCTS)

MCST(Monte Carlo Search Tree) 是一種搜尋樹的算法。它通常應用於決策問題的求解,特別是在棋類等遊戲中。MCST 通過隨機模擬和搜尋樹的擴展來評估每個可能的決策,以找到最佳的行動策略。
MCTS 在每一次的疊代一共會有 4 步,分別為:

Selection :
在搜索樹中從根節點開始,根據某種策略選擇下一步要擴展的節點,以找到潛在最佳的行動。
Expansion :
對於選擇的節點,擴展子節點,即生成可能的下一步行動或狀態。這些子節點尚未被完全探索,需要進一步評估其價值。
Simulation :
對每個擴展的子節點進行模擬或評估。通常使用蒙特卡羅模擬來模擬隨機的遊戲或決策過程,以估計每個子節點的潛在價值或勝率。
Backpropagation :
將模擬結果向上反向傳播到搜索樹的根節點,更新每個節點的統計信息,如訪問次數和累計獎勵。這有助於調整每個節點的價值估計,以改進未來的選擇策略。

透過不斷迭代的選擇、擴展、模擬和反向傳播,MCST能夠在復雜的決策問題中尋找到較佳的解決方案。

定點數運算:

定點數運算中, Q notation 是一種指定二進制定點數參數的方法。

  • Qm.f:例如 Q23.8 表示該數有 23 個整數位、8 個小數位,是 32 位二補數。

假設我們將 fraction bit 設為

d,則一定點數
N
的實際值為
Nd
因此當我們在做頂點數的加減時可以直接相加。

(N1+N2)÷d=N1+N2d

(N1N2)÷d=N1N2d

但是如果對於定點數的乘除不多加處理的話,結果會變為:

(N1N2)÷d=N1N2d

(N1/N2)÷d=N1/N2d

顯然與預期的結果不同,因此在計算完定點樹的乘除之,需要額外對積數與商數右移或左移

d

(N1N2)d÷d=N1N2d2=N1dN2d

(N1)N2dd=N1N2=N1d/N2d

了解定點數,將 MCTS 當中浮點數運算進行以下更動。
首先在 console.h' 定義 SCALING_BITS`。

#define SCALING_BITS 8

calculate_win_value ,將 1.0, 0.0, 0.5 全部更換成自定義定點數的形式。

Q23_8 calculate_win_value(char win, char player)
{
    if (win == player)
-        return 1.0;
+        return 1U << SCALING_BITS;
    if (win == (player ^ 'O' ^ 'X'))
-        return 0.0;
+        return 0U;
-    return 0.5;
+    return 1U << (SCALING_BITS - 1);
}

另外一處重點變更在於 uct_score 的計算函式。
公式如下

wini+cln(Ni)ni

  • wi
    代表第 i 次決定後該節點贏的次數
  • ni
    代表該節點在第 i 次決定後共做幾次模擬
  • Ni
    代表該節點的父節點在第 i 次決定後共做幾次模擬
  • c
    代表 exploration parameter ,通常定為
    2

因為在 EXPLORATION_FACTOR * sqrt(log(n_total) / n_visits) 這一項, sqrt(), log() 本身就使用浮點數運算,因此我們要設計對應的定點數運算來取代這兩個函式。

static inline Q23_8 uct_score(int n_total, int n_visits, Q23_8 score)
{
    if (n_visits == 0)
        return INT32_MAX;
    Q23_8 result = score << Q / (Q23_8) (n_visits << Q);
    int64_t tmp =
        EXPLORATION_FACTOR * fixed_sqrt(fixed_log(n_total) / n_visits);
    Q23_8 resultN = tmp >> Q;
    return result + resultN;
}

Fixed point Square root

此處,利用 第三周測驗一 的方法實作,並且轉為定點數。

int i_sqrt(int x)
{
    if (x <= 1)
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}
Q23_8 fixed_sqrt(Q23_8 x)
{
    if (x <= 1 << Q)
        return x;

    Q23_8 z = 0;
    for (Q23_8 m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;
    }
    z = z << Q / 2;
    return z;
}

Fixed point log

泰勒展開式

嘗試使用泰勒展開實作自然對數的計算,公式如下:

ln(1+x)=xx22+x33x44+......

以下是我的實作:

Q23_8 fixed_log(Q23_8 x)
{
    Q23_8 fixed_x = x << SCALING_BITS;

    if (x == 0) return UINT32_MAX;

    if (x == 1) return 0ULL;

   Q23_8 result = 0;

    for (int i = 1; i <= 20; ++i) {
        if (i % 2 == 0) {
            result -= fixed_power_int(fixed_x, FIXED_SCALING_BITS, i) / (i << SCALING_BITS);
        } else {
            result += fixed_power_int(fixed_x, FIXED_SCALING_BITS, i) / (i << SCALING_BITS);
        }
        result = result >> SCALING_BITS;
    }
    return result;
}

但自己實作時遇到了一些問題,目前還沒解決。

因此參考了 marvin0102 同學的實作

Q23_8 fixed_log(int input)
{
    if (!input || input == (1U << Q))
        return 0;

    int64_t y = input << Q;
    y = ((y - (1U << Q)) << (Q)) / (y + (1U << Q));
    int64_t ans = 0U;
    for (unsigned i = 1; i < 20; i += 2) {
        int64_t z = (1U << Q);
        for (int j = 0; j < i; j++) {
            z *= y;
            z >>= Q;
        }
        z <<= Q;
        z /= (i << Q);

        ans += z;
    }
    ans <<= 1;
    return (Q23_8) ans;
}

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

上圖是我利用同學的方法與浮點數 log 比較,可以從實驗結果發現使用泰勒展開,當數字越大時,誤差越大,且需要越多的項次才能做精準的估算。

逼近求近似

為了改善泰勒展開數字越大時誤差越大的問題,嘗試利用逼近法求 log 的近似值,假設

A<=x<=B ,求其近似值,方法如下:

CAB,因為
log(C)log(AB)=12(log(A)+log(B)
,因此我們可以得到
log(C)
log(A)
log(B)
的平均。

接著比較

x
C
,如果
x=C
及得到我們所求,如果
x<C
,我們則將
B
替換成
C
,如果
x>C
,我們則將
A
替換成
C
,繼續下一輪的求近似。

因為求近似值時,須先知道

log(A)
log(B)
的實際值,因此實作時,將以
log2(x)
做計算,且
A
B
分別為
2m
2m+1
,其中
2m<=x<=2m+1

      A   x   B
------>---|---<------

實作如下:

Q23_8 fixed_log(int input)
{
    if (!input || input == 1)
        return 0;

    Q23_8 y = input << Q;  // int to Q15_16
    Q23_8 L = 1L << ((31 - __builtin_clz(y))), R = L << 1;
    Q23_8 Llog = (31 - __builtin_clz(y) - Q) << Q, Rlog = Llog + (1 << Q), log;

    for (int i = 1; i < 20; i++) {
        if (y == L)
            return Llog;
        else if (y == R)
            return Rlog;
        log = fixed_div(Llog + Rlog, 2 << Q);

        int64_t tmp = ((int64_t) L * (int64_t) R) >> Q;
        tmp = fixed_sqrt((Q23_8) tmp);

        if (y >= tmp) {
            L = tmp;
            Llog = log;
        } else {
            R = tmp;
            Rlog = log;
        }
    }

    return (Q23_8) log;
}

comp

歐拉方法的結果幾乎和浮點數 log 相等,但這個方法我目前不了解它是如何運作。

TODO: 實作「人 vs.電腦」和「電腦 vs. 電腦」的對弈

整合 linux/list.h 和 'jserv/ttt`

首先將 Linux 核心的 linux/list.h 標頭檔與 hlist 相關的程式碼抽出,成為 hlist.h 整合進 lab-0,再將 jserv/ttt 專案的程式碼部分整合進 lab0-c 。其中包含 consolegamemt19937mctsnegamex 等檔案。

新增 ttt 命令

jserv/tttmain 檔案修改整合入 console.c ,新增ADD_COMMMAND ,修改 Makefile 讓程式能夠執行。

ADD_COMMAND(ttt, "Play Tic-Tac-Toe Game", "");

commit 95a09bc

人 vs.電腦

做完上述方法就能執行 qtest 再使用 do_ttt 開始玩 Tic-Tac-Toe

tseng@tseng-System-Product-Name:~/linux2024/lab0-c$ ./qtest
cmd> ttt
 1 |             
 2 |             
 3 |             
 4 |             
---+-------------
      A  B  C  D
X> b2
 1 |             
 2 |     ×     ○ 
 3 |             
 4 |             
---+-------------
      A  B  C  D

電腦 vs.電腦

增加 option 指令 AI_vs_AI,讓 AIAI 對弈,使用的演算法是 negamax

static int ai_vs_ai = 0;
else if (ai_vs_ai) {
            int move = negamax_predict(table, ai).move;
            if (move != -1) {
                table[move] = turn;
                record_move(move);
            }
}

預設值為 0,欲更改時需要設定 AI_vs_AI 的參數。在 qtest 要下的命令是:option AI_vs_AI 1

螢幕顯示當下的時間 (含秒數):

static void show_time()
{
    struct timeval currentTime;
    gettimeofday(&currentTime, NULL);
    time_t rawtime;
    time(&rawtime);
    struct tm *timeinfo = localtime(&rawtime);
    printf("Current Time: %02d:%02d:%02d:%02ld\n", timeinfo->tm_hour,
           timeinfo->tm_min, timeinfo->tm_sec, currentTime.tv_usec / 10000);
}

部份結果:

tseng@tseng-System-Product-Name:~/linux2024/lab0-c$ ./qtest
cmd> option AI_vs_AI 1
cmd> ttt
 1 |             
 2 |     ×       
 3 |             
 4 |             
---+-------------
      A  B  C  D
Current Time: 05:28:11:07
 1 |             
 2 |     ×       
 3 |             
 4 |        ○    
---+-------------
      A  B  C  D
Current Time: 05:28:11:55
 1 |             
 2 |     ×       
 3 |             
 4 |     ×  ○    
---+-------------
      A  B  C  D

TODO: 引入 coroutine 來處理對弈

參考 排程器原理
參考 coro
< HenryChaing > 同學

我使用 coroutine 的方式實作「電腦 vs 電腦」的對弈模式,其中電腦 A 是使用 negamax 演算法,而電腦 B 是使用 MCTS 演算法。而電腦 A、B 分別對應到任務一及任務二。至於任務之間的切換,是採用 setjmp + longjmp 的方法。

task

struct task {
    jmp_buf env;
    struct list_head list;
    char task_name[10];
    char *table;
    char turn;
};

首先可以看到 struct task 有 2 個成員,分別為用於 setjmp()jmp_buf env 以及用於排程的 list。第一個成員變數 env 即是用來儲存這次進入任務前的程式執行狀態。

setjmp/longjmp

這兩個函式可以轉換程式執行的順序,其中 setjmp 可以利用 jum_buf 儲存目前程式的狀態,並且在遇到 longjmp(jum_buf) 後跳回 setjmp 並恢復程式的儲存狀態,這樣的函式設計可以方便程式執行時在不同任務間轉換。

task_add/task_switch

這是主要切換任務的函式,我們用 list_head 構成的循環雙向鏈結串列存放即將執行的任務,也就是存放 jmp_buf 。其中 task_add 可以將任務加到串列當中, task_switch 可以切換到我們紀錄的下一個任務執行。

流程設計參照下方程式碼, schedule 函式會將兩個任務放到佇列中,而任務執行完的當下會再將這個任務加到佇列當中,若此對局勝負揭曉則不會再將加到佇列當中,佇列為空也就代表並行程式結束執行。

static void task_add(struct task *task)
{
    list_add_tail(&task->list, &tasklist);
}

static void task_switch()
{
    if (!list_empty(&tasklist)) {
        struct task *t = list_first_entry(&tasklist, struct task, list);
        list_del(&t->list);
        cur_task = t;
        longjmp(t->env, 1);
    }
}

void schedule(void)
{
    static int i;
    i = 0;
    setjmp(sched);
    while (ntasks-- > 0) {
        struct arg arg = args[i];
        tasks[i++](&arg);
        printf("Never reached\n");
    }
    task_switch();
}
task0 , task1
/*negamax*/
void task0(void *arg)
{
    if (setjmp(task->env) == 0) {
        task_add(task);
        longjmp(sched, 1);
    }
    while (1) {
        task = cur_task;
        if (setjmp(task->env) == 0) {
            
            char win = check_win(task->table);
            if (win == 'D') {
                draw_board(task->table);
                printf("It is a draw!\n");
                break;
            } 
            
            draw_board(task->table);
            int move = negamax_predict(task->table, task->turn).move;
            if (move != -1) {
                task->table[move] = task->turn;
                record_move(move);
            }

            task_add(task);
            task_switch();
        }
    }
    print_all_moves();
    printf("%s: complete\n", task->task_name);
    longjmp(sched, 1);
}

/*mcts*/
void task1(void *arg)
{
    if (setjmp(task->env) == 0) {
        task_add(task);
        longjmp(sched, 1);
    }
    while (1) {
        task = cur_task;
        if (setjmp(task->env) == 0) {
            
            char win = check_win(task->table);
            if (win == 'D') {
                draw_board(task->table);
                printf("It is a draw!\n");
                break;
            } 
            
            draw_board(task->table);
            int move = mcts(task->table, task->turn);
            if (move != -1) {
                task->table[move] = task->turn;
                record_move(move);
            }

            task_add(task);
            task_switch();
        }
    }
    printf("%s: complete\n", task->task_name);
    longjmp(sched, 1);
}

task0task1 的結構一樣,有三大部份,第一部份根據 schedule() 設定的參數決定迴圈次數,將 task 加入排程後呼叫 longjmp(sched, 1),讓 schedule() 可繼續將新任務加進排程。當所有任務被加入排程後,schedule() 會呼叫 task_join(&tasklist),其中,會根據 list 的 first entry longjmp 回該 task 的 setjmp 位置:

if (setjmp(task->env) == 0) {
        task_add(task);
        longjmp(sched, 1);
    }

第二部份是兩個 task 交互排程,每次執行一個 loop 後,呼叫 task_add() 重新將 task 加入 list 的尾端,並呼叫 task_switch 指定 list 的 first task 執行:

while(1) {
    if (setjmp(env) == 0) {
        task_add(task);
        task_switch();
    }
}

第三部份完成該 task,會呼叫 longjmp(sched, 1) 跳到 schedule(),接著會執行 task_join(&tasklist) 執行尚未執行完的 task:

printf("Task 1: complete\n");
longjmp(sched, 1);

TODO: 引入其他快速的 PRNG 實作並比較 MCTS 實作獲得的效益

MCTS 中,利用 PRNG 來產生亂數,以隨機挑選每次模擬中下一步的位置。原本的實作使用 glibcrand() 函式來產生亂數。以 4x4 的棋盤為例,每次最多可能有 15 種不同的選項,我用 014 來代表這些選項,並使用 rand() 進行了一億次的抽樣,接著利用 perf stat 執行五次 ttt 來測量程式的效能表現,分佈結果如下。

rand()

rand

Performance counter stats for './qtest':

           4091.34 msec task-clock                       #    0.153 CPUs utilized             
                40      context-switches                 #    9.777 /sec                      
                 6      cpu-migrations                   #    1.467 /sec                      
            4,9967      page-faults                      #   12.213 K/sec                     
     197,2995,5128      cycles                           #    4.822 GHz                       
     348,7598,8301      instructions                     #    1.77  insn per cycle            
      62,0083,0615      branches                         #    1.516 G/sec                     
       1,4453,6081      branch-misses                    #    2.33% of all branches           

      26.799816027 seconds time elapsed

       4.045053000 seconds user
       0.046658000 seconds sys

mt19937_rand()

mt

Performance counter stats for './qtest':

           3850.85 msec task-clock                       #    0.274 CPUs utilized             
                27      context-switches                 #    7.011 /sec                      
                 7      cpu-migrations                   #    1.818 /sec                      
            4,8811      page-faults                      #   12.675 K/sec                     
     185,1063,2378      cycles                           #    4.807 GHz                       
     324,6618,8678      instructions                     #    1.75  insn per cycle            
      57,2558,3620      branches                         #    1.487 G/sec                     
       1,4270,6467      branch-misses                    #    2.49% of all branches           

      14.057938599 seconds time elapsed

       3.811129000 seconds user
       0.040032000 seconds sys

SplitMix64()

split

Performance counter stats for './qtest':

           3906.89 msec task-clock                       #    0.178 CPUs utilized             
                19      context-switches                 #    4.863 /sec                      
                 4      cpu-migrations                   #    1.024 /sec                      
            4,3332      page-faults                      #   11.091 K/sec                     
     187,3365,9566      cycles                           #    4.795 GHz                       
     330,9484,1594      instructions                     #    1.77  insn per cycle            
      58,3272,5701      branches                         #    1.493 G/sec                     
       1,3575,9274      branch-misses                    #    2.33% of all branches           

      21.998475802 seconds time elapsed

       3.859266000 seconds user
       0.048090000 seconds sys

wyhash()

wy

Performance counter stats for './qtest':

           3896.78 msec task-clock                       #    0.360 CPUs utilized             
                24      context-switches                 #    6.159 /sec                      
                 0      cpu-migrations                   #    0.000 /sec                      
            4,7946      page-faults                      #   12.304 K/sec                     
     187,4050,0413      cycles                           #    4.809 GHz                       
     331,1468,0080      instructions                     #    1.77  insn per cycle            
      58,1496,6440      branches                         #    1.492 G/sec                     
       1,4438,8132      branch-misses                    #    2.48% of all branches           

      10.825147098 seconds time elapsed

       3.835545000 seconds user
       0.061725000 seconds sys

結論:
亂數分佈除了 mt19937_rand() 相對平均一些,其他分佈都有較明顯的差異。
雖然每個方法都大約差了0.05 至 0.2秒,和差幾億個 cycle,但是沒有像 vax-r 同學的實驗結果差異那麽明顯,目前不確定是否是實驗方法有細節的不同。

TODO: 改善定點數

改寫 lab0-c 既有的 shannon_entropy.c 和 log2_lshift16.h,採用更有效、準確度 (accuracy) 更高的定點數運算實作,需要有對應的數學統計分析和實際執行的討論。