Try   HackMD

2024q1 Homework4 (quiz3+4)

第三周

測驗一

版本一 - 使用 log2 計算開平方

首先先使用 log2 取得最高位元,然後從最高位元開始檢查該位元是否有值,判斷依據使用 if ((result + a) * (result + a) <= N),若小於輸入代表還沒超過 N 的值,所以要加一,從高位元做到低位元,慢慢收斂。

版本二 - 不使用 log2 計算開平方

版本一和版本二概念可以說是一模一樣,差別是在最高位元的取得,版本一使用 log2 而版本二則使用 while loop 達成。

    while (n > 1) {
        n >>= 1;
        msb++;
    }

版本三 - Digit-by-digit calculation

N 是我們想要求的平方根,
N2=(an+an1+an2+...+a0)2,am=smoram=0

[a0a0a1a0a2a0...ana0a0a1a1a1a2a1...ana1a0a2a1a2a2a2...ana2...............a0ana1ana2an...anan]

[a0a0a1a0a2a0...ana0a0a1a0a2...a0an]

[a1a1a2a1...ana1a1a2...a1an]

[a2a2...ana2...a2an]

透過觀察上面的矩陣,我們可以把

N2 的式子整理成:
N2=an2+[2an+an1]an1+...+[2(i=1nai)+a0]a0

可以歸納出

Pm=an+an1+an2+...+am+1+am=Pm+1+aM

現在我們得到一個關鍵的式子

Pm=Pm+1+am,這說明了我們可以透過迭帶的方式求出
P0
。方法是從高位到低位檢查是否
Pm2N2
,若成立則代表該位元可以被設為
1
,若不成立則代表該位元為
0

講解完理論後該如何實作,上面提到檢查是否

Pm2N2,但算平方的成本太高,因此文章提出使用上一輪的資訊來判斷,我們定義一個新的參數
X
,然後將
Xm
表示為與上一輪有關,這時會有多一項參數
Ym
:
Xm=N2Pm2=Xm+1Ym

Ym 定義成下面式子:
Ym=Pm2Pm+12=(an+an1+...+am+1+am)2(an+an1+...+am+1)2=(2Pm+1+am)am

如果該位元為

0,則
Ym=0

如果該位元為
1
,則
Ym=2Pm+1am+am2=Pm+1am+1+am2


cm=Pm+1am+1
dm=am2

可以推出下一位元

cm1=Pm2m=(Pm+1+am)2m=Pm+12m+am2m
dm1=dm4

如果該位元為

0,則
cm1=cm/2

如果該位元為
1
,則
cm1=cm/2+dm

所以說我們就可以透過

cm1,最後一個位元為
n=0
,因此
c01=P0=an+an1+...+a0
即可算出所求。

對比程式碼

