# [2021q1](http://wiki.csie.ncku.edu.tw/linux/schedule) 第 6 週測驗題
###### tags: `linux2021`
:::info
目的: 檢驗學員對數值處理, Linux 行程, hash table 的認知
:::
==[作答表單](https://docs.google.com/forms/d/e/1FAIpQLSe8fqVoE_np1Mioa-mnuc-bPJae7hmL_rvdslnPLIvdT3ZZnw/viewform)==
### 測驗 `1`
在第 3 週作業 [fibdrv](https://hackmd.io/@sysprog/linux2021-fibdrv) 中,談及 [arbitrary-precision arithmetic](https://en.wikipedia.org/wiki/Arbitrary-precision_arithmetic) (即「大數運算」),以下是另一個實作方式,用二進位來表示和處理內部數值,考慮階乘 (factorial) 運算:
```cpp
/*
* factorial(N) := N * (N-1) * (N-2) * ... * 1
*/
static void factorial(struct bn *n, struct bn *res) {
struct bn tmp;
bn_assign(&tmp, n);
bn_dec(n);
while (!bn_is_zero(n)) {
bn_mul(&tmp, n, res); /* res = tmp * n */
bn_dec(n); /* n -= 1 */
bn_assign(&tmp, res); /* tmp = res */
}
bn_assign(res, &tmp);
}
int main() {
struct bn num, result;
char buf[8192];
bn_from_int(&num, 100);
factorial(&num, &result);
bn_to_str(&result, buf, sizeof(buf));
printf("factorial(100) = %s\n", buf);
return 0;
}
```
需要注意的是,這個實作不處理十進位數值輸出,預期 $100!$ 的輸出為十六進位表示法:
> 1b30964ec395dc24069528d54bbda40d16e966ef9a70eb21b5b2943a321cdf10391745570cca9420c6ecb3b72ed2ee8b02ea2735c61a000000000000000000000000
我們可透過 Python 內建的大數運算來檢驗: (執行 `$ python3` 後,在命令提示中輸入以下敘述)
```python
import math
"%x" % math.factorial(100)
```
以下是對應的大數運算實作程式碼,針對 little-endian 的機器,並以 `x86_64` 架構為主要測試標的:
```cpp
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
/* how large the underlying array size should be */
#define UNIT_SIZE 4
/* These are dedicated to UNIT_SIZE == 4 */
#define UTYPE uint32_t
#define UTYPE_TMP uint64_t
#define FORMAT_STRING "%.08x"
#define MAX_VAL ((UTYPE_TMP) 0xFFFFFFFF)
#define BN_ARRAY_SIZE (128 / UNIT_SIZE) /* size of big-numbers in bytes */
/* bn consists of array of TYPE */
struct bn { UTYPE array[BN_ARRAY_SIZE]; };
static inline void bn_init(struct bn *n) {
for (int i = 0; i < BN_ARRAY_SIZE; ++i)
n->array[i] = 0;
}
static inline void bn_from_int(struct bn *n, UTYPE_TMP i) {
bn_init(n);
/* FIXME: what if machine is not little-endian? */
n->array[0] = i;
/* bit-shift with U64 operands to force 64-bit results */
UTYPE_TMP tmp = i >> 32;
n->array[1] = tmp;
}
static void bn_to_str(struct bn *n, char *str, int nbytes) {
/* index into array - reading "MSB" first -> big-endian */
int j = BN_ARRAY_SIZE - 1;
int i = 0; /* index into string representation */
/* reading last array-element "MSB" first -> big endian */
while ((j >= 0) && (nbytes > (i + 1))) {
sprintf(&str[i], FORMAT_STRING, n->array[j]);
i += (2 *
UNIT_SIZE); /* step UNIT_SIZE hex-byte(s) forward in the string */
j -= 1; /* step one element back in the array */
}
/* Count leading zeros: */
for (j = 0; str[j] == '0'; j++)
;
/* Move string j places ahead, effectively skipping leading zeros */
for (i = 0; i < (nbytes - j); ++i)
str[i] = str[i + j];
str[i] = 0;
}
/* Decrement: subtract 1 from n */
static void bn_dec(struct bn *n) {
for (int i = 0; i < BN_ARRAY_SIZE; ++i) {
UTYPE tmp = n->array[i];
UTYPE res = tmp - 1;
n->array[i] = res;
COND;
}
}
static void bn_add(struct bn *a, struct bn *b, struct bn *c) {
int carry = 0;
for (int i = 0; i < BN_ARRAY_SIZE; ++i) {
UTYPE_TMP tmp = (UTYPE_TMP) a->array[i] + b->array[i] + carry;
carry = (tmp > MAX_VAL);
c->array[i] = (tmp & MAX_VAL);
}
}
static inline void lshift_unit(struct bn *a, int n_units) {
int i;
/* Shift whole units */
for (i = (BN_ARRAY_SIZE - 1); i >= n_units; --i)
a->array[i] = a->array[i - n_units];
/* Zero pad shifted units */
for (; i >= 0; --i)
a->array[i] = 0;
}
static void bn_mul(struct bn *a, struct bn *b, struct bn *c) {
struct bn row, tmp;
bn_init(c);
for (int i = 0; i < BN_ARRAY_SIZE; ++i) {
bn_init(&row);
for (int j = 0; j < BN_ARRAY_SIZE; ++j) {
if (i + j < BN_ARRAY_SIZE) {
bn_init(&tmp);
III;
bn_from_int(&tmp, intermediate);
lshift_unit(&tmp, i + j);
bn_add(&tmp, &row, &row);
}
}
bn_add(c, &row, c);
}
}
static bool bn_is_zero(struct bn *n) {
for (int i = 0; i < BN_ARRAY_SIZE; ++i)
if (n->array[i])
return false;
return true;
}
/* Copy src into dst. i.e. dst := src */
static void bn_assign(struct bn *dst, struct bn *src) {
for (int i = 0; i < BN_ARRAY_SIZE; ++i)
dst->array[i] = src->array[i];
}
```
==作答區==
COND = ?
* `(a)` `if (res == tmp)) break`
* `(b)` `/* no operation */`
* `(c)` `if (!(res > tmp)) break`
* `(d)` `if (res > tmp) break`
III = ?
* `(a)` `UTYPE intermediate = a->array[i] * b->array[j]`
* `(b)` `UTYPE_TMP intermediate = a->array[i] * b->array[j]`
* `(c)` `UTYPE_TMP intermediate = a->array[i] * (UTYPE_TMP) b->array[j]`
* `(d)` `UTYPE intermediate = (UTYPE_TMP) a->array[i] * b->array[j]`
* `(e)` `UTYPE_TMP intermediate = a->array[i] + b->array[j]`
:::success
延伸問題:
1. 解釋上述程式碼運作原理,指出設計和實作的缺失並改進
2. [sysprog21/bignum](https://github.com/sysprog21/bignum) 提供一套更高效率的 [arbitrary-precision arithmetic](https://en.wikipedia.org/wiki/Arbitrary-precision_arithmetic) 實作,其中 [format.c](https://github.com/sysprog21/bignum/blob/master/format.c) 能夠將內部的二進位表示法轉為十進位輸出,請學習該手法,並改寫上述程式碼,以允許階乘運算的十進位輸出
3. 乘法運算在 [arbitrary-precision arithmetic](https://en.wikipedia.org/wiki/Arbitrary-precision_arithmetic) 往往相當耗時,於是 [Karatsuba algorithm](https://en.wikipedia.org/wiki/Karatsuba_algorithm) 和 [Schönhage–Strassen algorithm](https://en.wikipedia.org/wiki/Sch%C3%B6nhage%E2%80%93Strassen_algorithm) 這類的快速乘法演算法相繼提出,請引入到上述程式碼。可參考 [sysprog21/bignum](https://github.com/sysprog21/bignum) 的 [mul.c](https://github.com/sysprog21/bignum/blob/master/mul.c)
4. 在 Linux 核心原始程式碼中,找出大數運算的案例,可參見 [lib/mpi](https://github.com/torvalds/linux/tree/master/lib/mpi) 目錄
:::
---
### 測驗 `2`
[LeetCode](https://leetcode.com/) 編號 1 的題目 [Two Sum](https://leetcode.com/problems/two-sum/),貌似簡單,作為 LeetCode 的開篇之題,乃是經典中的經典,正所謂「平生不識 [Two Sum](https://leetcode.com/problems/two-sum/),刷盡 [LeetCode](https://leetcode.com/) 也枉然」,就像英語單詞書的第一個單詞總是 [Abandon](https://www.dictionary.com/browse/abandon) 一樣,很多沒有毅力堅持的人就只能記住這一個單詞,所以通常情況下單詞書就前幾頁有翻動的痕跡,後面都是[嶄新如初](https://en.wikipedia.org/wiki/Mint_condition),道理不需多講,雞湯不必多灌,明白的人自然明白。
> 以上說法取自 [Two Sum 兩數之和](https://www.cnblogs.com/grandyang/p/4130379.html)
> [mint condition](https://en.wikipedia.org/wiki/Mint_condition): "mint" 除了薄荷的意思,還可指鑄幣廠,"mint condition" 裡的 “mint” 就與鑄幣廠有關。有些人收集錢幣會在錢幣剛開始發行時收集,因爲這樣的錢幣看起來很新,他們會用 "mint condition" 來形容這種錢幣的狀況,強調「像剛從鑄幣廠出來」,後來衍伸出「有如新一樣的二手商品」的意涵。
題意是給定一個陣列 `nums` 和一個目標值 `target`,求找到 `nums` 的 2 個元素相加會等於 target 的索引值。題目確保必為單一解,且回傳索引的順序沒差異。例如給定輸入 `nums = [2, 7, 11, 15]`, `target = 9`,相加變成 `9` 的元素僅有 `2` 及 `7`,因此回傳這二個元素的索引值 `[0, 1]`
考慮以下 C 語言實作:
```cpp
#include <stdlib.h>
static int cmp(const void *lhs, const void *rhs) {
if (*(int *) lhs == *(int *) rhs)
return 0;
return *(int *) lhs < *(int *) rhs ? -1 : 1;
}
static int *alloc_wrapper(int a, int b, int *returnSize) {
*returnSize = 2;
int *res = (int *) malloc(sizeof(int) * 2);
res[0] = a, res[1] = b;
return res;
}
int *twoSum(int *nums, int numsSize, int target, int *returnSize)
{
*returnSize = 2;
int arr[numsSize][2]; /* {value, index} pair */
for (int i = 0; i < numsSize; ++i) {
arr[i][0] = nums[i];
arr[i][1] = i;
}
qsort(arr, numsSize, sizeof(arr[0]), cmp);
for (int i = 0, j = numsSize - 1; i < j; ) {
if (arr[i][0] + arr[j][0] == target)
return alloc_wrapper(arr[i][1], arr[j][1], returnSize);
if (arr[i][0] + arr[j][0] < target)
++i;
else
--j;
}
*returnSize = 0;
return NULL;
}
```
提交到 [LeetCode](https://leetcode.com/) 線上評分系統,得到以下回應:
![](https://i.imgur.com/9Gt5dw5.png)
沒拿到 [PR=99](http://www.stat.nuk.edu.tw/SouthShow.asp?myid=1503),心有不甘!想辦法改進。
若用暴力法,時間複雜度為 $O(n^2)$,顯然不符合期待。我們可改用 [hash table](https://en.wikipedia.org/wiki/Hash_table) (以下簡稱 `HT`) 記錄缺少的那一個值 (即 `target - nums[i]`) 和對應的索引。考慮以下案例:
> nums = `[2, 11, 7, 15]`:
對應的步驟:
1. `nums[0]` 是 `2`,`HT[2]` 不存在,於是建立 `HT[9 - 2] = 0`
2. `nums[1]`是 `11`, `HT[11]` 不存在,於是建立 `HT[9 - 11] = 1`
3. `nums[2]` 是 `7`,`HT[7]` 存在 (設定於步驟 `1`),於是回傳 `[2, HT[7]] = [2, 0]`
![](https://i.imgur.com/5FQZ6Lo.png)
`hlist` 用於 hash table 的實作,它的資料結構定義在 [include/linux/types.h](https://github.com/torvalds/linux/blob/master/include/linux/types.h) 中:
```cpp
struct hlist_head {
struct hlist_node *first;
};
struct hlist_node {
struct hlist_node *next, **pprev;
};
```
示意如下:
![](https://i.imgur.com/QYQLqvC.png)
`hlist` 的操作與 `list` 一樣定義於 [include/linux/list.h](https://github.com/torvalds/linux/blob/master/include/linux/list.h),以 `hlist_` 開頭。`hlist_head` 和 `hlist_node` 用於 hash table 中 bucket 的實作,具有相同 hash value 的節點會放在同一條 `hlist` 中。 為了節省空間,`hlist_head` 只使用一個 `first` 指標指向 `hlist_node`,沒有指向串列尾節點的指標。
以下是引入 [hash table](https://en.wikipedia.org/wiki/Hash_table) 的實作,學習 Linux 核心程式碼風格:
```cpp
#include <stddef.h>
#include <stdlib.h>
struct hlist_node { struct hlist_node *next, **pprev; };
struct hlist_head { struct hlist_node *first; };
typedef struct { int bits; struct hlist_head *ht; } map_t;
#define MAP_HASH_SIZE(bits) (1 << bits)
map_t *map_init(int bits) {
map_t *map = malloc(sizeof(map_t));
if (!map)
return NULL;
map->bits = bits;
map->ht = malloc(sizeof(struct hlist_head) * MAP_HASH_SIZE(map->bits));
if (map->ht) {
for (int i = 0; i < MAP_HASH_SIZE(map->bits); i++)
(map->ht)[i].first = NULL;
} else {
free(map);
map = NULL;
}
return map;
}
struct hash_key {
int key;
void *data;
struct hlist_node node;
};
#define container_of(ptr, type, member) \
({ \
void *__mptr = (void *) (ptr); \
((type *) (__mptr - offsetof(type, member))); \
})
#define GOLDEN_RATIO_32 0x61C88647
static inline unsigned int hash(unsigned int val, unsigned int bits) {
/* High bits are more random, so use them. */
return (val * GOLDEN_RATIO_32) >> (32 - bits);
}
static struct hash_key *find_key(map_t *map, int key) {
struct hlist_head *head = &(map->ht)[hash(key, map->bits)];
for (struct hlist_node *p = head->first; p; p = p->next) {
struct hash_key *kn = container_of(p, struct hash_key, node);
if (kn->key == key)
return kn;
}
return NULL;
}
void *map_get(map_t *map, int key)
{
struct hash_key *kn = find_key(map, key);
return kn ? kn->data : NULL;
}
void map_add(map_t *map, int key, void *data)
{
struct hash_key *kn = find_key(map, key);
if (kn)
return;
kn = malloc(sizeof(struct hash_key));
kn->key = key, kn->data = data;
struct hlist_head *h = &map->ht[hash(key, map->bits)];
struct hlist_node *n = &kn->node, *first = h->first;
NNN;
if (first)
first->pprev = &n->next;
h->first = n;
PPP;
}
void map_deinit(map_t *map)
{
if (!map)
return;
for (int i = 0; i < MAP_HASH_SIZE(map->bits); i++) {
struct hlist_head *head = &map->ht[i];
for (struct hlist_node *p = head->first; p;) {
struct hash_key *kn = container_of(p, struct hash_key, node);
struct hlist_node *n = p;
p = p->next;
if (!n->pprev) /* unhashed */
goto bail;
struct hlist_node *next = n->next, **pprev = n->pprev;
*pprev = next;
if (next)
next->pprev = pprev;
n->next = NULL, n->pprev = NULL;
bail:
free(kn->data);
free(kn);
}
}
free(map);
}
int *twoSum(int *nums, int numsSize, int target, int *returnSize)
{
map_t *map = map_init(10);
*returnSize = 0;
int *ret = malloc(sizeof(int) * 2);
if (!ret)
goto bail;
for (int i = 0; i < numsSize; i++) {
int *p = map_get(map, target - nums[i]);
if (p) { /* found */
ret[0] = i, ret[1] = *p;
*returnSize = 2;
break;
}
p = malloc(sizeof(int));
*p = i;
map_add(map, nums[i], p);
}
bail:
map_deinit(map);
return ret;
}
```
==作答區==
NNN = ?
* `(a)` `/* no operation */`
* `(b)` `n->pprev = first`
* `(c)` `n->next = first`
* `(d)` `n->pprev = n`
PPP = ?
* `(a)` `n->pprev = &h->first`
* `(b)` `n->next = h`
* `(c)` `n->next = n`
* `(d)` `n->next = h->first`
* `(e)` `n->next = &h->first`
:::success
延伸題目:
1. 解釋上述程式碼運作原理
2. 研讀 Linux 核心原始程式碼 [include/linux/hashtable.h](https://github.com/torvalds/linux/blob/master/include/linux/hashtable.h) 及對應的文件 [How does the kernel implements Hashtables?](https://kernelnewbies.org/FAQ/Hashtables),解釋 hash table 的設計和實作手法,並留意到 [tools/include/linux/hash.h](https://github.com/torvalds/linux/blob/master/tools/include/linux/hash.h) 的 `GOLDEN_RATIO_PRIME`,探討其實作考量
:::
---
### 測驗 `3`
user-level thread 又稱為 [fiber](https://en.wikipedia.org/wiki/Fiber_(computer_science)),設計動機是提供比原生執行緒 (native thread) 更小的執行單元。我們嘗試透過 [Linux: 不僅是個執行單元的 Process](https://hackmd.io/s/r1ojuBGgE) 和 [UNIX 作業系統 fork/exec 系統呼叫的前世今生](https://hackmd.io/@sysprog/unix-fork-exec) 提到的 [clone](https://man7.org/linux/man-pages/man2/clone.2.html) 系統呼叫,來實作 [fiber](https://en.wikipedia.org/wiki/Fiber_(computer_science))。
給定以下程式碼:
```cpp
static void fibonacci() {
int fib[2] = {0, 1};
printf("Fib(0) = 0\nFib(1) = 1\n");
for (int i = 2; i < 15; ++i) {
int nextFib = fib[0] + fib[1];
printf("Fib(%d) = %d\n", i, nextFib);
fib[0] = fib[1];
fib[1] = nextFib;
fiber_yield();
}
}
static void squares() {
for (int i = 1; i < 10; ++i) {
printf("%d * %d = %d\n", i, i, i * i);
fiber_yield();
}
}
int main() {
fiber_init();
fiber_spawn(&fibonacci);
fiber_spawn(&squares);
/* Since fibers are non-preemptive, we must allow them to run */
fiber_wait_all();
return 0;
}
```
預期執行輸出如下: (其中一種可能輸出)
```
Fib(0) = 0
Fib(1) = 1
Fib(2) = 1
Fib(3) = 2
Fib(4) = 3
Fib(5) = 5
Fib(6) = 8
Fib(7) = 13
Fib(8) = 21
Fib(9) = 34
Fib(10) = 55
Fib(11) = 89
Fib(12) = 144
Fib(13) = 233
Fib(14) = 377
1 * 1 = 1
2 * 2 = 4
3 * 3 = 9
4 * 4 = 16
5 * 5 = 25
6 * 6 = 36
7 * 7 = 49
8 * 8 = 64
9 * 9 = 81
```
對應的 [fiber](https://en.wikipedia.org/wiki/Fiber_(computer_science)) 實作程式碼如下:
```cpp
#define FIBER_NOERROR 0
#define FIBER_MAXFIBERS 1
#define FIBER_MALLOC_ERROR 2
#define FIBER_CLONE_ERROR 3
#define FIBER_INFIBER 4
/* The maximum number of fibers that can be active at once. */
#define MAX_FIBERS 10
/* The size of the stack for each fiber. */
#define FIBER_STACK (1024 * 1024)
#define _GNU_SOURCE
#include <sched.h> /* For clone */
#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h> /* For pid_t */
#include <sys/wait.h> /* For wait */
#include <unistd.h> /* For getpid */
typedef struct {
pid_t pid; /* The pid of the child thread as returned by clone */
void *stack; /* The stack pointer */
} fiber_t;
/* The fiber "queue" */
static fiber_t fiber_list[MAX_FIBERS];
/* The pid of the parent process */
static pid_t parent;
/* The number of active fibers */
static int num_fibers = 0;
void fiber_init()
{
for (int i = 0; i < MAX_FIBERS; ++i)
fiber_list[i].pid = 0, fiber_list[i].stack = 0;
parent = getpid();
}
/* Yield control to another execution context */
void fiber_yield()
{
/* move the current process to the end of the process queue. */
sched_yield();
}
struct fiber_args {
void (*func)(void);
};
static int fiber_start(void *arg)
{
struct fiber_args *args = (struct fiber_args *) arg;
void (*func)() = args->func;
free(args);
func();
return 0;
}
/* Creates a new fiber, running the func that is passed as an argument. */
int fiber_spawn(void (*func)(void))
{
if (num_fibers == MAX_FIBERS)
return FIBER_MAXFIBERS;
if ((fiber_list[num_fibers].stack = malloc(FIBER_STACK)) == 0)
return FIBER_MALLOC_ERROR;
struct fiber_args *args;
if ((args = malloc(sizeof(*args))) == 0) {
free(fiber_list[num_fibers].stack);
return FIBER_MALLOC_ERROR;
}
args->func = func;
fiber_list[num_fibers].pid = clone(
fiber_start, KKK,
SIGCHLD | CLONE_FS | CLONE_FILES | CLONE_SIGHAND | CLONE_VM, args);
if (fiber_list[num_fibers].pid == -1) {
free(fiber_list[num_fibers].stack);
free(args);
return FIBER_CLONE_ERROR;
}
num_fibers++;
return FIBER_NOERROR;
}
/* Execute the fibers until they all quit. */
int fiber_wait_all()
{
/* Check to see if we are in a fiber, since we do not get signals in the
* child threads
*/
pid_t pid = getpid();
if (pid != parent)
return FIBER_INFIBER;
/* Wait for the fibers to quit, then free the stacks */
while (num_fibers > 0) {
if ((pid = wait(0)) == -1)
exit(1);
/* Find the fiber, free the stack, and swap it with the last one */
for (int i = 0; i < num_fibers; ++i) {
if (CCC) {
free(fiber_list[i].stack);
if (i != --num_fibers)
fiber_list[i] = fiber_list[num_fibers];
break;
}
}
}
return FIBER_NOERROR;
}
```
請補完程式碼,使其運作符合預期。
==作答區==
KKK = ?
* `(a)` `(char *) fiber_list[num_fibers].stack - FIBER_STACK`
* `(b)` `(char *) fiber_list[num_fibers].stack + FIBER_STACK`
* `(c)` `fiber_list[num_fibers].stack + 1`
* `(d)` `fiber_list[num_fibers].stack - 1`
CCC = ?
* `(a)` `fiber_list[i].pid != pid`
* `(b)` `fiber_list[i].pid == pid`
* `(c)` `fiber_list[i].pid != 0`
* `(d)` `fiber_list[i].pid == parent`
:::success
延伸問題:
1. 解釋上述程式碼運作原理,指出設計和實作的缺失,並予以改進
2. 反覆執行上述程式碼,可發現 `fibonacci` 和 `squares` 這兩個函式的輸出字串可能會遇到無法一整行表示 (即 interleaving),請指出原因並修正
3. 研讀 [A (Bare-Bones) User-Level Thread Library](https://www.schaertl.me/posts/a-bare-bones-user-level-thread-library/),理解 fiber 的實作手法,並且設計實驗來量化 process, native thread (指符合 [NPTL](https://man7.org/linux/man-pages/man7/nptl.7.html) 的實作), [fiber](https://en.wikipedia.org/wiki/Fiber_(computer_science)) / [coroutine](https://en.wikipedia.org/wiki/Coroutine) 等執行單元的切換成本
:::