Try   HackMD

N04: kxo

主講人: jserv / 課程討論區: 2025 年系統軟體課程
:mega: 返回「Linux 核心設計」課程進度表

:memo: 預期目標

  1. 既然人工智慧已融入你我的生活,本課程呼應杜威博士提出在實踐中學習 (亦譯作「做中學」) 的教育思想,引導學員接觸 (古典) 人工智慧的基本概念。
  2. 藉由改寫井字遊戲來熟悉數值系統bitwise 操作排程器原理和 Linux 核心的程式開發介面

:rocket: 井字遊戲

在 Google 網頁搜尋 "tictactoe",會發現網頁出現井字遊戲 (tic-tac-toe),可以在網頁直接玩。

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

jserv/ttt 專案提供可跟電腦對弈的互動環境,其對弈的演算法包含以下:

編譯並執行:

$ make
$ ./ttt
 1 |             
 2 |             
 3 |             
 4 |             
---+-------------
      A  B  C  D
X> 

輸入 b2 (即橫向 ABCD 的 "B" 和縱向 1234 的 "2") 並按下 Enter 鍵,表示人類下 × 棋子,電腦程式會以 棋子做出回應,例如:

 1 |             
 2 |     ×  ○    
 3 |             
 4 |             
---+-------------
      A  B  C  D

最終賽局的參考輸出:

 1 |             
 2 |     ×  ○    
 3 |        ×  ○ 
 4 |           × 
---+-------------
      A  B  C  D
X won!
Moves: B2 -> C2 -> C3 -> D3 -> D4

閱讀 ttt 專案的 "Game Rules" 段落以得知勝負評斷規則。

關於 MCTS 演算法應用於井字遊戲,參閱〈Tic Tac Toe at the Monte Carlo〉,搭配以下解說影片:

馬里蘭大學的網站 Monte Carlo Tree Search Tic-Tac-Toe AI 提供互動的網頁,可理解 MCTS 演算法。

「蒙地卡羅方法」是利用隨機抽樣技術進行模擬以解決數學問題的策略,在第二次世界大戰期間首次被系統化應用於科學研究,促成 MANIAC(Mathematical Analyzer, Numerical Integrator and Computer)的誕生。由 Stanislaw Ulam, John von Neumann, Nicholas Metropolis、Enrico Fermi 等學者開發,這種以抽樣統計為基礎的方法,主要用於解決原子彈設計中的中子隨機擴散問題及 Schrödinger 方程特徵值的估算問題。最初由 Stanislaw Ulam 提出概念,後經 John von Neumann 深入研究,並於 1949 年以 The Monte Carlo method 一文公諸於世,隨著電腦時代的來臨,該手法從原始的手動產生隨機數解題轉變為現今的數值方法。

「蒙地卡羅方法」的命名源自 Nicholas Metropolis,靈感來自於其隨機性質與賭博的關聯,尤其是與北非蒙地卡羅城市中華麗的賭場生活息息相關。這方法適用於任何包含隨機性的過程,能夠透過大量模擬單一事件並統計平均值,來獲得在特定條件下的最可能結果。廣泛應用於自然現象如布朗運動、電波噪聲、基因突變、即時交通狀況等,顯示其多樣的適用性。又因該方法的高度可平行處理特性,廣泛應用於電腦圖形學中的光線追蹤技術以及分子動力學模擬等領域。

延伸閱讀: 亂灑一地的針其實有意義!布豐實驗與蒙地卡羅方法

:calling: Linux 核心的浮點數運算

Robert Love 在 Linux Kernel Development 一書論及浮點運算:

No (Easy) Use of Floating Point
When using floating-point instructions kernel normally catches a trap and then initiates the transition from integer to floating point mode. Unlike user-space, the kernel does not have the luxury of seamless support for floating point because it cannot easily trap itself. Using a floating point inside the kernel requires manually saving and restoring the floating point registers. Except in the rare cases, no floating-point operations are in the kernel.

Rusty Russell 在 Unreliable Guide To Hacking The Linux Kernel 則說:

The FPU context is not saved; even in user context the FPU state probably won't correspond with the current process: you would mess with some user process' FPU state. If you really want to do this, you would have to explicitly save/restore the full FPU state (and avoid context switches). It is generally a bad idea; use fixed point arithmetic first.

Lazy FP state restore

CVE-2018-3665 存在於 Intel Core 系列微處理器中,因為 speculative execution(推測執行)技術中的一些缺陷加上特定作業系統中對 FPU(Floating point unit)進行 context switch 所產生的漏洞,允許一個本地端的程式洩漏其他程序的 FPU 暫存器內容。

