Try   HackMD

Linux 核心專題: 重作第四次作業

執行人: Petakuo
解說影片

Reviewed by alenliao666

第三週測驗 2 ,有提到程式碼利用 d0d1d2 保留 lsb 數來的三個位元以利後續補回,想請問既然 d2 已保留這三個位元,為什麼還需要用到d0d1 呢?

因為在 (tmp >> 3) + (tmp >> 1) + tmp 中,我們總共做了 2 次的位元運算,假設 tmp 最低的三位元為 abc ,具體被省略掉的部分就分別為 0.abc0.c ,因此我們必須加上這兩個值,而這兩個值實際上為

8(a2+b4+c8+c2)=4a+2b+5c ,因此可由保留的位元得到,應為二進制加法的
4c+abc
,因此程式碼為 4d0 + d2
我原本的程式碼寫錯了,已更改。
Petakuo

Reviewed by 164253

第三週測驗 2 ,可以使用 barrett reduction ,不需要對特定的小數值逼近,且做除法和取餘時,如果直接寫出常數(包含字面常量、 const 變數等),現代編譯器會自動在編譯期做完計算並直接改在組合語言內。

我都不知道還有這個方法,非常實用,謝謝!
Petakuo

Reviewed by dockyu

第三週測驗 5,

x=0
x=1
時得到的值都是 1 ,這不是
log2
的正確答案

謝謝提醒!我已經將最後一行的 +1 更正為 x > 0 ,如此就不會對所有數進行 +1 的動作,在輸入

x=1 時就會得到正確的值。
至於
x=0
時的情況若要 return -Infinity 則需要用浮點數,因此這裡先以不讓它到 uint32_t 的最大值進行處理。
Petakuo

Reviewed by 56han

測驗例題選擇做四次 q = (q >> 8) + x,有特別的意義嗎?

而後續的 q = (q >> 8) + x 則是在對 0.8 逼近,做越多次就會越接近

有的,若逐一計算每次結果的值則可以發現 4 次的答案分別為 0.79998779296 0.79999995231 0.79999999981 0.79999999999 ,而再繼續做下去就會都是 0.8 了,因此而選擇做 4 次該運算。
Petakuo

Reviewed by gawei1206

關於第三周測驗 2 中,((((tmp >> 3) + (tmp >> 1) + tmp) << 3) + d0 + d1 + d2),這段程式碼是為了將右移損失的位元加回去,那我的問題是他真的有實際把損失的位元加回去,正確計算出 13tmp 嗎?

有的,但是在此例中,因為該值不會一直進位到第 8 個位元或是更高位,因此右移 7 個位元後並不會有影響。
Petakuo

任務簡介

重作第四次作業並彙整學員的成果

TODO: 重作第四次作業

重作並彙整其他學員的成果,適度標注出處,應重現實驗,確保程式碼的改進的確有效

第三週測驗題

測驗1

版本一

首先,利用 log2() 函式找到最高位的位元,接著由該位元開始測試,逐漸逼近答案,而逼近的方法為利用 if 函示判斷 (result + a) 的平方是否大於被要求開平方的值 N ,如果有,則繼續往低位進行逼近,如果沒有,則將原本的答案加上 a 進行更新,並同樣往低位進行逼近,如此做到 a 為 0 時即可得到 N 的平方根。

版本二

int msb = 0;
int n = N;
while (n > 1) {
    n >>= 1;
    msb++;
}

與版本一大致相同,唯改變了取最高位元的做法,將 log2() 用以上程式碼呈現,首先判斷 input 的 n 是否大於 1 ,有則將其右移一個位元,並將所求 +1 ,如此進行直到 n 只剩下 1 即代表該位元為最高位元,而所求也隨著每一輪迴圈的 +1 而得出。

版本三

利用 Digit-by-digit calculation 的方法計算,先將要求平方根的

N2 拆解成 :

N2=(an+an1+an2+...+a1+a0)2

並且先觀察較簡單的例子 :

(an+an1)2=an2+2anan1+an12=an2+[2an+an1]an1
(an+an1+an2)2=an2+[2an+an1]an1+[2(an+an1)+an2]an2

依此模式進行,我們最終就可以得到 :

N2=an2+[2an+an1]an1+[2(an+an1)+an2]an2+...+[2(i=1nai)+a0]a0

此時再令

Pm=an+an1+...+am ,則上述式子可以進一步被改寫為 :

N2=an2+[2Pn+an1]an1+[2Pn1+an2]an2+...+[2P1+a0]a0

並且

P0=an+an1+...+a0 即為所求
N

為求
P0
,我們可以由
Pm=Pm1+am
進行迭代,每一輪皆去判斷
Pm2N2
是否成立,若成立則將該位元設為 1 ,反之則設為 0 。
但此時由於去計算每一輪
Pm2
的成本太高,因此我們透過利用上一輪的資訊來做計算,首先定義
Xm=N2Pm2=Xm+1Ym

並且
Ym=Pm2Pm+12=2Pm+1am+am2

