Try   HackMD

2024q1 Homework4 (quiz3+4)

contributed by yu-hsiennn

quiz 3

題目

測驗 1

假定

Pm 為欲所求之平方根
{N2=(an+an1+an2+...+a0)2,am=2m or am=0Pm=an+an1+...+am

經展開後,化簡可得
N2=i=0nai2+2i=0n1Pi+1ai

Xm
N2Pm2
,且此輪與上一輪差設為
Ym
,即
{Xm=N2Pm2Ym=Xm+1Xm=Pm2Pm+12=2Pm+1am+am2

我們再將式子利用
cm
dm
化簡

  • cm=2Pm+1am
  • dm=am2

    Ym={cm+dmif am=2m0if am=0

再觀察下一輪的變化

{cm1=2Pmam1=2(Pm+1+am)am1=Pm+1am+am2=cm2+dmdm1=dm4
因此

int i_sqrt(int x)
{
    if (x <= 1) /* Assume x is always positive */
        return x;
    int z = 0;
    for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2) {
        // m: d_m, z: c_m
        int b = z + m;
        z >>= 1;
        if (x >= b)
            x -= b, z += m;               
    }
    return z;
}

改用第二周測驗 2 的 __ffs

-   for (int m = 1UL << ((31 - __builtin_clz(x)) & ~1UL); m; m >>= 2)
+   for (int m = 1UL << ((31 - __ffs(x)) & ~1UL); m; m >>= 2)

疑惑

cm=Pm+12m+1dm=(2m)2Ym={cm+dmif am=2m0if am=0

在一開始看到這個時,完全無法理解是怎麼跑出這些等號的。
後來才發現,其實有段開,只是 HackMD 沒有顯示出來

測驗 2

因為乘法及除法的運算成本很高,故想利用位移及加減的方式去達到相同的效果
此處數字範圍為 0~19,只要這一段的計算有達到預期即可
題目是用 N = 128 來逼近 1/10

1/128+1/32+1/16=(1+4+8)/128=13/1281/10
也就是說可以

/* 13n / 128 */
q = (n + (n << 2) + (n << 3)) >> 7;

但這邊會忽略掉一些值,如果要精確的話可以把他加回去

d0 = q & 0b1;
d1 = q & 0b11;
d2 = q & 0b111;
q = (n + (n << 2) + (n << 3) + d0 + d1 + d2) >> 7;

不過在範圍 0~19 其實並不會影響結果
如果範圍是 0~19 ,這邊做了一些測試

// test1: n / 16 + n / 32 + 1 / 8
q = ((n << 1) + n + 4) >> 5;
// test2: n / 16 + n / 32 + 1 / 16
q = ((n << 1) + n + 2) >> 5;

結果為

q: 0, r: 0
q: 0, r: 1
q: 0, r: 2
q: 0, r: 3
q: 0, r: 4
q: 0, r: 5
q: 0, r: 6
q: 0, r: 7
q: 0, r: 8
q: 0, r: 9
q: 1, r: 0
q: 1, r: 1
q: 1, r: 2
q: 1, r: 3
q: 1, r: 4
q: 1, r: 5
q: 1, r: 6
q: 1, r: 7
q: 1, r: 8
q: 1, r: 9

而題目

void divmod_10(uint32_t in, uint32_t *div, uint32_t *mod)
{
    uint32_t x = (in | 1) - (in >> 2); /* div = in/10 ==> div = 0.75*in/8 */
    uint32_t q = (x >> 4) + x;
    x = q;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;
    q = (q >> 8) + x;

    *div = (q >> 3);
    *mod = in - ((q & ~0x7) + (*div << 1));   
}

可以看到這邊的作法是

x=34n+364n=0.796875n810n
所以將 q 右移 3 可得
810n18n=110n

而最後是在計算 *mod = in - *div * 10

測驗 3

利用 4 層迴圈來找尋以 2 為底的對數

  • 第一個迴圈作用於處理大於等於 65536 的數值,直到小於才離開此迴圈
  • 第二個迴圈作用於處理大於等於 256 的數值,直到小於才離開此迴圈
  • 第三個迴圈作用於處理大於等於 16 的數值,直到小於才離開此迴圈
  • 第四個迴圈作用於處理大於等於 2 的數值,直到小於才離開此迴圈,並回傳結果

舉個例子:
n = 2049
2049 < 65536, (不執行第一個迴圈)
2049 >= 256, n: 8, result: 8
8 < 16, (不執行第三個迴圈)
8 >= 2, n: 4, result: 9
4 >= 2, n: 2, result: 10
2 >= 2, n: 1, result: 11
1 < 2, 回傳 result: 11

我們也可以換個想法,利用找尋目前 n 第一個 1 的前方有幾個 0 (__builtin_clz(n)),再用 31 去減掉前方 0 的個數即可

參考 log2.h,了解到 linux 核心內部 ilog2 的實作方式,其中有以下幾點:

  • 會先判斷 n 是否使用了常數時間來編譯
  • 不為常數時間則判斷其大小是否小於 4 byte,來執行後續相對應的 fls
  • fls 之所以需要減 1 是因為它是以 1 為底

測驗 4

struct ewma *ewma_add(struct ewma *avg, unsigned long val)
{
    avg->internal = avg->internal
                        ? (((avg->internal << avg->weight) - avg->internal) +
                           (val << avg->factor)) >> avg->weight
                        : (val << avg->factor);
    return avg;
}

根據公式

St=α(Yt+(1α)Yt1+(1α)2Yt2++(1α)kYtk++Y0)

St+1=((St2avg>weightSt)+Yt2avg>factor)2(avg>weight)
化簡後可得
{St+1=(12(avg>weight))St+Yt2avg>factor(avg>weight)α=2(avg>weight)

avg->factor 的用途在於,不是每個系統都支援浮點運算,因此藉由固定其大小來確保精確度

測驗 5

int ceil_ilog2(uint32_t x)
{
    uint32_t r, shift;
    
    x--;
    r = (x > 0xFFFF) << 4;                                                                                                                                    
    x >>= r;
    shift = (x > 0xFF) << 3;
    x >>= shift;
    r |= shift;
    shift = (x > 0xF) << 2;
    x >>= shift;
    r |= shift;
    shift = (x > 0x3) << 1;
    x >>= shift;
    return (r | shift | x > 1) + 1;       
}

開頭的 x-- 用意在於,若是 x 為 2 的冪,最後的 +1 會造成結果錯誤,所以先將 x-- 來避免這個情況發生。
而利用後面一連串的判斷,來快速找出 x 的對數。
最後的回傳會判斷 x 是否還有值,且合併 rshift 的值,並加上最後的進位 1

處理 x = 0 的狀況

可以發現到,當我們直接帶入 01 時,程式會回傳結果 1,而這並不是我們所預期的,於是可以如下修改

-   x--;
+   x -= likely(x); /* x -= !!x */

-   return (r | shift | x > 1) + 1;
+   return (r | shift | x > 1) + !(x == 0);

想法是:

  • 先判斷是否為 0 ,為 0 則不做 -1
  • 最後回傳時再判斷 x 是否為 0,是則不做 +1

如此,即可以避免當 x01,回傳的錯誤結果 1,並維持 branchless


quiz 4

題目

測驗 1

n = (v >> 1) & 0x77777777;
v -= n;
n = (n >> 1) & 0x77777777;
v -= n;
n = (n >> 1) & 0x77777777;
v -= n;

計算方式採取每 4 個位元為單位,

v=(b31b30b29b28)...b0
右移 1 且只把最高位元(以 4 位元為單位)給變成0,以防其影響計算結果,
n=(0b31b30b29)(0b27...b1)

vn=vv2

跑完這段即可得,
vv2v4v8

接下來會執行 v = (v + (v >> 4)) & 0x0F0F0F0F
v=B7B6B5B4B3B2B1B0
, (
Bn=b4n+3b4n+2b4n+1b4n+0
)
所以
v=0(B7+B6)0(B5+B4)0(B3+B2)0(B1+B0)

最後將 v *= 0x01010101, v >> 24,即可得
popcount=B7+B6+B5+B4+B3+B2+B2+B0

撰寫出更高效的程式碼 477. Total Hamming Distance

int totalHammingDistance(vector<int>& nums) {
    int res = 0, n = nums.size();
    /* TLE: O(n^2)
     *  for(int i=0;i<nums.size();i++){
     *       for(int j=i+1;j<nums.size();j++){
     *           int xOR = nums[i]^nums[j];
     *           ans += __builtin_popcount(xOR);
     *       }
     *   } 
     */
    for (int i = 0; i < 32; i++) {
        int count = 0;
        for (int j = 0; j < n; j++)
            count += (nums[j] >> i) & 1;
        res += count * (n - count);
    } 
    return res;
}

時間複雜度為

O(32n)=O(n)

測驗 2

2k{1(mod  3),  k  even1(mod  3),  k  odd
N=bn1bn2b1b0

可以改寫成
Ni=0n1bi(1)i  (mod  3)

我們可以利用 n = popcount(n & 0x55555555) - popcount(n & 0xAAAAAAAA),來達到相同效果
其中還可以利用以下定理來進一步的化簡
popcount(x&m)popcount(x&m)=popcount(xm)popcount(m)

popcount(m) = 16, (m = 0xAAAAAAAA),而為了避免 n = popcount(n ^ 0xAAAAAAAA) - 16 結果為負數,我們可以加入一個足夠大的 3 的倍數的數字給 -16,範例為 39 - 16 -> 23
再來執行 n = popcount(n ^ 0x2A) - 3 來限縮範圍至 -3 ~ 3,最後的 return n + ((n >> 31) & 3) 再將負的轉為正的
其中,參考 mesohandsome 的筆記,發現到要將 n 做轉型才會得到正確結果,即

-   return n + ((n >> 31) & 3);
+   return n + ((int)n >> 31) & 3;
static inline uint32_t is_win(uint32_t player_board)
{
    return (player_board + 0x11111111) & 0x88888888;
}

當連線時,會是 0x7,為了 0x7 + ? == 0x8 要為 true,所以 ? 應為 0x1

static inline int mod7(uint32_t x)
{
    x = (x >> 15) + (x & UINT32_C(0x7FFF));
    /* Take reminder as (mod 8) by mul/shift. Since the multiplier
     * was calculated using ceil() instead of floor(), it skips the
     * value '7' properly.
     *    M <- ceil(ldexp(8/7, 29))
     */
    return (int) ((x * UINT32_C(0x24924925)) >> 29);
}

先將 x 右移 15 來保留其商數,再利用 x & UINT32_C(0x7FFF) 來保留最低的 15 個位元

測驗 3

函式功用

  • xt_create: 初始化一棵新的 XTree
  • xt_destroy: 利用遞迴的方式來將樹的每個節點的資源給釋放掉
  • xt_frist: 找到樹中最小值的節點(最左邊)
  • xt_last: 找到樹中最大值的節點(最右邊)
  • xt_rotate_left: 執行下圖操作






Tree

origin


p

p



n

n



p->n





l

l



n->l





r

r



n->r





A

A



l->A





B

B



l->B











Tree

After Rotate Left


p

p



l

l



p->l





A

A



l->A





n

n



l->n





B

B



n->B





r

r



n->r





  • xt_rotate_right: 執行下圖操作






Tree

origin


p

p



n

n



p->n





l

l



n->l





r

r



n->r





A

A



r->A





B

B



r->B











Tree

After Rotate Right


p

p



r

r



p->r





n

n



r->n





B

B



r->B





l

l



n->l





A

A



n->A





  • xt_balance: 回傳左子樹樹高減去右子樹樹高
  • xt_update: 確認樹是否平衡,不平衡則執行對應的旋轉操作
  • xt_insert: 將新節點插入進樹中
  • xt_remove: 將對應的節點從樹中移除
static void __xt_remove(struct xt_node **root, struct xt_node *del)
{
    if (xt_right(del)) {
        struct xt_node *least = xt_first(xt_right(del));
        if (del == *root)
            *root = least;

        xt_replace_right(del, AAAA);
        xt_update(root, BBBB);
        return;
    }

    if (xt_left(del)) {
        struct xt_node *most = xt_last(xt_left(del));
        if (del == *root)
            *root = most;

        xt_replace_left(del, CCCC);
        xt_update(root, DDDD);
        return;
    }

    if (del == *root) {
        *root = 0;
        return;
    }

    /* empty node */
    struct xt_node *parent = xt_parent(del);

    if (xt_left(parent) == del)
        xt_left(parent) = 0;
    else
        xt_right(parent) = 0;

    xt_update(EEEE, FFFF);
}

想法是若欲刪除之節點在右子樹,則用右子樹的最小來填補,即 least ,而填補完需要確認樹是否平衡,所以執行 xt_update(root, xt_right(least))
反之,則用左子樹的最大,即 most,且更新 xt_update(root, xt_left(most))
最後在判斷整棵樹有沒有平衡,即 xt_update(root, parent)