Lazy FP state leak 的原理是透過 FPU 的 lazy state switching 機制達成。因為 FPU 和 SIMD 暫存器不一定會在每個任務持續被使用到,因此作業系統排程器可將不需要使用到 FPU 的任務,標註為不可使用 FPU,而不更改裡面的內容。發生 context switch 時,核心先對 FPU/SIMD 做標記,假如在待會新的 context 都沒有用到 FPU/SIMD 相關運算則其 state 將持續保存上個 context 的內容,反之,則會在新的 context 進行相關 ops 前對 kernel 發 trap,以使其保存上個 context 的內容。有鑑於此特性,資訊安全人員發現其可能被濫用的缺陷,因為 FPU/SIMD 相關暫存器不只保存浮點數運算相關資料,其也被用來暫存加密相關資料,例如 AES 指令集。

作業系統相關工程師為了修復此缺陷讓作業系統效能下降 (其中一個解決方法是每次 context switch 都強制 FPU/SIMD state 一起切換,進而增加不少切換時間開銷),而麻省理工的研究員在 2019 年底提出同時能解決此缺陷又不失效能的途徑 A better approach to preventing Meltdown/Spectre attacks

然而,在現今的亂序執行 CPU 中,lazy state switching 裡會設定的 "FPU not available" 可能沒辦法馬上被偵測到,導致我們在 process B,但仍然可以存取到 process A 的 FPU 暫存器內容,進而達到攻擊的目的。

基於上述原因,儘管我們在 Linux 核心模式中仍可使用浮點運算,但這不僅會拖慢執行速度,還需要手動儲存/取回相關的 FPU 暫存器,因此核心文件不建議在核心中使用到浮點運算。

afcidk 透過開發簡單的 Linux 核心模組,來測試在單純的浮點數運算及整數運算花費的時間差異。

程式碼可見 floating.ko

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

可見到,以單純的運算來說,核心模式中的浮點數運算,時間成本較整數運算高。

定點數

並非所有 Linux 支援的硬體都具備 FPU,且 Linux 核心也不建議進行浮點數計算,於是實務上要改用定點數計算 (Fixed-point arithmetic)。相較於 IEEE 754 浮點數標準,定點數運算沒有一致的格式,常見的處理方式如下:

  • 小數點的位置固定
  • 以十進位為例,精確度到小數點後 3 位,number10×103=10
  • 以 2 進位為例,精確度到小數點後 4 位,number2×24=2

定點數的運算存在以下常見規格:

  • 定點數加法與減法 - 可直接進行
  • 定點數乘法 - 結果須再 bn (b 為進制,n 為小數位數)
  • 定點數除法 - 結果須再 bn (b 為進制,n 為小數位數)

uptime 命令可顯示過去 1 / 5 / 15 分鐘的系統平均負載,類似以下輸出:

$ uptime
11:40:45 up 49 days,  2:47,  3 users,  load average: 0.01, 0.06, 0.02

Linux 核心一類的主流作業系統提供 load average,顧名思義就是要查看近期的負載平均,近期可分成三種時間間隔,計算公式如下:
loadt=n×α+(1α)×loadt1

這個公式很常見,在作業系統決定下一個 cpu burst 時就會運用 Exponential average。這個公式的好處是簡單且省記憶體的把過去的歷史資料都記住,而且會 aging。而很明顯的,這個公式的計算絕對牽涉的浮點數運算,而 Linux 不建議使用浮點運算,於是定點運算就派上用場。

fixed_power_int

研讀 Linux 核心原始程式碼的 Load Average 處理機制,可發現 linux/kernel/sched/loadavg.c 有一個用來計算定點數乘冪的函式 fixed_power_int,適合作為理解定點數計算的範例

/**
 * fixed_power_int - compute: x^n, in O(log n) time
 *
 * @x:         base of the power
 * @frac_bits: fractional bits of @x
 * @n:         power to raise @x to.
 *
 * By exploiting the relation between the definition of the natural power
 * function: x^n := x*x*...*x (x multiplied by itself for n times), and
 * the binary encoding of numbers used by computers: n := \Sum n_i * 2^i,
 * (where: n_i \elem {0, 1}, the binary vector representing n),
 * we find: x^n := x^(\Sum n_i * 2^i) := \Prod x^(n_i * 2^i), which is
 * of course trivially computable in O(log_2 n), the length of our binary
 * vector.
 */
static unsigned long fixed_power_int(unsigned long x,
                                     unsigned int frac_bits,
                                     unsigned int n)
{
    unsigned long result = 1UL << frac_bits;

    if (n) {
        for (;;) {
            if (n & 1) {
                result *= x;
                result += 1UL << (frac_bits - 1);
                result >>= frac_bits;
            }
            n >>= 1;
            if (!n)
                break;
            x *= x;
            x += 1UL << (frac_bits - 1);
            x >>= frac_bits;
        }
    }

    return result;                                                                                                                                            
}

