Try   HackMD

2025q1 Homework2 (quiz1+2)

contributed by < BrotherHong >

第 1 週測驗題

測驗 1

參考答案

  • AAAA = &l->head
  • BBBB = before
  • CCCC = &(*p)->next
  • DDDD = (*p)->next

思路

  1. 題目為單向非環狀鏈結串列
  2. 想插入節點 item 到某個節點的前方,就必須先找到目標節點 before 的前一個節點 prev
  3. 假設我們已經找到節點 prev,簡單進行兩個操作即可完成連接
    • item->next = prev->next
    • prev->next = item

如果依照上述思路去實作,在尋找 before 的時候會需要兩個節點指標 currprev 分別記錄目前尋找到的節點和其前一個節點

當執行思路的第 3 步時,prev 可能為 NULL,這時候就要設條件去判斷,為了簡化判斷這個步驟,我們就可以利用指標的指標 p 來簡化 (p = &prev->next)。

list_item_t **p;
for (p = &l->head; *p != before; p = &(*p)->next)
    ;
*p = item;
(*p)->next = before;

延伸問題

解釋程式碼原理

在測試程式碼上方定義了兩個 macro ,下面這個是其中一個,可以發現它們用了 do while(0) 來包住要執行的程式碼,用意是為了讓呼叫 macro 的語句結尾可以加上分號,寫起來會像是呼叫方法的樣子。

#define my_run_test(test)       \
    do {                        \
        char *message = test(); \
        tests_run++;            \
        if (message)            \
            return message;     \
    } while (0)

參考
Why use apparently meaningless do-while and if-else statements in macros? Stack Overflow

除此之外,也能避免在其他使用情況造成語意錯誤
在上方參考的論壇的內文有問到:

常看到使用 do while(0) 和 if(1) { } else

#define FOO(X) do { f(X); g(X); } while (0)
#define FOO(X) if (1) { f(X); g(X); } else

為什麼不像下面這樣寫就好?

#define FOO(X) f(X); g(X)

情況舉例

if (stmt)
    FOO(X);
else
    func();

在此展開會變成

if (stmt)
    f(X); g(X);
else
    func();

由於沒有用大括號括住,會造成語法上的錯誤。

用 if (1) { } else 這種寫法,後面加上 else 也是為了避免在此情況下出現 dangling else

如果只用 if (1) { }

if (stmt)
    if (1) { f(X); g(X); }
else
    func();

這個 else 就會被判定給 if (1),而不是上方的 if (stmt)


在現有基礎上加入合併排序

測驗 2

參考答案

  • EEEE = (*pred_ptr)->r
  • FFFF = &(*pred_ptr)->r

思路

    /* Find the in-order predecessor:
     * This is the rightmost node in the left subtree.
     */
    block_t **pred_ptr = &(*node_ptr)->l;
    while (EEEE)
        pred_ptr = FFFF;

從註解中可以知道我們的目標是要找到左子樹的最右節點,所以只要有右子樹存在最右節點必定存在在右子樹中,若不存在右子樹代表目前訪問到的節點就是我們要找的最右節點。

因此 while 的條件就設為 (*pred_ptr)->r,若存在右子樹 (not NULL),更新 pred_ptr = &(*pred_ptr)->r 往右繼續找。

測驗 3

參考答案

  • GGGG = head->prev=prev
  • HHHH = list_entry(pivot,node_t,list)->value
  • IIII = list_entry(n,node_t,list)->value
  • JJJJ = pivot
  • KKKK = right

思路

/* GGGG */
rebuild_list_link 這個方法裡面的程式碼可以觀察出來他是在把單向非環狀鏈結串列重新建構成雙向環狀鏈結串列,因此在迴圈下方兩行的程式碼是把尾端的節點和 head 連接起來,因此 GGGG 的答案就會是 head->prev=prev

