# Assignment1: RISC-V Assembly and Instruction Pipeline contributed by < [hbshub](https://github.com/hbshub/ca2025-quizzes) > >[!Note] AI tools usage >I use ChatGPT to assist with Quiz 1 by providing code explanations, grammar revisions, pre-work research, code summaries. ## Problem B In problem B, we want to implement a logarithmic 8-bit codec that maps 20-bit unsigned integers ($[0,1{,}015{,}792]$), and we have decode and encode formula. **<font size=5>`Decode formula`</font><br>** $$D(\text{uf8}) = m \cdot 2^e + (2^e - 1) \cdot 16$$ Where e = $\lfloor b/16 \rfloor$ and $m = b \bmod 16$ ``` uf8 notation : eeeemmmm ┌──────────────┬──────────────┐ │ Exponent (4) │ Mantissa (4) │ └──────────────┴──────────────┘ 7 4 3 0 E: Exponent bits (4 bits) M: Mantissa bits (4 bits) ``` The `high 4` bits of `uf8` represent `e = floor(uf8/16)` The `low 4` bits of `uf8` represent `m = uf8 mod 16`. After decode, we get the 20-bit unsigned integers. **<font size=5>`Encode formula`</font><br>** $$ E(v) = \begin{cases} v, & \text{if } v < 16 \\ 16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise} \end{cases} $$ where $\text{offset}(e) = (2^e - 1) \cdot 16$ We can encode values in the range $[0, 1{,}015{,}792]$ into the uf8 format, using only 8 bits to represent a 20-bit value. $$ ### **<font size=5>`clz`</font><br>** Helper function to check the MSB of a value. :::spoiler c source code ```c= // count leading zero static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } ``` ::: :::spoiler rv32i assembly code ```asm= # clz(uint32_t x) # a0: x (input/output) # return a0 = count of leading zeros clz: li t0, 32 # n = 32 li t1, 16 # c = 16 clz_loop: srl t2, a0, t1 # y = x >> c beq t2, x0, skip_update # if (y == 0) skip update sub t0, t0, t1 # n = n - c mv a0, t2 # x = y skip_update: srli t1, t1, 1 # c = c / 2 bnez t1, clz_loop # while (c) loop sub a0, t0, a0 # return n - x ret ``` ::: ### **<font size=5>`uf8_decode`</font><br>** :::spoiler c source code ```c= /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } ``` ::: :::spoiler rv32i assembly code ```asm= # fewer instructions by simplify the decode foumula # => (m ≪ e) + offset # => (m ≪ e) + ((2^e − 1)⋅16) # => (m ≪ e) + (16 << e) - 16 # => ((m + 16) << e) - 16 # uf8_decode(uf8 f) # a0: f (input/output) # return a0 = uf8_decode(f) uf8_decode: srli t0, a0, 4 # t0 = e andi a0, a0, 0x0F # a0 = m addi a0, a0, 16 # a0 = m + 16 sll a0, a0, t0 # a0 = (m + 16) << e addi a0, a0, -16 # a0 = a0 - 16 ret ``` ::: ### **<font size=5>`uf8_encode`</font><br>** :::spoiler c source code ```c= /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } ``` ::: :::spoiler rv32i assembly code ```asm= # uf8_encode : encode a value into 1-byte uf8 encoding # input a0 - uint32_t value # return a0 - uf8 uf8_encode(value) uf8_encode: li t3, 16 bge a0, t3, e_!0 # if (a0 >= 16) e_!0 ret e_!0: # e != 0 addi sp, sp, -8 sw a0, 0(sp) # store a0 in stack sw ra, 4(sp) # store ra in stack jal ra, clz # a0 = clz(a0) li t3, 31 sub t0, t3, a0 # msb = t0 = 31 - clz(a0) lw a0, 0(sp) # restore a0 from stack lw ra, 4(sp) # restore ra from stack addi sp, sp, 8 li t1, 0 # exp = t1 = 0 li t2, 0 # of = t2 = 0 li t3, 5 blt t0, t3, find_exa_exp # if (msb < 5) find_exa_exp addi t1, t0, -4 # exp = msb - 4 li t0, 0 # t0 = cnt = 0 li t3, 15 # t3 = 15, cmp value ble t1, t3, calc_exp # if (exp < 15) calc_exp li t1, 15 # exp = 15 calc_exp: bge t0, t1, adj_exp # if (cnt < exp) loop slli t2, t2, 1 # of = of << 1 addi t2, t2, 16 # of = of + 16 addi t0, t0, 1 # cnt++ jal x0, calc_exp adj_exp: ble t1, x0, find_exa_exp # if (exp <= 0) find_exa_exp bge a0, t2, find_exa_exp # if (a0 >= of) find_exa_exp addi t2, t2, -16 # of = of - 16 srli t2, t2, 1 # of = of >> 1 addi t1, t1, -1 # exp-- jal x0, adj_exp find_exa_exp: bge t1, t3, calc_m # if (exp >= 15) calc_m slli t0, t2, 1 # t0 = of << 1 addi t0, t0, 16 # t0 = (of << 1) + 16 = of_e+1 blt a0, t0, calc_m # if (a0 >= of_e) calc_m mv t2, t0 # of = of_e addi t1, t1, 1 # exp++ jal x0, find_exa_exp calc_m: sub t0, a0, t2 # t0 = value - of srl t0, t0, t1 # t0 = (value - of) >> exp = m ble t0, t3, cmb_num # if (m < 15) cmb_num li t0, 15 # m = 15 cmb_num: slli t1, t1, 4 # t1 = exp << 4 or a0, t1, t0 # a0 = (exp << 4) | m ret ``` ::: ### Run in Ripes uf8 code [0~255] decode and encode back all tests passed ![image](https://hackmd.io/_uploads/SJSnNTqaxx.png) ### 5-stage Pipelined Processor observation instruction `srli t0, a0, 4` ![image](https://hackmd.io/_uploads/rkp8GscTgl.png) 1. IF stage fetch instruction code `0x0045293` - opcode : `0010011` - rd : `00101` = x5 = t0 - rs1 : `01010` = x10 = a0 - shamt : `00100` = 4 program counter update - pc = pc + 4 ![image](https://hackmd.io/_uploads/Bka37jcall.png) 2. ID stage decode instruction to four parts - opcode = `SRLI` - r1_reg = `0x0a` = x10 - wr_reg = `0x05` = x5 - imm de = `0x000000004` ![image](https://hackmd.io/_uploads/r1Qy8o5pll.png) 3. EX stage execute the `SRLI` operation - op1 = 0x00000000 - op2 = 0x00000004 - res = 0x0 >> 0x4 = 0x0 (shift right) ![image](https://hackmd.io/_uploads/H1QTzh9pgg.png) 4. MEM stage `SRLI` instruction nothing happens in the MEM stage, because no memory access is needed. ![image](https://hackmd.io/_uploads/SkGKmhqpxg.png) 5. WB stage - wr_data =`0x0` - wr_idx = `0x5` write data back to register file. ![image](https://hackmd.io/_uploads/SkSTEn5pgx.png) ### cycle count improvement compiler generated code ([RISC-V (32-bits) gcc (trunk)](https://godbolt.org/)) ![image](https://hackmd.io/_uploads/Hy0l7T56xe.png) my implementation ![image](https://hackmd.io/_uploads/Hy-4NT5agx.png) <font size=5>Reduce the cycle count by 60%.</font><br> ## Problem C ### conversion function (f32_to_bf16, bf16_to_f32) :::spoiler c source code ```c= static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } ``` ::: :::spoiler rv32i assembly code ```asm= # ---- BF16 Masks ---- .equ BF16_SIGN_MASK, 0x8000 .equ BF16_EXP_MASK, 0x7F80 .equ BF16_MANT_MASK, 0x007F # ---- BF16 Constant ---- .equ BF16_EXP_BIAS, 127 .equ BF16_NAN, 0x7FC0 .equ BF16_ZERO, 0x0000 .data .text # input a0 = f32 val # return bf16 a0 f32_to_bf16: srli t0, a0, 23 # t0 = s | e andi t0, t0, 0xFF # t0 = e li t1, 0xFF beq t0, t1, skip_rounding # if e == 0xFF -> skip rounding srli t1, a0, 16 # t1 = upper 16 bits andi t0, t1, 1 # t0 = LSB of upper half li t2, 0x7FFF add t0, t2, t0 # rounding bias = 0x7FFF + LSB(1/0) add a0, a0, t0 # RNE (round to nearest even) skip_rounding: srli a0, a0, 16 # return upper 16 bits ret bf16_to_f32: slli a0, a0, 16 # shift bf16 to upper half ret ``` ::: ### special value function (nan, inf, zero) :::spoiler c source code ```c= static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } ``` ::: :::spoiler rv32i assembly code ```asm= # ---- BF16 Masks ---- .equ BF16_SIGN_MASK, 0x8000 .equ BF16_EXP_MASK, 0x7F80 .equ BF16_MANT_MASK, 0x007F # ---- BF16 Constant ---- .equ BF16_EXP_BIAS, 127 .equ BF16_NAN, 0x7FC0 .equ BF16_ZERO, 0x0000 # a0 = bf16 bits # return 1 if exp==0x7F80 && mant!=0, else 0 bf16_isnan: li t0, BF16_EXP_MASK # 0x7F80 and t1, a0, t0 # t1 = e bne t1, t0, nan_false # if e != 0x7F80 andi t1, a0, BF16_MANT_MASK # t1 = m sltu a0, x0, t1 # a0 = (m != 0) ret nan_false: li a0, 0 ret # a0 = bf16 bits # return 1 if exp==0x7F80 && mant==0, else 0 bf16_isinf: li t0, BF16_EXP_MASK # 0x7F80 and t1, a0, t0 # t1 = e bne t1, t0, inf_false # if e != 0x7F80 andi t1, a0, BF16_MANT_MASK # t1 = m sltiu a0, t1, 1 # a0 = (m == 0) ret inf_false: li a0, 0 ret # a0 = bf16 bits # return 1 if (a.bits & 0x7FFF)==0, else 0 bf16_iszero: li t0, 0x7FFF and a0, a0, t0 # clear sign bit sltiu a0, a0, 1 # a0 = (a0 == 0) ret ``` ::: ### compare function (eq, gt, lt) :::spoiler c source code ```c= static inline bool bf16_eq(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return true; return a.bits == b.bits; } static inline bool bf16_lt(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return false; bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1; if (sign_a != sign_b) return sign_a > sign_b; return sign_a ? a.bits > b.bits : a.bits < b.bits; } static inline bool bf16_gt(bf16_t a, bf16_t b) { return bf16_lt(b, a); } ``` ::: :::spoiler rv32i assembly code ```asm= # a0 = a, a1 = b # return: a0 = 1(true) / 0(false) # bf16_isnan(a0) -> a0=1/0 # bf16_iszero(a0) -> a0=1/0 bf16_eq: addi sp, sp, -16 sw ra, 12(sp) sw a0, 8(sp) # save a sw a1, 4(sp) # save b # (isnan(a)) return 0; lw a0, 8(sp) # a jal ra, bf16_isnan bnez a0, eq_false # (isnan(b)) return 0; lw a0, 4(sp) # b jal ra, bf16_isnan bnez a0, eq_false # (iszero(a) && iszero(b)) return 1; lw a0, 8(sp) # a jal ra, bf16_iszero beqz a0, cmp_bits # a != 0 -> cmp_bits lw a0, 4(sp) # b jal ra, bf16_iszero bnez a0, eq_true # a == b == 0 -> true cmp_bits: # a0 = (a == b) ? 1 : 0 (xor + sltiu) avoid branch lw t0, 8(sp) # t0 = a lw t1, 4(sp) # t1 = b xor t2, t0, t1 # t2 = a ^ b sltiu a0, t2, 1 # a0 = (t2 == 0) ? 1 : 0 j eq_ret eq_false: li a0, 0 j eq_ret eq_true: li a0, 1 # fallthrough to ret eq_ret: lw ra, 12(sp) addi sp, sp, 16 ret # a0 = a, a1 = b # return: a0 = 1 if (a < b) else 0 # bf16_isnan(a0) -> a0=1/0 # bf16_iszero(a0) -> a0=1/0 bf16_lt: addi sp, sp, -16 sw ra, 12(sp) sw a0, 8(sp) # save a sw a1, 4(sp) # save b # (isnan(a) || isnan(b)) return false; lw a0, 8(sp) # a jal ra, bf16_isnan bnez a0, lt_false lw a0, 4(sp) # b jal ra, bf16_isnan bnez a0, lt_false # (iszero(a) && iszero(b)) return false; lw a0, 8(sp) # a jal ra, bf16_iszero beqz a0, sign_cmp # a != 0 -> sign_cmp lw a0, 4(sp) # b jal ra, bf16_iszero bnez a0, lt_false # a == b == 0 -> lt_false sign_cmp: lw t0, 8(sp) # t0 = a.bits lw t1, 4(sp) # t1 = b.bits srli t2, t0, 15 andi t2, t2, 1 # t2 = sign_a srli t3, t1, 15 andi t3, t3, 1 # t3 = sign_b bne t2, t3, diff_sign # same sign # sign = 0 -> pos:a.bits < b.bits beqz t2, pos_cmp # sign = 1 -> neg:a < b <-> a.bits > b.bits sltu a0, t1, t0 # a0 = (b.bits < a.bits) j ret_common pos_cmp: sltu a0, t0, t1 # a0 = (a.bits < b.bits) j ret_common diff_sign: # return (sign_a > sign_b) sltu a0, t3, t2 # a0 = (sign_b < sign_a) j ret_common lt_false: li a0, 0 ret_common: lw ra, 12(sp) addi sp, sp, 16 ret # a0 = a, a1 = b # return: a0 = 1 if (a > b) else 0 bf16_gt: mv t0, a0 # swap a0,a1 mv a0, a1 mv a1, t0 j bf16_lt # tail-call, no stack frame needed ``` ::: ### arithmetic function (add, sub, mul, div, sqrt) :::spoiler add/sub c source code ```c= static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } ``` ::: :::spoiler add/sub rv32i assembly code ```asm= # bf16_add(a0=a, a1=b) -> a0 # BF16: [sign:1][exp:8][frac:7] # const: INF=0x7F80, NAN=0x7FC0, ZERO=0x0000 # regs: # t0=a, t1=b # t2=sign_a, t3=exp_a, t4=mant_a # t5=sign_b, t6=exp_b # a1=mant_b(temp), a0=temp/result .data cases: .half 0x3F80,0x3F80 .half 0x4000,0x4000 .half 0x4380,0x4180 .half 0x4040,0xBF80 .half 0x3F80,0xBF00 # .half 0x4380,0x4180 # NaN + 1.0 # .half 0x3F80,0x7FC0 # 1.0 + NaN # .half 0x7F80,0x4080 # +Inf + 4.0 # .half 0xc88b, 0xe9c9 # -1091.375 + -476.78125 = -1568.15625 # .half 0xfb44, 0xa286 # -0.001953125 + -20.15625 = -20.158203125 out: .half 0,0,0,0,0 # 5 results .text main: la s0, cases # s0 -> input pairs (a,b), 4 bytes each la s1, out # s1 -> output, 2 bytes each li s2, 5 # loop count 1: lh a0, 0(s0) # load a lh a1, 2(s0) # load b addi s0, s0, 4 # move to next input jal ra, bf16_add # call bf16_add sh a0, 0(s1) # store result addi s1, s1, 2 # move to next output addi s2, s2, -1 bnez s2, 1b halt: j halt # infinite loop bf16_add: # Save original a, b mv t0, a0 # t0 = a mv t1, a1 # t1 = b # ---------------- Special case: a Exp==0xFF ---------------- srli t3, t0, 7 # t3 = exp_a andi t3, t3, 0xFF li a0, 0xFF bne t3, a0, chk_b_ff andi t4, t0, 0x7F # mant_a bnez t4, ret_a # a is NaN # a is Inf, check b srli t6, t1, 7 # exp_b andi t6, t6, 0xFF bne t6, a0, ret_a # b not ExpFF -> return a andi a0, t1, 0x7F # mant_b bnez a0, ret_b # b is NaN -> return b # both a and b are Inf: same sign -> return b; diff sign -> NaN srli t2, t0, 15 andi t2, t2, 1 srli t5, t1, 15 andi t5, t5, 1 beq t2, t5, ret_b li a0, 0x7FC0 # NaN ret # ---------------- Special case: b Exp==0xFF ---------------- chk_b_ff: srli t6, t1, 7 # t6 = exp_b andi t6, t6, 0xFF li a0, 0xFF bne t6, a0, quick_zero andi a0, t1, 0x7F # mant_b bnez a0, ret_b # b is NaN mv a0, t1 # b is Inf ret # ---------------- Fast path: ±0 ---------------- quick_zero: # a == ±0 ? andi t4, t0, 0x7F # mant_a beqz t3, 1f # exp_a==0 ? j 2f 1: beqz t4, ret_b # a is ±0 -> return b 2: # b == ±0 ? andi a0, t1, 0x7F # mant_b (temporarily use a0) beqz t6, 3f j 4f 3: beqz a0, ret_a # b is ±0 -> return a 4: # ---------------- Extract sign/exp/mant ---------------- srli t2, t0, 15 # sign_a andi t2, t2, 1 srli t5, t1, 15 # sign_b andi t5, t5, 1 andi t4, t0, 0x7F # mant_a andi a1, t1, 0x7F # mant_b # normal add implicit 1 beqz t3, 5f ori t4, t4, 0x80 5: beqz t6, 6f ori a1, a1, 0x80 6: # ---------------- Exponent alignment ---------------- sub a0, t3, t6 # a0 = exp_diff = exp_a - exp_b bgtz a0, diff_pos bltz a0, diff_neg mv a0, t3 # result_exp = exp_a (diff==0) j add_or_sub # exp_a > exp_b diff_pos: mv a0, t3 # a0 = result_exp sub t6, t3, t6 # t6 = exp_diff li t3, 8 bgt t6, t3, ret_a # diff>8 → return a beqz t6, add_or_sub srl a1, a1, t6 # mant_b >>= diff j add_or_sub # exp_a < exp_b diff_neg: sub t3, t6, t3 # t3 = -exp_diff = exp_b - exp_a li t6, 8 bgt t3, t6, ret_b # -diff>8 → return b beqz t3, 7f srl t4, t4, t3 # mant_a >>= -diff 7: mv a0, t6 # a0 = result_exp = exp_b (already in t6) # ---------------- Same sign → add; different sign → sub ---------------- add_or_sub: beq t2, t5, same_sign # different sign: larger minus smaller bgeu t4, a1, 8f sub t4, a1, t4 # result_mant = mant_b - mant_a mv t2, t5 # result_sign = sign_b beqz t4, ret_zero j norm_sub 8: sub t4, t4, a1 # result_mant = mant_a - mant_b # result_sign = sign_a (t2) beqz t4, ret_zero norm_sub: # Normalize (shift left until bit7=1) andi t6, t4, 0x80 bnez t6, pack slli t4, t4, 1 addi a0, a0, -1 # --exp blez a0, ret_zero j norm_sub same_sign: add t4, t4, a1 # result_mant = mant_a + mant_b li t6, 0x100 and t6, t4, t6 # check carry into bit9 beqz t6, pack srli t4, t4, 1 # shift right back to 8 bits addi a0, a0, 1 # ++exp li t6, 0xFF blt a0, t6, pack # overflow -> ±Inf li t6, 0x7F80 slli t2, t2, 15 or a0, t2, t6 ret # ---------------- Pack back to BF16 ---------------- pack: andi t4, t4, 0x7F # frac andi a0, a0, 0xFF # exp slli a0, a0, 7 slli t2, t2, 15 # sign or a0, a0, t4 or a0, a0, t2 ret # ---------------- Fast return ---------------- ret_a: mv a0, t0 ret ret_b: mv a0, t1 ret ret_zero: li a0, 0x0000 ret # a0=a, a1=b -> a0 = a - b bf16_sub: li t0, 0x8000 # mask: sign bit xor a1, a1, t0 # b = b ^ 0x8000 (flip sign bit) j bf16_add # tail call : a + (-b) ``` ::: :::spoiler mul c source code ```c= static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } ``` ::: :::spoiler mul rv32i assembly code ```asm= .data cases: .half 0x3FC0, 0x3FC0 # expected 0x4010 (~2.25) ; 1.5 * 1.5 .half 0x4040, 0x3F00 # expected 0x3FC0 (~1.5) ; 3.0 * 0.5 .half 0x7F7F, 0x4000 # expected 0x7F80 (+Inf) ; max_finite * 2.0 .half 0x0080, 0x0080 # expected 0x0000 (+0) ; min_normal * min_normal .half 0x0001, 0x3F80 # expected 0x0000 (+0) ; min_subnormal * 1.0 .half 0x7F80, 0x0000 # expected 0x7FC0 (NaN) ; +Inf * +0 .half 0x7FC1, 0x4000 # expected 0x7FC1 (NaN) ; NaN(payload 0x01) * 2.0 .half 0xBFA0, 0x4000 # expected 0xC020 (~-2.5) ; -1.25 * 2.0 .half 0x0000, 0xC5A6 # expected 0x8000 (-0) ; +0 * negative .half 0x0001, 0x0001 # expected 0x0000 (+0) ; subnormal * subnormal out: .half 0,0,0,0,0,0,0,0,0,0 # 10 results .text # ------------------------------------------------------------ # main: run 10 test pairs in 'cases', write results into 'out' # ------------------------------------------------------------ main: la s0, cases # s0 -> (a,b) pairs, 4 bytes per pair la s1, out # s1 -> output buffer, 2 bytes per result li s2, 10 # number of test pairs loop: lhu a0, 0(s0) # load a (zero-extend) lhu a1, 2(s0) # load b (zero-extend) addi s0, s0, 4 # advance to next pair jal ra, bf16_mul # compute sh a0, 0(s1) # store result addi s1, s1, 2 # advance output pointer addi s2, s2, -1 bnez s2, loop halt: j halt # ------------------------------------------------------------ # mul8x8_u32: shift-add 8x8 unsigned multiply # IN : a2=x(8-bit), a3=y(8-bit) # OUT: a2 = x*y (lower 16 bits valid) # Clobbers: t0,t1,t2 # ------------------------------------------------------------ mul8x8_u32: li t0, 0 li t1, 8 mul8_loop: andi t2, a2, 1 beqz t2, mul8_skip_add add t0, t0, a3 mul8_skip_add: srli a2, a2, 1 slli a3, a3, 1 addi t1, t1, -1 bnez t1, mul8_loop mv a2, t0 ret # ------------------------------------------------------------ # bf16_mul (a0=a_bits, a1=b_bits) -> a0=result_bits # RV32I only; truncation (no RNE); handles NaN/Inf/±0 and subnormals # ------------------------------------------------------------ bf16_mul: # extract fields srli t0, a0, 15 # sign_a andi t0, t0, 1 srli t1, a1, 15 # sign_b andi t1, t1, 1 xor t2, t0, t1 # result_sign = sign_a ^ sign_b slli t2, t2, 15 # (sign<<15) mv a5, t2 # SAVE sign in a5 (callee won't clobber) srli t3, a0, 7 # exp_a andi t3, t3, 0xFF srli t4, a1, 7 # exp_b andi t4, t4, 0xFF andi t5, a0, 0x7F # mant_a andi t6, a1, 0x7F # mant_b # preload +Inf pattern for early special paths (will reload after mul) li a2, 0x7F80 # special cases: exp==0xFF? li t1, 0xFF beq t3, t1, special_a beq t4, t1, special_b # zero short-circuit beqz t3, check_a_zero j check_b_zero_done check_a_zero: beqz t5, ret_signed_zero check_b_zero_done: beqz t4, check_b_zero j norm_inputs check_b_zero: beqz t6, ret_signed_zero # normalize inputs: subnormals shift until bit7=1; normals add implicit 1 norm_inputs: li a4, 0 # exp_adjust = 0 # A: operand a beqz t3, norm_a_sub ori t5, t5, 0x80 # add implicit 1 j norm_b norm_a_sub: li t1, 0x80 norm_a_loop: and t2, t5, t1 bnez t2, norm_a_done slli t5, t5, 1 addi a4, a4, -1 j norm_a_loop norm_a_done: li t3, 1 # exp_a = 1 # B: operand b norm_b: beqz t4, norm_b_sub ori t6, t6, 0x80 j mul_mant norm_b_sub: li t1, 0x80 norm_b_loop: and t2, t6, t1 # use t2; do NOT touch a5 bnez t2, norm_b_done slli t6, t6, 1 addi a4, a4, -1 j norm_b_loop norm_b_done: li t4, 1 # exp_b = 1 # mantissa multiply (8x8) mul_mant: mv a2, t5 mv a3, t6 addi sp, sp, -8 sw ra, 4(sp) jal ra, mul8x8_u32 # clobbers t0,t1,t2 and a2/a3 lw ra, 4(sp) addi sp, sp, 8 mv t5, a2 # result_mant = product # RELOAD after call (callee clobbered these) li t0, 127 # bias li a2, 0x7F80 # +Inf pattern for ret_inf # result_exp = exp_a + exp_b - bias + exp_adjust add t6, t3, t4 sub t6, t6, t0 add t6, t6, a4 # normalize product to 1.x or 2.x; keep 7 fraction bits (truncate) li t1, 0x8000 and a4, t5, t1 beqz a4, norm_prod_1x # 2.x: (prod>>8)&0x7F; exp++ srli t5, t5, 8 andi t5, t5, 0x7F addi t6, t6, 1 j check_over_under norm_prod_1x: # 1.x: (prod>>7)&0x7F srli t5, t5, 7 andi t5, t5, 0x7F # overflow / underflow / subnormal check_over_under: li t1, 255 bge t6, t1, ret_inf # overflow -> ±Inf # underflow if exp <= 0 bge zero, t6, underflow_path j pack underflow_path: # exp < -6 -> ±0 li t1, -6 blt t6, t1, ret_signed_zero # subnormal output: mant >>= (1 - exp); exp = 0 li t1, 1 sub t1, t1, t6 srl t5, t5, t1 li t6, 0 j pack # pack back to bf16 pack: andi t6, t6, 0xFF slli t6, t6, 7 andi t5, t5, 0x7F or a0, a5, t6 # use saved sign in a5 or a0, a0, t5 ret # return ±Inf ret_inf: or a0, a5, a2 # a2 = 0x7F80 ret # return signed zero ret_signed_zero: mv a0, a5 # sign|0 ret # ------------------------------------------------------------ # special cases # special_a: a is Inf/NaN (exp_a==0xFF) # special_b: b is Inf/NaN (exp_b==0xFF) # ------------------------------------------------------------ # a is special: # (mant_a) return a; # NaN → propagate operand # (b==0) return NaN; # return ±Inf; special_a: bnez t5, ret_a beqz t4, chk_b_zero2 j ret_inf chk_b_zero2: bnez t6, ret_inf li a0, 0x7FC0 # canonical NaN ret ret_a: mv a0, a0 # return a as-is ret # b is special: # (mant_b) return b; # NaN → propagate operand # (a==0) return NaN; # return ±Inf; special_b: bnez t6, ret_b beqz t3, chk_a_zero2 j ret_inf chk_a_zero2: bnez t5, ret_inf li a0, 0x7FC0 # canonical NaN ret ret_b: mv a0, a1 # return b as-is ret ``` ::: :::spoiler div c source code ```c= static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F), }; } ``` ::: :::spoiler div rv32i assembly code ```asm= .data # ------------------------------------------------------------ # 10 pairs of (a,b) in BF16 (little-endian .half) # ------------------------------------------------------------ cases: # .half 0x3F80, 0x4000 # 1.0 / 2.0 -> 0x3F00 (0.5) # .half 0x4000, 0x3F80 # 2.0 / 1.0 -> 0x4000 (2.0) # .half 0x4040, 0x4000 # 3.0 / 2.0 -> 0x3FC0 (1.5) # .half 0x0000, 0x4000 # 0.0 / 2.0 -> 0x0000 (+0.0) # .half 0x3F80, 0x0000 # 1.0 / 0.0 -> 0x7F80 (+Inf) # .half 0x7F80, 0x4000 # +Inf / 2.0 -> 0x7F80 (+Inf) # .half 0x4000, 0x7F80 # 2.0 / +Inf -> 0x0000 (+0.0) # .half 0x7FC1, 0x4000 # NaN / 2.0 -> 0x7FC1 (NaN, payload pass) # .half 0x0040, 0x3F80 # subnormal_a / 1.0 -> 0x0000 (underflow→0) # .half 0x3F80, 0x0001 # 1.0 / tiny subnormal -> 0x7F80 (+Inf) .half 0x4040, 0x4000 # 3.0 / 2.0 -> 0x3FC0 (1.5) .half 0x40A0, 0x4080 # 5.0 / 4.0 -> 0x3FA0 (1.25) .half 0x4110, 0x4100 # 9.0 / 8.0 -> 0x3F90 (1.125) .half 0x40E0, 0x4100 # 7.0 / 8.0 -> 0x3F70 (0.875) .half 0x40C0, 0x4100 # 6.0 / 8.0 -> 0x3F40 (0.75) .half 0x40A0, 0x4100 # 5.0 / 8.0 -> 0x3F20 (0.625) .half 0x40A0, 0x4000 # 5.0 / 2.0 -> 0x4020 (2.5) .half 0x4200, 0x4100 # 32 / 8 -> 0x4080 (4.0) .half 0x4140, 0x4040 # 12 / 3 -> 0x4080 (4.0) .half 0x4120, 0x4080 # 10 / 4 -> 0x4020 (2.5) out: .half 0,0,0,0,0,0,0,0,0,0 # 10 results .text # ------------------------------------------------------------ # main: run 10 test pairs in 'cases', write results into 'out' # ------------------------------------------------------------ main: la s0, cases # s0 -> (a,b) pairs, 4 bytes per pair la s1, out # s1 -> output buffer, 2 bytes per result li s2, 10 # number of test pairs loop: lhu a0, 0(s0) # load a (zero-extend) lhu a1, 2(s0) # load b (zero-extend) addi s0, s0, 4 # advance to next pair jal ra, bf16_div # compute sh a0, 0(s1) # store result addi s1, s1, 2 # advance output pointer addi s2, s2, -1 bnez s2, loop halt: j halt # ------------------------------------------------------------ # bf16_div (a0=a_bits, a1=b_bits) -> a0=result_bits # RV32I only; restoring division on mantissas; truncation (no RNE) # Handles NaN/Inf/±0 and subnormals; underflow -> ±0 (no gradual subnormals) # BF16: [sign:1][exp:8][frac:7], bias=127 # ------------------------------------------------------------ bf16_div: # ------- unpack a ------- srli t2, a0, 15 # t2=sign_a andi t2, t2, 1 srli t3, a0, 7 # t3=exp_a andi t3, t3, 0xFF andi t4, a0, 0x7F # t4=mant_a mv s3, t3 # s3=exp_a # ------- unpack b ------- srli t5, a1, 15 # t5=sign_b andi t5, t5, 1 srli t6, a1, 7 # t6=exp_b andi t6, t6, 0xFF andi a2, a1, 0x7F # a2=mant_b # result_sign = sign_a ^ sign_b xor t0, t2, t5 # t0=result_sign # consts li a3, 0xFF # 0xFF li a4, 0x7F80 # +Inf pattern li a5, 127 # bias # ------- b is Inf/NaN? ------- beq t6, a3, b_inf_nan # ------- b == 0 ? ------- beqz t6, b_zero_check j a_inf_nan_check b_zero_check: beqz a2, div_by_zero # x/0 # b subnormal -> continue j a_inf_nan_check div_by_zero: # 0/0 -> NaN ; x/0 -> ±Inf beqz s3, a_zero_chk_for_00 j ret_signed_inf a_zero_chk_for_00: beqz t4, ret_nan # 0/0 -> NaN j ret_signed_inf b_inf_nan: # b is Inf/NaN bnez a2, ret_b # b is NaN -> return b # b is Inf: x/Inf -> signed zero; Inf/Inf handled later beq s3, a3, a_inf_then_nan slli a0, t0, 15 # signed zero ret a_inf_then_nan: beqz t4, ret_nan # Inf/Inf -> NaN # a is NaN actually (but mant!=0 implies NaN) -> return a mv a0, a0 ret # ------- a Inf/NaN? ------- a_inf_nan_check: bne s3, a3, a_zero_check bnez t4, ret_a # a is NaN -> return a # a is Inf ; Inf/finite -> ±Inf ret_signed_inf: slli t1, t0, 15 or a0, t1, a4 ret ret_a: mv a0, a0 ret # ------- a == 0 ? ------- a_zero_check: beqz s3, a_zero_exp j norm_mantissas a_zero_exp: beqz t4, ret_signed_zero # 0/x -> ±0 # a subnormal -> continue j norm_mantissas ret_signed_zero: slli a0, t0, 15 ret # ------- add hidden 1 for normals ------- norm_mantissas: beqz s3, 1f ori t4, t4, 0x80 # mant_a |= 0x80 1: beqz t6, 2f ori a2, a2, 0x80 # mant_b |= 0x80 2: # ------- restoring division (16 bits of quotient) ------- slli t1, t4, 15 # t1 = dividend slli t2, a2, 15 # t2 = d (divisor aligned) li t3, 0 # t3 = quotient li a7, 16 div_loop: slli t3, t3, 1 # quotient <<= 1 bltu t1, t2, no_sub sub t1, t1, t2 # dividend -= d ori t3, t3, 1 # quotient |= 1 no_sub: srli t2, t2, 1 # d >>= 1 addi a7, a7, -1 bnez a7, div_loop # ------- result_exp = exp_a - exp_b + bias (+subnormal adjust) ------- sub t1, s3, t6 # t1 = exp_a - exp_b add t1, t1, a5 # t1 += bias beqz s3, 3f j 4f 3: addi t1, t1, -1 # a subnormal -> --exp 4: beqz t6, 5f j 6f 5: addi t1, t1, 1 # b subnormal -> ++exp 6: # ------- normalize quotient to 1.x then drop hidden 1 and truncate ------- # (quotient & 0x8000) q >>= 8; else while(!(q&0x8000)&&exp>1){ q<<=1; exp--; } q >>= 8; li t2, 0x8000 and t4, t3, t2 bnez t4, q_has_one q_need_shift: # while (!(q&0x8000) && exp > 1) { q<<=1; exp--; } and t4, t3, t2 # test (q & 0x8000) bnez t4, q_align_done addi t1, t1, 0 # exp # addi t0, zero, 1 # tmp one # ble t1, t0, q_align_done li a6, 1 ble t1, a6, q_align_done slli t3, t3, 1 # q <<= 1 addi t1, t1, -1 # exp-- j q_need_shift q_align_done: srli t3, t3, 8 # drop to frac range j after_q_align q_has_one: srli t3, t3, 8 after_q_align: andi t3, t3, 0x7F # keep 7-bit fraction # ------- overflow/underflow checks ------- ble t1, zero, ret_signed_zero # exp <= 0 -> ±0 (no subnormals generated) li t2, 255 bge t1, t2, ret_signed_inf # exp >= 255 -> ±Inf # li t2, 0xFF # bgeu t1, t2, ret_signed_inf # exp >= 255 -> ±Inf # blez t1, ret_signed_zero # exp <= 0 -> ±0 (no subnormals generated) # ------- pack sign|exp|frac ------- slli t0, t0, 15 # sign bit andi t1, t1, 0xFF # clamp exp slli t1, t1, 7 # exp << 7 or t0, t0, t1 or a0, t0, t3 ret # ------- quick return ------- ret_b: mv a0, a1 ret ret_nan: li a0, 0x7FC0 # canonical qNaN ret ``` ::: :::spoiler sqrt c source code ```c= static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` ::: :::spoiler sqrt v32i assembly code ```asm= .data # ------------------------------------------------------------ # 10 BF16 inputs (.half, little-endian) for sqrt # rule: # NaN propagation # sqrt(+Inf) -> +Inf # sqrt(neg) -> NaN # sqrt(subnormal) -> 0(flush) # ------------------------------------------------------------ cases: # .half 0x3F80 # sqrt(1.0) -> 1.0 (0x3F80) # .half 0x4080 # sqrt(4.0) -> 2.0 (0x4000) # .half 0x4010 # sqrt(2.25) -> 1.5 (0x3FC0) # .half 0x3F00 # sqrt(0.5) -> ~0.7071 (≈0x3F59) # .half 0x0000 # sqrt(+0.0) -> +0.0 (0x0000) # .half 0x7F80 # sqrt(+Inf) -> +Inf (0x7F80) # .half 0x7FC1 # sqrt(NaN) -> NaN(payload reserved) (0x7FC1) # .half 0x0001 # sqrt(tiny subnormal) -> 0 (flush to 0) # .half 0x4110 # sqrt(9.0) -> 3.0 (0x4040) # .half 0xBF80 # sqrt(-1.0) -> NaN (0x7FC0) .half 0x3F10 # 0.5625 -> 0.75 (0x3F40) .half 0x3EC8 # 0.390625 -> 0.625 (0x3F20) .half 0x4010 # 2.25 -> 1.5 (0x3FC0) .half 0x40C8 # 6.25 -> 2.5 (0x4020) .half 0x3C80 # 1/64=0.015625 -> 0.125 (0x3E00) out: .half 0,0,0,0,0,0,0,0,0,0 # 10 results .text # ------------------------------------------------------------ # main: run 10 inputs in 'cases', write results into 'out' # a0=input_bits, ret a0=result_bits # ------------------------------------------------------------ main: la s0, cases # s0 -> inputs (2 bytes per case) la s1, out # s1 -> output buffer li s2, 5 # number of tests loop: lhu a0, 0(s0) # load input (zero-extend) addi s0, s0, 2 jal ra, bf16_sqrt # compute sqrt sh a0, 0(s1) # store result addi s1, s1, 2 addi s2, s2, -1 bnez s2, loop halt: j halt # ------------------------------------------------------------ # mul8x8_u32: 8x8 unsigned multiply (RV32I shift-add) # IN : a2=x (8-bit), a3=y (8-bit) # OUT: a2 = x*y (low 16 bits valid) # Clobbers: t0, t1, t2 # ------------------------------------------------------------ mul8x8_u32: li t0, 0 li t1, 8 mul_loop: andi t2, a2, 1 beqz t2, mul_skip_add add t0, t0, a3 mul_skip_add: srli a2, a2, 1 slli a3, a3, 1 addi t1, t1, -1 bnez t1, mul_loop mv a2, t0 ret # ------------------------------------------------------------ # bf16_sqrt (a0=input_bits) -> a0=result_bits # RV32I only; truncation (no RNE); NaN/Inf/±0/neg/subnormals handled. # Mantissa scale = 128 (1.0 -> 128) # ------------------------------------------------------------ bf16_sqrt: # ------------- parse fields ------------- srli t0, a0, 15 # t0 = sign andi t0, t0, 1 srli t1, a0, 7 # t1 = exp (8-bit) andi t1, t1, 0xFF andi t2, a0, 0x7F # t2 = mant (7-bit) # ------------- NaN, Inf ------------- li t3, 0xFF bne t1, t3, not_inf_nan # if exp==0xFF: NaN/Inf bnez t2, ret_nan_payload # NaN: payload propogation -> return a0 bnez t0, ret_qnan # -Inf -> NaN ret # +Inf -> +Inf not_inf_nan: beqz t1, exp_zero_or_subnorm # exp==0 ? j check_negative exp_zero_or_subnorm: beqz t2, ret_zero # ±0 -> +0(flush to 0) li a0, 0 # subnormal flush-to-zero -> 0 ret check_negative: beqz t0, core # non-zero & non-neg -> core ret_qnan: li a0, 0x7FC0 # quiet NaN ret core: # e = exp - 127 addi t3, t1, -127 # t3 = e (signed) # m = 0x80 | mant ori t4, t2, 0x80 # t4 = m (uint32) # (e & 1) { m<<=1; new_exp=((e-1)>>1)+127; } else { new_exp=(e>>1)+127; } andi t2, t3, 1 beqz t2, exp_even slli t4, t4, 1 # m <<= 1 addi t5, t3, -1 srai t5, t5, 1 addi t5, t5, 127 # t5 = new_exp j after_exp_adjust exp_even: mv t5, t3 srai t5, t5, 1 addi t5, t5, 127 # t5 = new_exp after_exp_adjust: # ------------- binary search, result in [128..255] ------------- li t6, 90 # low li a5, 255 # high li a1, 0 # result binsearch_loop: bltu a5, t6, binsearch_done # if (high < low) break add t1, t6, a5 # mid = (low + high) >> 1 srli t1, t1, 1 # --- sq = (mid*mid) >> 7 (scale to 128) mv a4, t1 # backup mid → a4(mul will clobber t1) mv a2, t1 # a2=mid mv a3, t1 # a3=mid addi sp, sp, -8 sw ra, 4(sp) jal ra, mul8x8_u32 # a2 = mid*mid (low 16 bits);mul will clobber t0,t1,t2 lw ra, 4(sp) addi sp, sp, 8 mv t1, a4 # restore mid mv t2, a2 srli t2, t2, 7 # t2 = sq # (sq <= m) { result=mid; low=mid+1; } else { high=mid-1; } bgtu t2, t4, shrink_high mv a1, t1 # result = mid addi t6, t1, 1 # low = mid + 1 j binsearch_loop shrink_high: addi a5, t1, -1 # high = mid - 1 j binsearch_loop binsearch_done: # ------------- pack back BF16 ------------- andi a1, a1, 0x7F # new_mant = result & 0x7F slli t5, t5, 7 # (new_exp << 7) or a0, t5, a1 # sign = 0 ret # ----------- quick return for special values ----------- ret_nan_payload: ret # NaN payload propogation -> return a0 ret_zero: li a0, 0 # +0 ret ``` :::