以二進位來思考:
xn=xni2i, ni0,1

x11=x120×x121×x022×x123 為例:

  • if (n & 1) 是否成立,即對應到 ni 是否為 1
  • 1UL << frac_bits 代表定點數 1 (二進位,小數點後有 frac_bits 位)
  • x += 1UL << (frac_bits - 1) 會讓 x 進行無條件進位

因為是進行定點數乘法,還要再除 2frac_bits,亦即 x >>= frac_bits

loadavg_proc_show

man proc(5) 可知 /proc/loadavg 記錄著 load average 相關的資訊

$ cat /proc/loadavg

觀察 fs/proc/loadavg.c 會發現 /proc/loadavg 是個 pseudo 檔案系統,對照實際輸出資料的函式 loadavg_proc_show

static int loadavg_proc_show(struct seq_file *m, void *v)
{
    unsigned long avnrun[3];

    get_avenrun(avnrun, FIXED_1 / 200, 0);

    seq_printf(m, "%lu.%02lu %lu.%02lu %lu.%02lu %ld/%d %d\n",
               LOAD_INT(avnrun[0]), LOAD_FRAC(avnrun[0]), LOAD_INT(avnrun[1]),
               LOAD_FRAC(avnrun[1]), LOAD_INT(avnrun[2]), LOAD_FRAC(avnrun[2]),
               nr_running(), nr_threads,
               idr_get_cursor(&task_active_pid_ns(current)->idr) - 1);
    return 0;
}

觀察 seq_printf 使用的格式 %lu.%02lu,再由 include/linux/sched/loadavg.h 理解 LOAD_INTLOAD_FRAC 的實作

#define LOAD_INT(x) ((x) >> FSHIFT)
#define LOAD_FRAC(x) LOAD_INT(((x) & (FIXED_1-1)) * 100)

%lu.%02lu 分別使用 LOAD_INTLOAD_FRAC 來組成完整的十進位數值,其中 LOAD_INT 直接取用整數部分,而 LOAD_FRAC 會取用小數點 (fraction) 部分,注意 FIXED_1-1 可視為小數點部分的 mask,乘上 100 的目的是保留十進位下的小數點後 2 位,而 FIXED_1/200 可理解為是十進位下的定點數 0.005,由於結果會保留小數點後 2 位,加上此數值目的是進行四捨五入

load average 的資料保存於 avnrun

Load Average

linux/kernel/sched/loadavg.c 可見負責更新 avnrun 的函式

/*
 * calc_load - update the avenrun load estimates 10 ticks after the
 * CPUs have updated calc_load_tasks.
 *
 * Called from the global timer code.
 */
void calc_global_load(unsigned long ticks)
{
    unsigned long sample_window;
    long active, delta;

    sample_window = READ_ONCE(calc_load_update);
    if (time_before(jiffies, sample_window + 10))
        return;

    /* Fold the 'old' NO_HZ-delta to include all NO_HZ CPUs. */
    delta = calc_load_nohz_fold();
    if (delta)
        atomic_long_add(delta, &calc_load_tasks);

    active = atomic_long_read(&calc_load_tasks);
    active = active > 0 ? active * FIXED_1 : 0;

    avenrun[0] = calc_load(avenrun[0], EXP_1, active);
    avenrun[1] = calc_load(avenrun[1], EXP_5, active);
    avenrun[2] = calc_load(avenrun[2], EXP_15, active);

    WRITE_ONCE(calc_load_update, sample_window + LOAD_FREQ);

    /*
     * In case we went to NO_HZ for multiple LOAD_FREQ intervals
     * catch up in bulk.
     */
    calc_global_nohz();
}

calc_load 是實際計算 load average 的函式,在 include/linux/sched/loadavg.h 可見其函式實作

/*
 * a1 = a0 * e + a * (1 - e)
 */
static inline unsigned long calc_load(unsigned long load,
                                      unsigned long exp,
                                      unsigned long active)
{
    unsigned long newload;

    newload = load * exp + active * (FIXED_1 - exp);
    if (active >= load)
        newload += FIXED_1 - 1;

    return newload / FIXED_1;
}

計算公式使用 Exponential smoothing,可參閱〈Linux Kernel Load Average 計算分析〉。其公式如下:
St=α×Xt1+(1α)×St1,where0<α<1

  • St : 時間 t 的平滑均值。
  • St1 : 時間 t-1 的平滑均值。
  • Xt1 : 時間 t-1 的實際值。
  • α : 平滑因子 (smoothing factor)。

