# 2025q1 Homework2 (quiz1+2) contributed by < `BrotherHong` > ## [第 1 週測驗題](https://hackmd.io/@sysprog/linux2025-quiz1) ### 測驗 1 #### 參考答案 * AAAA = `&l->head` * BBBB = `before` * CCCC = `&(*p)->next` * DDDD = `(*p)->next` #### 思路 1. 題目為**單向非環狀**鏈結串列 2. 想插入節點 `item` 到某個節點的前方,就必須先找到目標節點 `before` 的前一個節點 `prev` 3. 假設我們已經找到節點 `prev`,簡單進行兩個操作即可完成連接 * `item->next = prev->next` * `prev->next = item` 如果依照上述思路去實作,在尋找 `before` 的時候會需要兩個節點指標 `curr` 和 `prev` 分別記錄**目前尋找到的節點**和其**前一個節點**。 當執行思路的第 3 步時,`prev` 可能為 `NULL`,這時候就要設條件去判斷,為了簡化判斷這個步驟,我們就可以利用指標的指標 `p` 來簡化 (`p = &prev->next`)。 <!-- 當 `l->head` 為 `NULL` 時,`p = &l->head`,`*p = NULL` ```graphviz digraph { rankdir=LR; node[shape="record"] headA [label="{ <data> head | <next> }"] nullA [label="NULL"] pA [label="p" shape=plaintext] headA:next:c -> nullA [tailclip=false] pA -> headA:next } ``` --> ```c list_item_t **p; for (p = &l->head; *p != before; p = &(*p)->next) ; *p = item; (*p)->next = before; ``` #### 延伸問題 ##### 解釋程式碼原理 <!-- 首先看到測試程式的 main 方法 ```c int main(void) { printf("---=[ List tests\n"); char *result = test_suite(); if (result) printf("ERROR: %s\n", result); else printf("ALL TESTS PASSED\n"); printf("Tests run: %d\n", tests_run); return !!result; } ``` 我們可以發現 `result` 會儲存測試程式 `test_suite()` 的回傳結果,若不為 `NULL` 代表有發生錯誤並印出錯誤訊息 --- 深入往 `test_suite()` 看 ```c static char *test_suite(void) { my_run_test(test_list); return NULL; } ``` 它使用了一個 macro `my_run_test` 來執行傳入的 `test` 方法,更新測試次數 `tests_run`,若有錯誤訊息直接回傳。 所以實際上會執行到 `test_list()` 這個方法。 --> 在測試程式碼上方定義了兩個 macro ,下面這個是其中一個,可以發現它們用了 do ... while(0) 來包住要執行的程式碼,用意是為了讓呼叫 macro 的語句結尾可以加上分號,寫起來會像是呼叫方法的樣子。 ```c #define my_run_test(test) \ do { \ char *message = test(); \ tests_run++; \ if (message) \ return message; \ } while (0) ``` > 參考 > [Why use apparently meaningless do-while and if-else statements in macros? -- Stack Overflow](https://stackoverflow.com/questions/154136/why-use-apparently-meaningless-do-while-and-if-else-statements-in-macros) 除此之外,也能避免在其他使用情況造成語意錯誤 在上方參考的論壇的內文有問到: 常看到使用 do ... while(0) 和 if(1) { ... } else ```c #define FOO(X) do { f(X); g(X); } while (0) #define FOO(X) if (1) { f(X); g(X); } else ``` 為什麼不像下面這樣寫就好? ```c #define FOO(X) f(X); g(X) ``` **情況舉例** ```c if (stmt) FOO(X); else func(); ``` 在此展開會變成 ```c if (stmt) f(X); g(X); else func(); ``` 由於沒有用大括號括住,會造成語法上的錯誤。 用 if (1) { ... } else 這種寫法,後面加上 else 也是為了避免在此情況下出現 [dangling else](https://en.wikipedia.org/wiki/Dangling_else) 如果只用 if (1) { ... } ```c if (stmt) if (1) { f(X); g(X); } else func(); ``` 這個 else 就會被判定給 if (1),而不是上方的 if (stmt) --- ##### 在現有基礎上加入合併排序 --- ### 測驗 2 #### 參考答案 * EEEE = `(*pred_ptr)->r` * FFFF = `&(*pred_ptr)->r` #### 思路 ```c /* Find the in-order predecessor: * This is the rightmost node in the left subtree. */ block_t **pred_ptr = &(*node_ptr)->l; while (EEEE) pred_ptr = FFFF; ``` 從註解中可以知道我們的目標是要找到左子樹的最右節點,所以只要有右子樹存在最右節點必定存在在右子樹中,若不存在右子樹代表目前訪問到的節點就是我們要找的最右節點。 因此 `while` 的條件就設為 `(*pred_ptr)->r`,若存在右子樹 (not NULL),更新 `pred_ptr = &(*pred_ptr)->r` 往右繼續找。 ### 測驗 3 #### 參考答案 * GGGG = `head->prev=prev` * HHHH = `list_entry(pivot,node_t,list)->value` * IIII = `list_entry(n,node_t,list)->value` * JJJJ = `pivot` * KKKK = `right` #### 思路 ==/* GGGG */== 從 `rebuild_list_link` 這個方法裡面的程式碼可以觀察出來他是在把**單向非環狀鏈結串列**重新建構成**雙向環狀鏈結串列**,因此在迴圈下方兩行的程式碼是把尾端的節點和 `head` 連接起來,因此 `GGGG` 的答案就會是 `head->prev=prev`。 ==/* HHHH */== 觀察填空區域附近的程式碼可以知道這部分是在處理 pivot,再來從變數名稱 `value` 可以推測它是用來存 pivot 節點的值,因此 `HHHH` 會是 `list_entry(pivot,node_t,list)->value`。 ```c struct list_head *pivot = L; value = /* HHHH */ struct list_head *p = pivot->next; pivot->next = NULL; /* break the list */ ``` ==/* IIII */== 觀察這部分的程式碼可以知道這邊是在快速排序的 partition 過程,`n` 就是目前判斷到的節點。從變數名稱 `n_value` 就可以推測 `IIII` 要填的就是取 `n` 的值的操作 `list_entry(n,node_t,list)->value`。 ```c while (p) { struct list_head *n = p; p = p->next; int n_value = /* IIII */; if (n_value > value) { n->next = right; right = n; } else { n->next = left; left = n; } } ``` ==/* JJJJ */== ==/* KKKK */== 這部分的程式碼是在模擬遞迴版本的 push stack 操作,由於在合併操作的時候是從 head 插入 result ,因此必須按照 left -> pivot -> right 的順序放入 stack。 ```c begin[i] = left; begin[i + 1] = /* JJJJ */; begin[i + 2] = /* KKKK */; left = right = NULL; i += 2; ``` * 合併操作部分 ```c if (L) { L->next = result; result = L; } i--; ``` ## [第 2 週測驗題](https://hackmd.io/@sysprog/linux2025-quiz2) ### 測驗 1 #### 參考答案 * AAAA = `list_first_entry` * BBBB = `list_del` * CCCC = `list_move_tail` * DDDD = `list_add` * EEEE = `list_splice` * FFFF = `list_splice_tail` #### 思路 ==/* AAAA */== 這邊從 pivot 的型別 (`listitem *`) 加上參數的數量及內容,還有變數本身的名字 pivot,可以知道 `AAAA` 是要取得 list 的第一個 entry 因此答案會是 `list_first_entry`。 ==/* BBBB */== 快速排序在得到 pivot 的資訊後第一件事就是把 pivot 從 list 中分離,因此 `BBBB` 就會是 `list_del`。 ```c INIT_LIST_HEAD(&list_less); INIT_LIST_HEAD(&list_greater); pivot = AAAA(head, struct listitem, list); BBBB(&pivot->list); ``` ==/* CCCC */== 這部分迴圈是在做快速排序的 partition 操作,它會把**小於** pivot 的值放入 list_less ,**大於等於** pivot 的放入 list_greater。由於題目要實作的事 **stable sorting** ,為維持原本的順序,插入其他 list 時都要從尾端插入,因此 `CCCC` 會是 `list_move_tail`。 ```c list_for_each_entry_safe (item, is, head, list) { if (cmpint(&item->i, &pivot->i) < 0) list_move_tail(&item->list, &list_less); else CCCC(&item->list, &list_greater); } ``` ==/* DDDD */== ==/* EEEE */== ==/* FFFF */== 從這部分程式碼可以看出來這是快速排序的合併操作,要把排序完的 list_less、list_greater 和 pivot 合併成一個 sorted list。由於題目要求的是 **stable sorting** 並依照 pivot、list_less、list_greater 的順序放入。首先 pivot 只有一個節點,因此 `DDDD` 直接用 `list_add` 插入 pivot,再來把一整串 list_less 放在 pivot 前方,所以 `EEEE` 使用 `list_splice`,最後 list_greater 要放在 pivot 後方必須從尾端插入,因此 `FFFF` 使用 `list_splice_tail`。 > 這邊可以使用 list_splice_tail 的原因可以參考下方部分程式碼,它是將要插入的 list 的頭部接到目標 head 的尾端,讓 list 的尾端成為 head 的新的尾端。 ```c list_quicksort(&list_less); list_quicksort(&list_greater); DDDD(&pivot->list, head); EEEE(&list_less, head); FFFF(&list_greater, head); ``` * `list_splice_tail` 部分程式碼 ```c /* list_splice_tail */ head->prev = list_last; list_last->next = head; list_first->prev = head_last; head_last->next = list_first; ``` ### 測驗 2 #### 參考答案 * GGGG = `14` * HHHH = `2` * IIII = `0` * JJJJ = `3` * KKKK = `2` * LLLL = `1` * MMMM = `~1` * NNNN = `1` * PPPP = `2` #### 思路 `clz2` 是用遞迴來實作分治演算法計算前綴 `bit 0` 有多少個。 因此觀察函數的參數 `x` 就是我們要計算的目標,`c` 可以理解為遞迴的深度。 把每一層 `c` 對應的數值樣子列出來會像以下這樣: ``` c = 0 upper lower 0000 0000 0000 0000 | 0000 0000 0000 0000 c = 1 upper lower 0000 0000 | 0000 0000 c = 2 upper lower 0000 | 0000 c = 3 upper lower 00 | 00 ``` ==/* GGGG */== 在計算 lower 部分的時候是利用 mask 的方式來取得右半部分,當在 `c = 0` 這層時,我們要的 mask 就是 `0xFFFF`;在 `c = 1` 時,我們要的是右邊 8 個 bits 也就是 mask 是 `0xFF`,若使用 `0xFFFF` 位移來計算的話會是 `0xFFFF >> 8`,回去看 `mask` 陣列宣告也確實是 8;`c = 2` 時同理。 因此就可以推測 `GGGG` 的值會是要讓 `0xFFFF >> mask[3]` 能符合當 `c = 3` 時的 mask 要求。當 `c = 3` 時,lower 部分會需要取右邊 2 個 bits,需要的 mask 就是 `0x3 = 0b0011 = 0xFFFF >> 14` ,因此 `GGGG` 該填入 `14`。 ```c uint32_t upper = (x >> (16 >> c)); uint32_t lower = (x & (0xFFFF >> mask[c])); ``` ==/* HHHH */== ==/* IIII */== ==/* JJJJ */== 我們都知道遞迴函式都必須有一個終止條件,所以可以推測在此的 `JJJJ` 就是終止值,觀察程式碼可以發現 `c` 大部分都放在 `mask` 和 `magic` 陣列的索引值內,它們兩個的長度又只有 4,因此可以推測這邊的終止值會是 `3`。 接下來討論終止回傳值應該會是多少? 由於已經到了 `c = 3` 這層,`upper` 只會有 4 種可能 0、1、2、3,從這個三元運算式可以知道當 `upper` 為非 0,回傳的就會是 `magic[upper]`。 此時回頭去想我們對於這個函式的定義是什麼?是要取前綴 0 的數量,所以當 `upper = 1 = 0b01` 我們應該回傳 1、`upper = 2 = 0b10` 應回傳 0、`upper = 3 = 0b11` 應回傳 0。所以我們可以知道 `magic` 的每一項會是 0、1、2、3 對應的前綴 0 個數,因此 `magic[] = {2, 1, 0, 0}`。 所以 `HHHH` = `2`、`IIII` = `0`。 ==/* KKKK */== 當 `upper` 為 0 代表前綴 0 的個數就會是 `upper` 的 2 個 0 加上 `lower` 的前綴 0 個數,也就是 `2 + clz2(lower, c + 1)`,而這邊因為已經是終止條件回傳了,所以是用 `2 + magic[lower]`,因此 `KKKK` = `2`。 ```c if (c == JJJJ) return upper ? magic[upper] : KKKK + magic[lower]; ``` ==/* LLLL */== 若目前還不是最深層,就遞迴下去解,往更深一層遞迴的時候就必須更新當前的深度,因此 `LLLL` = `1`。 ```c return upper ? clz2(upper, c + 1) : (16 >> (c)) + clz2(lower, c + LLLL); ``` * `clz2` 補完後 ```c static const int mask[] = {0, 8, 12, 14}; static const int magic[] = {2, 1, 0, 0}; unsigned clz2(uint32_t x, int c) { if (!x && !c) return 32; uint32_t upper = (x >> (16 >> c)); uint32_t lower = (x & (0xFFFF >> mask[c])); if (c == 3) return upper ? magic[upper] : 2 + magic[lower]; return upper ? clz2(upper, c + 1) : (16 >> (c)) + clz2(lower, c + 1); } ``` ==/* MMMM */== ==/* NNNN */== ==/* PPPP */== 這部分填空的程式碼是要實作一個 `sqrti` :對一個整數開根號並向下取整。假設對 $x$ 開根號,我們要求 $y=\lfloor\sqrt{x}\rfloor$。 我們已知要求的答案 $y$ 是一個整數,如果我們用二進制來看 $y$ ,它會是一堆 $1 \ll b_i$ 的總和,例如 $y=0b00101011=(1\ll5)+(1\ll3)+(1\ll1)+(1\ll0)$,因此 $y^2=[(1\ll5)+(1\ll3)+(1\ll1)+(1\ll0)]^2$。 此時我們可以用代數的形式來看這個式子,$y^2=(a+b+c+d)^2$,我們只要找到 $a$ $b$ $c$ $d$ 分別是多少就能知道 $y$ 了。 我們知道我們要找的 $a$ $b$ $c$ $d$ 都一定會是 2 的冪,因此我們可以從最高位的 bit 往低位的 bit 去試,如此可以讓 $y^2$ 慢慢逼近 $x$。 >就像是在建構一個浮點數逼近某個值,首先嘗試加上範圍最大的值 `0.5`,如果把它加進去當前逼近的值會大於目標的話,那就不要把它加進當前逼近的值,接下來把範圍縮小成 `0.25`,若把它加進目前逼近的值還是小於目標值那就把它加進去,再來 `0.125`、`0.0625`、`...` ,以此類推。 所以我們可以先找到最高的可能的 bit 是哪一個,假設 $x$ 的最高位 bit 為 $b$,那麼 $x$ 一定會大於等於 $(2^{\lfloor\frac{b}{2}\rfloor})^2$ 並且小於 $(2^{\lfloor\frac{b}{2}\rfloor+1})^2$,因此我們就可以從第 $\lfloor\frac{b}{2}\rfloor$ 個 bit 開始往低位嘗試。 就目前的理解可以寫成下面的程式碼: ```c uint64_t sqrti(uint64_t x) { uint64_t m, y = 0; if (x <= 1) return x; int total_bits = 64; int shift = (total_bits - 1 - clz64(x)) >> 1; m = 1ULL << shift; while (m) { uint64_t b = y + m; if (x >= b*b) y += m; m >>= 1; } return y; } ``` 不過題目要求只能用加減法和位移運算。由於乘法出現在 $(a+b+c+d)^2$ 這個式子,因此我們可以嘗試從這個式子下手去看。 先從最簡單的 $(a+b)^2$ 來看: 這個式子可以把 $b$ 提出來寫成最下面那個形式。 $$ \begin{split} (a+b)^2&=a^2 + 2ab + b^2 \\&=a^2 + b(2a+b) \end{split} $$ 接下來看 $$ (a+b+c)^2=a^2+b^2+c^2+2ab+2bc+2ac $$ 整理後可以得到 $$ \begin{split} (a+b+c)^2&=(a^2+2ab+b^2)+c^2+2bc+2ac\\ &=(a+b)^2+c(2(a+b)+c) \end{split} $$ 同理在 $(a+b+c+d)^2$ 也可以展開成 $$ (a+b+c+d)^2=(a+b+c)^2 + d(2(a+b+c)+d) $$ 寫成一般式 $$ (a_1+a_2+\dots+a_n)^2=(\sum_{i=1}^{n-1}a_i)^2+a_n(2(\sum_{i=1}^{n-1}a_i)+a_n) $$ 這樣感覺還看不太出來如何不用到乘法運算,我們可以把後面展開變成以下 $$ (a_1+a_2+\dots+a_n)^2=(\sum_{i=1}^{n-1}a_i)^2+a_n^2+2a_n(\sum_{i=1}^{n-1}a_i) $$ 用我們目前的題目環境去理解這個式子的話: * 等號左邊的 $(a_1+a_2+\dots+a_n)^2$ 就是我們的 $x$ * 等號右邊左方的 $(\sum_{i=1}^{n-1}a_i)^2$ 就是我們目前逼近的 $y^2$ 值 * 我們就令 $(y')^2 = (\sum_{i=1}^{n-1}a_i)^2$,也就是 $y' = \sum_{i=1}^{n-1}a_i$ * $a_n$ 就是我們目前嘗試到的 `m` 所以式子就可以寫成: $$ x = (y')^2 + m^2 + 2my' $$ 由於我們的 $m$ 都會是 2 的冪,因此在做 $m^2$ 的時候只需要做 `m << shift` 就好 (需同時維護 `shift` 的值)。 同理,$2my'$ 也就是對 $y'$ 左移 `shift+1` 就能計算完成。 $(y')^2$ 的部分則可以在算出來時直接讓 $x$ 減掉,等同於是把它移到等號左邊,並另該結果為 $x'$,變成 $x'=x - (y')^2 = m^2 + 2my'$,這邊的 $x'$ 具體一點就是我們目前逼近的值與目標值之間的差值。 準確來講,只要我們嘗試算出來的結果比差值 $x'$ 還要小那這個 `m` 就是我們要找的值,也就是: $$ x' \ge m^2+2my' $$ 理解到這裡,就能寫出以下程式碼: ```c uint64_t sqrti(uint64_t x) { uint64_t m, y = 0; if (x <= 1) return x; int total_bits = 64; int shift = (total_bits - 1 - clz64(x)) >> 1; m = 1ULL << shift; while (m) { uint64_t b = (m << shift) + (y << (shift + 1)); if (x >= b) { x -= b; y += m; } m >>= 1; shift--; } return y; } ``` 這樣的寫法已經不會用到乘法運算了,不過還有地方可以再精簡。 如果我們對 `m` 預先再左移 `shift` 次就可以省去計算 `b` 時所需要的左移操作,因此不用再對 `shift` 除以 2,只需要確保他是偶數 (因為$(\lfloor\frac{b}{2}\rfloor)^2$),且每次迴圈結束都要右移 2,這是為了讓迴圈跑的次數與原先相同。所以這邊 `MMMM` 填入 `~1`、`PPPP` = `2`。 由於 `y` 是由 `m` 來更新的,所以 `y` 的值會是會是我們之前的 `y` 再左移 `shift` 的樣子,所以我們每次迴圈做完都要把 `y` 右移 1。`NNNN` = `1`。 而在計算 `b` 會需要 `y << (shift+1)`,也就是前一次迴圈的 `y`,因此右移 1 的操作必須在計算 `b` 之後,且在下一次更新,也就是 `y += m` 之前。 簡化過後,就能變成和題目一樣的程式碼了: ```c uint64_t sqrti(uint64_t x) { uint64_t m, y = 0; if (x <= 1) return x; int total_bits = 64; int shift = (total_bits - 1 - clz64(x)) & ~1; m = 1ULL << shift; while (m) { uint64_t b = y + m; y >>= 1; if (x >= b) { x -= b; y += m; } m >>= 2; } return y; } ``` #### 延伸問題 ##### 擴充為向上取整 $\lceil(\sqrt{x})\rceil$ 若計算結束後 `x` 的值,也就是與目標的差值,為非 0 的話,代表原先傳入的 `x` 不是平方數,因此算出來的 `y` 需要加一;若計算結束後的 `x` 為 0,代表是平方數結果就是 `y`。 因此只需要更改回傳的那行程式碼就能達成: ```c return y + !!x; ``` ### 測驗 3 #### 參考答案 * AAAA = `map->bits` * BBBB = `map->bits` * CCCC = `first->pprev` * DDDD = `n->pprev` * EEEE = `n->pprev` #### 思路 * 雜湊函數: ```c #define GOLDEN_RATIO_32 0x61C88647 static inline unsigned int hash(unsigned int val, unsigned int bits) { /* High bits are more random, so use them. */ return (val * GOLDEN_RATIO_32) >> (32 - bits); } ``` ==/* AAAA */== ==/* BBBB */== 從雜湊函數可以看出來第二個參數 `bits` 就是需要傳入先前在初始化設定的 `map->bits`。 * `find_key` ```c struct hlist_head *head = &(map->ht)[hash(key, AAAA)]; ``` * `map_add` ```c struct hlist_head *h = &map->ht[hash(key, BBBB)]; ``` ==/* CCCC */== 已知 `kn` 是我們要插入的 hash_key,所以`n` 就會是我們要插入的節點,而 `first` 是目前同個 hash_key 的 list 的開頭。從 `n->next = first` 可以知道是從 list 的頭部端插入,所以如果已存在其他相同 hash_key 的節點 (`if (first) 成立`),我們就要更新 first 的 pprev 所以 `CCCC` 是 `first->pprev`。 ==/* DDDD */== 承 C ,把 `n` 插入頭部端後也需要更新其 pprev,因此 `DDDD` = `n->pprev`。 * `map_add` ```c struct hlist_node *n = &kn->node, *first = h->first; n->next = first; if (first) CCCC = &n->next; h->first = n; DDDD = &h->first; ``` ==/* EEEE */== 從 `map_deinit` 完整程式碼可以知道 `n` 是我們要從 list 中移除並釋放掉的點,下方部分程式碼就是在將 `n` 從 list 中分離,並且從後續操作可以知道 next 就是 `n` 的下個節點,pprev 就是 `n` 的 pprev。因此 `EEEE` = `n->pprev`。 * `map_deinit` ```c struct hlist_node *next = n->next, **pprev = EEEE; *pprev = next; if (next) next->pprev = pprev; n->next = NULL, n->pprev = NULL; ```