此時再將
Ym
改寫為
cm+dm
,其中
cm=Pm+1am+1,dm=am2
,因此下一位元即可被更新為
cm1=Pm2m=(Pm+1+am)2m=Pm+12m+am2m={cm/2+dmif am=2mcm/2if am=0

dm1=dm/4

到這裡可以發現,將 m = 0 代入,可得
c1=P020=P0=an+an1+...+a0=N
,即為所求。

int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) /*m >>= AAAA*/ 
    {
        int b = z + m;
        z >>= 1; /*z >>= BBBB*/
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

在此程式碼中,利用 __builtin_clz(x) 找出 msb ,且 b 為上述推導過程中的

Ym ,因此 z 等同於
cm
而 m 等同於
dm
。由於每一輪
dm
都會除以 4 ,因此 AAAA 為 2 ,同樣道理,
cm
每一輪會除以 2 ,所以 BBBB 為 1 。

延伸問題 : 利用 ffs 取代 __builtin_clz ,使程式不依賴 GNU extension

static inline unsigned long __ffs(unsigned long word)
{
    int num = 0;
    if ((word & 0xffff) == 0) {
        num += 16;
        word >>= 16;
    }
    if ((word & 0xff) == 0) {
        num += 8;
        word >>= 8;
    }
    if ((word & 0xf) == 0) {
        num += 4;
        word >>= 4;
    }
    if ((word & 0x3) == 0) {
        num += 2;
        word >>= 2;
    }
    if ((word & 0x1) == 0)
        num += 1;
    return num;
}

int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;

    int z = 0;
    for (int m = 1UL << ((31 - __ffs(x)) & ~1UL); m; m >>= 2) {
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

搭配閱讀 從 √2 的存在談開平方根的快速運算,注意裡頭的數學證明和 Mark Dickinson 提出的查表演算法。

測驗2

unsigned d2, d1, d0, q, r;

d0 = q & 0b1;
d1 = q & 0b111;
q = ((((tmp >> 3) + (tmp >> 1) + tmp) << 3) + 4*d0 + d1) >> 7;
r = tmp - (((q << 2) + q) << 1);

此段程式碼嘗試利用 bitwise 操作來取代正整數的除法和 mod 運算,這裡以 10 為例進行解說。
從數學的角度切入,除以 10 等同於乘以

110 ,因此我們必須找到在二進制中
110
的表示法,而因為不是所有數都能夠利用有限個 bits 來表達,所以只能取近似值來做運算,在此例中使用了
128139.84
來計算,接著,就是要想辦法透過 bit operation 來得出原數乘上
13128
的結果。
首先,
13128
可以被拆解成
138
再除以 16 ,而
138
又可以被拆解成
1+12+18
,如此一來,原數乘上
138
就可以用 q = (tmp >> 3) + (tmp >> 1) + tmp 計算出來,最後再透過 q = q << 4 來除以 16 即可。但若考慮到 mod 的運算,則在利用位元右移計算乘以 2 的冪時所被捨棄的位元就會造成計算上的錯誤,因此該程式碼利用了以下三行將原數最右邊的三個位元先保留下來。

d0 = q & 0b1;
d1 = q & 0b11;
d2 = q & 0b111;

在計算完

13tmp8 後,再將其乘上 8 (q = q << 3),目的是在於要加回被捨棄掉的 bits ,而在加完之後,再直接除以 128 (q = q >> 7) 即可。
至於 mod 的部分,可以直接利用除法原理得到,因為除法原理告訴我們
tmp=q×10+r
,所以餘數可由
r=tmpq×10
得出,因此在程式碼的最後寫了 ((q << 2) + q) << 1 ,即
(q×4+q)×2=q×10

將此概念包裝後,可以下方程式碼重新呈現 :

void divmod_10(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = (in | 1) - (in >> 2); /* div = in/10 ==> div = 0.75*in/8 */
    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;

    *div = (q >> 3); /**div = (q >> CCCC)*/
    *mod = in - ((q & ~0x7) + (*div << 1)); /*(*div << DDDD)*/
}

這段程式碼做的事情為將 in 這個數以 bitwise 操作同時除以和 mod 10 ,具體的想法為除以 10 等於乘以 0.1 ,因此可先乘上 0.8 再除以 8 。

提供數學證明!

第一行 (in | 1) - (in >> 2) 強制將 in 轉為奇數(不知道原因,且不管有沒有那個 +1 經過測試後的皆是正確的)並減去

in4 得到
3in4
,而第二行 (x >> 4) + x 則是在計算
3in64+3in4
,此計算結果為
51in640.79in
即代表前兩行做的事情為令一變數 q 等於 0.8x
而後續的 q = (q >> 8) + x 則是在對 0.8 逼近,做越多次就會越接近,最後需再將結果除以 8 ,因此 CCCC 為 3 。
餘數的部份,靠除法原理可知 *mod = in - 10*div ,因此 (q & ~0x7) + (*div << DDDD) 就會等於 10*div ,而 (q & ~0x7) 做的事情為將 q 最低位的 3 個位元清零,此做法等同於 (q >> 3) << 3 ,即 8*div ,因此 *div << DDDD 就必須等於 2*div , 最終可知 DDDD 為 1 。

延伸問題 : 練習撰寫不依賴任何除法指令的 % 9 (modulo 9) 和 % 5 (modulo 5) 程式碼,並證明在 uint32_t 的有效範圍可達到等效。

void divmod_9(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = in - (in >> 3);  /*in = (in*8/9)/8  8/9 = 7/8 + 7/512 */
    uint32_t q = (x >> 6) + x;
    x = q;
    q = (q >> 12) + x;
    q = (q >> 12) + x;
    q = (q >> 12) + x;
    q = (q >> 12) + x;
    
    *div = (q >> 3);
    *mod = in - (q & ~0x7) - *div;
    if(*mod == 10 || *mod == 9)
    {
        *div++;
        *mod -= 9;
    }
}

測試程式碼如下,首先撰寫出 randnum 來造一個隨機的 32 位元無號數,接下來利用實際的除法與 mod 運算來確認 divmod_9 是否正確 :

uint32_t randnum()
{
    uint32_t ran = 0;
    ran |= (uint32_t)(rand() & 0xFFFF) << 16;
    ran |= (uint32_t)rand() & 0xFFFF;
    return ran;
}

int main()
{
    uint32_t div, mod;
    for(int i = 0; i < 10; i++){
        uint32_t ran = randnum();
        uint32_t a = ran/9;
        uint32_t b = ran%9;
        divmod_9(ran, &div, &mod);
        printf("a = %d\ndiv = %u\nb = %d\nmod = %u\n", a, div, b, mod);
    }
    return 0;
}

列印出的 10 組數據如下 :

a = 129376363
div = 129376363
b = 3
mod = 3
a = 284115184
div = 284115184
b = 3
mod = 3
a = 410702193
div = 410702193
b = 6
mod = 6
a = 276433377
div = 276433377
b = 3
mod = 3
a = 58090291
div = 58090291
b = 2
mod = 2
a = 165404435
div = 165404435
b = 0
mod = 0
a = 122931853
div = 122931853
b = 6
mod = 6
a = 316698205
div = 316698205
b = 1
mod = 1
a = 905749
div = 905749
b = 5
mod = 5
a = 15525887
div = 15525887
b = 1
mod = 1

看起來是正確的,接著再改寫 main 的程式碼,把測試資料放大到 1000000 組 :

int main()
{
    uint32_t div, mod;
    int correct = 0;
    for(int i = 0; i < 1000000; i++){
        uint32_t ran = randnum();
        uint32_t a = ran/9;
        uint32_t b = ran%9;
        divmod_9(ran, &div, &mod);
        if(a == div && b == mod)
        {
            correct++;
        }
    }
    printf("%d/1000000\n", correct);
    return 0;
}

輸出的結果為 :

1000000/1000000

因此可驗證 divmod_9 為可行的程式碼。

而對於 mod5 也可以用同樣的方式驗證,程式碼如下 :

uint32_t randnum()
{
    uint32_t ran = 0;
    ran |= (uint32_t)(rand() & 0xFFFF) << 16;
    ran |= (uint32_t)rand() & 0xFFFF;
    return ran;
}

void divmod_5(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = in - (in >> 2);  /*in = (in*4/5)/4  4/5 = 3/4 + 3/64 */
    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    
    *div = (q >> 2);
    *mod = in - (q & ~0x3) - *div;
    if(*mod == 6 || *mod == 5)
    {
        (*div)++;
        *mod -= 5;
    }
}

int main()
{
    uint32_t div, mod;
    int correct = 0;
    for(int i = 0; i < 1000000; i++){
        uint32_t ran = randnum();
        uint32_t a = ran/5;
        uint32_t b = ran%5;
        divmod_5(ran, &div, &mod);
        if(a == div && b == mod)
        {
            correct++;
        }
    }
    printf("%d/1000000\n", correct);
      return 0;
}

輸出的結果為 :

1000000/1000000

測驗3

版本一

首先將 log 設置為 -1 ,這是考慮到 input 為 0 的情況,因為

20 為 1 ,接著進入迴圈,若 input 不為 0 ,則將 log + 1 ,並且同時將 input 右移一個位元,也就是除以 2 ,在離開迴圈時,答案就會是最高位元的 1 位置,也等於對 input 取 log 的整數值。

版本二

static size_t ilog2(size_t i)
{
    size_t result = 0;
    while (i >= 65536) /*(i >= AAAA)*/
    {
        result += 16;
        i >>= 16;
    }
    while (i >= 256) /*(i >= BBBB)*/
    {
        result += 8;
        i >>= 8;
    }
    while (i >= 16) /*(i >= CCCC)*/
    {
        result += 4;
        i >>= 4;
    }
    while (i >= 2) 
    {
        result += 1;
        i >>= 1;
    }
    return result;
}

與版本一的概念相同,都是利用迴圈判斷是否將 input 右移,只是在這個版本利用多個判斷式子來加速進行,而判斷的依據為是否大於

216
28
24
以及
22
,因此 AAAA , BBBB , CCCC 分別為 65536 , 256 以及 16 。

版本三

int ilog32(uint32_t v)
{
    return (31 - __builtin_clz(v|1)); /*(31 - __builtin_clz(DDDD)*/
}

版本三利用了 GNU extension 中的函式 __builtin_clz 來改寫,該函式的功能為返回從最高位數來為 0 的個數,因此若利用總位元數減去該函數之結果就可以得到最高位為 1 的位置,即為所求,所以直覺上可以直接把 v 放入該函式中,但還必需注意到該函式在 input 為 0 時未定義,因此需將 v 改為 v|1 ,已確保 v 不為 0 ,藉此避免未定義的情況發生,因此 DDDDv|1

測驗4

struct ewma *ewma_add(struct ewma *avg, unsigned long val)
{
    avg->internal = avg->internal
                        ? (((avg->internal << avg->weight) - avg->internal) +
                           (val << avg->factor)) >> avg->weight
                        : (val << avg->factor);
    /*(((avg->internal << EEEE) - avg->internal) + (val << FFFF)) >> avg->weight : (val << avg->factor)*/
    return avg;
}

此程式碼用來計算 EWMA ,具體的數學定義為 :

St={Y0t=0αYt+(1α)St1t>0
St
為程式碼中的 internal
Yt
val << avg->factor ,會隨著 factor 值而改變,而
α
12avg>weight

首先,該程式碼應用了三元運算來對 t 是否為 0 的情況分類, 並且先將整個式子除以

α 做運算,因此在程式碼中若 t 不為 0 ,則需做 ((avg->internal << EEEE) - avg->internal) + (val << FFFF) ,即表示
1αSt1St1+Yt
,因此 EEEEavg->weightFFFFavg->factor ,算完後還需乘回原本的
α
,因此最後會有 >> avg->weight

測驗5

int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;

    x--;
    r = (x > 0xFFFF) << 4;
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > 1) + 1; /*(r | shift | x > GGGG)*/   
}

根據測驗3我們可以知道,若要求某數取 log2 的值只要找到該數最高位元的 1 即可,此程式碼也在做相同的事情,利用 rshift 記錄一共位移了幾個 bits ,最後再將他們相加得出結果。

首先,(x > 0xFFFF) << 4 是在判斷 x 的左半邊 16 個 bits 中是否有 1 存在,若有則將 r 設為 16 並且將 x 右半邊 16 個 bits 捨棄,若沒有則將 r 設為 0 , x 保持不變。後續就是不斷的對剩餘的位元數做相同的事情,也就是將該數切半並判斷,最後利用 or 運算將 rshift 以及 (x > 0x1) 相加,因此 GGGG 為 1 。

在程式碼的最後寫了 +1 是對應到了程式碼一開始的 x-- ,這裡考量了 input 本身為 2 的羃的情況,要注意到當沒有 x-- 且 input 為 2 的羃時,最後的結果會多 1 ,因此最後利用 +1 來修正 : 若 input 為 2 的羃,則加減 1 後答案不變,若不為 2 的羃,則加 1 可達到取 ceiling 的效果。

延伸問題 : 改進程式碼,使其得以處理 x = 0 的狀況,並仍是 branchless 。

int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;

    x -= !!x;
    r = (x > 0xFFFF) << 4;
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > 1) + 1;       
}

