Try   HackMD

linux2024-homework4

contributed by < SHChang-Anderson >

第三週測驗題

完整題目

測驗一

計算開平方根: (版本一)

#include <math.h>
int i_sqrt(int N)
{
    int msb = (int) log2(N);
    int a = 1 << msb;
    int result = 0;
    while (a != 0) {
        if ((result + a) * (result + a) <= N)
            result += a;
        a >>= 1;
    }
    return result;
}

int msb = (int) log2(N); 找到變數 N 的最高有效位 (most significant bit) 位置。

計算最高有效位的值存入變數 a 中。

進入迴圈,使用逐位元掃描的方法,從最高有效位開始,逐位檢查是否可以將對應的值加入到 result 中,而不會讓 result 的平方超過 N。此方法的特色在於相較於逐一數值嘗試開平方根近似的方式,利用了二進制表示的特性,每次只需要檢查一個位元,因此能夠更有效地近似計算出整數平方根。

計算開平方根: (版本二)

int i_sqrt(int N)
{
+   int msb = 0;
+   int n = N;
+   while (n > 1) {
+       n >>= 1; // 將 n 右移一位,相當於除以 2
+       msb++; // 計數器加 1
+   }
-   int msb = (int) log2(N);
    int a = 1 << msb; 
    int result = 0;
    while (a != 0) {
        if ((result + a) * (result + a) <= N)
            result += a;
        a >>= 1; 
    }
    return result;
}

與版本一不同於不使用 log2 函式,而是使用迭代計算的方式找到最高位元。

這樣的優勢在於可以避免使用到浮點數運算,也可以在不支持 log2 的環境中運行。

計算開平方根: (版本三)

這個版本的開方根利用 Digit-by-digit calculation 的概念實作開平方根。

首先若要對

x(
x0
) 做開平方根 ,可以假設
x=N2
N
即為欲求得的平方根數值,接著,將
N
改寫為 2 的羃次和,即:

N2=(an+an1+an2+...+a0)2,am=2m or am=0

若將

(an+an1+an2+...+a0)2 做展開,透過矩陣觀察:

[a0a0a1a0a2a0...ana0a0a1a1a1a2a1...ana1a0a2a1a2a2a2...ana2a0ana1ana2an...anan]

主對角線元素:

[a0a0a1a1a2a2    anan]

其餘元素:

[a1a0a2a0...ana0a0a1a2a1...ana1a0a2a1a2...ana2 a0ana1ana2an...]

將主對角線元素與其於元素分開討論可將原式整理成:

N2=(i=0nai)2=i=0nai2+20i<jnaiaj

其中,

i=0nai2 為對角線上的平方項,另外
20i<jnaiaj
為其餘的元素交叉相乘展開項。

接著將等式拆解為:

i=0nai2+20i<jnaiaj=an2+i=0n1ai2+2ani=0n1ai+20i<jn1aiaj

移項之後做觀察:

i=0nai2+20i<jnaiaj=an2+2ani=0n1ai+(i=0n1ai2+20i<jn1aiaj)

觀察括號內的數學式,可將

(i=0n1ai2+20i<jn1aiaj) 改寫為:
(i=1nai)2

最終整理得到:

N2=a02+2a0(i=1nai)+(i=1nai)2

Pm=an+an1+...+am

則所求

N=P0

接著將計算式整理成:

N2=P02=a02+2a0(P12)+(P1)2

若推展成一般式可得:

Pm2=am2+2amPm+1+Pm+12=Pm+12+am(2Pm+1+am)

Pm2=Pm+12+am(2Pm+1+am)

可以將

am(2Pm+1+am) 令為
Ym
,則 
Pm2=Pm+12+Ym
 。

若從從

m=n 一路嘗試計算到
m=0
每一輪透過
Ym
 得到下一輪次的
Pm2
並測試
Pm2N2
是否成立,最終即可找到所求。

然而,每輪計算

Pm2 的成本過高,若將
N2Pm2
計算結果令為
Xm
,則可推得
Xm=N2Pm2=N2(Pm+12+Ym)
,最終推得遞迴式:
XmXm+1Ym

這樣一來透過方程式

Ym=Pm2Pm+12=2Pm+1am+am2 ,紀錄上一輪的
Pm+1
來計算
Ym
以這樣的方式避免較高的運算成本。

為了實現演算法設計,進一步將

Ym 拆成
cm
dm
,得到:

cm=Pm+12m+1

dm=(2m)2

Ym={cm+dmif am=2m0if am=0

可以藉由位元運算從

cm,dm 推出下一輪
cm1,dm1
,再利用
cm1,dm1
計算出
Ym
最終推得
am

  • cm1=Pm2m=(Pm+1+am)2m=Pm+12m+am2m={cm/2+dmif am=2mcm/2if am=0

  • dm1=dm4

綜合以上方法使用演算法尋求

P0 (
an+an1+...+a0
) ,從
Pn
的初始條件:

  • Xn+1=N2

  • Xn=Xn+1Yn

    • Xn=Xn+1(cn+dn)
  • cn=0

    • Pn+1=0cn=0
  • dn=an2=(2n)2=4n


int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= AAAA) {
        int b = z + m;
        z >>= BBBB;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

演算法中 z 對應到上述推導的

cn, 而 m 對應到上述推導的
dn

由於初始的

cn=0, 將 z 設為 0:

int z = 0;

另一方面,

dn=an2=(2n)2=4n ,因此可以利用以下程式碼計算 m:

int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL);


    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= AAAA) {
        int b = z + m;
        z >>= BBBB;
        if (x >= b)
            x -= b, z += m;               
    }

