Try   HackMD

研讀 Linux 核心原始程式碼 list_sort.c

contributed by < Risheng1128 >

Linux2022 開發紀錄(lab0)
Linux2023 開發紀錄(lab0)

首先 Linux kernel 的 merge sort 主要由三個函式 list_sort()merge()merge_final() 組成,以下開始分析每個函式所做的功能

在分析函式之前,先了解函式的 prototype ,每個函式的宣告都擁有函式屬性 __attribute__((nonnull())) ,參考 __attribute__((nonnull)) function attribute 可以得知 nonnull 是用來使特定函數的引數為 NULL 時,編譯器會發出警告

This function attribute specifies function parameters that are not supposed to be null pointers. This enables the compiler to generate a warning on encountering such a parameter.

list_sort()

以下為 list_sort() 的定義及引數

__attribute__((nonnull(2,3))) void list_sort(void *priv, struct list_head *head, list_cmp_func_t cmp)
@priv: private data, opaque to list_sort(), passed to @cmp
@head: the list to sort
@cmp: the elements comparison function

這邊根據原始碼對 cmp 做一些分析

  1. 首先是 cmp 的資料型態 list_cmp_func_t ,根據 include/linux/list_sort.h 可以找到,可以得知 cmp 其實是一個函式指標,指向回傳型態為 int ,且引數有 void * 及 2 個 const struct list_head * 的函式
    ​​​typedef int __attribute__((nonnull(2,3))) (*list_cmp_func_t)(void *, const struct list_head *, const struct list_head *);
    
  2. cmp 回傳大於 0 時,表示執行 ascending sort
  3. cmp 回傳小於等於 0 時,表示執行 descending sort 或是保留原本順序

接著來到分析 merge sort 在 Linux kernel 的實際樣貌,這邊參考前人的精華及原始碼搭配分析

 * This mergesort is as eager as possible while always performing at least
 * 2:1 balanced merges.  Given two pending sublists of size 2^k, they are
 * merged to a size-2^(k+1) list as soon as we have 2^k following elements.
 *
 * Thus, it will avoid cache thrashing as long as 3*2^k elements can
 * fit into the cache.  Not quite as good as a fully-eager bottom-up
 * mergesort, but it does use 0.2*n fewer comparisons, so is faster in
 * the common case that everything fits into L1.

首先查看註解的部份,可以得知該 merge sort 的鏈節串列總是保持 2:1 ,即比較長的 linked list 至少要是短的 2 倍長。假設有兩個 2k 大小的 linked list , merge sort 不會選擇直接合併,而是會等到有第 3 個長度為 2k 的 linked list 才會開始合併,變成 2k+1 及 2k 兩個 linked list ,並符合前面所說的 2:1
如果前述所說的這 3 個 linked list ,都可以被放進 cache 裡,則可以避免 Cache thrashing 的問題,即不會發生 cache miss

 * The merging is controlled by "count", the number of elements in the
 * pending lists.  This is beautifully simple code, but rather subtle.
 *
 * Each time we increment "count", we set one bit (bit k) and clear
 * bits k-1 .. 0.  Each time this happens (except the very first time
 * for each bit, when count increments to 2^k), we merge two lists of
 * size 2^k into one list of size 2^(k+1).
 *
 * This merge happens exactly when the count reaches an odd multiple of
 * 2^k, which is when we have 2^k elements pending in smaller lists,
 * so it's safe to merge away two lists of size 2^k.
 *
 * After this happens twice, we have created two lists of size 2^(k+1),
 * which will be merged into a list of size 2^(k+2) before we create
 * a third list of size 2^(k+1), so there are never more than two pending.
 *

