Try   HackMD

2022q1 Homework2 (quiz2)

contributed by < blueskyson >

題目連結

測驗 1

考慮以下對二個無號整數取平均值的程式碼:

#include <stdint.h>
uint32_t average(uint32_t a, uint32_t b)
{
    return (a + b) / 2;
}

這個直覺的解法會有 overflow 的問題,若我們已知 a, b 數值的大小,可用下方程式避免 overflow:

#include <stdint.h>
uint32_t average(uint32_t low, uint32_t high)
{
    return low + (high - low) / 2;
}

接著我們可改寫為以下等價的實作:

#include <stdint.h>
uint32_t average(uint32_t a, uint32_t b)
{
    return (a >> 1) + (b >> 1) + (EXP1);
}

我們再次改寫為以下等價的實作:

uint32_t average(uint32_t a, uint32_t b)
{
    return (EXP2) + ((EXP3) >> 1);
}

解題

在第一種改寫的實作中,當 ab 皆為奇數時時,在 a >> 1b >> 1a/2a/b 各會損失 1,需要把 1 加回去。因此 EXP1a & b & 1

在第二種改寫的實作則是使用加法器的概念,a & ba + b 的進位值,a ^ b 則是 a + b 的和。因此 EXP2a & bEXP3a ^ b

測驗 2

改寫〈解讀計算機編碼〉一文的「不需要分支的設計」一節提供的程式碼 min,我們得到以下實作 (max):

#include <stdint.h>
uint32_t max(uint32_t a, uint32_t b)
{
    return a ^ ((EXP4) & -(EXP5));
}

延伸閱讀:

解題

利用自己 xor 自己等於 0 的特殊性質,我預期當 a > b 時,max(a, b) 會執行 (a ^ 0) 以回傳 a,反之則執行 (a ^ a ^ b)。因此很明顯的得到 EXP4a ^ b

a > b(a ^ b) & -(EXP5) 必須為 0,才能使得 max(a, b) 回傳 a ^ 0;反之 (a ^ b) & -(EXP5) 必須為 a ^ b,才能使得 max(a, b) 回傳 a ^ a ^ b。故 EXP5a < b 時,恰好可以製造出 (a ^ b) & 0(a ^ b) & 0xffff 來控制 max 的回傳值。

測驗 3

考慮以下 64 位元 GCD (greatest common divisor, 最大公因數) 求值函式:

#include <stdint.h>
uint64_t gcd64(uint64_t u, uint64_t v)
{
    if (!u || !v) return u | v;
    while (v) {                               
        uint64_t t = v;
        v = u % v;
        u = t;
    }
    return u;
}

改寫為以下等價實作:

#include <stdint.h>
uint64_t gcd64(uint64_t u, uint64_t v)
{
    if (!u || !v) return u | v;
    int shift;
    for (shift = 0; !((u | v) & 1); shift++) {
        u /= 2, v /= 2;
    }
    while (!(u & 1))
        u /= 2;
    do {
        while (!(v & 1))
            v /= 2;
        if (u < v) {
            v -= u;
        } else {
            uint64_t t = u - v;
            u = v;
            v = t;
        }
    } while (COND);
    return RET;
}

解題

第 1 步:
if (!u || !v) return u | v; 判斷 u, v 是否是 0 ,若其中一個是 0 就回傳 0。

第 2 步:

for (shift = 0; !((u | v) & 1); shift++) {
    u /= 2, v /= 2;
}

u, v 同時可被 2 整除,就將 u, v 同除以 2 ,並且讓 shift 加 1 ,由此可知 u, v 同為 (0x1 << shift) 的倍數,也就是將

2shift 作為公因數提出來。

第 3 步:

while (!(u & 1))
    u /= 2;

在第 2 步已經將

2shift 提出來了,代表接下來 gcd 的過程不會再萃取出偶數公因數,但是 uv 可能還是偶數,繼續將 u 除以 2 直到 u 不是偶數。

第 4 步:

do {
    while (!(v & 1))
        v /= 2;
    if (u < v) {
        v -= u;
    } else {
        uint64_t t = u - v;
        u = v;
        v = t;
    }
} while (COND);

這個 do while 迴圈持續相減過程就是輾轉相除。與第 3 步同理,每一輪迭代都將 v 除以 2 直到 v 不是偶數。

  • v 大於 uv - u 可以視為 v ÷ u 的餘數,將 v 減去 u 之後執行下一輪迭代。
  • v 小於 uu - v 可以視為 u ÷ v 的餘數,將 u 減去 v 之後執行下一輪迭代。
  • v 等於 u 時,v 即為所求公因數。由此可以推斷 do while 的條件 COND 即為 v,當 u == vv 會與 u 相減變成 0 以跳出迴圈。

