Try   HackMD

2021q1 第 2 週測驗題

tags: linux2021

目的: 檢驗學員對 C 語言指標操作, linked list, bitwise 的認知

題目解說錄影
作答表單

測驗 1

考慮以下仿效 Linux 核心 include/linux/list.h 的精簡實作:

#include <stddef.h>

/**
 * container_of() - Calculate address of object that contains address ptr
 * @ptr: pointer to member variable
 * @type: type of the structure containing ptr
 * @member: name of the member variable in struct @type
 *
 * Return: @type pointer of object containing ptr
 */
#ifndef container_of
#define container_of(ptr, type, member)                            \
    __extension__({                                                \
        const __typeof__(((type *) 0)->member) *__pmember = (ptr); \
        (type *) ((char *) __pmember - offsetof(type, member));    \
    })
#endif

/**
 * struct list_head - Head and node of a doubly-linked list
 * @prev: pointer to the previous node in the list
 * @next: pointer to the next node in the list
 */
struct list_head {
    struct list_head *prev, *next;
};

/**
 * LIST_HEAD - Declare list head and initialize it
 * @head: name of the new object
 */
#define LIST_HEAD(head) struct list_head head = {&(head), &(head)}

/**
 * INIT_LIST_HEAD() - Initialize empty list head
 * @head: pointer to list head
 */
static inline void INIT_LIST_HEAD(struct list_head *head)
{
    head->next = head; head->prev = head;
}

/**
 * list_add_tail() - Add a list node to the end of the list
 * @node: pointer to the new node
 * @head: pointer to the head of the list
 */
static inline void list_add_tail(struct list_head *node, struct list_head *head)
{
    struct list_head *prev = head->prev;

    prev->next = node;
    node->next = head;
    node->prev = prev;
    head->prev = node;
}

/**
 * list_del() - Remove a list node from the list
 * @node: pointer to the node
 */
static inline void list_del(struct list_head *node)
{
    struct list_head *next = node->next, *prev = node->prev;
    next->prev = prev; prev->next = next;
}

/**
 * list_empty() - Check if list head has no nodes attached
 * @head: pointer to the head of the list
 */
static inline int list_empty(const struct list_head *head)
{
    return (head->next == head);
}

/**
 * list_is_singular() - Check if list head has exactly one node attached
 * @head: pointer to the head of the list
 */
static inline int list_is_singular(const struct list_head *head)
{
    return (!list_empty(head) && head->prev == head->next);
}

/**
 * list_splice_tail() - Add list nodes from a list to end of another list
 * @list: pointer to the head of the list with the node entries
 * @head: pointer to the head of the list
 */
static inline void list_splice_tail(struct list_head *list,
                                    struct list_head *head)
{
    struct list_head *head_last = head->prev;
    struct list_head *list_first = list->next, *list_last = list->prev;

    if (list_empty(list))
        return;

    head->prev = list_last;
    list_last->next = head;

    list_first->prev = head_last;
    head_last->next = list_first;
}

/**
 * list_cut_position() - Move beginning of a list to another list
 * @head_to: pointer to the head of the list which receives nodes
 * @head_from: pointer to the head of the list
 * @node: pointer to the node in which defines the cutting point
 */
static inline void list_cut_position(struct list_head *head_to,
                                     struct list_head *head_from,
                                     struct list_head *node)
{
    struct list_head *head_from_first = head_from->next;

    if (list_empty(head_from))
        return;

    if (head_from == node) {
        INIT_LIST_HEAD(head_to);
        return;
    }

    head_from->next = node->next;
    head_from->next->prev = head_from;

    head_to->prev = node;
    node->next = head_to;
    head_to->next = head_from_first;
    head_to->next->prev = head_to;
}

/**
 * list_entry() - Calculate address of entry that contains list node
 * @node: pointer to list node
 * @type: type of the entry containing the list node
 * @member: name of the list_head member variable in struct @type
 */
#define list_entry(node, type, member) container_of(node, type, member)

/**
 * list_first_entry() - get first entry of the list
 * @head: pointer to the head of the list
 * @type: type of the entry containing the list node
 * @member: name of the list_head member variable in struct @type
 */
#define list_first_entry(head, type, member) \
    list_entry((head)->next, type, member)

/**
 * list_for_each - iterate over list nodes
 * @node: list_head pointer used as iterator
 * @head: pointer to the head of the list
 */
#define list_for_each(node, head) \
    for (node = (head)->next; node != (head); node = node->next)

