Try   HackMD

2024q1 Homework4 (quiz3+4)

contributed by < lintin528 >

2024q1 第 3 週測驗一

版本一 使用 <math.h> 之 log2 計算開平方根

該版本為透過 (int) log2(N) 求得 msb (Most Significant Bit) ,並且從 1 << msb 開始進行二分逼近法,往低位元逐個檢查當 該位元 = 1 即 (result + a) 時平方相較原本的 N 大或是小,以此設定低位元的所有 bits

int msb = (int) log2(N);
int a = 1 << msb;
int result = 0;
while (a != 0) {
    if ((result + a) * (result + a) <= N)
        result += a;
    a >>= 1;
}
return result;

版本二 不依賴 log2 並保持原本的函式原型 (prototype) 和精度 (precision)

可以直接透過 while loop 取得最高有效位元,而不使用 log2

int msb = 0;
int n = N;
while (n > 1) {
    n >>= 1;
    msb++;
}

版本三 Digit-by-digit calculation

透過

x=N2=(an+an1+an2+...+a0)2,am=2moram=0 的拆解方式,逐項比對
Pm2=(an+an1+...+am)2
N2
來設定
am
對應的 bit ,若
N2P2
則設定此 bit 為 1。

再透過

N2 項的展開,可歸納出

Pm2=Pm+12+(2Pm+1+am)am

此處為了效能考量,不直接逐步進行

N2P2 的驗證,採用另一種策略,透過
Xm,Ym
求得
cm,dm
,並在進行完迭代後得到最後的
c0
則是結果
P0
,按照以下步驟:

首先將每個 bit 的檢測

N2P2 的步驟轉換為計算

Xm=N2Pm2=Xm+1Ym 為當下之
Pm2N2
之差距,隨著 m 減少,此差距應會慢慢降低,意義與逐位元逼近相同。

Ym=Pm2Pm+12=(2Pm+1+am)am 理解為上次測量差距的減少量,即
Xm+1Xm

一開始我疑惑為何要將

Ym 進行拆解的動作,想了一下發現這就是這個演算法可以節省效能的關鍵,相比起原本需要大量的平方運算,在此處拆解為
cm,dm
將發現可以直接透過該輪的比對結果透過較為簡單的除法、加法運算取得下一次迭代的
cm1,dm1

至此,將

Ym 拆解為:

cm=2Pm+1am

dm=(am)2