第 5 步

return RET;

回傳時要將原本同除的

2shift 乘回去,所以 RETu << shift

測驗 4

在影像處理中,bit array (也稱 bitset) 廣泛使用,考慮以下程式碼:

#include <stddef.h>
size_t naive(uint64_t *bitmap, size_t bitmapsize, uint32_t *out)
{
    size_t pos = 0;
    for (size_t k = 0; k < bitmapsize; ++k) {
        uint64_t bitset = bitmap[k];
        size_t p = k * 64;
        for (int i = 0; i < 64; i++) {
            if ((bitset >> i) & 0x1)
                out[pos++] = p + i;
        }
    }
    return pos;
}

考慮 GNU extension 的 __builtin_ctzll 的行為是回傳由低位往高位遇上連續多少個 0 才碰到 1

範例: 當 a = 16
16 這個十進位數值的二進位表示法為 00000000 00000000 00000000 00010000
從低位元 (即右側) 往高位元,我們可發現 0 → 0 → 0 → 0 → 1,於是 ctz 就為 4,表示最低位元往高位元有 4 個 0

用以改寫的程式碼如下:

#include <stddef.h> size_t improved(uint64_t *bitmap, size_t bitmapsize, uint32_t *out) { size_t pos = 0; uint64_t bitset; for (size_t k = 0; k < bitmapsize; ++k) { bitset = bitmap[k]; while (bitset != 0) { uint64_t t = EXP6; int r = __builtin_ctzll(bitset); out[pos++] = k * 64 + r; bitset ^= t; } } return pos; }

其中第 9 行的作用是找出目前最低位元的 1,並紀錄到 t 變數。若 bitmap 越鬆散 (即 1 越少),於是 improved 的效益就更高。

解題

improved 改寫程式碼根據 trailing zero 的數量來判斷最靠近 LSB 的 1 的 bit,所以每次紀錄完最靠近 LSB 的 1 的 bit,都必須將該 bit 的值變為 0 且其他 bit 保持不變。EXP6 的用途就是把 LSB 單獨提取出來,並賦值給 t,故 EXP6bitset & -bitset

a & -a 的特性:

    bitset = xxxx 1000
&  -bitset = yyyy 1000
--------------------------
         t = 0000 1000

我們可以看到在 bitset & -bitset 之後,t 變成只剩 LSB 的數值。之後再讓 bitset ^ t 就能把 bitset 的 LSB,進行下一輪計算:

    bitset = xxxx 1000
^        t = 0000 1000
--------------------------
    bitset = xxxx 0000

測驗 5

以下是 LeetCode 166. Fraction to Recurring Decimal 的可能實作:

#include <stdbool.h>                       
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "list.h"
    
struct rem_node {
    int key;
    int index;
    struct list_head link;
};  
    
static int find(struct list_head *heads, int size, int key)
{
    struct rem_node *node;
    int hash = key % size;
    list_for_each_entry (node, &heads[hash], link) {
        if (key == node->key)
            return node->index;
    }
    return -1;
}

char *fractionToDecimal(int numerator, int denominator)
{
    int size = 1024;
    char *result = malloc(size);
    char *p = result;

    if (denominator == 0) {
        result[0] = '\0';
        return result;
    }

    if (numerator == 0) {
        result[0] = '0';
        result[1] = '\0';
        return result;
    }

    /* using long long type make sure there has no integer overflow */
    long long n = numerator;
    long long d = denominator;

    /* deal with negtive cases */
    if (n < 0)
        n = -n;
    if (d < 0)
        d = -d;

    bool sign = (float) numerator / denominator >= 0;
    if (!sign)
        *p++ = '-';

    long long remainder = n % d;
    long long division = n / d;

    sprintf(p, "%ld", division > 0 ? (long) division : (long) -division);
    if (remainder == 0)
        return result;

    p = result + strlen(result);
    *p++ = '.';

    /* Using a map to record all of reminders and their position.
     * if the reminder appeared before, which means the repeated loop begin,
     */
    char *decimal = malloc(size);
    memset(decimal, 0, size);
    char *q = decimal;

    size = 1333;
    struct list_head *heads = malloc(size * sizeof(*heads));
    for (int i = 0; i < size; i++)
        INIT_LIST_HEAD(&heads[i]);

    for (int i = 0; remainder; i++) {
        int pos = find(heads, size, remainder);
        if (pos >= 0) {
            while (PPP > 0)
                *p++ = *decimal++;
            *p++ = '(';
            while (*decimal != '\0')
                *p++ = *decimal++;
            *p++ = ')';
            *p = '\0';
            return result;
        }
        struct rem_node *node = malloc(sizeof(*node));
        node->key = remainder;
        node->index = i;

        MMM(&node->link, EEE);

        *q++ = (remainder * 10) / d + '0';
        remainder = (remainder * 10) % d;
    }

    strcpy(p, decimal);
    return result;
}

