# 2025q1 Homework2 (quiz1+2)
contributed by < `BrotherHong` >
## [第 1 週測驗題](https://hackmd.io/@sysprog/linux2025-quiz1)
### 測驗 1
#### 參考答案
* AAAA = `&l->head`
* BBBB = `before`
* CCCC = `&(*p)->next`
* DDDD = `(*p)->next`
#### 思路
1. 題目為**單向非環狀**鏈結串列
2. 想插入節點 `item` 到某個節點的前方,就必須先找到目標節點 `before` 的前一個節點 `prev`
3. 假設我們已經找到節點 `prev`,簡單進行兩個操作即可完成連接
* `item->next = prev->next`
* `prev->next = item`
如果依照上述思路去實作,在尋找 `before` 的時候會需要兩個節點指標 `curr` 和 `prev` 分別記錄**目前尋找到的節點**和其**前一個節點**。
當執行思路的第 3 步時,`prev` 可能為 `NULL`,這時候就要設條件去判斷,為了簡化判斷這個步驟,我們就可以利用指標的指標 `p` 來簡化 (`p = &prev->next`)。
<!-- 當 `l->head` 為 `NULL` 時,`p = &l->head`,`*p = NULL`
```graphviz
digraph {
rankdir=LR;
node[shape="record"]
headA [label="{ <data> head | <next> }"]
nullA [label="NULL"]
pA [label="p" shape=plaintext]
headA:next:c -> nullA [tailclip=false]
pA -> headA:next
}
``` -->
```c
list_item_t **p;
for (p = &l->head; *p != before; p = &(*p)->next)
;
*p = item;
(*p)->next = before;
```
#### 延伸問題
##### 解釋程式碼原理
<!-- 首先看到測試程式的 main 方法
```c
int main(void)
{
printf("---=[ List tests\n");
char *result = test_suite();
if (result)
printf("ERROR: %s\n", result);
else
printf("ALL TESTS PASSED\n");
printf("Tests run: %d\n", tests_run);
return !!result;
}
```
我們可以發現 `result` 會儲存測試程式 `test_suite()` 的回傳結果,若不為 `NULL` 代表有發生錯誤並印出錯誤訊息
---
深入往 `test_suite()` 看
```c
static char *test_suite(void)
{
my_run_test(test_list);
return NULL;
}
```
它使用了一個 macro `my_run_test` 來執行傳入的 `test` 方法,更新測試次數 `tests_run`,若有錯誤訊息直接回傳。
所以實際上會執行到 `test_list()` 這個方法。 -->
在測試程式碼上方定義了兩個 macro ,下面這個是其中一個,可以發現它們用了 do ... while(0) 來包住要執行的程式碼,用意是為了讓呼叫 macro 的語句結尾可以加上分號,寫起來會像是呼叫方法的樣子。
```c
#define my_run_test(test) \
do { \
char *message = test(); \
tests_run++; \
if (message) \
return message; \
} while (0)
```
> 參考
> [Why use apparently meaningless do-while and if-else statements in macros? -- Stack Overflow](https://stackoverflow.com/questions/154136/why-use-apparently-meaningless-do-while-and-if-else-statements-in-macros)
除此之外,也能避免在其他使用情況造成語意錯誤
在上方參考的論壇的內文有問到:
常看到使用 do ... while(0) 和 if(1) { ... } else
```c
#define FOO(X) do { f(X); g(X); } while (0)
#define FOO(X) if (1) { f(X); g(X); } else
```
為什麼不像下面這樣寫就好?
```c
#define FOO(X) f(X); g(X)
```
**情況舉例**
```c
if (stmt)
FOO(X);
else
func();
```
在此展開會變成
```c
if (stmt)
f(X); g(X);
else
func();
```
由於沒有用大括號括住,會造成語法上的錯誤。
用 if (1) { ... } else 這種寫法,後面加上 else 也是為了避免在此情況下出現 [dangling else](https://en.wikipedia.org/wiki/Dangling_else)
如果只用 if (1) { ... }
```c
if (stmt)
if (1) { f(X); g(X); }
else
func();
```
這個 else 就會被判定給 if (1),而不是上方的 if (stmt)
---
##### 在現有基礎上加入合併排序
---
### 測驗 2
#### 參考答案
* EEEE = `(*pred_ptr)->r`
* FFFF = `&(*pred_ptr)->r`
#### 思路
```c
/* Find the in-order predecessor:
* This is the rightmost node in the left subtree.
*/
block_t **pred_ptr = &(*node_ptr)->l;
while (EEEE)
pred_ptr = FFFF;
```
從註解中可以知道我們的目標是要找到左子樹的最右節點,所以只要有右子樹存在最右節點必定存在在右子樹中,若不存在右子樹代表目前訪問到的節點就是我們要找的最右節點。
因此 `while` 的條件就設為 `(*pred_ptr)->r`,若存在右子樹 (not NULL),更新 `pred_ptr = &(*pred_ptr)->r` 往右繼續找。
### 測驗 3
#### 參考答案
* GGGG = `head->prev=prev`
* HHHH = `list_entry(pivot,node_t,list)->value`
* IIII = `list_entry(n,node_t,list)->value`
* JJJJ = `pivot`
* KKKK = `right`
#### 思路
==/* GGGG */==
從 `rebuild_list_link` 這個方法裡面的程式碼可以觀察出來他是在把**單向非環狀鏈結串列**重新建構成**雙向環狀鏈結串列**,因此在迴圈下方兩行的程式碼是把尾端的節點和 `head` 連接起來,因此 `GGGG` 的答案就會是 `head->prev=prev`。
==/* HHHH */==
觀察填空區域附近的程式碼可以知道這部分是在處理 pivot,再來從變數名稱 `value` 可以推測它是用來存 pivot 節點的值,因此 `HHHH` 會是 `list_entry(pivot,node_t,list)->value`。
```c
struct list_head *pivot = L;
value = /* HHHH */
struct list_head *p = pivot->next;
pivot->next = NULL; /* break the list */
```
==/* IIII */==
觀察這部分的程式碼可以知道這邊是在快速排序的 partition 過程,`n` 就是目前判斷到的節點。從變數名稱 `n_value` 就可以推測 `IIII` 要填的就是取 `n` 的值的操作 `list_entry(n,node_t,list)->value`。
```c
while (p) {
struct list_head *n = p;
p = p->next;
int n_value = /* IIII */;
if (n_value > value) {
n->next = right;
right = n;
} else {
n->next = left;
left = n;
}
}
```
==/* JJJJ */== ==/* KKKK */==
這部分的程式碼是在模擬遞迴版本的 push stack 操作,由於在合併操作的時候是從 head 插入 result ,因此必須按照 left -> pivot -> right 的順序放入 stack。
```c
begin[i] = left;
begin[i + 1] = /* JJJJ */;
begin[i + 2] = /* KKKK */;
left = right = NULL;
i += 2;
```
* 合併操作部分
```c
if (L) {
L->next = result;
result = L;
}
i--;
```
## [第 2 週測驗題](https://hackmd.io/@sysprog/linux2025-quiz2)
### 測驗 1
#### 參考答案
* AAAA = `list_first_entry`
* BBBB = `list_del`
* CCCC = `list_move_tail`
* DDDD = `list_add`
* EEEE = `list_splice`
* FFFF = `list_splice_tail`
#### 思路
==/* AAAA */==
這邊從 pivot 的型別 (`listitem *`) 加上參數的數量及內容,還有變數本身的名字 pivot,可以知道 `AAAA` 是要取得 list 的第一個 entry 因此答案會是 `list_first_entry`。
==/* BBBB */==
快速排序在得到 pivot 的資訊後第一件事就是把 pivot 從 list 中分離,因此 `BBBB` 就會是 `list_del`。
```c
INIT_LIST_HEAD(&list_less);
INIT_LIST_HEAD(&list_greater);
pivot = AAAA(head, struct listitem, list);
BBBB(&pivot->list);
```
==/* CCCC */==
這部分迴圈是在做快速排序的 partition 操作,它會把**小於** pivot 的值放入 list_less ,**大於等於** pivot 的放入 list_greater。由於題目要實作的事 **stable sorting** ,為維持原本的順序,插入其他 list 時都要從尾端插入,因此 `CCCC` 會是 `list_move_tail`。
```c
list_for_each_entry_safe (item, is, head, list) {
if (cmpint(&item->i, &pivot->i) < 0)
list_move_tail(&item->list, &list_less);
else
CCCC(&item->list, &list_greater);
}
```
==/* DDDD */== ==/* EEEE */== ==/* FFFF */==
從這部分程式碼可以看出來這是快速排序的合併操作,要把排序完的 list_less、list_greater 和 pivot 合併成一個 sorted list。由於題目要求的是 **stable sorting** 並依照 pivot、list_less、list_greater 的順序放入。首先 pivot 只有一個節點,因此 `DDDD` 直接用 `list_add` 插入 pivot,再來把一整串 list_less 放在 pivot 前方,所以 `EEEE` 使用 `list_splice`,最後 list_greater 要放在 pivot 後方必須從尾端插入,因此 `FFFF` 使用 `list_splice_tail`。
> 這邊可以使用 list_splice_tail 的原因可以參考下方部分程式碼,它是將要插入的 list 的頭部接到目標 head 的尾端,讓 list 的尾端成為 head 的新的尾端。
```c
list_quicksort(&list_less);
list_quicksort(&list_greater);
DDDD(&pivot->list, head);
EEEE(&list_less, head);
FFFF(&list_greater, head);
```
* `list_splice_tail` 部分程式碼
```c
/* list_splice_tail */
head->prev = list_last;
list_last->next = head;
list_first->prev = head_last;
head_last->next = list_first;
```
### 測驗 2
#### 參考答案
* GGGG = `14`
* HHHH = `2`
* IIII = `0`
* JJJJ = `3`
* KKKK = `2`
* LLLL = `1`
* MMMM = `~1`
* NNNN = `1`
* PPPP = `2`
#### 思路
`clz2` 是用遞迴來實作分治演算法計算前綴 `bit 0` 有多少個。
因此觀察函數的參數 `x` 就是我們要計算的目標,`c` 可以理解為遞迴的深度。
把每一層 `c` 對應的數值樣子列出來會像以下這樣:
```
c = 0
upper lower
0000 0000 0000 0000 | 0000 0000 0000 0000
c = 1
upper lower
0000 0000 | 0000 0000
c = 2
upper lower
0000 | 0000
c = 3
upper lower
00 | 00
```
==/* GGGG */==
在計算 lower 部分的時候是利用 mask 的方式來取得右半部分,當在 `c = 0` 這層時,我們要的 mask 就是 `0xFFFF`;在 `c = 1` 時,我們要的是右邊 8 個 bits 也就是 mask 是 `0xFF`,若使用 `0xFFFF` 位移來計算的話會是 `0xFFFF >> 8`,回去看 `mask` 陣列宣告也確實是 8;`c = 2` 時同理。
因此就可以推測 `GGGG` 的值會是要讓 `0xFFFF >> mask[3]` 能符合當 `c = 3` 時的 mask 要求。當 `c = 3` 時,lower 部分會需要取右邊 2 個 bits,需要的 mask 就是 `0x3 = 0b0011 = 0xFFFF >> 14` ,因此 `GGGG` 該填入 `14`。
```c
uint32_t upper = (x >> (16 >> c));
uint32_t lower = (x & (0xFFFF >> mask[c]));
```
==/* HHHH */== ==/* IIII */== ==/* JJJJ */==
我們都知道遞迴函式都必須有一個終止條件,所以可以推測在此的 `JJJJ` 就是終止值,觀察程式碼可以發現 `c` 大部分都放在 `mask` 和 `magic` 陣列的索引值內,它們兩個的長度又只有 4,因此可以推測這邊的終止值會是 `3`。
接下來討論終止回傳值應該會是多少?
由於已經到了 `c = 3` 這層,`upper` 只會有 4 種可能 0、1、2、3,從這個三元運算式可以知道當 `upper` 為非 0,回傳的就會是 `magic[upper]`。
此時回頭去想我們對於這個函式的定義是什麼?是要取前綴 0 的數量,所以當 `upper = 1 = 0b01` 我們應該回傳 1、`upper = 2 = 0b10` 應回傳 0、`upper = 3 = 0b11` 應回傳 0。所以我們可以知道 `magic` 的每一項會是 0、1、2、3 對應的前綴 0 個數,因此 `magic[] = {2, 1, 0, 0}`。
所以 `HHHH` = `2`、`IIII` = `0`。
==/* KKKK */==
當 `upper` 為 0 代表前綴 0 的個數就會是 `upper` 的 2 個 0 加上 `lower` 的前綴 0 個數,也就是 `2 + clz2(lower, c + 1)`,而這邊因為已經是終止條件回傳了,所以是用 `2 + magic[lower]`,因此 `KKKK` = `2`。
```c
if (c == JJJJ)
return upper ? magic[upper] : KKKK + magic[lower];
```
==/* LLLL */==
若目前還不是最深層,就遞迴下去解,往更深一層遞迴的時候就必須更新當前的深度,因此 `LLLL` = `1`。
```c
return upper ? clz2(upper, c + 1) : (16 >> (c)) + clz2(lower, c + LLLL);
```
* `clz2` 補完後
```c
static const int mask[] = {0, 8, 12, 14};
static const int magic[] = {2, 1, 0, 0};
unsigned clz2(uint32_t x, int c)
{
if (!x && !c)
return 32;
uint32_t upper = (x >> (16 >> c));
uint32_t lower = (x & (0xFFFF >> mask[c]));
if (c == 3)
return upper ? magic[upper] : 2 + magic[lower];
return upper ? clz2(upper, c + 1) : (16 >> (c)) + clz2(lower, c + 1);
}
```
==/* MMMM */== ==/* NNNN */== ==/* PPPP */==
這部分填空的程式碼是要實作一個 `sqrti` :對一個整數開根號並向下取整。假設對 $x$ 開根號,我們要求 $y=\lfloor\sqrt{x}\rfloor$。
我們已知要求的答案 $y$ 是一個整數,如果我們用二進制來看 $y$ ,它會是一堆 $1 \ll b_i$ 的總和,例如 $y=0b00101011=(1\ll5)+(1\ll3)+(1\ll1)+(1\ll0)$,因此 $y^2=[(1\ll5)+(1\ll3)+(1\ll1)+(1\ll0)]^2$。
此時我們可以用代數的形式來看這個式子,$y^2=(a+b+c+d)^2$,我們只要找到 $a$ $b$ $c$ $d$ 分別是多少就能知道 $y$ 了。
我們知道我們要找的 $a$ $b$ $c$ $d$ 都一定會是 2 的冪,因此我們可以從最高位的 bit 往低位的 bit 去試,如此可以讓 $y^2$ 慢慢逼近 $x$。
>就像是在建構一個浮點數逼近某個值,首先嘗試加上範圍最大的值 `0.5`,如果把它加進去當前逼近的值會大於目標的話,那就不要把它加進當前逼近的值,接下來把範圍縮小成 `0.25`,若把它加進目前逼近的值還是小於目標值那就把它加進去,再來 `0.125`、`0.0625`、`...` ,以此類推。
所以我們可以先找到最高的可能的 bit 是哪一個,假設 $x$ 的最高位 bit 為 $b$,那麼 $x$ 一定會大於等於 $(2^{\lfloor\frac{b}{2}\rfloor})^2$ 並且小於 $(2^{\lfloor\frac{b}{2}\rfloor+1})^2$,因此我們就可以從第 $\lfloor\frac{b}{2}\rfloor$ 個 bit 開始往低位嘗試。
就目前的理解可以寫成下面的程式碼:
```c
uint64_t sqrti(uint64_t x)
{
uint64_t m, y = 0;
if (x <= 1)
return x;
int total_bits = 64;
int shift = (total_bits - 1 - clz64(x)) >> 1;
m = 1ULL << shift;
while (m) {
uint64_t b = y + m;
if (x >= b*b)
y += m;
m >>= 1;
}
return y;
}
```
不過題目要求只能用加減法和位移運算。由於乘法出現在 $(a+b+c+d)^2$ 這個式子,因此我們可以嘗試從這個式子下手去看。
先從最簡單的 $(a+b)^2$ 來看:
這個式子可以把 $b$ 提出來寫成最下面那個形式。
$$
\begin{split}
(a+b)^2&=a^2 + 2ab + b^2
\\&=a^2 + b(2a+b)
\end{split}
$$
接下來看
$$
(a+b+c)^2=a^2+b^2+c^2+2ab+2bc+2ac
$$
整理後可以得到
$$
\begin{split}
(a+b+c)^2&=(a^2+2ab+b^2)+c^2+2bc+2ac\\
&=(a+b)^2+c(2(a+b)+c)
\end{split}
$$
同理在 $(a+b+c+d)^2$ 也可以展開成
$$
(a+b+c+d)^2=(a+b+c)^2 + d(2(a+b+c)+d)
$$
寫成一般式
$$
(a_1+a_2+\dots+a_n)^2=(\sum_{i=1}^{n-1}a_i)^2+a_n(2(\sum_{i=1}^{n-1}a_i)+a_n)
$$
這樣感覺還看不太出來如何不用到乘法運算,我們可以把後面展開變成以下
$$
(a_1+a_2+\dots+a_n)^2=(\sum_{i=1}^{n-1}a_i)^2+a_n^2+2a_n(\sum_{i=1}^{n-1}a_i)
$$
用我們目前的題目環境去理解這個式子的話:
* 等號左邊的 $(a_1+a_2+\dots+a_n)^2$ 就是我們的 $x$
* 等號右邊左方的 $(\sum_{i=1}^{n-1}a_i)^2$ 就是我們目前逼近的 $y^2$ 值
* 我們就令 $(y')^2 = (\sum_{i=1}^{n-1}a_i)^2$,也就是 $y' = \sum_{i=1}^{n-1}a_i$
* $a_n$ 就是我們目前嘗試到的 `m`
所以式子就可以寫成:
$$
x = (y')^2 + m^2 + 2my'
$$
由於我們的 $m$ 都會是 2 的冪,因此在做 $m^2$ 的時候只需要做 `m << shift` 就好 (需同時維護 `shift` 的值)。
同理,$2my'$ 也就是對 $y'$ 左移 `shift+1` 就能計算完成。
$(y')^2$ 的部分則可以在算出來時直接讓 $x$ 減掉,等同於是把它移到等號左邊,並另該結果為 $x'$,變成 $x'=x - (y')^2 = m^2 + 2my'$,這邊的 $x'$ 具體一點就是我們目前逼近的值與目標值之間的差值。
準確來講,只要我們嘗試算出來的結果比差值 $x'$ 還要小那這個 `m` 就是我們要找的值,也就是:
$$
x' \ge m^2+2my'
$$
理解到這裡,就能寫出以下程式碼:
```c
uint64_t sqrti(uint64_t x)
{
uint64_t m, y = 0;
if (x <= 1)
return x;
int total_bits = 64;
int shift = (total_bits - 1 - clz64(x)) >> 1;
m = 1ULL << shift;
while (m) {
uint64_t b = (m << shift) + (y << (shift + 1));
if (x >= b) {
x -= b;
y += m;
}
m >>= 1;
shift--;
}
return y;
}
```
這樣的寫法已經不會用到乘法運算了,不過還有地方可以再精簡。
如果我們對 `m` 預先再左移 `shift` 次就可以省去計算 `b` 時所需要的左移操作,因此不用再對 `shift` 除以 2,只需要確保他是偶數 (因為$(\lfloor\frac{b}{2}\rfloor)^2$),且每次迴圈結束都要右移 2,這是為了讓迴圈跑的次數與原先相同。所以這邊 `MMMM` 填入 `~1`、`PPPP` = `2`。
由於 `y` 是由 `m` 來更新的,所以 `y` 的值會是會是我們之前的 `y` 再左移 `shift` 的樣子,所以我們每次迴圈做完都要把 `y` 右移 1。`NNNN` = `1`。
而在計算 `b` 會需要 `y << (shift+1)`,也就是前一次迴圈的 `y`,因此右移 1 的操作必須在計算 `b` 之後,且在下一次更新,也就是 `y += m` 之前。
簡化過後,就能變成和題目一樣的程式碼了:
```c
uint64_t sqrti(uint64_t x)
{
uint64_t m, y = 0;
if (x <= 1)
return x;
int total_bits = 64;
int shift = (total_bits - 1 - clz64(x)) & ~1;
m = 1ULL << shift;
while (m) {
uint64_t b = y + m;
y >>= 1;
if (x >= b) {
x -= b;
y += m;
}
m >>= 2;
}
return y;
}
```
#### 延伸問題
##### 擴充為向上取整 $\lceil(\sqrt{x})\rceil$
若計算結束後 `x` 的值,也就是與目標的差值,為非 0 的話,代表原先傳入的 `x` 不是平方數,因此算出來的 `y` 需要加一;若計算結束後的 `x` 為 0,代表是平方數結果就是 `y`。
因此只需要更改回傳的那行程式碼就能達成:
```c
return y + !!x;
```
### 測驗 3
#### 參考答案
* AAAA = `map->bits`
* BBBB = `map->bits`
* CCCC = `first->pprev`
* DDDD = `n->pprev`
* EEEE = `n->pprev`
#### 思路
* 雜湊函數:
```c
#define GOLDEN_RATIO_32 0x61C88647
static inline unsigned int hash(unsigned int val, unsigned int bits)
{
/* High bits are more random, so use them. */
return (val * GOLDEN_RATIO_32) >> (32 - bits);
}
```
==/* AAAA */== ==/* BBBB */==
從雜湊函數可以看出來第二個參數 `bits` 就是需要傳入先前在初始化設定的 `map->bits`。
* `find_key`
```c
struct hlist_head *head = &(map->ht)[hash(key, AAAA)];
```
* `map_add`
```c
struct hlist_head *h = &map->ht[hash(key, BBBB)];
```
==/* CCCC */==
已知 `kn` 是我們要插入的 hash_key,所以`n` 就會是我們要插入的節點,而 `first` 是目前同個 hash_key 的 list 的開頭。從 `n->next = first` 可以知道是從 list 的頭部端插入,所以如果已存在其他相同 hash_key 的節點 (`if (first) 成立`),我們就要更新 first 的 pprev 所以 `CCCC` 是 `first->pprev`。
==/* DDDD */==
承 C ,把 `n` 插入頭部端後也需要更新其 pprev,因此 `DDDD` = `n->pprev`。
* `map_add`
```c
struct hlist_node *n = &kn->node, *first = h->first;
n->next = first;
if (first)
CCCC = &n->next;
h->first = n;
DDDD = &h->first;
```
==/* EEEE */==
從 `map_deinit` 完整程式碼可以知道 `n` 是我們要從 list 中移除並釋放掉的點,下方部分程式碼就是在將 `n` 從 list 中分離,並且從後續操作可以知道 next 就是 `n` 的下個節點,pprev 就是 `n` 的 pprev。因此 `EEEE` = `n->pprev`。
* `map_deinit`
```c
struct hlist_node *next = n->next, **pprev = EEEE;
*pprev = next;
if (next)
next->pprev = pprev;
n->next = NULL, n->pprev = NULL;
```