在迴圈中, int b = z + m;b 對應到推導中的

Ym 。 而由上述推導
cm1={cm/2+dmif am=2m cm/2if am=0
可知,無論
am
結果為何,都需要將
cm/2
,因此 z >>= BBBB; 就是
cm/2
的實作, BBBB 應替換為 1 。

另外,由

dm1=dm4 知道每一輪迴圈,需要將變數 m 除以 4, m >>= AAAA 應改為 m >>= 2

    if (x >= b)
        x -= b, z += m; 

以上條件判斷中, if (x >= b)

Xm+1>Ym
XmXm+1Ym
N2>Pm2
,因此將
am
加入 z 當中,因為
c1=P0=an+an1+...+a0
因此最終所求即 z


嘗試用 ffs / fls 取代 __builtin_clz

int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

__builtin_clz(x) 函式回傳 x 的最高有效位前面連續的 0 位元的數量,那麼 31 - __builtin_clz(x) 就是最高有效位的位置。

同樣 fls(x) - 1 也可以找到最高有效位的位置,但 fls() 由索引值 1 開始計算,因此需要將 fls(x) - 1,從而計算需要左移多少位元才能得到最接近且不大於 x 的 2 的冪次數。

因此可以將程式碼改寫為:

int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;

    int z = 0;
+   int shift = (fls(x) - 1);
+   for (int m = 1U << ((shift) & ~1U); m; m >>= 2) {
-   for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

在 Linux 核心找出對整數進行平方根運算的程式碼

lib/math/int_sqrt.c,找到整數平方根運算的程式碼:

unsigned long int_sqrt(unsigned long x)
{
	unsigned long b, m, y = 0;

	if (x <= 1)
		return x;

	m = 1UL << (__fls(x) & ~1UL);
	while (m != 0) {
		b = y + m;
		y >>= 1;

		if (x >= b) {
			x -= b;
			y += m;
		}
		m >>= 2;
	}

	return y;
}

程式碼風格與實做原理與測驗一(版本三)類似,比較值得注意程式碼使用到 __fls(x) 來找到需要位移的位元數量,閱讀過去探討過關於 ffs 及 __ffs 加雙底線與否的不同 了解ffs__ffs (是否加雙底線) 的不同之處:參考 bitops 系列對外的介面: arch/arc/include/asm/bitops.h 中的註解得知:

  • __XXX 系列: 以 index 0 索引,範圍為 0 ~ {length of input type - 1}
  • XXX 系列: 以 index 1 索引,範圍為 1 ~ {length of input type}

由此可知使用 __fls(x) 來找到需要位移的位元數量,不需要 - 1 。

測驗二

mod 10div 10 連續操作

本測驗針對正整數在相鄰描述進行 mod 10div 10 操作的場景進行探討。若要優化這個計算情況,最直觀的方式是使用餘式定理。

根據餘式定理:
被除數 = (商 * 除數) + 餘數

對應到程式碼就是:
tmp = (tmp / 10) * 10 + (tmp % 10)

因此可以合併 mod 10 和 div 10 操作,改寫為以下程式碼:

carry = tmp / 10;
tmp = tmp - carry * 10;

若使用 bitwise operation 實作以上除法,會發現由於 10 存在因數 5 並非 2 的羃,因此可能會產生誤差。

測驗中題到了 tmp 不可能會大於 19 ,因此只需要考慮 19~0 的情況即可。其原因為:

  • tmp 是透過計算 (b[i] - '0') + (a[i] - '0') + carry 得到的。 b[i]a[i] 分別是字符 '0' 到 '9' 之間的數字字符,對應的數值範圍是 0 到 9。

  • 所以 (b[i] - '0') 和 (a[i] - '0') 的值範圍都是 0 到 9。

  • 將兩個 0 到 9 之間的數相加,最大值為 9 + 9 = 18。

  • 再加上最大可能的進位值 1,最大結果就是 18 + 1 = 19

接著,繼續針對方法繼續探究,針對此問題,提出了猜想:

  • 我們的目標是找到一個適當的除數 q ,使得 tmp / q 的結果至少在小數點後一位是精確的。
  • 假設最大的被除數是 n ,我們設 l 是一個比 n 小的非負整數。
  • 現在考慮兩個數 ab 和 cd ,其中 a 和 c 是十位數 , b 和 d 是個位數。
  • 如果存在一個除數 q ,使得 cd / q 的結果在精確度範圍內,那麼 ab / q 的結果也應該在精確度範圍內。

假設:

  • n=ab
    (
    a
    是十位數
    b
    是個位數)
  • l=cd
    (
    c
    是十位數
    d
    是個位數)

以下證明上述猜想:

a.bnxa.b9na.b9xna.b

分別對左右不等式進行探討:

na.b9xna.b

  • 右不等式:
    • xna.b
      得知
      x10
      必然再精度以內。
  • 左不等式:
    • na.b9x

接著討論

c.dlxc.d9,若我們將
na.b9
代入
x
可將不等式改寫為:

c.dl×a.b9nc.d9

分別將

l
n
替換為
cd
ab
可表現為:

c.dcd×a.b9abc.d9

  • 右不等式:

    • cd×a.b9abc.d9
      , 首先將不等式改寫為:
      cd×a.b0.09abc.d+0.09cd×(110+0.09a.b)c.d+0.09
      ,又透過分配律可得到:
      c.d+0.09cdabc.d+0.09
      ,由於
      ab>cd
      因此
      0.09cdab0.09
      ,上述不等式必成立。
  • 左不等式:

    • c.dcd×a.b9abc.dcd×(110+0.09a.b)c.dc.d+0.09a.b
      明顯成立。

由上述證明可得知,若 tmp 不可能會大於 19 ,只須透過不等式:

1.919x1.999.55x10 即可得知,除數介於
9.55
10
之間皆可程式中達到相同效果。

找除數的方法使用 bitwise operation

2Na 找到介於
9.55x10
的除數,若欲處理得數字為
n
商式可以寫成
an2N

2N=128,a=13,128139.84 為一個可用的除數,由於 13 可以拆成
13=8+4+1=23+22+20
2
的羃相加,因此範例程式中透過 (tmp >> 3) + (tmp >> 1) + tmp 得到
13tmp8
再將此式乘上 8 (向左位移 3 bits) 即可得到
13tmp
,只要再將其除以
128
(
27
) 即可得到目標商式
13tmp27

包裝後函式探討

#include <stdint.h>
void divmod_10(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = (in | 1) - (in >> 2); /* div = in/10 ==> div = 0.75*in/8 */
    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;

    *div = (q >> CCCC);
    *mod = in - ((q & ~0x7) + (*div << DDDD));   
}

uint32_t x = (in | 1) - (in >> 2);: 這行程式碼初始化 x。表達式 (in | 1) 確保 in 是奇數(將最低位設置為 1),然後減去向右移位 2 的 in(相當於將 in 除以 4),得到一個大約等於

in×0.75 (
34
) 的近似值。

uint32_t q = (x >> 4) + x; 相當於

3424in+34in 亦等於
5164in
若換算為小數約為
0.79in
,因此目前 q 值接近
810in

    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x; 

接著不斷透過持續的 bitwise 操作使的 q 值持續逼近

810in

*div = (q >> CCCC); 為了使商值 *div 正確被指定,CCCC 應替換為 3 得到

810in×18=in10

最後 *mod = in - ((q & ~0x7) + (*div << DDDD)); 應為透過餘式定理的餘數計算。根據餘式定理:
餘數 = 被除數 - 除數*商, ((q & ~0x7) + (*div << DDDD)); 計算的就是除數*商的部份也就是

quotient10((q & ~0x7) 的操作即 q & 0xFFFFFFF8 即將 q 最後三個位元清 0 ,這樣的作法使得目前 q 等價於
quotient8
,由於所求為
quotient10quotient(8+2)
因此 (*div << DDDD)) DDDD 應替換為 1 相當於
quotient2
以符合預期。

TODO:撰寫不依賴任何除法指令的 % 9 (modulo 9) 和 % 5 (modulo 5) 程式碼。

測驗三

ilog2 以 2 為底的對數 (版本一)

int ilog2(int i)
{
    int log = -1;
    while (i) {
        i >>= 1;
        log++;
    }
    return log;
}
  • 首先將 log 設為 -1。這樣做是為了在輸入值為 0 時,函式回傳 -1。

  • 接著進入一個 while 迴圈,當 i 不為 0 時持續執行迴圈體內的操作。

  • 在迴圈內, i 右移一位 (i >>= 1)。這個操作相當於將 i 除以2。

  • 每執行一次右移操作,就將 log 的值加1。

  • i 變為 0 時,迴圈終止。

  • 最後,函式即可求得輸入值 i 的對數值並回傳。

ilog2 以 2 為底的對數 (版本二)

static size_t ilog2(size_t i)
{
    size_t result = 0;
    while (i >= AAAA) {
        result += 16;
        i >>= 16;
    }
    while (i >= BBBB) {
        result += 8;
        i >>= 8;
    }
    while (i >= CCCC) {
        result += 4;
        i >>= 4;
    }
    while (i >= 2) {
        result += 1;
        i >>= 1;
    }
    return result;
}

這段程式碼每一次檢查一半的位元數量並進行 bitwise 位移,這種作法的優點是計算速度更快,因為它將對數值的計算分成了多個階段,每個階段只需要處理一小部分位元,而不是像前一種方法那樣逐位處理整個數值。
分段計算:這段程式碼將對數值的計算分成了多個階段,每個階段對應不同的位移量(16, 8, 4, 1)。這種做法可以加速計算過程,尤其是在處理大數值時。
這段程式碼將對數值的計算分成了多個階段,每個階段對應不同的位移量。可以觀察程式碼從最高 16 位元進行判別,接著是 8 位元 4 位元,直到最後一個位元為止。透過這樣的觀察我們可以知道 AAAABBBBCCCC 分別對映為

216
28
24

可以發現這樣的方法就是尋找 i 最高有效位的位置。

Linux 核心 log2 的相關程式碼

linux/log2.h 中可以找到 log2 的相關實作。

int __ilog2_u32(u32 n)
{
	return fls(n) - 1;
}

int __ilog2_u64(u64 n)
{
	return fls64(n) - 1;
}

可以看到 log2 使用到測驗一提到的 fls 也就是透過fls(x) - 1 找到最高有效位的位置達成與 ilog2 以 2 為底的對數 (版本二) 相同的效果。

測驗四

EWMA 理解

EWMA (指數加權移動平均) 是一種統計資料取平均的手法,其數學定義如下:

St={Y0t=0αYt+(1α)St1t>0

其中:

  • St
    為第
    t
    個時間點的 EWMA 值
  • Yt
    為第
    t
    個時間點的觀測值
  • α
    為歷史資料加權常數,介於0與1之間

α 值越大時,EWMA 會給予較多的權重於最近的觀測值,因此計算出的平均曲線會較為敏感,能夠快速反映最新的數據變化。反之,若
α
值較小,則 EWMA 會給予較多權重於歷史數據,計算出的平均曲線會較為平滑,變化也相對較小。

EWMA 實作

閱讀並理解測驗四中對於 EWMA 實作。首先,先觀察結構體 ewma

struct ewma {
    unsigned long internal;
    unsigned long factor;
    unsigned long weight;
};

結構體中使用 unsigned long 除存所有參數,使用 2 的羃來除存所有參數以及權重。

void ewma_init(struct ewma *avg, unsigned long factor, unsigned long weight)
{
    if (!is_power_of_2(weight) || !is_power_of_2(factor)) 
        assert(0 && "weight and factor have to be a power of two!");

    avg->weight = ilog2(weight); 
    avg->factor = ilog2(factor);
    avg->internal = 0; 
}

ewma_init 函式用於初始化結構體中的參數。在函式內部,我們觀察到對要初始化的參數進行了檢查,確保其值是 2 的冪次方。這樣做是為了後續使用位元操作來提高效能,代替乘除法的運算。接著,將這些參數轉換為對數形式,以便後續的處理。

值得注意的是,透過程式碼中對於 2 的冪次方的檢測,可以得知實作希望使用定點數進行加權平均的計算。此外,在函式的註解中提到,factor 參數被用作準備定點數運算所需的平移值。

struct ewma *ewma_add(struct ewma *avg, unsigned long val)
{
    avg->internal = avg->internal
                        ? (((avg->internal << EEEE) - avg->internal) +
                           (val << FFFF)) >> avg->weight
                        : (val << avg->factor);
    return avg;
}

ewma_add 是實際執行 EWMA 計算的函式,其中 internal 對應到

St;而 val 對應到
Yt
。我推測 (((avg->internal << EEEE) - avg->internal) + (val << FFFF)) >> avg->weight 即為上述數學定義:
αYt+(1α)St1
的實作方式。

我注意到當 avg->internal 為 0 時,函式會執行 (val << avg->factor);,也就是說當初始計算 EWMA 時尚未有任何資料,直接將目前時間點的觀測值加入。這時函式將 val 向左位移,因此我推測 (val << FFFF) 也應該將 val 向左位移,由此可知 FFFF 應該替換為 avg->factor

接著,我繼續探究 ((avg->internal << EEEE) - avg->internal) 程式碼部份。假設將 EEEE 暫時設置為變數

x,則由於 weight 以對數方式儲存 (((avg->internal << EEEE) - avg->internal) + (val << FFFF)) >> avg->weight 在數學上的意義為:
((St12xSt1)+Yt)×12weight
。然而目標數學方程式應該為:
αYt+(1α)St1
,觀察可發現後者的
α
應為前式的
12weight
。因此,
(St12xSt1)×12weight
應該與
(112weight)St1
等價,由此可推得
x=weight
,因此 FFFF 應該替換為 avg->weight

在 Linux 核心原始程式碼找出 EWMA 的相關程式碼

linux/average.h 可以找到 EWMA 實作程式碼。

#define DECLARE_EWMA(name, _precision, _weight_rcp)	

linux/average.h 定義了 DECLARE_EWMA 巨集,這個巨集接受了三個參數, name(用於生成的 struct 和函式名稱)、 _precision (表示用於儲存小數部分的位元數)和 _weight_rcp (一個 2 的冪,決定了新舊值的加權)。


struct ewma_##name {						\
	unsigned long internal;					\
};	

linux/average.h 定義了一個結構體 ewma_,包含一個 unsigned long 型態的成員,用於儲存 EWMA 值。


static inline void ewma_##name##_init(struct ewma_##name *e) { \
    BUILD_BUG_ON(!__builtin_constant_p(_precision)); \
    BUILD_BUG_ON(!__builtin_constant_p(_weight_rcp)); \
    /* \
     * Even if you want to feed it just 0/1 you should have \
     * some bits for the non-fractional part... \
     */ \
    BUILD_BUG_ON((_precision) > 30); \
    BUILD_BUG_ON_NOT_POWER_OF_2(_weight_rcp); \
    e->internal = 0; \
}	

ewma_init() 函式用於初始化結構實例,將 internal 設為 0。並使用了 BUILD_BUG_ON 巨集,在編譯時檢查 _precision 和 _weight_rcp 參數是否符合要求。


ewma_##name##_read(struct ewma_##name *e) {
    BUILD_BUG_ON(!__builtin_constant_p(_precision));
    BUILD_BUG_ON(!__builtin_constant_p(_weight_rcp));
    BUILD_BUG_ON((_precision) > 30);
    BUILD_BUG_ON_NOT_POWER_OF_2(_weight_rcp);
    return e->internal >> (_precision);
}	