int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= AAAA) {
        int b = z + m;
        z >>= BBBB;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

與前兩個版本相同,都是從最高位開始計算,因此可以看到初始值使用 __builtin_clz(x) 來取得最高位,然後我們需要對 x = 0 的情況做處理。然後 m 即為上述數學推導中的

dm,還記得不管該位元是否為
0
,下一輪的
dm
也就是
dm1=dm/4
,無論如何都會除以
4
,因此利用 AAAA = 2,進行迭代。接下來 b 為數學推導的
Ym
,所以 z
cm
。因為

若該位為

1
cm1=cm/2+dm

若該位為
0
cm1=cm/2

所以我們將

cm/2 提出去,z >>= 1,所以 BBBB = 1。然後再進行比較 if (x >= b),若成立的話就將
dm
也就是 m 加到 z,經過反覆迭代即可以求出
P0
所求。

測驗二

為了可以不要使用 / % 來做除法,可以使用近似的方式求解,假設我們的除數為

q,那麼
x/q
則可以表示成
x1q
倒數形式,如果
q
是 2 的冪,那麼除法實作則會相當簡單,但並不是所有的數字的倒數都可以使用 2 進位精準表示,如
1/3
使用二進位會表示成 0.010101010101010... 的無理數,因此如何近似除數就成為這個測驗的主要技巧。

接下來我們來探討除數為 10 的情況,

1/10 使用二進位可以表示成 0.000110001...,表示的不精準,因此我們也需要透過近似方式求解,文章中除數使用
128139.84
,這個除數是根據精確度而決定。

假設 l 是一個比 n 還小的非負整數,只要以

n 算出來的除數在
n
除以該除術後在精度內,則
l
除以該除數能然會在精度內,近似的這個手法就是根據這一個猜想。因此我們可以根據最大情況 19 來考慮:
1.919x1.999.55x10

我們就可以使用
an2N
來配出一個可以用的數字
128139.84

決定好除數為

12813 後,也就是
x13128
,我們就可以使用 shift 來拼湊出這一個數字,首先將
13128=138116
,因此我們可以得到
138tmp=tmp8+tmp2+tmp
,如下。

    (tmp >> 3) + (tmp >> 1) + tmp

但因為我們會先將 tmp 右移,造成最低 3 位會被捨棄,因此需要先將最低 3 位存起來,然後再合併起來。

    d0 = q & 0b1;
    d1 = q & 0b11;
    d2 = q & 0b111;
    ((((tmp >> 3) + (tmp >> 1) + tmp) << 3) + d0 + d1 + d2)

這一個步驟我想了很久,為甚麼不先將 tmp * 13 再除 8,如下。

    ((tmp << 3) + (tmp << 2) + tmp) >> 3

我認為是因為 overflow 的問題,理論上來說要先左移也可以,不過被保存的數字也需要做相應的調整,需要保存最高的 3 位。

最後考慮 128,即可以得到商數和餘數。

    q = ((((tmp >> 3) + (tmp >> 1) + tmp) << 3) + d0 + d1 + d2) >> 7;
    r = tmp - (((q << 2) + q) << 1);

包裝函式

接下來看到包裝的函式,這個函式與我們所討論的看起來落差相當大。

#include <stdint.h>
void divmod_10(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = (in | 1) - (in >> 2); /* div = in/10 ==> div = 0.75*in/8 */
    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;

    *div = (q >> CCCC);
    *mod = in - ((q & ~0x7) + (*div << DDDD));   
}

x = (in | 1) - (in >> 2); 主要是將x = 0.75 * in。而我們可以透過下方的程式碼讓 q 趨近於 0.8,如此一來我們可以用到上面的概念,將

x/10 表示成
xa8
,其中
a
趨近於 0.8。

下方程式碼則是在使 q 趨近於 0.8。首先經過 q = (x >> 4) + x; 得到

q=in5164=0.796875in,然後會經過 4 次的 q = (q >> 8) + x;,得到
q=in(5164+5164256)=0.79998in
,最後經過四次會讓 q 趨近於 0.8 * in。所以說要求商的話需要再除以 8,也就是左移 3 位,因此 CCCC = 3

    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;

最後要算餘數,

mod=indiv10,比較 *mod = in - ((q & ~0x7) + (*div << DDDD));其中 ((q & ~0x7)) 的意思為
8div
,因此我們可以透過對 div 左移一位得到
2div
,所以DDDD = 1

比較這兩種方法,他們其實都是在處理除數的近似,只是數字不同罷了。

測驗三

版本一

看到第一種 ilog2 的實作,採用了非常簡單的概念,檢查 i 裡面有沒有值,如果有則 log++ 沒有的話就 shift 一位,會一直做到最高位被移走。可以注意到 log 的初始值為 -1,這是因為需要考慮到

20 的情況,也就是說不能把
20
算進去。這個實作方法相當簡單,同時精度也不高,只能處理 log2 的整數部分。

int ilog2(int i)
{
    int log = -1;
    while (i) {
        i >>= 1;
        log++;
    }
    return log;
}

版本二

接下來看到版本二的實作,觀察下方程式碼我們可以發現,基本概念與版本一相同,不同的地方是將

216,
28
,
24
,
22
特別處理,為甚麼要這麼做呢,就是因為如此一來,可以直接檢查很多位,不用像版本一這樣需要把所有的位數都檢查一遍。因此根據
216
,一次檢查 16 位,將 i shift 16 位,並將結果 +16,AAAA = 65536,依此類推,BBBB = 256CCCC = 16

這個版本可以有效加速運算,我們實際舉一個例子,假設 ilog2(257),因此會在 while (i >= 256) 這個迴圈執行一次然後就結束,但版本一要執行 9 次,因為 257 的 binary 有 9 位。

static size_t ilog2(size_t i)
{
    size_t result = 0;
    while (i >= AAAA) {
        result += 16;
        i >>= 16;
    }
    while (i >= BBBB) {
        result += 8;
        i >>= 8;
    }
    while (i >= CCCC) {
        result += 4;
        i >>= 4;
    }
    while (i >= 2) {
        result += 1;
        i >>= 1;
    }
    return result;
}

版本三 - 使用 GNU extension

版本三的程式相當簡潔,使用了 GNU extension,__builtin_clz 是一個內建函式,用來計算一個整數在二進制表示中,從最高位(最左邊)開始的零位的個數,前兩個版本都是從最低位開始算,而這個正好相反。

可以來推理 DDDD 的答案,值觀上來想,v 應該就是答案了,但是如果 v 是 0 呢? __builtin_clz() 若是參數為 0 則結果未定義,所以我們要讓 v 無論如何都會有 1 位,如此一來則不會有未定義的情況發生,因此 v|1 這個運算僅僅是在加強函式的強健性。

int ilog32(uint32_t v)
{
    return (31 - __builtin_clz(DDDD));
}

測驗四

這一測驗主要在計算指數加權移動平均 EMWA,使經過時間越久的歷史資料的權重也會越低,數學的定義為:

St={Y0,t=0αYt+(1α)St1t>0

觀察這一個式子可以發現

α 就是所謂的權重,
α
越高代表先前所影響越少,因為會被
(1α)
的次方稀釋掉,我們可以透過調整這一個參數以控制以往資訊對 EMWA 的重要性,其中
Y0
為初始的資料,
Yt
為當前時刻的資料,
St
為前一刻所計算出的平均數。

St=αYt+(1α)St1
=α(Yt+(1α)Yt1+(1α)2Yt2+...+(1α)kYtk+...+Y0)

接下來看到程式碼的部分,首先定義一個 struct ewma,裡面有 internal 代表計算出的平均數,factor 為一個縮放因子,我們要將 value 放進這一個結構裡都要乘上 factor,而 factor 為 2 的冪。weight 為權重,一樣是 2 的冪。

struct ewma {
    unsigned long internal;
    unsigned long factor;
    unsigned long weight;
};

以下程式碼為初始化的部分,可以觀察到存進去 weight factor 都是以 log2 存進去,而 interval 一開始會被設為 0。

void ewma_init(struct ewma *avg, unsigned long factor, unsigned long weight)
{
    if (!is_power_of_2(weight) || !is_power_of_2(factor))
        assert(0 && "weight and factor have to be a power of two!");

    avg->weight = ilog2(weight);
    avg->factor = ilog2(factor);
    avg->internal = 0;
}

接下來這個程式為主要計算平均值的地方,首先我們要先判斷 avg->internal 有沒有東西,如果沒有的話我們直接把 val 加進去 internal,但要注意的是,要將 val 乘上 factor,對應到數學式的

Y0,因為 avg->factor 已經取過對數,因此我們可以用左移來實現。

如果 avg->interval 已經有值的話,那我們需要使用

St=αYt+(1α)St1 來做運算,以數學表示來說,
α
是一個介於 0 到 1 的值,但我們的 struct 裡面定義的 weightunsigned long,因此需要對數學式做一些操作。

set  w=1α
St=Yt/w+(11/w)St1

    =(Yt/w+(11/w)St1)w/w

    =(Yt+(w1)St1)/w

    =((St1wSt1)+Yt)/w

根據上式,可以對照程式碼,EEEE = avg->weightDDDD = avg->factor

struct ewma *ewma_add(struct ewma *avg, unsigned long val)
{
    avg->internal = avg->internal
                        ? (((avg->internal << EEEE) - avg->internal) +
                           (val << FFFF)) >> avg->weight
                        : (val << avg->factor);
    return avg;
}

測驗五

測驗五結合了 ceil 和 log2,可以看到以下程式碼與測驗三的版本二類似。差在測驗五會先 x--,進而造成判斷的數值在測驗三為0d65536,而測驗五為 0d65536 - 0d1 = 0xFFFF

測驗三中的 result += 16 在測驗五中為 r |= shift,而在最後一行return (r | shift | x > GGG) + 1; 與測驗三不一樣的地方在於 x > GGG,這一部分是為了要無條件進位,GGG = 1,如此一來只要大於等於 2 的數字在第一個位元就一定是 1。

int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;

    x--;
    r = (x > 0xFFFF) << 4;
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > GGG) + 1;       
}

