# 2021q1 quiz6A - arbitrary-precision arithmetic contributed by < `tigger12613` > > [GitHub](https://github.com/tigger12613/arbitrary-precision-arithmetic) ## 程式運作原理 ### 資料結構 ```cpp #define UTYPE uint32_t /* how large the underlying array size should be */ #define UNIT_SIZE 4 /* BN_ARRAY_SIZE = 32 */ #define BN_ARRAY_SIZE (128 / UNIT_SIZE) /* size of big-numbers in bytes */ struct bn { UTYPE array[BN_ARRAY_SIZE]; }; ``` bn 由一串 unsigned int32 組成 array , `array[0]` 儲存最低位。 ```cpp 舉例來說 array[0] array[1] a = 0xFFFFFFFF 0x00000001 ... a + 1 = 0x00000000 0x00000002 ... ``` ### 給 bn 設值 如果機器是 big-endian ,一樣`array[0]`存入低位,`array[1]`存入高位。 改進:加入 big-endian, little-endian 判斷。 ```cpp static inline void bn_from_int(struct bn *n, UTYPE_TMP i) { bn_init(n); if( __BYTE_ORDER == __BIG_ENDIAN ){ n->array[1] = i; UTYPE_TMP tmp = i >> 32; n->array[0] = tmp; }else{ /* FIXME: what if machine is not little-endian? */ n->array[0] = i; /* bit-shift with U64 operands to force 64-bit results */ UTYPE_TMP tmp = i >> 32; n->array[1] = tmp; } } ``` c99 6.5.7 > 4. The result of E1 << E2 is E1 left-shifted E2 bit positions; vacated bits are filled with zeros. If E1 has an unsigned type, the value of the result is E1 × 2<sup>E2</sup>, reduced modulo one more than the maximum value representable in the result type. If E1 has a signed type and nonnegative value, and E1 × 2<sup>E2</sup> is representable in the result type, then that is the resulting value; otherwise, the behavior is undefined. > 5. The result of E1 >> E2 is E1 right-shifted E2 bit positions. If E1 has an unsigned type or if E1 has a signed type and a nonnegative value, the value of the result is the integral part of the quotient of E1 / 2<sup>E2</sup>. If E1 has a signed type and a negative value, the resulting value is implementation-defined > 這說明了不論在 big-endian 或是 little-endian 的機器上做 shift 都會是一樣的結果。 ### bn 乘法 array 用 32bit unsigned interger 去儲存,可以將 32bit unsigned interger 無損轉換成 64bit unsigned interger ,並且相乘後不會溢位。 ```cpp static void bn_mul(struct bn *a, struct bn *b, struct bn *c) { struct bn row, tmp; bn_init(c); for (int i = 0; i < BN_ARRAY_SIZE; ++i) { bn_init(&row); for (int j = 0; j < BN_ARRAY_SIZE; ++j) { if (i + j < BN_ARRAY_SIZE) { bn_init(&tmp); UTYPE_TMP intermediate = a->array[i] * (UTYPE_TMP)b->array[j]; bn_from_int(&tmp, intermediate); lshift_unit(&tmp, i + j); bn_add(&tmp, &row, &row); } } bn_add(c, &row, c); } } ``` 算法如下圖所示,將兩數相乘後位移到正確的 index 再加到 row 。 ```cpp a = ... a[2] a[1] a[0] b = ... b[2] b[1] b[0] * ---------------------------------------- ... a[2]*b[0] a[1]*b[0] a[0]*b[0] ... a[1]*b[1] a[0]*b[1] ... a[0]*b[2] + ---------------------------------------- row = ... row[2] row[1] row[0] ``` ### print bn by hex 由於 bn 使用二進位儲存,因此無法簡單轉換成十進位輸出。 ```cpp static void bn_to_str(struct bn *n, char *str, int nbytes) { /* index into array - reading "MSB" first -> big-endian */ int j = BN_ARRAY_SIZE - 1; int i = 0; /* index into string representation */ /* reading last array-element "MSB" first -> big endian */ while ((j >= 0) && (nbytes > (i + 1))) { sprintf(&str[i], FORMAT_STRING, n->array[j]); i += (2 * UNIT_SIZE); /* step UNIT_SIZE hex-byte(s) forward in the string */ j -= 1; /* step one element back in the array */ } /* Count leading zeros: */ for (j = 0; str[j] == '0'; j++) ; /* Move string j places ahead, effectively skipping leading zeros */ for (i = 0; i < (nbytes - j); ++i) str[i] = str[i + j]; str[i] = 0; } ``` ### bn add c = a + b 改進:如果 overflow return 0, 不然 return 1. ```cpp static int bn_add(struct bn *a, struct bn *b, struct bn *c) { int carry = 0; for (int i = 0; i < BN_ARRAY_SIZE; ++i) { UTYPE_TMP tmp = (UTYPE_TMP)a->array[i] + b->array[i] + carry; carry = (tmp > MAX_VAL); c->array[i] = (tmp & MAX_VAL); } if (carry == 1) { return 0; } else { return 1; } } ``` ### bn dec bn = bn - 1 改進:如果 overflow return 0, 不然 return 1. ```cpp /* Decrement: subtract 1 from n */ static int bn_dec(struct bn *n) { if (bn_is_zero) { return 0; } for (int i = 0; i < BN_ARRAY_SIZE; ++i) { UTYPE tmp = n->array[i]; UTYPE res = tmp - 1; n->array[i] = res; if (!(res > tmp)) break; } return 1; } ``` ## bn 輸出十進位 參考[sysprog21/bignum](https://github.com/sysprog21/bignum)中 `format.c`的實作 ```cpp #define digit_div(n1, n0, d, q, r) \ __asm__("divl %4" \ : "=a"(q), "=d"(r) \ : "0"(n0), "1"(n1), "g"(d)) #define BN_NORMALIZE(u, usize) \ while ((usize) && !(u)[(usize)-1]) \ --(usize); /* Set u[size] = u[usize] / v, and return the remainder. */ static uint32_t bn_ddivi(uint32_t *u, uint32_t size, uint32_t v) { assert(u != NULL); assert(v != 0); if (v == 1) return 0; uint32_t s1 = 0; u += size; do { uint32_t s0 = *--u; uint32_t q, r; if (s1 == 0) { q = s0 / v; r = s0 % v; } else { digit_div(s1, s0, v, q, r); } *u = q; s1 = r; } while (--size); return s1; } static const char radix_chars[37] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; #ifndef SWAP #define SWAP(x, y) \ do { \ typeof(x) __tmp = x; \ x = y; \ y = __tmp; \ } while (0) #endif static char *bn_get_str(bn_t *n, char *out) { unsigned int radix = 10; const uint32_t max_radix = 0x3B9ACA00U; const unsigned int max_power = 9; if (!out) out = calloc(256, (sizeof(char))); char *outp = out; bn_t *tmp = bn_tmp_copy(n); uint32_t size; uint32_t *tmp_u = tmp->array; for (int i = 32; i >= 0; i--) { if (tmp->array[i] != 0) { size = i + 1; break; } } BN_NORMALIZE(n->array, size); if (size == 0 || (size == 1 && tmp_u[0] < radix)) { if (!out) out = malloc(2); out[0] = size ? radix_chars[tmp_u[0]] : '0'; out[1] = '\0'; return out; } uint32_t tsize = size; do { /* Multi-precision: divide U by largest power of RADIX to fit in * one apm_digit and extract remainder. */ uint32_t r = bn_ddivi(tmp_u, size, max_radix); tsize -= (tmp_u[tsize - 1] == 0); /* Single-precision: extract K remainders from that remainder, * where K is the largest integer such that RADIX^K < 2^BITS. */ unsigned int i = 0; do { uint32_t rq = r / radix; uint32_t rr = r % radix; *outp++ = radix_chars[rr]; r = rq; if (tsize == 0 && r == 0) /* Eliminate any leading zeroes */ break; } while (++i < max_power); assert(r == 0); /* Loop until TMP = 0. */ } while (tsize != 0); free(tmp); char *f = outp - 1; /* Eliminate leading (trailing) zeroes */ while (*f == '0') --f; /* NULL terminate */ f[1] = '\0'; /* Reverse digits */ for (char *s = out; s < f; ++s, --f) SWAP(*s, *f); return out; } ``` ### 測試結果 印出 factorial(1) 到 factorial(100) 並跟 python 運算結果做比對 ```cpp int main(int argc, char *argv[]) { struct bn num, result; char buf[8192]; unsigned int n = strtoul(argv[1], NULL, 10); if (!n) return -2; for (int i = 1; i <= n; i++) { bn_from_int(&num, i); factorial(&num, &result); bn_to_str(&result, buf, sizeof(buf)); // printf("factorial(%d) = %s\n", i, buf); uint32_t string_size = 512; char *str = calloc(1, string_size); char *a = bn_get_str(&result, str); printf("fac(%d) = %s\n", i, a); } return 0; } ``` ```python #python 計算 factorial from 1~100 def factorial(n): for i in range(n-1,0,-1): n *= i return n if __name__=="__main__": for i in range(1,101): print("fac(%d) =" %i,factorial(i)) ``` 用 shell script 來比對兩個結果 ```bash #!/bin/bash diff -u -B q6ans pyans ret=$? if [[ $ret -eq 0 ]]; then echo "passed." else echo "failed." fi ``` 比對運算結果,與 python 計算結果一致。 ```shell $./q6 100 > q6ans fac(1) = 1 fac(2) = 2 fac(3) = 6 fac(4) = 24 fac(5) = 120 ... fac(98) = 9426890448883247745626185743057242473809693764078951663494238777294707070023223798882976159207729119823605850588608460429412647567360000000000000000000000 fac(99) = 933262154439441526816992388562667004907159682643816214685929638952175999932299156089414639761565182862536979208272237582511852109168640000000000000000000000 fac(100) = 93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000 $python3 factorial > pyans $bash diff.sh passed. ``` ## Karatsuba algorithm ### Karatsuba 實作 參考 [Karatsuba algorithm](https://en.wikipedia.org/wiki/Karatsuba_algorithm) 的 procedure code ```python procedure karatsuba(num1, num2) if (num1 < 10) or (num2 < 10) return num1 × num2 /* Calculates the size of the numbers. */ m = min(size_base10(num1), size_base10(num2)) m2 = floor(m / 2) /* m2 = ceil(m / 2) will also work */ /* Split the digit sequences in the middle. */ high1, low1 = split_at(num1, m2) high2, low2 = split_at(num2, m2) /* 3 calls made to numbers approximately half the size. */ z0 = karatsuba(low1, low2) z1 = karatsuba((low1 + high1), (low2 + high2)) z2 = karatsuba(high1, high2) return (z2 × 10 ^ (m2 × 2)) + ((z1 - z2 - z0) × 10 ^ m2) + z0 ``` ```cpp= static void karatsuba_mul(bn_t *a, bn_t *b, bn_t *c) { int m1 = bn_size(a); int m2 = bn_size(b); bn_init(c); if (m1 <= 1 && m2 <= 1) { UTYPE_TMP intermediate = a->array[0] * (UTYPE_TMP)b->array[0]; bn_from_int(c, intermediate); return; }else if(m1 <= 1){ bn_t row ,tmp; bn_init(&row); for (int j = 0; j < BN_ARRAY_SIZE; ++j) { bn_init(&tmp); UTYPE_TMP intermediate = a->array[0] * (UTYPE_TMP)b->array[j]; bn_from_int(&tmp, intermediate); lshift_unit(&tmp, j); bn_add(&tmp, &row, &row); } bn_add(c, &row, c); return; }else if( m2 <=1){ bn_t row ,tmp; bn_init(&row); for (int j = 0; j < BN_ARRAY_SIZE; ++j) { if ( j < BN_ARRAY_SIZE) { bn_init(&tmp); UTYPE_TMP intermediate = a->array[j] * (UTYPE_TMP)b->array[0]; bn_from_int(&tmp, intermediate); lshift_unit(&tmp, j); bn_add(&tmp, &row, &row); } } bn_add(c, &row, c); return; } //min int m = (m1 > m2) ? m2 : m1; int mm = m / 2; bn_t high1, low1, high2, low2, z0, z1, z2, tmp1, tmp2; bn_init(&high1); bn_init(&low1); bn_init(&high2); bn_init(&low2); bn_init(&z0); bn_init(&z1); bn_init(&z2); bn_init(&tmp1); bn_init(&tmp2); bn_split_shift(a, &high1, &low1, mm); bn_split_shift(b, &high2, &low2, mm); karatsuba_mul(&low1, &low2, &z0); bn_add(&low1, &high1, &tmp1); bn_add(&low2, &high2, &tmp2); karatsuba_mul(&tmp1, &tmp2, &z1); karatsuba_mul(&high1, &high2, &z2); bn_sub(&z1, &z2, &tmp1); bn_sub(&tmp1, &z0, &tmp1); bn_left_shift(&tmp1, mm); bn_left_shift(&z2, 2*mm); bn_add(&tmp1, &z2, &tmp2); bn_add(&tmp2, &z0, c); return; } ``` ### bn size 回傳 bn 去除前面的零的大小 ```cpp static int bn_size(bn_t *a) { for (int i = (int)(BN_ARRAY_SIZE - 1); i >= 0; i--) { if (a->array[i] != 0) { return i + 1; } } return 0; } ``` ### bn split 以 offset 的位置分割 a 到 b 跟 c ```cpp // a = b concat c static void bn_split_shift(bn_t *a, bn_t *b, bn_t *c, const int offset) { bn_init(b); bn_init(c); for (int i = 0; i < offset; i++) { c->array[i] = a->array[i]; } for (int i = 0; i < BN_ARRAY_SIZE - offset; i++) { b->array[i] = a->array[i + offset]; } } ``` ### bn subtraction bn 減法 ```cpp // c = a - b static int bn_sub(bn_t *a, bn_t *b, bn_t *c) { //printf("sub\n"); UTYPE_TMP carry = 0; for (int i = 0; i < BN_ARRAY_SIZE; i++) { if (a->array[i] == 0 && carry == 1) { c->array[i] = MAX_VAL - (UTYPE_TMP)b->array[i] + 1; carry = 1; } else { if (a->array[i] - carry < b->array[i]) { c->array[i] = (MAX_VAL - (UTYPE_TMP)b->array[i])+ 1 + (UTYPE_TMP)a->array[i] - carry; carry = 1; } else { c->array[i] = (UTYPE_TMP)a->array[i] - carry - (UTYPE_TMP)b->array[i]; carry = 0; } } } if (carry == 1) { return 0; } else { return 1; } } ``` ### bn left shift left shift bn number of offset,一個 offset 為 32bit ```cpp void bn_left_shift(bn_t *a, int offset) { for (int i = (int)(BN_ARRAY_SIZE - 1); i >= offset; i--) { a->array[i] = a->array[i - offset]; } for (int i = 0; i < offset; i++) { a->array[i] = 0; } } ``` ### 正確性 用 bn_mul 當作正確答案進行比對。手動更改初始值, bn_mul 跟 karatsuba_mul 的答案都一樣。 ```cpp int main(int argc, char *argv[]) { bn_t a, b, c, d; bn_init(&a); bn_init(&b); bn_init(&c); bn_init(&d); bn_from_int(&a, 1567579508); bn_from_int(&d, 1567579508); int loop = 31; for (uint64_t i = 0; i < loop; i++) { bn_left_shift(&a,1); bn_assign(&b,&a); bn_add(&a,&d,&a); bn_mul(&a, &b, &c); bn_printhex(&c); } bn_from_int(&a, 1567579508); for (uint64_t i = 0; i < loop; i++) { bn_left_shift(&a,1); bn_assign(&b,&a); bn_add(&a,&d,&a); karatsuba_mul(&a, &b, &c); bn_printhex(&c); } return 0; } ``` ### 效率比較 隨機一個 unsigned int 作為初始值,每一次迴圈都 left shift 32bit 再加上一個隨機 unsigned int ,以此測試演算法在不同大小的效率。 ```cpp int main(int argc, char *argv[]) { struct timespec t_start, t_end; long long k = 500; const int size = 1000; long long normal_time[size], karatsuba_time[size]; int normal_bits[size], karatsuba_bits[size]; bn_t a, b, c, d; bn_init(&a); bn_init(&b); bn_init(&c); bn_init(&d); srand(time(NULL)); FILE *fp1, *fp2; fp1 = fopen("origin.txt", "a"); fp2 = fopen("karatsuba.txt","a"); for (int k = 0; k < 50; k++) { uint32_t rand_n = rand(); bn_from_int(&a, rand_n); bn_from_int(&d, rand_n); int loop = 31; for (uint64_t i = 0; i < loop; i++) { normal_bits[i] = bn_bit_count(&a); clock_gettime(CLOCK_MONOTONIC, &t_start); for (uint64_t j = 0; j < 10; j++) { bn_mul(&a, &b, &c); } clock_gettime(CLOCK_MONOTONIC, &t_end); normal_time[i] = (t_end.tv_sec * NANOSEC + t_end.tv_nsec) - (t_start.tv_sec * NANOSEC + t_start.tv_nsec); bn_left_shift(&a, 1); bn_assign(&b, &a); bn_add(&a, &d, &a); } for (int i = 0; i < loop; i++) { fprintf(fp1, "%d ", normal_bits[i]); fprintf(fp1, "%lld\n", normal_time[i]); } bn_from_int(&a, rand_n); for (uint64_t i = 0; i < loop; i++) { karatsuba_bits[i] = bn_bit_count(&a); clock_gettime(CLOCK_MONOTONIC, &t_start); for (uint64_t j = 0; j < 10; j++) { karatsuba_mul(&a, &b, &c); } clock_gettime(CLOCK_MONOTONIC, &t_end); karatsuba_time[i] = (t_end.tv_sec * NANOSEC + t_end.tv_nsec) - (t_start.tv_sec * NANOSEC + t_start.tv_nsec); bn_left_shift(&a, 1); bn_assign(&b, &a); bn_add(&a, &d, &a); } for (int i = 0; i < loop; i++) { fprintf(fp2, "%d ", karatsuba_bits[i]); fprintf(fp2, "%lld\n", karatsuba_time[i]); } } fclose(fp1); fclose(fp2); return 0; } ``` 從圖中可以發現 karatsuba algorithm 的耗時隨著 bit size 增加而增加。 bn_mul,不論大小耗時都一致。 karatsuba algorithm 在 bit size 超過一定的範圍後效率會不如 bn_mul 。 因為是隨機初始值,所以初始值很容易是 32 bit or 31 bit ![](https://i.imgur.com/tZ2T69s.png)