x-- 改寫為 x -= !!x ,利用兩次的邏輯運算 ! 來判斷 x 的值是否為 0 ,如此一來若 x 為 0 則 - 0 ,若 x 為 1 則 - 1 。

但上述程式碼在 input = 1 時會 output 1 ,以下為更正後的版本 :

int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;

    x-=!!x;
    r = (x > 0xFFFF) << 4;
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > 1) + (x > 0);       
}

第四週測驗題

測驗1

unsigned popcount_branchless(unsigned v)
{
    unsigned n;
    n = (v >> 1) & 0x77777777;
    v -= n;
    n = (n >> 1) & 0x77777777;
    v -= n;
    n = (n >> 1) & 0x77777777;
    v -= n;

    v = (v + (v >> 4)) & 0x0F0F0F0F;
    v *= 0x01010101;                                     

    return v >> 24;
}

這段程式碼用來計算一個 32 bits 無號整數中 1 的個數,也就是在實作 popcount() 函式。

首先,假設一 32 bits 無號整數

x=b31b30b2b1b0 ,則該數中 1 的個數可以將每個位元相加而得,也就是
n=031bn=n=031(2ni=0n12i)bn

而另一種做法為利用數學式去推導,可以得知
xx2x4x231
也是一種算法,該程式碼就是利用這個觀念求得。