上方程式碼的好處在於,只要 list_head 納入新的 C 結構的一個成員,即可操作,且不用自行維護一套 doubly-linked list 。

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

延伸閱讀: Linux 鏈結串列 struct list_head 研究

其中 GNU extension 的 typeof 是個 operator,用於回傳 object 的型別,搭配巨集使用,可在傳遞參數時接受多種型別,從而強化程式碼的彈性。示範:

#define max(a, b) \
    ({ typeof(a) _a = (a); \
      typeof(b) _b = (b); \
      _a > _b ? _a : _b; \
     })

C89/C99 的 offsetof 可接受給定成員的型態及成員的名稱,回傳「成員的位址減去 struct 的起始位址」

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

巨集 container_of 用途: 給定成員的位址、struct 的型態,及成員的名稱,container_of 會回傳此 struct 的位址,以下是示意圖

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

container_of 實作:

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

#define container_of(ptr, type, member)                            \
    __extension__({                                                \
        const __typeof__(((type *) 0)->member) *__pmember = (ptr); \
        (type *) ((char *) __pmember - offsetof(type, member));    \
    })

目的是從 struct 中的 member 推算出原本 struct 的位址。解析:

  • 先透過 __typeof__ 得到 type 中的成員 member 的型別,並產生一個 pointer to 給該型別的 __pmember
  • ptr assign 給 __pmember
  • __pmember 目前指向的是 member 的位址。
  • offsetof(type, member) 可以算出 membertype 這個 struct 中的位移量, offset 。
  • 將絕對位址 (char *) __pmember 減去 offsetof(type, member) ,可以得到 struct 的起始位址。
  • 最後 (type *) 再將起始位置轉型為 pointer to type

延伸閱讀: 你所不知道的 C 語言 : 指標篇

接著我們延伸上述程式碼來實作 doubly-linked list 的 merge sort:

#include <string.h>

typedef struct __element {
    char *value;
    struct __element *next;
    struct list_head list;
} list_ele_t;

typedef struct {
    list_ele_t *head; /* Linked list of elements */
    list_ele_t *tail;
    size_t size;
    struct list_head list;
} queue_t;

static list_ele_t *get_middle(struct list_head *list)
{
    struct list_head *fast = list->next, *slow;
    list_for_each (slow, list) {
        if (COND1 || COND2)
            break;
        fast = fast->next->next;
    }
    return list_entry(TTT, list_ele_t, list);
}

static void list_merge(struct list_head *lhs,
                       struct list_head *rhs,
                       struct list_head *head)
{
    INIT_LIST_HEAD(head);
    if (list_empty(lhs)) {
        list_splice_tail(lhs, head);
        return;
    }
    if (list_empty(rhs)) {
        list_splice_tail(rhs, head);
        return;
    }

    while (!list_empty(lhs) && !list_empty(rhs)) {
        char *lv = list_entry(lhs->next, list_ele_t, list)->value;
        char *rv = list_entry(rhs->next, list_ele_t, list)->value;
        struct list_head *tmp = strcmp(lv, rv) <= 0 ? lhs->next : rhs->next;
        list_del(tmp);
        list_add_tail(tmp, head);
    }
    list_splice_tail(list_empty(lhs) ? rhs : lhs, head);
}

void list_merge_sort(queue_t *q)
{
    if (list_is_singular(&q->list))
        return;

    queue_t left;
    struct list_head sorted;
    INIT_LIST_HEAD(&left.list);
    list_cut_position(&left.list, &q->list, MMM);
    list_merge_sort(&left);
    list_merge_sort(q);
    list_merge(&left.list, &q->list, &sorted);
    INIT_LIST_HEAD(&q->list);
    list_splice_tail(&sorted, &q->list);
}

搭配的測試程式碼如下:

#include <assert.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>

static bool validate(queue_t *q)
{
    struct list_head *node;
    list_for_each (node, &q->list) {
        if (node->next == &q->list)
            break;
        if (strcmp(list_entry(node, list_ele_t, list)->value,
                   list_entry(node->next, list_ele_t, list)->value) > 0)
            return false;
    }
    return true;
}

static queue_t *q_new()
{
    queue_t *q = malloc(sizeof(queue_t));
    if (!q) return NULL;

    q->head = q->tail = NULL;
    q->size = 0;
    INIT_LIST_HEAD(&q->list);
    return q;
}

static void q_free(queue_t *q)
{
    if (!q) return;

    list_ele_t *current = q->head;
    while (current) {
        list_ele_t *tmp = current;
        current = current->next;
        free(tmp->value);
        free(tmp);
    }
    free(q);
}

