---
tags: NCKU Linux Kernel Internals, C語言
---
# C語言:遞迴呼叫
[你所不知道的C語言:遞迴呼叫篇](https://hackmd.io/@sysprog/c-recursion?type=view)
## 從簡單的題目思考遞迴
思考以下題目: 假設有一個只含 0 / 1 的數列,要計算有多少的 0 前面是 1 的配對。以 {1,0,1,0,0,1} 為例,第一個 0 前面(左邊)有 1 個 1,第二個 0 前面有 2 個 1,第三個 0 前面有 2 個 1,所以應該要輸出 `1 + 2 + 2 = 5`。
### 解法 1
最直觀的想法是,迴圈找 0,然後去算 0 前面有幾個 1
```c=
int find(char *a,int n){
int ans,i;
for (i = 0, ans = 0 ; i < n ;i++) {
if (a[i] == 0) {
int j;
int before = 0;
for(j = 0; j < i; j++){
if (a[j] == 1)
before ++;
}
ans += before;
}
}
return ans;
}
```
這個作法的時間複雜度是 O(n^2)
### 解法 2
可以利用 prefix 的思維: 對於每次找到的 0,其前面的 1 是數量累積起來的。
```c=
int find(char *a,int n){
int ans,cnt,i;
for (i = 0, cnt = 0, ans = 0 ; i < n ;i++) {
if (a[i] == 0)
ans += cnt;
else cnt++;
}
return ans;
}
```
如此一來,便可以將時間複雜度降低至 O(n)
### 解法 3
轉換一下思維,讓我們用 divide and conquer 來解決此問題。如果把數列從中間切開分成左右,對於右邊的 y 個 0,左邊都各自有 x 個 1,則答案會有 x * y 個。然後再對左右的數列各自遞迴求解,詳細的程式碼如下:
```c=
int find(char *a, int left, int right) {
if (right - left < 2) return 0;
int mid = (left + right) / 2;
int w = dc(a,left, mid) + dc(a,mid, right);
int x,y;
x = y = 0;
// 計算左邊幾個 1
for(int i = left; i < mid ; i++)
if( a[i] == 1)
x++;
// 計算右邊幾個 0
for(int i = mid; i < right ; i++)
if( a[i] == 0)
y++;
return w + x * y;
}
```
然而,這個演算法的時間 T(n) = 2T(n/2) + O(n),得出時間複雜度為 O(n lgn)。
### 解法 4
解法 3 並不比解法 2 還要快,原因是因為我們對於 0/1 的數量重複計算,如果把程式改寫:
```c=
struct ret{
int ans;
int y; // 0 的數量
int x; // 1 的數量
};
struct ret dc(char *a, int left, int right) {
if (right - left < 2) {
if(a[left] == 0)
return (struct ret){0,1,0};
else
return (struct ret){0,0,1};
}
int mid = (left + right) / 2;
struct ret w1 = dc(a,left, mid);
struct ret w2 = dc(a,mid, right);
return (struct ret){w1.ans + w2.ans + w1.x * w2.y, w1.y + w2.y, w1.x + w2.x};
}
```
則時間是 T(n) = 2T(n/2) + O(1),時間複雜度是 O(n)。
### 總結
事實上,一個演算法的遞迴版本勢必會有對應的迴圈版本。以此題來看,看似遞迴的算法更難理解,然而如果把問題推廣到不只有 0 / 1 的數列呢(逆序數問題) ? 如果我們想要透過迴圈解題,可能會需要一個複雜的資料結構。但是如果透過遞迴,問題得以被簡單化(準確的說,使用更少的變數維護狀態,程式也更容易被理解)
對於逆序數的問題,其暴力解為:
```c=
int brute_check(int *a, int left ,int right){
int cnt = 0;
for(int i = left; i < right; i++){
for (int j = left; j < i; j++){
if(a[j] > a[i])
cnt++;
}
}
return cnt;
}
```
但是我們可以其實可以透過遞迴求解,而其解法其實只是 merge sort 的延伸:
``` c=
int merge(int *a, int left, int mid, int right){
int ans = 0;
int *tmp = malloc(sizeof(int)*(right - left));
int i = left;
int j = mid;
int k = 0;
while(i < mid && j < right)
{
if(a[i] <= a[j])
tmp[k++] = a[i++];
else
{
tmp[k++] = a[j++];
ans += mid-i;
}
}
while(i < mid )
tmp[k++] = a[i++];
while(j < right)
tmp[k++] = a[j++];
int count = 0;
for(int i = left;i < right; i++)
a[i] = tmp[count++];
free(tmp);
return ans;
}
int mergesort(int *a, int left, int right) {
if(right - left > 1)
{
int mid=(left+right)/2;
int w = mergesort(a,left,mid) + mergesort(a,mid,right);
int c = merge(a,left,mid,right);
return w + c;
}
else
return 0;
}
```
:::info
1. 詳細請直接參閱 [你所不知道的C語言:遞迴呼叫篇](https://hackmd.io/@sysprog/c-recursion?type=view),此處僅附上自己額外補充的程式!
2. 本人程式 sence 極差,如果找到程式執行的問題還請告知! 大家還請多多包容qq
:::
## 迷思: 迴圈 always 優於遞迴
前面有提及,所有可以用遞迴呈現的演算法都必定有迴圈的版本。而以往我們常常有相同類型的程式,迴圈的執行速度比遞迴快的迷思(額外的 allocate / push / pop stack frame)。這個想法固然沒錯,然而,得益於編譯器越來越強大,有些時候迴圈或者遞迴版本的程式,編譯器都可以直接幫我們優化成相同的 assembly。也就是說,寫成遞迴或者迴圈的形式在執行速度並無反別,然而遞迴程式反而可能會擁有更佳的可讀性! ~~compiler 萬歲~~
該如何寫遞迴,才可以使編譯器幫我們做到最佳化呢? 這裡就必須提到 Tail recursion。Tail recursion 是遞迴的一種特殊形式,指 function 的最後一步是呼叫自己。
遞迴之所以跑得慢,是因為大量使用 stack 來儲存某些資料,如果 function 的最後一步不僅是呼叫自己,則不斷遞迴的過程中,有許多變數/ return address 都需要額外的 stack。反過來說,如果 function 的最後一步是遞迴呼叫的話,從編譯器的角度來說,遞迴實際上是不需要不斷的 call stack 的,因此便可以被優化!
Reference:
* [遞迴的美麗與哀愁](https://www.ithome.com.tw/node/81087)
* [一般递归与尾递归(Tail Recursion)](http://qiusli.github.io/blog/2013/07/09/tail-recursion/)
* [浅谈尾递归](https://site.douban.com/196781/widget/notes/12161495/note/262014367/)
## 案例分析:數列輸出
思考以下程式的功能:
```c=
int p(int i, int N) {
return (i < N && printf("%d\n", i) && !p(i + 1, N))
|| printf("%d\n", i);
}
```
運行的關鍵在 `&&` 和 `||` 的特性。
* `A && B`: 如果 A statement 不成立,則 B 不會執行
* `A || B`: 如果 A statement 成立,則 B 不會執行
以及 [printf](https://linux.die.net/man/3/printf) 的回傳值是印出的字元數量,因此正常狀況下 p 裡的 printf 回傳不為 0。
因此,上面的程式其實等價於寫成:
```c=
int p(int i, int N) {
if (i < N){
printf("%d\n", i);
p(i + 1, N);
printf("%d\n", i);
}
else{
printf("%d\n", i);
}
}
```
## 案例分析:字串反轉
思考以下問題: 實做 char *reverse(char *s),反轉 NULL 結尾的字串,限定 in-place 與遞迴。
解題的思路是: 如果要 in-place 完成 reverse,只要透過數個 swap 就可以了。而提到 in-place 的 swap 勢必會想到 [XOR](https://hackmd.io/9tGfpJtLTwyyOzDvlyJS2w?view#XOR-Swap)。
```c=
void swap(char *a, char *b) {
*a = *a ^ *b; *b = *a ^ *b; *a = *a ^ *b;
}
```
有了 in-place 的 swap,就可以設計反轉字串的函式了。一個設計的思考模式是,對於一個字串 s,如果 reverse(s) 代表反轉 s,舉例反轉 `54321` 來說其實可以拆解成:
* reverse(s + 1): 先反轉後面的字串 -> **51234**
* 將 s 的 0 位置和 1 位置交換 -> **15234**
* reverse(s + 2): 把 1 位置後的字串恢復成原本的順序 -> **15432**
* reverse(s + 1): 位置 0 已經是正確的值了,此時再去反轉後面即可,字串的 0、1、... 會依序變成正確的數值,一直遞迴下去直到 '\0'
程式可以被寫成:
```c=
void reverse(char *s) {
if((*s == '\0') || (*(s + 1) == '\0'))
return;
reverse(s + 1);
swap(s, (s + 1));
if (*(s + 2) != '\0')
reverse(s + 2);
reverse(s + 1);
}
```
然而,這個演算的時間複雜度為 O(n^2) (對於長度為 n 的字串,每次只把第 0 個放好,然而再用同樣的方法去轉剩下 n - 1 個)
但是其實我們可以轉換一下思維,假如我們先計算出字串的長度的話,其實 reverse 長度 n 的字串就是把 0 <-> n-1 / 1 <-> n-2 ......依序 swap 而已,因此:
```c=
int reverse(char *head, int idx) {
if (head[idx] != '\0') {
int end = reverse(head, idx + 1);
if (idx > end / 2)
swap(head + idx, head + end - idx);
return end;
}
return idx - 1;
}
```
其實只要遞迴求長度,然後透過長度計算要交換的對應位置在哪裡就可以了! 複雜度便可以降低至O(n)。
### 補充: 迴圈版本
等同於上面遞迴的迴圈版本,可能會比較好理解一點?
```c=
void reverse(char *s) {
int len = 0;
while( *(s + len)!='\0')
len++;
int mid = len / 2;
for(int i = 0; i < mid; i++)
swap(s + i, s + len - i - 1);
}
```
## TODO
- [ ] 研究 [MapReduce with POSIX Thread](https://hackmd.io/@top30339/Hkb-lXkyg?type=view),深入了解何謂 MapReduce
- [ ] 研究 [Functional Programming 風格的 C 語言實作](https://hackmd.io/@sysprog/c-functional-programming),深入了解何謂 functional programming