直觀的做法為不斷的將 v 除以二,並且用原本的 v 去扣除它,就如上方程式碼中的 n = (v >> 1)v -= n ,不過在 32 bits 的數值上運作就必須重複做 32 次循環,因此該程式碼將其分為 8 組,也就是將每 4 個位元 (nibble) 視為一個單位去做計算,如此本該重複 32 次迭代的程式碼就能簡寫為上方的形式。

首先從公式的部份討論每組 nibble 該如何計算,因為只有 4 個 bits ,所以可以將公式改寫為

n=03bn=n=03(2ni=0n12i)bn
展開後可得
(23222120)b3+(222120)b2+(2120)b1+20b0

而若是以數學是推導,可以得知在 4 個 bits 的情況下可由
xx2x4x8
得出答案,且該式等同於
(23b3+22b2+21b1+20b0)(22b3+21b2+20b1)(21b3+20b2)20b3
並經過改寫後可得
(23222120)b3+(222120)b2+(2120)b1+20b0
,與上述推導出的公式相同,驗證了其正確性。

此程式碼便是以這個概念進行運算,首先,v >> 1 是在進行

x2 的運算,而每多做一次就會再除以 2 並取其底,由於我們讓每 4 個 bits 為一組,因此需要做 3 次至
x23
,而因為每次在除以 2 時皆會將 v 右移一位,如此做法會造成前三組 nibble 的最高位元在移動後都會等於下一組的最低位元,計算出的結果就會錯誤。
直接列出來解釋如下 :

b31 b30 b29 b28 ... b3 b2 b1 b0 // v
  0 b31 b30 b29 ... b4 b3 b2 b1 // v >> 1

右移一位後, b4 會成為第一組 nibble 的最高位元,但我們在運算時不需要 b4 ,因此需將其設定為 0 ,這裡的做法是和 0x77777777& 運算,如此一來便可以同時將八組 nibble 的最高位歸零,不過這裡使用 0xF7777777 也可以達到相同效果,因為右移後最高位本身就會被 0 補上。

