# 2021q1 quiz6A - arbitrary-precision arithmetic
contributed by < `tigger12613` >
> [GitHub](https://github.com/tigger12613/arbitrary-precision-arithmetic)
## 程式運作原理
### 資料結構
```cpp
#define UTYPE uint32_t
/* how large the underlying array size should be */
#define UNIT_SIZE 4
/* BN_ARRAY_SIZE = 32 */
#define BN_ARRAY_SIZE (128 / UNIT_SIZE) /* size of big-numbers in bytes */
struct bn { UTYPE array[BN_ARRAY_SIZE]; };
```
bn 由一串 unsigned int32 組成 array , `array[0]` 儲存最低位。
```cpp
舉例來說
array[0] array[1]
a = 0xFFFFFFFF 0x00000001 ...
a + 1 = 0x00000000 0x00000002 ...
```
### 給 bn 設值
如果機器是 big-endian ,一樣`array[0]`存入低位,`array[1]`存入高位。
改進:加入 big-endian, little-endian 判斷。
```cpp
static inline void bn_from_int(struct bn *n, UTYPE_TMP i) {
bn_init(n);
if( __BYTE_ORDER == __BIG_ENDIAN ){
n->array[1] = i;
UTYPE_TMP tmp = i >> 32;
n->array[0] = tmp;
}else{
/* FIXME: what if machine is not little-endian? */
n->array[0] = i;
/* bit-shift with U64 operands to force 64-bit results */
UTYPE_TMP tmp = i >> 32;
n->array[1] = tmp;
}
}
```
c99 6.5.7
> 4. The result of E1 << E2 is E1 left-shifted E2 bit positions; vacated bits are filled with zeros. If E1 has an unsigned type, the value of the result is E1 × 2<sup>E2</sup>, reduced modulo one more than the maximum value representable in the result type. If E1 has a signed type and nonnegative value, and E1 × 2<sup>E2</sup> is representable in the result type, then that is the resulting value; otherwise, the behavior is undefined.
> 5. The result of E1 >> E2 is E1 right-shifted E2 bit positions. If E1 has an unsigned type or if E1 has a signed type and a nonnegative value, the value of the result is the integral part of the quotient of E1 / 2<sup>E2</sup>. If E1 has a signed type and a negative value, the resulting value is implementation-defined
>
這說明了不論在 big-endian 或是 little-endian 的機器上做 shift 都會是一樣的結果。
### bn 乘法
array 用 32bit unsigned interger 去儲存,可以將 32bit unsigned interger 無損轉換成 64bit unsigned interger ,並且相乘後不會溢位。
```cpp
static void bn_mul(struct bn *a, struct bn *b, struct bn *c) {
struct bn row, tmp;
bn_init(c);
for (int i = 0; i < BN_ARRAY_SIZE; ++i) {
bn_init(&row);
for (int j = 0; j < BN_ARRAY_SIZE; ++j) {
if (i + j < BN_ARRAY_SIZE) {
bn_init(&tmp);
UTYPE_TMP intermediate = a->array[i] * (UTYPE_TMP)b->array[j];
bn_from_int(&tmp, intermediate);
lshift_unit(&tmp, i + j);
bn_add(&tmp, &row, &row);
}
}
bn_add(c, &row, c);
}
}
```
算法如下圖所示,將兩數相乘後位移到正確的 index 再加到 row 。
```cpp
a = ... a[2] a[1] a[0]
b = ... b[2] b[1] b[0]
*
----------------------------------------
... a[2]*b[0] a[1]*b[0] a[0]*b[0]
... a[1]*b[1] a[0]*b[1]
... a[0]*b[2]
+
----------------------------------------
row = ... row[2] row[1] row[0]
```
### print bn by hex
由於 bn 使用二進位儲存,因此無法簡單轉換成十進位輸出。
```cpp
static void bn_to_str(struct bn *n, char *str, int nbytes) {
/* index into array - reading "MSB" first -> big-endian */
int j = BN_ARRAY_SIZE - 1;
int i = 0; /* index into string representation */
/* reading last array-element "MSB" first -> big endian */
while ((j >= 0) && (nbytes > (i + 1))) {
sprintf(&str[i], FORMAT_STRING, n->array[j]);
i += (2 *
UNIT_SIZE); /* step UNIT_SIZE hex-byte(s) forward in the string */
j -= 1; /* step one element back in the array */
}
/* Count leading zeros: */
for (j = 0; str[j] == '0'; j++)
;
/* Move string j places ahead, effectively skipping leading zeros */
for (i = 0; i < (nbytes - j); ++i)
str[i] = str[i + j];
str[i] = 0;
}
```
### bn add
c = a + b
改進:如果 overflow return 0, 不然 return 1.
```cpp
static int bn_add(struct bn *a, struct bn *b, struct bn *c) {
int carry = 0;
for (int i = 0; i < BN_ARRAY_SIZE; ++i) {
UTYPE_TMP tmp = (UTYPE_TMP)a->array[i] + b->array[i] + carry;
carry = (tmp > MAX_VAL);
c->array[i] = (tmp & MAX_VAL);
}
if (carry == 1) {
return 0;
} else {
return 1;
}
}
```
### bn dec
bn = bn - 1
改進:如果 overflow return 0, 不然 return 1.
```cpp
/* Decrement: subtract 1 from n */
static int bn_dec(struct bn *n) {
if (bn_is_zero) {
return 0;
}
for (int i = 0; i < BN_ARRAY_SIZE; ++i) {
UTYPE tmp = n->array[i];
UTYPE res = tmp - 1;
n->array[i] = res;
if (!(res > tmp)) break;
}
return 1;
}
```
## bn 輸出十進位
參考[sysprog21/bignum](https://github.com/sysprog21/bignum)中 `format.c`的實作
```cpp
#define digit_div(n1, n0, d, q, r) \
__asm__("divl %4" \
: "=a"(q), "=d"(r) \
: "0"(n0), "1"(n1), "g"(d))
#define BN_NORMALIZE(u, usize) \
while ((usize) && !(u)[(usize)-1]) \
--(usize);
/* Set u[size] = u[usize] / v, and return the remainder. */
static uint32_t bn_ddivi(uint32_t *u, uint32_t size, uint32_t v) {
assert(u != NULL);
assert(v != 0);
if (v == 1)
return 0;
uint32_t s1 = 0;
u += size;
do {
uint32_t s0 = *--u;
uint32_t q, r;
if (s1 == 0) {
q = s0 / v;
r = s0 % v;
} else {
digit_div(s1, s0, v, q, r);
}
*u = q;
s1 = r;
} while (--size);
return s1;
}
static const char radix_chars[37] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
#ifndef SWAP
#define SWAP(x, y) \
do { \
typeof(x) __tmp = x; \
x = y; \
y = __tmp; \
} while (0)
#endif
static char *bn_get_str(bn_t *n, char *out) {
unsigned int radix = 10;
const uint32_t max_radix = 0x3B9ACA00U;
const unsigned int max_power = 9;
if (!out)
out = calloc(256, (sizeof(char)));
char *outp = out;
bn_t *tmp = bn_tmp_copy(n);
uint32_t size;
uint32_t *tmp_u = tmp->array;
for (int i = 32; i >= 0; i--) {
if (tmp->array[i] != 0) {
size = i + 1;
break;
}
}
BN_NORMALIZE(n->array, size);
if (size == 0 || (size == 1 && tmp_u[0] < radix)) {
if (!out)
out = malloc(2);
out[0] = size ? radix_chars[tmp_u[0]] : '0';
out[1] = '\0';
return out;
}
uint32_t tsize = size;
do {
/* Multi-precision: divide U by largest power of RADIX to fit in
* one apm_digit and extract remainder.
*/
uint32_t r = bn_ddivi(tmp_u, size, max_radix);
tsize -= (tmp_u[tsize - 1] == 0);
/* Single-precision: extract K remainders from that remainder,
* where K is the largest integer such that RADIX^K < 2^BITS.
*/
unsigned int i = 0;
do {
uint32_t rq = r / radix;
uint32_t rr = r % radix;
*outp++ = radix_chars[rr];
r = rq;
if (tsize == 0 && r == 0) /* Eliminate any leading zeroes */
break;
} while (++i < max_power);
assert(r == 0);
/* Loop until TMP = 0. */
} while (tsize != 0);
free(tmp);
char *f = outp - 1;
/* Eliminate leading (trailing) zeroes */
while (*f == '0')
--f;
/* NULL terminate */
f[1] = '\0';
/* Reverse digits */
for (char *s = out; s < f; ++s, --f)
SWAP(*s, *f);
return out;
}
```
### 測試結果
印出 factorial(1) 到 factorial(100) 並跟 python 運算結果做比對
```cpp
int main(int argc, char *argv[]) {
struct bn num, result;
char buf[8192];
unsigned int n = strtoul(argv[1], NULL, 10);
if (!n)
return -2;
for (int i = 1; i <= n; i++) {
bn_from_int(&num, i);
factorial(&num, &result);
bn_to_str(&result, buf, sizeof(buf));
// printf("factorial(%d) = %s\n", i, buf);
uint32_t string_size = 512;
char *str = calloc(1, string_size);
char *a = bn_get_str(&result, str);
printf("fac(%d) = %s\n", i, a);
}
return 0;
}
```
```python
#python 計算 factorial from 1~100
def factorial(n):
for i in range(n-1,0,-1):
n *= i
return n
if __name__=="__main__":
for i in range(1,101):
print("fac(%d) =" %i,factorial(i))
```
用 shell script 來比對兩個結果
```bash
#!/bin/bash
diff -u -B q6ans pyans
ret=$?
if [[ $ret -eq 0 ]]; then
echo "passed."
else
echo "failed."
fi
```
比對運算結果,與 python 計算結果一致。
```shell
$./q6 100 > q6ans
fac(1) = 1
fac(2) = 2
fac(3) = 6
fac(4) = 24
fac(5) = 120
...
fac(98) = 9426890448883247745626185743057242473809693764078951663494238777294707070023223798882976159207729119823605850588608460429412647567360000000000000000000000
fac(99) = 933262154439441526816992388562667004907159682643816214685929638952175999932299156089414639761565182862536979208272237582511852109168640000000000000000000000
fac(100) = 93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000
$python3 factorial > pyans
$bash diff.sh
passed.
```
## Karatsuba algorithm
### Karatsuba 實作
參考 [Karatsuba algorithm](https://en.wikipedia.org/wiki/Karatsuba_algorithm) 的 procedure code
```python
procedure karatsuba(num1, num2)
if (num1 < 10) or (num2 < 10)
return num1 × num2
/* Calculates the size of the numbers. */
m = min(size_base10(num1), size_base10(num2))
m2 = floor(m / 2)
/* m2 = ceil(m / 2) will also work */
/* Split the digit sequences in the middle. */
high1, low1 = split_at(num1, m2)
high2, low2 = split_at(num2, m2)
/* 3 calls made to numbers approximately half the size. */
z0 = karatsuba(low1, low2)
z1 = karatsuba((low1 + high1), (low2 + high2))
z2 = karatsuba(high1, high2)
return (z2 × 10 ^ (m2 × 2)) + ((z1 - z2 - z0) × 10 ^ m2) + z0
```
```cpp=
static void karatsuba_mul(bn_t *a, bn_t *b, bn_t *c) {
int m1 = bn_size(a);
int m2 = bn_size(b);
bn_init(c);
if (m1 <= 1 && m2 <= 1) {
UTYPE_TMP intermediate = a->array[0] * (UTYPE_TMP)b->array[0];
bn_from_int(c, intermediate);
return;
}else if(m1 <= 1){
bn_t row ,tmp;
bn_init(&row);
for (int j = 0; j < BN_ARRAY_SIZE; ++j) {
bn_init(&tmp);
UTYPE_TMP intermediate = a->array[0] * (UTYPE_TMP)b->array[j];
bn_from_int(&tmp, intermediate);
lshift_unit(&tmp, j);
bn_add(&tmp, &row, &row);
}
bn_add(c, &row, c);
return;
}else if( m2 <=1){
bn_t row ,tmp;
bn_init(&row);
for (int j = 0; j < BN_ARRAY_SIZE; ++j) {
if ( j < BN_ARRAY_SIZE) {
bn_init(&tmp);
UTYPE_TMP intermediate = a->array[j] * (UTYPE_TMP)b->array[0];
bn_from_int(&tmp, intermediate);
lshift_unit(&tmp, j);
bn_add(&tmp, &row, &row);
}
}
bn_add(c, &row, c);
return;
}
//min
int m = (m1 > m2) ? m2 : m1;
int mm = m / 2;
bn_t high1, low1, high2, low2, z0, z1, z2, tmp1, tmp2;
bn_init(&high1);
bn_init(&low1);
bn_init(&high2);
bn_init(&low2);
bn_init(&z0);
bn_init(&z1);
bn_init(&z2);
bn_init(&tmp1);
bn_init(&tmp2);
bn_split_shift(a, &high1, &low1, mm);
bn_split_shift(b, &high2, &low2, mm);
karatsuba_mul(&low1, &low2, &z0);
bn_add(&low1, &high1, &tmp1);
bn_add(&low2, &high2, &tmp2);
karatsuba_mul(&tmp1, &tmp2, &z1);
karatsuba_mul(&high1, &high2, &z2);
bn_sub(&z1, &z2, &tmp1);
bn_sub(&tmp1, &z0, &tmp1);
bn_left_shift(&tmp1, mm);
bn_left_shift(&z2, 2*mm);
bn_add(&tmp1, &z2, &tmp2);
bn_add(&tmp2, &z0, c);
return;
}
```
### bn size
回傳 bn 去除前面的零的大小
```cpp
static int bn_size(bn_t *a) {
for (int i = (int)(BN_ARRAY_SIZE - 1); i >= 0; i--) {
if (a->array[i] != 0) {
return i + 1;
}
}
return 0;
}
```
### bn split
以 offset 的位置分割 a 到 b 跟 c
```cpp
// a = b concat c
static void bn_split_shift(bn_t *a, bn_t *b, bn_t *c, const int offset) {
bn_init(b);
bn_init(c);
for (int i = 0; i < offset; i++) {
c->array[i] = a->array[i];
}
for (int i = 0; i < BN_ARRAY_SIZE - offset; i++) {
b->array[i] = a->array[i + offset];
}
}
```
### bn subtraction
bn 減法
```cpp
// c = a - b
static int bn_sub(bn_t *a, bn_t *b, bn_t *c) {
//printf("sub\n");
UTYPE_TMP carry = 0;
for (int i = 0; i < BN_ARRAY_SIZE; i++) {
if (a->array[i] == 0 && carry == 1) {
c->array[i] = MAX_VAL - (UTYPE_TMP)b->array[i] + 1;
carry = 1;
} else {
if (a->array[i] - carry < b->array[i]) {
c->array[i] = (MAX_VAL - (UTYPE_TMP)b->array[i])+ 1 + (UTYPE_TMP)a->array[i] - carry;
carry = 1;
} else {
c->array[i] = (UTYPE_TMP)a->array[i] - carry - (UTYPE_TMP)b->array[i];
carry = 0;
}
}
}
if (carry == 1) {
return 0;
} else {
return 1;
}
}
```
### bn left shift
left shift bn number of offset,一個 offset 為 32bit
```cpp
void bn_left_shift(bn_t *a, int offset) {
for (int i = (int)(BN_ARRAY_SIZE - 1); i >= offset; i--) {
a->array[i] = a->array[i - offset];
}
for (int i = 0; i < offset; i++) {
a->array[i] = 0;
}
}
```
### 正確性
用 bn_mul 當作正確答案進行比對。手動更改初始值, bn_mul 跟 karatsuba_mul 的答案都一樣。
```cpp
int main(int argc, char *argv[]) {
bn_t a, b, c, d;
bn_init(&a);
bn_init(&b);
bn_init(&c);
bn_init(&d);
bn_from_int(&a, 1567579508);
bn_from_int(&d, 1567579508);
int loop = 31;
for (uint64_t i = 0; i < loop; i++) {
bn_left_shift(&a,1);
bn_assign(&b,&a);
bn_add(&a,&d,&a);
bn_mul(&a, &b, &c);
bn_printhex(&c);
}
bn_from_int(&a, 1567579508);
for (uint64_t i = 0; i < loop; i++) {
bn_left_shift(&a,1);
bn_assign(&b,&a);
bn_add(&a,&d,&a);
karatsuba_mul(&a, &b, &c);
bn_printhex(&c);
}
return 0;
}
```
### 效率比較
隨機一個 unsigned int 作為初始值,每一次迴圈都 left shift 32bit 再加上一個隨機 unsigned int ,以此測試演算法在不同大小的效率。
```cpp
int main(int argc, char *argv[]) {
struct timespec t_start, t_end;
long long k = 500;
const int size = 1000;
long long normal_time[size], karatsuba_time[size];
int normal_bits[size], karatsuba_bits[size];
bn_t a, b, c, d;
bn_init(&a);
bn_init(&b);
bn_init(&c);
bn_init(&d);
srand(time(NULL));
FILE *fp1, *fp2;
fp1 = fopen("origin.txt", "a");
fp2 = fopen("karatsuba.txt","a");
for (int k = 0; k < 50; k++) {
uint32_t rand_n = rand();
bn_from_int(&a, rand_n);
bn_from_int(&d, rand_n);
int loop = 31;
for (uint64_t i = 0; i < loop; i++) {
normal_bits[i] = bn_bit_count(&a);
clock_gettime(CLOCK_MONOTONIC, &t_start);
for (uint64_t j = 0; j < 10; j++) {
bn_mul(&a, &b, &c);
}
clock_gettime(CLOCK_MONOTONIC, &t_end);
normal_time[i] = (t_end.tv_sec * NANOSEC + t_end.tv_nsec) - (t_start.tv_sec * NANOSEC + t_start.tv_nsec);
bn_left_shift(&a, 1);
bn_assign(&b, &a);
bn_add(&a, &d, &a);
}
for (int i = 0; i < loop; i++) {
fprintf(fp1, "%d ", normal_bits[i]);
fprintf(fp1, "%lld\n", normal_time[i]);
}
bn_from_int(&a, rand_n);
for (uint64_t i = 0; i < loop; i++) {
karatsuba_bits[i] = bn_bit_count(&a);
clock_gettime(CLOCK_MONOTONIC, &t_start);
for (uint64_t j = 0; j < 10; j++) {
karatsuba_mul(&a, &b, &c);
}
clock_gettime(CLOCK_MONOTONIC, &t_end);
karatsuba_time[i] = (t_end.tv_sec * NANOSEC + t_end.tv_nsec) - (t_start.tv_sec * NANOSEC + t_start.tv_nsec);
bn_left_shift(&a, 1);
bn_assign(&b, &a);
bn_add(&a, &d, &a);
}
for (int i = 0; i < loop; i++) {
fprintf(fp2, "%d ", karatsuba_bits[i]);
fprintf(fp2, "%lld\n", karatsuba_time[i]);
}
}
fclose(fp1);
fclose(fp2);
return 0;
}
```
從圖中可以發現 karatsuba algorithm 的耗時隨著 bit size 增加而增加。
bn_mul,不論大小耗時都一致。
karatsuba algorithm 在 bit size 超過一定的範圍後效率會不如 bn_mul 。
因為是隨機初始值,所以初始值很容易是 32 bit or 31 bit
![](https://i.imgur.com/tZ2T69s.png)