Ym={cm+dm,ifam=2m0,ifam=0

在每次的迭代中,會透過比較

Xm+1,Ym 判斷出
am
是否有值,並作相對應的轉換,在以下程式碼中, x 則為上一次結果的平方差,b 為該次推算出的
Ym
z 則為
cm1
m 對應到
dm

if (x >= b)
    x -= b, z += m;  

而下次迭代的值即為:

cm1=Pm2m=(Pm+1+am)2m=Pm+12m+am2m={cm2+dm,  if  am=2mcm2

dm1=dm4

看到以下 for 迴圈,可以看到每次迭代結束後 m >>= AAAA 的行為,是確認完

am 之後更新
Xm
cm1
後,做
dm1
的更新,對應到
dm1=dm4
, 因此 AAAA 即為 2 ;觀察到 z >>= BBBB 的行為是對應到在設定完該次的
Ym
後,固定先將
cm1=cm2
,並根據
am
判斷是否要加上
dm

for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= AAAA) {
    int b = z + m;
    z >>= BBBB;
    if (x >= b)
        x -= b, z += m;               
}
return z;

最後當結束迴圈後的

c1 即為平方根。

在分析這個演算法的時候產生了一個疑問,為何

dm 一次是右移兩個 bit ,這該怎麼對應到逐位元檢查的方式呢? 後來又回去看了題目敘述一次,發現
dm
所代表的意義是
(2m)2
因此在逼近時,原本做法中,直接對平方根的估測值右移一個 bit
dm
則為該平方根估測值的平方,因此需要右移兩個 bit

2024q1 第 3 週測驗二 mod 10 和 div 10 操作

在不使用 % / 的前提下,要做到商和餘數的運算,題目的方式是透過結合 + - << >> 的方式去完成,當我們嘗試除以一個數時,會利用倒數的方式避免使用前兩個運算子,然而當除數的倒數中,分母不可以透過二的冪次表示時,結果將會是不精確的,因此提供了一個逼近的方式,得到一個除數使得最後的 tmp/10 結果能達到小數點後第一位都是精確的。

因為題目中是以 10 進位的方式去計算 ,所以之後的推導是以 tmp 不大於 19 作為假設 ( a, b 兩個某一位數 0-9 加上 carry bit ,最高即 9+9+1 ),在這之後,找到適合的除數,因為要透過 bit operation 得出,因此可以透過條件:

x=an2N,1.91.9x1.99
取得適當的 a, N,此處我們選擇
a=13,N=128

對應到以下的程式碼,在拼湊 13*tmp/8 = tmp/8 + tmp/2 +tmp 的過程中,因為會使用到最高 tmp >> 3 三個位元的位移,需要先將後三個位元保存,之後將 13*tmp/8 左移回 tmp 時。再將其補回,最後右移七位得到最後的 13*tmp/128 ,即所求商。

d0 = q & 0b1;
d1 = q & 0b11;
d2 = q & 0b111;
q = ((((tmp >> 3) + (tmp >> 1) + tmp) << 3) + d0 + d1 + d2) >> 7;
r = tmp - (((q << 2) + q) << 1);

這裡我很疑惑為何需要透過 13*tmp/8 嘗試拼湊 13*tmp 的過程中右移而不是選擇直接
tmp + tmp << 2 + tmp << 3 ,原本的想法是選擇捨去掉最低三位元的精確度來保證高位元不會因為左移而被捨去,但仔細思考發現在這個過程 (((tmp >> 3) + (tmp >> 1) + tmp) << 3) 中,我最後依舊還是會有一項進行左移 3 bit 跟一項左移 2 bit 的這個行為,這樣我不僅沒有避免高位元的左移,甚至還使得低位元因為右移被捨棄,需要修正回來的額外考量;另外既然這邊我已經成功拼湊出

13tmp8 了,在這裡先左移 3 bit 修正回
13tmp
後而不直接進行
13tmp8116
,直接繼續右移 4 bit 的原因是什麼,我的推測是為了修正右移後的誤差,才需要左移回來並加上 d0,d1,d2 ,如果我想做這個修正,或許可以做成 (((tmp >> 3) + (tmp >> 1) + tmp) + ((d0 + d1 + d2) >> 3)) >> 4 ,這樣既考慮到避免高位元的左移,又可以達到右移後的修正問題。

還有

13tmp8 實際上還是有可能會有溢位的情況發生,可以改成
13tmp1618
去實作上述的假設,即 (((tmp >> 4) + (tmp >> 2) + tmp >> 1) + ((d0 + d1 + d2 + d3) >> 4)) >> 3,雖然低位元的誤差來到了 4 bit,但也完全避免了溢位問題。

最後,題目中的程式碼乍看之下與上方討論的相差甚遠,實際上他們唯一的共同點就是在進行除數的逼近,但逼近的方式不同而已。

void divmod_10(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = (in | 1) - (in >> 2); /* div = in/10 ==> div = 0.75*in/8 */
    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;

    *div = (q >> CCCC);
    *mod = in - ((q & ~0x7) + (*div << 1));   
}

這裡一開始透過 uint32_t x = (in | 1) - (in >> 2); 取得 0.75 * in(in | 1) 這個部分當最低位元為 0 時將其改為 1 ,目前不知道實際的意義是什麼,再來的一連串

uint32_t q = (x >> 4) + x;
x = q;
q = (q >> 8) + x;
q = (q >> 8) + x;
q = (q >> 8) + x;
q = (q >> 8) + x;

即是為了逼近 0.8 * in ,並在最後右移 3bit 以達到逼近

in110 的結果,因此這邊的 CCCC 即為 3,最後的在計算餘數時透過除法原理,使用被除數減掉 除數乘以商,即 ((q & ~0x7) + (*div << 1)) ,而在此題中除數為 10 ,其中 (*div << 1) = 2 * 除數 ,因此剩餘的部分 ((q & ~0x7)8 * 除數 ,即為 (*div << 3) ,與這邊的 q & ~0x7 是同樣意義,這邊之所以需要過濾掉 q 的後三個位元是因為在逼近時這三個位元將會在計算商數時被捨去掉,因此 (*div << 3) 時,後三個位元應該都是 0。

2024q1 第 3 週測驗三 ilog2 計算

版本一

從二進位數表示法中可以看出,某數的

log2 值即為最高位元的位元數減一,此數可表示為
2N+an12N1+...+a020=2N(1.xxx)
log2
後整數部分即為
N

int ilog2(int i)
{
    int log = -1;
    while (i) {
        i >>= 1;
        log++;
    }
    return log;
}

版本二

此處在做的其實也是在尋找最高位元,並透過二分法的方式,先檢查最高位元在前 16 位還是後後 16 位,因此可推斷 AAAA65536 ,若成立即將結果加上 16 ,代表在高位元,檢查完之後,繼續確認他是在這 16 位元中的前半還是後半,到最後即可確定最高 bit 位數是第幾位,比較特殊的是這邊在最後的 while 迴圈比較是逐 bit 做縮減而不同於前面的二分法。

因此,以此類推,BBBBCCCC 應為 25616

static size_t ilog2(size_t i)
{
    size_t result = 0;
    while (i >= AAAA) {
        result += 16;
        i >>= 16;
    }
    while (i >= BBBB) {
        result += 8;
        i >>= 8;
    }
    while (i >= CCCC) {
        result += 4;
        i >>= 4;
    }
    while (i >= 2) {
        result += 1;
        i >>= 1;
    }
    return result;
}

版本三 運用 GNU extension __builtin_clz

其中 __builtin_clz 計算該整數由高位元像低位元算起,到第一個 set 之前的 0 個數,因此這裡的 31 - __builtin_clz(DDDD) 可以理解成 32 - __builtin_clz(DDDD) - 1 即最高位元數 -1 ,考慮到 __builtin_clz() 在輸入為零時是未定義,因此為了增強安全性,希望在輸入零時函式也回傳 0 ,設定 DDDDv|1

return (31 - __builtin_clz(DDDD));

2024q1 第 3 週測驗四 EWMA 指數加權移動平均

該次測驗以設定權重的方式使得時間越久的資料銓重將會越低,在此前提下計算出平均數。
數學定義為:

St={Y0,t=0αYt+(1α)St1t>0
其中
Yt
為當下資料,
St
為目前的 EWMA
透過以上的式子,展開後可以得到:

St=αYt+(1α)St1

=α(Yt+(1α)Yt1+(1α)2Yt2+...+(1α)kYtk+...+Y0)

可以看出在多次的資料進入後,原本的資料將會乘上

(1α)k ,即最終的權重。

再來到 EWMA 的結構體:

其中 internal 代表當下的

Stfactor 根據程式碼中的註解 "Scaling factor for the internal representation of the value." ,可以得知在這個演算法中 factor 即為定點運算時使用到的 fraction bit 位數,而 weight 是實作過程中資料的 decay rate。這邊 factorweight 都為二的冪,因此在接下來的運算中都可以透過左右移的方式來計算。

struct ewma {
    unsigned long internal;
    unsigned long factor;
    unsigned long weight;
};

接下來分析新增資料時的行為,

St=αYt+(1α)St1 ,可對應到 avg->internal 的更新過程,此處的 input val 即為本次的輸入資料
Ym
,經過剛剛結構的定義,可以看到 weight 相對應的就是
α
,就是 decay rate ,但這兩者為倒數關係,由於這裡使用了定點數運算,不能直接使用浮點數,會先將
α
的倒數存入 weight 再進行右移以完成衰減的計算,這邊也可以知道
α
12N
, 當我們將原式以
α=1w
代入時,可推導出:
St=1wYt+(11w)St1

可以看到程式碼中最後將計算結果進行 >> avg->weight 的動作,可以判斷將

1w 提出來

St=Yt+(w1)St11w

整理到這樣即可以對應到 (((avg->internal << EEEE) - avg->internal) + (val << FFFF))

val << FFFF 對應到

Yt 並經過 fraction bit 個左移,因此可判斷 FFFFavg -> factor

(((avg->internal << EEEE) - avg->internal) 對應到

(w1)St11w ,可以判斷 EEEEavg -> weight

struct ewma *ewma_add(struct ewma *avg, unsigned long val)
{
    avg->internal = avg->internal
                        ? (((avg->internal << EEEE) - avg->internal) +
                           (val << FFFF)) >> avg->weight
                        : (val << avg->factor);
    return avg;
}

2024q1 第 3 週測驗五 ceil_ilog2

與第三題的運算其實基本上是相同的,差別在這邊先做了 x--; ,且每次在進行的判斷式,如 (x > 0xFFFF) 與測驗三 (x >= 65536) 是一樣的,這邊就是為了讓值為

2N 的數先在上面的判斷少一位,並在結果補上,這個部分就是在做 ceil,因此可以推斷每次的 x > 0xF..F 判斷,其值都為
2N1
,最後的 (r | shift | x > GGG) + 1 ,其實可以展開為,就是最後一個位元的二分法,所以 GGG 應為 1

r |= shift;
shift = (x > GGG);
r |= shift;
int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;

    x--;
    r = (x > 0xFFFF) << 4; 
    
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > GGG) + 1;       
}

2024q1 第 4 週測驗一 Total Hamming Distance

popcount_naive

為了計算出 Hamming Distance ,將會使用到 popcount_naive 找出二進位中所有 1 的個數,提供了兩種實作方式:

第一種

首先透過 v &= (v - 1) 這會刪除 v 中被設定為 1 的最低位元,其中的原理是透過 -1 的方式,使的 (v - 1) 中最低的 1 位元變成 0 ,透過例子來模擬,先將一個十進位的數拆解成

2N+2L+...+2M ,因此他的最低位元就是右邊數來第 (M+1) 位,再進行 -1 後必定可以得到一結果是
2N+2L+...+02M+2M1+2M2+...1
,因此在做 & 運算時,相當於最低位的 set 被刪除。

n = -(~n) 是利用二補數運算的特性得到 n++ 的結果,但不清楚為何需要這樣做,效能是否會提高。

unsigned popcount_naive(unsigned v)
{
    unsigned n = 0;
    while (v)
        v &= (v - 1), n = -(~n);
    return n;
}

由於以上的算法執行時間為 O(N) , N 為二進位表示法中的所有 set 數量,以下將其改良為常數時間。

unsigned popcount_branchless(unsigned v)
{
    unsigned n;
    n = (v >> 1) & 0x77777777;
    v -= n;
    n = (n >> 1) & 0x77777777;
    v -= n;
    n = (n >> 1) & 0x77777777;
    v -= n;

    v = (v + (v >> 4)) & 0x0F0F0F0F;
    v *= 0x01010101;                                     

    return v >> 24;
}

這邊最主要的概念就是將每四個位元分成一個區間,如

b3b2b1b0 ,並在上面的三次左移與 v -= n 操作計算出
(23222120)b3+(222120)b2+(2120)b1+(20)b0
,即每一個四位元區間的 set 個數總和,後來再經過這兩個操作,讓 A6+A4+A2+A0 = B7+B6+B5+B4+B3+B2+B1+B0 ,在這裡 B0 即為 0-3 bit 中,所有 set 的總和,因此 B7+B6+B5+B4+B3+B2+B1+B0 即代表這個 32 位元的整數的二進位表示法中,所有 set 的個數總和。

v = (v + (v >> 4)) & 0x0F0F0F0F;
v *= 0x01010101;    

且這個總和將保存在第 25-28 bit 的這個區間內,所以最後才會需要 return v >> 24

再來看到 "兩個整數間的 Hamming distance 為其二進位的每個位元的差" ,因此當我們在分析 totalHammingDistance 時,事實上就是在做兩兩的 Hamming distance 計算,而該計算又可以寫為

int hammingDistance(int x, int y)
{
    return __builtin_popcount(x ^ y);
}

透過 XOR 得到這兩個整數的二進位表示法中那些位元是不同的,並透過剛剛定義的 popcount 得到最終的 Hamming distance

在多個整數之間,計算的程式碼如下:

int totalHammingDistance(int* nums, int numsSize)
{
    int total = 0;;
    for (int i = 0; i < numsSize;i++)
        for (int j = 0; j < numsSize;j++)
            total += __builtin_popcount(nums[i] ^ nums[j]); 
    return total >> AAAA;
}

將注意到這邊的兩個 for 迴圈中, jloop 起始條件並不是 j = i ,這將導致除了自己與自己本身的 Hamming distance 計算之外,其餘兩兩間的比較次數將會加倍,如將進行 __builtin_popcount(nums[0]) ^ nums[1])__builtin_popcount(nums[1]) ^ nums[0]),因此需在最後做修正,所以 AAAA 應為 1