ewma_read() 函式用於讀取 EWMA 值。由於 EWMA 實作使用了定點數運算,因此 internal 成員儲存了一個經過左移_precision 位的值。internal 成員儲存了經過放大的 EWMA 值(包含整數和小數部分),而 ewma_read() 的右移操作則是為了將它縮小回原始的整數 EWMA 值。


static inline void ewma_##name##_add(struct ewma_##name *e,
                                     unsigned long val)
{
    unsigned long internal = READ_ONCE(e->internal);
    unsigned long weight_rcp = ilog2(_weight_rcp);
    unsigned long precision = _precision;

    BUILD_BUG_ON(!__builtin_constant_p(_precision));
    BUILD_BUG_ON(!__builtin_constant_p(_weight_rcp));
    BUILD_BUG_ON((_precision) > 30);
    BUILD_BUG_ON_NOT_POWER_OF_2(_weight_rcp);

    WRITE_ONCE(e->internal, internal ?
        (((internal << weight_rcp) - internal) +
            (val << precision)) >> weight_rcp :
        (val << precision));
}

ewma_add() 函式是 EWMA 計算關鍵,用於將新值納入 EWMA 計算。它首先讀取 internal 值,根據 _weight_rcp 的值決定歷史資料與目前資料的加權。

相關應用程式碼