bool q_insert_head(queue_t *q, char *s)
{
    if (!q) return false;

    list_ele_t *newh = malloc(sizeof(list_ele_t));
    if (!newh)
        return false;

    char *new_value = strdup(s);
    if (!new_value) {
        free(newh);
        return false;
    }

    newh->value = new_value;
    newh->next = q->head;
    q->head = newh;
    if (q->size == 0)
        q->tail = newh;
    q->size++;
    list_add_tail(&newh->list, &q->list);

    return true;
}

int main(void)
{
    FILE *fp = fopen("cities.txt", "r");
    if (!fp) {
        perror("failed to open cities.txt");
        exit(EXIT_FAILURE);
    }

    queue_t *q = q_new();
    char buf[256];
    while (fgets(buf, 256, fp))
        q_insert_head(q, buf);
    fclose(fp);

    list_merge_sort(q);
    assert(validate(q));

    q_free(q);

    return 0;
}

其中 cities.txt 取自 dict/cities.txt,內含世界超過 9 萬個都市的名稱。

可用以下命令來確認:

$ wc -l cities.txt 

預期輸出 93827

請補完程式碼,使其運作符合預期。

作答區

COND1 = ?

  • (a) slow->next == list
  • (b) slow->next != list
  • (c) fast->next == list
  • (d) fast->next != list
  • (e) slow == list
  • (f) slow != list
  • (g) fast == list
  • (h) fast != list

COND2 = ?

  • (a) slow->next->next == list
  • (b) fast->next->next == list
  • (c) !list
  • (d) fast != slow

MMM = ?

  • (a) get_middle(&q->list)
  • (b) get_middle(&(q->list))
  • (c) get_middle(q)
  • (d) get_middle(&q)
  • (e) &get_middle(&q->list)->list

TTT = ?

  • (a) slow
  • (b) fast

延伸問題:

  1. 解釋上述程式碼運作原理,指出改進空間並著手實作
  2. 研讀 Linux 核心的 lib/list_sort.c 原始程式碼,學習 sysprog21/linux-list 的手法,將 lib/list_sort.c 的實作抽離為可單獨執行 (standalone) 的使用層級應用程式,並設計效能評比的程式碼,說明 Linux 核心的 lib/list_sort.c 最佳化手法
  3. 將你在 quiz1 提出的改進實作,和 Linux 核心的 lib/list_sort.c 進行效能比較,需要提出涵蓋足夠廣泛的 benchmark

測驗 2

考慮函式 func 接受一個 16 位元無號整數

N,並回傳小於或等於
N
的 power-of-2 (漢字可寫作為 2 的冪)

女星楊冪的名字裡,也有一個「冪」字。且,她取這個名字,就有「次方」的意義,因為她一家三口都姓楊,所以是「楊的三次方」。

實作程式碼如下:

uint16_t func(uint16_t N) {
    /* change all right side bits to 1 */
    N |= N >> 1;
    N |= N >> X;
    N |= N >> Y;
    N |= N >> Z;

    return (N + 1) >> 1;
}

假定 N = 101012 = 2110,那麼 func(21) 要得到 16,請補完程式碼,使其運作符合預期。

作答區

X = ?

  • (a) 1
  • (b) 2
  • (c) 3
  • (d) 4
  • (e) 5
  • (f) 6
  • (g) 7
  • (h) 8
  • (i) 0

Y = ?

  • (a) 1
  • (b) 2
  • (c) 3
  • (d) 4
  • (e) 5
  • (f) 6
  • (g) 7
  • (h) 8
  • (i) 0

Z = ?

  • (a) 1
  • (b) 2
  • (c) 3
  • (d) 4
  • (e) 5
  • (f) 6
  • (g) 7
  • (h) 8
  • (i) 0

延伸問題:

  1. 解釋上述程式碼運作原理
  2. The Linux Kernel API 頁面搜尋 "power of 2",可見 is_power_of_2,查閱 Linux 核心原始程式碼,舉例說明其作用和考量,特別是 round up/down power of 2
    • 特別留意 __roundup_pow_of_two__rounddown_pow_of_two 的實作機制
  3. 研讀 slab allocator,探索 Linux 核心的 slab 實作,找出運用 power-of-2 的程式碼和解讀

測驗 3

考慮到一個 bitcpy 實作,允許開發者對指定的記憶體區域,逐 bit 進行資料複製,這個將指定的位元偏移量及位元數複製到目標位址的函式原型宣告如下:

void bitcpy(void *_dest,      /* Address of the buffer to write to */
            size_t _write,    /* Bit offset to start writing to */
            const void *_src, /* Address of the buffer to read from */
            size_t _read,     /* Bit offset to start reading from */
            size_t count)
  • input/_src: 長度為 8 個 uint8 陣列 (總共 64 位元)。注意: 其位元順序布局由每個位元組的 MSB (Most Significant Bit) 往上增加,如下圖一所示。
  • output/_dest: 長度為 8 個 uint8 陣列 (總共 64 位元),其位元順序布局如 input/_dest 所述。
  • _write: 從 _dest 的第 _write 個位元開始寫入 _count 位元數。
  • _read: 從 _src 的第 _read 個位元開始讀取 _count 位元數。
  • count: 讀取/寫入的位元數。

input/output 變數位元順序示意圖

我們可指定以下變數:

size_t read_lhs = _read & 7;
size_t read_rhs = 8 - read_lhs;

_read & 7 作用等同於 _read % 8

read_lhs > 0,表示起始位元沒有對齊位元組的 MSB,再者,宣告
可分為以下 2 種情況:

  • bitsize <= read_rhs : 欲讀取位元數,不需跨兩個位元組
  • bitsize > read_rhs : 欲讀取位元數,需跨兩個位元組

bitcpy 實作程式碼如下:

#include <stdint.h>

void bitcpy(void *_dest,      /* Address of the buffer to write to */
            size_t _write,    /* Bit offset to start writing to */
            const void *_src, /* Address of the buffer to read from */
            size_t _read,     /* Bit offset to start reading from */
            size_t count)
{
    size_t read_lhs = _read & 7;
    size_t read_rhs = 8 - read_lhs;
    const uint8_t *source = (const uint8_t *) _src + (_read / 8);
    size_t write_lhs = _write & 7;
    size_t write_rhs = 8 - write_lhs;
    uint8_t *dest = (uint8_t *) _dest + (_write / 8);

    static const uint8_t read_mask[] = {
        0x00, /*    == 0    00000000b   */
        0x80, /*    == 1    10000000b   */
        0xC0, /*    == 2    11000000b   */
        0xE0, /*    == 3    11100000b   */
        0xF0, /*    == 4    11110000b   */
        0xF8, /*    == 5    11111000b   */
        0xFC, /*    == 6    11111100b   */
        0xFE, /*    == 7    11111110b   */
        0xFF  /*    == 8    11111111b   */
    };

    static const uint8_t write_mask[] = {
        0xFF, /*    == 0    11111111b   */
        0x7F, /*    == 1    01111111b   */
        0x3F, /*    == 2    00111111b   */
        0x1F, /*    == 3    00011111b   */
        0x0F, /*    == 4    00001111b   */
        0x07, /*    == 5    00000111b   */
        0x03, /*    == 6    00000011b   */
        0x01, /*    == 7    00000001b   */
        0x00  /*    == 8    00000000b   */
    };

    while (count > 0) {
        uint8_t data = *source++;
        size_t bitsize = (count > 8) ? 8 : count;
        if (read_lhs > 0) {
            RRR;
            if (bitsize > read_rhs)
                data |= (*source >> read_rhs);
        }

        if (bitsize < 8)
            data &= read_mask[bitsize];

        uint8_t original = *dest;
        uint8_t mask = read_mask[write_lhs];
        if (bitsize > write_rhs) {
            /* Cross multiple bytes */
            *dest++ = (original & mask) | (data >> write_lhs);
            original = *dest & write_mask[bitsize - write_rhs];
            *dest = original | (data << write_rhs);
        } else {
            // Since write_lhs + bitsize is never >= 8, no out-of-bound access.
            DDD;
            *dest++ = (original & mask) | (data >> write_lhs);
        }

        count -= bitsize;
    }
}

搭配的測試程式碼:

#include <stdio.h>
#include <string.h>

static uint8_t output[8], input[8];

static inline void dump_8bits(uint8_t _data)
{   
    for (int i = 0; i < 8; ++i)
        printf("%d", (_data & (0x80 >> i)) ? 1 : 0);
}

static inline void dump_binary(uint8_t *_buffer, size_t _length)
{   
    for (int i = 0; i < _length; ++i)
        dump_8bits(*_buffer++);
}