測驗三 - 版本二

    while (i >= 65536) {
        result += 16;
        i >>= 16;
    }
    while (i >= 256) {
        result += 8;
        i >>= 8;
    }
    ...

第四周

測驗一

popcount 是用來計算二進位表示中有多少位元是 1。透過 v &= (v - 1),可以將最低位元的 1 清掉並記錄,舉例來說 0b10100 - 0b1 = 0b10011,會讓最低位元為 1 以下(含)的位元取反,然後再做且運算即可以達到效果。而 n = -(~n) 可以用二補數來解釋

n= n+1 為二補數的定義,因此這一段程式碼要表達的其實與 n++ 相同。

unsigned popcount_naive(unsigned v)
{
    unsigned n = 0;
    while (v)
        v &= (v - 1), n = -(~n);
    return n;
}

然而因為上述算法的時間複雜度取決於 set bit 的個數,因此可以改寫為常數時間時間複雜度的實作。

以下程式碼的運算基於以下數學運算:

popcount(x)=xx2x4...x231

假設

x=b31...b3b2b1b0
x[3:0]
這四個位元可以推導成:

popcount(x)=(23b3+22b2+21b1+20b0)(22b3+21b2+20b1)(21b3+20b2)20b3
  =(23222120)b3+(222120)b2+(2120)b1+20b0