ath11k/core.h 找到 DECLARE_EWMA(avg_rssi, 10, 8) 定義了 EWMA 結構體,由 linux/average.h 可得知 fixed-precision values 為 10 而

weight=1weightrcp=18

  • ath11k: 高通 IEEE 802.11ax 裝置的 Linux 驅動程式:
    根據 Wireless Wiki,ath11k 是針對高通的 IEEE 802.11ax 無線網路裝置所設計的 Linux 驅動程式。它能夠支援在 SoC 類型裝置中的 AHB 匯流排和 PCI 接口。ath11k 基於 mac80211,這是 Linux 核心中用於無線網路裝置的通用框架。

  • avg_rssi 和 EWMA :
    我參考了 Received Signal Strength Indicator 一文來理解 RSSI。根據該資料,RSSI 是衡量設備從接收端接收訊號能力的指標,用於評估無線通訊中訊號的強度和品質。在 ath11k 的程式碼中,有一個名為 avg_rssi 的變數,被用來儲存指數加權移動平均值 (EWMA)。透過 EWMA,能夠平滑訊號強度的變化,提供更穩定的接收訊號強度指標 (RSSI)。

測驗五

ceil_ilog2 程式碼理解

ceil_log2 這個程式碼實現了一個函式,用於計算給定的 32 位元無號整數

