# Assignment 1: RISC-V Assembly and Instruction Pipeline contributed by [zz1888](https://github.com/zz1888/ca2025-quizzes) # Problem B **Algorithm Comparison for Count Leading Zeros (CLZ)**: The primary objective of this project is to implement, analyze, and compare the performance of CLZ with `The De Bruijn Sequence Method`. ### OverView of the problem B: The problem is about implementing 20-bit unsigned integers by 8-bit uf8 format approximatilly. ### CLZ It's about to count the number of leading zero bits in a 32-bit unsigned integer using a binary search algorithm to efficiently find the position of the most significant set bit (the first '1' from the left). :::spoiler C Code ```c= static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } ``` ::: ### Decoder: $D(b) = m \cdot 2^e + (2^e - 1) \cdot 16$ :::spoiler C Code ```c= uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } ``` ::: :::spoiler Assembly Code ```c= uf8_decoder: andi t0,a0,0x0F #mantissa srli t1,a0,4 #exponent li t6,15 #15 sub t2,t6,t1 #15-exponent li t3,0x7FFF srl t3,t3,t2 #0x7FFF >> (15 - exponent) slli t3,t3,4 #offset sll t4,t0,t1 #mantissa << exponent add a0,t4,t3 ret ``` ::: ### Encoder: $E(v) = \begin{cases} v, & \text{if } v < 16 \\ 16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise} \end{cases}$ :::spoiler Code ```c= uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } ``` ::: :::spoiler Assembly Code ```c= uf8_encoder: addi sp, sp, -20 sw ra, 16(sp) sw s0, 12(sp) sw s1, 8(sp) sw s2, 4(sp) sw s3, 0(sp) mv s0,a0 slti t0,s0,16 bnez t0,clz_encoder_if1_loop #if (value < 16) return value; jal ra,clz #lz mv s9,a0 # t1=lz li t0,31 sub s3,t0,s9 # s3=msb li s1,0 #exponent li s2,0 #overflow li t0,5 bge s3,t0,clz_encoder_if2_loop li t6,15 bge s1,t6,clz_encoder_ending_while2 j clz_encoder_ending_while1 clz_encoder_if1_loop: mv a0,s0 j encoder_end clz_encoder_if2_loop: addi s1,s3,-4 li t0,15 ble s1,t0,clz_encoder_if3_loop li s1,15 clz_encoder_if3_loop: li t5,0 #t5=e clz_encoder_for: bge t5,s1,clz_encoder_ending_for slli s2,s2,1 addi s2,s2,16 addi t5,t5,1 j clz_encoder_for clz_encoder_ending_for: blez s1,clz_encoder_ending_while1 bge s0,s2,clz_encoder_ending_while1 addi s2,s2,-16 srli s2,s2,1 addi s1,s1,-1 j clz_encoder_ending_for clz_encoder_ending_while1: li t6,15 slli t4,s2,1 addi t4,t4,16 blt s0,t4,clz_encoder_ending_while2 # if (value < next_overflow) mv s2,t4 # overflow = next_overflow addi s1,s1,1 bge s1,t6,clz_encoder_ending_while2 #if (exponent >= 15) j clz_encoder_ending_while1 clz_encoder_ending_while2: sub s3,s0,s2 srl s3,s3,s1 slli a0,s1,4 or a0,a0,s3 j encoder_end encoder_end: lw s3,0(sp) lw s2,4(sp) lw s1,8(sp) lw s0,12(sp) lw ra,16(sp) addi sp, sp, 20 ret ``` ::: ### Test :::spoiler C Code ```c= static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } ``` ::: :::spoiler Assembly Code ```c= test_loop: addi sp, sp, -4 sw ra,0(sp) li s10,256 bge s5,s10,test_end mv a0,s5 jal ra,uf8_decoder mv t1,a0 jal ra,uf8_encoder mv t2,a0 bne s5,t2,test_if_loop1 ble t1,s4,test_if_loop2 mv s4,t1 j test_out test_if_loop1: li s3,0 mv s6,s5 mv s7,t1 mv s8,t2 ble t1,s4,test_if_loop2 mv s4,t1 j test_out test_if_loop2: li s3,0 mv s6,s5 mv s7,t1 mv s8,s4 mv s4,t1 j test_out test_out: beqz s3,fail addi s5,s5,1 bge s5,s10,test_end j test_loop test_end: lw ra,0(sp) addi sp,sp,4 beq s5,s10,pass jr ra ``` ::: ## First Approach-The Binary Search Method :::spoiler Full Assembly Code ```c= .data msg_pass: .asciz "All tests passed.\n" msg_fail: .asciz "Some tests failed.\n" .text .global _main main: li s3,1 # passed = true li s4,-1 # previous_value = -1 li s5,0 # i = 0 jal ra, test_loop beqz s3, fail j end pass: la a0, msg_pass li a7, 4 ecall j end fail: mv a0,s5 # print value li a7,1 ecall la a0, msg_fail li a7, 4 ecall j end clz: addi sp, sp, -4 sw ra, 0(sp) li a2, 32 #n=32 li a3, 16 #c=16 clz_loop: srl a1,a0,a3 #x=a0 y=t1 beqz a1, clz_not_if_loop sub a2,a2,a3 mv a0,a1 clz_not_if_loop: srli a3,a3,1 bnez a3,clz_loop sub a0, a2, a0 lw ra, 0(sp) addi sp, sp, 4 ret uf8_decoder: andi t0,a0,0x0F #mantissa srli t1,a0,4 #exponent li t6,15 #15 sub t2,t6,t1 #15-exponent li t3,0x7FFF srl t3,t3,t2 #0x7FFF >> (15 - exponent) slli t3,t3,4 #offset sll t4,t0,t1 #mantissa << exponent add a0,t4,t3 ret uf8_encoder: addi sp, sp, -20 sw ra, 16(sp) sw s0, 12(sp) sw s1, 8(sp) sw s2, 4(sp) sw s3, 0(sp) mv s0,a0 slti t0,s0,16 bnez t0,clz_encoder_if1_loop #if (value < 16) return value; jal ra,clz #lz mv s9,a0 # t1=lz li t0,31 sub s3,t0,s9 # s3=msb li s1,0 #exponent li s2,0 #overflow li t0,5 bge s3,t0,clz_encoder_if2_loop li t6,15 bge s1,t6,clz_encoder_ending_while2 j clz_encoder_ending_while1 clz_encoder_if1_loop: mv a0,s0 j encoder_end clz_encoder_if2_loop: addi s1,s3,-4 li t0,15 ble s1,t0,clz_encoder_if3_loop li s1,15 clz_encoder_if3_loop: li t5,0 #t5=e clz_encoder_for: bge t5,s1,clz_encoder_ending_for slli s2,s2,1 addi s2,s2,16 addi t5,t5,1 j clz_encoder_for clz_encoder_ending_for: blez s1,clz_encoder_ending_while1 bge s0,s2,clz_encoder_ending_while1 addi s2,s2,-16 srli s2,s2,1 addi s1,s1,-1 j clz_encoder_ending_for clz_encoder_ending_while1: li t6,15 slli t4,s2,1 addi t4,t4,16 blt s0,t4,clz_encoder_ending_while2 # if (value < next_overflow) mv s2,t4 # overflow = next_overflow addi s1,s1,1 bge s1,t6,clz_encoder_ending_while2 #if (exponent >= 15) j clz_encoder_ending_while1 clz_encoder_ending_while2: sub s3,s0,s2 srl s3,s3,s1 slli a0,s1,4 or a0,a0,s3 j encoder_end encoder_end: lw s3,0(sp) lw s2,4(sp) lw s1,8(sp) lw s0,12(sp) lw ra,16(sp) addi sp, sp, 20 ret test_loop: addi sp, sp, -4 sw ra,0(sp) li s10,256 bge s5,s10,test_end mv a0,s5 jal ra,uf8_decoder mv t1,a0 jal ra,uf8_encoder mv t2,a0 bne s5,t2,test_if_loop1 ble t1,s4,test_if_loop2 mv s4,t1 j test_out test_if_loop1: li s3,0 mv s6,s5 mv s7,t1 mv s8,t2 ble t1,s4,test_if_loop2 mv s4,t1 j test_out test_if_loop2: li s3,0 mv s6,s5 mv s7,t1 mv s8,s4 mv s4,t1 j test_out test_out: beqz s3,fail addi s5,s5,1 bge s5,s10,test_end j test_loop test_end: lw ra,0(sp) addi sp,sp,4 beq s5,s10,pass jr ra end: li a7,10 ecall ``` ::: #### CLZ ```c= clz: addi sp, sp, -4 sw ra, 0(sp) li a2, 32 #n=32 li a3, 16 #c=16 clz_loop: srl a1,a0,a3 #x=a0 y=t1 beqz a1, clz_not_if_loop sub a2,a2,a3 mv a0,a1 clz_not_if_loop: srli a3,a3,1 bnez a3,clz_loop sub a0, a2, a0 lw ra, 0(sp) addi sp, sp, 4 ret ``` #### Excution Info ![image](https://hackmd.io/_uploads/H1JStV6agl.png) #### DATA ```c= msg_pass: .asciz "All tests passed.\n" msg_fail: .asciz "Some tests failed.\n" ``` ### Test ![image](https://hackmd.io/_uploads/SkeGjEpael.png) ## Optimize Approach:Usind The De Bruijn Sequence Method :::spoiler Assembly Code ```c== .data msg_pass: .asciz "All tests passed.\n" msg_fail: .asciz "Some tests failed.\n" debruijn_table: .byte 0, 31, 9, 30, 3, 8, 13, 29, 2, 5, 7, 21, 12, 24, 28, 19 .byte 1, 10, 4, 14, 6, 22, 25, 20, 11, 15, 23, 26, 16, 27, 17, 18 .text .global _main main: li s3,1 # passed = true li s4,-1 # previous_value = -1 li s5,0 # i = 0 jal ra, test_loop beqz s3, fail j end pass: la a0, msg_pass li a7, 4 ecall j end fail: mv a0,s5 # print value li a7,1 ecall la a0, msg_fail li a7, 4 ecall j end clz: beqz a0, handle_zero srli t0, a0, 1 # x |= x >> 1; or a0, a0,t0 srli t0, a0, 2 # x |= x >> 2; or a0, a0, t0 srli t0, a0, 4 # x |= x >> 4; or a0,a0,t0 srli t0,a0,8 # x |= x >> 8; or a0,a0,t0 srli t0,a0,16 # x |= x >> 16; or a0,a0,t0 addi a0,a0,1 # x++; li t0,0x076be629 mul a0,a0,t0 srli a0,a0,27 la t0,debruijn_table add a0,t0,a0 lbu a0,0(a0) ret handle_zero: li a0, 32 ret uf8_decoder: andi t0,a0,0x0F #mantissa srli t1,a0,4 #exponent li t6,15 #15 sub t2,t6,t1 #15-exponent li t3,0x7FFF srl t3,t3,t2 #0x7FFF >> (15 - exponent) slli t3,t3,4 #offset sll t4,t0,t1 #mantissa << exponent add a0,t4,t3 ret uf8_encoder: addi sp, sp, -20 sw ra, 16(sp) sw s0, 12(sp) sw s1, 8(sp) sw s2, 4(sp) sw s3, 0(sp) mv s0,a0 slti t0,s0,16 bnez t0,clz_encoder_if1_loop #if (value < 16) return value; jal ra,clz #lz mv s9,a0 # t1=lz li t0,31 sub s3,t0,s9 # s3=msb li s1,0 #exponent li s2,0 #overflow li t0,5 bge s3,t0,clz_encoder_if2_loop li t6,15 bge s1,t6,clz_encoder_ending_while2 j clz_encoder_ending_while1 clz_encoder_if1_loop: mv a0,s0 j encoder_end clz_encoder_if2_loop: addi s1,s3,-4 li t0,15 ble s1,t0,clz_encoder_if3_loop li s1,15 clz_encoder_if3_loop: li t5,0 #t5=e clz_encoder_for: bge t5,s1,clz_encoder_ending_for slli s2,s2,1 addi s2,s2,16 addi t5,t5,1 j clz_encoder_for clz_encoder_ending_for: blez s1,clz_encoder_ending_while1 bge s0,s2,clz_encoder_ending_while1 addi s2,s2,-16 srli s2,s2,1 addi s1,s1,-1 j clz_encoder_ending_for clz_encoder_ending_while1: li t6,15 slli t4,s2,1 addi t4,t4,16 blt s0,t4,clz_encoder_ending_while2 # if (value < next_overflow) mv s2,t4 # overflow = next_overflow addi s1,s1,1 bge s1,t6,clz_encoder_ending_while2 #if (exponent >= 15) j clz_encoder_ending_while1 clz_encoder_ending_while2: sub s3,s0,s2 srl s3,s3,s1 slli a0,s1,4 or a0,a0,s3 j encoder_end encoder_end: lw s3,0(sp) lw s2,4(sp) lw s1,8(sp) lw s0,12(sp) lw ra,16(sp) addi sp, sp, 20 ret test_loop: addi sp, sp, -4 sw ra,0(sp) li s10,256 bge s5,s10,test_end mv a0,s5 jal ra,uf8_decoder mv t1,a0 jal ra,uf8_encoder mv t2,a0 bne s5,t2,test_if_loop1 ble t1,s4,test_if_loop2 mv s4,t1 j test_out test_if_loop1: li s3,0 mv s6,s5 mv s7,t1 mv s8,t2 ble t1,s4,test_if_loop2 mv s4,t1 j test_out test_if_loop2: li s3,0 mv s6,s5 mv s7,t1 mv s8,s4 mv s4,t1 j test_out test_out: beqz s3,fail addi s5,s5,1 bge s5,s10,test_end j test_loop test_end: lw ra,0(sp) addi sp,sp,4 beq s5,s10,pass jr ra end: li a7,10 ecall ``` ::: #### Improve the CLZ by Using `DeBruijn sequence` method ```c= int clz(uint32_t x) { static const char debruijn32[32] = { 0, 31, 9, 30, 3, 8, 13, 29, 2, 5, 7, 21, 12, 24, 28, 19, 1, 10, 4, 14, 6, 22, 25, 20, 11, 15, 23, 26, 16, 27, 17, 18 }; x |= x>>1; x |= x>>2; x |= x>>4; x |= x>>8; x |= x>>16; x++; return debruijn32[x*0x076be629>>27]; } ``` **1、Convert Any Number into a Power of Two.** Example:`0b00101101`→`0b01000000` **2、Hashing with a `Magic Number:0x076be629`** We need a way to quickly map the 32 possible power-of-two values ($2^ 0,2^1,...,2^{31}$) to 32 unique indices (0 to 31) `x * MagicNumber` : Since $x$ is a power of two ($2^k$), this multiplication is equivalent to a left bit-shift of the magic number by k places (magic_number << k),it will form a unique 5-bit number. `>>27` :This operation isolates the top 5 bits of the result,converting the bit position k into a unique index between 0 and 31. **3、Looking Up the Value** Example: If we pre-calculated that a bit position of `k=25`yields a hash index of `13`, then we would store the value `25` at the 13th element of the table (`table[13] = 25`) #### CLZ ```c= clz: beqz a0, handle_zero srli t0,a0, 1 # x|= x >> 1; or a0,a0,t0 srli t0,a0, 2 # x|= x >> 2; or a0,a0, t0 srli t0,a0, 4 # x|= x >> 4; or a0,a0,t0 srli t0,a0,8 # x|= x >> 8; or a0,a0,t0 srli t0,a0,16 # x|= x >> 16; or a0,a0,t0 addi a0,a0,1 # x++; li t0,0x076be629 mul a0,a0,t0 srli a0,a0,27 la t0,debruijn_table add a0,t0,a0 lbu a0,0(a0) ret ``` #### Excution Info ![image](https://hackmd.io/_uploads/rJCCjr6alx.png) The De Bruijn sequence method is demonstrably superior. It not only reduces the total execution cycles and instruction count but also improves the processor's core efficiency. The lower CPI and higher IPC indicate that the branchless, fixed-path nature of the algorithm allows the CPU pipeline to operate more smoothly and effectively, leading to a significant overall performance gain. #### DATA ```c= .data msg_pass: .asciz "All tests passed.\n" msg_fail: .asciz "Some tests failed.\n" debruijn_table: .byte 0, 31, 9, 30, 3, 8, 13, 29, 2, 5, 7, 21, 12, 24, 28, 19 .byte 1, 10, 4, 14, 6, 22, 25, 20, 11, 15, 23, 26, 16, 27, 17, 18 ``` ### Test ![image](https://hackmd.io/_uploads/SJCOpSa6xx.png) # Problem C ## DATA :::spoiler Assembly ```c= .data test1_a: .half 0x40C0 test1_b: .half 0xC000 test1_add: .half 0x4080 # 4.0 test1_sub: .half 0x4100 # 8.0 test1_mul: .half 0xC140 # -12.0 test1_div: .half 0xC040 # -3.0 test1_sqrt: .half 0x401E # sqrt(6.0) is ~2.45 msg_test1: .string "\nTest 1 (a=6.0, b=-2.0):\n" test2_a: .half 0x7F80 test2_b: .half 0x40A0 test2_add: .half 0x7F80 # +Inf test2_sub: .half 0x7F80 # +Inf test2_mul: .half 0x7F80 # +Inf test2_div: .half 0x7F80 # +Inf test2_sqrt: .half 0x7F80 # sqrt(+Inf) is +Inf msg_test2: .string "\nTest 2 (a=+Inf, b=5.0):\n" test3_a: .half 0x0000 test3_b: .half 0x8000 test3_add: .half 0x0000 # 0.0 test3_sub: .half 0x0000 # 0.0 test3_mul: .half 0x8000 # -0.0 test3_div: .half 0x7FC0 # NaN (0/0) test3_sqrt: .half 0x0000 # sqrt(0.0) is 0.0 msg_test3: .string "\nTest 3 (a=0.0, b=-0.0):\n" test4_a: .half 0xBF80 test4_sqrt: .half 0x7FC0 # NaN msg_test4: .string "\nTest 4 (sqrt(-1.0)):\n" test5_a: .half 0x7FC1 test5_sqrt: .half 0x7FC1 # NaN msg_test5: .string "\nTest 5 (sqrt(NaN)):\n" msg_add: .string " add: " msg_sub: .string " sub: " msg_mul: .string " mul: " msg_div: .string " div: " msg_sqrt: .string " sqrt(a): " msg_pass: .string "PASS\n" msg_fail: .string "FAIL" ``` ::: ## MAIN ```c= Its not finish yet,Sorry. ``` ## Is NAN :::spoiler C ```c= static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } ``` ::: :::spoiler Assembly ```c= Isnan: li t6,0x7F80 li s10,0x007F and s10,a0,t6 and s11,a0,s10 bne s10,t6,return_false beqz s11,return_false li a0,1 jr ra ``` ::: ## Is Inf :::spoiler Assembly C ```c= static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } ``` ::: :::spoiler Assembly ```c= Isinf: li t6,0x7F80 li s10,0x007F and s10,a0,t6 and s11,a0,s10 bne s10,t6,return_false bnez s11,return_false li a0,1 jr ra ``` ::: ## Is Zero :::spoiler C ```c= static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } ``` ::: :::spoiler Assembly ```c= Iszero: li s10,0x7FFF and a0,a0,s10 bnez a0,return_false li a0,1 jr ra ``` ::: ## f32_to_bf16 :::spoiler C ```c= static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } ``` ::: :::spoiler Assembly ```c= f32_to_bf16: srli a4, a0, 23 andi a4, a4, 0xFF # a4 = exponent li a5, 0xFF beq a4, a5, is_all_one srli a6, a0, 16 andi a6, a6, 1 # a6 = (a0 >> 16) & 1 li a7, 0x7FFF add a6, a6, a7 # a6 = round offset add a4, a0, a6 # a4 = a0 + round offset srli a0, a4, 16 # a0 = high 16 bits jr ra is_all_one: srli a0, a0, 16 li a5, 0xFFFF and a0, a0, a5 # a0 = bfloat16 result jr ra ``` ::: ## bf16_to_f32 :::spoiler C ```c= static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } ``` ::: :::spoiler Assembly ```c= bf16_to_f32: slli a4, a0, 16 # a4 = a0 << 16 mv a0, a4 # a0 = f32 result jr ra ``` ::: ## ADD :::spoiler C ```c= sstatic inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } ``` ::: :::spoiler Assembly ```c= bf16_add: li t6,0xFF bne t2,t6,check_exp_b bnez t4,return_a check_expb: bne t3,t6,return_a bnez t5,return_b beq t0,t1,return_b li a0, 0x7FC0 jr ra return_b: mv a0,a1 jr ra return_a: mv a0,a0 jr ra check_exp_b: bne t3,t6,check_expa_manta mv a0,a1 jr ra check_expa_manta: bnez t2,check_expb_mantb bnez t4,check_expb_mantb mv a0,a1 jr ra check_expb_mantb: bnez t3,check_zero_expa bnez t5,check_zero_expa mv a0,a0 jr ra check_zero_expa: beqz t2,check_zero_expb ori t4,t4,0x80 check_zero_expb: beqz t3,next ori t5,t5,0x80 next: sub s0,t2,t3 #s0=exp_diff beqz s0,equal blt s0,x0,expa_smaller expa_bigger: mv s1,t2 li t6,8 bgt s0,t6,return_a srl t5,t5,s0 j align_done expa_smaller: mv s1, t3 li t6,-8 blt s0,t6,return_b sub s0,x0,s0 srl t4,t4,s0 j align_done equal: mv s1,t2 #s1=result_exp align_done: beq t0,t1,same_sign bge t4,t5,mant_bigger mv s2,t1 sub s3,t5,t4 beqz s3,result_mant andi t6,s3,0x80 bnez t6,normalize_done while: slli s3,s3,1 # mant �����@�� addi s1,s1,-1 # exp-- ble s1,x0,return_zero # exp <= 0 -> return 0 andi t6,s3,0x80 # �A���ˬd leading 1 beqz t6,while # �Y���L leading 1 -> �~���j�� normalize_done: slli s2,s2,15 andi s1,s1,0xFF slli s1,s1,7 andi s3,s3,0x7F or a0,s2,s1 or a0,a0,s3 jr ra result_mant: li t6,0x0000 mv a0,t6 jr ra mant_bigger: mv s2,t0 sub s3,t4,t5 beqz s3,result_mant andi t6,s3,0x80 bnez t6,normalize_done j while same_sign: mv s2,t0 #result_sign add s3,t4,t5 #s3=result_mant andi s4,s3,0x100 #S4=result_mant&&0x100 beqz s4, build_result # �L�i���N���� build ���G srli s3, s3, 1 # mantissa addi s1, s1, 1 # exponent +1 li t6, 0xFF bge s1, t6, return_out # �Y overflow�A���h ��Inf build_result: andi s1,s1,0xFF andi s3, s3, 0x7F slli s1, s1, 7 or s1, s1, s3 slli s2, s2, 15 or a0, s2, s1 jr ra return_out: li t6,0x7F80 slli s2,s2,15 # sign << 15 or a0,s2,t6 # �զX���G jr ra return_zero: li t6,0x0000 mv a0,t6 jr ra ``` ::: ## SUB :::spoiler C ```c= static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } ``` ::: :::spoiler Assembly ```c= bf16_sub: addi sp, sp, -4 sw ra, 0(sp) li t6, 0x8000 xor a1, a1, t6 # ½�� b �Ÿ� jal ra, bf16_unpack jal ra, bf16_add lw ra, 0(sp) addi sp, sp, 4 jr ra ``` ::: ## MUL :::spoiler C ```c= static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } ``` ::: :::spoiler Assembly ```c= bf16_mul: xor s0,t0,t1 #s0 is result sign li t6, 0xFF check_a_inf: bne t2, t6, check_b bnez t4,return_a beq t3,t6,check_nan1 ######################## j bb check_nan1: ############################ bnez t5,return_nan ########################## bb: bnez t3,check_a_inf_done bnez t5,check_a_inf_done j return_nan check_a_inf_done: slli s0,s0,15 li t6,0x7F80 or s0,s0,t6 mv a0,s0 jr ra check_b: li t6, 0xFF bne t3,t6,check_zero bnez t5,return_b bnez t2,check_b_inf_done bnez t4,check_b_inf_done li t6,0x7FC0 mv a0,t6 jr ra check_b_inf_done: slli s0,s0,15 li t6,0x7F80 or s0,s0,t6 mv a0,s0 jr ra check_zero: # if ((!exp_a&&!mant_a) || (!exp_b&&!mant_b)) beqz t2,check_a_zero beqz t3,check_b_zero j conti check_a_zero: beqz t4,zero_done j conti check_b_zero: beqz t5,zero_done j conti zero_done: slli s0,s0,15 mv a0,s0 jr ra conti: li s5,0 ### s5 is exp_adjust bnez t2,expa_not_if #if(!expa) whiles_a: li t6,0x80 and t6,t4,t6 bnez t6,set_expa slli t4,t4,1 addi s5,s5,-1 j whiles_a set_expa: li t2,1 j check_expb_zero expa_not_if: ori t4,t4,0x80 j check_expb_zero check_expb_zero: bnez t3,expb_not_if whiles_b: li t6,0x80 and t6,t5,t6 bnez t6,set_expb slli t5,t5,1 addi s5,s5,-1 j whiles_b set_expb: li t3,1 j conti2 expb_not_if: ori t5,t5,0x80 j conti2 conti2: #s3=result_mant li s3, 0 # s3 = result_mant = 0 mv t6, t5 # t6 = multiplier mv s7, t4 # s7 = multiplicand mul_mant_loop: beqz t6, mul_mant_done andi s8, t6, 1 beqz s8, skip_add_mant add s3, s3, s7 skip_add_mant: slli s7, s7, 1 srli t6, t6, 1 j mul_mant_loop mul_mant_done: add s1,t2,t3 #s1=result_exp addi s1,s1,-127 add s1,s1,s5 check_resulit_mant: li t6,0x8000 and t6,s3,t6 beqz t6,set_result_mant srli s3,s3,8 li t6,0x7F and s3,s3,t6 addi s1,s1,1 j check_result_exp set_result_mant: srli s3,s3,7 li t6,0x7F and s3,s3,t6 j check_result_exp check_result_exp: li t6,0xFF blt s1,t6,check_result_exp_zero li t6,0x7F80 slli s0,s0,15 or s0,s0,t6 mv a0,s0 jr ra check_result_exp_zero: bgt s1,x0,output li t6,-6 blt s1,t6,returns li t6,1 sub t6,t6,s1 srl s3,s3,t6 li s1,0 j output returns: slli s0,s0,15 mv a0,s0 jr ra output: slli s0,s0,15 li t6,0xFF and s1,s1,t6 slli s1,s1,7 li t6,0x7F and s3,s3,t6 or a0,s0,s1 or a0,a0,s3 jr ra ``` ::: ## DIV :::spoiler C ```c= static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F)}; } ``` ::: :::spoiler Assembly ```c= bf16_div: xor s0,t0,t1 #s0 is result sign li t6,0xFF bne t3,t6,check_not_expb_and_mantb bnez t5,return_b li t6,0xFF beq t2,t6,check_nan2 ################## j dd check_nan2: ###################### bnez t4,return_nan ############### dd: bne t2,t6,returnss bnez t4,returnss li a0,0x7FC0 jr ra returnss: slli s0,s0,15 mv a0,s0 jr ra check_not_expb_and_mantb: bnez t3,check_expaa bnez t5,check_expaa li t6,0xFF #### beq t2,t6,check_nan3 ####### j cc #### check_nan3: #### bnez t4,return_nan #### j cc #### cc: bnez t2,returnsss bnez t4,returnsss li a0,0x7FC0 jr ra returnsss: li t6,0x7F80 slli s0,s0,15 or a0,s0,t6 jr ra check_expaa: li t6,0xFF bne t2,t6,check_not_expa_and_manta bnez t4,return_a li t6,0x7F80 slli s0,s0,15 or a0,s0,t6 jr ra check_not_expa_and_manta: bnez t2,check_expaaa bnez t4,check_expaaa slli s0,s0,15 mv a0,s0 jr ra check_expaaa: beqz t2,check_expbbb ori t4,t4,0x80 check_expbbb: beqz t3,contii ori t5,t5,0x80 contii: slli s1,t4,15 #s1=dividend mv s2,t5 #s2=diversor li s3,0 #s3=quotient li s4,0 #i=0 li s5,16 #16 li s8,15 for: bge s4,s5,conti3 slli s3,s3,1 sub s6,s8,s4 #(15 - i) sll s7,s2,s6 bge s1,s7,for_if addi s4,s4,1 j for for_if: sub s1,s1,s7 ori s3,s3,1 addi s4,s4,1 j for conti3: sub s9,t2,t3 #s9=result_exp addi s9,s9,127 bnez t2,check_expb1 addi s9,s9,-1 check_expb1: bnez t3,check_quotient addi s9,s9,1 check_quotient: li t6,0x8000 and t6,s3,t6 beqz t6,else srli s3,s3,8 j conti4 else: li s7,0x8000 li s8,1 and t6,s3,s7 bnez t6,else_conti ble s9,s8,else_conti while2: slli s3,s3,1 addi s9,s9,-1 and t6,s3,s7 bnez t6,else_conti ble s9,s8,else_conti j while2 else_conti: srli s3,s3,8 conti4: andi s3,s3,0x7F li t6,0xFF blt s9,t6,check_result_le_zero li t6,0x7F80 slli s0,s0,15 or a0,s0,t6 jr ra check_result_le_zero: bgtz s9,final slli a0,s0,15 jr ra final: slli s0,s0,15 li t6,0xFF and s9,s9,t6 slli s9,s9,7 li t6,0x7F and s3,s3,t6 or a0,s0,s9 or a0,a0,s3 jr ra ``` ::: ## Sqrt :::spoiler C ```c= static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` ::: :::spoiler Assembly ```c= bf16_sqrt: #t0=SIGN_A #t2=EXP_A #t4=MANT_A li t6,0xFF bne t2,t6,check_exp_mant_is_zero bnez t4,return_a bnez t0,return_nan j return_a check_exp_mant_is_zero: bnez t2,check_sign bnez t4,check_sign j return_zero check_sign: bnez t0,return_nan check_exp: beqz t2,return_zero conti5: addi s5,t2,-127 #s5=e li s6,0 #s6=new_exp ori s7,t4,0x80 #s7=m if: andi t6,s5,1 beqz t6,else_if slli s7,s7,1 addi t6,s5,-1 srli t6,t6,1 addi s6,t6,127 j conti6 else_if: srli t6,s5,1 addi s6,t6,127 conti6: li s8,90 #s8=low li s9,256 #s9=high li s10,128 #s10=result ble s8,s9,while_loop j if_2 while_loop: add s1,s8,s9 #s1=mid srli s1,s1,1 # ===== mul s3 = s1 * s1 ===== li s3, 0 mv t6, s1 mv t3, s1 mul_sqrt_loop: beqz t6, mul_sqrt_done andi t0, t6, 1 beqz t0, skip_add_sqrt add s3, s3, t3 skip_add_sqrt: slli t3, t3, 1 srli t6, t6, 1 j mul_sqrt_loop mul_sqrt_done: # ===== div s3 /= 128 ===== li t0,0 mv t1,s3 li t2,128 div_emul_loop: blt t1,t2,div_emul_end sub t1,t1,t2 addi t0,t0,1 j div_emul_loop div_emul_end: mv s3,t0 bgt s3,s7,else_if_2 mv s10,s1 addi s8,s1,1 ble s8,s9,while_loop j if_2 else_if_2: addi s9,s1,-1 ble s8,s9,while_loop if_2: li t6,256 blt s10,t6,else_if3 srli s10,s10,1 addi s6,s6,1 j final_2 else_if3: li t6,128 bge s10,t6,final_2 li t6,1 blt s6,t6,final_2 while_loop2: slli s10,s10,1 addi s6,s6,-1 li t6,128 bge s10,t6,final_2 li t6,1 ble s6,t6,final_2 j while_loop2 final_2: andi s0,s10,0x7F #s11=new_mant li t6,0xFF bge s6,t6,final_if1 blez s6,return_zero andi a0,s6,0xFF slli a0,a0,7 or a0,a0,s0 jr ra final_if1: li a0,0x7F80 mv a0,a0 jr ra return_nan: li a0, 0x7FC0 jr ra ``` :::