Try   HackMD

2023q1 Homework3 (quiz3)

contributed by < paul90317 >

作業要求

測驗 1

觀察 原始碼

typedef struct __node {
    uintptr_t color;
    struct __node *left, *right;
    struct __node *next;
    long value;
} node_t __attribute__((aligned(sizeof(long))));

會將顏色與親節點指標的值存在 color 成員裡面,用於存取的程式碼如下

#define rb_parent(r) ((node_t *) (AAAA))
#define rb_color(r) ((color_t) (r)->color & 1)

由於指標後兩位(64 bits 系統是三位)都是 0,且我們只會用到其中一位,AAAA(r)->color & ~1

AAAA 不可以是 (r)->color & ~1u 因為在做位元且之前會將 ~1u 從 32 位元擴展到 64 位元,如果是無號擴展再位元且,將會遺失指標高位位元。

#define rb_is_red(r) (!rb_color(r))
#define rb_is_black(r) (rb_color(r))
#define rb_set_parent(r, p)                         \
    do {                                            \
        (r)->color = rb_color(r) | (uintptr_t) (p); \
    } while (0)
#define rb_set_red(r) \
    do {              \
        BBBB;         \
    } while (0)
#define rb_set_black(r) \
    do {                \
        CCCC;           \
    } while (0)

從以上程式碼看出紅色是 0,黑色是 1,BBBB 應該是 (r)->color &= ~1,而 CCCC 應該是 (r)->color |= 1

觀察 cmap_insert 函式的程式碼片段

        if (res < 0) {
            if (!cur->left) {
                cur->left = node;
                rb_set_parent(node, cur);
                cmap_fix_colors(obj, node);
                break;
            }
            DDDD;
        } else {
            if (!cur->right) {
                cur->right = node;
                rb_set_parent(node, cur);
                cmap_fix_colors(obj, node);
                break;
            }
            EEEE;
        }

根據該函式的意圖,DDDDEEEE 應是在 cur 不是 internal node 的情況下,存取下一個節點,故 DDDDcur = cur->leftEEEEcur = cur->right

觀察 tree_sort 程式碼

void tree_sort(node_t **list)
{
    node_t **record = list;
    cmap_t map = cmap_new(sizeof(long), sizeof(NULL), cmap_cmp_int);
    while (*list) {
        cmap_insert(map, *list, NULL);
        list = FFFF;
    }
    node_t *node = cmap_first(map), *first = node;
    for (; node; node = cmap_next(node)) {
        *list = node;
        list = GGGG;
    }
    HHHH;
    *record = first;
    free(map);
}

可以看出 while 迴圈想要將 list 所有節點插入 map,所以 FFFF 應是 &(*list)->next 以存取下一個要插入的節點。
for 迴圈想要用 cmap_next 走訪節點並構造新的數列,所以 GGGG 應該要是 &(*list)->next

順便一提指標的指標讓 list 可以直接在從 map 拿到 node 後再給予值,否則程式碼將變成如下

    node_t *node = cmap_first(map), *first = node, *last = NULL;
    for (; node; node = cmap_next(node)) {
        if (last != NULL) {
            last->next = node;
        }
        last = node;
    }

將於每次判斷 node 是否是第一個節點,或是是以下

    node_t *last = cmap_first(map), *node = cmap_next(map), *first = last;
    for (; node; node = cmap_next(node)) {
        last->next = node;
        last = node;
    }

將第一個節點分開判斷。

HHHH*list = NULL 對新數列進行收尾。

測驗 2

觀察 原始碼

struct avl_node {
    unsigned long parent_balance;
    struct avl_node *left, *right;
} AVL_NODE_ALIGNED;

enum avl_node_balance { AVL_NEUTRAL = 0, AVL_LEFT, AVL_RIGHT, };

static inline struct avl_node *avl_parent(struct avl_node *node)
{
    return (struct avl_node *) (IIII);
}
static inline enum avl_node_balance avl_balance(const struct avl_node *node)
{
    return (enum avl_node_balance)(JJJJ);
}