2024q1 第 4 週測驗二 Remainder by Summing digits

若除數為

m=2k±1 ,在此以三為例,則可以透過 popcount 的方式計算出該數的餘數,基本上是透過餘數的運算,將一整數以二進位的方式寫出,則可得到 input
n=bn1bn2...b1b0
,由餘數的加法規則可以知道該數的 餘數 = 奇數位元 set 個數 - 偶數位元 set 個數
ab(modm),cd(modm):(1)

a+cb+d(modm):(2)

acbd(modm):(3)

根據以上的 (3) ,可以列出

11(mod3)

21(mod3)

並將這兩者相乘,即可得到

2k(1)k(mod3)

再來透過 (2) ,可以知道 n 可表示為

n=bn12n1+bn22n2+...+b121+b020 ,因此才可以使用上面的結論 餘數 = 奇數位元 set 個數 - 偶數位元 set 個數 (mod3) ,以位元的方式觀察則是
n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA) (mod3)

再來開始分析以下程式碼

int mod3(unsigned n)
{
    n = popcount(n ^ 0xAAAAAAAA) + 23;
    n = popcount(n ^ 0x2A) - 3;
    return n + ((n >> 31) & 3);
}

這邊先使用了以下的等式,將原本的餘數計算方式換為 XOR 的方式,這是為了在之後能夠使用加上一個正整數的方式去縮減模數範圍,經過原本的 n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA) 後,可以發現模數的範圍是 +16 ~ -16 ,且根據 (2) 的公式,可以加上一個三的倍數,這裡的疑惑是不知道為何選擇 39 這個數字,目的只是為了讓他進行下一次的 n = popcount(n ^ 0x2A) - 3; ,所以我認為 18 應該也是可以的,繼續以原本的 39 計算下去的話,會發現模數範圍以被限制在 +55 ~ +23 之間,最高位元被限制在 6 bit 以內,才可以進行上面的運算,因為