程式碼實作上以 4 bit 為單位,以上運算對應到

    n = (v >> 1) & 0x77777777;
    v -= n

透過將 v 右移 1 位並減掉達成

(23222120)b3,重複三次可以得到
(23b3+22b2+21b1+20b0)(22b3+21b2+20b1)(21b3+20b2)
如此一來我們便可以得到 4 bit 的 popcount。

接下來這一步驟,目的是要獲取所有位元也就是 8 個 4 bit 的 popcount,首先 (v + (v >> 4)) 透過位移得到前後兩組 4 bit 的總和,再透過 0x0F0F0F0F 當作遮罩,濾掉重複部分。

B7 B6 B5 B4 B3 B2 B1 B0  // v
 0 B7 B6 B5 B4 B3 B2 B1  // (v >> 4)
 
 0 (B7+B6) 0 (B5+B4) 0 (B3+B2) 0 (B1+B0) // (v + (v >> 4)) & 0x0F0F0F0F

接下來 v *= 0x01010101 利用乘法值式的特性,會再中間項出現所有 nibble 的累加,而這個位置剛好會在

27 這個位置,因此最後將 v 右移 24 位。

    v = (v + (v >> 4)) & 0x0F0F0F0F;
    v *= 0x01010101;
unsigned popcount_branchless(unsigned v)
{
    unsigned n;
    n = (v >> 1) & 0x77777777;
    v -= n;
    n = (n >> 1) & 0x77777777;
    v -= n;
    n = (n >> 1) & 0x77777777;
    v -= n;

    v = (v + (v >> 4)) & 0x0F0F0F0F;
    v *= 0x01010101;                                     

    return v >> 24;
}

接下來針對 LeetCode 477 考慮 totalHammingDistance。我們可以用矩陣的觀點來看,total hamming distance 是倆倆比較的總和,假設有 A, B, C,三個元素相互比較可以建出下表,其中 hd()hammingDistance()

A B C
A 0 hd(A, B) hd(A, C)
B hd(A, B) 0 hd(B, C)
C hd(A, C) hd(B, C) 0

透過這個表格我們可以明白對稱性之影響,因此兩個 for loop 加總起來必須除以 2,AAAA = 1

int totalHammingDistance(int* nums, int numsSize)
{
    int total = 0;;
    for (int i = 0;i < numsSize;i++)
        for (int j = 0; j < numsSize;j++)
            total += __builtin_popcount(nums[i] ^ nums[j]); 
    return total >> AAAA;
}

測驗二

這一個測驗的目標是不使用任何除法就計算出餘數,相比第三周的題目,需要先算出商數,再使用被除數減掉除數成以商數,這種方法可以直接計算出餘數。根據Hacker's Delight 提出的定理。

If  ab(mod  m)  andcd(mod  m),then
a+cb+d(mod  m)and

acbd(mod  m)

因此我們以

mod  3 為例可以看到
11(mod  3),  21(mod  3)
,我們可以透過對
1
1
取餘數表示成一個二進位的式子。

2k=1(mod  3)(1(mod  3))k

並套用 Hacker's Delight 提出的定理。