b31 b30 b29 b28 ... b3 b2 b1 b0 // v
  0 b31 b30 b29 ...  0 b3 b2 b1 // (v >> 1) & 0x77777777

重複做四次後,可以得到一組新的 32 bits 表示,令為

B7B6B1B0 其中每一
Bn
皆表示一組 nibble ,接著就是要將八組 nibble 所求出來 1 的個數相加,因此我們想要得到的數值為
B7+B6+B5+B4+B3+B2+B1+B0
,實際做法如下 :

   B7      B6      B5      B4      B3      B2      B1      B0    // v
    0      B7      B6      B5      B4      B3      B2      B1    // v >> 4
(B7+0) (B6+B7) (B5+B6) (B4+B5) (B3+B4) (B2+B3) (B1+B2) (B0+B1)  // v + (v >> 4)

先將該數右移一個 nibble 的長度,接著與原數相加,從得到的數觀察可以發現我們想要的數值為偶數組的 nibble 相加,因此再與 0x0F0F0F0F& 運算,可得 :

// & 0x0F0F0F0F
0 (B7+B6) 0 (B5+B4) 0 (B3+B2) 0 (B1+B0)

最後,為了把這些數值相加,我們讓 v 去和 0x01010101 相乘,以直式展開結果如下 :

                                 0 (B7+B6)  0 (B5+B4)  0 (B3+B2)  0 (B1+B0)
                              x  0       1  0       1  0       1  0       1
------------------------------------------------------------------------------
                                 0 (B7+B6)  0 (B5+B4)  0 (B3+B2)  0 (B1+B0)
                      0 (B7+B6)  0 (B5+B4)  0 (B3+B2)  0 (B1+B0)
           0 (B7+B6)  0 (B5+B4)  0 (B3+B2)  0 (B1+B0)
0 (B7+B6)  0 (B5+B4)  0 (B3+B2)  0 (B1+B0)
------------------------------------------------------------------------------
                                      ↑________(B7+B6)+(B5+B4)+(B3+B2)+(B1+B0)

到此可以看到我們需要的結果在相乘後數值的第七個 nibble 的位置,也就是我們可以將前面 24 bits 捨棄掉,因此在程式碼的最後用了 v >> 24 來進行此一操作,最終就能夠得到

(23222120)b31+(222120)b30+(2120)b29+20b28++(23222120)b3+(222120)b2+(2120)b1+20b0 ,即為所求。

測驗2

本測驗告訴我們如何不用除法就能得出某兩數相除之後的餘數,以下用 mod 3 當作例子進行操作和解釋 :
由數學上同餘 (modulo) 的特性可以知道,當

ab(mod  m)
cd(mod  m)
,則
  a+cb+d(mod  m)
acbd(mod  m)
,因此若我們有
11(mod  3)
21(mod  3)
,則可以推出
2k{1(mod  3)if  k  is  even1(mod  3)if  k  is  odd

假設一 32 bits 無號整數

n=b31b30b2b1b0 ,則 n 的值就會等於
i=031bi2i
,由上述推論可得
n=i=031bi(1)i  (mod  3)
。到這裡我們可以發現對於以二進位表示的數來講, mod 3 其實就是將偶數位元相加再減去奇數位元,因此可以利用 population count 類型的函式來進行操作,其中一種做法為 n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA) ,因為 5 為 0101 , A 為 1010 ,所以如此運算便能將奇偶穿插著相減,達到預期的效果。

而上述做法其實還有別的寫法,那就是利用以下定理將其簡化 :

popcount(x&m)popcount(x&m)=popcount(xm)popcount(m)
因此我們就能將其改寫為 n = popcount(n ^ 0xAAAAAAAA) - 16 ,但這裡必須注意到 popcount(n ^ 0xAAAAAAAA) 這個值的範圍為 0 < n < 32 , -16 這個動作可能會導致算出來的值為負數,因此需透過加上 3 的倍數避免這樣的情況發生,而該數只要大於 16 即可,以下利用 39 做說明。

int mod3(unsigned n)
{
    n = popcount(n ^ 0xAAAAAAAA) + 23;
    n = popcount(n ^ 0x2A) - 3;
    return n + ((n >> 31) & 3);
}

程式碼第一行是由上述推倒而來的,也就是 n = popcount(n ^ 0xAAAAAAAA) - 16 + 39 ,但此時 n 的範圍就會變成 23 < n < 55 ,而我們希望最終的結果會落在 0 < n < 2 (因為 mod 3 ),因此可以再次利用上述定理縮小其範圍。
在經過第一行的運算後,其範圍已經變為 23 < n < 55 ,只有最低位的 6 個 bits 可能會有 1 存在,因此根據定理,第二行即為 popcount(n ^ 0x2A) - 3 (因為 0x2A 為 101010 ),此時 n 的範圍為 -3 < n < 2 ,我們只需再處理負數的部分即可。
處理的方法為得到負數就將其 + 3 ,首先,透過右移 31 位來判斷正負,若為 1 則代表是負數,需要 + 3 ,於是和 3 做 & 運算,此時會發現要用全 1 去做 & 才會是 + 3 的結果,也就是右移需要是算術位移而非邏輯位移,但上方的 input 卻是 unsigned n ,因此該程式碼有個小 bug 存在,可以透過將 input 宣告成 int 來解決,或者是修改 return 的部分,修改過後的程式碼如下 :