對應 Linux 核心程式碼如下所示: (取自 include/linux/sched/loadavg.h)

newload = load * exp + active * (FIXED_1 - exp);
  • St = newload
  • St-1 = load
  • Xt-1 = active (目前 'RUNNABLE + TASK_UNINTERRUPTIBLE' Process 總數: 全域變數 calc_load_tasks)
  • α = exp

使用定點數:

  • LSB 11 bits : mantissa
  • MSB 53 bits : exponent

因此定點數 1.0(1 << 11),也就是 FIXED_1 (第 19 行),這就是為何 (1α) 成為 (FIXED_1 - exp)。

  • include/linux/sched/loadavg.h
#define FSHIFT 11 /* nr of bits of precision */ #define FIXED_1 (1<<FSHIFT) /* 1.0 as fixed-point */ #define LOAD_FREQ (5*HZ+1) /* 5 sec intervals */ #define EXP_1 1884 /* 1/exp(5sec/1min) as fixed-point */ #define EXP_5 2014 /* 1/exp(5sec/5min) */ #define EXP_15 2037 /* 1/exp(5sec/15min) */ /* * a1 = a0 * e + a * (1 - e) */ static inline unsigned long calc_load(unsigned long load, unsigned long exp, unsigned long active) { unsigned long newload; newload = load * exp + active * (FIXED_1 - exp); if (active >= load) newload += FIXED_1-1; return newload / FIXED_1; } extern unsigned long calc_load_n(unsigned long load, unsigned long exp, unsigned long active, unsigned int n); #define LOAD_INT(x) ((x) >> FSHIFT) #define LOAD_FRAC(x) LOAD_INT(((x) & (FIXED_1-1)) * 100)
  • calc_load 函式:

第 20 行: 每隔 5 秒計算 1 / 5/ 15 分鐘的系統平均負載。
第 21-23 行: 定義 1 / 5 / 15 分鐘的 α 值。

  • 1 分鐘 α 值: 1/exp(5sec/1min) = 0.92004441462,其定點數 0.92004441462 * FIXED_1 = 1884
  • 5 分鐘 α 值: 1/exp(5sec/5min) = 0.98347145382,其定點數 0.98347145382 * FIXED_1 = 2014
  • 15 分鐘 α 值: 1/exp(5sec/15min) = 0.994459848,其定點數 0.994459848 * FIXED_1 = 2037
    第 34-35 行: 該判斷式成立的話,代表目前 'RUNNABLE + TASK_UNINTERRUPTIBLE' 行程 (process) 總數大於上次的 load average,則 newload 無條件進位。
    第 37 行: 由於在第 33 行進行定點數乘法運算,所以其結果需要還原 (除以 FIXED_1)。
  • calc_global_load 函式

第 355-356 行: 每隔 5 秒 (其實是,5秒 + 10 ticks) 計算 1/5/15 分鐘的系統平均負載。
第 362-367 行: 計算當下 RUNNABLE + TASK_UNINTERRUPTIBLE Process/Task 總數。為何需要加入 TASK_UNINTERRUPTIBLE task 呢?可參閱 Brendan Gregg 的文章 Linux Load Averages: Solving the Mystery,詳盡地解說其來龍去脈,Brendan 一開始想透過 git log -p kernel/sched/loadavg.c 歷史紀錄找出編修紀錄,但對應修改歷史更久遠,回溯至 v0.99.13 到 v0.99.15 (1993 年),而 git 要在 2005 年才出現。如下所示 (節錄自 Linux Load Averages: Solving the Mystery):

From: Matthias Urlichs <urlichs@smurf.sub.org>
Subject: Load average broken ?
Date: Fri, 29 Oct 1993 11:37:23 +0200

The kernel only counts "runnable" processes when computing the load average.
I don't like that; the problem is that processes which are swapping or waiting on "fast", i.e. noninterruptible, I/O, also consume resources.

It seems somewhat nonintuitive that the load average goes down when you replace your fast swap disk with a slow swap disk

Anyway, the following patch seems to make the load average much more consistent WRT the subjective speed of the system. And, most important, the load is still zero when nobody is doing anything. ;-)

--- kernel/sched.c.orig Fri Oct 29 10:31:11 1993
+++ kernel/sched.c  Fri Oct 29 10:32:51 1993
@@ -414,7 +414,9 @@
unsigned long nr = 0;

for(p = &LAST_TASK; p > &FIRST_TASK; --p)
-       if (*p && (*p)->state == TASK_RUNNING)
+       if (*p && ((*p)->state == TASK_RUNNING) ||
+                  (*p)->state == TASK_UNINTERRUPTIBLE) ||
+                  (*p)->state == TASK_SWAPPING))
        nr += FIXED_1;
    return nr;
 }