/* HHHH */
觀察填空區域附近的程式碼可以知道這部分是在處理 pivot,再來從變數名稱 value 可以推測它是用來存 pivot 節點的值,因此 HHHH 會是 list_entry(pivot,node_t,list)->value

    struct list_head *pivot = L;
    value = /* HHHH */
    struct list_head *p = pivot->next;
    pivot->next = NULL; /* break the list */

/* IIII */
觀察這部分的程式碼可以知道這邊是在快速排序的 partition 過程,n 就是目前判斷到的節點。從變數名稱 n_value 就可以推測 IIII 要填的就是取 n 的值的操作 list_entry(n,node_t,list)->value

    while (p) {
        struct list_head *n = p;
        p = p->next;
        int n_value = /* IIII */;
        if (n_value > value) {
            n->next = right;
            right = n;
        } else {
            n->next = left;
            left = n;
        }
    }

/* JJJJ */ /* KKKK */
這部分的程式碼是在模擬遞迴版本的 push stack 操作,由於在合併操作的時候是從 head 插入 result ,因此必須按照 left -> pivot -> right 的順序放入 stack。

    begin[i] = left;
    begin[i + 1] = /* JJJJ */;
    begin[i + 2] = /* KKKK */;
    left = right = NULL;
    i += 2;
  • 合併操作部分
    if (L) {
        L->next = result;
        result = L;
    }
    i--;

第 2 週測驗題

測驗 1

參考答案

  • AAAA = list_first_entry
  • BBBB = list_del
  • CCCC = list_move_tail
  • DDDD = list_add
  • EEEE = list_splice
  • FFFF = list_splice_tail

思路

/* AAAA */
這邊從 pivot 的型別 (listitem *) 加上參數的數量及內容,還有變數本身的名字 pivot,可以知道 AAAA 是要取得 list 的第一個 entry 因此答案會是 list_first_entry

/* BBBB */
快速排序在得到 pivot 的資訊後第一件事就是把 pivot 從 list 中分離,因此 BBBB 就會是 list_del

    INIT_LIST_HEAD(&list_less);
    INIT_LIST_HEAD(&list_greater);

    pivot = AAAA(head, struct listitem, list);
    BBBB(&pivot->list);

/* CCCC */
這部分迴圈是在做快速排序的 partition 操作,它會把小於 pivot 的值放入 list_less ,大於等於 pivot 的放入 list_greater。由於題目要實作的事 stable sorting ,為維持原本的順序,插入其他 list 時都要從尾端插入,因此 CCCC 會是 list_move_tail

    list_for_each_entry_safe (item, is, head, list) {
        if (cmpint(&item->i, &pivot->i) < 0)
            list_move_tail(&item->list, &list_less);
        else
            CCCC(&item->list, &list_greater);
    }

/* DDDD */ /* EEEE */ /* FFFF */
從這部分程式碼可以看出來這是快速排序的合併操作,要把排序完的 list_less、list_greater 和 pivot 合併成一個 sorted list。由於題目要求的是 stable sorting 並依照 pivot、list_less、list_greater 的順序放入。首先 pivot 只有一個節點,因此 DDDD 直接用 list_add 插入 pivot,再來把一整串 list_less 放在 pivot 前方,所以 EEEE 使用 list_splice,最後 list_greater 要放在 pivot 後方必須從尾端插入,因此 FFFF 使用 list_splice_tail

這邊可以使用 list_splice_tail 的原因可以參考下方部分程式碼,它是將要插入的 list 的頭部接到目標 head 的尾端,讓 list 的尾端成為 head 的新的尾端。

    list_quicksort(&list_less);
    list_quicksort(&list_greater);

    DDDD(&pivot->list, head);
    EEEE(&list_less, head);
    FFFF(&list_greater, head);
  • list_splice_tail 部分程式碼
    /* list_splice_tail */
    head->prev = list_last;
    list_last->next = head;

    list_first->prev = head_last;
    head_last->next = list_first;

測驗 2

參考答案

  • GGGG = 14
  • HHHH = 2
  • IIII = 0
  • JJJJ = 3
  • KKKK = 2
  • LLLL = 1
  • MMMM = ~1
  • NNNN = 1
  • PPPP = 2