int mod3(unsigned n)
{
    n = __builtin_popcount(n ^ 0xAAAAAAAA) + 23;
    n = __builtin_popcount(n ^ 0x2A) - 3;
    return n + (((n >> 31) | (n >> 30)) & 3);
}

在最後一行的地方,我將 (n >> 31)(n >> 30) 做 or 運算,由於 3 的二進位表示法為 011 ,因此只須考慮最後 2 個 bits 即可,並且利用 or 來達到算術位移的效果,也就是將負數的高位元填滿 1 。
而事實上若改寫為這樣也不需要再 &3 了,因為若是負數則位移完會得到 011 ,若是正數則為 0 ,因此最終的程式碼為 :

int mod3(unsigned n)
{
    n = __builtin_popcount(n ^ 0xAAAAAAAA) + 23;
    n = __builtin_popcount(n ^ 0x2A) - 3;
    return n + ((n >> 31) | (n >> 30));
}

還有另一種做法為利用 lookup table :

int mod3(unsigned n)
{
    static char table[33] = {2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1 };
    n = popcount(n ^ 0xAAAAAAAA);
    return table[n];
}

首先要建立一個表格,該表格的內容為 mod 3 後可能的值,也就是 {012} ,而大小則取決於 n 的範圍,這裡是 0 ~ 32 ,因此表格大小為 33 。至於表格內值的排序則是由 n 所決定的,在這裡因為 n = popcount(n ^ 0xAAAAAAAA) ,可以想成 n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA) + 16 ,因此對於 table[0] 來講,要找到當程式碼內的 n = 0 時,實際上 mod 3 為多少,而這裡只要找出前 3 項,然後依序填完整個 table 即可 :
n = 0 : -16 mod 3 為 2 ,所以 table[0] = 2 ;
n = 1 : -15 mod 3 為 0 ,所以 table[1] = 0 ;
n = 2 : -14 mod 3 為 1 ,所以 table[2] = 1 ;
如此填完便可以得到正確的值。

將上述的概念應用於 tictactoe.c 程式碼,該程式碼用來模擬 100 萬次井字遊戲的對奕,並統計出第一次選擇 9 個不同位置的勝率,具體的輸出結果如下 :

Win probability for first move with random agents:
0.115 0.101 0.115 
0.101 0.132 0.102 
0.116 0.102 0.116 
Player 1 won 585067 times
Player 2 won 288241 times
126692 ties
0.055166 seconds
18.127107 million games/sec
static const uint32_t move_masks[9] = {
    0x40040040, 0x20004000, 0x10000404, 0x04020000, 0x02002022,
    0x01000200, 0x00410001, 0x00201000, 0x00100110,
};

至於程式碼的部份,首先可以看到位置是用 16 進位 32 bits 無號數來表示的,原因為 8 個 bits 正好可以對應到九宮格的 8 種連線方法,更具體一點的說,由最高位元開始依序往下會對應到的是九宮格中由上到下、由左到右的橫一、橫二、橫三、直一、直二、直三、右斜和左斜,而一組 nibble 中的前 3 個位元則代表了該位置是否已經被放置。以左上角第一個位置做說明,對於橫一、直一和右斜來講,就會是 100 ,其餘皆為 000 ,因此整體就會是 0x40040040

static inline uint32_t is_win(uint32_t player_board)
{
    return (player_board + 0x11111111) & 0x88888888; /*(player_board + BBBB)*/
}

此段程式碼是用來決定是否已經存在一方獲勝的狀態,也就是上述八種可能有其中一種被填滿,舉例來說,如果是橫二被填滿,那當下的狀態應為 0x07022222 ,換句話說,被填滿的那條線必定會出現 7(0b0111) 這個數字,因此可以將當下的狀態加上 0x11111111 並和 0x8888888& 運算,若有值則代表有存在著連線,否則會 return 0
BBBB0x11111111

static inline int mod7(uint32_t x)
{
    x = (x >> 15) + (x & UINT32_C(0x7FFF)); /*(x >> CCCC)*/
    /* Take reminder as (mod 8) by mul/shift. Since the multiplier
     * was calculated using ceil() instead of floor(), it skips the
     * value '7' properly.
     *    M <- ceil(ldexp(8/7, 29))
     */
    return (int) ((x * UINT32_C(0x24924925)) >> 29);
}

此段程式碼想要得到 mod 7 的結果,在第一行利用了 Hacker's Delight 提出"若將 n 拆成兩部分並且將左半右移加上右半不會影響同餘的結果"的結論來進行範圍的縮小,可以看到 x & UINT32_C(0x7FFF) 是在計算右半邊 15 位,因此 x >> CCCC 就是在計算左半邊 17 位, CCCC 為 15 。
而這段程式碼也應用了另一定理 :

n  (mod  7)87n  (mod  8) ,因此在程式碼的最後出現了 0x24924925 這個數字,該數字即為
2327
,而再往右位移 29 位即可得到
87

測驗3

XTree 為一種兼具 AVL tree 和 red-black tree 部分特性的結構,以下將說明節點刪除的部份 :