0x2A=000....001010102

popcount(x&m)popcount(x&m)=popcount(xm)popcount(m)

最後觀察到 n + ((n >> 31) & 3) ,此時 n 的範圍已縮小到 +3 ~ -3 了,透過有號數的算術位移 (n >> 31) ,若是正數這邊將會是 0 ,若是負數則會是全部為一的二進位數 (-1) ,在與 3& 運算,所以這邊其實在做的是將最後的結果範圍縮小到 +3 - 0 即最後的餘數。

在最後這邊,我突然發現最後出來的結果是 +3 的話就不是我們想要的餘數了,因此我就思考到剛剛的 +39 這件事情,其實是特地為了把第一輪結束的 n 控制在 +55 ~ +23 的結果內,因為在第二次的 popcount(n ^ 0x2A) ,可以推測說算出 6 (就是最後餘數為 3 的狀況) ,就是在第一輪結果為 21 = 0000...010101 的時候,所以才用 39 這個數來修正,也因此第二次的 n 範圍進一步縮小到 +2 ~ -3 ,所以在這個例子中第一次計算選擇用來修正的數應該可以用 39, 42, 45 (16 + 45 = 61 ),如果用 48 就會超過第二次計算設定的 6 bit (63) 了。

look up table 的方式,則是直接將範圍修正為 +32 ~ 0 ,所以可以推測其實是做了 n = popcount(n ^ 0xAAAAAAAA) - 16; 的修改,在此之上又加 16 作為標籤使用,可以透過觀察原本回傳 16 的結果,來觀察這個表,所以原本的

