contributed by <Shiritai
>
將以下程式:
#include <stdint.h>
static inline uint64_t pow2(uint8_t e) { return ((uint64_t)1) << e; }
uint64_t next_pow2(uint64_t x)
{
uint8_t lo = 0, hi = 63;
while (lo < hi) {
uint8_t test = (lo + hi)/2;
if (x < pow2(test)) { hi = test; }
else if (pow2(test) < x) { lo = test+1; }
else { return pow2(test); }
}
return pow2(lo);
}
改為位元運算風格如下,應填入何?
uint64_t next_pow2(uint64_t x)
{
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
x |= x >> AAAA;
x |= x >> 16;
x |= x >> BBBB;
return CCCC;
}
「填補」的説明提示為
然而上述程式碼存在分支,於是我們可考慮建構以下「填補」位元表示中的 1:
x = 0010000000000000 x = 0011000000000000 x = 0011110000000000 x = 0011111111000000 x = 0011111111111111
可見往左每次補左移 1, 2, 4, 8, … 皆為二的冪,以此我們重新整理題目的程式碼,並加上適當的空行註解:
// shift 1 then or
x |= x >> 1;
// shift 2 then or
x |= x >> 1;
x |= x >> 1;
// shift 4 then or
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
x |= x >> 1;
// shift -> 8 then or
x |= x >> AAAA;
// shift 16 then or
x |= x >> 16;
// shift -> 32 then or
x |= x >> BBBB;
// 00...011...11 + 1 = result twos power
return CCCC;
答案呼之欲出:
AAAA
為 8
,BBBB
為 32
,CCCC
為 x + 1
。
本程式碼透過以 1
填滿最高位 1
位元以降的所有位元,來獲得大於原值之二的冪減一,最後加一即獲得所求。
不過實際上本算法有個問題:當原本 x 即為二的冪時反而求得比其大的二的冪…
__builtin_clzl
改寫有 count leading zero,尋找目標二的冪變的更加簡單:取得 leading zero 中最高位的 0
的位移量,產生一個以零為基礎,將該為設為 1
後的結果:
uint64_t next_pow2(uint64_t x)
{
return 1 << (64 - __builtin_clzl(x));
}
極其簡單,不過如我於作業三的開發紀錄中 Count leading zero 實作一章所述,__builtin_clzl
有未定義行為,可能需要使用 branching 排除例外:
if (x)
return 1 << (64 - __builtin_clzl(x));
else // zero
return 1;
前面提到過 x 本身即二的冪的可能性,這可以透過 __builtin_ctzl
來判斷:當 lead + tail + 1
為 64
時即表示 x
為二的冪:
int lead = __builtin_clzl(x);
int tail = __builtin_ctzl(x);
if (lead + tail == 63)
return x;
else if (x)
return 1 << (64 - lead);
else // zero
return 1;
此時不仿利用邏輯與位元運算離規避上述 branching:
int lead = __builtin_clzl(x);
int tail = __builtin_ctzl(x);
_Bool is2power = lead + tail == 63;
int res;
(res = (-is2power & x)) || (res = ((-!!x) & (1u << (64 - lead))) + !x);
return res;
注意到利用邏輯運算子之 short circuit (短路求值) 的特性,第 12 行故意使用 ||
使當前者若非零就短路求值,這導致 res
停留在捕獲若 x
為二的冪時的結果,否則捕獲後者的結果。||
中後者的說明如下:
/**
* !!x -> whether x is non-zero
* -!!x -> if x is non-zero, got all ones (-1 = 0xff..ff)
* otherwise, still zero
*
* !x -> if x is zero, got 1
*/
((-!!x) & (1u << (64 - lead))) // for x != 0
+ !x; // for x == 0
如此便可實現看似無 branching 版的 next_pow2
。
前述我們討論了兩種以 __builtin_clzl
為基礎的實現,差別在一個使用純 branching,,另一個極端的使用邏輯與位元運算規避 branching。一個重要的問題是:後面這樣折騰,是否真的比較高效?看看產生的指令也許可以略知一二:
以下所有組語使用 cc -O2 -std=c99 -S next_pow2.c
編譯。
前幾行求 lead
和 tail
皆相同,出現 __builtin_clzl
對應之組語 bsrq
和 __builtin_ctzl
對應之 bsfq
:
bsrq %rdi, %rax
xorl %edx, %edx
xorq $63, %rax
rep bsfq %rdi, %rdx
addl %eax, %edx
...
為了比較極端使用非分支與極端使用分支,我額外寫一個分支判斷二的冪後即不使用分支的中間版本作為比較…
...
movl %eax, %esi
movq %rdi, %rax
cmpl $63, %edx
je .L4
movl $1, %eax
testq %rdi, %rdi
je .L4
movl $64, %ecx
subl %esi, %ecx
sall %cl, %eax
cltq
.L4:
ret
...
cmpl $63, %ecx
je .L4
negq %rdi
movl $64, %ecx
sbbl %esi, %esi
subl %edx, %ecx
movl $1, %edx
salq %cl, %rdx
movslq %esi, %rsi
andq %rdx, %rsi
cmpq $1, %rax
movq %rsi, %rax
adcq $0, %rax
.L4:
ret
...
cmpl $63, %eax
sete %al
movzbl %al, %eax
negl %eax
andl %edi, %eax
jne .L2
movq %rdi, %rax
movl $64, %ecx
negq %rax
sbbl %eax, %eax
subl %edx, %ecx
movl $1, %edx
salq %cl, %rdx
andl %edx, %eax
cmpq $1, %rdi
adcl $0, %eax
.L2:
cltq
ret
可以發現全無分支版的竟然出現 jne
(jump if not equals) 指令,不過對於它的出現也不應該感到意外,這就是 short circuit 的實作。考慮到條件轉跳於預測錯誤時 CPU flush 的性能損耗,分支版有兩條件轉跳,其餘兩者皆為單條件轉跳,單轉跳的可能表現比較好。不過極端版組合語言明顯長非常多。綜合來看,道取中庸可能是比較好的選擇,其對應的 C 語言如下:
int lead = __builtin_clzl(x);
int tail = __builtin_ctzl(x);
if (lead + tail == 63)
return x;
else
return ((-!!x) & (1llu << (64 - lead))) + !x;
串接 \(1\) 至 \(n\) 間所有二進制數的值。
int concatenatedBinary(int n)
{
const int M = 1e9 + 7;
int len = 0; /* the bit length to be shifted */
/* use long here as it potentially could overflow for int */
long ans = 0;
for (int i = 1; i <= n; i++) {
/* removing the rightmost set bit
* e.g. 100100 -> 100000
* 000001 -> 000000
* 000000 -> 000000
* after removal, if it is 0, then it means it is power of 2
* as all power of 2 only contains 1 set bit
* if it is power of 2, we increase the bit length
*/
if (!(DDDD))
len++;
ans = (i | (EEEE)) % M;
}
return ans;
}
考慮判斷去除最低位的 1
後是否為零的邏輯,便是一個使用 __builtin_ctz
的好時機:將 tailing zero 位加一之位元去除,故 DDDD
為 x & ~(1 << __builtin_ctz(x))
。由於暫時不確定是針對哪個變數操作,故先以 x
代替。
之後我們看到第 18 行需要一個與 ans
於之前迴圈舊的值位移後有關的運算結果,推斷 DDDD
的邏輯應該是為了複製貼上 i
的位元所做的準備,故 DDDD
的變數應該與 i
有關,也就是 i
。這樣思考的話 EEEE
的答案也呼之欲出:即將 ans
位移 len
。
故答案為:
DDDD
: i & ~(1 << __builtin_ctz(i))
EEEE
: ans << len
由於直覺想到使用 __builtin_ctz
,與延伸不謀而合,於是我去觀賞其他學員們的答案,發現許多人 DDDD
都使用 i & (i - 1)
作答,其所代表的意義為直接判斷是否為二的冪,也十分精妙。
當初確診時看這題題目敘述就昏頭,現在看一次就懂了,可惡 qq
以 SWAR 實作計算 UTF8 字數的函式。
size_t swar_count_utf8(const char *buf, size_t len)
{
const uint64_t *qword = (const uint64_t *) buf;
const uint64_t *end = qword + len >> 3;
size_t count = 0;
for (; qword != end; qword++) {
const uint64_t t0 = *qword;
const uint64_t t1 = ~t0;
const uint64_t t2 = t1 & 0x04040404040404040llu;
const uint64_t t3 = t2 + t2;
const uint64_t t4 = t0 & t3;
count += __builtin_popcountll(t4);
}
count = (1 << AAAA) * (len / 8) - count;
count += (len & BBBB) ? count_utf8((const char *) end, len & CCCC) : DDDD;
return count;
}
當中出現的 count_utf8
為以單字元為基礎掃描的版本:
#include <stddef.h>
#include <stdint.h>
size_t count_utf8(const char *buf, size_t len)
{
const int8_t *p = (const int8_t *) buf;
size_t counter = 0;
for (size_t i = 0; i < len; i++) {
/* -65 is 0b10111111, anything larger in two-complement's should start
* new code point.
*/
if (p[i] > -65)
counter++;
}
return counter;
}
首先觀察 swar_count_utf8
中初始化的兩指標,qword
表四重 words (2 bytes),也就是 \(64\) 位元,而 end
指向 qword
遍歷結束的位址,除以 \(8\) (右移 \(3\)) 處理了 char *
至uint64_t *
位移量的型別轉換。
遍歷完迴圈後,可以預期 count
已經紀錄 continuation bytes 數,以總位元組數減去之後便是真正的 UTF8 字數。考量如此,(1 << AAAA) * (len / 8)
應該為總位元組數,也就是假設 len
整除於 \(8\) (\(64\) bits) 情況下的總位元組數:整數除 \(8\) 後乘 \(8\) == 1 << 3
,故 AAAA = 3
。
但 buf
未必能對齊於 uint64_t
,需額外計算未對其的部分。由此推斷 BBBB
為判斷對齊與否的邏輯:len & 0b111
,故 BBBB = 0b111 = 7
,同時判斷 DDDD
即 0
,表已經對齊不需補加。
至於 CCCC
,由於希望 count_utf8
幫我們處理未對齊的部分,故應該略去已經對齊的量,透過 & 0b111
即可。
故答案為:
AAAA = 3
BBBB = 7
CCCC = 7
DDDD = 0
由於單次尋找量為非 SWAR 的八倍,但單次迴圈內的計算量增加,故推測效能會有顯著而小於八倍的成長。
TODO: 更具體的效能比較…
Shiritai
#include <stdint.h>
#include <stdbool.h>
bool is_pattern(uint16_t x)
{
if (!x)
return 0;
for (; x > 0; x <<= 1) {
if (!(x & 0x8000))
return false;
}
return true;
}
以上述函式確認 \(16\) bits 無號數是否符合以下樣式:
8000, c000, e000, f000, f800, fc00, fe00, ff00, ff80, ffc0, ffe0, fff0, fff8, fffc, fffe, ffff,
改由以下函式更加有效:
bool is_pattern(uint16_t x)
{
const uint16_t n = EEEE;
return (n ^ x) < FFFF;
}
眼尖便能發現特定樣式的為最高位必是一個以上(含)連續的一。如果不使用題目的框架,我可能會這樣寫:
return x && __builtin_ctzs(x) + __builtin_clzs(~x) == 16;
不過題目框架已經給好了,這還真腦筋急轉不過來…只好參考其他學員們的答案 (EEEE = ~x + 1
, FFFF = x
),但還是覺得沒有學到什麼,就不額外說明了。