一言以蔽之,考慮 system load average,而非 CPU load average。

信件發文者 Matthias Urlichs 認為 load average 代表以使用者角度觀察系統忙碌程度,也就是 system load average,而不是 CPU load average。TASK_UNINTERRUPTIBLE 表示 process 正在等待某特定事件 (例如: Task swapping、Disk I/O等等),這也需要算在 system load。底下場景說明為何需要考慮 TASK_UNINTERRUPTIBLE (節錄自 Linux Load Averages: Solving the Mystery):

A heavily disk-bound system might be extremely sluggish but only have a TASK_RUNNING average of 0.1, which doesn't help anybody.

第 368-370 行: 計算 1/5/15 分鐘的系統平均負載。全域變數 avenrun 用定點數格式儲存,所以使用 printk 輸出時,需做對應轉換。
第 372 行: 更新下次計算系統平均負載時間,即 n + 5 秒

  • kernel/sched/loadavg.c
/* * calc_load - update the avenrun load estimates 10 ticks after the * CPUs have updated calc_load_tasks. * * Called from the global timer code. */ void calc_global_load(unsigned long ticks) { unsigned long sample_window; long active, delta; sample_window = READ_ONCE(calc_load_update); if (time_before(jiffies, sample_window + 10)) return; /* * Fold the 'old' NO_HZ-delta to include all NO_HZ CPUs. */ delta = calc_load_nohz_read(); if (delta) atomic_long_add(delta, &calc_load_tasks); active = atomic_long_read(&calc_load_tasks); active = active > 0 ? active * FIXED_1 : 0; avenrun[0] = calc_load(avenrun[0], EXP_1, active); avenrun[1] = calc_load(avenrun[1], EXP_5, active); avenrun[2] = calc_load(avenrun[2], EXP_15, active); WRITE_ONCE(calc_load_update, sample_window + LOAD_FREQ); /* * In case we went to NO_HZ for multiple LOAD_FREQ intervals * catch up in bulk. */ calc_global_nohz(); }

loadavg_proc_show 輸出 1 / 5 / 15 分鐘的系統平均負載:

  • LOAD_INT: 取得整數部分。
  • LOAD_FRAC: 取得小數點部分。因為只取小數點兩位,所以此巨集裡可看到乘以 100。
  • fs/proc/loadavg.c
static int loadavg_proc_show(struct seq_file *m, void *v) { unsigned long avnrun[3]; get_avenrun(avnrun, FIXED_1/200, 0); seq_printf(m, "%lu.%02lu %lu.%02lu %lu.%02lu %ld/%d %d\n", LOAD_INT(avnrun[0]), LOAD_FRAC(avnrun[0]), LOAD_INT(avnrun[1]), LOAD_FRAC(avnrun[1]), LOAD_INT(avnrun[2]), LOAD_FRAC(avnrun[2]), nr_running(), nr_threads, idr_get_cursor(&task_active_pid_ns(current)->idr) - 1); return 0; }

強化學習(Reinforcement Learning)

最基本的強化學習(RL)可利用馬可夫決策過程(Markov Decision Process, MDP)來描述,即有一個狀態集合 S 和一個動作集合 A,透過動作 a 從狀態 s 轉移到狀態 s 的機率為:
Pa(s,s)=Pr(St+1=s|St=s,At=a)

並定義透過 as 轉移到 s 的獎勵(Reward)為 Ra(s,s)

機器學習可以大致分為監督式學習(supervised learning)、非監督式學習(unsupervised learning)及強化學習(RL)三大類,其中 RL 特別適合用來解決某些難以標註標籤(label)且人類無法提供正確答案的問題。

馬可夫決策過程(Markov Decision Process)

馬可夫性質描述隨機過程的無記憶性(memoryless property),即未來的狀態演變僅依賴於目前狀態,而與過去的歷史無關。數學表述如下:
Pr(Xn=xn|Xn1=xn1,,X0=x0)=Pr(Xn=xn|Xn1=xn1)

  • 離散時間馬可夫鏈(Discrete-time Markov Chain)

一個滿足馬可夫性質的隨機變數序列 X1,X2, 可稱為馬可夫鏈(Markov Chain)。若
Pr(X1=x1,,Xn=xn)>0

則滿足
Pr(Xn+1=x|X1=x1,X2=x2,,Xn=xn)=Pr(Xn+1=x|Xn=xn)

亦即,從目前狀態轉移到下一個狀態的機率只依賴目前狀態,而與過去的狀態無關。

這些隨機變數的集合 S 也稱為狀態空間(state space),並且是可數集 (countable set)。當 S 為有限集合時,轉移機率分佈可定義為轉移矩陣(Transition Matrix)P,其中:
pij=Pr(Xn+1=j|Xn=i)