x 的最小次方指數值向上進位的結果。也就是說,對於傳入的參數
x
,回傳最小的整數 n,滿足
2nx

可以注意到函式的最開始將 x 減 1,這樣可以確保當 x 是 2 的冪次時,計算出的指數正確。

r = (x > 0xFFFF) << 4;
x >>= r;
shift = (x > 0xFF) << 3;
x >>= shift;
r |= shift;
shift = (x > 0xF) << 2;
x >>= shift;
r |= shift;

接著根據以上程式碼的操作,我們可以觀察到類似於測驗三中的使用以 2 為底的對數 (版本二) 的二分搜尋法來找到最高位元位置。然而,與測驗三程式碼不同的是,這裡使用變數 r 來記錄位移量,而 (x > 0xFFFF) << 4; 的操作等價於以下程式碼:

while (i >= 65536) {
    result += 16;
    i >>= 16;
}

在這裡,(x > 0xFFFF) 的結果是一個布林值。如果 x 的高 16 位元不為 0,則 (x > 0xFFFF) 的結果為 True,亦即等於 1。因此,位移量 r 被設定為 1 << 4,即

16,達到與測驗三 log2 程式碼相同的效果。

shift = (x > 0xFF) << 3;
x >>= shift;
r |= shift;

值得注意的是,在程式碼中除了執行位移操作外,還將目前的位移量與變數 r 做了

OR 運算,即 r |= shift;。由於位移量都是 2 的冪次方,因此這樣的
OR
運算等同於對位移量進行加法操作(result +=)。因此,在進行位移後,程式碼將持續累加位移量,以找到最高位元位置。

return (r | shift | x > GGG) + 1;

最終函式回傳,總位移量 + 1,然而若對照測驗三 log2 程式碼可發現,少了一個條件判斷:

while (i >= 2) {
    result += 1;
    i >>= 1;
}

因此 | x > GGG 即判斷

x 是否
2
,也因此 GGG 應填入
1
以符合預期。

改進程式碼

試想上述程式碼若傳入參數

x=0 時,程式碼的第一行, x-- 會將 x 值改變為 0xFFFFFFFF ,這樣一來,會使得接下來的迴圈以及條件判斷不符合預期,因此需要對此做修正。

簡單的修正方法即為加入 if 條件判斷,避開

x=0 時做減法,然而這樣的方式並不符合 branchless