由上述的註解可以歸納出以下幾點:

  1. 變數 count 用來計算在 pending list 的 element 總數
  2. count1 後,會將第 k 個位元設為 1 ,而 k-10 位元,則會被清為 0
  3. merge 會在 count 達到 2k 的奇數倍時發生,且 merge 發生兩次後會產生 2 個 2k+1 大小的 linked list
 * The number of pending lists of size 2^k is determined by the
 * state of bit k of "count" plus two extra pieces of information:
 *
 * - The state of bit k-1 (when k == 0, consider bit -1 always set), and
 * - Whether the higher-order bits are zero or non-zero (i.e.
 *   is count >= 2^(k+1)).
 *
 * There are six states we distinguish.  "x" represents some arbitrary
 * bits, and "y" represents some arbitrary non-zero bits:
 * 0:  00x: 0 pending of size 2^k;           x pending of sizes < 2^k
 * 1:  01x: 0 pending of size 2^k; 2^(k-1) + x pending of sizes < 2^k
 * 2: x10x: 0 pending of size 2^k; 2^k     + x pending of sizes < 2^k
 * 3: x11x: 1 pending of size 2^k; 2^(k-1) + x pending of sizes < 2^k
 * 4: y00x: 1 pending of size 2^k; 2^k     + x pending of sizes < 2^k
 * 5: y01x: 2 pending of size 2^k; 2^(k-1) + x pending of sizes < 2^k
 * (merge and loop back to state 2)

本來想先從註解開始理解,但上方的註解實在是太難懂了..,因此決定從程式碼一起理解

首先可以看到 Linux kernel 在排序之前將 linked list 環狀的部份切開

/* Convert to a null-terminated singly-linked list. */
head->prev->next = NULL;

接著開始尋找變數 count 位於最低位元的 clear bit 位置

/* Find the least-significant clear bit in count */
for (bits = count; bits & 1; bits >>= 1)
	tail = &(*tail)->prev;

再來出現了一個巨集函式 likely() ,可以在 /include/linux/compiler.h 找到定義

if (likely(bits)) {

發現一共有兩種定義

//第一種
# ifndef likely
#  define likely(x)	(__branch_check__(x, 1, __builtin_constant_p(x)))
# endif
//第二種
# define likely(x)	__builtin_expect(!!(x), 1)

分析整個標頭檔的邏輯後,可以發現要選擇第一種還是第二種的 likely() 是由 CONFIG_TRACE_BRANCH_PROFILINGDISABLE_BRANCH_PROFILING__CHECKER__ 作決定 (由以下原始碼得知)

#if defined(CONFIG_TRACE_BRANCH_PROFILING) \
    && !defined(DISABLE_BRANCH_PROFILING) && !defined(__CHECKER__)

首先分析第一種的定義,出現了幾個沒看過得巨集函式 __branch_check__()__builtin_constant_p() ,繼續往下尋找

可以在 /include/linux/compiler.h 找到 __branch_check__() 的定義

#define __branch_check__(x, expect, is_constant) ({			\
			long ______r;					\
			static struct ftrace_likely_data		\
				__aligned(4)				\
				__section("_ftrace_annotated_branch")	\
				______f = {				\
				.data.func = __func__,			\
				.data.file = __FILE__,			\
				.data.line = __LINE__,			\
			};						\
			______r = __builtin_expect(!!(x), expect);	\
			ftrace_likely_update(&______f, ______r,		\
					     expect, is_constant);	\
			______r;					\
		})

有點難懂,不過慢慢查後還是可以得知一些訊息,首先是 __aligned__section ,可以在 include/linux/compiler_attributes.h 找到定義

#define __aligned(x)          __attribute__((__aligned__(x)))
#define __section(section)    __attribute__((__section__(section)))

再來是 __func____FILE____LINE__ 都是 gcc 的預設巨集,從 __func__ 可以參考 Function Names as Strings__LINE++__FILE__ 可以參考 Standard Predefined Macros

  • __FILE__

This macro expands to the name of the current input file, in the form of a C string constant. This is the path by which the preprocessor opened the file, not the short name specified in ‘#include’ or as the input file name argument. For example, "/usr/local/include/myheader.h" is a possible expansion of this macro.

  • __LINE__

This macro expands to the current input line number, in the form of a decimal integer constant. While we call it a predefined macro, it’s a pretty strange macro, since its “definition” changes with each new line of source code.

接著是很重要的 __builtin_expect() ,這個巨集也是 gcc 的預設巨集,參考Other Built-in Functions Provided by GCC可以更清楚的知道 __builtin_expect() 的意義

  • __builtin_expect()

You may use __builtin_expect to provide the compiler with branch prediction information. In general, you should prefer to use actual profile feedback for this (-fprofile-arcs), as programmers are notoriously bad at predicting how their programs actually perform. However, there are applications in which this data is hard to collect.