在以上程式碼中,由於 AVL 樹的節點有三個狀態,須給予兩個位元,故 IIIInode->parent_balance & ~3JJJJnode->parent_balance & 3

static void avl_set_parent(struct avl_node *node, struct avl_node *parent)
{
    node->parent_balance =
        (unsigned long) parent | (KKKK);
}
static void avl_set_balance(struct avl_node *node,
                            enum avl_node_balance balance)
{
    node->parent_balance = (LLLL) | balance;
}

在上述程式碼中,KKKK 應是原本的狀態,故是 avl_balance(node)LLLLavl_parent(node)
觀察 avl_insert_balance 的程式碼片段

/* compensate double left balance by rotation
* and stop afterwards
*/
switch (avl_balance(node)) {
default:
case AVL_LEFT:
case AVL_NEUTRAL:
    MMMM(node, parent, root);
    break;
case AVL_RIGHT:
    NNNN(node, parent, root);
    break;
}

以上程式碼發生於當 nodeparent 左邊的節點
當樹向左邊傾斜時,應該將 parent 向右旋轉,故 MMMMavl_rotate_right
node 是向右邊傾斜的情況,NNNN 應該先將node 向左轉再將 parent 向右轉,故是 avl_rotate_leftright

測驗 3

觀察 原始碼

static const char log_table_256[256] = {
#define _(n) n, n, n, n, n, n, n, n, n, n, n, n, n, n, n, n
    -1,   0,    1,    1,    2,    2,    2,    2,    3,    3,    3,
    3,    3,    3,    3,    3,    _(4), _(5), _(5), _(6), _(6), _(6),
    _(6), _(7), _(7), _(7), _(7), _(7), _(7), _(7), _(7),
#undef _
};

#undef _

在查找表 log_table_256,輸入 0 得到 -1,輸 1 得 0,輸 2、3 得 1,輸

n
log2n
,最高到輸入 255 得 7。
255 所佔用的位元會是輸出值
log2255+1
,也就是 8 個位元。

/* ASSUME x >= 1
 * returns smallest integer b such that 2^b = 1 << b is >= x
 */
uint64_t log2_64(uint64_t v)
{
    unsigned r;
    uint64_t t, tt, ttt;

    ttt = v >> 32;
    if (ttt) {
        tt = ttt >> 16;
        if (tt) {
            t = tt >> 8;
            if (t) {
                r = AAAA + log_table_256[t];
            } else {
                r = BBBB + log_table_256[tt];
            }
        } else {
            t = ttt >> 8;
            if (t) {
                r = CCCC + log_table_256[t];
            } else {
                r = DDDD + log_table_256[ttt];
            }
        }
    } else {
        tt = v >> 16;
        if (tt) {
            t = tt >> 8;
            if (t) {
                r = EEEE + log_table_256[t];
            } else {
                r = FFFF + log_table_256[tt];
            }
        } else {
            t = v >> 8;
            if (t) {
                r = GGGG + log_table_256[t];
            } else {
                r = HHHH + log_table_256[v];
            }
        }
    }
    return r;
}

以上程式碼我覺得是有點類似遞迴的作法,只是為了更好的效率所以將遞迴函式展開,於是我將該程式傳回遞迴,如下

uint64_t _log2_64(uint64_t v, uint8_t shift)
{
    if (shift == 4) // 改 !(v >> 8) 剪枝 (pruning) 會有更好的效率
        return log_table_256[v]
    if (v >> shift) 
        return _log2_64(v >> shift, shift >> 1) + t;
    else
        return _log2_64(v, shift >> 1);
}
uint64_t log2_64(uint64_t v)
{
    _log2_64(v, 32)
}

對照兩者後可以得出

  • AAAA 是 32 + 16 + 8 是 56
  • BBBB 是 32 + 16 是 48
  • CCCC 是 32 + 8 是 40
  • DDDD 是 32
  • EEEE 是 16 + 8 是24
  • FFFF 是 16
  • GGGG 是 8
  • HHHH 是 0