解題

為了揣摩這個程式的邏輯,舉循環小數 12 / 11 為例:

step 0:

首先計算 12 / 11 的商數 division = 1 與餘數 remainder = 1,此時 result = "1."。接下來初始化 hash table,然後進入 for 迴圈計算小數部份。此時的狀態如下:







G


cluster_0

hash table



result

1

.

 

 

 



p

p



p->result:2





r

result



r->result:w





decimal

 

 

 

 

 



q

q



q->decimal:0





d

decimal



d->decimal:w






step 1:

進到 for 迴圈後第一件事就是透過 find,從 hash table 中尋找過去是否除過當前的餘數,若是就代表陷入循環小數,回傳發生第一次循環小數的位數。因為此時 hash table 還是空的,所以 find 回傳 -1,並賦值給 pos。因為沒有陷入循環小數,所以不會進到 if 區塊中。

把當前的小數位數 i == 0 以及餘數 remainder == 1 裝進 rem_node 放入 hash table 中。接下來進行長除法,將餘數 1 乘以 10,再除以除數 11,也就是計算 10 / 11。將 10 / 11 的商數 0 轉為字元存到 decimal 中,然後把 remainder 為更新為 10 / 11 的餘數 10,此時狀態如下:







G


cluster_0

hash table



result

1

.

 

 

 



p

p



p->result:2





r

result



r->result:w





decimal

0

 

 

 

 



q

q



q->decimal:1





d

decimal



d->decimal:w





n1

index = 0

remainder = 1



step 2:

此時 remainder == 10,hash table 中並沒有 remainder == 10 的元素,所以 find 回傳 -1,不會進到 if 區塊中。

把當前的小數位數 i == 1 以及餘數 remainder == 10 裝進 rem_node 放入 hash table 中。接下來進行長除法,將 10 乘以 10,再除以 11,也就是計算 100 / 11。將 100 / 11 的商數 9 轉為字元存到 decimal 中,然後把 remainder 為更新為 100 / 11 的餘數 1,此時狀態如下:







G


cluster_0

hash table



result

1

.

 

 

 



p

p



p->result:2





r

result



r->result:w





decimal

0

9

 

 

 



q

q



q->decimal:2





d

decimal



d->decimal:w





n1

index = 1

remainder = 10



n2

index = 0

remainder = 1



step 3:

此時 remainder == 1,恰好 hash table 中存在 index == 0, remainder == 1 的元素,代表已經陷入循環小數了,回傳 index 的值 0pos,然後進入 if 區塊。

在 if 區塊中,PPP 的 while 是將未循環的位數填到 result 中(例如 0.12(34)12),這個例子中從小數後第 0 位開始都是循環小數,所以不會執行這個 while 迴圈。接下來在 result 填入左括弧、填入 decimal 循環的部份、填入右括弧。最後回傳 result







G


cluster_0

hash table



result

1

.

(

0

9

)

\0



p

p



p->result:6





r

result



r->result:w





decimal

0

9

 

 

 



q

q



q->decimal:2





d

decimal



d->decimal:w





n1

index = 1

remainder = 10



n2

index = 0

remainder = 1



理清楚程式的邏輯後,很明顯的得出 PPPpos-- 以填入未循環的位數;MMMlist_add 把元素放入 hash table;EEE&heads[remainder % size],用以找到對應的 hash 的 entry。

測驗 6

__alignof__ 是 GNU extension,以下是其可能的實作方式:

/*
 * ALIGNOF - get the alignment of a type
 * @t: the type to test
 *
 * This returns a safe alignment for the given type.
 */
#define ALIGNOF(t) \
    ((char *)(&((struct { char c; t _h; } *)0)->M) - (char *)X)

解題

慢慢剖析這個巨集,首先 (struct { char c; t _h; } *) 0,是將 0x0 (nil) 這個位址開頭的記憶體視為一個 struct { char c; t _h; } 物件。

(char *)(&((struct { char c; t _h; } *)0)->M) 則是以 0x0 作為此物件的起始點,取成員 M 的位址,再將其轉形為 char *,待會便能以 1 byte 為單位計算位址的差距。因為 ALIGNOF 是用來計算 t 的 alignment,很明顯 M_h

取得 _h 的位址後,我們只要將其減去 0x0 就能得到型態 t 的位移量,所以 X0。以下測試常用型態的位移量:

// test.c
#include <stdio.h>
#define ALIGNOF(t) \
    ((char *)(&((struct { char c; t _h; } *)0)->_h) - (char *)0)

int main(void) {
    printf("alignof char: %ld\n", ALIGNOF(char));
    printf("alignof short: %ld\n", ALIGNOF(short));
    printf("alignof int: %ld\n", ALIGNOF(int));
    printf("alignof double: %ld\n", ALIGNOF(double));
    printf("alignof long: %ld\n", ALIGNOF(long));
    printf("alignof long long: %ld\n", ALIGNOF(long long));
    return 0;
}
$ gcc test.c
$ ./a.out
alignof char: 1
alignof short: 2
alignof int: 4
alignof double: 8
alignof long: 8
alignof long long: 8

測驗 7

考慮貌似簡單卻蘊含實作深度的 FizzBuzz 題目:

  • 從 1 數到 n,如果是 3的倍數,印出 “Fizz”
  • 如果是 5 的倍數,印出 “Buzz”
  • 如果是 15 的倍數,印出 “FizzBuzz”
  • 如果都不是,就印出數字本身

直覺的實作程式碼如下: (naive.c)

#include <stdio.h>
int main() {
    for (unsigned int i = 1; i < 100; i++) {
        if (i % 3 == 0) printf("Fizz");
        if (i % 5 == 0) printf("Buzz");
        if (i % 15 == 0) printf("FizzBuzz");
        if ((i % 3) && (i % 5)) printf("%u", i);
        printf("\n");
    }
    return 0;
}

觀察 printf 的(格式)字串,可分類為以下三種:

  1. 整數格式字串 "%d" : 長度為 2 B
  2. “Fizz” 或 “Buzz” : 長度為 4 B
  3. “FizzBuzz” : 長度為 8 B

考慮下方程式碼:

#define MSG_LEN 8
char fmt[MSG_LEN + 1];
strncpy(fmt, &"FizzBuzz%u"[start], length);
fmt[length] = '\0';
printf(fmt, i);
printf("\n");

我們若能精準從給定輸入 i 的規律去控制 startlength ,即可符合 FizzBuzz 題意:

string literal: "fizzbuzz%u"
        offset:  0   4   8

以下是利用 bitwise 和上述技巧實作的 FizzBuzz 程式碼: (bitwise.c)

static inline bool is_divisible(uint32_t n, uint64_t M)
{
    return n * M <= M - 1;
}

static uint64_t M3 = UINT64_C(0xFFFFFFFFFFFFFFFF) / 3 + 1;
static uint64_t M5 = UINT64_C(0xFFFFFFFFFFFFFFFF) / 5 + 1;

int main(int argc, char **argv)
{
    for (size_t i = 1; i <= 100; i++) {
        uint8_t div3 = is_divisible(i, M3);
        uint8_t div5 = is_divisible(i, M5);
        unsigned int length = (2 << KK1) << KK2;

        char fmt[9];
        strncpy(fmt, &"FizzBuzz%u"[(9 >> div5) >> (KK3)], length);
        fmt[length] = '\0';

        printf(fmt, i);
        printf("\n");
    }
    return 0;
}

其中 is_divisible 函式技巧來自 Faster remainders when the divisor is a constant: beating compilers and libdivide,甚至 gcc-9 還內建了 FizzBuzz optimization (Bug 82853 - Optimize x % 3 == 0 without modulo)。

請補完。

對於處理器來說,每個運算所花的成本是不同的,比如 add, sub 就低於 mul,而這些運算的 cost 又低於 div 。依據〈Infographics: Operation Costs in CPU Clock Cycles〉,可發現整數除法的成本幾乎是整數加法的 50 倍。

解題

length 作為 strncpy 複製的字串長度,當

i|3
i|5
length 預期為 4,即 "Fizz""Buzz" 的長度;當
i|15
length 預期為 8,即 "FizzBuzz" 的長度。因此 KK1div3KK2div5

&"FizzBuzz%u"[(9 >> div5) >> (KK3)] 代表要從字元陣列 "FizzBuzz%u" 的哪個位址開始複製。以下是在各種情況期望的起始位址與對應的字串:

  • i|3
    : (9 >> div5) >> (KK3) == 0,i.e. "Fizz"
  • i|5
    : (9 >> div5) >> (KK3) == 4,i.e. "Buzz"
  • i|15
    : (9 >> div5) >> (KK3) == 0,i.e. "FizzBuzz"
  • default: (9 >> div5) >> (KK3) == 8,i.e. "%u"

KK3div3 << 2 時可以達成上述期望。

(9 >> div5) >> (KK3) 應改成 (8 >> div5) >> (KK3),否則會複製到 "u" 而非 "%u"。