最後是函式 ftrace_likely_update() ,可以在 kernel/trace/trace_branch.c 找到該函式的實作,至於程式碼的邏輯及目的目前還沒有很清楚 (先做標記)

ftrace_likely_update()
void ftrace_likely_update(struct ftrace_likely_data *f, int val,
			  int expect, int is_constant)
{
	unsigned long flags = user_access_save();

	/* A constant is always correct */
	if (is_constant) {
		f->constant++;
		val = expect;
	}
	/*
	 * I would love to have a trace point here instead, but the
	 * trace point code is so inundated with unlikely and likely
	 * conditions that the recursive nightmare that exists is too
	 * much to try to get working. At least for now.
	 */
	trace_likely_condition(f, val, expect);

	/* FIXME: Make this atomic! */
	if (val == expect)
		f->data.correct++;
	else
		f->data.incorrect++;

	user_access_restore(flags);
}
EXPORT_SYMBOL(ftrace_likely_update);

接著換第二種情況,第二種就單純很多,只有呼叫巨集函式 __builtin_expect() ,跟第一種實作的方式一樣

從以上的分析可以得到 likely() 主要的功能

  1. __builtin_expect(!!(x), 1) 會告訴編譯器 !!(x)1 的機率很高,因此編譯器可以根據這樣的訊息對程式碼做優化,加速執行速度
  2. 換句話說 !!(x)1 , 表示 x 為非 0 的數字
  3. 根據原始碼的邏輯,我們可以得知,如果找到了 bits 的最低位元的 clear bit 後, bits 仍不為 0 ,則會進行 merge 的動作

回到 list_sort() 原始碼,這邊則是依據前面提到的 likely() 決定是否要做 merge 的動作,而 merge 是利用 ab 合併兩個 linked list 後面會有流程圖做更詳細的解釋

if (likely(bits)) {
    struct list_head *a = *tail, *b = a->prev;

    a = merge(priv, cmp, b, a);
    /* Install the merged result in place of the inputs */
    a->prev = b->prev;
    *tail = a;
}

接著在每一次的迭代都會增加一個節點到 pending list 裡,原始碼如下所示

/* Move one element from input list to pending */
list->prev = pending;
pending = list;
list = list->next;
pending->next = NULL;
count++;

可以將上述的過程簡單的做個流程圖,每張圖表示 do while 一開始的時候

  1. count = 0
    初始狀態如下圖所示,這邊假設一共有 5 個節點,其中 node1head






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





node3

node3

prev

next



node2:n->node3:m





node4

node4

prev

next



node3:n->node4:m





node5

node5

prev

next



node4:n->node5:m





NULL
NULL



node5:n->NULL





head
head



head->node1:m





pending
pending



pending->NULL





tail
tail



tail->pending





list
list



list->node2:m





  1. count = 1
    經過第一次的迭代,不會產生合併,只會將第二個節點 (node2) 加到 pending list ,如下圖表示, node2 已經被加到 pending list






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





NULL2
NULL



node2:p->NULL2





NULL3
NULL



node2:n->NULL3





node3

node3

prev

next



node3:p->node2:m





node4

node4

prev

next



node3:n->node4:m





node5

node5

prev

next



node4:n->node5:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node2:m





tail
tail



tail->pending





list
list



list->node3:m





進入下列原始碼,由於 count = 1 ,因此會進入一次迴圈,使 tail 指到 node2prev ,如下圖所示

/* Find the least-significant clear bit in count */
for (bits = count; bits & 1; bits >>= 1)
    tail = &(*tail)->prev;






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





NULL2
NULL



node2:p->NULL2





NULL3
NULL



node2:n->NULL3





node3

node3

prev

next



node3:p->node2:m





node4

node4

prev

next



node3:n->node4:m





node5

node5

prev

next



node4:n->node5:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node2:m





tail
tail



tail->node2:p





list
list



list->node3:m





  1. count = 2
    經過 2 次的迭代,此時 pending list 會有兩個節點,如下圖所示, node3 已被加進 pending list






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





NULL2
NULL



node2:p->NULL2





NULL3
NULL



node2:n->NULL3





node3

node3

prev

next



node3:p->node2:m





NULL4
NULL



node3:n->NULL4





node4

node4