我嘗試將程式碼做以下更動:

int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;
+   x = x - (x > 0);
-   x--;
    r = (x > 0xFFFF) << 4;                                                                                                                                    
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > GGG) + 1;       
}

加入 x = x - (x > 0); 後使得在

x>0 時才會產生布林值 1 ,達到減 1 的效果,並仍是 branchless

第四週測驗題

完整題目

測驗一

Population count 程式碼理解

population count 簡稱 popcount 或叫 sideways sum,是計算數值的二進位表示中,有多少位元是 1。

閱讀到關鍵程式碼:

n = (v >> 1) & 0x77777777; v -= n; n = (n >> 1) & 0x77777777; v -= n; n = (n >> 1) & 0x77777777; v -= n;

不了解為何 n = (v >> 1) & 0x77777777 即可將數值分為四個位元一個單位做減法,並透過 v -= n 即可求得

vv2 。因此我將數學式列出,最一開始,傳入值 v 即為
231b3+230b2+229b1+...+20b0
而將 v >> 1 可得
0+230b2+229b1+...+21b1

而我們可以得知 (v >> 1) & 0x77777777; 結果為:

    0 b_31 b_30 b_29 b_28 b_27 ... b_4 b_3 b_2 b_1
&   0    1    1    1    0    1       0   1   1   1
---------------------------------------------------
    0 b_31 b_30 b_29    0 b_27 ...   0 b_3 b_2 b_1

寫成數學式為:

(0+230b31+229b30+228b29)+(0+226b27+225b26+224b25)+...+(0+22b3+21b2+20b1)

若持續重複執行 n = (n >> 1) & 0x77777777; v -= n; 可以分別得到:

(0+0+229b31+228b30)+(0+0+225b27+224b26)+...+(0+0+21b3+20b2)

(0+0+0+228b31)+(0+0+0+224b27)+...+(0+0+0+20b3)

因此,若以四項為單位做相減即可得到:

(23b3+22b2+21b1+20b0)(22b3+21b2+20b1)(21b3+20b2)20b3

並對應到以每 4 個位元 (nibble) 為一個單位計算 1 的個數。

接著透過一系列位移以及 bitmask 操作可得所求的 popcount 值。

Hamming Distance

int hammingDistance(int x, int y)
{
    return __builtin_popcount(x ^ y);
}

Hamming Distance 是指這兩個數字對應位元的不同位置的個數。例如:數字 3 (二進制為

0011 )和數字 5 (二進制為
0101
)的漢明距離為 2,因為它們在第一位和第三位不同。

在程式碼中,使用了位元運算子 XOR(

) 來找出兩個數字在哪些位置不同。對於任何兩個位元,只有當它們不同時, XOR 的結果才會是 1。因此,
xy
的結果會是一個數字,其二進制表示中 1 的位置就是 x 和 y 不同的位置。

接著程式使用 __builtin_popcount 函式來計算

xy 中有多少個 1,也就是 x 和 y 有多少個不同的位元。這個函數的實作方式是直接對輸入的數字計算其二進制表示中 1 的個數,效率相當快。因此,這一行程式碼實際上就是在快速計算出 x 和 y 的漢明距離。

int totalHammingDistance(int* nums, int numsSize)
{
    int total = 0;;
    for (int i = 0;i < numsSize;i++)
        for (int j = 0; j < numsSize;j++)
            total += __builtin_popcount(nums[i] ^ nums[j]); 
    return total >> AAAA;
}

以上程式碼用於計算 nums 陣列中所有數字之間的漢明距離總和。需特別注意的是,程式中重複考慮了兩個數字之間的距離,因此最終結果為總距離的兩倍。為了將結果除以 2,使用右移操作符將總和向右移動1位。因此,AAAA 應填入 1。

Total Hamming Distance 程式碼改進

從位元展現的樣貌,觀察 Total Hamming Distance 的規則:

n 'th bit 4 3 2 1 0
Input 7 0 0 1 1 1
Input 5 0 0 1 0 1
Input 10 0 1 0 1 0
Input 17 1 0 0 0 1

首先,我們觀察第 0 個位元位置。在這個位置上,數字 7、5、17 都是 1,而數字 10 是 0。進一步探究 Hamming Distance:







so



1

1



0

0



1->0










so



1

1



0

0



1->0




11

1



11->0




111

1



111->0




從上圖可理解為:每個 1 位元可以與 1 個 0 位元產生距離為 1 的 Hamming Distance。因此,由於有 3 個 1 位元,總 Hamming Distance 為

1×3

接下來,我們觀察第 1 個位元位置:







so



1

1



0

0



1->0




00

0



1->00










so



1

1



0

0



1->0




00

0



1->00




11

1



11->0




11->00




在第 1 個位元位置上,數字 7 和 10 都是 1,而數字 5 和 17 是 0。觀察上圖可理解為:每個 1 位元可以與 2 個 0 位元產生距離為 1 的 Hamming Distance。因此,由於有 2 個 1 位元,總 Hamming Distance 為

2×2

總結來說,可以計算每個位置上的 1 位元數量,並將每個位置的 1 位元數量乘以 0 位元數量,以求得 Total Hamming Distance。

根據以上方法實作改進後的程式碼:

int totalHammingDistance_(int* nums, int numsSize)
{
    int total = 0;
    for (int i = 0;i < 32;i++) {
        int ct = 0;
        for (int j = 0; j < numsSize;j++)
            ct += ((nums[j] >> i) & 1);
            
        total += ct * (numsSize - ct);
    }
    return total;
}

撰寫程式驗證其正確性:

commit 3e675c3

使用 perf 分析其效能差異,針對 10000 筆數字進行 Total Hamming Distance 計算。

改進前:

 Performance counter stats for './totalHammingDistance':

     1,141,370,997      cycles                                                                

       0.425123945 seconds time elapsed

       0.421050000 seconds user
       0.003972000 seconds sys

改進後:

 Performance counter stats for './totalHammingDistance_':

         4,986,285      cycles                                                                

       0.003329801 seconds time elapsed

       0.000000000 seconds user
       0.003289000 seconds sys

可以發現改進後的程式碼大幅減少了 clock cycles 數量,同時也縮減了執行時間。

測驗二

Remainder by Summing digits

為了在不使用除法的情況下計算某數除以另一個數的餘數,使用了模同餘的概念。

當除數為 3 時,我們可以觀察到

11(mod  3)
21(mod  3)
。 根據
acbd(mod  m)
的性質,我們可以進行以下推導:

2k{1(mod  3),  k 為偶數1(mod  3),  k 為奇數

當我們將

n 以二進位表示時,可以寫為
bn1bn2bn3...b1b0

根據前述推導,我們得知當

k 為偶數時,同餘為 1;當
k
為奇數時,同餘為 -1。因此,我們可以得到以下表達式:

n=bn12n1+bn22n2++b323+b222+b121+b0b3+b2b1+b0 (mod 3)

接著,我們使用以下定理進行化簡:

popcount(xm)popcount(xm)=popcount(xm)popcount(m)

因此,n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA) 可以寫為 n = popcount(n ^ 0xAAAAAAAA) - 16

然而,以上計算結果的範圍會落在 -16 至 16 之間。考慮到希望餘數為正數的情況,我們需要加上一個 3 的倍數以確保餘數在同餘情況下為正數。

至於為何要加上 39 ? 參閱 《Hacker's Delight》中的說明:

We want to apply this transformation again, until n is in the range 0 to 2, if possible. But it is best to avoid producing a negative value of n, because the sign bit would not be treated properly on the next round. A negative value can be avoided by adding a sufficiently large multiple of 3 to n. Bonzini’s code, shown in Figure 10–21, increases the constant by 39. This is larger than necessary to make n nonnegative, but it causes n to range from –3 to 2 (rather than –3 to 3) after the second round of reduction. This simplifies the code on the return statement, which is adding 3 if n is negative. The function executes in 11 instructions, counting two to load the large constant.

在文中指出,將常數增加了 39。這個值比僅使非負數所需的常數值更大,可使得在第二輪計算後,值落在 -3 到 2 的範圍內(而非 -3 到 3),也因此簡化了程式碼,只需在 n 為負數時加 3 即可。

另一種方法是直接將 0 到 32 的所有數字除以 3 得到的餘數事先儲存在一個 lookup table 中,這樣就可以直接透過查表的方式找到對應的餘數。程式碼如下所示:

int mod3(unsigned n)
{
    static char table[33] = {2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1 };
    n = popcount(n ^ 0xAAAAAAAA);
    return table[n];
}

井字遊戲程式碼理解

此程式實現的是一個井字遊戲的變體,該變體將遊戲目標從傳統的在3x3棋盤上實現三個連續的棋子,擴展到了在任何八條可能的直線上達到三個連續的棋子。這使得玩家的策略更加多樣化,因為一步棋可能會影響多條直線。

程式中設計了 available_moves[] 陣列,此陣列即為 3x3 井字遊戲上的九個位置選擇,因此我們可以看到 play_random_game 函式中,每做出一個選擇將其選擇從陣列中移除:

uint32_t move = available_moves[i];
available_moves[i] = available_moves[n_moves - 1];

實作方式就是將選擇的走法用最後一個可用的走法來替換,然後將最後一個可用的走法移到了被選擇的走法所在的位置。

board |= move_masks[move]; 這行程式碼中,將選定的走法 move 對應的連線狀態更新到了玩家的棋盤狀態中。在這個井字遊戲變體中,棋盤狀態不是按照傳統的九宮格形式表示,而是以8條可能的連線來表示。這意味著玩家的棋盤狀態存儲了對應於這8條連線的狀態,而不僅僅是傳統的棋子擺放狀態。

對於 board |= move_masks[move]; 這行程式碼,move_masks[move] 取得了選定走法 move 對應的連線狀態,然後將它與玩家的棋盤狀態 board 做位元或運算,從而將選定走法的影響更新到了玩家的棋盤狀態中。由於每個 move_masks 元素都包含了對應位置下棋可能影響的所有連線,因此這個操作有效地更新了玩家棋盤狀態中的所有可能連線。

static const uint32_t move_masks[9] = {
    0x40040040, 0x20004000, 0x10000404, 0x04020000, 0x02002022,
    0x01000200, 0x00410001, 0x00201000, 0x00100110,
};

move_masks 陣列中的每個元素代表了在將棋子放置到特定位置後,對於所有可能連線狀態的影響。每個元素的二進位表示描述了在該位置放置棋子後,連線狀態發生變化的情況。