static void __xt_remove(struct xt_node **root, struct xt_node *del)
{
    if (xt_right(del)) {
        struct xt_node *least = xt_first(xt_right(del));
        if (del == *root)
            *root = least;

        xt_replace_right(del, least); /*xt_replace_right(del, AAAA);*/
        xt_update(root, xt_right(least)); /*xt_update(root, BBBB)*/
        return;
    }

    if (xt_left(del)) {
        struct xt_node *most = xt_last(xt_left(del));
        if (del == *root)
            *root = most;

        xt_replace_left(del, most); /*xt_replace_left(del, CCCC);*/
        xt_update(root, xt_left(most)); /*xt_update(root, DDDD)*/
        return;
    }

    if (del == *root) {
        *root = 0;
        return;
    }

    /* empty node */
    struct xt_node *parent = xt_parent(del);

    if (xt_left(parent) == del)
        xt_left(parent) = 0;
    else
        xt_right(parent) = 0;

    xt_update(root, parent); /*xt_update(EEEE, FFFF)*/
}

在完整的程式碼中所定義的函式 :

  • xt_first(n) : 尋找由節點 n 所代表的子樹中最左側的節點,也就是該子樹中最小的節點。
  • xt_last(n) : 尋找由節點 n 所代表的子樹中最右側的節點,也就是該子樹中最大的節點。
  • xt_replace_right(n, r) : 用右子樹中的節點 r 來替換調節點 n ,並保持樹狀結構。
  • xt_replace_left(n, l) : 用左子樹中的節點 l 來替換調節點 n ,並保持樹狀結構。
  • xt_update(n, m) : 用來更新插入或刪除節點後的樹,利用一參數 hint 檢查樹是否平衡,以此決定是否需要 rotate 。

從程式碼的第一段可以看到,該程式碼會先判斷欲刪除的節點是否存在右子樹,若存在,則必須找到該子樹中最小的節點與 del 互換,並且在換完後重新調整以保持平衡,因此 AAAAleastBBBBxt_right(least) 。而左子樹也是同樣道理,但由於 BST 的特性,若要交換則要找左子樹中最大的節點,因此使用了 xt_last 函式,其餘的部份就是將右子樹處理的方法稍做變換,因此 CCCCmostDDDDxt_left(most)
在程式碼的最後,還需檢查整棵樹是否達到平衡,因此 EEEErootFFFFparent


問題記錄

  • 第三周測驗二不懂為何要有 (in | 1) ,因為就算不強行轉為奇數即果還是正確的。
  • 第四周測驗二的最後知道算出了
    87
    ,但為何這樣就同時做了 mod8 的動作。
  • 在撰寫 mod9 和 mod5 的程式碼時,是由逼近的方法求解的,但數字一放大微小的差異也就跟著放大,因此我用了 if 來做調整,但好像可以不用這樣,應該從數值本身去探討其他寫法。

補充議題 : 浮點數乘法的實作探討

float_mul2 探討

IEEE 754, float is single precision. Assume 32-bit

float float_mul2(float x)
{
    // using bitwise operation, no mul
    int a = *((int*)&x);
    int b = *((int*)&x);
    a = (a & 0x7F800000) >> 23;
    a++;
    b = (b & 0x807FFFFF) | (a << 23);
    return *((float*)&b);
}

考慮 overflow 後的做法 :

float float_mul2(float x)
{
    // using bitwise operation, no mul
    int a = *((int*)&x);
    int b = *((int*)&x);
    a = (a & 0x7F800000) >> 23;
    a++;
    if(a == 0xFFFFFFFF)
        return NAN;
    else 
        b = (b & 0x807FFFFF) | (a << 23);
    return *((float*)&b);
}

https://godbolt.org/

  • Instruction latency
    從老師提供的 instruction tables 找到對應的處理器架構,並且將上方程式碼轉換成組合語言去逐行計算,得到的結果為 :
float_mul2:                             # @float_mul2
        push    rbp
        mov     rbp, rsp
        movss   dword ptr [rbp - 4], xmm0
        mov     eax, dword ptr [rbp - 4]
        mov     dword ptr [rbp - 8], eax
        mov     eax, dword ptr [rbp - 4]
        mov     dword ptr [rbp - 12], eax
        mov     eax, dword ptr [rbp - 8]
        and     eax, 2139095040
        sar     eax, 23
        mov     dword ptr [rbp - 8], eax
        mov     eax, dword ptr [rbp - 8]
        add     eax, 1
        mov     dword ptr [rbp - 8], eax
        mov     eax, dword ptr [rbp - 12]
        and     eax, -2139095041
        mov     ecx, dword ptr [rbp - 8]
        shl     ecx, 23
        or      eax, ecx
        mov     dword ptr [rbp - 12], eax
        movss   xmm0, dword ptr [rbp - 12]      # xmm0 = mem[0],zero,zero,zero
        pop     rbp
        ret
Move instructions Arithmetic instructions Logic Control transfer instructions
push * 1, mov * 12, movss * 2, pop * 1 add * 1 and * 2, sar * 1, shl * 1, or * 1 ret * 1

計算過後可知該程式總共需要 29 (2+12+2+1+1+2+7+1+1)個 clock cycle 。

以下為稍微改進過後的版本,嘗試只用一個變數來節省記憶體空間 :

float float_mul2(float x)
{
    // using bitwise operation, no mul
    int a = *((int*)&x);
    a = ((a & 0x7F800000) + 0x00800000) | (a & 0x807FFFFF);
    return *((float*)&a);
}

而這麼做還是太複雜,因為要先得到 exponent 的那 8 個 bits 加 1 後再把其他部分補回,因此更簡單的想法為直接在 exponent 那項加 1 ,修改過後的程式碼如下 :

float float_mul2(float x)
{
    // using bitwise operation, no mul
   *(int*)&x += 0x00800000;
    return x;
}