2k=(1)k(mod  3)

n 設為輸入,將
n
表示為
bn1bn2...b1b0
,則
n=i=0n1bi2i=n=i=0n1bi(1)i(mod  3)
,如此一來便可以透過對奇數位和偶數位的位元總和計算出
3
的模數。

看到以下程式碼 popcount(n ^ 0xAAAAAAAA) + 23,這一步是從另外一個定理推導出來的,

popcount(x&m)popcount(x&m)=popcount(xm)popcount(m)

n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA),透過這一個式子我們可以比較奇數位偶數位的模數,因為若是奇數位有值,他的模數會是 -1,偶數則是 1,透過相加即可以計算出所有的模數,所以我們要減掉奇數位的 popcount

popcount(n ^ 0xAAAAAAAA) - 16 + 39,這邊為了不要讓算出來的數落在負數,所以我們可以加上一個 3 的倍數,在這個程式選擇了 39,理論上可以選擇任何 3 的倍數,但後續的運算也需要做相對調整,此時

23n23+32=55,所以我們要針對 n 在做一次運算,因為這時 n 的最大值為 55,也就是說不會超過 6 bit,所以使用 n = popcount(n^ 0x2A) - 3 來計算模數,這時
3n2

return n + ((n >> 31) & 3); 這時

3n2,所以我們要針對負數做處理,方法是將 n >> 31 也就是判斷 sign bit,注意這邊是使用算數位移而非邏輯位移,所以若是負數的話會造成全 1,此時可以與 3 座且運算,也就是若小於 0 則加 3。

int mod3(unsigned n)
{
    n = popcount(n ^ 0xAAAAAAAA) + 23;
    n = popcount(n ^ 0x2A) - 3;
    return n + ((n >> 31) & 3);
}

n = popcount(n ^ 0xAAAAAAAA)可以看成 n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA) + 16,可以看到因為我們加上了 16 所以會造成 table 有偏移的狀況,我們需要手動修正。假設輸入為 16,也就是 0b10000,經過運算得到的 n 為 17,就可以對第 17 個元素修正為 16(mod 3),推回來可以決定第 0 個元素是從 2 開始。

int mod3(unsigned n)
{
    static char table[33] = {2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1 };
    n = popcount(n ^ 0xAAAAAAAA);
    return table[n];
}

Hacker's Delight 還提出了一種不用使用 popcount 的計算模數方法,其中關鍵在於,可以將 n 拆成兩部分,並將左邊的部分右移並與右邊相加不會對模數計算造成影響,所以要怎麼拆就是一大關鍵。取 7 的模數會有此關係式:

8k1(mod  7)
8k=23k
,所以可以看出只要是 shift 3 的倍數就不會影響到 mod 7 運算。值得注意的是 mod 不同的基數會有不同的限制。

int remu7(unsigned n) {
     static char table[75] = {0,1,2,3,4,5,6, 0,1,2,3,4,5,6,
                              0,1,2,3,4,5,6, 0,1,2,3,4,5,6, 0,1,2,3,4,5,6,
                              0,1,2,3,4,5,6, 0,1,2,3,4,5,6, 0,1,2,3,4,5,6,
                              0,1,2,3,4,5,6, 0,1,2,3,4,5,6, 0,1,2,3,4};
     n = (n >> 15) + (n & 0x7FFF); // Max 0x27FFE.
     n = (n >> 9) + (n & 0x001FF); // Max 0x33D.
     n = (n >> 6) + (n & 0x0003F); // Max 0x4A.
     return table[n];
}

這邊利用了

n  mod  7=87n(mod  8)0x24924924
2327
。最後一步則是將原本的 0~7 mapping 為 0~6。

int remu7(unsigned n) {
     n = (0x24924924*n + (n >> 1) + (n >> 4)) >> 29;
     return n & ((int)(n - 7) >> 31);
}

透過結合上述兩種方法,可以判斷 CCCC = 15,因為 0x7FFF,將 x 拆成兩部分。

static inline int mod7(uint32_t x)
{
    x = (x >> CCCC) + (x & UINT32_C(0x7FFF));
    /* Take reminder as (mod 8) by mul/shift. Since the multiplier
     * was calculated using ceil() instead of floor(), it skips the
     * value '7' properly.
     *    M <- ceil(ldexp(8/7, 29))
     */
    return (int) ((x * UINT32_C(0x24924925)) >> 29);
}