static unsigned int N_BUCKETS;
static unsigned char N_BITS;
void set_N_BUCKETS(unsigned int n)
{
    N_BUCKETS = n;
}
void set_N_BITS()
{
    N_BITS = log2_64(N_BUCKETS);
}
/* n == number of totally available buckets, so buckets = \{0, ...,, n-1\}
 * ASSUME n < (1 << 32)
 */
unsigned int bucket_number(uint64_t x)
{
    uint64_t mask111 = (1 << (N_BITS + 1)) - 1;
    uint64_t mask011 = (1 << (N_BITS)) - 1; /* one 1 less */

    unsigned char leq = ((x & mask111) < N_BUCKETS);
    /* leq (less or equal) is 0 or 1. */

    return (leq * (x & IIII)) +
           ((1 - leq) * ((x >> (N_BITS + 1)) & JJJJ));
    /* 'x >> (N_BITS + 1)' : take different set of bits -> better uniformity.
     * '... & mask011' guarantees that the result is less or equal N_BUCKETS.
     */
}

舉一個例子,假設 N_BUCKETS 是 256,N_BITS 就是 8,mask111 是 511,將這個數字對 x 做位元且可能會超出桶數。
leq 為 1 時,bucket_number 應輸出正常的桶號,也就是 x & mask111,故 IIIImask111
反之,根據註解,應當要取較少的位元,故 JJJJmask011

測驗 4

int ceil_log2(uint32_t x)
{
    uint32_t r, shift;

    x--;
    r = (x > 0xFFFF) << 4;                 
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (KKKK) + 1;       
}

因為

log2n=log2(n1)+1,先將程式碼改寫成方便理解的形式,如下。

inline int floor_log2(uint32_t x) { uint32_t r, shift; // 若 x 超過 16 位元,向右位移 16 個位元 r = (x > 0xFFFF) << 4; x >>= r; // 若 x 超過 8 位元,向右位移 8 個位元 shift = (x > 0xFF) << 3; x >>= shift; r |= shift; // 視為加法 // 若 x 超過 4 位元,向右位移 4 個位元 shift = (x > 0xF) << 2; x >>= shift; r |= shift; // 視為加法 // 若 x 超過 2 位元,向右位移 2 個位元 shift = (x > 0x3) << 1; x >>= shift; return (KKKK); } int ceil_log2(uint32_t x) { return ceil_log2(x - 1) + 1; }

經過 21 行的位移後,x 最多只剩 2 個位元,於是 KKKK 就是前一次的結果 r | shift 加上左邊數來第二個位元的結果 x >> 1
KKKKr | shift | x >> 1

2. 改進程式碼,使其得以處理 x = 0 的狀況,並仍是 branchless
程式碼會在 x-- 的時候出問題,最簡單的作法就是使用測驗 3 bucket_number 最後一行的作法,將 ceil_log2 改寫如下

inline int _floor_log2(uint32_t x)
{
    uint32_t r, shift;
    
    // 若 x 超過 16 位元,向右位移 16 個位元
    r = (x > 0xFFFF) << 4;                 
    x >>= r;
    
    // 若 x 超過 8 位元,向右位移 8 個位元
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift; // 視為加法
    
    // 若 x 超過 4 位元,向右位移 4 個位元
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift; // 視為加法
    
    // 若 x 超過 2 位元,向右位移 2 個位元
    shift = (x > 0x3) << 1;
    x >>= shift;
    return r | shift | x >> 1;       
}
int floor_log2(uint32_t x)
{
    unsigned char iszero = !x;
    return -iszero + (1 - iszero) * _floor_log2(x);
}
int ceil_log2(uint32_t x)
{
    unsigned char iszero = !x;
    return -iszero + (1 - iszero) * (_floor_log2(x - 1) + 1);
}