# 2024q1 Homework4 (quiz3+4) contributed by `yu-hsiennn` ## quiz 3 > [題目](https://hackmd.io/TvCDjRwvQD-LVzCXq5gA_A) ### 測驗 1 假定 $P_m$ 為欲所求之平方根 $$ \begin{cases} N^2=(a_{n}+a_{n-1}+a_{n-2}+...+{a_0})^2, a_m=2^m\ or\ a_m=0 \\ P_m = a_n + a_{n-1} + ... + a_m \end{cases} $$ 經展開後,化簡可得 $$ N^2 = \sum_{i=0}^{n}a_i^2 + 2\sum_{i=0}^{n-1}P_{i+1}\cdot a_i $$ 令 $X_m$ 為 $N^2 - P_m^2$,且此輪與上一輪差設為 $Y_m$,即 $$ \begin{cases} X_m = N^2 - P_m^2 \\ Y_m = X_{m+1} - X_m = P_m^2 - P_{m+1}^2 = 2P_{m+1}a_m + a_m^2 \end{cases} $$ 我們再將式子利用 $c_m$ 及 $d_m$ 化簡 - $c_m = 2P_{m+1}a_m$ - $d_m = a_m^2$ $$ Y_m = \begin{cases} c_m + d_m & \text{if } a_m=2^m \\ 0 & \text{if } a_m = 0 \end{cases} $$ 再觀察下一輪的變化 $$ \begin{cases} c_{m-1} = 2P_{m}a_{m-1} = 2(P_{m+1}+a_m)a_{m-1} = P_{m+1}a_m + a_m^2 = \dfrac {c_m}{2} + d_m \\ d_{m-1} = \dfrac{d_m}{4} \end{cases} $$ 因此 ```c int i_sqrt(int x) { if (x <= 1) /* Assume x is always positive */ return x; int z = 0; for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) { // m: d_m, z: c_m int b = z + m; z >>= 1; if (x >= b) x -= b, z += m; } return z; } ``` 改用第二周測驗 2 的 [__ffs](https://hackmd.io/@sysprog/linux2024-quiz2) ```diff - for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) + for (int m = 1UL << ((31 - __ffs(x)) & ~1UL); m; m >>= 2) ``` #### 疑惑 $$ c_m = P_{m+1}2^{m+1} \\ d_m = (2^m)^2 \\ Y_m=\left. \begin{cases} c_m+d_m & \text{if } a_m=2^m \\ 0 & \text{if } a_m=0 \end{cases} \right. $$ 在一開始看到這個時,完全無法理解是怎麼跑出這些等號的。 後來才發現,其實有段開,只是 `HackMD` 沒有顯示出來 ### 測驗 2 因為乘法及除法的運算成本很高,故想利用位移及加減的方式去達到相同的效果 此處數字範圍為 `0~19`,只要這一段的計算有達到預期即可 題目是用 `N = 128` 來逼近 `1/10` $1/128 + 1/32 + 1/16 = (1 + 4 + 8) / 128 = 13/128 \approx 1/10$ 也就是說可以 ```c /* 13n / 128 */ q = (n + (n << 2) + (n << 3)) >> 7; ``` 但這邊會忽略掉一些值,如果要精確的話可以把他加回去 ```c d0 = q & 0b1; d1 = q & 0b11; d2 = q & 0b111; q = (n + (n << 2) + (n << 3) + d0 + d1 + d2) >> 7; ``` 不過在範圍 `0~19` 其實並不會影響結果 如果範圍是 `0~19` ,這邊做了一些測試 ```c // test1: n / 16 + n / 32 + 1 / 8 q = ((n << 1) + n + 4) >> 5; // test2: n / 16 + n / 32 + 1 / 16 q = ((n << 1) + n + 2) >> 5; ``` 結果為 ``` q: 0, r: 0 q: 0, r: 1 q: 0, r: 2 q: 0, r: 3 q: 0, r: 4 q: 0, r: 5 q: 0, r: 6 q: 0, r: 7 q: 0, r: 8 q: 0, r: 9 q: 1, r: 0 q: 1, r: 1 q: 1, r: 2 q: 1, r: 3 q: 1, r: 4 q: 1, r: 5 q: 1, r: 6 q: 1, r: 7 q: 1, r: 8 q: 1, r: 9 ``` 而題目 ```c void divmod_10(uint32_t in, uint32_t *div, uint32_t *mod) { uint32_t x = (in | 1) - (in >> 2); /* div = in/10 ==> div = 0.75*in/8 */ uint32_t q = (x >> 4) + x; x = q; q = (q >> 8) + x; q = (q >> 8) + x; q = (q >> 8) + x; q = (q >> 8) + x; *div = (q >> 3); *mod = in - ((q & ~0x7) + (*div << 1)); } ``` 可以看到這邊的作法是 $$ x = \dfrac{3}{4}n + \dfrac{3}{64}n = 0.796875n \approx \dfrac{8}{10}n $$ 所以將 q 右移 3 可得 $\dfrac{8}{10}n * \dfrac{1}{8}n = \dfrac{1}{10}n$ 而最後是在計算 `*mod = in - *div * 10` ### 測驗 3 利用 4 層迴圈來找尋以 2 為底的對數 - 第一個迴圈作用於處理大於等於 `65536` 的數值,直到小於才離開此迴圈 - 第二個迴圈作用於處理大於等於 `256` 的數值,直到小於才離開此迴圈 - 第三個迴圈作用於處理大於等於 `16` 的數值,直到小於才離開此迴圈 - 第四個迴圈作用於處理大於等於 `2` 的數值,直到小於才離開此迴圈,並回傳結果 舉個例子: n = 2049 2049 < 65536, (不執行第一個迴圈) 2049 >= 256, n: 8, result: 8 8 < 16, (不執行第三個迴圈) 8 >= 2, n: 4, result: 9 4 >= 2, n: 2, result: 10 2 >= 2, n: 1, result: 11 1 < 2, 回傳 result: 11 我們也可以換個想法,利用找尋目前 `n` 第一個 `1` 的前方有幾個 `0` (__builtin_clz(n)),再用 31 去減掉前方 `0` 的個數即可 參考 [log2.h](https://elixir.bootlin.com/linux/latest/source/include/linux/log2.h),了解到 linux 核心內部 `ilog2` 的實作方式,其中有以下幾點: - 會先判斷 `n` 是否使用了常數時間來編譯 - 不為常數時間則判斷其大小是否小於 4 byte,來執行後續相對應的 `fls` - `fls` 之所以需要減 1 是因為它是以 `1` 為底 ### 測驗 4 ```c struct ewma *ewma_add(struct ewma *avg, unsigned long val) { avg->internal = avg->internal ? (((avg->internal << avg->weight) - avg->internal) + (val << avg->factor)) >> avg->weight : (val << avg->factor); return avg; } ``` 根據公式 $S_t = \alpha (Y_t + (1 - \alpha)Y_{t-1} + (1 - \alpha)^2Y_{t-2} + \ldots + (1 - \alpha)^kY_{t-k} + \ldots + Y_0)$ $$ S_{t+1} = ((S_t * 2^{avg->weight} - S_t) + Y_t * 2^{avg->factor}) * 2^{-(avg->weight)} $$ 化簡後可得 $$ \begin{cases} S_{t+1} = (1 - 2^{-(avg->weight)})S_t + Y_t * 2^{avg->factor-(avg->weight)} \\ \alpha = 2^{-(avg->weight)} \end{cases} $$ 而 `avg->factor` 的用途在於,不是每個系統都支援浮點運算,因此藉由固定其大小來確保精確度 ### 測驗 5 ```c int ceil_ilog2(uint32_t x) { uint32_t r, shift; x--; r = (x > 0xFFFF) << 4; x >>= r; shift = (x > 0xFF) << 3; x >>= shift; r |= shift; shift = (x > 0xF) << 2; x >>= shift; r |= shift; shift = (x > 0x3) << 1; x >>= shift; return (r | shift | x > 1) + 1; } ``` 開頭的 `x--` 用意在於,若是 `x` 為 2 的冪,最後的 `+1` 會造成結果錯誤,所以先將 `x--` 來避免這個情況發生。 而利用後面一連串的判斷,來快速找出 `x` 的對數。 最後的回傳會判斷 `x` 是否還有值,且合併 `r` 與 `shift` 的值,並加上最後的進位 `1`。 #### 處理 `x = 0` 的狀況 可以發現到,當我們直接帶入 `0` 及 `1` 時,程式會回傳結果 `1`,而這並不是我們所預期的,於是可以如下修改 ```diff - x--; + x -= likely(x); /* x -= !!x */ - return (r | shift | x > 1) + 1; + return (r | shift | x > 1) + !(x == 0); ``` 想法是: - 先判斷是否為 `0` ,為 `0` 則不做 -1 - 最後回傳時再判斷 `x` 是否為 `0`,是則不做 +1 如此,即可以避免當 `x` 為 `0` 和`1`,回傳的錯誤結果 `1`,並維持 branchless --- ## quiz 4 > [題目](https://hackmd.io/f0JmelVzROSSw-mKnGZxSQ) ### 測驗 1 ```c n = (v >> 1) & 0x77777777; v -= n; n = (n >> 1) & 0x77777777; v -= n; n = (n >> 1) & 0x77777777; v -= n; ``` 計算方式採取每 4 個位元為單位, $v = (b_{31}b_{30}b_{29}b_{28})...b_{0}$ 右移 1 且只把最高位元(以 4 位元為單位)給變成0,以防其影響計算結果, $n = (0b_{31}b_{30}b_{29})(0b_{27}...b_{1})$ $v - n = v - \lfloor{{\dfrac{v}{2}}}\rfloor$ 跑完這段即可得, $v - \lfloor{{\dfrac{v}{2}}}\rfloor - \lfloor{{\dfrac{v}{4}}}\rfloor - \lfloor{{\dfrac{v}{8}}}\rfloor$ 接下來會執行 `v = (v + (v >> 4)) & 0x0F0F0F0F` 令 $v = B_7B_6B_5B_4B_3B_2B_1B_0$, ($B_n = b_{4n+3}b_{4n+2}b_{4n+1}b_{4n+0}$) 所以 $v = 0(B_7+B_6)0(B_5+B_4)0(B_3+B_2)0(B_1+B_0)$ 最後將 `v *= 0x01010101, v >> 24`,即可得 $popcount = B_7+B_6+B_5+B_4+B_3+B_2+B_2+B_0$ #### 撰寫出更高效的程式碼 [477. Total Hamming Distance](https://leetcode.com/problems/total-hamming-distance/) ```cpp int totalHammingDistance(vector<int>& nums) { int res = 0, n = nums.size(); /* TLE: O(n^2) * for(int i=0;i<nums.size();i++){ * for(int j=i+1;j<nums.size();j++){ * int xOR = nums[i]^nums[j]; * ans += __builtin_popcount(xOR); * } * } */ for (int i = 0; i < 32; i++) { int count = 0; for (int j = 0; j < n; j++) count += (nums[j] >> i) & 1; res += count * (n - count); } return res; } ``` 時間複雜度為 $O(32n) = O(n)$ ### 測驗 2 $$ 2^k \equiv \begin{cases} 1 (mod \ \ 3), \ \ k \ \ even\\ -1 (mod \ \ 3), \ \ k \ \ odd\\ \end{cases} $$ 令 $N = b_{n-1}b_{n-2}\cdots b_1 b_0$, 可以改寫成 $N \equiv \sum_{i=0}^{n-1} b_i\cdot (-1)^i \ \ (mod \ \ 3)$ 我們可以利用 `n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA)`,來達到相同效果 其中還可以利用以下定理來進一步的化簡 $$ popcount(x \And \overline{m}) - popcount(x \And m) = popcount(x \oplus m) - popcount(m) $$ 即 `popcount(m) = 16, (m = 0xAAAAAAAA)`,而為了避免 `n = popcount(n ^ 0xAAAAAAAA) - 16` 結果為負數,我們可以加入一個足夠大的 3 的倍數的數字給 `-16`,範例為 `39 - 16 -> 23` 再來執行 `n = popcount(n ^ 0x2A) - 3` 來限縮範圍至 `-3 ~ 3`,最後的 `return n + ((n >> 31) & 3)` 再將負的轉為正的 其中,參考 [mesohandsome](https://hackmd.io/@Peng-You/linux2024-homework4) 的筆記,發現到要將 `n` 做轉型才會得到正確結果,即 ```diff - return n + ((n >> 31) & 3); + return n + ((int)n >> 31) & 3; ``` ```c static inline uint32_t is_win(uint32_t player_board) { return (player_board + 0x11111111) & 0x88888888; } ``` 當連線時,會是 `0x7`,為了 `0x7 + ? == 0x8` 要為 `true`,所以 `?` 應為 `0x1` ```c static inline int mod7(uint32_t x) { x = (x >> 15) + (x & UINT32_C(0x7FFF)); /* Take reminder as (mod 8) by mul/shift. Since the multiplier * was calculated using ceil() instead of floor(), it skips the * value '7' properly. * M <- ceil(ldexp(8/7, 29)) */ return (int) ((x * UINT32_C(0x24924925)) >> 29); } ``` 先將 `x` 右移 `15` 來保留其商數,再利用 `x & UINT32_C(0x7FFF)` 來保留最低的 `15` 個位元 ### 測驗 3 #### 函式功用 - `xt_create`: 初始化一棵新的 `XTree` - `xt_destroy`: 利用遞迴的方式來將樹的每個節點的資源給釋放掉 - `xt_frist`: 找到樹中最小值的節點(最左邊) - `xt_last`: 找到樹中最大值的節點(最右邊) - `xt_rotate_left`: 執行下圖操作 ```graphviz digraph Tree { graph[ordering=out] p -> n n -> l n -> r l -> A l -> B label=origin } ``` ```graphviz digraph Tree { graph[ordering=out] p -> l l -> A l -> n n -> B n -> r label="After Rotate Left" } ``` - `xt_rotate_right`: 執行下圖操作 ```graphviz digraph Tree { graph[ordering=out] p -> n n -> l n -> r r -> A r -> B label=origin } ``` ```graphviz digraph Tree { graph[ordering=out] p -> r r -> n r -> B n -> l n -> A label="After Rotate Right" } ``` - `xt_balance`: 回傳左子樹樹高減去右子樹樹高 - `xt_update`: 確認樹是否平衡,不平衡則執行對應的旋轉操作 - `xt_insert`: 將新節點插入進樹中 - `xt_remove`: 將對應的節點從樹中移除 ```c static void __xt_remove(struct xt_node **root, struct xt_node *del) { if (xt_right(del)) { struct xt_node *least = xt_first(xt_right(del)); if (del == *root) *root = least; xt_replace_right(del, AAAA); xt_update(root, BBBB); return; } if (xt_left(del)) { struct xt_node *most = xt_last(xt_left(del)); if (del == *root) *root = most; xt_replace_left(del, CCCC); xt_update(root, DDDD); return; } if (del == *root) { *root = 0; return; } /* empty node */ struct xt_node *parent = xt_parent(del); if (xt_left(parent) == del) xt_left(parent) = 0; else xt_right(parent) = 0; xt_update(EEEE, FFFF); } ``` 想法是若欲刪除之節點在右子樹,則用右子樹的最小來填補,即 `least` ,而填補完需要確認樹是否平衡,所以執行 `xt_update(root, xt_right(least))`, 反之,則用左子樹的最大,即 `most`,且更新 `xt_update(root, xt_left(most))` 最後在判斷整棵樹有沒有平衡,即 `xt_update(root, parent)`