轉換為組合語言 :

float_mul2:                             # @float_mul2
        push    rbp
        mov     rbp, rsp
        movss   dword ptr [rbp - 4], xmm0
        mov     eax, dword ptr [rbp - 4]
        add     eax, 8388608
        mov     dword ptr [rbp - 4], eax
        movss   xmm0, dword ptr [rbp - 4]       # xmm0 = mem[0],zero,zero,zero
        pop     rbp
        ret
Move instructions Arithmetic instructions Logic Control transfer instructions
push * 1, mov * 3, movss * 2, pop * 1 add * 1 none ret * 1

計算過後可知該程式總共需要 9 (2+3+2+1+1)個 clock cycle ,可以發現確實比第一個版本少了許多。

參考 類神經網路的 ReLU 及其常數時間複雜度實作 教材,裡面提到了由於對浮點數的操作成本太高,因此可以使用 union 的技巧來優化改善, union 可以讓浮點數和整數共用一段記憶體空間,並且在任一時刻只有一個成員有效,宣告的方法如下 :

union{
    float f;
    int i;
} out = {.f=x};

宣告一個名為 out 的 union ,裡面輸入需要操作的資料型態,而最後的 {.f=x} 是在做資料得初始化,將 input x 的值指派到 f ,後續則可以根據需求使用 out.fout.i 來獲得兩種不同的資料型態。
以下給出一個簡單的例子 :

//code
float example(float x){
    union{
    float f;
    int i;
    } out = {.f=x};
    
    printf("out.i = %d\n", out.i);
    printf("out.f = %f\n", out.f);
}

int main()
{
    float x = 2;
    example(x);
    return 0;
}

//output
out.i = 1073741824
out.f = 2.000000

原因為在這段記憶體裡存著 0x40000000 ,以浮點數表示為 2 ,以整數表示則為

230 = 1073741824 。

最後,將原本的程式碼改寫 :

float float_mul2(float x)
{
    // using bitwise operation, no mul
    union{
        float f;
        int i;
    } out = {.f=x};
    
    out.i += 0x00800000;
    return out.f;
}

轉換為組合語言 :

float_mul2:                             # @float_mul2
        push    rbp
        mov     rbp, rsp
        movss   dword ptr [rbp - 4], xmm0
        movss   xmm0, dword ptr [rbp - 4]       # xmm0 = mem[0],zero,zero,zero
        movss   dword ptr [rbp - 8], xmm0
        mov     eax, dword ptr [rbp - 8]
        add     eax, 8388608
        mov     dword ptr [rbp - 8], eax
        movss   xmm0, dword ptr [rbp - 8]       # xmm0 = mem[0],zero,zero,zero
        pop     rbp
        ret
Move instructions Arithmetic instructions Logic Control transfer instructions
push * 1, mov * 3, movss * 4, pop * 1 add * 1 none ret * 1

計算過後可知該程式總共需要 11 (2+3+4+1+1)個 clock cycle ,雖然相較於第二個版本多了 2 個 clock cycle ,但用 union 的寫法會比較安全,因為強制將資料型別轉換可能會造成 strict aliasing ,在 C 語言中會發生為定義的行為。

float_mul_power_of_2 探討

float float_mul_power_of_2(float x, int e)
{
    // using bitwise operation, no mul
    union{
        float f;
        int i;
    } out = {.f=x};
    
    out.i += (e << 23);
    return out.f;
}

利用和上述 float_mul2 相同的概念,若要乘上 2 的冪,則把指數部分的 +1 改為 +e 即可,但若是直接用 (e << 23) 這樣的寫法會有問題 :
若該整數 e 夠大或夠小,則位移後會影響到最左方的 sign bit,因此我們需想辦法將其大小控制在 8 個 bits 以內,以符合 IEEE 754 的範圍。
解決方法 :
首先,由於浮點數的指數部分的範圍為 0 ~ 255 ,因此我們可以把指數部分單獨取出來看是否有在範圍內,對應的程式碼如下 :

int exp = (out.i >> 23) & 0xFF;
exp += e;
if(exp < 0 || exp > 255)
return 0;

該程式碼的作用為先將 fraction bits 23 位過濾掉,然後再和 8 個 1 做 & 運算,以取得原數的指數部分,接著用整數型態的變數和 input e 相加,最後再判斷是否有在範圍內,完整的程式碼如下 :

float float_mul_power_of_2(float x, int e)
{
    // using bitwise operation, no mul
    union{
        float f;
        int i;
    } out = {.f=x};
    
    int exp = (out.i >> 23) & 0xFF;
    exp += e;
    if(exp < 0 || exp > 255)
    return 0;
    else
    out.i = (out.i & 0x807FFFFF) | (exp << 23);
    return out.f;
}

其中 return 0 的部分若是引進 math.h 函式庫則可以再細分為不同的情況來表達,以下為修改過後的程式碼 :

float float_mul_power_of_2(float x, int e)
{
    // using bitwise operation, no mul
    union{
        float f;
        int i;
    } out = {.f=x};
    
    int exp = (out.i >> 23) & 0xFF;
    exp += e;
    if(exp < 0)
    return 0;
    else if(exp >= 255)
    return INFINITY;
    else
    out.i = (out.i & 0x807FFFFF) | (exp << 23);
    return out.f;
}

TODO: 強化延伸問題指定的 Linux 核心應用場景

不該只有列出 Linux 核心原始程式碼,並適度予以解說和推論其設計考量