161 會放置在更新後的第 table[32] ,也就是最後一個位置,所以原本的模數
150
會放在 table[31] ,以此類推,即做出整個表格。

int mod3(unsigned n)
{
    static char table[33] = {2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1 };
    n = popcount(n ^ 0xAAAAAAAA);
    return table[n];
}

tic-tac-toc 的部分,在每次的移動,透過這九個數值去表示他所在的格子,用 16 進位的一個字元來表示位置,而由高到低每個字元的意義為: "第一橫排,第一橫排,第一橫排,第一橫排,第一橫排,第一橫排,左上到右下斜線,右上到左下斜線" ,可由上至下,由左至右將一個字元分成四個 bit ,並用後三個 bit 去表示該線上的元素有沒有值,若第一條橫線中, 0110 即代表 [0][1] 與
[0][2] 都有值。

static const uint32_t move_masks[9] = {
    0x40040040, 0x20004000, 0x10000404, 0x04020000, 0x02002022,
    0x01000200, 0x00410001, 0x00201000, 0x00100110,
};

再來為了判斷勝利條件,即代表在剛才定義的每一條線上,只要該 16 進位位元在二進位表示為 0111 時,則代表連成一線,在觀察到算式後面的 & 0x88888888 ,就可以推斷出 BBBB0x11111111