int main(int _argc, char **_argv)
{
    memset(&input[0], 0xFF, sizeof(input));

    for (int i = 1; i <= 32; ++i) {
        for (int j = 0; j < 16; ++j) {
            for (int k = 0; k < 16; ++k) {
                memset(&output[0], 0x00, sizeof(output));
                printf("%2d:%2d:%2d ", i, k, j);
                bitcpy(&output[0], k, &input[0], j, i);
                dump_binary(&output[0], 8);
                printf("\n");
            }
        }
    }

    return 0;
}

測試程式碼的參考輸出:

 1: 0: 0 1000000000000000000000000000000000000000000000000000000000000000
 1: 1: 0 0100000000000000000000000000000000000000000000000000000000000000
 1: 2: 0 0010000000000000000000000000000000000000000000000000000000000000                                                                                     
 ...
 1:15: 0 0000000000000001000000000000000000000000000000000000000000000000
 1: 0: 1 1000000000000000000000000000000000000000000000000000000000000000
 1: 1: 1 0100000000000000000000000000000000000000000000000000000000000000
 1: 2: 1 0010000000000000000000000000000000000000000000000000000000000000
 ...
 1:15: 1 0000000000000001000000000000000000000000000000000000000000000000
 ...

請補完程式碼,使其運作符合預期。

作答區

RRR = ?

  • (a) data <<= read_lhs
  • (b) data &= read_lhs
  • (c) data |= read_lhs
  • (d) data >>= read_lhs

DDD = ?

  • (a) mask |= write_mask[bitsize - write_lhs]
  • (b) mask |= write_mask[write_lhs + bitsize]