0x40040040 為例,用圖示來進行說明:

考慮 0x40040040 九宮格棋盤中為左上角 (0) 的選擇:

1 2 3
1 0 1 2
2 3 4 5
3 6 7 8

此位置將有三種可能連線:

1 2 3
1 0 1 2
2 3 4 5
3 6 7 8

對應到 board 二進位表示即:
0111 0000 0000 0000 0000 0000 0000 0000

1 2 3
1 0 1 2
2 3 4 5
3 6 7 8

對應到 board 二進位表示即:
0000 0000 0000 0111 0000 0000 0000 0000

1 2 3
1 0 1 2
2 3 4 5
3 6 7 8

對應到 board 二進位表示即:
0000 0000 0000 0000 0000 0000 0111 0000

0x40040040 以二進位可表示為:
0100 0000 0000 0100 0000 0000 0100 0000 玩家棋盤狀態與其做

or 運算可更新棋子擺上棋盤 0 位置後對連線狀態的影響。

static inline uint32_t is_win(uint32_t player_board)
{
    return (player_board + BBBB) & 0x88888888;
}

勝利的條件判斷即 player_board 以四個位元為單位出現 0111 即判斷該玩家獲勝,可以看到程式碼將 (player_board + BBBB)0x88888888

and 運算,由此可知,當出現 0111 時需要將棋結果轉為 1000 ,而將 0111 + 1 即可達成此效果,因此 BBBB 應填入 0x11111111

Modulo 7 程式碼理解

TODO

測驗三

XTree

treeint.c 為二元樹測試程式,用來測量在不同的操作下,如插入、查找和刪除,二元樹的性能表現。

  • treeint_ops 結構,該結構包含指向各種樹操作函式的指標。
  • xt_opstreeint_ops 的實例 ,並將其函式指標設定為特定的實作函式。

xtree.[ch] 二元搜尋樹的實現,它採用了一些特定的策略來保持樹的平衡。

二元樹的結構定義包含 xt_tree 以及 xt_node ,值得注意的是在 xt_node 中加入了 hint 作為平衡參數。

程式使用不同函式實現二元樹的不同功能:

  • xt_create 函數創建一個空的樹。

  • xt_destroy 和 __xt_destroy 函數用來遞迴釋放樹中所有節點的記憶體。

  • xt_rotate_left 和 xt_rotate_right 函數用於節點的左旋和右旋操作,這是平衡樹的關鍵操作之一。

  • xt_update 函數根據節點的平衡因子進行相應的旋轉操作,以維持樹的平衡。

  • __xt_find 和 __xt_find2 函數實現了在樹中查找特定鍵值的節點。

  • __xt_insert 和 xt_insert 函數實現了向樹中插入新節點的功能。

樹的刪除操作:

刪除操作相對複雜,尋找替代節點(右子樹的最小節點或左子樹的最大節點)並進行替換。

if (xt_right(del)) {
    struct xt_node *least = xt_first(xt_right(del));
    if (del == *root)
        *root = least;

    xt_replace_right(del, AAAA);
    xt_update(root, BBBB);
    return;
}

if (xt_left(del)) {
    struct xt_node *most = xt_last(xt_left(del));
    if (del == *root)
        *root = most;

    xt_replace_left(del, CCCC);
    xt_update(root, DDDD);
    return;
}
  • 函式首先檢查被刪除的節點 del 是否有右子節點 (xt_right(del))

  • 若存在,它會找到右子樹中最小的節點(xt_first(xt_right(del))),這個最小節點將會取代要被刪除的節點。

  • 如果被刪除的節點是根節點(del == *root),則將根節點更新為這個最小節點 (*root = least)

  • 接著,函式呼叫 xt_replace_right 找到的最小節點來替換被刪除的節點。因此 AAAA 應該替換為 least

  • 最後,對替換後的新樹結構進行 xt_update ,以維持樹的平衡。由於替換操作後,least 節點被移動到了 del 的位置,而 least 的原位置現在由它的右子節點所取代,針對原位置的節點進行更新,因此BBBB 應替換為 xt_right(least)

同樣的,若刪除的節點有左子節點, CCCC 應該被替換為 most,表示將 most 節點放到 del 節點的位置。

如果被刪除的節點 del 有左子節點,會找到這個左子樹中的最大節點 most 來替代 del。而 DDDD 應替換為 xt_left(most)

最後,當欲刪除的節點沒有子節點時分為兩種情況進行處理:

  • 節點為根節點:如果要刪除的節點 del 正好是根節點,直接將根節點指標設置為 NULL,樹將變為空。
  • 節點非根節點:如果 del 不是根節點,那麼首先找到 del 的親代節點 parent ,接著判斷 del 是其親代節點的左子節點還是右子節點。如果 del 是左子節點,則將 parent 的左子節點指針設置為 NULL。如果 del 是右子節點,則將 parent 的右子節點指針設置為 NULL。此舉使得節點從二元樹中斷開。
  • 平衡更新:xt_update(EEEE, FFFF) 來更新樹的平衡。 EEEE 應傳入 root。由於刪除 del 後,需要重新平衡的是 del 的親代節點 parent。因此 FFFF 應該是 parent 節點。