思路

clz2 是用遞迴來實作分治演算法計算前綴 bit 0 有多少個。
因此觀察函數的參數 x 就是我們要計算的目標,c 可以理解為遞迴的深度。
把每一層 c 對應的數值樣子列出來會像以下這樣:

c = 0
            upper                lower
    0000 0000 0000 0000 | 0000 0000 0000 0000
    
c = 1
                upper       lower
              0000 0000 | 0000 0000
              
c = 2
                  upper   lower
                   0000 | 0000
                   
c = 3
                   upper lower
                     00 | 00

/* GGGG */
在計算 lower 部分的時候是利用 mask 的方式來取得右半部分,當在 c = 0 這層時,我們要的 mask 就是 0xFFFF;在 c = 1 時,我們要的是右邊 8 個 bits 也就是 mask 是 0xFF,若使用 0xFFFF 位移來計算的話會是 0xFFFF >> 8,回去看 mask 陣列宣告也確實是 8;c = 2 時同理。

因此就可以推測 GGGG 的值會是要讓 0xFFFF >> mask[3] 能符合當 c = 3 時的 mask 要求。當 c = 3 時,lower 部分會需要取右邊 2 個 bits,需要的 mask 就是 0x3 = 0b0011 = 0xFFFF >> 14 ,因此 GGGG 該填入 14

uint32_t upper = (x >> (16 >> c));
uint32_t lower = (x & (0xFFFF >> mask[c]));

/* HHHH */ /* IIII */ /* JJJJ */
我們都知道遞迴函式都必須有一個終止條件,所以可以推測在此的 JJJJ 就是終止值,觀察程式碼可以發現 c 大部分都放在 maskmagic 陣列的索引值內,它們兩個的長度又只有 4,因此可以推測這邊的終止值會是 3

接下來討論終止回傳值應該會是多少?
由於已經到了 c = 3 這層,upper 只會有 4 種可能 0、1、2、3,從這個三元運算式可以知道當 upper 為非 0,回傳的就會是 magic[upper]

此時回頭去想我們對於這個函式的定義是什麼?是要取前綴 0 的數量,所以當 upper = 1 = 0b01 我們應該回傳 1、upper = 2 = 0b10 應回傳 0、upper = 3 = 0b11 應回傳 0。所以我們可以知道 magic 的每一項會是 0、1、2、3 對應的前綴 0 個數,因此 magic[] = {2, 1, 0, 0}
所以 HHHH = 2IIII = 0

/* KKKK */
upper 為 0 代表前綴 0 的個數就會是 upper 的 2 個 0 加上 lower 的前綴 0 個數,也就是 2 + clz2(lower, c + 1),而這邊因為已經是終止條件回傳了,所以是用 2 + magic[lower],因此 KKKK = 2

if (c == JJJJ)
        return upper ? magic[upper] : KKKK + magic[lower];

/* LLLL */
若目前還不是最深層,就遞迴下去解,往更深一層遞迴的時候就必須更新當前的深度,因此 LLLL = 1

return upper ? clz2(upper, c + 1) : (16 >> (c)) + clz2(lower, c + LLLL);
  • clz2 補完後
static const int mask[] = {0, 8, 12, 14};
static const int magic[] = {2, 1, 0, 0};

unsigned clz2(uint32_t x, int c)
{
    if (!x && !c)
        return 32;

    uint32_t upper = (x >> (16 >> c));
    uint32_t lower = (x & (0xFFFF >> mask[c]));
    if (c == 3)
        return upper ? magic[upper] : 2 + magic[lower];
    return upper ? clz2(upper, c + 1) : (16 >> (c)) + clz2(lower, c + 1);
}

/* MMMM */ /* NNNN */ /* PPPP */
這部分填空的程式碼是要實作一個 sqrti :對一個整數開根號並向下取整。假設對

x 開根號,我們要求
y=x

我們已知要求的答案