Linux 核心原始程式碼定義 bitmap,後者為 unsigned long 陣列 (詳見: include/linux/types.h)。其中,bitmap_set() 設定某一特定區間位元全為 1,程式碼如下:

  1. __builtin_constant_p() 編譯器優化: 此 GCC 內建函式判斷其參數值是否在編譯時期就已經知道。如果是,則回傳 1 (代表編譯器可做 constant folding)。bitmap_set 有兩個相關優化:
    • nbits = 1: 底下範例程式碼將數值 1 傳給 nbits,所以 if (__builtin_constant_p(nbits) && nbits == 1) 條件成立。
    ​​​​/* Source: arch/mips/kernel/smp-cps.c */
    ​​​​static void boot_core(unsigned int core, unsigned int vpe_id)
    ​​​​{
    ​​​​    ...
    ​​​​    /* The core is now powered up */
    ​​​​    bitmap_set(core_power, core, 1);
    ​​​​}
    
    • startnbits 對齊 8 個位元組 (bitmap 為 unsigned long 陣列,所以最小單位為 unsigned long): 此狀況可呼叫 memset。
    ​​​​/* Source: include/linux/bitmap.h */
    ​​​​static __always_inline void bitmap_set(unsigned long *map, unsigned int start,
    ​​​​                unsigned int nbits)
    ​​​​{
    ​​​​        if (__builtin_constant_p(nbits) && nbits == 1)
    ​​​​                __set_bit(start, map);
    ​​​​        else if (__builtin_constant_p(start & BITMAP_MEM_MASK) &&
    ​​​​             IS_ALIGNED(start, BITMAP_MEM_ALIGNMENT) &&
    ​​​​                 __builtin_constant_p(nbits & BITMAP_MEM_MASK) &&
    ​​​​                 IS_ALIGNED(nbits, BITMAP_MEM_ALIGNMENT))
    ​​​​                memset((char *)map + start / 8, 0xff, nbits / 8);
    ​​​​        else
    ​​​​                __bitmap_set(map, start, nbits);
    ​​​​}
    
  2. 呼叫 __bitmap_set(): 此函式依據參數 startlen 設定其對應的位元為 1。說明如下:
    • p: 依據起始位元 (start) 找出對應的 bitmap 陣列元素。
    • size: 紀錄最後一個 bitmap 陣列元素中,欲設定的位元總數。
    • bits_to_set: 此次所設定 bitmap 陣列元素的位元總數。
    • mask_to_set: 此次所設定 bitmap 陣列元素的位元遮罩。
    • while迴圈: 此迴圈設定每一個陣列元素的所有位元為 1 (當所欲設定的位元總數超過 sizeof (unsigned long) ,也就是超過 64 個位元。
    • 最後一個 bitmap 陣列元素設定: 依據 startlen 找出對應的 mask_to_set
      ​​​​​​​​/* Source: lib/bitmap.c */
      ​​​​​​​​void __bitmap_set(unsigned long *map, unsigned int start, int len)
      ​​​​​​​​{
      ​​​​​​​​    unsigned long *p = map + BIT_WORD(start);
      ​​​​​​​​    const unsigned int size = start + len;
      ​​​​​​​​    int bits_to_set = BITS_PER_LONG - (start % BITS_PER_LONG);
      ​​​​​​​​    unsigned long mask_to_set = BITMAP_FIRST_WORD_MASK(start);
      
      ​​​​​​​​    while (len - bits_to_set >= 0) {
      ​​​​​​​​        *p |= mask_to_set;
      ​​​​​​​​        len -= bits_to_set;
      ​​​​​​​​        bits_to_set = BITS_PER_LONG;
      ​​​​​​​​        mask_to_set = ~0UL;
      ​​​​​​​​        p++;
      ​​​​​​​​    }
      ​​​​​​​​    if (len) {
      ​​​​​​​​        mask_to_set &= BITMAP_LAST_WORD_MASK(size);
      ​​​​​​​​        *p |= mask_to_set;
      ​​​​​​​​    }
      ​​​​​​​​}
      

延伸問題:

  1. 解釋上述程式碼運作原理,並嘗試重寫為同樣功能但效率更高的實作
  2. 在 Linux 核心原始程式碼找出逐 bit 進行資料複製的程式碼,並解說相對應的情境

測驗 4

預計藉由探討 string interning 來討論 "immutable" 的概念,為日後學習 Rust 程式語言做準備。儘管 C 語言沒有在語言層級支援 string interning,但用 C 語言開發的 Linux 核心卻有頗多地方實作 CoW (copy-on-write),從而改進記憶體使用效率。

"intern" 這詞在許多人的認知是公司的實習生,不過原本的意思是「拘留」和「扣押」,對於公司來說,將在校生「關在」公司做符合商業利益的事,這形式就是 "intern"。上述的 "string interning" 可翻譯為「字串駐留」,這種最佳化手段可將某些出現多處的字串,簡約為單一儲存空間,換言之,實際存取字串時,並非副本,而是指標或記憶體參照 (reference),例如在 Python 虛擬機器中,就提供 string interning 的實作。

為了用 C 語言實作 string interning,我們會需要有效的 hash,讓字串轉為數值,而且不能有碰撞,再者是 cache 機制,搭配針對小字串的記憶體管理器 —— 這幾個要素恰好都在 Linux 核心內部出現。

以下是 string interning 的簡易實作:

  • cstr.h
#pragma once
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>

enum {
    CSTR_PERMANENT = 1,
    CSTR_INTERNING = 2,
    CSTR_ONSTACK = 4,
};

#define CSTR_INTERNING_SIZE (32)
#define CSTR_STACK_SIZE (128)

typedef struct __cstr_data {
    char *cstr;
    uint32_t hash_size;
    uint16_t type;
    uint16_t ref;
} * cstring;

typedef struct __cstr_buffer {
    cstring str;
} cstr_buffer[1];

#define CSTR_S(s) ((s)->str)

#define CSTR_BUFFER(var)                                                      \
    char var##_cstring[CSTR_STACK_SIZE] = {0};                                \
    struct __cstr_data var##_cstr_data = {var##_cstring, 0, CSTR_ONSTACK, 0}; \
    cstr_buffer var;                                                          \
    var->str = &var##_cstr_data;

#define CSTR_LITERAL(var, cstr)                                               \
    static cstring var = NULL;                                                \
    if (!var) {                                                               \
        cstring tmp = cstr_clone("" cstr, (sizeof(cstr) / sizeof(char)) - 1); \
        if (tmp->type == 0) {                                                 \
            tmp->type = CSTR_PERMANENT;                                       \
            tmp->ref = 0;                                                     \
        }                                                                     \
        if (!__sync_bool_compare_and_swap(&var, NULL, tmp)) {                 \
            if (tmp->type == CSTR_PERMANENT)                                  \
                free(tmp);                                                    \
        }                                                                     \
    }

#define CSTR_CLOSE(var)               \
    do {                              \
        if (!(var)->str->type)        \
            cstr_release((var)->str); \
    } while (0)

/* Public API */
cstring cstr_grab(cstring s);
cstring cstr_clone(const char *cstr, size_t sz);
cstring cstr_cat(cstr_buffer sb, const char *str);
int cstr_equal(cstring a, cstring b);
void cstr_release(cstring s);
  • cstr.c
#include <errno.h>
#include <stdarg.h>
#include <stdio.h>
#include <string.h>

#include "cstr.h"

#define INTERNING_POOL_SIZE 1024

#define HASH_START_SIZE 16 /* must be power of 2 */

struct __cstr_node {
    char buffer[CSTR_INTERNING_SIZE];
    struct __cstr_data str;
    struct __cstr_node *next;
};

struct __cstr_pool {
    struct __cstr_node node[INTERNING_POOL_SIZE];
};

struct __cstr_interning {
    int lock;
    int index;
    unsigned size;
    unsigned total;
    struct __cstr_node **hash;
    struct __cstr_pool *pool;
};

static struct __cstr_interning __cstr_ctx;

/* FIXME: use C11 atomics */
#define CSTR_LOCK()                                               \
    ({                                                            \
        while (__sync_lock_test_and_set(&(__cstr_ctx.lock), 1)) { \
        }                                                         \
    })
#define CSTR_UNLOCK() ({ __sync_lock_release(&(__cstr_ctx.lock)); })

static void *xalloc(size_t n)
{
    void *m = malloc(n);
    if (!m)
        exit(-1);
    return m;
}

static inline void insert_node(struct __cstr_node **hash,
                               int sz,
                               struct __cstr_node *node)
{
    uint32_t h = node->str.hash_size;
    int index = h & (sz - 1);
    node->next = hash[index];
    hash[index] = node;
}

static void expand(struct __cstr_interning *si)
{
    unsigned new_size = si->size * 2;
    if (new_size < HASH_START_SIZE)
        new_size = HASH_START_SIZE;

    struct __cstr_node **new_hash =
        xalloc(sizeof(struct __cstr_node *) * new_size);
    memset(new_hash, 0, sizeof(struct __cstr_node *) * new_size);

    for (unsigned i = 0; i < si->size; ++i) {
        struct __cstr_node *node = si->hash[i];
        while (node) {
            struct __cstr_node *tmp = node->next;
            insert_node(new_hash, new_size, node);
            node = tmp;
        }
    }

    free(si->hash);
    si->hash = new_hash;
    si->size = new_size;
}

static cstring interning(struct __cstr_interning *si,
                         const char *cstr,
                         size_t sz,
                         uint32_t hash)
{
    if (!si->hash)
        return NULL;

    int index = (int) (hash & (si->size - 1));
    struct __cstr_node *n = si->hash[index];
    while (n) {
        if (n->str.hash_size == hash) {
            if (!strcmp(n->str.cstr, cstr))
                return &n->str;
        }
        n = n->next;
    }
    // 80% (4/5) threshold
    if (si->total * 5 >= si->size * 4)
        return NULL;
    if (!si->pool) {
        si->pool = xalloc(sizeof(struct __cstr_pool));
        si->index = 0;
    }
    n = &si->pool->node[si->index++];
    memcpy(n->buffer, cstr, sz);
    n->buffer[sz] = 0;

    cstring cs = &n->str;
    cs->cstr = n->buffer;
    cs->hash_size = hash;
    cs->type = CSTR_INTERNING;
    cs->ref = 0;

    n->next = si->hash[index];
    si->hash[index] = n;

    return cs;
}

static cstring cstr_interning(const char *cstr, size_t sz, uint32_t hash)
{
    cstring ret;
    CSTR_LOCK();
    ret = interning(&__cstr_ctx, cstr, sz, hash);
    if (!ret) {
        expand(&__cstr_ctx);
        ret = interning(&__cstr_ctx, cstr, sz, hash);
    }
    ++__cstr_ctx.total;
    CSTR_UNLOCK();
    return ret;
}

static inline uint32_t hash_blob(const char *buffer, size_t len)
{
    const uint8_t *ptr = (const uint8_t *) buffer;
    size_t h = len;
    size_t step = (len >> 5) + 1;
    for (size_t i = len; i >= step; i -= step)
        h = h ^ ((h << 5) + (h >> 2) + ptr[i - 1]);
    return h == 0 ? 1 : h;
}

cstring cstr_clone(const char *cstr, size_t sz)
{
    if (sz < CSTR_INTERNING_SIZE)
        return cstr_interning(cstr, sz, hash_blob(cstr, sz));
    cstring p = xalloc(sizeof(struct __cstr_data) + sz + 1);
    if (!p)
        return NULL;
    void *ptr = (void *) (p + 1);
    p->cstr = ptr;
    p->type = 0;
    p->ref = 1;
    memcpy(ptr, cstr, sz);
    ((char *) ptr)[sz] = 0;
    p->hash_size = 0;
    return p;
}

cstring cstr_grab(cstring s)
{
    if (s->type & (CSTR_PERMANENT | CSTR_INTERNING))
        return s;
    if (s->type == CSTR_ONSTACK)
        return cstr_clone(s->cstr, s->hash_size);
    if (s->ref == 0)
        s->type = CSTR_PERMANENT;
    else
        __sync_add_and_fetch(&s->ref, 1);
    return s;
}

void cstr_release(cstring s)
{
    if (s->type || !s->ref)
        return;
    if (__sync_sub_and_fetch(&s->ref, 1) == 0)
        free(s);
}

static size_t cstr_hash(cstring s)
{
    if (s->type == CSTR_ONSTACK)
        return hash_blob(s->cstr, s->hash_size);
    if (s->hash_size == 0)
        s->hash_size = hash_blob(s->cstr, strlen(s->cstr));
    return s->hash_size;
}

int cstr_equal(cstring a, cstring b)
{
    if (a == b)
        return 1;
    if ((a->type == CSTR_INTERNING) && (b->type == CSTR_INTERNING))
        return 0;
    if ((a->type == CSTR_ONSTACK) && (b->type == CSTR_ONSTACK)) {
        if (a->hash_size != b->hash_size)
            return 0;
        return memcmp(a->cstr, b->cstr, a->hash_size) == 0;
    }
    uint32_t hasha = cstr_hash(a);
    uint32_t hashb = cstr_hash(b);
    if (hasha != hashb)
        return 0;
    return !strcmp(a->cstr, b->cstr);
}

static cstring cstr_cat2(const char *a, const char *b)
{
    size_t sa = strlen(a), sb = strlen(b);
    if (sa + sb < CSTR_INTERNING_SIZE) {
        char tmp[CSTR_INTERNING_SIZE];
        memcpy(tmp, a, sa);
        memcpy(tmp + sa, b, sb);
        tmp[sa + sb] = 0;
        return cstr_interning(tmp, sa + sb, hash_blob(tmp, sa + sb));
    }
    cstring p = xalloc(sizeof(struct __cstr_data) + sa + sb + 1);
    if (!p)
        return NULL;

    char *ptr = (char *) (p + 1);
    p->cstr = ptr;
    p->type = 0;
    p->ref = 1;
    memcpy(ptr, a, sa);
    memcpy(ptr + sa, b, sb);
    ptr[sa + sb] = 0;
    p->hash_size = 0;
    return p;
}

cstring cstr_cat(cstr_buffer sb, const char *str)
{
    cstring s = sb->str;
    if (s->type == CSTR_ONSTACK) {
        int i = CCC;
        while (i < CSTR_STACK_SIZE - 1) {
            s->cstr[i] = *str;
            if (*str == 0)
                return s;
            ++s->hash_size;
            ++str;
            ++i;
        }
        s->cstr[i] = 0;
    }
    cstring tmp = s;
    sb->str = cstr_cat2(tmp->cstr, str);
    cstr_release(tmp);
    return sb->str;
}

對應的測試程式:

#include <stdio.h>

#include "cstr.h"

static cstring cmp(cstring t)
{
    CSTR_LITERAL(hello, "Hello string");
    CSTR_BUFFER(ret);
    cstr_cat(ret, cstr_equal(hello, t) ? "equal" : "not equal");
    return cstr_grab(CSTR_S(ret));
}

static void test_cstr()
{
    CSTR_BUFFER(a);
    cstr_cat(a, "Hello ");
    cstr_cat(a, "string");
    cstring b = cmp(CSTR_S(a));
    printf("%s\n", b->cstr);
    CSTR_CLOSE(a);
    cstr_release(b);
}

int main(int argc, char *argv[])
{
    test_cstr();
    return 0;
}

預期執行結果為:

equal

作答區

CCC = ?

  • (a) strlen(str) + 1
  • (b) strlen(str) * 2
  • (c) s->hash_size
  • (d) s->size
  • (e) strlen(s->cstr) + 1

延伸問題:

  1. 解釋上述程式碼運作原理
  2. 上述程式碼使用到 gcc atomics builtins,請透過 POSIX Thread 在 GNU/Linux 驗證多執行緒的環境中,string interning 能否符合預期地運作?若不能,請提出解決方案
  3. chriso/intern 是個更好的 string interning 實作,請探討其特性和設計技巧,並透過內建的 benchmark 分析其效能表現