static inline uint32_t is_win(uint32_t player_board)
{
    return (player_board + BBBB) & 0x88888888;
}

參考 Hacker's Delight 中的實作,利用

8k1(mod7) ,因此只要對 n 進行三的倍數的位移,即可在不造成模數改變的情況下減少 n 的範圍,這邊每次在進行的計算其實就是將他二進位表示法中的後 3k 位加到計算過後的模數,經過多次的範圍縮減之後,就能夠使用 look up table 進行模數計算。

int remu7(unsigned n) {
    static char table [75] = {0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4};
    
    n = (n >> 15) + (n & 0x7FFF);
    n = (n >> 9) + (n & 0x001FF);
    n = (n >> 6) + (n & 0x0003F);
    return table[n];
}

觀察上面的方法,這裡也是為了進行 x 的範圍縮減,將後面 15 bit 的值加回右移後的數以維持模數不變,因此可以推測 CCCC 為 15,而最後使用到的 ((x * UINT32_C(0x24924925)) >> 29) ,是利用到 Hacker's Delight 文章中推導出來的

n(mod7)=popcount(n) ,且
n(mod7)87n  (mod  8)
,以求得 mod 7 的模數。

static inline int mod7(uint32_t x)
{
    x = (x >> CCCC) + (x & UINT32_C(0x7FFF));
    /* Take reminder as (mod 8) by mul/shift. Since the multiplier
     * was calculated using ceil() instead of floor(), it skips the
     * value '7' properly.
     *    M <- ceil(ldexp(8/7, 29))
     */
    return (int) ((x * UINT32_C(0x24924925)) >> 29);
}