prev

next



node4:p->node3:m





node5

node5

prev

next



node4:n->node5:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node3:m





tail
tail



tail->node2:p





list
list



list->node4:m





接著由於一開始下列程式碼的影響,將原本指到 node2->prevtail 改指回 pending 的位置,如下圖

struct list_head **tail = &pending;






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





NULL2
NULL



node2:p->NULL2





NULL3
NULL



node2:n->NULL3





node3

node3

prev

next



node3:p->node2:m





NULL4
NULL



node3:n->NULL4





node4

node4

prev

next



node4:p->node3:m





node5

node5

prev

next



node4:n->node5:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node3:m





tail
tail



tail->pending





list
list



list->node4:m





一樣執行上述的迴圈,出迴圈後 bits2 ,這時接著執行 if (likely(bits)) ,由於 bits 已經不為 0 ,因此會執行 merge 的動作 (a 指向 node3 ,且 b 指向 node2),下圖為合併結束的樣子 (假設位置不變), node2 已經接回 node3







list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





node3

node3

prev

next



node2:n->node3:m





NULL2
NULL



node2:p->NULL2





node3:p->node2:m





NULL3
NULL



node3:n->NULL3





node4

node4

prev

next



node4:p->node3:m





node5

node5

prev

next



node4:n->node5:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node3:m





tail
tail



tail->pending





list
list



list->node4:m





  1. count = 3
    經過了兩次加入 pending 及一次的 merge ,形成以下的圖,目前的 pending list 為 [2(node2,node3),1(node4)]






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





node3

node3

prev

next



node2:n->node3:m





NULL2
NULL



node2:p->NULL2





node3:p->node2:m





NULL3
NULL



node3:n->NULL3





node4

node4

prev

next



node4:p->node3:m





NULL4
NULL



node4:n->NULL4





node5

node5

prev

next



node5:p->node4:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node4:m





tail
tail



tail->pending





list
list



list->node5:m





  1. count = 4
    count = 3 的最後將 list 移到了 NULL 的位置,因此此時已經離開迴圈,下圖為離開迴圈前的最後狀態






list



node1

node1

prev

next



node2

node2

prev

next



node1:n->node2:m





node3

node3

prev

next



node2:n->node3:m





NULL2
NULL



node2:p->NULL2





node3:p->node2:m





NULL3
NULL



node3:n->NULL3





node4

node4

prev

next



node4:p->node3:m





NULL4
NULL



node4:n->NULL4





node5

node5

prev

next



node5:p->node4:m





NULL1
NULL



node5:n->NULL1





head
head



head->node1:m





pending
pending



pending->node5:m





tail
tail



tail->pending





list
list



list->NULL1





由上述的流程圖可以得知所有的節點已經被加到 pending list ,接著的步驟就是將所有的 pending list 合併在一起,原始碼如下

/* End of input; merge together all the pending lists. */
	list = pending;
	pending = pending->prev;
	for (;;) {
		struct list_head *next = pending->prev;

		if (!next)
			break;
		list = merge(priv, cmp, pending, list);
		pending = next;
	}

最後一個步驟也就是把已經斷掉的 prev 重新接回去

/* The final merge, rebuilding prev links */
merge_final(priv, cmp, head, pending, list);

最後發現了一個巨集函式 EXPORT_SYMBOL() ,參考Linux驅動開發——EXPORT_SYMBOL的使用,可以得知 EXPORT_SYMBOL() 的目的及使用方法

為了更清楚的說明了整個 linked list 合併的狀態,做出以下的表格

state merge count pending list 的狀態 (在迴圈一開始的地方) pending list 的狀態 (在迴圈結束的地方)
0 X 0b0000(0) NULL [1]
0 X 0b0001(1) [1] [1,1]
1 O 0b0010(2) [1,1] [2,1]
1 X 0b0011(3) [2,1] [2,1,1]
2 O 0b0100(4) [2,1,1] [2,2,1]
2 O 0b0101(5) [2,2,1] [4,1,1]
3 O 0b0110(6) [4,1,1] [4,2,1]
3 X 0b0111(7) [4,2,1] [4,2,1,1]
4 O 0b1000(8) [4,2,1,1] [4,2,2,1]
4 O 0b1001(9) [4,2,2,1] [4,4,1,1]
5 O 0b1010(10) [4,4,1,1] [4,4,2,1]
5 O 0b1011(11) [4,4,2,1] [8,2,1,1]
2 O. 0b1100(12) [8,2,1,1] [8,2,2,1]
2 O 0b1101(13) [8,2,2,1] [8,4,1,1]
3 O 0b1110(14) [8,4,1,1] [8,4,2,1]
3 X 0b1111(15) [8,4,2,1] [8,4,2,1,1]

