# 2024q1 Homework4 (quiz3+4)
contributed by < `YangYeh-PD` >
## Problem `1`: Computing the Square Root
### How it works
Suppose that
$$
N^2 = (000a_na_{n-1}a_{n-2}...a_1a_0)^2, \textrm{where }a_m = 2^m.
$$
then
$$
\begin{split}
N^2 & = (a_n+a_{n-1}+a_{n-2}+...+a_1+a_0)^2 \\
& = a_n^2+[2a_n+a_{n-1}]a_{n-1}+[2(a_n+a_{n-1})+a_{n-2}]a_{n-2}+...+\left[2\left(\sum_{i=1}^{n} a_i\right)+a_0\right]a_0 \\
& = a_n^2+[2P_n+a_{n-1}]a_{n-1}+[2(P_{n-1})+a_{n-2}]a_{n-2}+...+\left[2P_1+a_0\right]a_0, \\
\end{split}
$$
$$
\begin{split}
\textrm{where } P_m = a_n + & a_{n-1} + ... + a_m, P_m = P_{m+1} + a_m, \\
\textrm{and } P_0 = a_n + & a_{n-1} + ... + a_0 = N.
\end{split}
$$
Therefore, the problem can be reduced to determine each term $a_m$. If setting the m-th bit to 1 would result in $P_m^2 > N^2$, then set it to 0 and continue to examine the next bit.
$$
\begin{cases}
P_m = P_{m+1} + 2^m &, P_m^2 \leq N^2. \\
P_m = P_{m+1} &, \textrm{otherwise}. \\
\end{cases}
$$
Thus, the simplest way to calculate the square root of a number `N` is to first compute the most significant bit (MSB) of `N` and clear all bits smaller than the MSB. In the binary system, the simplest way to obtain the MSB of a number `N` is to directly **take the logarithm base 2** and then explicitly cast it to an integer.
```c
#include <math.h>
int i_sqrt(int N)
{
int msb = (int) log2(N);
int a = 1 << msb;
```
Then, use a loop to sequentially check whether adding these bits will result in a square greater than `N`, and output the result accordingly.
```c
int result = 0;
while (a != 0) {
if ((result + a) * (result + a) <= N)
result += a;
a >>= 1;
}
return result;
}
```
If we want to avoid using the `log2()` function, we can iterate through `N` and perform **right-shift operations** successively, counting the number of right-shifts until `N` becomes less than or equal to 1.
```c
int i_sqrt(int N)
{
int msb = 0;
int n = N;
while (n > 1) {
n >>= 1;
msb++;
}
int a = 1 << msb;
```
Due to the significant computational complexity of directly calculating whether $N^2 - P_m^2$ is greater than 0, we simplify as follows. Setting
$$
\begin{split}
X_m & = N^2 - P_m^2 \\
& = N^2 - P_{m+1}^2 - (P_m^2 - P_{m+1}^2) \\
& = N^2 - P_{m+1}^2 - [(P_{m+1}+2^m)^2 - P_{m+1}^2] \\
& = N^2 - P_{m+1}^2 - (2P_{m+1}a_m + a_m^2) \\
\end{split}
$$
By setting $Y_m = 2P_{m+1}a_m + a_m^2$, $c_m=2P_{m+1}a_m$ and $d_m=a_m^2$,
$$
Y_m =
\begin{cases}
c_m + d_m &, a_m = 2^m \\
0 &, a_m = 0. \\
\end{cases}
$$
and
$$
c_{m-1} = 2P_ma_{m-1} = P_ma_m =
\begin{cases}
P_{m+1}a_m + a_m^2 = c_m/2 + d_m &, \textrm{if } a_m = 2^m \\
P_{m+1}a_m = c_m/2 &, \textrm{if } a_m = 0. \\
\end{cases}
$$
$$
d_{m-1} = a_{m-1}^2 = (2^{m-1})^2 = 2^{2m-2} = a_m^2/4 = d_m/4.
$$
Next, we attempt to write an algorithm for finding the square root. Since `n` is MSB, the initial conditions are as follows:
* $X_{n+1} = N^2 - P_{n+1} = N^2$.
* $d_n = a_n^2 = (2^n)^2 = 4^n$.
```c
int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL);
```
> But I still don't know why I should add `& ~1UL` at the end of the expression.
[name=ChenYang Yeh] [time=Sun, Mar 17, 2024 06:31 AM]
> Check C standard for integer promotion. :notes: jserv
* $C_n = 2P_{n+1}a_n = 0$ and $C_{-1} = 2P_0a_{-1} = P_0 = N$.
and in each iteration,
* $Y_m = c_m + d_m$.
* $c_{m+1} = c_m / 2$.
* If ($X_{m+1} > Y_m$)
$X_m = N^2 - P_{m+1}^2 - Y_m = X_{m+1} - Y_m$.
$c_{m+1} = c_m / 2 + d_m$.
Therefore, we can express it in the following code
```c
int i_sqrt(int x)
{
if (x <= 1) /* Assume x is always positive */
return x;
/* b = Y_m, z = c_m, m = d_m */
int z = 0;
for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
int b = z + m;
z >>= 1;
if (x >= b)
x -= b, z += m;
}
return z;
}
```
### Further Improvement
Since `__builtin_clz(x)` is a built-in function in GCC, and the result is undefined if `x == 0`, we can implement [ffs](https://man7.org/linux/man-pages/man3/ffs.3.html) function in `<string.h>` to avoid that special case.
```c
for (int m = 1UL << ((31 - ffs(x)) & ~1UL); m; m >>= 2) {
```
which returns `0` when the input has no set bit.
To get rid of the branch, we can make the best use of bitwise operations.
```c
int mask = (b - x - 1) >> 31;
x -= b & mask;
z += m & mask;
```
Since when `x == b`, the result of the right shift operation `(b - x) >> 31` would be `0`. To ensure the correct behavior, we need to subtract `1` from the expression inside the parentheses.
### In Linux Kernel
The Linux kernel's [lib/math/int_sqrt.c](https://github.com/torvalds/linux/blob/master/lib/math/int_sqrt.c) file utilizes a similar method for computing square roots.
When I opened it, I was surprised to find that it actually uses a branch... I initially thought about submitting a patch to modify this branch. However, I found it strange that this commit [aa6159a](https://github.com/torvalds/linux/commit/aa6159ab99a9ab5df835b4750b66cf132a5aa292) has been there for 4 years, and no one seemed to have noticed it before? So, I decided to conduct my own testing and only then did I realize that the performance improvement was not significantly noticeable.
![runtime](https://hackmd.io/_uploads/HyqukGNC6.png)
After compiling these two code snippets into assembly, it was found that the version using bit-masks has **5 more instructions** than the version using branches. Therefore, although the bit-masks version does not have branches, the overall performance cannot be improved because the number of instructions increases.
```c=
.L8:
movl -12(%rbp), %edx
movl -8(%rbp), %eax
addl %edx, %eax
movl %eax, -4(%rbp)
sarl -12(%rbp)
movl -20(%rbp), %eax
cmpl -4(%rbp), %eax
jl .L7
movl -4(%rbp), %eax
subl %eax, -20(%rbp)
movl -8(%rbp), %eax
addl %eax, -12(%rbp)
```
```c=
.L7:
movl -16(%rbp), %edx
movl -12(%rbp), %eax
addl %edx, %eax
movl %eax, -8(%rbp)
sarl -16(%rbp)
movl -8(%rbp), %eax
subl -20(%rbp), %eax
subl $1, %eax
sarl $31, %eax
movl %eax, -4(%rbp)
movl -8(%rbp), %eax
andl -4(%rbp), %eax
subl %eax, -20(%rbp)
movl -12(%rbp), %eax
andl -4(%rbp), %eax
addl %eax, -16(%rbp)
sarl $2, -12(%rbp)
```
In addition, since the kernel uses `unsigned long` instead of integer, performing a right-shift operation results in a ==logical shift==. If an arithmetic shift is desired, it would be necessary to declare another `int mask` or perform an explicit type conversion.
Last but not least, the behavior of arithmetic shift is subject to the compiler being used. This is outlined in Section 5 of the [ISO/IEC 9899 (P)6.5.7 5](https://www.open-std.org/jtc1/sc22/wg14/www/docs/n1124.pdf#page=97).
> The result of **E1 >> E2** is **E1** right-shifted **E2** bit positions. If **E1** has an unsigned type or if E1 has a signed type and a nonnegative value, the value of the result is the integral part of the quotient of $E1 / 2^{E2}$. If **E1** has a signed type and a negative value, the resulting value is implementation-defined.
## Problem `3`: Calculate the logarithm with base 2 I
### How it works
The technique of taking logarithm base 2 has actually been used in problem 1, which involves calculating ==the Most Significant Bit (MSB)== of number `N`, and then subtracting `1` from it.
```c
int ilog2(int i)
{
int log = -1;
while (i) {
i >>= 1;
log++;
}
return log;
}
```
We can also utilize the technique of ==binary search== to expedite the process.
```c
static size_t ilog2(size_t i)
{
size_t result = 0;
while (i >=65536) {
result += 16;
i >>= 16;
}
while (i >= 256) {
result += 8;
i >>= 8;
}
while (i >= 16) {
result += 4;
i >>= 4;
}
while (i >= 2) {
result += 1;
i >>= 1;
}
return result;
}
```
Or using built-in count leading zero function.
```c
int ilog32(uint32_t v)
{
return (31 - __builtin_clz(v | 1));
}
```
Noting that since gcc `clz()` built-in function leads to the undefined result when input is `0`, we should at least give it a number $> 0$. So we choose `v | 1` as input.
### In Linux Kernel
There are two `log2()` implementations in [linux/log2.h](https://github.com/torvalds/linux/blob/master/include/linux/log2.h).
The first one is simple, just return `MSB - 1`.
```c
#ifndef CONFIG_ARCH_HAS_ILOG2_U32
static __always_inline __attribute__((const))
int __ilog2_u32(u32 n)
{
return fls(n) - 1;
}
#endif
```
The second one is bittle bit tricky. It lists out **all the possibilities** of every bit of `n` and then uses a bunch of `? :` ternary operators to perform ==nested if-condition checks==. The reason for this approach might be that in most cases where logarithms base 2 are required, the numbers tend to be large. Therefore, by first evaluating the cases of the higher-order bits, the computation can effectively terminate sooner.
```c
#define const_ilog2(n) \
( \
__builtin_constant_p(n) ? ( \
(n) < 2 ? 0 : \
(n) & (1ULL << 63) ? 63 : \
(n) & (1ULL << 62) ? 62 : \
(n) & (1ULL << 61) ? 61 : \
(n) & (1ULL << 60) ? 60 : \
(n) & (1ULL << 59) ? 59 : \
(n) & (1ULL << 58) ? 58 : \
(n) & (1ULL << 57) ? 57 : \
(n) & (1ULL << 56) ? 56 : \
(n) & (1ULL << 55) ? 55 : \
(n) & (1ULL << 54) ? 54 : \
(n) & (1ULL << 53) ? 53 : \
(n) & (1ULL << 52) ? 52 : \
(n) & (1ULL << 51) ? 51 : \
(n) & (1ULL << 50) ? 50 : \
(n) & (1ULL << 49) ? 49 : \
(n) & (1ULL << 48) ? 48 : \
(n) & (1ULL << 47) ? 47 : \
(n) & (1ULL << 46) ? 46 : \
(n) & (1ULL << 45) ? 45 : \
(n) & (1ULL << 44) ? 44 : \
(n) & (1ULL << 43) ? 43 : \
(n) & (1ULL << 42) ? 42 : \
(n) & (1ULL << 41) ? 41 : \
(n) & (1ULL << 40) ? 40 : \
(n) & (1ULL << 39) ? 39 : \
(n) & (1ULL << 38) ? 38 : \
(n) & (1ULL << 37) ? 37 : \
(n) & (1ULL << 36) ? 36 : \
(n) & (1ULL << 35) ? 35 : \
(n) & (1ULL << 34) ? 34 : \
(n) & (1ULL << 33) ? 33 : \
(n) & (1ULL << 32) ? 32 : \
(n) & (1ULL << 31) ? 31 : \
(n) & (1ULL << 30) ? 30 : \
(n) & (1ULL << 29) ? 29 : \
(n) & (1ULL << 28) ? 28 : \
(n) & (1ULL << 27) ? 27 : \
(n) & (1ULL << 26) ? 26 : \
(n) & (1ULL << 25) ? 25 : \
(n) & (1ULL << 24) ? 24 : \
(n) & (1ULL << 23) ? 23 : \
(n) & (1ULL << 22) ? 22 : \
(n) & (1ULL << 21) ? 21 : \
(n) & (1ULL << 20) ? 20 : \
(n) & (1ULL << 19) ? 19 : \
(n) & (1ULL << 18) ? 18 : \
(n) & (1ULL << 17) ? 17 : \
(n) & (1ULL << 16) ? 16 : \
(n) & (1ULL << 15) ? 15 : \
(n) & (1ULL << 14) ? 14 : \
(n) & (1ULL << 13) ? 13 : \
(n) & (1ULL << 12) ? 12 : \
(n) & (1ULL << 11) ? 11 : \
(n) & (1ULL << 10) ? 10 : \
(n) & (1ULL << 9) ? 9 : \
(n) & (1ULL << 8) ? 8 : \
(n) & (1ULL << 7) ? 7 : \
(n) & (1ULL << 6) ? 6 : \
(n) & (1ULL << 5) ? 5 : \
(n) & (1ULL << 4) ? 4 : \
(n) & (1ULL << 3) ? 3 : \
(n) & (1ULL << 2) ? 2 : \
1) : \
-1)
```
> TODO: check the lecture material about compiler optimizations to determine why optimizing compilers like GCC can generate the corresponding constant result. :notes: jserv
## Problem `5`: Calculate the logarithm with base 2 II
### How it works
The following program again calculate the logarithm with base 2. It use `r | shift` to store the logarith value. Since the logarithm value will not exceed 32 (included) in `uint32_t`, we can only use 5 bits to represent the value.
So the problem is to find whether each bit of `r | shift` should be set or not. To know we should set the bit of `r | shift` or not, we again use ==binary search==.
```c
int ceil_ilog2(uint32_t x)
{
uint32_t r, shift;
x--; /* To avoid the case like x = 0x00008000 */
r = (x > 0xFFFF) << 4; /* If x > 0xFFFF, r = 0b10000 = 16*/
x >>= r; /* x >>= 16 */
shift = (x > 0xFF) << 3; /* If x > 0xFF, shift = 0b1000 = 8 */
x >>= shift; /* x >>= 8 */
r |= shift; /* r = r | shift */
shift = (x > 0xF) << 2; /* If x > 0xF, shift = 0b100 = 4 */
x >>= shift; /* x >>= 4 */
r |= shift; /* r = r | shift */
shift = (x > 0x3) << 1; /* If x > 0x3, shift = 0b10 = 2 */
x >>= shift; /* x >> 2 */
return (r | shift | x > 1) + 1;
/* If the remainder x > 1, than add 1 to r | shift */
}
```
### Further Improvement
However, there has a special case. When the input value is `0`, the result will be `32`, which is wrong.
```
$ ./log2
0
32
```
We can simply alter the condition of `x--`.
```c
x -= 1 && x;
```
> When the input is 0 or 1, there are still issues because according to mathematical definitions, $\log{1} = 0$, and \log{0} is undefined.
> ```
> $./log2
> 0
> 1
> $./log2
> 1
> 1
> ```
> [name=ChenYang Yeh] [time=Sun, Mar 17, 2024 20:34 PM]
### In Linux Kernel
> I cannot find it...
> [name=ChenYang Yeh] [time=Sun, Mar 18, 2024 17:46 PM]
> Search `log_2` `ilog2` and `int_log2` instead. :notes: jserv