2024q1 第 4 週測驗三 Xtree

為了研究 Xtree 的特性時,首先先分析他的 xt_max_hint 、 xt_balance 、 xt_update 是如何被實作出來的,可以看到儲存平衡因子的變數稱為 hint ,並在 insertdelete 時做 update ,第一個看到 xt_balance

static inline int xt_balance(struct xt_node *n)
{
    int l = 0, r = 0;

    if (xt_left(n))
        l = xt_left(n)->hint + 1;

    if (xt_right(n))
        r = xt_right(n)->hint + 1;

    return l - r;
}

這邊在做的是平衡判斷,與 AVL 數相同,這裡的 hint 意義應該是從該點到其子數中最遠的 leaf node 的長度,即樹高。

static inline int xt_max_hint(struct xt_node *n)
{
    int l = 0, r = 0;

    if (xt_left(n))
        l = xt_left(n)->hint + 1;

    if (xt_right(n))
        r = xt_right(n)->hint + 1;

    return l > r ? l : r;
}

這邊在做的是在 update 的最後,更新當下的 hint

再來觀察 update 內部

if (b < -1) {
    /* leaning to the right */
    if (n == *root)
        *root = xt_right(n);
    xt_rotate_right(n);
}

這邊負責處理失去平衡時的左右旋轉,以右旋轉來觀察,這邊為止與 AVL 樹都是相同的。

最後就是題目的部分,此處為移除一節點之後的行為,根據 AVL 樹的行為,跟程式碼上方的註解,可以判斷當我移除一節點 del 時,我會將該節點右子樹中最小的元素 least 替換掉原本的 del ,這則透過 xt_replace_right 達成,因此 AAAA 應為 least ,這裡的 xt_replace_right 可以分析為多個部分,分別做的事情有 "連接 least 節點與 delparent 節點、連接 least 節點與 del 之右節點 (有右子樹的情形) 、將 least 節點原本的右子樹連接上 leastparent 節點" ,並在這之後進行 xt_update 也就是平衡的檢查,根據註解 an update operation is invoked on the right child of the newly inserted node. ,可以推斷 BBBBxt_right(least) ;若刪除的節點缺少右子樹,則選擇左子樹中最大的節點代替,相對的可以得知 CCCCDDDD 分別為 mostxt_left(most);若被刪除節點為 leaf node 的話將不用進行交換,且對 parent 進行 update ,所以判斷 EEEE, FFFFroot, parent

static void __xt_remove(struct xt_node **root, struct xt_node *del)
{
    if (xt_right(del)) {
        struct xt_node *least = xt_first(xt_right(del));
        if (del == *root)
            *root = least;

        xt_replace_right(del, AAAA);
        xt_update(root, BBBB);
        return;
    }

    if (xt_left(del)) {
        struct xt_node *most = xt_last(xt_left(del));
        if (del == *root)
            *root = most;

        xt_replace_left(del, CCCC);
        xt_update(root, DDDD);
        return;
    }

    if (del == *root) {
        *root = 0;
        return;
    }

    /* empty node */
    struct xt_node *parent = xt_parent(del);

    if (xt_left(parent) == del)
        xt_left(parent) = 0;
    else
        xt_right(parent) = 0;

    xt_update(EEEE, FFFF);
}

struct xt_node *xt_find(struct xt_tree *tree, void *key)
{
    return __xt_find2(tree, key);
}

int xt_remove(struct xt_tree *tree, void *key)
{
    struct xt_node *n = xt_find(tree, key);
    if (!n)
        return -1;

    __xt_remove(&xt_root(tree), n);
    tree->destroy_node(n);

    return 0;
}