# 2024q1 Homework4 (quiz3+4) contributed by < `YangYeh-PD` > ## Problem `1`: Computing the Square Root ### How it works Suppose that $$ N^2 = (000a_na_{n-1}a_{n-2}...a_1a_0)^2, \textrm{where }a_m = 2^m. $$ then $$ \begin{split} N^2 & = (a_n+a_{n-1}+a_{n-2}+...+a_1+a_0)^2 \\ & = a_n^2+[2a_n+a_{n-1}]a_{n-1}+[2(a_n+a_{n-1})+a_{n-2}]a_{n-2}+...+\left[2\left(\sum_{i=1}^{n} a_i\right)+a_0\right]a_0 \\ & = a_n^2+[2P_n+a_{n-1}]a_{n-1}+[2(P_{n-1})+a_{n-2}]a_{n-2}+...+\left[2P_1+a_0\right]a_0, \\ \end{split} $$ $$ \begin{split} \textrm{where } P_m = a_n + & a_{n-1} + ... + a_m, P_m = P_{m+1} + a_m, \\ \textrm{and } P_0 = a_n + & a_{n-1} + ... + a_0 = N. \end{split} $$ Therefore, the problem can be reduced to determine each term $a_m$. If setting the m-th bit to 1 would result in $P_m^2 > N^2$, then set it to 0 and continue to examine the next bit. $$ \begin{cases} P_m = P_{m+1} + 2^m &, P_m^2 \leq N^2. \\ P_m = P_{m+1} &, \textrm{otherwise}. \\ \end{cases} $$ Thus, the simplest way to calculate the square root of a number `N` is to first compute the most significant bit (MSB) of `N` and clear all bits smaller than the MSB. In the binary system, the simplest way to obtain the MSB of a number `N` is to directly **take the logarithm base 2** and then explicitly cast it to an integer. ```c #include <math.h> int i_sqrt(int N) { int msb = (int) log2(N); int a = 1 << msb; ``` Then, use a loop to sequentially check whether adding these bits will result in a square greater than `N`, and output the result accordingly. ```c int result = 0; while (a != 0) { if ((result + a) * (result + a) <= N) result += a; a >>= 1; } return result; } ``` If we want to avoid using the `log2()` function, we can iterate through `N` and perform **right-shift operations** successively, counting the number of right-shifts until `N` becomes less than or equal to 1. ```c int i_sqrt(int N) { int msb = 0; int n = N; while (n > 1) { n >>= 1; msb++; } int a = 1 << msb; ``` Due to the significant computational complexity of directly calculating whether $N^2 - P_m^2$ is greater than 0, we simplify as follows. Setting $$ \begin{split} X_m & = N^2 - P_m^2 \\ & = N^2 - P_{m+1}^2 - (P_m^2 - P_{m+1}^2) \\ & = N^2 - P_{m+1}^2 - [(P_{m+1}+2^m)^2 - P_{m+1}^2] \\ & = N^2 - P_{m+1}^2 - (2P_{m+1}a_m + a_m^2) \\ \end{split} $$ By setting $Y_m = 2P_{m+1}a_m + a_m^2$, $c_m=2P_{m+1}a_m$ and $d_m=a_m^2$, $$ Y_m = \begin{cases} c_m + d_m &, a_m = 2^m \\ 0 &, a_m = 0. \\ \end{cases} $$ and $$ c_{m-1} = 2P_ma_{m-1} = P_ma_m = \begin{cases} P_{m+1}a_m + a_m^2 = c_m/2 + d_m &, \textrm{if } a_m = 2^m \\ P_{m+1}a_m = c_m/2 &, \textrm{if } a_m = 0. \\ \end{cases} $$ $$ d_{m-1} = a_{m-1}^2 = (2^{m-1})^2 = 2^{2m-2} = a_m^2/4 = d_m/4. $$ Next, we attempt to write an algorithm for finding the square root. Since `n` is MSB, the initial conditions are as follows: * $X_{n+1} = N^2 - P_{n+1} = N^2$. * $d_n = a_n^2 = (2^n)^2 = 4^n$. ```c int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); ``` > But I still don't know why I should add `& ~1UL` at the end of the expression. [name=ChenYang Yeh] [time=Sun, Mar 17, 2024 06:31 AM] > Check C standard for integer promotion. :notes: jserv * $C_n = 2P_{n+1}a_n = 0$ and $C_{-1} = 2P_0a_{-1} = P_0 = N$. and in each iteration, * $Y_m = c_m + d_m$. * $c_{m+1} = c_m / 2$. * If ($X_{m+1} > Y_m$) $X_m = N^2 - P_{m+1}^2 - Y_m = X_{m+1} - Y_m$. $c_{m+1} = c_m / 2 + d_m$. Therefore, we can express it in the following code ```c int i_sqrt(int x) { if (x <= 1) /* Assume x is always positive */ return x; /* b = Y_m, z = c_m, m = d_m */ int z = 0; for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) { int b = z + m; z >>= 1; if (x >= b) x -= b, z += m; } return z; } ``` ### Further Improvement Since `__builtin_clz(x)` is a built-in function in GCC, and the result is undefined if `x == 0`, we can implement [ffs](https://man7.org/linux/man-pages/man3/ffs.3.html) function in `<string.h>` to avoid that special case. ```c for (int m = 1UL << ((31 - ffs(x)) & ~1UL); m; m >>= 2) { ``` which returns `0` when the input has no set bit. To get rid of the branch, we can make the best use of bitwise operations. ```c int mask = (b - x - 1) >> 31; x -= b & mask; z += m & mask; ``` Since when `x == b`, the result of the right shift operation `(b - x) >> 31` would be `0`. To ensure the correct behavior, we need to subtract `1` from the expression inside the parentheses. ### In Linux Kernel The Linux kernel's [lib/math/int_sqrt.c](https://github.com/torvalds/linux/blob/master/lib/math/int_sqrt.c) file utilizes a similar method for computing square roots. When I opened it, I was surprised to find that it actually uses a branch... I initially thought about submitting a patch to modify this branch. However, I found it strange that this commit [aa6159a](https://github.com/torvalds/linux/commit/aa6159ab99a9ab5df835b4750b66cf132a5aa292) has been there for 4 years, and no one seemed to have noticed it before? So, I decided to conduct my own testing and only then did I realize that the performance improvement was not significantly noticeable. ![runtime](https://hackmd.io/_uploads/HyqukGNC6.png) After compiling these two code snippets into assembly, it was found that the version using bit-masks has **5 more instructions** than the version using branches. Therefore, although the bit-masks version does not have branches, the overall performance cannot be improved because the number of instructions increases. ```c= .L8: movl -12(%rbp), %edx movl -8(%rbp), %eax addl %edx, %eax movl %eax, -4(%rbp) sarl -12(%rbp) movl -20(%rbp), %eax cmpl -4(%rbp), %eax jl .L7 movl -4(%rbp), %eax subl %eax, -20(%rbp) movl -8(%rbp), %eax addl %eax, -12(%rbp) ``` ```c= .L7: movl -16(%rbp), %edx movl -12(%rbp), %eax addl %edx, %eax movl %eax, -8(%rbp) sarl -16(%rbp) movl -8(%rbp), %eax subl -20(%rbp), %eax subl $1, %eax sarl $31, %eax movl %eax, -4(%rbp) movl -8(%rbp), %eax andl -4(%rbp), %eax subl %eax, -20(%rbp) movl -12(%rbp), %eax andl -4(%rbp), %eax addl %eax, -16(%rbp) sarl $2, -12(%rbp) ``` In addition, since the kernel uses `unsigned long` instead of integer, performing a right-shift operation results in a ==logical shift==. If an arithmetic shift is desired, it would be necessary to declare another `int mask` or perform an explicit type conversion. Last but not least, the behavior of arithmetic shift is subject to the compiler being used. This is outlined in Section 5 of the [ISO/IEC 9899 (P)6.5.7 5](https://www.open-std.org/jtc1/sc22/wg14/www/docs/n1124.pdf#page=97). > The result of **E1 >> E2** is **E1** right-shifted **E2** bit positions. If **E1** has an unsigned type or if E1 has a signed type and a nonnegative value, the value of the result is the integral part of the quotient of $E1 / 2^{E2}$. If **E1** has a signed type and a negative value, the resulting value is implementation-defined. ## Problem `3`: Calculate the logarithm with base 2 I ### How it works The technique of taking logarithm base 2 has actually been used in problem 1, which involves calculating ==the Most Significant Bit (MSB)== of number `N`, and then subtracting `1` from it. ```c int ilog2(int i) { int log = -1; while (i) { i >>= 1; log++; } return log; } ``` We can also utilize the technique of ==binary search== to expedite the process. ```c static size_t ilog2(size_t i) { size_t result = 0; while (i >=65536) { result += 16; i >>= 16; } while (i >= 256) { result += 8; i >>= 8; } while (i >= 16) { result += 4; i >>= 4; } while (i >= 2) { result += 1; i >>= 1; } return result; } ``` Or using built-in count leading zero function. ```c int ilog32(uint32_t v) { return (31 - __builtin_clz(v | 1)); } ``` Noting that since gcc `clz()` built-in function leads to the undefined result when input is `0`, we should at least give it a number $> 0$. So we choose `v | 1` as input. ### In Linux Kernel There are two `log2()` implementations in [linux/log2.h](https://github.com/torvalds/linux/blob/master/include/linux/log2.h). The first one is simple, just return `MSB - 1`. ```c #ifndef CONFIG_ARCH_HAS_ILOG2_U32 static __always_inline __attribute__((const)) int __ilog2_u32(u32 n) { return fls(n) - 1; } #endif ``` The second one is bittle bit tricky. It lists out **all the possibilities** of every bit of `n` and then uses a bunch of `? :` ternary operators to perform ==nested if-condition checks==. The reason for this approach might be that in most cases where logarithms base 2 are required, the numbers tend to be large. Therefore, by first evaluating the cases of the higher-order bits, the computation can effectively terminate sooner. ```c #define const_ilog2(n) \ ( \ __builtin_constant_p(n) ? ( \ (n) < 2 ? 0 : \ (n) & (1ULL << 63) ? 63 : \ (n) & (1ULL << 62) ? 62 : \ (n) & (1ULL << 61) ? 61 : \ (n) & (1ULL << 60) ? 60 : \ (n) & (1ULL << 59) ? 59 : \ (n) & (1ULL << 58) ? 58 : \ (n) & (1ULL << 57) ? 57 : \ (n) & (1ULL << 56) ? 56 : \ (n) & (1ULL << 55) ? 55 : \ (n) & (1ULL << 54) ? 54 : \ (n) & (1ULL << 53) ? 53 : \ (n) & (1ULL << 52) ? 52 : \ (n) & (1ULL << 51) ? 51 : \ (n) & (1ULL << 50) ? 50 : \ (n) & (1ULL << 49) ? 49 : \ (n) & (1ULL << 48) ? 48 : \ (n) & (1ULL << 47) ? 47 : \ (n) & (1ULL << 46) ? 46 : \ (n) & (1ULL << 45) ? 45 : \ (n) & (1ULL << 44) ? 44 : \ (n) & (1ULL << 43) ? 43 : \ (n) & (1ULL << 42) ? 42 : \ (n) & (1ULL << 41) ? 41 : \ (n) & (1ULL << 40) ? 40 : \ (n) & (1ULL << 39) ? 39 : \ (n) & (1ULL << 38) ? 38 : \ (n) & (1ULL << 37) ? 37 : \ (n) & (1ULL << 36) ? 36 : \ (n) & (1ULL << 35) ? 35 : \ (n) & (1ULL << 34) ? 34 : \ (n) & (1ULL << 33) ? 33 : \ (n) & (1ULL << 32) ? 32 : \ (n) & (1ULL << 31) ? 31 : \ (n) & (1ULL << 30) ? 30 : \ (n) & (1ULL << 29) ? 29 : \ (n) & (1ULL << 28) ? 28 : \ (n) & (1ULL << 27) ? 27 : \ (n) & (1ULL << 26) ? 26 : \ (n) & (1ULL << 25) ? 25 : \ (n) & (1ULL << 24) ? 24 : \ (n) & (1ULL << 23) ? 23 : \ (n) & (1ULL << 22) ? 22 : \ (n) & (1ULL << 21) ? 21 : \ (n) & (1ULL << 20) ? 20 : \ (n) & (1ULL << 19) ? 19 : \ (n) & (1ULL << 18) ? 18 : \ (n) & (1ULL << 17) ? 17 : \ (n) & (1ULL << 16) ? 16 : \ (n) & (1ULL << 15) ? 15 : \ (n) & (1ULL << 14) ? 14 : \ (n) & (1ULL << 13) ? 13 : \ (n) & (1ULL << 12) ? 12 : \ (n) & (1ULL << 11) ? 11 : \ (n) & (1ULL << 10) ? 10 : \ (n) & (1ULL << 9) ? 9 : \ (n) & (1ULL << 8) ? 8 : \ (n) & (1ULL << 7) ? 7 : \ (n) & (1ULL << 6) ? 6 : \ (n) & (1ULL << 5) ? 5 : \ (n) & (1ULL << 4) ? 4 : \ (n) & (1ULL << 3) ? 3 : \ (n) & (1ULL << 2) ? 2 : \ 1) : \ -1) ``` > TODO: check the lecture material about compiler optimizations to determine why optimizing compilers like GCC can generate the corresponding constant result. :notes: jserv ## Problem `5`: Calculate the logarithm with base 2 II ### How it works The following program again calculate the logarithm with base 2. It use `r | shift` to store the logarith value. Since the logarithm value will not exceed 32 (included) in `uint32_t`, we can only use 5 bits to represent the value. So the problem is to find whether each bit of `r | shift` should be set or not. To know we should set the bit of `r | shift` or not, we again use ==binary search==. ```c int ceil_ilog2(uint32_t x) { uint32_t r, shift; x--; /* To avoid the case like x = 0x00008000 */ r = (x > 0xFFFF) << 4; /* If x > 0xFFFF, r = 0b10000 = 16*/ x >>= r; /* x >>= 16 */ shift = (x > 0xFF) << 3; /* If x > 0xFF, shift = 0b1000 = 8 */ x >>= shift; /* x >>= 8 */ r |= shift; /* r = r | shift */ shift = (x > 0xF) << 2; /* If x > 0xF, shift = 0b100 = 4 */ x >>= shift; /* x >>= 4 */ r |= shift; /* r = r | shift */ shift = (x > 0x3) << 1; /* If x > 0x3, shift = 0b10 = 2 */ x >>= shift; /* x >> 2 */ return (r | shift | x > 1) + 1; /* If the remainder x > 1, than add 1 to r | shift */ } ``` ### Further Improvement However, there has a special case. When the input value is `0`, the result will be `32`, which is wrong. ``` $ ./log2 0 32 ``` We can simply alter the condition of `x--`. ```c x -= 1 && x; ``` > When the input is 0 or 1, there are still issues because according to mathematical definitions, $\log{1} = 0$, and \log{0} is undefined. > ``` > $./log2 > 0 > 1 > $./log2 > 1 > 1 > ``` > [name=ChenYang Yeh] [time=Sun, Mar 17, 2024 20:34 PM] ### In Linux Kernel > I cannot find it... > [name=ChenYang Yeh] [time=Sun, Mar 18, 2024 17:46 PM] > Search `log_2` `ilog2` and `int_log2` instead. :notes: jserv