merge()

這邊先附上整個 merge() 的程式碼

merge()
__attribute__((nonnull(2,3,4)))
static struct list_head *merge(void *priv, list_cmp_func_t cmp,
				struct list_head *a, struct list_head *b)
{
	struct list_head *head, **tail = &head;

	for (;;) {
		/* if equal, take 'a' -- important for sort stability */
		if (cmp(priv, a, b) <= 0) {
			*tail = a;
			tail = &a->next;
			a = a->next;
			if (!a) {
				*tail = b;
				break;
			}
		} else {
			*tail = b;
			tail = &b->next;
			b = b->next;
			if (!b) {
				*tail = a;
				break;
			}
		}
	}
	return head;
}

merge() 的實作就比 list_sort() 單純很多,首先建立一個指向 struct list_head 的指標 head ,接著建立一個指標的指標 tail ,指向 head

struct list_head *head, **tail = &head;

接著進入無限迴圈,使用 cmp 來決定下一個節點要從 a 還是 b 取得
cmp 回傳 ≤ 時選擇 a ,回傳 > 時則選擇 b

if (cmp(priv, a, b) <= 0) {
    *tail = a;
    tail = &a->next;
    a = a->next;
    if (!a) {
        *tail = b;
        break;
    }
} else {
    *tail = b;
    tail = &b->next;
    b = b->next;
    if (!b) {
        *tail = a;
        break;
    }
}

當 linked list a 已經沒有節點時,直接接上 linked list b ,反之則接上 linked list a

if (!a) {
	*tail = b;
	break;
}

if (!b) {
	*tail = a;
    break;
}

merge_final()

這邊附上整個 merge_final() 的程式碼

merge_final()
__attribute__((nonnull(2,3,4,5)))
static void merge_final(void *priv, list_cmp_func_t cmp, struct list_head *head,
			struct list_head *a, struct list_head *b)
{
	struct list_head *tail = head;
	u8 count = 0;

	for (;;) {
		/* if equal, take 'a' -- important for sort stability */
		if (cmp(priv, a, b) <= 0) {
			tail->next = a;
			a->prev = tail;
			tail = a;
			a = a->next;
			if (!a)
				break;
		} else {
			tail->next = b;
			b->prev = tail;
			tail = b;
			b = b->next;
			if (!b) {
				b = a;
				break;
			}
		}
	}

	/* Finish linking remainder of list b on to tail */
	tail->next = b;
	do {
		/*
		 * If the merge is highly unbalanced (e.g. the input is
		 * already sorted), this loop may run many iterations.
		 * Continue callbacks to the client even though no
		 * element comparison is needed, so the client's cmp()
		 * routine can invoke cond_resched() periodically.
		 */
		if (unlikely(!++count))
			cmp(priv, b, b);
		b->prev = tail;
		tail = b;
		b = b->next;
	} while (b);

	/* And the final links to make a circular doubly-linked list */
	tail->next = head;
	head->prev = tail;
}

將所有的 prev 接回前一個節點

/* Finish linking remainder of list b on to tail */
tail->next = b;
do {
    /*
     * If the merge is highly unbalanced (e.g. the input is
     * already sorted), this loop may run many iterations.
     * Continue callbacks to the client even though no
     * element comparison is needed, so the client's cmp()
     * routine can invoke cond_resched() periodically.
     */
    if (unlikely(!++count))
        cmp(priv, b, b);
    b->prev = tail;
    tail = b;
    b = b->next;
} while (b);

/* And the final links to make a circular doubly-linked list */
tail->next = head;
head->prev = tail;

參考資料

list_sort.c
研讀 Linux 核心的 lib/list_sort.c 原始程式碼
Standard Predefined Macros
Function Names as Strings
Other Built-in Functions Provided by GCC