Try  HackMD Logo HackMD

Linux 核心專題: 位元運算

執行人: alexc-313
解說錄影

Reviewed by MikazukiHikari

牛頓迭代法

// 檢查是否收斂
        if (new_guess == guess || (new_guess > guess && new_guess - guess <= 1) || (guess > new_guess && guess - new_guess <= 1)) {
            // 返回較小的整數作為平方根的整數部分
            return (new_guess * new_guess <= x) ? new_guess : new_guess - 1;
}

原判斷條件三重且略微冗餘,是否可簡化為:

if (abs(new_guess - guess) <= 1)

alexc-313
經過測試後發現此方法是可行的,感謝回饋。

Reviewed by rota1001

在浮點數開平方根的部份,是否有注意到指數部份是奇數的情況?在程式碼部份我有看到這個問題的處理,不過文字敘述方面沒有講到這個,如果補上會比較完整。

alexc-313
已補上文字說明,感謝回饋。

Reviewed by thwuedwin

我們可以用最簡單的二分查找法來實作開平方根,此方法簡單但運用多次除法,成本較高。

你的 x 是整數,所以你判斷 mid <= x / mid 這個是不是可以換成 mid * mid <= x。頂多確認一下是否有溢位,但有溢位就代表 mid 還太大。除以 2 也可以換成位元運算。解決了除法的問題,二分查找法還有什麼問題?

alexc-313
就算換成 mid * mid <= x 且假設檢查溢位的成本可忽略,以及用位元運算取代除法,相比於位元運算法,二分查找法需要利用乘法實際計算平方根,且在最差情況下,位元運算法需要~15次迴圈,而二分查找法需要~31次迴圈。

另外你確定你的浮點數平方根是正確的嗎?你只計算 mantissa 的部分,再乘上 exponent 的部分,你算出來的平方根值都會是偶數吧?會有一堆精度被吃掉。

alexc-313
已補上測試結果,若使用 64 位元則不會損失任何精度,也不會有都是偶數的問題。

Reviewed by Ian-Yen

int mid = left + (right - left) / 2;
這邊程式碼中使用要 / 2 的部分,能夠使用位元運算應盡量使用位元運算,也就是 ((right - left) >> 1)

alexc-313
感謝回饋,在二分查找法與牛頓迭代法的程式碼中確實有許多能改進的地方。

Reviewed by Max042004

好難,還在看
情境 0 的 while (bit > x) bit >>= 2 也可以用 clz,最差情況下應該會比一次移兩位好。

alexc-313
感謝您的回饋,已用 clz 對程式做出改進,並確認使用 __builtin_clzll 能夠達到與 sqrtf 相近的效能與精度。

Reviewed by LeoriumDev

int_sqrt 函式當中 while 迴圈判斷式可簡化為:

- while (bit != 0) {
+ while (bit) {

可利用位元運算將程式碼從 3 行簡化為 2 行:

- uint64_t bit = (uint64_t)1 << 62;
- int lz = __builtin_clzll(x);
- bit >>= lz - (lz & 1);

+ int shift = (63 - __builtin_clzll(x)) & ~1;
+ uint64_t bit = 1ULL << shift;

alexc-313
感謝回饋,確實此做法會讓程式更精簡。

Reviewed by h0w726

牛頓迭代法

 // 初始猜測值,可以選擇 x/2 或其他合理的值
    int guess = x / 2;

選擇

2x 的原因是什麼,其他合理的值要怎麼猜測?

alexc-313
感謝回饋,選擇

x2 的原因是因為其最直觀,也可以選擇其他初始值,例如當
 x[1,100]
時使用 guess = 0.4 * x + 0.6

TODO: 以定點數實作開平方根

情境 0: 輸入是 int 型態 (採用 LP64 data model),輸出也是 int 型態
情境 1: 輸入是 IEEE float (單精度),輸出是 int 型態
情境 2: 輸入是 IEEE float,輸出也是 float 型態
針對上述情境,皆不可使用 FPU,只能藉由定點數予以計算,務必降低誤差並比較 glibc 的效能表現,提出改進方案。要有數學分析,並提供多種實作手法

情境 0

二分查找法

int bin_sqrt(int x)
{
    if (x < 0) return 0;
    if (x <= 1) return x;

    int left = 0, right = x;
    int result = 0;

    while (left <= right) {
        int mid = left + (right - left) / 2;
        if (mid <= x / mid) {
            left = mid + 1;
            result = mid;
        } else {
            right = mid - 1;
        }
    }
    
    return result;
}

我們可以用最簡單的二分查找法來實作開平方根,此方法簡單但運用多次除法,成本較高。

牛頓迭代法

int new_sqrt(int x)
{
    if (x < 0) return 0;
    if (x <= 1) return x;
    
    // 初始猜測值,可以選擇 x/2 或其他合理的值
    int guess = x / 2;
    
    // 避免無限循環,設置最大迭代次數
    int max_iterations = 20;
    int iteration = 0;
    
    while (iteration < max_iterations) {
        int new_guess = (guess + x / guess) / 2;
        
        // 檢查是否收斂
        if (new_guess == guess || (new_guess > guess && new_guess - guess <= 1) || 
            (guess > new_guess && guess - new_guess <= 1)) {
            // 返回較小的整數作為平方根的整數部分
            return (new_guess * new_guess <= x) ? new_guess : new_guess - 1;
        }
        
        guess = new_guess;
        iteration++;
    }
    
    return (guess * guess <= x) ? guess : guess - 1;
}

位元運算法

int bit_sqrt(int x)
{
    uint32_t res = 0;
    uint32_t bit = 1UL << 30;

    while (bit > x) bit >>= 2;

    while (bit != 0) {
        if (x >= res + bit) {
            x -= res + bit;
            res = (res >> 1) + bit;
        }
        else res >>= 1;
        bit >>= 2;
    }
    
    return res;
}

此方法基於逐位試探法,從最高有效位開始,逐步向下檢查每個可能的位是否應該包含在平方根中。它類似於長除法的手動計算方式。

  • 初始化: uint32_t bit = 1UL << 30;
    232
    的平方根為 65536 ,需要 16 位元來表示,所以再任意 32 位元的整數中,最大的可能平方根需要用 15 位元來表示,又因為開平方後的位元大約為原本的兩倍,所以從 1 << 30 開始。接下來用 while (bit > x) bit >>= 2; 快速移動 bit 至初始位置。
  • 主迴圈:
    1. if (x >= res + bit) → 測試是否可以加入當前 bit。
    2. x -= res + bit → 減去對應的平方貢獻。
    3. res = (res >> 1) + bit → 更新平方根。
    4. bit >>= 2 → 處理下一個更低的位。

測試程式:

bool err = false;
for (size_t i = 0; i < INT_MAX; i++)
{
    int a = my_sqrt(i), b = sqrt(i);
    if (a != b) {
        printf("your answer: %d, correct answer: %d\n", a, b);
        err = true;
    } 
}
if (!err) {
    printf("all correct\n");
}

以上方法皆通過此測試。

效能測試

Figure_2

情境 1

首先我們先分析單精度 IEEE 754:

upload_2c378b894fc4efabab3388eae732bc4f

單精度 IEEE 754 可分為三個部分,分別為 Sign (S)、 Exponent (E)、 Mantissa (M),以下分別討論各部分在求平方根時需做的處理。

由於我們不討論虛數,所以可忽略 S,而 E 及 M 可以表示為:

M2E127=M2E1272由此可知,我們只須求出 M 的平方根,再將其乘上
2E1272
就可得出浮點數的平方根。
在不可使用 FPU,只能藉由定點數予以計算的條件下,我們先將 M 轉換為定點數,利用上述的位元運算法求出其平方根,再將其轉換為 int 即可得到浮點數的平方根。若遇到
2E1272
為奇數的狀況,則將 M 用位移方式乘二,就可避免 e >>= 1 奇數時損失精度。

int my_sqrt(float x)
{
    if (x < 1) return 0;
    if (x == 1.0f) return 1;
    uint32_t bits;
    memcpy(&bits, &x, sizeof(x));
    int e = ((bits >> 23) & 0xFF) - 127;
    uint32_t m = (bits & 0x7FFFFF) | 0x800000; //Q16.16
    m >>= 7;
    if (e & 1) m <<= 1;
    uint32_t root = int_sqrt(m); //Q8.8
    e >>= 1;
    root <<= 8; //convert from Q8.8 to Q16.16
    uint64_t root64 = (uint64_t)root;
    root64 <<= e;
    return root64 >> 16;
}

測試結果如下:

          Input         my_sqrt     sqrtf (int)       Abs Error
---------------------------------------------------------------
   0.000000e+00               0               0               0
  -1.000000e+00               0     -2147483648     -2147483648
   1.000000e-10               0               0               0
   1.000000e-04               0               0               0
   2.500000e-01               0               0               0
   5.000000e-01               0               0               0
   7.500000e-01               0               0               0
   1.000000e+00               1               1               0
   1.500000e+00               1               1               0
   2.000000e+00               1               1               0
   3.141590e+00               1               1               0
   4.000000e+00               2               2               0
   9.000000e+00               3               3               0
   1.600000e+01               4               4               0
   1.000000e+02              10              10               0
   1.234568e+04             111             111               0
   1.000000e+10           99840          100000            -160

首先當 x 為

1.000000×100 時,sqrtf (int) 回傳的數值很詭異,這是因為原本回傳值為 nan ,再轉換為 int 所造成的結果。
再來看到當 x 為
1.000000×1010
時,此方法產生了顯著的誤差,推斷是我使用 Q16.16 表示原有23位 mantissa 的浮點數,在 m >>= 7 時損失了後七位元提供的精度。

為了確認我的推斷,我改進了 int_sqrt 以及 my_sqrt,用 uint64_t 讓我能夠用 Q32.32 表示浮點數。

uint64_t int_sqrt64(uint64_t x)
{
    uint64_t res = 0;
    uint64_t bit = (uint64_t)1 << 62;

    while (bit > x) bit >>= 2;

    while (bit != 0) {
        if (x >= res + bit) {
            x -= res + bit;
            res = (res >> 1) + bit;
        }
        else res >>= 1;
        bit >>= 2;
    }
    
    return res;
}

int my_sqrt64(float x)
{
    if (x < 1) return 0;
    if (x == 1.0f) return 1;
    uint64_t bits;
    memcpy(&bits, &x, sizeof(x));
    int e = ((bits >> 23) & 0xFF) - 127;
    uint64_t m = (bits & 0x7FFFFF) | 0x800000; //Q32.32
    m <<= 9;
    if (e & 1) m <<= 1;
    uint64_t root = int_sqrt64(m); //Q16.16
    e >>= 1;
    root <<= 16; //convert from Q16.16 to Q32.32
    root <<= e;
    return root >> 32; //convert from Q32.32 to int
}

測試結果如下:

          Input       my_sqrt64     sqrtf (int)       Abs Error
---------------------------------------------------------------
   0.000000e+00               0               0               0
  -1.000000e+00               0     -2147483648     -2147483648
   1.000000e-10               0               0               0
   1.000000e-04               0               0               0
   2.500000e-01               0               0               0
   5.000000e-01               0               0               0
   7.500000e-01               0               0               0
   1.000000e+00               1               1               0
   1.500000e+00               1               1               0
   2.000000e+00               1               1               0
   3.141590e+00               1               1               0
   4.000000e+00               2               2               0
   9.000000e+00               3               3               0
   1.600000e+01               4               4               0
   1.000000e+02              10              10               0
   1.234568e+04             111             111               0
   1.000000e+10          100000          100000               0

可見此版本能夠通過 x 為

1.000000×1010

效能測試

Figure_1

由於效能並不理想,我再次審視程式碼後發現,由於 m = (bits & 0x7FFFFF) | 0x800000 ,只會用到 24 位元,這代表不需要使用 uint64_t ,可以使用 uint32_t 並採用 Q8.24 表示法,這樣不僅節省記憶體的使用,也可以增加效率

int my_sqrt_new(float x)
{
    if (x < 1) return 0;
    if (x == 1.0f) return 1;
    uint32_t bits;
    memcpy(&bits, &x, sizeof(x));
    int e = ((bits >> 23) & 0xFF) - 127;
    uint32_t m = (bits & 0x7FFFFF) | 0x800000; //Q16.16
    m <<= 1; //Q8.24
    if (e & 1) m <<= 1;
    uint32_t root = int_sqrt(m); //Q4.12
    e >>= 1;
    root <<= e;
    return root >> 12;
}

效能測試

Figure_1

可見效能相較於先前有顯著的進步,與 glibc 相近。

情境 2

float my_sqrt_f(float x) {
    if (x < 0) return NAN;
    uint64_t bits;
    memcpy(&bits, &x, sizeof(x));
    int e = ((bits >> 23) & 0xFF) - 127;
    uint64_t m = (bits & 0x7FFFFF) | 0x800000; //Q32.32
    m <<= 9;
    if (e & 1) m <<= 1;
    uint64_t root = int_sqrt64(m); //Q16.16
    e >>= 1;
    // Reconstruct float manually
    uint32_t new_bits = ((e + 127) << 23) | ((root << 7) & 0x7FFFFF);
    float result;
    memcpy(&result, &new_bits, sizeof(result));
    return result;
}

這裡和情境 1 所用的方法類似,差別在於最後不是將定點數轉為 int 表示,而是重組成 float。 (root << 7) 會將原本 Q16.16 轉換為 23 位元的 mantissa ,再搭配 (e + 127) 的位移,就完成 float 的重組了。

測試結果如下:

          Input       my_sqrt_f           sqrtf       Abs Error
---------------------------------------------------------------
   0.000000e+00        0.000000        0.000000        0.000000
  -1.000000e+00             nan             nan             nan
   1.000000e-10        0.000010        0.000010       -0.000000
   1.000000e-04        0.010000        0.010000       -0.000000
   2.500000e-01        0.500000        0.500000        0.000000
   5.000000e-01        0.707100        0.707107       -0.000007
   7.500000e-01        0.866020        0.866025       -0.000005
   1.000000e+00        1.000000        1.000000        0.000000
   1.500000e+00        1.224731        1.224745       -0.000013
   2.000000e+00        1.414200        1.414214       -0.000014
   3.141590e+00        1.772446        1.772453       -0.000008
   4.000000e+00        2.000000        2.000000        0.000000
   9.000000e+00        3.000000        3.000000        0.000000
   1.600000e+01        4.000000        4.000000        0.000000
   1.000000e+02       10.000000       10.000000        0.000000
   1.234568e+04      111.110352      111.111107       -0.000755
   1.000000e+10   100000.000000   100000.000000        0.000000

可見在非整數時會有些許誤差,我推斷是 uint64_t root = int_sqrt64(m); //Q16.16 ,由於開根號後會從原本的 Q32.32 降成 Q16.16 ,此時只有 16 bit 來表示原本 23 位元的 mantissa ,犧牲的精度會造成些微誤差。
於是我決定採用 Q18.46 來表示 float ,這樣開根號後就變成 Q9.23 ,能夠有效的保留小數點後的精度。

float my_sqrt_f_new(float x) {
    if (x < 0) return NAN;
    uint64_t bits;
    memcpy(&bits, &x, sizeof(x));
    int e = ((bits >> 23) & 0xFF) - 127;
    uint64_t m = (bits & 0x7FFFFF) | 0x800000; //Q18.46
    m <<= 23;
    if (e & 1) m <<= 1;
    uint64_t root = int_sqrt64(m); //Q9.23
    e >>= 1;
    // Reconstruct float manually
    uint32_t new_bits = ((e + 127) << 23) | ((root ) & 0x7FFFFF);
    float result;
    memcpy(&result, &new_bits, sizeof(result));
    return result;
}

測試結果如下:

          Input   my_sqrt_f_new           sqrtf       Abs Error
---------------------------------------------------------------
   0.000000e+00        0.000000        0.000000        0.000000
  -1.000000e+00             nan             nan             nan
   1.000000e-10        0.000010        0.000010        0.000000
   1.000000e-04        0.010000        0.010000        0.000000
   2.500000e-01        0.500000        0.500000        0.000000
   5.000000e-01        0.707107        0.707107        0.000000
   7.500000e-01        0.866025        0.866025        0.000000
   1.000000e+00        1.000000        1.000000        0.000000
   1.500000e+00        1.224745        1.224745       -0.000000
   2.000000e+00        1.414214        1.414214        0.000000
   3.141590e+00        1.772453        1.772453       -0.000000
   4.000000e+00        2.000000        2.000000        0.000000
   9.000000e+00        3.000000        3.000000        0.000000
   1.600000e+01        4.000000        4.000000        0.000000
   1.000000e+02       10.000000       10.000000        0.000000
   1.234568e+04      111.111107      111.111107        0.000000
   1.000000e+10   100000.000000   100000.000000        0.000000

可見所有測試皆與 sqrtf 的結果相同。

效能測試

Figure_1

由於需使用 64 位元的 int_sqrt64 ,所以效能與 sqrtf 有顯著差距。

若要增加效能,我們可以只用 uint32_t ,並用 Q2.30 來表示 float ,但就會損失 8 位元的精確度。

float my_sqrt_f32(float x) {
    if (x < 0) return NAN;
    uint32_t bits;
    memcpy(&bits, &x, sizeof(x));
    int e = ((bits >> 23) & 0xFF) - 127;
    uint32_t m = (bits & 0x7FFFFF) | 0x800000; //Q2.30
    m <<= 7; //Q2.30
    if (e & 1) m <<= 1;
    uint32_t root = int_sqrt(m); //1.15
    e >>= 1;
    // Reconstruct float manually
    uint32_t new_bits = ((e + 127) << 23) | ((root << 8) & 0x7FFFFF);
    float result;
    memcpy(&result, &new_bits, sizeof(result));
    return result;
}

測試結果如下:

          Input     my_sqrt_f32           sqrtf       Abs Error
---------------------------------------------------------------
   0.000000e+00        0.000000        0.000000        0.000000
  -1.000000e+00             nan             nan             nan
   1.000000e-10        0.000010        0.000010       -0.000000
   1.000000e-04        0.010000        0.010000       -0.000000
   2.500000e-01        0.500000        0.500000        0.000000
   5.000000e-01        0.707092        0.707107       -0.000014
   7.500000e-01        0.866013        0.866025       -0.000013
   1.000000e+00        1.000000        1.000000        0.000000
   1.500000e+00        1.224731        1.224745       -0.000013
   2.000000e+00        1.414185        1.414214       -0.000029
   3.141590e+00        1.772430        1.772453       -0.000023
   4.000000e+00        2.000000        2.000000        0.000000
   9.000000e+00        3.000000        3.000000        0.000000
   1.600000e+01        4.000000        4.000000        0.000000
   1.000000e+02       10.000000       10.000000        0.000000
   1.234568e+04      111.109375      111.111107       -0.001732
   1.000000e+10   100000.000000   100000.000000        0.000000

但在效能上就能與 glibc 相提並論

Figure_3

感謝 Max042004 提出可以使用 clz 改進 while (bit > x) bit >>= 2

uint64_t int_sqrt(uint64_t x)
{
    uint64_t res = 0;
    uint64_t bit = (uint64_t)1 << 62;
    
    /*while (bit > x) bit >>= 2;*/    //replaced by __builtin_clzll
    
    int lz = __builtin_clzll(x);
    bit >>= lz - (lz & 1);
    
    while (bit != 0) {
        if (x >= res + bit) {
            x -= res + bit;
            res = (res >> 1) + bit;
        }
        else res >>= 1;
        bit >>= 2;
    }
    
    return res;
}

這邊我使用 <x86intrin.h> 中的 __builtin_clzll 快速得出 clz ,取代原本的 while 迴圈。

測試結果如下:

          Input   my_sqrt_f_new           sqrtf       Abs Error
---------------------------------------------------------------
   0.000000e+00        0.000000        0.000000        0.000000
  -1.000000e+00             nan             nan             nan
   1.000000e-10        0.000010        0.000010        0.000000
   1.000000e-04        0.010000        0.010000        0.000000
   2.500000e-01        0.500000        0.500000        0.000000
   5.000000e-01        0.707107        0.707107        0.000000
   7.500000e-01        0.866025        0.866025        0.000000
   1.000000e+00        1.000000        1.000000        0.000000
   1.500000e+00        1.224745        1.224745       -0.000000
   2.000000e+00        1.414214        1.414214        0.000000
   3.141590e+00        1.772453        1.772453       -0.000000
   4.000000e+00        2.000000        2.000000        0.000000
   9.000000e+00        3.000000        3.000000        0.000000
   1.600000e+01        4.000000        4.000000        0.000000
   1.000000e+02       10.000000       10.000000        0.000000
   1.234568e+04      111.111107      111.111107        0.000000
   1.000000e+10   100000.000000   100000.000000        0.000000

效能測試

Figure_1

可以看到此版本在精度與效能都達到與 glibc 中的 sqrtf 相近的表現,再次感謝 Max042004 提出改進的想法!

以定點數計算 EWMA,闡述其原理、適用場景,並探討誤差,要有數學分析。

TODO: 探討上述在 Linux 核心的應用

搭配課程教材,理解開平方根和 EMWA 如何應用於 Linux 核心 (如 CPU 排程器),予以探討其實作手法