--- tags: linux2022 --- # 2022q1 Homework2 (quiz2) contributed by < `blueskyson` > [題目連結](https://hackmd.io/@sysprog/linux2022-quiz2) ## 測驗 1 考慮以下對二個無號整數取平均值的程式碼: ```c #include <stdint.h> uint32_t average(uint32_t a, uint32_t b) { return (a + b) / 2; } ``` 這個直覺的解法會有 overflow 的問題,若我們已知 a, b 數值的大小,可用下方程式避免 overflow: ```c #include <stdint.h> uint32_t average(uint32_t low, uint32_t high) { return low + (high - low) / 2; } ``` 接著我們可改寫為以下等價的實作: ```c #include <stdint.h> uint32_t average(uint32_t a, uint32_t b) { return (a >> 1) + (b >> 1) + (EXP1); } ``` 我們再次改寫為以下等價的實作: ```c uint32_t average(uint32_t a, uint32_t b) { return (EXP2) + ((EXP3) >> 1); } ``` ### 解題 在第一種改寫的實作中,當 `a` 與 `b` 皆為奇數時時,在 `a >> 1` 與 `b >> 1` 後 `a/2` 與 `a/b` 各會損失 1,需要把 1 加回去。因此 ==EXP1== 為 ==a & b & 1==。 在第二種改寫的實作則是使用加法器的概念,`a & b` 是 `a + b` 的進位值,`a ^ b` 則是 `a + b` 的和。因此 ==EXP2== 為 ==a & b==、==EXP3== 為 ==a ^ b==。 ## 測驗 2 改寫〈[解讀計算機編碼](https://hackmd.io/@sysprog/binary-representation)〉一文的「不需要分支的設計」一節提供的程式碼 `min`,我們得到以下實作 (`max`): ```c #include <stdint.h> uint32_t max(uint32_t a, uint32_t b) { return a ^ ((EXP4) & -(EXP5)); } ``` 延伸閱讀: - [That XOR Trick](https://florian.github.io/xor-trick/) - [Branchless Programming: Why “If” is Sloowww… and what we can do about it!](https://youtu.be/bVJ-mWWL7cE) - [Making Your Code Faster by Taming Branches](https://www.infoq.com/articles/making-code-faster-taming-branches/) ### 解題 利用自己 xor 自己等於 0 的特殊性質,我預期當 `a > b` 時,`max(a, b)` 會執行 `(a ^ 0)` 以回傳 `a`,反之則執行 `(a ^ a ^ b)`。因此很明顯的得到 ==EXP4== 為 ==a ^ b==。 當 `a > b` 時 `(a ^ b) & -(EXP5)` 必須為 `0`,才能使得 `max(a, b)` 回傳 `a ^ 0`;反之 `(a ^ b) & -(EXP5)` 必須為 `a ^ b`,才能使得 `max(a, b)` 回傳 ` a ^ a ^ b`。故 ==EXP5== 為 ==a < b== 時,恰好可以製造出 `(a ^ b) & 0` 與 `(a ^ b) & 0xffff` 來控制 `max` 的回傳值。 ## 測驗 3 考慮以下 64 位元 GCD (greatest common divisor, 最大公因數) 求值函式: ```c #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; while (v) { uint64_t t = v; v = u % v; u = t; } return u; } ``` 改寫為以下等價實作: ```c #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; int shift; for (shift = 0; !((u | v) & 1); shift++) { u /= 2, v /= 2; } while (!(u & 1)) u /= 2; do { while (!(v & 1)) v /= 2; if (u < v) { v -= u; } else { uint64_t t = u - v; u = v; v = t; } } while (COND); return RET; } ``` ### 解題 第 1 步: `if (!u || !v) return u | v;` 判斷 `u`, `v` 是否是 0 ,若其中一個是 0 就回傳 0。 第 2 步: ```c for (shift = 0; !((u | v) & 1); shift++) { u /= 2, v /= 2; } ``` 若 `u`, `v` 同時可被 2 整除,就將 `u`, `v` 同除以 2 ,並且讓 `shift` 加 1 ,由此可知 `u`, `v` 同為 `(0x1 << shift)` 的倍數,也就是將 $2^{shift}$ 作為公因數提出來。 第 3 步: ```c while (!(u & 1)) u /= 2; ``` 在第 2 步已經將 $2^{shift}$ 提出來了,代表接下來 gcd 的過程不會再萃取出偶數公因數,但是 `u` 或 `v` 可能還是偶數,繼續將 `u` 除以 2 直到 `u` 不是偶數。 第 4 步: ```c do { while (!(v & 1)) v /= 2; if (u < v) { v -= u; } else { uint64_t t = u - v; u = v; v = t; } } while (COND); ``` 這個 do while 迴圈持續相減過程就是輾轉相除。與第 3 步同理,每一輪迭代都將 `v` 除以 2 直到 `v` 不是偶數。 - 當 `v` 大於 `u`,`v - u` 可以視為 `v ÷ u` 的餘數,將 `v` 減去 `u` 之後執行下一輪迭代。 - 當 `v` 小於 `u`,`u - v` 可以視為 `u ÷ v` 的餘數,將 `u` 減去 `v` 之後執行下一輪迭代。 - 當 `v` 等於 `u` 時,`v` 即為所求公因數。由此可以推斷 `do while` 的條件 ==COND== 即為 ==v==,當 `u == v` 時 `v` 會與 `u` 相減變成 0 以跳出迴圈。 第 5 步 ```c return RET; ``` 回傳時要將原本同除的 $2^{shift}$ 乘回去,所以 ==RET== 為 ==u << shift==。 ## 測驗 4 在影像處理中,[bit array](https://en.wikipedia.org/wiki/Bit_array) (也稱 bitset) 廣泛使用,考慮以下程式碼: ```c #include <stddef.h> size_t naive(uint64_t *bitmap, size_t bitmapsize, uint32_t *out) { size_t pos = 0; for (size_t k = 0; k < bitmapsize; ++k) { uint64_t bitset = bitmap[k]; size_t p = k * 64; for (int i = 0; i < 64; i++) { if ((bitset >> i) & 0x1) out[pos++] = p + i; } } return pos; } ``` 考慮 GNU extension 的 [__builtin_ctzll](https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html) 的行為是回傳由低位往高位遇上連續多少個 `0` 才碰到 `1`。 > 範例: 當 `a = 16` > 16 這個十進位數值的二進位表示法為 00000000 00000000 00000000 00010000 > 從低位元 (即右側) 往高位元,我們可發現 0 → 0 → 0 → 0 → 1,於是 ctz 就為 4,表示最低位元往高位元有 4 個 `0`。 用以改寫的程式碼如下: ```c= #include <stddef.h> size_t improved(uint64_t *bitmap, size_t bitmapsize, uint32_t *out) { size_t pos = 0; uint64_t bitset; for (size_t k = 0; k < bitmapsize; ++k) { bitset = bitmap[k]; while (bitset != 0) { uint64_t t = EXP6; int r = __builtin_ctzll(bitset); out[pos++] = k * 64 + r; bitset ^= t; } } return pos; } ``` 其中第 9 行的作用是找出目前最低位元的 `1`,並紀錄到 `t` 變數。若 bitmap 越鬆散 (即 `1` 越少),於是 `improved` 的效益就更高。 ### 解題 `improved` 改寫程式碼根據 trailing zero 的數量來判斷最靠近 LSB 的 `1` 的 bit,所以每次紀錄完最靠近 LSB 的 `1` 的 bit,都必須將該 bit 的值變為 `0` 且其他 bit 保持不變。EXP6 的用途就是把 LSB 單獨提取出來,並賦值給 `t`,故 ==EXP6== 為 ==bitset & -bitset==。 `a & -a` 的特性: ``` bitset = xxxx 1000 & -bitset = yyyy 1000 -------------------------- t = 0000 1000 ``` 我們可以看到在 `bitset & -bitset` 之後,`t` 變成只剩 LSB 的數值。之後再讓 `bitset ^ t` 就能把 `bitset` 的 LSB,進行下一輪計算: ``` bitset = xxxx 1000 ^ t = 0000 1000 -------------------------- bitset = xxxx 0000 ``` ## 測驗 5 以下是 [LeetCode 166. Fraction to Recurring Decimal](https://leetcode.com/problems/fraction-to-recurring-decimal/) 的可能實作: ```c #include <stdbool.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include "list.h" struct rem_node { int key; int index; struct list_head link; }; static int find(struct list_head *heads, int size, int key) { struct rem_node *node; int hash = key % size; list_for_each_entry (node, &heads[hash], link) { if (key == node->key) return node->index; } return -1; } char *fractionToDecimal(int numerator, int denominator) { int size = 1024; char *result = malloc(size); char *p = result; if (denominator == 0) { result[0] = '\0'; return result; } if (numerator == 0) { result[0] = '0'; result[1] = '\0'; return result; } /* using long long type make sure there has no integer overflow */ long long n = numerator; long long d = denominator; /* deal with negtive cases */ if (n < 0) n = -n; if (d < 0) d = -d; bool sign = (float) numerator / denominator >= 0; if (!sign) *p++ = '-'; long long remainder = n % d; long long division = n / d; sprintf(p, "%ld", division > 0 ? (long) division : (long) -division); if (remainder == 0) return result; p = result + strlen(result); *p++ = '.'; /* Using a map to record all of reminders and their position. * if the reminder appeared before, which means the repeated loop begin, */ char *decimal = malloc(size); memset(decimal, 0, size); char *q = decimal; size = 1333; struct list_head *heads = malloc(size * sizeof(*heads)); for (int i = 0; i < size; i++) INIT_LIST_HEAD(&heads[i]); for (int i = 0; remainder; i++) { int pos = find(heads, size, remainder); if (pos >= 0) { while (PPP > 0) *p++ = *decimal++; *p++ = '('; while (*decimal != '\0') *p++ = *decimal++; *p++ = ')'; *p = '\0'; return result; } struct rem_node *node = malloc(sizeof(*node)); node->key = remainder; node->index = i; MMM(&node->link, EEE); *q++ = (remainder * 10) / d + '0'; remainder = (remainder * 10) % d; } strcpy(p, decimal); return result; } ``` ### 解題 為了揣摩這個程式的邏輯,舉循環小數 `12 / 11` 為例: **step 0:** 首先計算 `12 / 11` 的商數 `division = 1` 與餘數 `remainder = 1`,此時 `result = "1."`。接下來初始化 hash table,然後進入 for 迴圈計算小數部份。此時的狀態如下: ```graphviz digraph G { node[shape=record]; result[label="<0>1|<1>.|||" shape=record] p [label="p", style=dashed, color=grey]; r [label="result", style=dashed, color=grey]; r -> result:0:w [style=dashed, color=grey]; p -> result:2 [style=dashed, color=grey]; decimal [label="<0>|<1>|||" shape=record] q [label="q", style=dashed, color=grey]; d [label="decimal", style=dashed, color=grey]; d -> decimal:0:w [style=dashed, color=grey]; q -> decimal:0 [style=dashed, color=grey]; subgraph cluster_0 { label="hash table"; divide_pad[label="----------", style=invis] } } ``` **step 1:** 進到 for 迴圈後第一件事就是透過 `find`,從 hash table 中尋找過去是否除過當前的餘數,若是就代表陷入循環小數,回傳發生第一次循環小數的位數。因為此時 hash table 還是空的,所以 `find` 回傳 `-1`,並賦值給 `pos`。因為沒有陷入循環小數,所以不會進到 if 區塊中。 把當前的小數位數 `i == 0` 以及餘數 `remainder == 1` 裝進 `rem_node` 放入 hash table 中。接下來進行長除法,將餘數 `1` 乘以 10,再除以除數 `11`,也就是計算 `10 / 11`。將 `10 / 11` 的商數 `0` 轉為字元存到 `decimal` 中,然後把 `remainder` 為更新為 `10 / 11` 的餘數 `10`,此時狀態如下: ```graphviz digraph G { node[shape=record]; result[label="<0>1|<1>.|||" shape=record] p [label="p", style=dashed, color=grey]; r [label="result", style=dashed, color=grey]; r -> result:0:w [style=dashed, color=grey]; p -> result:2 [style=dashed, color=grey]; decimal [label="<0>0|<1>|||" shape=record] q [label="q", style=dashed, color=grey]; d [label="decimal", style=dashed, color=grey]; d -> decimal:0:w [style=dashed, color=grey]; q -> decimal:1 [style=dashed, color=grey]; subgraph cluster_0 { label="hash table"; n1 [label="{index = 0|remainder = 1}"] } } ``` **step 2:** 此時 `remainder == 10`,hash table 中並沒有 `remainder == 10` 的元素,所以 `find` 回傳 `-1`,不會進到 if 區塊中。 把當前的小數位數 `i == 1` 以及餘數 `remainder == 10` 裝進 `rem_node` 放入 hash table 中。接下來進行長除法,將 `10` 乘以 10,再除以 `11`,也就是計算 `100 / 11`。將 `100 / 11` 的商數 `9` 轉為字元存到 `decimal` 中,然後把 `remainder` 為更新為 `100 / 11` 的餘數 `1`,此時狀態如下: ```graphviz digraph G { node[shape=record]; result[label="<0>1|<1>.|||" shape=record] p [label="p", style=dashed, color=grey]; r [label="result", style=dashed, color=grey]; r -> result:0:w [style=dashed, color=grey]; p -> result:2 [style=dashed, color=grey]; decimal [label="<0>0|<1>9|<2>||" shape=record] q [label="q", style=dashed, color=grey]; d [label="decimal", style=dashed, color=grey]; d -> decimal:0:w [style=dashed, color=grey]; q -> decimal:2 [style=dashed, color=grey]; subgraph cluster_0 { label="hash table"; n1 [label="{index = 1|remainder = 10}"] n2 [label="{index = 0|remainder = 1}"] } } ``` **step 3:** 此時 `remainder == 1`,恰好 hash table 中存在 `index == 0, remainder == 1` 的元素,代表已經陷入循環小數了,回傳 `index` 的值 `0` 給 `pos`,然後進入 if 區塊。 在 if 區塊中,`PPP` 的 while 是將未循環的位數填到 `result` 中(例如 `0.12(34)` 的 `12`),這個例子中從小數後第 0 位開始都是循環小數,所以不會執行這個 while 迴圈。接下來在 `result` 填入左括弧、填入 `decimal` 循環的部份、填入右括弧。最後回傳 `result`。 ```graphviz digraph G { node[shape=record]; result[label="<0>1|.|(|0|9|)|<6>\\0" shape=record] p [label="p", style=dashed, color=grey]; r [label="result", style=dashed, color=grey]; r -> result:0:w [style=dashed, color=grey]; p -> result:6 [style=dashed, color=grey]; decimal [label="<0>0|<1>9|<2>||" shape=record] q [label="q", style=dashed, color=grey]; d [label="decimal", style=dashed, color=grey]; d -> decimal:0:w [style=dashed, color=grey]; q -> decimal:2 [style=dashed, color=grey]; subgraph cluster_0 { label="hash table"; n1 [label="{index = 1|remainder = 10}"] n2 [label="{index = 0|remainder = 1}", color=red, fontcolor=red] } } ``` 理清楚程式的邏輯後,很明顯的得出 ==PPP== 為 ==pos\-\-== 以填入未循環的位數;==MMM== 為 ==list_add== 把元素放入 hash table;==EEE== 為 ==&heads[remainder % size]==,用以找到對應的 hash 的 entry。 ## 測驗 6 [\_\_alignof\_\_](https://gcc.gnu.org/onlinedocs/gcc/Alignment.html) 是 GNU extension,以下是其可能的實作方式: ```c /* * ALIGNOF - get the alignment of a type * @t: the type to test * * This returns a safe alignment for the given type. */ #define ALIGNOF(t) \ ((char *)(&((struct { char c; t _h; } *)0)->M) - (char *)X) ``` ### 解題 慢慢剖析這個巨集,首先 `(struct { char c; t _h; } *) 0`,是將 `0x0` (nil) 這個位址開頭的記憶體視為一個 `struct { char c; t _h; }` 物件。 `(char *)(&((struct { char c; t _h; } *)0)->M)` 則是以 `0x0` 作為此物件的起始點,取成員 `M` 的位址,再將其轉形為 `char *`,待會便能以 1 byte 為單位計算位址的差距。因為 `ALIGNOF` 是用來計算 `t` 的 alignment,很明顯 ==M== 為 ==\_h==。 取得 `_h` 的位址後,我們只要將其減去 `0x0` 就能得到型態 `t` 的位移量,所以 ==X== 為 ==0==。以下測試常用型態的位移量: ```c // test.c #include <stdio.h> #define ALIGNOF(t) \ ((char *)(&((struct { char c; t _h; } *)0)->_h) - (char *)0) int main(void) { printf("alignof char: %ld\n", ALIGNOF(char)); printf("alignof short: %ld\n", ALIGNOF(short)); printf("alignof int: %ld\n", ALIGNOF(int)); printf("alignof double: %ld\n", ALIGNOF(double)); printf("alignof long: %ld\n", ALIGNOF(long)); printf("alignof long long: %ld\n", ALIGNOF(long long)); return 0; } ``` ```bash $ gcc test.c $ ./a.out alignof char: 1 alignof short: 2 alignof int: 4 alignof double: 8 alignof long: 8 alignof long long: 8 ``` ## 測驗 7 考慮貌似簡單卻蘊含實作深度的 [FizzBuzz](https://en.wikipedia.org/wiki/Fizz_buzz) 題目: - 從 1 數到 n,如果是 3的倍數,印出 “Fizz” - 如果是 5 的倍數,印出 “Buzz” - 如果是 15 的倍數,印出 “FizzBuzz” - 如果都不是,就印出數字本身 直覺的實作程式碼如下: (`naive.c`) ```c #include <stdio.h> int main() { for (unsigned int i = 1; i < 100; i++) { if (i % 3 == 0) printf("Fizz"); if (i % 5 == 0) printf("Buzz"); if (i % 15 == 0) printf("FizzBuzz"); if ((i % 3) && (i % 5)) printf("%u", i); printf("\n"); } return 0; } ``` 觀察 `printf` 的(格式)字串,可分類為以下三種: 1. 整數格式字串 "%d" : 長度為 2 B 2. “Fizz” 或 “Buzz” : 長度為 4 B 3. “FizzBuzz” : 長度為 8 B 考慮下方程式碼: ```c #define MSG_LEN 8 char fmt[MSG_LEN + 1]; strncpy(fmt, &"FizzBuzz%u"[start], length); fmt[length] = '\0'; printf(fmt, i); printf("\n"); ``` 我們若能精準從給定輸入 `i` 的規律去控制 `start` 及 `length` ,即可符合 FizzBuzz 題意: ```c string literal: "fizzbuzz%u" offset: 0 4 8 ``` 以下是利用 bitwise 和上述技巧實作的 FizzBuzz 程式碼: (`bitwise.c`) ```c static inline bool is_divisible(uint32_t n, uint64_t M) { return n * M <= M - 1; } static uint64_t M3 = UINT64_C(0xFFFFFFFFFFFFFFFF) / 3 + 1; static uint64_t M5 = UINT64_C(0xFFFFFFFFFFFFFFFF) / 5 + 1; int main(int argc, char **argv) { for (size_t i = 1; i <= 100; i++) { uint8_t div3 = is_divisible(i, M3); uint8_t div5 = is_divisible(i, M5); unsigned int length = (2 << KK1) << KK2; char fmt[9]; strncpy(fmt, &"FizzBuzz%u"[(9 >> div5) >> (KK3)], length); fmt[length] = '\0'; printf(fmt, i); printf("\n"); } return 0; } ``` 其中 `is_divisible` 函式技巧來自 [Faster remainders when the divisor is a constant: beating compilers and libdivide](https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/),甚至 gcc-9 還內建了 [FizzBuzz optimization](https://gcc.gnu.org/bugzilla/show_bug.cgi?id=82853) (Bug 82853 - Optimize x % 3 == 0 without modulo)。 請補完。 對於處理器來說,每個運算所花的成本是不同的,比如 `add`, `sub` 就低於 `mul`,而這些運算的 cost 又低於 `div` 。依據〈[Infographics: Operation Costs in CPU Clock Cycles](http://ithare.com/infographics-operation-costs-in-cpu-clock-cycles/)〉,可發現整數除法的成本幾乎是整數加法的 50 倍。 ![](https://hackmd.io/_uploads/rJp3Lfce9.png) ### 解題 `length` 作為 `strncpy` 複製的字串長度,當 $i|3$ 或 $i|5$ 時 `length` 預期為 4,即 `"Fizz"` 或 `"Buzz"` 的長度;當 $i|15$ 時 `length` 預期為 8,即 `"FizzBuzz"` 的長度。因此 ==KK1== 為 ==div3==、==KK2== 為 ==div5==。 `&"FizzBuzz%u"[(9 >> div5) >> (KK3)]` 代表要從字元陣列 `"FizzBuzz%u"` 的哪個位址開始複製。以下是在各種情況期望的起始位址與對應的字串: - $i|3$: `(9 >> div5) >> (KK3) == 0`,i.e. `"Fizz"` - $i|5$: `(9 >> div5) >> (KK3) == 4`,i.e. `"Buzz"` - $i|15$: `(9 >> div5) >> (KK3) == 0`,i.e. `"FizzBuzz"` - default: `(9 >> div5) >> (KK3) == 8`,i.e. `"%u"` 當 ==KK3== 為 ==div3 << 2== 時可以達成上述期望。 :::info (9 >> div5) >> (KK3) 應改成 (8 >> div5) >> (KK3),否則會複製到 "u" 而非 "%u"。 :::