--- tags: NCKU Linux Kernel Internals, C語言 --- # C語言:遞迴呼叫 [你所不知道的C語言:遞迴呼叫篇](https://hackmd.io/@sysprog/c-recursion?type=view) ## 從簡單的題目思考遞迴 思考以下題目: 假設有一個只含 0 / 1 的數列,要計算有多少的 0 前面是 1 的配對。以 {1,0,1,0,0,1} 為例,第一個 0 前面(左邊)有 1 個 1,第二個 0 前面有 2 個 1,第三個 0 前面有 2 個 1,所以應該要輸出 `1 + 2 + 2 = 5`。 ### 解法 1 最直觀的想法是,迴圈找 0,然後去算 0 前面有幾個 1 ```c= int find(char *a,int n){ int ans,i; for (i = 0, ans = 0 ; i < n ;i++) { if (a[i] == 0) { int j; int before = 0; for(j = 0; j < i; j++){ if (a[j] == 1) before ++; } ans += before; } } return ans; } ``` 這個作法的時間複雜度是 O(n^2) ### 解法 2 可以利用 prefix 的思維: 對於每次找到的 0,其前面的 1 是數量累積起來的。 ```c= int find(char *a,int n){ int ans,cnt,i; for (i = 0, cnt = 0, ans = 0 ; i < n ;i++) { if (a[i] == 0) ans += cnt; else cnt++; } return ans; } ``` 如此一來,便可以將時間複雜度降低至 O(n) ### 解法 3 轉換一下思維,讓我們用 divide and conquer 來解決此問題。如果把數列從中間切開分成左右,對於右邊的 y 個 0,左邊都各自有 x 個 1,則答案會有 x * y 個。然後再對左右的數列各自遞迴求解,詳細的程式碼如下: ```c= int find(char *a, int left, int right) { if (right - left < 2) return 0; int mid = (left + right) / 2; int w = dc(a,left, mid) + dc(a,mid, right); int x,y; x = y = 0; // 計算左邊幾個 1 for(int i = left; i < mid ; i++) if( a[i] == 1) x++; // 計算右邊幾個 0 for(int i = mid; i < right ; i++) if( a[i] == 0) y++; return w + x * y; } ``` 然而,這個演算法的時間 T(n) = 2T(n/2) + O(n),得出時間複雜度為 O(n lgn)。 ### 解法 4 解法 3 並不比解法 2 還要快,原因是因為我們對於 0/1 的數量重複計算,如果把程式改寫: ```c= struct ret{ int ans; int y; // 0 的數量 int x; // 1 的數量 }; struct ret dc(char *a, int left, int right) { if (right - left < 2) { if(a[left] == 0) return (struct ret){0,1,0}; else return (struct ret){0,0,1}; } int mid = (left + right) / 2; struct ret w1 = dc(a,left, mid); struct ret w2 = dc(a,mid, right); return (struct ret){w1.ans + w2.ans + w1.x * w2.y, w1.y + w2.y, w1.x + w2.x}; } ``` 則時間是 T(n) = 2T(n/2) + O(1),時間複雜度是 O(n)。 ### 總結 事實上,一個演算法的遞迴版本勢必會有對應的迴圈版本。以此題來看,看似遞迴的算法更難理解,然而如果把問題推廣到不只有 0 / 1 的數列呢(逆序數問題) ? 如果我們想要透過迴圈解題,可能會需要一個複雜的資料結構。但是如果透過遞迴,問題得以被簡單化(準確的說,使用更少的變數維護狀態,程式也更容易被理解) 對於逆序數的問題,其暴力解為: ```c= int brute_check(int *a, int left ,int right){ int cnt = 0; for(int i = left; i < right; i++){ for (int j = left; j < i; j++){ if(a[j] > a[i]) cnt++; } } return cnt; } ``` 但是我們可以其實可以透過遞迴求解,而其解法其實只是 merge sort 的延伸: ``` c= int merge(int *a, int left, int mid, int right){ int ans = 0; int *tmp = malloc(sizeof(int)*(right - left)); int i = left; int j = mid; int k = 0; while(i < mid && j < right) { if(a[i] <= a[j]) tmp[k++] = a[i++]; else { tmp[k++] = a[j++]; ans += mid-i; } } while(i < mid ) tmp[k++] = a[i++]; while(j < right) tmp[k++] = a[j++]; int count = 0; for(int i = left;i < right; i++) a[i] = tmp[count++]; free(tmp); return ans; } int mergesort(int *a, int left, int right) { if(right - left > 1) { int mid=(left+right)/2; int w = mergesort(a,left,mid) + mergesort(a,mid,right); int c = merge(a,left,mid,right); return w + c; } else return 0; } ``` :::info 1. 詳細請直接參閱 [你所不知道的C語言:遞迴呼叫篇](https://hackmd.io/@sysprog/c-recursion?type=view),此處僅附上自己額外補充的程式! 2. 本人程式 sence 極差,如果找到程式執行的問題還請告知! 大家還請多多包容qq ::: ## 迷思: 迴圈 always 優於遞迴 前面有提及,所有可以用遞迴呈現的演算法都必定有迴圈的版本。而以往我們常常有相同類型的程式,迴圈的執行速度比遞迴快的迷思(額外的 allocate / push / pop stack frame)。這個想法固然沒錯,然而,得益於編譯器越來越強大,有些時候迴圈或者遞迴版本的程式,編譯器都可以直接幫我們優化成相同的 assembly。也就是說,寫成遞迴或者迴圈的形式在執行速度並無反別,然而遞迴程式反而可能會擁有更佳的可讀性! ~~compiler 萬歲~~ 該如何寫遞迴,才可以使編譯器幫我們做到最佳化呢? 這裡就必須提到 Tail recursion。Tail recursion 是遞迴的一種特殊形式,指 function 的最後一步是呼叫自己。 遞迴之所以跑得慢,是因為大量使用 stack 來儲存某些資料,如果 function 的最後一步不僅是呼叫自己,則不斷遞迴的過程中,有許多變數/ return address 都需要額外的 stack。反過來說,如果 function 的最後一步是遞迴呼叫的話,從編譯器的角度來說,遞迴實際上是不需要不斷的 call stack 的,因此便可以被優化! Reference: * [遞迴的美麗與哀愁](https://www.ithome.com.tw/node/81087) * [一般递归与尾递归(Tail Recursion)](http://qiusli.github.io/blog/2013/07/09/tail-recursion/) * [浅谈尾递归](https://site.douban.com/196781/widget/notes/12161495/note/262014367/) ## 案例分析:數列輸出 思考以下程式的功能: ```c= int p(int i, int N) { return (i < N && printf("%d\n", i) && !p(i + 1, N)) || printf("%d\n", i); } ``` 運行的關鍵在 `&&` 和 `||` 的特性。 * `A && B`: 如果 A statement 不成立,則 B 不會執行 * `A || B`: 如果 A statement 成立,則 B 不會執行 以及 [printf](https://linux.die.net/man/3/printf) 的回傳值是印出的字元數量,因此正常狀況下 p 裡的 printf 回傳不為 0。 因此,上面的程式其實等價於寫成: ```c= int p(int i, int N) { if (i < N){ printf("%d\n", i); p(i + 1, N); printf("%d\n", i); } else{ printf("%d\n", i); } } ``` ## 案例分析:字串反轉 思考以下問題: 實做 char *reverse(char *s),反轉 NULL 結尾的字串,限定 in-place 與遞迴。 解題的思路是: 如果要 in-place 完成 reverse,只要透過數個 swap 就可以了。而提到 in-place 的 swap 勢必會想到 [XOR](https://hackmd.io/9tGfpJtLTwyyOzDvlyJS2w?view#XOR-Swap)。 ```c= void swap(char *a, char *b) { *a = *a ^ *b; *b = *a ^ *b; *a = *a ^ *b; } ``` 有了 in-place 的 swap,就可以設計反轉字串的函式了。一個設計的思考模式是,對於一個字串 s,如果 reverse(s) 代表反轉 s,舉例反轉 `54321` 來說其實可以拆解成: * reverse(s + 1): 先反轉後面的字串 -> **51234** * 將 s 的 0 位置和 1 位置交換 -> **15234** * reverse(s + 2): 把 1 位置後的字串恢復成原本的順序 -> **15432** * reverse(s + 1): 位置 0 已經是正確的值了,此時再去反轉後面即可,字串的 0、1、... 會依序變成正確的數值,一直遞迴下去直到 '\0' 程式可以被寫成: ```c= void reverse(char *s) { if((*s == '\0') || (*(s + 1) == '\0')) return; reverse(s + 1); swap(s, (s + 1)); if (*(s + 2) != '\0') reverse(s + 2); reverse(s + 1); } ``` 然而,這個演算的時間複雜度為 O(n^2) (對於長度為 n 的字串,每次只把第 0 個放好,然而再用同樣的方法去轉剩下 n - 1 個) 但是其實我們可以轉換一下思維,假如我們先計算出字串的長度的話,其實 reverse 長度 n 的字串就是把 0 <-> n-1 / 1 <-> n-2 ......依序 swap 而已,因此: ```c= int reverse(char *head, int idx) { if (head[idx] != '\0') { int end = reverse(head, idx + 1); if (idx > end / 2) swap(head + idx, head + end - idx); return end; } return idx - 1; } ``` 其實只要遞迴求長度,然後透過長度計算要交換的對應位置在哪裡就可以了! 複雜度便可以降低至O(n)。 ### 補充: 迴圈版本 等同於上面遞迴的迴圈版本,可能會比較好理解一點? ```c= void reverse(char *s) { int len = 0; while( *(s + len)!='\0') len++; int mid = len / 2; for(int i = 0; i < mid; i++) swap(s + i, s + len - i - 1); } ``` ## TODO - [ ] 研究 [MapReduce with POSIX Thread](https://hackmd.io/@top30339/Hkb-lXkyg?type=view),深入了解何謂 MapReduce - [ ] 研究 [Functional Programming 風格的 C 語言實作](https://hackmd.io/@sysprog/c-functional-programming),深入了解何謂 functional programming