Try   HackMD

Tail-call Optimization

如果一個函式 (function) 的最後行為是呼叫函式,我們稱這個最後行為為 tail call。

舉個例子,考慮到以下程式碼,函式 A 的最後行為是執行函式 B

int A(int data)
{
    return B(data);
}

要注意的是, tail call 的最後行為只允許呼叫函式,所以以下的程式碼就不是 tail call,因為他的最後行為是進行 2 * 1 + B(data) 的運算,而非單純呼叫函式。

int A(int data)
{
    return 2 * 1 + B(data);
}

Tail-call optimization,每當函式呼叫時,我們需要建立 stack frame 來紀錄函式的參數以及 return address,tail-call optimization 最主要的目地是避免建立過多的 stack frame 而造成 stack overflow,所以呼叫函式時,只記錄參數,接著跳到函式執行處開始執行。

如果一個函式的最後行為是呼叫自己本身,那我們稱這個最後行為為 tail recursion。Tail-call optimization 最常見的使用情境就是優化 tail-recursion 的函式。我們只要將 tail call 所呼叫函式的參數記錄下來接著跳到 caller 函式所在位置,即可執行。以下用兩個例子來說明 Tail-call optimization

Example 1

第一個範例是階層計算,以下階層計算程式碼符合 tail recursion,因此可以以 Tail-call optimization 來優化

int factorial(int n, int sum) 
{
    if (n == 0)
        return sum;
    sum *= n;
    return factorial(n - 1, sum);
}

分別比較使用以及不是用 tail-call optimization 產生之指令,頻繁的使用 load 和 store 指令,操作 stack frame 中的參數或 return address,顯然優化後的程式碼少了這些指令來操作 stack frame,效率較佳。

  • Without tail-call optimization
factorial:
	addi	sp,sp,-32
	sw	ra,28(sp)
	sw	s0,24(sp)
	addi	s0,sp,32
	sw	a0,-20(s0)
	sw	a1,-24(s0)
	lw	a5,-20(s0)
	bne	a5,zero,.L2
	lw	a5,-24(s0)
	j	.L3
.L2:
	lw	a4,-24(s0)
	lw	a5,-20(s0)
	mul	a5,a4,a5
	sw	a5,-24(s0)
	lw	a5,-20(s0)
	addi	a5,a5,-1
	lw	a1,-24(s0)
	mv	a0,a5
	call	factorial
	mv	a5,a0
.L3:
	mv	a0,a5
	lw	ra,28(sp)
	lw	s0,24(sp)
	addi	sp,sp,32
	jr	ra
  • With tail-call optimization
factorial:
	mv	a5,a0
	mv	a0,a1
.L3:
	beq	a5,zero,.L4
	mul	a0,a0,a5
	addi	a5,a5,-1
	j	.L3
.L4:
	ret

Example 2

第二個範例是計算 fibonacci number,以下程式碼是典型利用遞迴來計算 fibonacci number,但以下程式碼並不符合 tail recursion,因為最後行為是兩個函式的回傳值相加

int fib(int n) 
{
    if (n == 1 || n == 2)
        return 1;
    return fib(n - 1) + fib(n - 2);
}

以下是已經優化過的指令,可以發現不符合 tail-call 的函式(無法利用 tail-call optimization),即使優化過,仍然不可避免的必須對 stack frame 進行 load 和 store 操作。

fib:
	addi	sp,sp,-16
	sw	s0,8(sp)
	sw	s1,4(sp)
	sw	s2,0(sp)
	sw	ra,12(sp)
	addi	s0,a0,-1
	li	s1,0
	li	s2,1
.L3:
	bleu	s0,s2,.L5
	mv	a0,s0
	call	fib
	add	s1,s1,a0
	addi	s0,s0,-2
	j	.L3
.L5:
	lw	ra,12(sp)
	lw	s0,8(sp)
	lw	s2,0(sp)
	addi	a0,s1,1
	lw	s1,4(sp)
	addi	sp,sp,16
	jr	ra

為了可以成功利用 tail-call optimization 優化,我們進一步將這個計算函式改為符合 tail recursion 的版本再進行優化,程式碼如下:

// left = fib(n - 2), right = fib(n - 1)
int fib(int n, int left, int right)
{
    if (n == 0) {
        return left;
    }
    return fib(n-1, right, left + right);
}

以下是已經優化過的指令,可以發現,相較典型利用遞迴來計算 fibonacci number 的寫法,tail recursion 版本的 fibonacci number 計算函式帶入的參數較多,可讀性也較差,然而利用 tail-call optimization 優化後的指令,明顯簡潔許多,可看出 tail-call optimization 顯著的優化。

fib:
	mv	a5,a0
	mv	a0,a1
.L3:
	beq	a5,zero,.L4
	add	a4,a0,a2
	addi	a5,a5,-1
	mv	a0,a2
	mv	a2,a4
	j	.L3
.L4:
	ret