# 2019q3 Homework2 (quiz2) contributed by < `colinyoyo26` > ### 測驗 `1` 考慮下方檔案 `4thought.c` 是 ACM-ICPC 題目 [4 thought](https://open.kattis.com/problems/4thought) 的一個解法,假設程式的輸入符合 [4 thought](https://open.kattis.com/problems/4thought) 的描述,請補完程式碼: ```cpp #include <stdbool.h> #include <stdio.h> enum { opType1 = 0x1 << 0, opType2 = 0x1 << 1, opType3 = 0x1 << 4, opType4 = 0x1 << 5, }; static int operate(int op, int a, int b) { switch (op) { case opType1: return a + b; case opType2: return a - b; case opType3: return a * b; case opType4: return (int) a / b; } return 0; } static char op_to_char(int op) { return "+-*/?"[op - 1]; } static int op_to_prio(int op) { return ((int[]){opType1, opType2, opType3, opType4, -1})[op - 1]; } static int calc(int op1, int op2, int op3) { op1 = op_to_prio(op1); op2 = op_to_prio(op2); op3 = op_to_prio(op3); bool p1 = (op1 & 0x0F) == 0; // = 1 for * or / bool p2 = (op2 & 0x0F) == 0; // else = 0 bool p3 = (op3 & 0x0F) == 0; // (4 + 4 + 4 + 4) or (4 / 4 / 4 / 4) if ((p1 == p2) && (p2 == p3)) return operate(op3, operate(op2, operate(op1, 4, 4), 4), 4); /* Write your code here */ return 0; } int main(void) { int n; scanf("%d", &n); int sol[n]; for (int i = 0; i < n; i++) scanf("%d", &sol[i]); bool validSolution = false; for (int i = 0; i < n; i++) { for (int op1 = 4; op1 > 0; op1--) { for (int op2 = 4; op2 > 0; op2--) { for (int op3 = 4; op3 > 0; op3--) { int sol_checked = calc(op1, op2, op3); if (sol_checked == sol[i]) { validSolution = true; char op1char = op_to_char(op1); char op2char = op_to_char(op2); char op3char = op_to_char(op3); printf("4 %c 4 %c 4 %c 4 = %d\n", op1char, op2char, op3char, sol[i]); op1 = -1; op2 = -1; op3 = -1; break; } } } } if (!validSolution) printf("no solution\n"); validSolution = false; } return 0; } ``` 注意: 你應該要實作 `calc` 函式中標註 `/* Write your code here */` 之後的程式碼。除了撰寫程式,你應該提供對應的程式碼註解。 :::success 延伸問題: 1. 解釋程式運作的原理和推敲背後的思路; 2. 探討 [4 thought](https://open.kattis.com/problems/4thought) 組合出來的數值分佈,並且透過數論解釋; 3. 提出得以改善上述程式碼執行效率的方案,著手分析和實作; ::: ==解題== ```cpp static int calc(int op1, int op2, int op3) { op1 = op_to_prio(op1); op2 = op_to_prio(op2); op3 = op_to_prio(op3); bool p1 = (op1 & 0x0F) == 0; // = 1 for * or / bool p2 = (op2 & 0x0F) == 0; // else = 0 bool p3 = (op3 & 0x0F) == 0; // (4 + 4 + 4 + 4) or (4 / 4 / 4 / 4) or (4 / 4 / 4 + 4) or (4 / 4 + 4 + 4) if ((p1 == p2) && (p2 == p3) || p1 && p2 || p1 && !p2 && !p3) return operate(op3, operate(op2, operate(op1, 4, 4), 4), 4); // (4 + 4 / 4 / 4) if (p2 && p3) return operate(op1, 4, operate(op3, operate(op2, 4, 4), 4)); // (4 + 4 / 4 + 4) if (!p1 && !p3) return operate(op3, operate(op1, operate(op2, 4, 4), 4), 4); // (4 + 4 + 4 / 4) if (p3) return operate(op2, operate(op1, 4, 4), operate(op3, 4, 4)); return 0; } ``` - 這題其實只是苦工題而已,把可能的排列組合都列出來 :::danger 不要輕易說「苦工題」,你的數學推理素養去哪了?如何 generalization 呢? :notes: jserv ::: #### 延伸問題 - 程式運作原理 - 用三層迴圈嘗試所有可能的組合 - 待更新 --- ### 測驗 `2` 考慮以下程式碼 (`fitbits.c`) 可檢驗輸入的整數 `x` 是否可用 `n` 個位元來表示,例如 (x = 4, n = 9) 要回傳 `true`, 當 (x = 4, n = 2) 回傳 `false`。 ```cpp #include <stdbool.h> bool fit_bits(int x, int n) { /* Write your code here */ return (bool) x; } ``` 實作的程式碼不能有任何邏輯條件判斷 (如 `if`, `else`, `?`) 或迴圈 (如 `for`, `while`, `goto`, `switch`, `case`, `break`, `continue`),可用的運算子是 `>>`, `<<`, `-`, `+`, `!`, `~`, `&`, `|` 請補完程式碼,作答時需要包含函式宣告及定義。除了撰寫程式,你應該提供對應的程式碼註解。 :::success 延伸問題: 在重視資訊安全的專案中,找出類似用法的程式碼,予以解說並進行相關 information leaks 的實驗 ::: :::warning 是不是覺得有點挫折呢?原因很簡單:你沒有如期寫作業 [review](https://hackmd.io/@sysprog/rJM4SPw8S),趕快動手! ::: ==解題== ```cpp #include <stdbool.h> bool fit_bits(int x, int n) { x >>= n - 1; x++; x >>= 1; x = !x; return (bool) x; } ``` - 考慮 n bits 能表達的有號數 - 因為留一個 sign bit 所以一開始先 shift right n - 1 個 bits - 如果能用 n bits 表示,右移完會有兩個結果 - 本來是正數 -> 0 - 本來是負數 -> -1 - 所以加一後分別變成 0 和 1 接著我們只有看除了 LSB 以外是否全為 0 就好了 #### 延伸問題 - 待更新 --- ### 測驗 `3` 考慮以下程式碼 (`is-less-equal.c`) 可檢驗輸入的整數 `x` 和 `y`,是否存在 $x <= y$ 的關係。例如 (x = 4, n = 4) 要回傳 `true`, 當 (x = 14, n = 9) 回傳 `false`。 ```cpp #include <stdbool.h> bool is_leq(int x, int y) { int s; /* Write your code here */ return (bool) s; } ``` 實作的程式碼不能有任何邏輯條件判斷 (如 `if`, `else`, `?`) 或迴圈 (如 `for`, `while`, `goto`, `switch`, `case`, `break`, `continue`),當然也不能用 `>=`, `>`, `<`, `<=`, `-` 等運算子。可用的運算子是 `>>`, `<<`, `+`, `~` 請補完程式碼,作答時需要包含函式宣告及定義。除了撰寫程式,你應該提供對應的程式碼註解。 :::success 延伸問題: 在重視資訊安全的專案中,找出類似用法的程式碼,予以解說並進行相關 information leaks 的實驗 ::: ==解題== ```cpp #include <stdbool.h> bool is_leq(int x, int y) { int s; s = ((y + ~x + 1) >> 31) + 1; return (bool) s; } ``` - `(y + ~x + 1) >> 31` 會有兩種結果 - y >= x -> 0 - y < x -> -1 - 最後 + 1 就能回傳正確結果了 #### 延伸問題 - 因為目前沒有方向所以先做實驗測試看看 - 執行時間實驗 - 想法: 隨機給定數值測試執行時間變異度 ```cpp #include <stdio.h> #include <stdbool.h> #include <stdlib.h> #include <time.h> #include <assert.h> bool is_leq(int x, int y) { int s; s = ((y + ~x + 1) >> 31) + 1; return (bool) s; } bool is_leq_norm(int x, int y) { return x <= y; } int main(void){ struct timespec start1, end1, start2, end2; int a, b; bool ans1, ans2; for(int iter = 0; iter < 100; iter++){ a = random(); b = random(); clock_gettime(CLOCK_REALTIME, &start1); ans1 = is_leq(a, b); clock_gettime(CLOCK_REALTIME, &end1); clock_gettime(CLOCK_REALTIME, &start2); ans2 = is_leq_norm(a, b); clock_gettime(CLOCK_REALTIME, &end2); assert(!ans1 ^ ans2); printf("%d %lu %lu\n", iter, end1.tv_nsec - start1.tv_nsec, end2.tv_nsec - start2.tv_nsec); } return 0; } ``` - 結果 - 從結果看來直接做 ``<=`` 運算效能比較好且執行時間變異度也很平穩,看來執行時間不是造成 information leak 的原因 ![](https://i.imgur.com/JwcwYna.png) :::warning [lab0 作業](https://hackmd.io/@sysprog/HyFQpqgPB)要求提到 [dudect 工具](https://github.com/oreparaz/dudect),請善用 :notes: jserv ::: - 待更新 --- ### 測驗 `4` 考慮一種針對短字串的最佳化操作,假設字串總是小於等於 8 bytes,標的硬體是像 x86_64 這樣的 64-bit 架構而且是 [little endian](https://en.wikipedia.org/wiki/Endianness),於是我們可實作類似以下的程式碼 (`ministr.c`): ```cpp #include <stdint.h> #include <stdio.h> #include <string.h> typedef union { uint64_t integer; char array[8]; } mini_str; static unsigned BitScanReverse(uint64_t x) { return 63 - __builtin_clzll(x); } /** * Find the length of the given mini_str. * @param str string to find length of * @return length of the given string */ unsigned mini_strlen(mini_str str) { // Special case for empty string. if (str.integer == 0) return 0; // Otherwise find first non-zero bit (which will be in the first non-zero // byte), and find which byte it is in. // FIXME: Assumes little-endian. unsigned msb = BitScanReverse(str.integer); return msb / 8 + 1; } /** * Create a new mini_str with length 0. * @return newly created mini_str */ mini_str mini_str_new(void) { // Create string of all null bytes. mini_str str = {.integer = 0}; return str; } /** * Append str2 to the end of str1 and return the reult. * @param str1 first string * @param str2 second string * @return combined string */ mini_str mini_strcat(mini_str str1, mini_str str2) { // Shift str2 along by str1Len characters to move it into position. unsigned str1Len = mini_strlen(str1); str2.integer <<= str1Len * 8; // FIXME: Assumes little-endian. /* Write your code here */ return str1; } #define mini_str_to_c(mini_str) ((const char *) (mini_str).array) #define mini_str_to_cNoConst(mini_str) ((char *) (mini_str).array) /** * Create a mini_str from a standard C character array. * @param cstr Null-terminated C-string to use as input * @return newly created mini_str */ mini_str mini_str_from_c(const char *cstr) { // Create empty string. mini_str mini_str = mini_str_new(); // Copy string. strncpy(mini_str_to_cNoConst(mini_str), cstr, 7); return mini_str; } int main(int argc, char **argv) { mini_str all = mini_str_from_c("All "); mini_str red = mini_str_from_c("red"); mini_str cat = mini_strcat(all, red); printf("%s\n", mini_str_to_c(cat)); return 0; } ``` 這裡的 `__builtin_clzll` 是 [GCC builtin function](https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html),作用是 [bit scan](https://www.chessprogramming.org/BitScan),程式預期輸出為: ``` All red ``` 你應該要實作 `calc` 函式中標註 `/* Write your code here */` 之後的程式碼。除了撰寫程式,你應該提供對應的程式碼註解。 注意: 實作的程式碼不能有任何邏輯條件判斷 (如 `if`, `else`, `?`) 或迴圈 (如 `for`, `while`, `goto`, `switch`, `case`, `break`, `continue`),也不能用 `>=`, `>`, `<`, `<=`, `-` 等運算子。 :::success 延伸問題: 1. 指出這樣針對短字串的最佳化效益,並嘗試量化; 2. 什麼樣的情境會出現大量的短字串?請舉例並分析; 3. 程式碼該如何修改,才能適用 big/little-endian 呢? 4. 考慮到現代的處理器架構支援 [SIMD](https://en.wikipedia.org/wiki/SIMD),得以一次處理 128-bit, 256-bit, 甚至是 512-bit,請評估這樣最佳化策略的可用性,應當有對應的實驗; ::: ==解題== ```cpp mini_str mini_strcat(mini_str str1, mini_str str2) { // Shift str2 along by str1Len characters to move it into position. unsigned str1Len = mini_strlen(str1); str2.integer <<= str1Len * 8; // FIXME: Assumes little-endian. str1.integer |= str2.integer; return str1; } ``` - 每個 char 單位是 8 bytes 所以往左位移 str1Len * 8 bits - 再來做 or 運算就能把 str2 擺在較高位址 #### 延伸問題 - 因為對齊 8 bytes 在 64 位元的處理器下 strcat 可以一次處理 8 bytes 比起一次搬 1 byte 有效率 - 適用 big/little-endian 版本 - little endian MSB (Most Significand Byte) 一定是 0 - big endian LSB 一定是 0 - 由這個特性製造出 mask 在不用 branch 情況下判斷 ```cpp mini_str mini_strcat(mini_str str1, mini_str str2) { // Shift str2 along by str1Len characters to move it into position. unsigned str1Len = mini_strlen(str1); // mask will be 0xffffffff if its little endian int64_t mask = (int64_t) ((str1.integer >> 56) -1) >> 63; str2.integer <<= str1Len * 8 & mask; // FIXME: Assumes little-endian. str2.integer >>= str1Len * 8 & ~mask; str1.integer |= str2.integer; return str1; } ``` - 待更新 - SSE - [#include <x86intrin.h>](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#) --- ### 測驗 `5` [population count](https://en.wikichip.org/wiki/population_count) 簡稱 popcount 或叫 sideways sum,是計算數值的二進位表示中,有多少位元是 `1`,在一些場合下很有用,例如計算 0-1 稀疏矩陣 (sparse matrix)或 bit array 中非 `0` 元素個數、計算兩個字串的 [Hamming distance](https://en.wikipedia.org/wiki/Hamming_weight)。Intel 在 2008 年 11 月 Nehalem 架構的處理器 Core i7 引入 SSE4.2 指令集,其中就有 `CRC32` 和 `POPCNT` 指令,`POPCNT` 可處理 16-bit, 32-bit, 64-bit 整數。 GCC 提供對應的 builtin function: ${\_\_builtin\_popcount}(x)$: `x` 總共有幾個 `1`。使用示範: ```cpp int x = 5328; // 00000000000000000001010011010000 printf("%d\n", __builtin_popcount(x)); // 5 ``` 以下是個存在實作缺陷的版本: ```cpp int popcnt_naive(int n) { int count = 0; while (n) { if (n & 1) ++count; n = n >> 1; } return count; } ``` 呼叫 `popcnt_naive(-1)` 時,會造成無窮迴圈,請指出錯誤所在,並且重寫為正確的版本。 ==解題== - 原本的瑕疵在於 `>>` 為算術右移,會自動補上 sign bit 所以用 unsigned int 就沒問題了 ```cpp int popcnt_naive(unsigned int n) { int count = 0; while (n) { if (n & 1) ++count; n = n >> 1; } return count; } ``` --- ### 測驗 `6` 延伸測驗 `5`,實作 branchless 的 `popcnt` 並附上對應的程式原理解說。 :::success 延伸問題: 1. 指出 `popcnt` 的應用場景; 2. 在 Linux 核心程式碼找出具體用法並解析; ::: :::warning 挫折到此就不是只有一點了吧?原因很簡單:你沒有如期寫作業 [review](https://hackmd.io/@sysprog/rJM4SPw8S),趕快動手! ::: ==解題== ```cpp= unsigned int countBits(unsigned int x) { // count bits of each 2-bit chunk x = x - ((x >> 1) & 0x55555555); // count bits of each 4-bit chunk x = (x & 0x33333333) + ((x >> 2) & 0x33333333); // count bits of each 8-bit chunk x = x + (x >> 4); // mask out junk x &= 0xF0F0F0F; // add all four 8-bit chunks return (x * 0x01010101) >> 24; } ``` - 第 4 行,我們先把 32 bits 的資料拆成 16 個 2 bits (以下記作 A)來看,該行運算等同於是 A - A >> 1 因為 0x5 == 0101b (以下就不對 mask 部份贅述了) 可以分成以下四種 case: 1. A = 00b A - A >> 1 = 00b 2. A = 01b A - A >> 1 = 01b 3. A = 10b A - A >> 1 = 01b 4. A = 11b A - A >> 1 = 10b - 看出計算結果剛好是對每個 A 的 popcnt - 再來第 6 行把 32 bits 的資料拆成 8 個 4 bits (再分成左右各 2 bits 分別記作 B1 B2) - 可以發現該行運算的結果是 B1 + B2 剛好把剛剛第四行的結果相加,所以現在 32 bits 的資料變成 8 個 4 bits 一組的 popcnt - 繼續看第 8 行把第 i 堆加到第 i - 1 堆 for all i > 1 (這裡一樣 4 個 bits 一堆) - 再來第 10 行做完 and 後第 j 堆的數字剛好是第 j / 2 + 1 byte 的 popcnt for all j belong to odd (共四堆) - 最後第 12 行做完乘法,剛好可以把四堆相加的結果存在第 25 個 bit 開始的地方,然後再向右 shift 得到結果 #### 延伸問題 - [linux/tools/include/linux/bitmap.h](https://github.com/torvalds/linux/blob/dad4f140edaa3f6bb452b6913d41af1ffd672e45/tools/include/linux/bitmap.h) 找到以下 function - 計算 `*src` 右邊 `nbits` 個 bits 的 hamming weight ```cpp static inline int bitmap_weight(const unsigned long *src, int nbits) { if (small_const_nbits(nbits)) return hweight_long(*src & BITMAP_LAST_WORD_MASK(nbits)); return __bitmap_weight(src, nbits); } ``` - BITMAP_LAST_WORD_MASK 巨集 - 判斷 mask 長度, `(nbits) % BITS_PER_LONG` 防止超過 `BITS_PER_LONG` 個 bits ```cpp #define BITMAP_LAST_WORD_MASK(nbits) \ ( \ ((nbits) % BITS_PER_LONG) ? \ (1UL<<((nbits) % BITS_PER_LONG))-1 : ~0UL \ ) ``` - hweight_long 定義在 [tools/include/linux/bitops.h](https://github.com/torvalds/linux/blob/7e67a859997aad47727aff9c5a32e160da079ce3/tools/include/linux/bitops.h) - hweight 實作就是 popcnt - 從 long 的長度判斷要用哪個版本的 popcnt ```cpp static inline unsigned long hweight_long(unsigned long w) { return sizeof(w) == 4 ? hweight32(w) : hweight64(w); } ``` - 待更新 --- ### 測驗 `7` 考慮到以下程式 (`alloc.c`) 是 [aligned_alloc](https://linux.die.net/man/3/posix_memalign) 的一種簡易實作: ```cpp #include <stdlib.h> // Number of bytes used for storing the aligned pointer offset. // up to 64KB alignment, a size which is already unlikely to be // used for alignment. typedef uint16_t offset_t; #define PTR_OFFSET_SIZE sizeof(offset_t) #define align_up(num, align) \ (((num) + ((align) - 1)) & ~((align) - 1)) void *aligned_malloc(size_t align, size_t size) { void *ptr = NULL; // size must be a power of two. if (!((align & (align - 1)) == 0)) return ptr; if (align && size) { // allocate extra bytes to meet the alignment uint32_t header_size = PTR_OFFSET_SIZE + (align - 1); void *p = malloc(size + header_size); /* Write your code here */ } return ptr; } ``` 其作用是配置針對 `align` 個 bytes 對齊的記憶體空間,可對照閱讀 [Introduction & Allocators](http://stevenlr.com/posts/handmade-rust-1-allocators/) 以掌握原理。你應該要實作 `aligned_malloc` 函式中標註 `/* Write your code here */` 之後的程式碼。除了撰寫程式,你應該提供對應的程式碼註解。 注意: 輸入的 `align` 應該要是 2^N^ (power of 2),否則就回傳 `NULL`。 :::success 延伸問題: 1. 解釋程式運作的原理和推敲背後的思路; 2. 在開放原始碼的專案中,找尋類似的程式碼,解說並量化具體效益; ::: ==解題== ```cpp #include <stdlib.h> #include <stdint.h> // Number of bytes used for storing the aligned pointer offset. // up to 64KB alignment, a size which is already unlikely to be // used for alignment. typedef uint16_t offset_t; #define PTR_OFFSET_SIZE sizeof(offset_t) #define align_up(num, align) \ (((num) + ((align) - 1)) & ~((align) - 1)) void *aligned_malloc(size_t align, size_t size) { void *ptr = NULL; // size must be a power of two. if (!((align & (align - 1)) == 0)) return ptr; if (align && size) { // allocate extra bytes to meet the alignment uint32_t header_size = PTR_OFFSET_SIZE + (align - 1); void *p = malloc(size + header_size); ptr = (void *) align_up((size_t)(p + PTR_OFFSET_SIZE), align); *((offset_t *) (ptr - PTR_OFFSET_SIZE)) = (offset_t) (ptr - p); } return ptr; } ``` - gdb 測試結果 ``` (gdb) info local header_size = 9 p = 0x602010 ptr = 0x602018 # 順利存到 offset (gdb) x ptr-sizeof(uint16_t) 0x602016: 0x00000008 ``` #### 延伸問題 - 因為確保找到真正的 base address 所以 `malloc` 時多塞了 `offset_` 的空間 - 且要確保有空間可以對齊 align-bytes 所以也多塞了 `(align - 1)` bytes 的空間 - `ptr = (void *) align_up((size_t)(p + PTR_OFFSET_SIZE), align)` 得到回傳給 user 的 base address - `*((offset_t *) (ptr - PTR_OFFSET_SIZE)) = (offset_t) (ptr - p);`算出離真正 base address 的 offset 存到 `ptr - PTR_OFFSET_SIZE` 的地方 --- ### 測驗 `8` 延伸測驗 `7`,實作 `aligned_free`,其原型宣告如下: ```cpp void aligned_free(void *ptr); ``` 除了撰寫程式,你應該提供對應的程式碼註解。 ==解題== ```cpp void aligned_free(void *ptr){ offset_t off = *((offset_t *) (ptr - PTR_OFFSET_SIZE));; free(ptr - off); } ``` - 從 ptr 往回推 `PTR_OFFSET_SIZE` bytes 拿到 offset - ptr - offset 得到真正的 base address - gdb 測試結果 (承上題) ``` # 順利拿到 base address! (gdb) p off $2 = 8 (gdb) p ptr-off $3 = (void *) 0x602010 ``` --- ### 測驗 `9` 考慮以下 64-bit GCD (greatest common divisor, 最大公因數) 求值函式: ```cpp #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; while (v) { uint64_t t = v; v = u % v; u = t; } return u; } ``` 改寫為以下等價實作: ```cpp #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; int shift; for (shift = 0; !((u | v) & 1); shift++) { u /= 2, v /= 2; } while (!(u & 1)) u /= 2; do { while (!(v & 1)) v /= 2; if (u < v) { v -= u; } else { uint64_t t = u - v; u = v; v = t; } } while (/* Write your code here */); return /* Write your code here */; } ``` 補完以上程式碼,即標注 `/* Write your code here */` 的部分,需要抄寫 `while` 和 `return` 所在的程式碼。 ==解題== ```cpp #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; int shift; for (shift = 0; !((u | v) & 1); shift++) { u /= 2, v /= 2; } while (!(u & 1)) u /= 2; do { while (!(v & 1)) v /= 2; if (u < v) { v -= u; } else { uint64_t t = u - v; u = v; v = t; } } while (v); return u << shift; } ``` - 一開始的迴圈用來檢查最大公因數的 2 的次方數 - 因為最大公因數的 2 已經找完了,所以用以下迴圈把 2 除掉讓數字縮小 ```cpp while (!(u & 1)) u /= 2; ``` - 以下 code 相當於再取餘數的動作,相當於數學式 gcd(v, u) = gcd(u, v (mod) u) - 所以執行到最後變成 gcd(u, 0) 可得知跳脫條件 - 其中 u 在乘回 2 的 shift 次方就是 gcd ```cpp if (u < v) { v -= u; } else { uint64_t t = u - v; u = v; v = t; } ``` --- ### 測驗 `10` 承測驗 `9`, 透過 gcc 內建的 [\_\_builtin_ctz](https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html) (Returns the number of trailing 0-bits in x, starting at the least significant bit position) 改寫程式碼如下: ```clike #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; int shift = __builtin_ctzll(/* Write your code here */); u >>= __builtin_ctzll(u); while (v) { v >>= __builtin_ctzll(v); if (u < v) { /* Write your code here */ } else { uint64_t t = u - v; u = v, v = t; } } return /* Write your code here */; } ``` 請補完程式碼,作答時需要一併包含原本函式內容。除了撰寫程式,你應該提供對應的程式碼註解。 :::success 延伸問題: 解釋上述程式程式運作原理,以及在 x86_64 上透過 [\_\_builtin_ctz](https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html) 改寫 GCD 對效能的提升 ::: ==解題== ```cpp #include <stdint.h> uint64_t gcd64(uint64_t u, uint64_t v) { if (!u || !v) return u | v; int shift = __builtin_ctzll(u | v); u >>= __builtin_ctzll(u); while (v) { v >>= __builtin_ctzll(v); if (u < v) { v -= u; } else { uint64_t t = u - v; u = v, v = t; } } return u << shift; } ``` - `__builtin_ctzll(u | v)` 等價於測驗 9 第一個 for 迴圈的功用 - `u >>= __builtin_ctzll(u)` 也等價於測驗 9 清除 2 的因數的迴圈 - 差別在於先算出尾端有幾個 0 一次 shift #### 延伸問題 - 寫了簡單的 program 反組譯看看 ```cpp int ctz(int x){ return __builtin_ctzll(x); } int main(void){ ctz(1 << 31); return 0; } ``` - 反組譯結果 - cltq 作用為 %eax 做 sign extention 結果放到 %rax - tzcnt 就是 count tailing zero - 待更新 - x86 asm: bsfq ```cpp 00000000004004e9 <ctz>: 4004e9: 55 push %rbp 4004ea: 48 89 e5 mov %rsp,%rbp 4004ed: 89 7d fc mov %edi,-0x4(%rbp) 4004f0: 8b 45 fc mov -0x4(%rbp),%eax 4004f3: 48 98 cltq 4004f5: f3 48 0f bc c0 tzcnt %rax,%rax 4004fa: 5d pop %rbp 4004fb: c3 retq ``` - 對兩種 ctz 做效能分析 (repeat 100 times) - __builtin_ctzll 版本 ```cpp #include <time.h> #include <stdlib.h> #include <stdint.h> #include <stdio.h> int ctz(uint64_t x){ return __builtin_ctzll(x); } int main(int argc, char** argv){ struct timespec start, end; int32_t shift = atoi(argv[1]); clock_gettime(CLOCK_REALTIME, &start); for(int i = 0; i < 100; i++) ctz(1 << shift); clock_gettime(CLOCK_REALTIME, &end); printf("%u %lu\n", shift, end.tv_nsec - start.tv_nsec); return 0; } ``` - loop 版本 ```cpp #include <time.h> #include <stdlib.h> #include <stdint.h> #include <stdio.h> int ctz_loop(uint64_t x){ int count = 0; for(; !(x & 1); count++){ x /= 2; } return count; } int main(int argc, char** argv){ struct timespec start, end; int32_t shift = atoi(argv[1]); clock_gettime(CLOCK_REALTIME, &start); for(int i = 0; i < 100; i++) ctz_loop(1 << shift); clock_gettime(CLOCK_REALTIME, &end); printf("%u %lu\n", shift, end.tv_nsec - start.tv_nsec); return 0; } ``` - test.sh - 為了降低快取影響所以加了 `sync && echo 3 > /proc/sys/vm/drop_caches` 每輪都清空快取 ```cpp #!/bin/bash gcc -o ctz ctz.c gcc -o ctz_loop ctz_loop.c rm *.txt for i in $(seq 0 63) do sync && echo 3 > /proc/sys/vm/drop_caches ./ctz $i >> out1.txt ./ctz_loop $i >> out2.txt done gnuplot ./plot.gp ``` - 實驗結果如下,從這張圖可以看出 builtin 版本表現的效能較好且為常數時間 - 資料待整理 - branch prediction - 效能分佈趨勢 ![](https://i.imgur.com/WNaNTUr.png) :::warning 嚇到吃手手了嗎?原因很簡單:你沒有如期寫作業 [review](https://hackmd.io/@sysprog/rJM4SPw8S),趕快動手! ::: --- ### 測驗 `11` 考慮到 [memcmp](http://man7.org/linux/man-pages/man3/memcmp.3.html) 一種實作如下: (行為和 ISO C90 有出入) ```cpp #include <stdint.h> #include <stddef.h> int memcmp(const uint8_t *m1, const uint8_t *m1, size_t n) { for (size_t i = 0; i < n; ++i ) { int diff = m1[i] - m2[i]; if (diff != 0) return (diff < 0) ? -1 : +1; } return 0; } ``` 我們可能因此承受 [information leakage](https://en.wikipedia.org/wiki/Information_leakage) 的風險,於是著手避免使用 conditional branching 一類的指令,從而避免 [side-channel attack](https://en.wikipedia.org/wiki/Side-channel_attack)。 為了避免 conditional branch 指令的出現,我們可將 `(res > 0) - (res < 0)` 替換為 `((res - 1) >> 8) + (res >> 8) + 1`。隨後我們實作下方功能等價但避免 branch 的 ` cst_memcmp`: ```cpp #include <stdint.h> #include <stddef.h> int cst_memcmp(const void *m1, const void *m2, size_t n) { const uint8_t *pm1 = (const uint8_t *) m1 + n; const uint8_t *pm2 = (const uint8_t *) m2 + n; int res = 0; if (n) { do { int diff = *--pm1 - *--pm2; /* Write your code here */ } while (pm1 != m1); } return ((res - 1) >> 8) + (res >> 8) + 1; } ``` 注意: 在 Linux 核心內部的實作方式可見: * [[PATCH] crypto_memcmp: add constant-time memcmp](https://www.spinics.net/lists/linux-crypto/msg09542.html) * [Re: [PATCH] crypto_memcmp: add constant-time memcmp](https://www.spinics.net/lists/linux-crypto/msg09551.html) 請補完程式碼,作答時需要一併包含原本函式內容。除了撰寫程式,你應該提供對應的程式碼註解。 注意: 在 `/* Write your code here */` 所在的程式碼作用區域 (scope) 中,不得存任何邏輯條件判斷 (如 `if`, `else`, `?`) 或迴圈 (如 `for`, `while`, `goto`, `switch`, `case`, `break`, `continue`) :::success 延伸問題: 1. 解釋上述程式的原理,需要從機率統計的觀點分析; * 為何不能用事先計算好的表格呢? (提示: cache 的影響) * 如何驗證程式正確以及 constant-time 呢? 2. 在 Linux 核心找到這類 constant-time 的操作程式碼,予以解說和設計實驗; ::: ==解題== ```cpp #include <stdint.h> #include <stddef.h> int cst_memcmp(const void *m1, const void *m2, size_t n) { const uint8_t *pm1 = (const uint8_t *) m1 + n; const uint8_t *pm2 = (const uint8_t *) m2 + n; int res = 0; if (n) { do { int diff = *--pm1 - *--pm2; res = (diff >> 8 | diff) + (!diff) * res; } while (pm1 != m1); } return ((res - 1) >> 8) + (res >> 8) + 1; } ``` - 要等價於原本程式碼所以要看最靠近低位址不相同的字元差為正號或負號 - `(diff >> 8 | diff)` - 如果小於 0 這項會變成 -1 - 大於 0 會是 diff - (!diff) * res - diff 為 0 時保留原本的值 #### 延伸問題 - [linux/crypto/memneq.c](https://github.com/torvalds/linux/blob/81160dda9a7aad13c04e78bb2cfd3c4630e3afab/include/crypto/algapi.h) ```cpp /* Compare two areas of memory without leaking timing information, * and with special optimizations for common sizes. Users should * not call this function directly, but should instead use * crypto_memneq defined in crypto/algapi.h. */ noinline unsigned long __crypto_memneq(const void *a, const void *b, size_t size) { switch (size) { case 16: return __crypto_memneq_16(a, b); default: return __crypto_memneq_generic(a, b, size); } } ``` - 看到 generic case - 第一個 while 迴圈一次檢查 sizeof(unsigned long)) 個 bytes (如果可以接受 unaligned acces 才會進入) - 第二個 while 迴圈一次檢查 1 byte - return 0 代表 a b 在 size bytes 內 memory content 一樣 ```cpp static inline unsigned long __crypto_memneq_generic(const void *a, const void *b, size_t size) { unsigned long neq = 0; #if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) while (size >= sizeof(unsigned long)) { neq |= *(unsigned long *)a ^ *(unsigned long *)b; OPTIMIZER_HIDE_VAR(neq); a += sizeof(unsigned long); b += sizeof(unsigned long); size -= sizeof(unsigned long); } #endif /* CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS */ while (size > 0) { neq |= *(unsigned char *)a ^ *(unsigned char *)b; OPTIMIZER_HIDE_VAR(neq); a += 1; b += 1; size -= 1; } return neq; } ``` - __crypto_memneq_16 概念一樣只是因為已知大小所以直接把迴圈展開 - 看得出來執行時間是依據 size 所以沒辦法從執行時間推敲出正確的字串長度 --- ### 測驗 `12` 給定一個 [circular linked list](https://en.wikipedia.org/wiki/Linked_list#Circular_linked_list) 實作如下: (檔案 `list.h`) ```cpp typedef struct __list_t { struct __list_t *prev, *next; } list_t; /* * Initialize a list to empty. Because these are circular lists, an "empty" * list is an entry where both links point to itself. This makes insertion * and removal simpler because they do not need any branches. */ static inline void list_init(list_t *list) { list->prev = list; list->next = list; } /* * Append the provided entry to the end of the list. This assumes the entry * is not in a list already because it overwrites the linked list pointers. */ static inline void list_push(list_t *list, list_t *entry) { list_t *prev = list->prev; entry->prev = prev; entry->next = list; prev->next = entry; list->prev = entry; } /* * Remove the provided entry from whichever list it is currently in. This * assumes that the entry is in a list. You do not need to provide the list * because the lists are circular, so the list's pointers will automatically * be updated if the first or last entries are removed. */ static inline void list_remove(list_t *entry) { list_t *prev = entry->prev; list_t *next = entry->next; prev->next = next; next->prev = prev; } /* * Remove and return the first entry in the list or NULL if the list is empty. */ static inline list_t *list_pop(list_t *list) { /* Write your code here */ } ``` 請依循程式註解的描述,參照 `list_push`, 實作可正確運作的 `list_pop`。作答時需要一併包含原本函式內容。除了撰寫程式,你應該提供對應的程式碼註解。 注意: 應善用 `list_remove` 和已實作的函式。 :::success 延伸問題: 1. 解釋上述程式的原理和技巧; 2. 在 Linux 核心找到這類的操作程式碼; ::: ==解題== ```cpp static inline list_t *list_pop(list_t *list) { if(!list || list->next = list) return NULL; list *tem = list->next; list_remove(list->next); return tem; } ``` - 先檢查是否為 null ptr or empty list - 再把第一個 entry 存在 tem 之後 call list_remove 刪除第一個 entry - return tem #### 延伸問題 - 用到了 doubly-linked list - 用 list->next 是否指向自己來判斷 empty 插入或刪除時可以避免 branch - 因為最少會剩一個節點,可以避免 dereference null ptr - [linux/include/linux/list.h](https://github.com/torvalds/linux/blob/a2d79c7174aeb43b13020dd53d85a7aefdd9f3e5/include/linux/list.h) - 這邊可以看到 init 時把前後指向自己 ```cpp #define LIST_HEAD_INIT(name) { &(name), &(name) } #define LIST_HEAD(name) \ struct list_head name = LIST_HEAD_INIT(name) ``` - 一開始先定義一個通用的插入,之後的 list_add_tail list_add 直接 reuse ```cpp static inline void __list_add(struct list_head *new, struct list_head *prev, struct list_head *next) { next->prev = new; new->next = next; new->prev = prev; prev->next = new; } ``` - 其實概念都跟這題大同小異,比較令人好奇的是,他的資料結構只有前後兩個指標,那資料放哪? - 下面可以看到它定義的其中一個巨集 - 看到註解的 type 和 name 可以知道這個資料結構是要被放在使用者的資料結構內的 ```cpp /** * list_entry - get the struct for this entry * @ptr: the &struct list_head pointer. * @type: the type of the struct this is embedded in. * @member: the name of the list_head within the struct. */ #define list_entry(ptr, type, member) \ container_of(ptr, type, member) ``` - container_of 定義在 [linux/include/linux/kernel.h](https://github.com/torvalds/linux/blob/a2d79c7174aeb43b13020dd53d85a7aefdd9f3e5/include/linux/kernel.h) - 先判斷使用者給的 ptr 和 type struct 內的 member 是一樣的資料型態,否則會在 compile time 報錯 - 再來 offsetof 得到 type 和 member address 的 offset - 最後 ptr 減掉 offset 得到 type 的 address ```cpp #define container_of(ptr, type, member) ({ \ void *__mptr = (void *)(ptr); \ BUILD_BUG_ON_MSG(!__same_type(*(ptr), ((type *)0)->member) && \ !__same_type(*(ptr), void), \ "pointer type mismatch in container_of()"); \ ((type *)(__mptr - offsetof(type, member))); }) ``` :::info 想請問特別宣告 `void *__mptr = (void *)(ptr)` 寫成 `((type *)(__mptr - offsetof(type, member)))` 有什麼特別用意嗎 直接轉形成 `char *` 或是 `void *` 運算也是可以的吧 `((type *)((void *)ptr - offsetof(type, member))); })` ::: :::warning 試著將不合法的型態帶入巨集,然後觀察會發生什麼事。 對照閱讀: [The Magical container_of() Macro](https://radek.io/2012/11/10/magical-container_of-macro/) :notes: jserv ::: - BUILD_BUG_ON_MSG 定義在 [linux/include/linux/build_bug.h](https://github.com/torvalds/linux/blob/master/include/linux/build_bug.h) - code 中的註解 BUILD_BUG_ON - break compile if a condition is true. ```cpp #define BUILD_BUG_ON_MSG(cond, msg) compiletime_assert(!(cond), msg) #define BUILD_BUG_ON(condition) \ BUILD_BUG_ON_MSG(condition, "BUILD_BUG_ON failed: " #condition) ``` --- ## 參考資料 - [https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=7,6150&techs=SSE](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=7,6150&techs=SSE) - [Header files for x86 SIMD intrinsics](https://stackoverflow.com/questions/11228855/header-files-for-x86-simd-intrinsics) - [c/c++ 程式碼中使用sse指令集加速](https://www.itread01.com/content/1544371404.html)