y 是一個整數,如果我們用二進制來看
y
,它會是一堆
1bi
的總和,例如
y=0b00101011=(15)+(13)+(11)+(10)
,因此
y2=[(15)+(13)+(11)+(10)]2

此時我們可以用代數的形式來看這個式子,

y2=(a+b+c+d)2,我們只要找到
a
b
c
d
分別是多少就能知道
y
了。

我們知道我們要找的

a
b
c
d
都一定會是 2 的冪,因此我們可以從最高位的 bit 往低位的 bit 去試,如此可以讓
y2
慢慢逼近
x

就像是在建構一個浮點數逼近某個值,首先嘗試加上範圍最大的值 0.5,如果把它加進去當前逼近的值會大於目標的話,那就不要把它加進當前逼近的值,接下來把範圍縮小成 0.25,若把它加進目前逼近的值還是小於目標值那就把它加進去,再來 0.1250.0625... ,以此類推。

所以我們可以先找到最高的可能的 bit 是哪一個,假設

x 的最高位 bit 為
b
,那麼
x
一定會大於等於
(2b2)2
並且小於
(2b2+1)2
,因此我們就可以從第
b2
個 bit 開始往低位嘗試。

就目前的理解可以寫成下面的程式碼:

uint64_t sqrti(uint64_t x)
{
    uint64_t m, y = 0;
    if (x <= 1)
        return x;

    int total_bits = 64;
    
    int shift = (total_bits - 1 - clz64(x)) >> 1;
    m = 1ULL << shift;

    while (m) {
        uint64_t b = y + m;
        if (x >= b*b)
            y += m;
        m >>= 1;
    }
    
    return y;
}

不過題目要求只能用加減法和位移運算。由於乘法出現在

(a+b+c+d)2 這個式子,因此我們可以嘗試從這個式子下手去看。

先從最簡單的

(a+b)2 來看:
這個式子可以把
b
提出來寫成最下面那個形式。
(a+b)2=a2+2ab+b2=a2+b(2a+b)

接下來看

(a+b+c)2=a2+b2+c2+2ab+2bc+2ac
整理後可以得到
(a+b+c)2=(a2+2ab+b2)+c2+2bc+2ac=(a+b)2+c(2(a+b)+c)

同理在
(a+b+c+d)2
也可以展開成
(a+b+c+d)2=(a+b+c)2+d(2(a+b+c)+d)

寫成一般式
(a1+a2++an)2=(i=1n1ai)2+an(2(i=1n1ai)+an)

這樣感覺還看不太出來如何不用到乘法運算,我們可以把後面展開變成以下

(a1+a2++an)2=(i=1n1ai)2+an2+2an(i=1n1ai)

用我們目前的題目環境去理解這個式子的話:

  • 等號左邊的
    (a1+a2++an)2
    就是我們的
    x
  • 等號右邊左方的
    (i=1n1ai)2
    就是我們目前逼近的
    y2
    • 我們就令
      (y)2=(i=1n1ai)2
      ,也就是
      y=i=1n1ai
  • an
    就是我們目前嘗試到的 m

所以式子就可以寫成:

x=(y)2+m2+2my

由於我們的

m 都會是 2 的冪,因此在做
m2
的時候只需要做 m << shift 就好 (需同時維護 shift 的值)。
同理,
2my
也就是對
y
左移 shift+1 就能計算完成。

(y)2 的部分則可以在算出來時直接讓
x
減掉,等同於是把它移到等號左邊,並另該結果為
x
,變成
x=x(y)2=m2+2my
,這邊的
x
具體一點就是我們目前逼近的值與目標值之間的差值。

準確來講,只要我們嘗試算出來的結果比差值

x 還要小那這個 m 就是我們要找的值,也就是:
xm2+2my

理解到這裡,就能寫出以下程式碼:

uint64_t sqrti(uint64_t x)
{
    uint64_t m, y = 0;
    if (x <= 1)
        return x;

    int total_bits = 64;
    
    int shift = (total_bits - 1 - clz64(x)) >> 1;
    m = 1ULL << shift;

    while (m) {
        uint64_t b = (m << shift) + (y << (shift + 1));
        if (x >= b) {
            x -= b;
            y += m;
        }
        m >>= 1;
        shift--;
    }
    
    return y;
}

這樣的寫法已經不會用到乘法運算了,不過還有地方可以再精簡。

如果我們對 m 預先再左移 shift 次就可以省去計算 b 時所需要的左移操作,因此不用再對 shift 除以 2,只需要確保他是偶數 (因為

(b2)2),且每次迴圈結束都要右移 2,這是為了讓迴圈跑的次數與原先相同。所以這邊 MMMM 填入 ~1PPPP = 2

由於 y 是由 m 來更新的,所以 y 的值會是會是我們之前的 y 再左移 shift 的樣子,所以我們每次迴圈做完都要把 y 右移 1。NNNN = 1

而在計算 b 會需要 y << (shift+1),也就是前一次迴圈的 y,因此右移 1 的操作必須在計算 b 之後,且在下一次更新,也就是 y += m 之前。

簡化過後,就能變成和題目一樣的程式碼了:

uint64_t sqrti(uint64_t x)
{
    uint64_t m, y = 0;
    if (x <= 1)
        return x;

    int total_bits = 64;

    int shift = (total_bits - 1 - clz64(x)) & ~1;
    m = 1ULL << shift;

    while (m) {
        uint64_t b = y + m;
        y >>= 1;
        if (x >= b) {
            x -= b;
            y += m;
        }
        m >>= 2;
    }
    return y;
}

延伸問題

擴充為向上取整
(x)

若計算結束後 x 的值,也就是與目標的差值,為非 0 的話,代表原先傳入的 x 不是平方數,因此算出來的 y 需要加一;若計算結束後的 x 為 0,代表是平方數結果就是 y

因此只需要更改回傳的那行程式碼就能達成:

return y + !!x;

測驗 3

參考答案

  • AAAA = map->bits
  • BBBB = map->bits
  • CCCC = first->pprev
  • DDDD = n->pprev
  • EEEE = n->pprev

思路

  • 雜湊函數:
#define GOLDEN_RATIO_32 0x61C88647
static inline unsigned int hash(unsigned int val, unsigned int bits)
{
    /* High bits are more random, so use them. */
    return (val * GOLDEN_RATIO_32) >> (32 - bits);
}

/* AAAA */ /* BBBB */
從雜湊函數可以看出來第二個參數 bits 就是需要傳入先前在初始化設定的 map->bits

  • find_key
struct hlist_head *head = &(map->ht)[hash(key, AAAA)];
  • map_add
struct hlist_head *h = &map->ht[hash(key, BBBB)];

/* CCCC */
已知 kn 是我們要插入的 hash_key,所以n 就會是我們要插入的節點,而 first 是目前同個 hash_key 的 list 的開頭。從 n->next = first 可以知道是從 list 的頭部端插入,所以如果已存在其他相同 hash_key 的節點 (if (first) 成立),我們就要更新 first 的 pprev 所以 CCCCfirst->pprev

/* DDDD */
承 C ,把 n 插入頭部端後也需要更新其 pprev,因此 DDDD = n->pprev

  • map_add
    struct hlist_node *n = &kn->node, *first = h->first;

    n->next = first;
    if (first)
        CCCC = &n->next;
    h->first = n;
    DDDD = &h->first;

/* EEEE */
map_deinit 完整程式碼可以知道 n 是我們要從 list 中移除並釋放掉的點,下方部分程式碼就是在將 n 從 list 中分離,並且從後續操作可以知道 next 就是 n 的下個節點,pprev 就是 n 的 pprev。因此 EEEE = n->pprev

  • map_deinit
    struct hlist_node *next = n->next, **pprev = EEEE;
    *pprev = next;
    if (next)
        next->pprev = pprev;
    n->next = NULL, n->pprev = NULL;