每列的總和為 1,且 P 是一個右隨機矩陣(right stochastic matrix)。

馬可夫決策過程(Markov Decision Process, MDP)可用四元組(4-tuple)(S,A,Pa,Ra) 來描述,其中:

  • S:狀態空間(state space)
  • A:動作空間(action space)(As 代表從 s 開始的可用動作集合)
  • Pa(s,s)=Pr(st+1=s|st=s,at=a):狀態轉移機率
  • Ra(s,s):獎勵函式(Reward function

若對於每個狀態轉移,我們固定採取某個動作,則 MDP 會收斂成一個馬可夫鏈。此時,每條馬可夫鏈都可計算出一個折扣總和(discount sum),定義如下:
E[t=0γtRat(st,st+1)],0γ1

MDP 的目標是找到一個策略函式(policy function) π:SA,使得對應的馬可夫鏈能夠最大化折扣回報。其中 π(s)=a 代表在狀態 s 時選擇動作 a

其中一種解法是動態規劃(Dynamic Programming, DP),定義兩個陣列 V(s)π(s),分別為:

  • V(s):依照該策略一路演變至狀態 s 的折扣回報
  • π(s):對應的最佳動作選擇

遞迴關係如下:
V(s)=sPπ(s)(s,s)(Rπ(s)(s,s)+γV(s))

π(s)=argmaxa{sPa(s,s)(Ra(s,s)+γV(s))}

強化學習演算法(RL Algorithm)

強化學習即是 MDP 的一種,在機率分佈和獎勵函式皆未知的情況下,試圖最大化以下函數:
Q(s,a)=sPa(s,s)(Ra(s,s)+γV(s))

透過學習每個狀態-動作對(state-action pair)(s,a) 來更新估計值。

強化學習的核心精神在於利用過往經驗來決定最佳的動作選擇,以獲得最高的回報。

策略函式 π 定義為:
π:A×S[0,1]

π(a,s)=Pr(At=a|St=s)

即在狀態 s 下選擇動作 a 的機率。

狀態價值函數(State-value Function Vπ(s) 定義為從狀態 s 出發時的期望折扣回報(expected discounted return):
Vπ(s)=E[G|S0=s]=E[t=0γtRt+1|S0=s]

其中,
G=t=0γtRt+1

代表從目前狀態開始,未來所能獲得的累積報酬。

強化學習的核心目標是學習最佳的 π,以最大化 Vπ(s),並透過探索與利用(exploration vs. exploitation)策略來達成最優解。

時間差分學習(Temporal Difference Learning, TD Learning)

時間差分學習(TD Learning)是一種強化學習(Reinforcement Learning, RL)中的核心方法,介於動態規劃(Dynamic Programming, DP)與蒙地卡羅方法(Monte Carlo Methods)之間。

  • TD Learning 不需要完整的環境模型(model-free),而動態規劃方法則需要已知的轉移機率 Pa(s,s) 和獎勵函式 Ra(s,s)
  • 蒙地卡羅方法需要等到整個回合(episode)結束後,才能計算回報並進行學習;而 TD Learning 則是在每步更新學習結果,不需要等待回合結束。

TD Learning 是一種在線學習(online learning)方法,可在每步觀察到新狀態後即時更新價值函式,這使其適用於持續性任務(continuous tasks)。

最基本的時間差分方法是 TD(0),其目標是學習狀態價值函式(state-value function):
V(s)=E[Gt|St=s]

透過 Bellman Equation,我們可遞迴估計 V(s),TD(0) 利用以下的更新規則:
V(St)V(St)+α[Rt+1+γV(St+1)V(St)]

其中:

  • V(St):當前狀態 St 的價值估計
  • Rt+1:當前時間步驟 t 執行動作後獲得的獎勵
  • γ:折扣因子(discount factor),控制未來獎勵的重要性
  • α:學習率(learning rate),控制更新的步伐

這個更新公式的關鍵想法是:目前狀態的價值應該向「即時獎勵 + 折扣後的下一狀態價值」靠攏。

TD(0) 的特點:

  1. 不需要等待回合結束:每個時間步驟都可以立即學習。
  2. 模型無關(model-free):不需要已知的轉移機率 Pa(s,s) 或獎勵函式 Ra(s,s)
  3. 具備隨機逼近性(stochastic approximation):當 α 適當選擇時,V(s) 可收斂到真實價值函式。

jserv/ttt 專案提供基本的 TD learning 實作。


以定點數實作 MCTS

Monte Carlo tree search (MCTS) 演算法是種搜尋樹的演算法,它通常應用於決策問題的求解,特別是在棋類等遊戲中。MCTS 透過隨機模擬和搜尋樹的擴展來評估每個可能的決策,以找到最佳的行動策略。
MCTS 在每次疊代會有 4 步,分別為:

  • Selection: 在搜尋樹中從根節點開始,根據某種策略選擇下一步要擴展的節點,以找到潛在最佳的行動。
  • Expansion: 對於選擇的節點,擴展子節點,即生成可能的下一步行動或狀態。這些子節點尚未被完全探索,需要進一步評估其價值。
  • Simulation: 對每個擴展的子節點進行模擬或評估。通常使用蒙特卡羅模擬來模擬隨機的遊戲或決策過程,以估計每個子節點的潛在價值或勝率。
  • Backpropagation: 將模擬結果向上反向傳播到搜尋樹的根節點,更新每個節點的統計資訊,如存取次數和累計獎勵。這有助於調整每個節點的價值估計,以改進未來的選擇策略。

每步的最終選擇即是所有模擬結果當中做最多次的選擇。

但在 Selection 的時候最大的問題是在一開始只有很少量的 simulation result 時,有可能會過度偏向某個不正確的選則,從而影響後面模擬的選擇(因為 selection 會從過去的模擬結果給的權重做選擇,會往錯的選擇越鑽越深),所以我們必須在 exploration 和 exploitation 之間做平衡。
其中一個作法是利用 UCT (Upper Confidence Bound 1 applied to trees) 來給節點權重,公式如下
wini+cln(Ni)ni

  • wi 代表第 i 次決定後該節點贏的次數
  • ni 代表該節點在第 i 次決定後共做幾次模擬
  • Ni 代表該節點的父節點在第 i 次決定後共做幾次模擬
  • c 代表 exploration parameter ,通常定為 2

此公式的前後兩項分別對應到 exploitation 和 exploration ,當某個節點贏的次數越多,該節點的 exploitation power 越高,而在模擬次數很少時, 後面那一項容易比較高。

使用 UCT 的 MCTS 已經被證明會收斂成 minimax ,但實際上最基本的 MCTS 只有在 Monte Carlo Perfect games 才會收斂。不過 MCTS 有個優勢,它不需要特別給定 evaluation function ,所以幾乎可以適用於任何遊戲規則。特別是在高 branching factor 的遊戲中表現較傳統演算法更優(因此圍棋、西洋棋這類 high branching factor 的遊戲非常適合用 MCTS)。

在 high branching factor 遊戲當中,要建構出包含所有選擇所產生的搜尋樹太耗費成本, MCTS 事實上並不會推導出所有可能結果,而是根據每一輪的 selection, expansion, simulation 與 backpropagation 來建構這棵樹。所以決定節點權重的方法非常重要,因為每一次 selection 都是選擇目前節點的子節點中最高權重的節點往下走,例如上述提到的 UCT 是一種選擇,但這也有機會讓這棵樹往完全錯誤的方向長下去。

定點數運算

定點數運算中, Q notation 是一種指定二進位定點數參數的方法。

  • Qm.f:例如 Q23.8 表示該數有 23 個整數位、8 個小數位,是 32 位二補數。

假設我們將 fraction bit 設為 d,則一定點數 N 的實際值為 Nd 因此當我們在做定點數的加減時可以直接相加。
(N1+N2)÷d=N1+N2d
(N1N2)÷d=N1N2d

但若對於定點數的乘除不多加處理的話,結果會變為:
(N1N2)÷d=N1N2d
(N1/N2)÷d=N1/N2d

顯然與預期的結果不同,因此在計算完定點樹的乘除之,需要額外對積數與商數右移或左移 d
(N1N2)d÷d=N1N2d2=N1dN2d
(N1)N2dd=N1N2=N1d/N2d

了解定點數,將 MCTS 當中浮點數運算進行以下更動。
首先在 console.h 定義 SCALING_BITS

#define SCALING_BITS 8

calculate_win_value ,將 1.0, 0.0, 0.5 全部更換成客製化的定點數的形式。

Q23_8 calculate_win_value(char win, char player)
{
    if (win == player)
-        return 1.0;
+        return 1U << SCALING_BITS;
    if (win == (player ^ 'O' ^ 'X'))
-        return 0.0;
+        return 0U;
-    return 0.5;
+    return 1U << (SCALING_BITS - 1);
}

另一處重點變更在於 uct_score 的計算函式。
公式如下
wini+cln(Ni)ni

  • wi 代表第 i 次決定後該節點贏的次數
  • ni 代表該節點在第 i 次決定後共做幾次模擬
  • Ni 代表該節點的親代節點在第 i 次決定後共做幾次模擬
  • c 代表 exploration parameter ,通常定為 2

因為在 EXPLORATION_FACTOR * sqrt(log(n_total) / n_visits) 這一項, sqrt()log() 使用浮點數運算,因此我們要設計對應的定點數運算來取代這兩個函式。

static inline Q23_8 uct_score(int n_total, int n_visits, Q23_8 score)
{
    if (n_visits == 0)
        return INT32_MAX;
    Q23_8 result = score << Q / (Q23_8) (n_visits << Q);
    int64_t tmp =
        EXPLORATION_FACTOR * fixed_sqrt(fixed_log(n_total) / n_visits);
    Q23_8 resultN = tmp >> Q;
    return result + resultN;
}

Fixed point Square root

利用 第三周測驗一 的方法實作,並且轉為定點數。

int i_sqrt(int x)
{
    if (x <= 1)
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}
Q23_8 fixed_sqrt(Q23_8 x)
{
    if (x <= 1 << Q)
        return x;

    Q23_8 z = 0;
    for (Q23_8 m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;
    }
    z = z << Q / 2;
    return z;
}

Fixed point log

嘗試使用泰勒展開實作自然對數的計算,公式如下:
ln(1+x)=xx22+x33x44+...
這裡補充說明下面的實作方法,如果單純以數學的觀點(先不加入 fixed point 的討論),那麼它實際求的是以下的東西:

2i=012i+1(y1y+1)2i+1

接下來會用兩種觀點切入,分別是使用 tanh1ln 的泰勒展開式。

首先是 tanh1,會發現 tanh1x=i=012i+1x2i+1,然後我們把上式帶入:

2i=012i+1(y1y+1)2i+1=2tanh1(y1y+1)

Let x=y1y+1, k=2tanh1x

tanhk2=x

x=ek1ek+1

ekx+x=ek1

ek=1+x1x=1+y1y+11y1y+1=y

k=lny

所以 lny=2i=012i+1(y1y+1)2i+1

接下來是使用 ln 的觀點:

2i=012i+1(y1y+1)2i+1

i=11i(y1y+1)ii=11i(y1y+1)i

=ln(1+y1y+1)ln(1y1y+1)

=ln(2y)ln2

=lny

可以發現,我們也求出一樣的結論了

參考實作:

Q23_8 fixed_log(int input)
{
    if (!input || input == (1U << Q))
        return 0;

    int64_t y = input << Q;
    y = ((y - (1U << Q)) << (Q)) / (y + (1U << Q));
    int64_t ans = 0U;
    for (unsigned i = 1; i < 20; i += 2) {
        int64_t z = (1U << Q);
        for (int j = 0; j < i; j++) {
            z *= y;
            z >>= Q;
        }
        z <<= Q;
        z /= (i << Q);

        ans += z;
    }
    ans <<= 1;
    return (Q23_8) ans;
}

comp

與浮點數 log 比較,可從實驗結果發現,使用泰勒展開,當數值越大時,誤差越大,且需要越多的項次才能做精準的估算。

為了改善泰勒展開數字越大時誤差越大的問題,嘗試利用逼近法求 log 的近似值,假設 A<=x<=B ,求其近似值,方法如下:
CAB,因為 log(C)log(AB)=12(log(A)+log(B)),因此我們可以得到 log(C)log(A)log(B) 的平均。

接著比較 xC,如果 x=C 及得到我們所求,如果 x<C ,我們則將 B 替換成 C ,如果 x>C ,我們則將 A 替換成 C,繼續下一輪的求近似。

因為求近似值時,須先知道 log(A)log(B) 的實際值,因此實作時,將以 log2(x) 做計算,且 A B 分別為 2m2m+1 ,其中 2m<=x<=2m+1

      A   x   B
------>---|---<------

實作如下:

Q23_8 fixed_log(int input)
{
    if (!input || input == 1)
        return 0;

    Q23_8 y = input << Q;  // int to Q15_16
    Q23_8 L = 1L << ((31 - __builtin_clz(y))), R = L << 1;
    Q23_8 Llog = (31 - __builtin_clz(y) - Q) << Q, Rlog = Llog + (1 << Q), log;

    for (int i = 1; i < 20; i++) {
        if (y == L)
            return Llog;
        else if (y == R)
            return Rlog;
        log = fixed_div(Llog + Rlog, 2 << Q);

        int64_t tmp = ((int64_t) L * (int64_t) R) >> Q;
        tmp = fixed_sqrt((Q23_8) tmp);

        if (y >= tmp) {
            L = tmp;
            Llog = log;
        } else {
            R = tmp;
            Rlog = log;
        }
    }

    return (Q23_8) log;
}

comp

結果和浮點數 log 相當接近。