觀察 move_masks,我們可以發現紀錄井字遊戲的位置時不單單只是記錄九宮格的位置,而是使用八條可能的路徑紀錄,三條橫線,三條直線,兩條斜線,也就是說一個位置的改變會改變到多條路徑。move_masks 以 4 bit 表示一條路徑,若為 0b0100,則代表該路徑上的第一個位置有值,0b0010代表該路徑上的第二個位置有值,以此類推。而從最高位到最低位依序代表的路徑為第一條橫線,第二條橫線,第三條橫線,第一條直線,第二條直線,第三條直線,第一條斜線,第二條斜線。所以我們以中間的位置為例,該點影響了第二條直線,第二條橫線和兩條斜線,而在三條路徑上剛好都是第二個位置有值,此點可表達為 0x02002022

static const uint32_t move_masks[9] = {
    0x40040040, 0x20004000, 0x10000404, 0x04020000, 0x02002022,
    0x01000200, 0x00410001, 0x00201000, 0x00100110,
};

顯而易見的,勝利的條件即為任意一條路徑三個位置都為同一方,也就是 0b0111,因此只要有任意一條 +0x1 然後與 0x8 且運算有值,
即為獲勝,所以 BBBB = 0x11111111

/* Determine if the tic-tac-toe board is in a winning state. */
static inline uint32_t is_win(uint32_t player_board)
{
    return (player_board + BBBB) & 0x88888888;
}

測驗三

X Tree 結合了 AVL Tree 和 紅黑樹的特性。

看到 remove 的函式,首先檢查愈刪除節點的右子樹,若存在那我們會尋找最小的節點,使用 xt_first 尋找,反之若要找最大局點,則使用 xt_last 。找到最小節點後,利用 xt_replace_right 將最小節點放到 del 指標,如此才會符合 小-中-大 的特性,所以 AAAA = least。因為改變了 del 右子樹的結構,所以我們要使用 xt_update 去檢查目前是否平衡,若不平衡就要進行 rotate。

同理,若存在左子樹,就要將左子樹的最大節點放到 del 的位置,因此 CCCC = most,而 DDDD = xt_left(most)

函式的最後我們要檢查整棵樹是否達成平衡,因此 EEEE = root FFFF = parent,也就是檢查 root 是否達到平衡。

static void __xt_remove(struct xt_node **root, struct xt_node *del)
{
    if (xt_right(del)) {
        struct xt_node *least = xt_first(xt_right(del));
        if (del == *root)
            *root = least;

        xt_replace_right(del, AAAA);
        xt_update(root, BBBB);
        return;
    }

    if (xt_left(del)) {
        struct xt_node *most = xt_last(xt_left(del));
        if (del == *root)
            *root = most;

        xt_replace_left(del, CCCC);
        xt_update(root, DDDD);
        return;
    }

    if (del == *root) {
        *root = 0;
        return;
    }

    /* empty node */
    struct xt_node *parent = xt_parent(del);

    if (xt_left(parent) == del)
        xt_left(parent) = 0;
    else
        xt_right(parent) = 0;

    xt_update(EEEE, FFFF);
}

來看看 xt_update ,勇於更新樹中節點的平衡狀態和 hint 值,首先透過 xt_balance 計算平衡因子,並獲取節點 n 的前一個 hint 和父節點 p,若平衡因子小於 -1 代表樹向右傾斜,要將樹進行右旋轉,若大於 1 代表向左傾斜,要將樹進行左旋轉,然後更新 nhint 為右子樹中的最大值。最後如果節點的提示值為 0 或者 hint 發生了變化,則遞迴地使用 xt_update,更新父節點 p 的平衡狀態和 hint

static inline void xt_update(struct xt_node **root, struct xt_node *n)
{
    if (!n)
        return;

    int b = xt_balance(n);
    int prev_hint = n->hint;
    struct xt_node *p = xt_parent(n);

    if (b < -1) {
        /* leaning to the right */
        if (n == *root)
            *root = xt_right(n);
        xt_rotate_right(n);
    }

    else if (b > 1) {
        /* leaning to the left */
        if (n == *root)
            *root = xt_left(n);
        xt_rotate_left(n);
    }

    n->hint = xt_max_hint(n);
    if (n->hint == 0 || n->hint != prev_hint)
        xt_update(root, p);
}