# 2025q1 Homework3 (kxo) contributed by < `rota1001` > {%hackmd NrmQUGbRQWemgwPfhzXj6g %} :::danger 注意書寫規範! ::: :::info 以下本文中的圖(非引用的部份),只要涉及定點數的運算,縱軸都是乘上 `1<<Q` 的結果。 ::: ## 定點數運算 ### `fixed_sqrt` 改進 這是原本的 `fixed_sqrt` 函式: ```c int fixed_sqrt(int x) { if (x <= 1 << Q) return x; int z = 0; for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) { int b = z + m; z >>= 1; if (x >= b) x -= b, z += m; } z = z << Q / 2; return z; } ``` 會發現,它會先開完根號後再左移 `Q / 2`,可以發現二進位制下最低的 `Q / 2` 位永遠是 0,這些是喪失的精度。 所以可以改成在開根號前先左移 `Q` 位再進行根號運算,會發現: $$ \sqrt{(x\times2^Q)\times 2^Q}=\sqrt{x}\times 2^Q $$ 所以這樣能更準確的求出根號。然後前面的 `x <= 1 << Q` 也是沒有必要的,只要判斷 `x` 是否是 `0` 就好了(因為 `x` 會做左移,所以絕對不會是 1,然後開根號時要特判的應是被左移後為 1 或 0 的值)。修改後程式碼為以下: ```c int fixed_sqrt_modified(int x) { if (x <= 0) return 0; int64_t k = (x << Q); int z = 0; for (int m = 1UL << ((63 - __builtin_clzll(k)) & ~1UL); m; m >>= 2) { int b = z + m; z >>= 1; if (k >= b) k -= b, z += m; } return z; } ``` ### `Fixed point log` 改進 我在 [求 log 近似值](https://hackmd.io/@rota1001/fix-log-research) 說明了[作業說明](https://hackmd.io/@sysprog/linux2025-kxo/%2F%40sysprog%2Flinux2025-kxo-a#Fixed-point-log)中使用的第一種 `fixed_log` 函數的實作,接下來討論其中的第二種實作: 他是使用 $C=\sqrt{AB}$ 則 $\log{C}=\frac{\log{A}+\log{B}}{2}$,所以找到好計算的 $A$ 和 $B$ 就好。而好計算的 $A$ 和 $B$ 就是找到把 $C$ 夾住的兩個 2 的羃,也就是 $2^m \le C \le 2^{m+1}$,文中畫出來的圖長這樣: ![image](https://hackmd.io/_uploads/Sy6Hnsi21e.png) 這看起來很棒,然而實際上我們去畫圖會畫出這樣的圖: ![Q15_16](https://hackmd.io/_uploads/ryspvQhhye.png) 於是我們去看一下文中的程式碼(我把它調整成能動的版本,要注意的一件事是 `fixed_sqrt` 要支援 `int64_t` 的輸入),會發現,它做的事情是先取兩個 2 的羃來將 $C$ 夾住,並且在指數上進行二分搜尋。這裡的問題是在最前面: ```c int y = input << Q; // int to Q15_16 ``` 這裡將 `input` 由 `Q23_8` 轉成 `Q15_16`,然而後續的運算卻仍將它視為 `Q23_8` 來運算,所以算出來的結果會比預期結果還大,這裡統一用 `Q23_8` 來計算: ```diff - int y = input << Q; // int to Q15_16 + int y = input; // Use Q23_8 ``` 會發現完美的貼合了: ![Q23_8](https://hackmd.io/_uploads/SyfdYQnhJe.png) 然而,這個程式碼還有一些可以改進的空間,首先是需要一個迴圈,裡面有條件判斷會造成很多分支,而且每次都要進行一個 64 位元的根號運算,是很花時間的。然後如果輸入一個小於 1 的數的話,`(31 - __builtin_clz(y) - Q)` 會變成負數,造成左移操作變成未定義行為。 於是我去進行了以下的改進,完整的程式碼放在 [gist](https://gist.github.com/rota1001/6a13112a8ab8d85a1f95070077de8e38#file-fixed-log-c)。 參考 [picolibc/newlib/libm/math/s_log.c](https://github.com/picolibc/picolibc/blob/main/newlib/libm/math/s_log.c) 中的做法,並且改成使用定點數。因為上述程式碼實作是針對浮點數的,所以這裡不看原始程式碼,而是直接用數學分析此方法: 首先,這裡先寫一個 $\ln$ 的泰勒展開式備用: $$ \displaystyle \ln x = x - \frac{x^2}{2} + \frac{x^3}{3} - \frac{x^4}{4}+... $$ 我們現在要求 $\ln(1+f)$,然而,我們可以消除 $x$ 奇數羃的項,讓我們可以用相同的成本計算到更高的羃。 我們可以找到一個 $s$ 使得: $$ \displaystyle \ln(1+f) = \ln(\frac{1+s}{1-s})= \ln(1+s)-\ln(1-s) = 2\sum_{i=0}^{\infty}\frac{1}{2i+1}s^{2i+1} $$ 由 $\displaystyle 1+f = \frac{1+s}{1-s}$ 可以求得 $\displaystyle s=\frac{f}{2+f}$,而我們會把 $f$ 控制在 $[1, 2)$ 這個區間中,所以可以把 $s$ 視為接近 0 的數(當然我們在高次項還是要做處理) 我們可以對 $\ln(1+f)$ 做以下處理: $$ \displaystyle \ln(1+f) = 2\sum_{i=0}^{\infty}\frac{1}{2i+1}s^{2i+1}=2s + sR $$ 其中,$R$ 的部份我們會用一個 14 次的多項式去逼近。雖說是 14 次,但它只有偶數項,原文中是使用一個叫做 [Reme algorithm](https://en.wikipedia.org/wiki/Remez_algorithm) 的演算法來算出來的,它可以拿來對一個函數做多項式的近似。 ```c Lg1 = _F_64(6.666666666666735130e-01), /* 3FE55555 55555593 */ Lg2 = _F_64(3.999999999940941908e-01), /* 3FD99999 9997FA04 */ Lg3 = _F_64(2.857142874366239149e-01), /* 3FD24924 94229359 */ Lg4 = _F_64(2.222219843214978396e-01), /* 3FCC71C5 1D8E78AF */ Lg5 = _F_64(1.818357216161805012e-01), /* 3FC74664 96CB03DE */ Lg6 = _F_64(1.531383769920937332e-01), /* 3FC39A09 D078C69F */ Lg7 = _F_64(1.479819860511658591e-01); /* 3FC2F112 DF3E5244 */ ``` 然而,可以進一步的去改良 $2s$ 的精確度。因為 $s$ 是 $f$ 經過除法計算得到的,會含有精度的遺失,所以更好的方法是讓結果中大部份的值由 $f$ 表示出。可以發現: $$ \displaystyle (1-s)f = \frac{2}{1+f}\times f = 2s $$ 所以可以得到 $2s = f-sf=f-\frac{1}{2}f^2(1-s)$ $\displaystyle \text{Let }hfsq = \frac{f^2}{2}$,則: $$ \displaystyle \ln(1+f)=f-hfsq(1-s)+sR = f-(hfsq-s(hfsq+R)) $$ 最後,因為我們要限制 $f$ 的範圍,所以會乘上一個 2 的羃,最後要把它用對數律補回來,就完成了。接下來看看怎麼納入定點數的考量。 首先對於一個輸入的 `x`,我會先排除小於等於 0 的數,也就是說剩下的數都不會是 0。接下來我會先轉成 `uint64_t`,再把最高位的 1 一到 `1<<31` 的位置,並且紀錄移動的位數 `shift`。然後因為 $y=1+f$,所以把它減掉 `1<<31`,也就是以 31 位元的定點數調整後的 1。接下來的這些計算都會以 31 位元的定點數來計算。 ```c int shift = __builtin_clzll(y) - 32; y <<= shift; f = y - 0x80000000u; ``` 接下來就是求出 `s` 和 `hfsq`: ```c s = (f << 31) / ((2ull << 31) + f); hfsq = (f * f) >> 32; // (f * f / 2) >> 31 ``` 然後接下來用 python 把那些多項式中的數值用符合 31 位元定點數的方式計算出來(順便算一個 `ln2`,等下會用到): ```c const uint64_t a1 = 1431655765, a2 = 858993459, a3 = 613566760, a4 = 477218078, a5 = 390489239, a6 = 328862160, a7 = 317788895; const uint64_t ln2 = 1488522236; ``` 接下來就是照著代公式: ```c uint64_t s2 = (s * s) >> 31, s4 = (s2 * s2) >> 31; y = (s4 * (a2 + ((s4 * (a4 + ((s4 * a6) >> 31))) >> 31))) >> 31; y += (s2 * (a1 + ((s4 * (a3 + ((s4 * (a5 + ((s4 * a7) >> 31))) >> 31))) >> 31))) >> 31; y = (y * s) >> 31; y = f - (hfsq - ((s * (hfsq + y)) >> 31)); ``` 最後,依照位數調整結果: ```c return ((23 - shift) * ln2 + y) >> 23; ``` 以下是這個程式的結果,可以幾乎精準的計算 $\ln$ 到定點數的位數 ![better_log2](https://hackmd.io/_uploads/S1FQk3jnkx.png) 然後我們去看一下他在輸入小於 1 的時候的表現: ![yee](https://hackmd.io/_uploads/SyCYX162ke.png) 可以發現,他們是重合的。 ## Linux 核心的並行處理 ### 對 `open_cnt` 的錯誤判斷修正 > Pull request: [Fix incorrect open count check in release function](https://github.com/sysprog21/simrupt/pull/4) 這是針對 [simrupt](https://github.com/sysprog21/simrupt) 的討論。 如果開啟 3 個終端機 A、B、C,其中 C 使用以下命令來觀察訊息: ```bash sudo dmesg --follow ``` 接下來,A、B 都使用以下命令來和核心模組互動: ```bash sudo cat /dev/simrupt ``` 會發現,A、B 都會持續的輸出字元。 然後我們把終端機 B 用 Ctrl+C 將使用者程式中止,會發現,這個時候連終端機 A 都停止輸出字元了,然而預期行為應該是核心模組要持續的產生新的字元直到對那個裝置的引用數(`open_cnt`)變為 0。 會造成這個問題是在 `simrupt_release` 中,當 `atomic_dec_and_test(&open_cnt)` 為 0 的時候會讓生產者停止產生新的字元: ```cpp static int simrupt_release(struct inode *inode, struct file *filp) { pr_debug("simrupt: %s\n", __func__); if (atomic_dec_and_test(&open_cnt) == 0) { del_timer_sync(&timer); flush_workqueue(simrupt_workqueue); fast_buf_clear(); } pr_info("release, current cnt: %d\n", atomic_read(&open_cnt)); return 0; } ``` 但是,觀察 `atomic_dec_and_test` 的註解,會發現當 `open_cnt` 被操作完後的結果是 0 的時候,它會回傳 `true`: ```cpp /** * atomic_dec_and_test() - atomic decrement and test if zero with full ordering ... * Return: @true if the resulting value of @v is zero, @false otherwise. */ ``` 所以,我做了以下的修正: ```diff - if (atomic_dec_and_test(&open_cnt) == 0) { + if (atomic_dec_and_test(&open_cnt)) { ``` 而另外也會發現 [kxo](https://github.com/sysprog21/kxo) 中也有出現一樣的問題,會造成在使用者程式被中止後,核心模組仍不斷產生新的棋盤,於是提交了以下 pull request: > Pull request: [Fix incorrect open count check in release function](https://github.com/sysprog21/kxo/pull/4) ## 對奕的核心模組 ### 降低核心模組與使用者程式的溝通成本 首先,觀察到核心模組中會使用 `draw_board` 將棋盤畫出來,並且把整個棋盤(共 66 位元組)傳送給使用者程式,成本太高,於是做的第一個改進是讓核心模組傳送的是原始的棋盤(16 位元組),將 `draw_board` 改成以下: ```cpp static int draw_board(char *table) { int i = 0, k = 0; while (i < N_GRIDS) { draw_buffer[i++] = table[k++]; smp_wmb(); } return 0; } ``` 然後變成在使用者程式中將棋盤畫出來: > commit [b73de5f](https://github.com/sysprog21/kxo/commit/b73de5f1a1bf6914c4c7cba97e1a6e7fab46807e) 更進一步,我們會發現 `draw_buffer` 原本是用 16 格的字元陣列去存 `' '`、`'O'`、`'X'`,然而,要存下這個狀態不用用那麼多的空間,只要每格存 3 種狀態存 16 格就好。然而,3 不是 2 的羃,於是決定了用兩個 bit 來存一個格子的狀態,這樣就能用簡單的位元運算來處理。另外,觀察這三種字元的二進位制表示: - `'O'`:0b1001111 - `'X'`:0b1011000 - `' '`:0b0100000 會發現右移 2 再取最低的兩位可以直接表示出 3 種狀態,於是將 `draw_board` 改為以下: ```cpp static int draw_board(char *table) { int i = 0, k = 0; draw_buffer = 0; smp_wmb(); while (i < N_GRIDS) { draw_buffer |= ((table[k++] >> 2) & 3) << (i << 1); smp_wmb(); i++; } return 0; } ``` 然後在使用者程式做對應的修改,就完成了: > commit [00f0b76](https://github.com/sysprog21/kxo/commit/00f0b76004303e49a0abb573a424f781c47b5a20) ### 程式狀態控制 #### 壓縮傳遞狀態 > commit [4a04e1c](https://github.com/sysprog21/kxo/commit/4a04e1cff57e4f80d1a52c70592449e95f63ff2e) 首先會發現,程式的狀態儲存在 `attr_obj` 裡面,而使用者與它交流的方式是使用 `/sys/class/kxo/kxo/kxo_state`,而在核心模組中則定義了讀寫的函式 `kxo_state_show` 和 `kxo_state_store`。第一件改進的事情是它傳輸資料的方式是用一個長度為 6 的字串去傳送狀態,然而實際上只需要有 3 個位元就可以表示所有狀態了,所以先用位元運算去降低傳輸成本(至於儲存狀態的部份,因為不想去改變結構體就沒有做改動,如果有需要之後會補上) :::danger 提供數學分析 ::: 這裡的數學分析直接窮舉所有情況,去計算出實際的資訊量。 ```cpp int main() { int ans = 0; for (unsigned int i = 0; i < 43046721; i++) { int n = 0, m = 0; int tmp = i; for (int j = 0; j < 16; j++) { table[j] = charset[tmp % 3]; tmp /= 3; if (table[j] == 'O') n++; if (table[j] == 'X') m++; table[j] = ' '; } if (n - m != 1 && n - m != 0) continue; if (check_win_mod(table, 'O') == 'O' && check_win_mod(table, 'X') == 'X') continue; if (n == m && check_win_mod(table, 'O') == 'O') continue; if (n - m == 1 && check_win_mod(table, 'X') == 'X') continue; ans++; } printf("%d\n%f\n", ans, log2(ans)); } ``` `check_win_mod` 可以判斷有沒有由後面那個字元連線。這裡是窮舉所有可能的情況中,合法的情況。首先 `O` 的數量要比 `X` 的數量多 0 到 1 個。然後兩者不能同時贏,還有數量相同的時候 `O` 不能贏,數量差 1 的時候 `X` 不能贏。最後算出來總共有 10165779 種可能的結果,用 24 個位元可以存下所有狀態。 #### 接收 Ctrl+Q 鍵盤輸入 > commit [4a6128f](https://github.com/sysprog21/kxo/commit/4a6128f5578bbd5da3473722d704a0fc7b77df55) 接下來是要讓它可以接受 Ctrl+Q 的鍵盤輸入,並且停止核心模組中的遊戲運行。去讀了 [mazu-editor](https://github.com/jserv/mazu-editor) 的程式碼,會發現如果要用 `read` 街收到 Ctrl+Q 的鍵盤輸入的話,那必須要把 `c_iflag` 去除掉 `IXON` 這個 flag,於是加上了這個更改: ```diff static void raw_mode_enable(void) { ... + raw.c_iflag &= ~IXON; ... } ``` 另外,我讓核心模組在每次 `kxo_open` 被呼叫的時候都去檢查 `attr_obj.end`,如果他是 `'1'` 的話,那就把狀態定為初始狀態。因為如果一個使用者去中止遊戲,另一個使用者再去和核心模組進行互動的話,那麼它應該從新開始進行一場遊戲。 這裡與 [weiso131](https://github.com/weiso131/) 共同向 kxo 提交了 pull request: > Pull request: [Enable Control+Q capture by setting input mode](https://github.com/sysprog21/kxo/pull/10) #### 顯示過去棋盤紀錄 > commit [05e349d](https://github.com/sysprog21/kxo/commit/05e349de9e705591268a90a965305e06b251145c) 我另外開了一個 `record.c` 來紀錄過去的棋盤。 對於一個移動來說,有 4 * 4 共 16 種可能,而最差的情況是動 16 步(也就是把所有格子都填滿)。所以可以用每 4 個位元紀錄一個移動(將每個移動表示成 0 到 15 中的其中一個數字),這樣就可以用 64 位元的無號整數來紀錄一場遊戲的所有移動。而這裡使用一個環狀的佇列來紀錄棋盤,共有 16 格,最多的歷史紀錄就是 15 個,剩下會由比較早的優先丟棄。 然後我在 `record.h` 中定下了以下的界面: - `void record_init(void)` 在 `kxo_init` 中被呼叫,拿來初始化佇列 - `void record_board_init(void)` 初始化一個新的棋盤 - `void record_board_update(int move)` 將現在的棋盤紀錄更新一個新的移動 - `void record_append_board(void)` 將現在的棋盤放進佇列中,如果沒有呼叫這個函數的話,那麼棋盤的紀錄就不會算,這樣可以讓還沒有完成的遊戲不會進入紀錄中 - `uint64_t record_get_board(unsigned int index)` 拿到從現在佇列中有的紀錄中,以插入時間來說第 `index` 個棋盤紀錄 - `int record_get_size(void)` 得到現在的佇列大小 由於要在使用者程式的地方輸出結果,所以要有個方式讓使用者知道一個棋盤紀錄的總共步數。我們會發現,合法的遊戲步驟不會到 $2^{64}$ 種狀態,一個 `uint64_t` 能儲存的狀態是比遊戲的總狀態還要多的。基於這個想法,我們可以去找到一種方式去同時表示出遊戲狀態和總步數。 可以發現,如果把 0 到 15 的所有數字用 bitwise xor 做計算會得到 0,而不會出現同一個格子 2 次,根據鴿籠原理可以知道,如果總步數是 16 的話,那麼所有移動的 bitwise xor 會是 0。基於這個原理,我讓每個總步數不為 16 的棋盤紀錄在最高的 4 位元補一個數字,讓每 4 個位元為一組的數字用 bitwise xor 計算完後,結果會是他的步數。而因為總步數不可能是 0,所以如果出現 0 就代表他的總步數是 16。 實作是這樣: ```cpp uint64_t record_size = record; record_size ^= (record_size >> 4); record_size ^= (record_size >> 8); record_size ^= (record_size >> 16); record_size ^= (record_size >> 32); record_size &= 0xf; if (!record_size) record_size = 16; ``` 另外,在核心模組與使用者程式交流的部份我是使用 `ioctl`。設定了兩種模式 `IOCTL_READ_SIZE` 和 `IOCTL_READ_LIST`。第一種模式會回傳佇列大小,第二種模式會去向使用者緩衝區寫入它要求的 `index` 對應的棋盤紀錄。這個 `index` 是存在 `cmd` 的第 1 到 4 個位元。 ```cpp /** * kxo_ioctl - Get the size of board record or the specific board * @cmd: the opcode * @arg: the user buffer * * The lowest bit of cmd represent the operation number, which is * IOCTL_READ_SIZE or IOCTL_READ_LIST. If it is read list mode, * the 1st to 4th bit represent the index of the board user wants * to get * * Return: * It will return -ENOTTY if the mode is invalid * If the mode is IOCTL_READ_SIZE, it will return the size of record * queue. If the mode is IOCTL_READ_LIST, it will return the number of * bytes it copies to the user. */ static long kxo_ioctl(struct file *flip, unsigned int cmd, unsigned long arg) { int ret; switch (cmd & 1) { case IOCTL_READ_SIZE: ret = record_get_size(); pr_info("kxo_ioctl: the size is %d\n", ret); break; case IOCTL_READ_LIST: uint64_t record = record_get_board(cmd >> 1); ret = copy_to_user((void *) arg, &record, 8); pr_info("kxo_ioctl: read list\n"); break; default: ret = -ENOTTY; } return ret; } ``` 現在在輸入 Ctrl+Q 之後會出現棋盤紀錄了: ```bash Stopping the kernel space tic-tac-toe game... Moves: A3 -> B2 -> B3 -> C3 -> B0 -> C2 -> A1 -> C1 Moves: C0 -> B3 -> A0 -> C3 -> B0 Moves: C2 -> A3 -> B2 -> A0 -> B1 -> A1 -> A2 Moves: A0 -> B1 -> A1 -> A2 -> C1 -> B2 -> C0 -> C2 ```