# Assignment 1: RISC-V Assembly and Instruction Pipeline contributed by < [TWChris90](https://github.com/TWChris90/ca2025-quizzes) > ## Problem B: ### uf8_decode #### Original version :::spoiler C Code ```c= * Decode uf8 to uint32_t */ uint32_t uf8_decode(uint8_t fl) { uint32_t mantissa = fl & 0x0f; // m uint8_t exponent = fl >> 4; // e uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } ``` ::: The original computes the segment start as: $$ \mathrm{offset}(e)=((\texttt{0x7FFF}\,\texttt{>>}\,(15- e))\,\texttt{<<}\,4)=(2^{e}-1)\cdot 16 $$ #### Improved version :::spoiler C Code ```c= * Decode uf8 to uint32_t */ uint32_t uf8_decode(uint8_t fl) { uint32_t m = fl & 0x0F; // m uint32_t e = fl >> 4; // e uint32_t offset = (1u << (e + 4)) - 16u; return offset + (m << e); } ``` ::: The improved version uses the algebraically equivalent closed form: $$ \mathrm{offset}(e)=(1 \ll (e+4)) - 16 = 2^{\,e+4} - 16 $$ This removes the need to load `0x7FFF` and perform a variable right shift before a left shift. #### Edge cases - **e = 0** - **Original:** `(0x7FFF >> 15) << 4 = 0` - **Improved:** `(1u << 4) - 16u = 16 - 16 = 0` - **e = 15** - **Original:** `(0x7FFF >> 0) << 4 = 0x7FFF << 4 = 524,272` - **Improved:** `(1u << 19) - 16u = 524,288 - 16 = 524,272` Both produce the correct offsets at the boundaries; the decoded value remains: $$ \mathrm{value}=\mathrm{offset}(e)+\bigl(m\mathbin{\text{≪}}e\bigr) $$ #### Practical impact - **Same complexity (O(1)), fewer moving parts.** The improved form typically compiles down to one variable left shift + one subtract, instead of “load 0x7FFF + variable right shift + left shift by 4”. That usually reduces instructions and dependencies on RV32I. - **Portability & maintainability.** Using an unsigned, closed-form expression for the offset is portable and easier to audit. ### uf8_encode #### Original version :::spoiler C Code ```c= // Encode uint32_t to uf8 uf8 uf8_encode(uint32_t value) { if (value < 16) return (uf8)value; int lz = clz32(value); int msb = 31 - lz; uint8_t exponent = 0; uint32_t offset = 0; // coarse guess for e from msb if (msb >= 5) { exponent = msb - 4; if (exponent > 15) exponent = 15; // forward recurrence to reach offset(e): // offset_{e+1} = 2*offset_e + 16 for (uint8_t e = 0; e < exponent; e++) offset = (offset << 1) + 16; // if the guess overshoots, step downward (inverse recurrence) while (exponent > 0 && value < offset) { offset = (offset - 16) >> 1; exponent--; } } // then step upward until the next segment would exceed value while (exponent < 15) { uint32_t next = (offset << 1) + 16; // start of segment e+1 if (value < next) break; offset = next; exponent++; } uint8_t mantissa = (value - offset) >> exponent; // floor return (uf8)((exponent << 4) | mantissa); } ``` ::: #### Improved version :::spoiler C Code ```c= // Encode uint32_t to uf8 uf8 uf8_encode(uint32_t v) { if (v < 16u) return (uf8)v; unsigned lz = clz32(v); int msb = 31 - (int)lz; // 1) start from e0 = clamp(msb - 4, 0..15) int e = msb - 4; if (e < 0) e = 0; if (e > 15) e = 15; // 2) closed-form segment starts uint32_t offset0 = (1u << (e + 4)) - 16u; // start of e uint32_t next = (e < 15) ? ((1u << (e + 5)) - 16u) : UINT32_MAX; // start of e+1 (or +∞) // 3) decide e with two comparisons → e ∈ {e-1, e, e+1} if (v < offset0) { e -= 1; if (e < 0) { e = 0; offset0 = 0u; } else { offset0 = (1u << (e + 4)) - 16u; } } else if (v >= next) { e += 1; if (e > 15) e = 15; offset0 = (1u << (e + 4)) - 16u; } // 4) in-segment index uint32_t m = (v - offset0) >> e; // floor return (uf8)((e << 4) | (m & 0x0F)); } ``` ::: #### How the segment `e` is chosen - **Original approach** - First take a coarse guess `e ≈ msb - 4`. - Then use a **for loop** to advance `offset` to that `e` via the recurrence `offset_{e+1} = 2*offset_e + 16`. - If the guess overshoots, use a **while (down)** loop to inverse-step the recurrence and decrease `e`. - After that, use a **while (up)** loop to step forward until the next segment start would exceed `value`. - **Improved approach** - Start with `e0 = clamp(msb - 4, 0..15)`. - Compute in one shot the closed-form boundaries: - `offset0 = (1u << (e0 + 4)) - 16u` - `next = (e0 < 15) ? ((1u << (e0 + 5)) - 16u) : UINT32_MAX` - With two comparisons—`value < offset0` and `value >= next`—decide that the true `e` is in `{ e0 - 1, e0, e0 + 1 }`. - This eliminates all for/while loops and makes the step count O(1) (fixed), independent of the input value. #### Computing the segment start (`offset`) - **Original approach** Define the start of segment `e` by a recurrence: $$ \mathrm{offset}(0)=0,\ \mathrm{offset}(e+1)=2\mathbin{\text{⋅}}\mathrm{offset}(e)+16 $$ In code, this typically requires a `for`/`while` loop to “walk” from `e=0` up to the target `e`. - **Improved approach** Use the algebraically equivalent closed form: $$ \mathrm{offset}(e)=(2^{e}-1)\mathbin{\text{·}}16=2^{\,e+4}-16 $$ This computes the segment start in one shot—no recurrence, no loop, no intermediate state. #### Complexity - **Original approach** `O(# of loop iterations)` — data-dependent; may execute multiple for/while passes depending on where `value` falls. - **Improved approach** Fixed `O(1)` — deterministic step count; shorter control flow with fewer branches. ### RISC-V32 code :::spoiler Assembly Code ``` .data msg1: .asciz ": produces value " msg2: .asciz " but encodes back to " msg3: .asciz ": value " msg4: .asciz " <= previous_value " msg5: .asciz "All tests passed.\n" msg6: .asciz "Some tests failed.\n" newline:.asciz "\n" .align 2 .text .globl main main: jal ra, test # run the full test beq a0, x0, Not_pass # a0==0 => failed la a0, msg5 # print "All tests passed.\n" li a7, 4 ecall li a7, 10 # exit(0) li a0, 0 ecall Not_pass: la a0, msg6 # print "Some tests failed.\n" li a7, 4 ecall li a7, 10 # exit(1) li a0, 1 ecall test: addi sp, sp, -4 sw ra, 0(sp) # test calls other functions addi s11, x0, -1 # previous_value = -1 li s10, 1 # pass = true li s9, 0 # code = 0 li s8, 256 # end bound For_2: add a0, s9, x0 # a0 = code jal ra, uf8_decode add s7, a0, x0 # s7 = decoded value add a0, s7, x0 jal ra, uf8_encode add s6, a0, x0 # s6 = re-encoded code test_if_1: beq s9, s6, test_if_2 mv a0, s9 # print code (hex) li a7, 34 ecall la a0, msg1 # ": produces value " li a7, 4 ecall mv a0, s7 # print decoded value (dec) li a7, 1 ecall la a0, msg2 # " but encodes back to " li a7, 4 ecall mv a0, s6 # print re-encoded code (hex) li a7, 34 ecall la a0, newline li a7, 4 ecall li s10, 0 # pass = false test_if_2: blt s11, s7, after_if mv a0, s9 # offending code (hex) li a7, 34 ecall la a0, msg3 # ": value " li a7, 4 ecall mv a0, s7 # current value (dec) li a7, 1 ecall la a0, msg4 # " <= previous_value " li a7, 4 ecall mv a0, s11 # previous_value (hex) li a7, 34 ecall la a0, newline li a7, 4 ecall li s10, 0 # pass = false after_if: mv s11, s7 # update previous_value addi s9, s9, 1 # code++ blt s9, s8, For_2 mv a0, s10 # return pass flag lw ra, 0(sp) addi sp, sp, 4 jr ra CLZ: li a1, 0 # count = 0 add a3, a0, x0 # a3 = x (working copy) srli a2, a3, 16 # y = x >> 16 bne a2, x0, L1 # if y != 0, MSB in top half → keep y addi a1, a1, 16 # else top 16 are zero j L2 L1: add a3, a2, x0 # a3 = y (keep top half) L2: srli a2, a3, 8 # y = a3 >> 8 bne a2, x0, L3 # if y != 0, MSB in this 8-bit window addi a1, a1, 8 # else next 8 are zero j L4 L3: add a3, a2, x0 # a3 = y (keep this 8-bit window) L4: li a2, 7 # i = 7 .. 0 L5: blt a2, x0, L_done # i < 0 → done li t0, 1 sll t0, t0, a2 # mask = 1 << i and t0, t0, a3 # bit = a3 & mask bne t0, x0, L_done # first '1' found → stop addi a1, a1, 1 # count++ addi a2, a2, -1 # i-- j L5 L_done: add a0, a1, x0 # return count jr ra uf8_decode: andi a1, a0, 0x0F # m srli a2, a0, 4 # e addi a3, a2, 4 # e + 4 li a4, 1 sll a4, a4, a3 # 1 << (e+4) addi a4, a4, -16 # offset sll a3, a1, a2 # m << e add a0, a3, a4 # value jr ra uf8_encode: addi sp, sp, -4 sw ra, 0(sp) # may call CLZ (software) add a7, a0, x0 # a7 = value li a1, 16 blt a7, a1, UE_RET # value < 16 → return value add a0, a7, x0 # call software CLZ jal ra, CLZ li a1, 31 sub a1, a1, a0 # a1 = msb = 31 - lz addi a3, a1, -4 slti a2, a3, 0 beqz a2, 1f li a3, 0 1: li a2, 15 ble a3, a2, 2f li a3, 15 2: addi a4, a3, 4 li a5, 1 sll a5, a5, a4 addi a5, a5, -16 # a5 = offset0 addi a4, a3, 5 li a2, 1 sll a2, a2, a4 addi a2, a2, -16 # a2 = next blt a7, a5, _dec_e # value < offset0 → e = e0 - 1 bge a7, a2, _inc_e # value >= next → e = e0 + 1 j _e_ok # else e = e0 _dec_e: addi a3, a3, -1 bgez a3, 3f li a3, 0 # e cannot go below 0 li a5, 0 # offset(0) = 0 j 4f 3: # recompute offset(e) = (1 << (e+4)) - 16 addi a4, a3, 4 li a5, 1 sll a5, a5, a4 addi a5, a5, -16 4: j _e_ok _inc_e: addi a3, a3, 1 li a1, 15 ble a3, a1, 5f li a3, 15 # e cannot exceed 15 5: # recompute offset(e) addi a4, a3, 4 li a5, 1 sll a5, a5, a4 addi a5, a5, -16 j _e_ok _e_ok: sub a2, a7, a5 srl a2, a2, a3 # mantissa = (value - offset) >> e slli a1, a3, 4 or a0, a1, a2 # pack [eeee mmmm] UE_RET: lw ra, 0(sp) addi sp, sp, 4 jr ra ``` ::: ### RISC-V ![image](https://hackmd.io/_uploads/B14hJE96xg.png) ### Improvement - **Original approach** ![initial](https://hackmd.io/_uploads/SyFCdmYpeg.png) - **Improved approach** ![螢幕擷取畫面 2025-10-13 152031](https://hackmd.io/_uploads/H1Ktu7caee.png) - **Calculation notes:** - **Cycle reduction** = (43,953 − 33408) / 43,953 = 23.99% - **Instruction** reduction = (29,907 − 24122) / 29,907 = 19.34% - **Speedup** = 43,953 / 33408 = 1.315× ## Problem C: ### 16 Bit Layout ``` ┌─────────┬──────────────┬──────────────┐ │Sign (1) │ Exponent (8) │ Mantissa (7) │ └─────────┴──────────────┴──────────────┘ 15 14 7 6 0 S: Sign bit (0 = positive, 1 = negative) E: Exponent bits (8 bits, bias = 127) M: Mantissa/fraction bits (7 bits) ``` - **bit 15:** `S` → Sign bit `0` means positive, `1` means negative. - **bits 14–7:** `E` → 8-bit exponent The stored value already includes the bias **(bias = 127)**. - **bits 6–0:** `M` → 7-bit mantissa (fraction part) Does **not** include the implicit leading **1**. ### Numerical Formula The normalized **bfloat16** value is defined as: $$ v = (-1)^S \times 2^{E - 127} \times \left(1 + \frac{M}{128}\right) $$ **$(-1)^S$** : - `S = 0` → `(+1)` - `S = 1` → `(-1)` Determines the **sign** of the number. **$2^{E - 127}$** : - The exponent field stores a **biased value**: `E_stored = e_real + 127` - Therefore, the **actual exponent** is: `e_real = E - 127`. **$1 + M / 128$** : - The mantissa has **7 bits**, so `M ∈ [0, 127]`. - These 7 bits represent the **fractional part**, with an **implicit leading 1** forming a value in `[1, 2)`. - Since \( 2^7 = 128 \), dividing `M` by 128 converts it into a fraction in `[0, 1)`. ### Tool Function: #### bf16_isnan - **Condition:** `E = 255, M ≠ 0` - **Value:** `v = NaN` :::spoiler C Code ```c= static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } ``` ::: :::spoiler Assembly ```asm= .globl bf16_isnan bf16_isnan: lui t0, 0x7 addi t0, t0, 0x780 # t0 = 0x00007780 addi t0, t0, 0x80 # t0 = 0x00007F80 and t1, a0, t0 bne t1, t0, isnan_false # if (exp != 0xFF) → not NaN/Inf andi t3, a0, 0x7F # t3 = mantissa sltu a0, x0, t3 ret isnan_false: addi a0, x0, 0 ret ``` ::: #### bf16_isinf - **Condition:** `E = 255 (0xFF), M = 0` - **Value:** `v = (-1)^S × ∞` :::spoiler C Code ```c= static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } ``` ::: :::spoiler Assembly ```asm= .globl bf16_isinf bf16_isinf: lui t0, 0x7 # Load upper bits → 0x00007000 addi t0, t0, 0x780 # Add lower part (within 12-bit range) addi t0, t0, 0x80 # Now t0 = 0x00007F80 and t1, a0, t0 # Extract exponent field: t1 = a0 & 0x7F80 bne t1, t0, isinf_false # If exponent != 0xFF → not Inf/NaN andi t3, a0, 0x7F # Isolate mantissa bits (lowest 7) sltiu a0, t3, 1 ret isinf_false: addi a0, x0, 0 # Return 0 if not infinity ret ``` ::: #### bf16_iszero - **Condition:** `E = 0, M = 0` - **Value:** `v = (-1)^S × 0` - **Bit patterns:** - `+0`: `S = 0, E = 0, M = 0` - `-0`: `S = 1, E = 0, M = 0` :::spoiler C Code ```c= static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } ``` ::: :::spoiler Assembly ```asm= .globl bf16_iszero bf16_iszero: lui t0, 0x8 # t0 = 0x00008000 addi t0, t0, -1 # t0 = 0x00007FFF and t1, a0, t0 # t1 = a0 & 0x7FFF sltiu a0, t1, 1 ret ``` ::: #### f32_to_bf16 :::spoiler C Code ```c= static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } ``` ::: :::spoiler Assembly ```asm= f32_to_bf16: addi sp, sp, -4 sw s0, 0(sp) addi s0, a0, 0 # Save input (float32 bits) to s0 srli t0, s0, 23 andi t0, t0, 0xFF addi t1, x0, 0xFF bne t0, t1, unspecial # If exp != 255 → normal number srli a0, s0, 16 lui t0, 0x1 # t0 = 0x00010000 addi t0, t0, -1 # t0 = 0x0000FFFF and a0, a0, t0 # Mask lower 16 bits jal x0, f32_to_bf16_done # Jump to end unspecial: srli t0, s0, 16 andi t0, t0, 1 # t0 = LSB of the upper 16 bits lui t1, 0x8 # t1 = 0x00008000 addi t1, t1, -1 # t1 = 0x00007FFF add t0, t0, t1 # Add rounding offset (0x7FFF + LSB) add s0, s0, t0 # Apply rounding to full 32-bit value srli a0, s0, 16 # Take upper 16 bits as bfloat16 result f32_to_bf16_done: lw s0, 0(sp) addi sp, sp, 4 ret ``` ::: #### bf16_to_f32 :::spoiler C Code ```c= static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } ``` ::: :::spoiler Assembly ```asm= .globl bf16_to_f32 bf16_to_f32: slli a0, a0, 16 ret ``` ::: #### ADD ##### bf16_add Test Cases **Case 1: 1.5 + 0.5 = 2.0** - `0x3FC0` → +1.5 - `0x3F00` → +0.5 - Expected result:`0x4000` → +2.0 Tests normal positive addition to verify mantissa alignment and exponent handling. **Case 2:-1.0 + 0.0 = -1.0** - `0xBF80` → -1.0 - `0x0000` → +0.0 - Expected result:`0xBF80` → -1.0 Tests signed-zero behavior to ensure the negative sign is preserved when adding 0. **Case 3: +∞ + (-10.0) = +∞** - `0x7F80` → +∞ - `0xC120` → -10.0 (float32 for 10.0 is `0x41200000`; bfloat16 = `0x4120`; adding sign bit gives `0xC120`) - Expected result: According to IEEE754, ∞ plus a finite number remains ∞. Tests special value handling (Infinity + finite number). ![image](https://hackmd.io/_uploads/B1JFIdWMZe.png) :::spoiler C Code ```c= static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } ``` ::: :::spoiler Assembly ```asm= .globl bf16_to_f32 .data newline: .string "\n" pass_msg: .asciz "Test Passed\n" fail_msg: .asciz "Test Failed\n" .text .globl main main: addi sp, sp, -4 sw ra, 0(sp) test1: li a0, 0x3FC0 # A = 1.5 li a1, 0x3F00 # B = 0.5 jal ra, bf16_add # a0 = A + B li t1, 0x4000 # expected = 2.0 bne a0, t1, test1_fail jal ra, print_pass j test2 test1_fail: jal ra, print_fail j test2 test2: li a0, 0xBF80 # A = -1.0 li a1, 0x0000 # B = +0.0 jal ra, bf16_add li t1, 0xBF80 # expected = -1.0 bne a0, t1, test2_fail jal ra, print_pass j test3 test2_fail: jal ra, print_fail j test3 test3: li a0, 0x7F80 # A = +Inf li a1, 0xC120 # B = -10.0 jal ra, bf16_add li t1, 0x7F80 # expected = +Inf bne a0, t1, test3_fail jal ra, print_pass j tests_done test3_fail: jal ra, print_fail j tests_done print_pass: la a0, pass_msg li a7, 4 # syscall: print string ecall jr ra print_fail: la a0, fail_msg li a7, 4 ecall jr ra tests_done: lw ra, 0(sp) addi sp, sp, 4 li a7, 10 # syscall: exit ecall .globl bf16_add bf16_add: # extract sign, exponent, mantissa srli t0, a0, 15 # t0 = sign_a (bit 15) srli t1, a1, 15 # t1 = sign_b srli t2, a0, 7 andi t2, t2, 0xFF # t2 = exp_a (8 bits) srli t3, a1, 7 andi t3, t3, 0xFF # t3 = exp_b andi t4, a0, 0x7F # t4 = mant_a (7 bits) andi t5, a1, 0x7F # t5 = mant_b (7 bits) li t6, 0xFF bne t2, t6, check_exp_b exp_a_checkall: bnez t4, ret_a # mant_a != 0 → a is NaN, return a bne t3, t6, ret_a # a is Inf, b is finite → return a bnez t5, return_b1 # b mantissa != 0 → b is NaN bne t0, t1, return_nan # +Inf + -Inf → NaN return_b1: mv a0, a1 # b is NaN or same-sign Inf ret return_nan: li a0, 0x7FC0 # canonical NaN ret_a: ret check_exp_b: beq t3, t6, return_b2 j check_0_a return_b2: mv a0, a1 # b is NaN or Inf ret check_0_a: bnez t2, check_0_b # exp_a != 0 → not zero bnez t4, check_0_b # mant_a != 0 → not zero mv a0, a1 # a is ±0 → result = b ret check_0_b: bnez t3, norm_a bnez t5, norm_a ret # b is ±0 → result = a (a0) norm_a: beqz t2, norm_b # exp_a == 0 → subnormal ori t4, t4, 0x80 # mant_a |= 1 << 7 norm_b: beqz t3, end_check1 ori t5, t5, 0x80 end_check1: addi sp, sp, -20 sw s0, 16(sp) sw s1, 12(sp) sw s2, 8(sp) sw s3, 4(sp) sw s4, 0(sp) sub s0, t2, t3 # s0 = exp_diff = exp_a - exp_b blez s0, diff_neg # exp_a <= exp_b mv s2, t2 # result_exp = exp_a li t6, 8 bgt s0, t6, return_a # if exp_diff > 8 → B too small srl t5, t5, s0 # shift mant_b j exp_done diff_neg: bgez s0, diff_else # exp_diff == 0 mv s2, t3 # result_exp = exp_b li t6, -8 bge s0, t6, shift_a # if exp_diff >= -8 → shift A j return_b3 # else A too small → result ≈ B shift_a: neg s4, s0 # s4 = -exp_diff srl t4, t4, s4 # shift mant_a j exp_done diff_else: # exp_diff == 0 mv s2, t2 j exp_done return_a: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 ret return_b3: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 mv a0, a1 ret exp_done: bne t0, t1, diff_sign # sign differ → subtraction same_sign: mv s1, t0 # result_sign add s3, t4, t5 # result_mant andi t6, s3, 0x100 # overflow into bit 8? beqz t6, norm_end srli s3, s3, 1 # shift mantissa addi s2, s2, 1 # exponent++ li t6, 0xFF bge s2, t6, overflow_inf j norm_end overflow_inf: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 slli a0, s1, 15 # sign li t6, 0x7F80 # Inf exponent or a0, a0, t6 ret diff_sign: bge t4, t5, manta_gt_b mv s1, t1 # result_sign = sign_b sub s3, t5, t4 # mant_b - mant_a j mant_result manta_gt_b: mv s1, t0 # result_sign = sign_a sub s3, t4, t5 # mant_a - mant_b mant_result: beqz s3, return_zero # exact zero norm_loop: andi t6, s3, 0x80 bnez t6, norm_end slli s3, s3, 1 addi s2, s2, -1 blez s2, return_zero j norm_loop norm_end: slli a0, s1, 15 # sign andi t0, s2, 0xFF slli t0, t0, 7 # exponent or a0, a0, t0 andi t0, s3, 0x7F # mantissa or a0, a0, t0 lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 ret return_zero: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 li a0, 0x0000 # +0 ret ``` ::: #### SUB ##### bf16_sub Test Cases **Case 1: 2.0 − 1.5 = 0.5** - `0x4000` → +2.0 - `0x3FC0` → +1.5 - Expected result: `0x3F00` → +0.5 Tests normal positive subtraction to verify exponent alignment and mantissa subtraction logic. **Case 2: -1.0 − 2.0 = -3.0** - `0xBF80` → -1.0 - `0x4000` → +2.0 - Expected result: `0xC040` → -3.0 Tests negative minus positive, confirming correct handling of opposite signs (effectively addition). **Case 3: 0.0 − (-2.0) = 2.0** - `0x0000` → +0.0 - `0xC000` → -2.0 - Expected result: `0x4000` → +2.0 Tests subtraction involving a negative number, ensuring the sign bit inversion is correctly applied. ![image](https://hackmd.io/_uploads/BkEwtOWMWe.png) :::spoiler C Code ```c= static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } ``` ::: :::spoiler Assembly ```asm= .data newline: .string "\n" pass_msg: .asciz "Test Passed\n" fail_msg: .asciz "Test Failed\n" .text .globl main main: addi sp, sp, -4 sw ra, 0(sp) test1: li a0, 0x4000 # A = 2.0 li a1, 0x3FC0 # B = 1.5 jal ra, bf16_sub # a0 = A - B li t1, 0x3F00 # expected = 0.5 bne a0, t1, test1_fail jal ra, print_pass j test2 test1_fail: jal ra, print_fail j test2 test2: li a0, 0xBF80 # A = -1.0 li a1, 0x4000 # B = 2.0 jal ra, bf16_sub li t1, 0xC040 # expected = -3.0 bne a0, t1, test2_fail jal ra, print_pass j test3 test2_fail: jal ra, print_fail j test3 test3: li a0, 0x0000 # A = 0.0 li a1, 0xC000 # B = -2.0 jal ra, bf16_sub li t1, 0x4000 # expected = 2.0 bne a0, t1, test3_fail jal ra, print_pass j tests_done test3_fail: jal ra, print_fail j tests_done print_pass: la a0, pass_msg li a7, 4 # print string ecall jr ra print_fail: la a0, fail_msg li a7, 4 ecall jr ra tests_done: lw ra, 0(sp) addi sp, sp, 4 li a7, 10 # exit ecall .globl bf16_sub bf16_sub: li t0, 0x8000 # sign bit mask xor a1, a1, t0 # flip sign of B → -B j bf16_add # reuse bf16_add .globl bf16_add bf16_add: # extract sign, exponent, mantissa srli t0, a0, 15 # t0 = sign_a srli t1, a1, 15 # t1 = sign_b srli t2, a0, 7 # t2 = exp_a andi t2, t2, 0xFF srli t3, a1, 7 # t3 = exp_b andi t3, t3, 0xFF andi t4, a0, 0x7F # t4 = mant_a andi t5, a1, 0x7F # t5 = mant_b li t6, 0xFF bne t2, t6, check_exp_b exp_a_checkall: bnez t4, ret_a # mant_a != 0 → a is NaN bne t3, t6, ret_a # a is Inf, b finite → return a bnez t5, return_b1 # b mant != 0 → b is NaN bne t0, t1, return_nan # +Inf + -Inf → NaN return_b1: mv a0, a1 ret return_nan: li a0, 0x7FC0 # NaN ret_a: ret check_exp_b: beq t3, t6, return_b2 # b is NaN/Inf j check_0_a return_b2: mv a0, a1 ret check_0_a: bnez t2, check_0_b # exp_a != 0 → not zero bnez t4, check_0_b # mant_a != 0 → not zero mv a0, a1 # a is zero ret check_0_b: bnez t3, norm_a bnez t5, norm_a ret # b is zero → return a norm_a: beqz t2, norm_b ori t4, t4, 0x80 # mant_a |= 1<<7 norm_b: beqz t3, end_check1 ori t5, t5, 0x80 # mant_b |= 1<<7 end_check1: addi sp, sp, -20 sw s0, 16(sp) # exp_diff sw s1, 12(sp) # result_sign sw s2, 8(sp) # result_exp sw s3, 4(sp) # result_mant sw s4, 0(sp) sub s0, t2, t3 # s0 = exp_a - exp_b blez s0, diff_neg # exp_a <= exp_b mv s2, t2 # result_exp = exp_a li t6, 8 bgt s0, t6, return_a # diff > 8 → B too small srl t5, t5, s0 # shift mant_b j exp_done diff_neg: bgez s0, diff_else # exp_diff == 0 mv s2, t3 # result_exp = exp_b li t6, -8 bge s0, t6, shift_a # diff >= -8 → shift A j return_b3 shift_a: neg s4, s0 srl t4, t4, s4 j exp_done diff_else: mv s2, t2 j exp_done return_a: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 ret return_b3: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 mv a0, a1 ret exp_done: bne t0, t1, diff_sign # sign differ → subtraction same_sign: mv s1, t0 # result_sign add s3, t4, t5 # result_mant andi t6, s3, 0x100 # overflow into bit 8? beqz t6, norm_end srli s3, s3, 1 addi s2, s2, 1 li t6, 0xFF bge s2, t6, overflow_inf j norm_end overflow_inf: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 slli a0, s1, 15 li t6, 0x7F80 # Inf or a0, a0, t6 ret diff_sign: bge t4, t5, manta_ge_b mv s1, t1 # |b| > |a| → sign = sign_b sub s3, t5, t4 # mant_b - mant_a j mant_result manta_ge_b: mv s1, t0 # |a| >= |b| → sign = sign_a sub s3, t4, t5 # mant_a - mant_b mant_result: beqz s3, return_zero norm_loop: andi t6, s3, 0x80 bnez t6, norm_end slli s3, s3, 1 addi s2, s2, -1 blez s2, return_zero j norm_loop norm_end: slli a0, s1, 15 andi t0, s2, 0xFF slli t0, t0, 7 or a0, a0, t0 andi t0, s3, 0x7F or a0, a0, t0 lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 ret return_zero: lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 20 li a0, 0x0000 ret ``` ::: #### MUL ##### bf16_mul Test Cases **Case 1: -2.0 × 0.5 = -1.0** - A = `0xC000` → -2.0 - B = `0x3F00` → +0.5 - Expected result: `0xBF80` → -1.0 Tests a "negative × positive" case to verify: - Correct sign handling using `sign_a XOR sign_b`; - Correct exponent alignment and mantissa multiplication for normal-sized values. **Case 2: -3.0 × -4.0 = +12.0** - A = `0xC040` → -3.0 - B = `0xC080` → -4.0 - Expected result: `0x4140` → +12.0 Tests "negative × negative = positive" to ensure: - The sign computed by `sign_a XOR sign_b` is positive; - Exponent addition, mantissa multiplication, and normalization are all correct. **Case 3: (+∞) × (-2.0) = -∞** - A = `0x7F80` → +∞ (exponent = 0xFF, mantissa = 0) - B = `0xC000` → -2.0 - Expected result: `0xFF80` → -∞ Tests special value handling (Infinity times a finite non-zero number), checking: - Proper NaN / Inf classification branches; - Correct result sign determined by `sign_a XOR sign_b`, yielding negative infinity. ![image](https://hackmd.io/_uploads/Sy2UlYWGZg.png) :::spoiler C Code ```c= static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } ``` ::: :::spoiler Assembly ```asm= .data newline: .string "\n" pass_msg: .asciz "Test Passed\n" fail_msg: .asciz "Test Failed\n" .text .globl main main: addi sp, sp, -4 sw ra, 0(sp) test1: li a0, 0xC000 # A = -2.0 li a1, 0x3F00 # B = 0.5 jal ra, bf16_mul # a0 = A * B li t1, 0xBF80 # expected = -1.0 bne a0, t1, test1_fail jal ra, print_pass j test2 test1_fail: jal ra, print_fail j test2 test2: li a0, 0xC040 # A = -3.0 li a1, 0xC080 # B = -4.0 jal ra, bf16_mul li t1, 0x4140 # expected = 12.0 bne a0, t1, test2_fail jal ra, print_pass j test3 test2_fail: jal ra, print_fail j test3 test3: li a0, 0x7F80 # A = +Inf li a1, 0xC000 # B = -2.0 jal ra, bf16_mul li t1, 0xFF80 # expected = -Inf bne a0, t1, test3_fail jal ra, print_pass j tests_done test3_fail: jal ra, print_fail j tests_done print_pass: la a0, pass_msg li a7, 4 # syscall: print string ecall jr ra print_fail: la a0, fail_msg li a7, 4 ecall jr ra tests_done: lw ra, 0(sp) addi sp, sp, 4 li a7, 10 # syscall: exit ecall .globl bf16_mul bf16_mul: addi sp, sp, -16 sw s0, 0(sp) sw s1, 4(sp) sw s2, 8(sp) sw ra, 12(sp) srli t0, a0, 15 # sign_a = (a0 >> 15) & 1 andi t0, t0, 1 srli t1, a1, 15 # sign_b = (a1 >> 15) & 1 andi t1, t1, 1 srli t2, a0, 7 # exp_a = (a0 >> 7) & 0xFF andi t2, t2, 0xFF srli t3, a1, 7 # exp_b = (a1 >> 7) & 0xFF andi t3, t3, 0xFF andi t4, a0, 0x7F # mant_a (7 bits) andi t5, a1, 0x7F # mant_b (7 bits) xor s0, t0, t1 # s0 = result_sign = sign_a XOR sign_b li t6, 0xFF bne t2, t6, check_b_exp # exp_a != 255 → not NaN/Inf bnez t4, return_a # mant_a != 0 → a is NaN → return a bnez t3, result_inf_a # a is Inf, b not Inf/NaN → result Inf/NaN logic bnez t5, result_inf_a # b is NaN → go inf/NaN handler j return_nan # Inf * Inf → NaN? (here choose NaN) result_inf_a: slli a0, s0, 15 li t6, 0x7F80 or a0, a0, t6 j quit check_b_exp: li t6, 0xFF bne t3, t6, check_0 # exp_b != 255 → not NaN/Inf bnez t5, return_b # b is NaN → return b bnez t2, result_inf_b # b is Inf, a finite → result Inf/NaN logic bnez t4, result_inf_b # a is NaN → go inf/NaN handler j return_nan # Inf * Inf → NaN result_inf_b: slli a0, s0, 15 li t6, 0x7F80 or a0, a0, t6 j quit check_0: bnez t2, a_not_zero bnez t4, a_not_zero j return_0 a_not_zero: bnez t3, norm_mant bnez t5, norm_mant return_0: slli a0, s0, 15 # signed zero j quit norm_mant: li s1, 0 # s1: exp_adjust = 0 bnez t2, norm_a_else # if exp_a != 0 → already normalized norm_loop_a: andi t6, t4, 0x80 # check bit 7 bnez t6, norm_loop_a_done slli t4, t4, 1 # shift mant_a left addi s1, s1, -1 # exp_adjust-- j norm_loop_a norm_loop_a_done: li t2, 1 # treat as exponent = 1 for subnormal j check_exp_b_norm norm_a_else: ori t4, t4, 0x80 # add implicit 1 check_exp_b_norm: bnez t3, else_norm_b # if exp_b != 0 → normalized norm_loop_b: andi t6, t5, 0x80 bnez t6, norm_b_done slli t5, t5, 1 addi s1, s1, -1 j norm_loop_b norm_b_done: li t3, 1 j mul_mant else_norm_b: ori t5, t5, 0x80 # add implicit 1 mul_mant: li s2, 0 # s2 = product = 0 li t6, 8 # 8 bits to process mul_loop: andi t0, t5, 1 # if (t5 & 1) add t4 to product beqz t0, skip_add add s2, s2, t4 skip_add: slli t4, t4, 1 # multiplicand <<= 1 srli t5, t5, 1 # multiplier >>= 1 addi t6, t6, -1 bnez t6, mul_loop add t6, t2, t3 addi t6, t6, -127 add t6, t6, s1 mv s1, t6 # s1 = result_exp li t6, 0x8000 and t6, s2, t6 beqz t6, mult_else srli s2, s2, 8 # keep highest 8 bits after overflow andi s2, s2, 0x7F # keep 7 LSBs, bit 7 is implicit 1 addi s1, s1, 1 j check_exp_overflow mult_else: srli s2, s2, 7 andi s2, s2, 0x7F check_exp_overflow: li t6, 0xFF blt s1, t6, underflow_check # if result_exp < 255 → continue slli a0, s0, 15 li t6, 0x7F80 or a0, a0, t6 j quit underflow_check: bgt s1, x0, final # if result_exp > 0 → normal li t6, -6 blt s1, t6, return_0_udflow # too small → signed zero li t6, 1 sub t6, t6, s1 # shift = 1 - result_exp srl s2, s2, t6 li s1, 0 j final zero_label: nop return_0_udflow: slli a0, s0, 15 # signed zero j quit final: slli a0, s0, 15 # sign andi s1, s1, 0xFF slli s1, s1, 7 # exponent andi s2, s2, 0x7F # mantissa or a0, a0, s1 or a0, a0, s2 j quit return_a: j quit return_b: mv a0, a1 # return b (NaN) j quit return_nan: li a0, 0x7FC0 # canonical NaN j quit quit: lw s0, 0(sp) lw s1, 4(sp) lw s2, 8(sp) lw ra, 12(sp) addi sp, sp, 16 ret ``` ::: #### DIV ##### bf16_div Test Cases **Case 1: -6.0 ÷ 2.0 = -3.0** - A = `0xC0C0` → -6.0 - B = `0x4000` → +2.0 - Expected result: `0xC040` → -3.0 This case tests a **negative divided by positive** scenario, verifying that: - The result sign is correctly computed as `sign_a XOR sign_b` (negative); - The exponent subtraction and bias re-addition are correct; - The integer division of mantissas and subsequent normalization produce the right bf16 encoding. **Case 2: 1.0 ÷ (-4.0) = -0.25** - A = `0x3F80` → +1.0 - B = `0xC080` → -4.0 - Expected result: `0xBE80` → -0.25 This case checks a **positive divided by negative** that yields a **negative fraction**, confirming: - The sign logic `sign_a XOR sign_b` still behaves correctly; - The result exponent shrinks as expected when the quotient is less than 1; - The quotient mantissa is normalized by shifting left/right in a correct and stable way. **Case 3: (+∞) ÷ 4.0 = +∞** - A = `0x7F80` → +∞ (exponent = 0xFF, mantissa = 0) - B = `0x4080` → +4.0 - Expected result: `0x7F80` → +∞ This case focuses on **special value handling** where the numerator is Infinity: - When the numerator is `+∞` and the denominator is a finite non-zero value, the result should remain `+∞`; - It also validates that the NaN / ∞ branches correctly distinguish between: - denominator = ∞, and - numerator = ∞, denominator finite, and that only the latter returns a signed Infinity here. ![image](https://hackmd.io/_uploads/By1fNY-zZx.png) :::spoiler C Code ```c= static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F)}; } ``` ::: :::spoiler Assembly ```asm= .data newline: .string "\n" pass_msg: .asciz "Test Passed\n" fail_msg: .asciz "Test Failed\n" .text .globl main main: addi sp, sp, -4 sw ra, 0(sp) test1: li a0, 0xC0C0 # A = -6.0 li a1, 0x4000 # B = 2.0 jal ra, bf16_div # a0 = A / B li t1, 0xC040 # expected = -3.0 bne a0, t1, test1_fail jal ra, print_pass j test2 test1_fail: jal ra, print_fail j test2 test2: li a0, 0x3F80 # A = 1.0 li a1, 0xC080 # B = -4.0 jal ra, bf16_div li t1, 0xBE80 # expected = -0.25 bne a0, t1, test2_fail jal ra, print_pass j test3 test2_fail: jal ra, print_fail j test3 test3: li a0, 0x7F80 # A = +Inf li a1, 0x4080 # B = 4.0 jal ra, bf16_div li t1, 0x7F80 # expected = +Inf bne a0, t1, test3_fail jal ra, print_pass j tests_done test3_fail: jal ra, print_fail j tests_done print_pass: la a0, pass_msg li a7, 4 # print string ecall jr ra print_fail: la a0, fail_msg li a7, 4 ecall jr ra tests_done: lw ra, 0(sp) addi sp, sp, 4 li a7, 10 # exit ecall .globl bf16_isnan bf16_isnan: li t0, 0x7F80 # exponent mask and t1, a0, t0 bne t1, t0, isnan_false # if (exp != 0xFF) → not NaN/Inf li t2, 0x007F # mantissa mask and t3, a0, t2 # t3 = mantissa snez a0, t3 # a0 = (mant != 0) ? 1 : 0 ret isnan_false: li a0, 0 # return 0 ret .globl bf16_isinf bf16_isinf: li t0, 0x7F80 # exponent mask and t1, a0, t0 bne t1, t0, isinf_false # if (exp != 0xFF) → not Inf/NaN li t2, 0x007F # mantissa mask and t3, a0, t2 # t3 = mantissa seqz a0, t3 # a0 = (mant == 0) ? 1 : 0 ret isinf_false: li a0, 0 ret .globl bf16_iszero bf16_iszero: li t0, 0x7FFF # mask out sign and t1, a0, t0 seqz a0, t1 # a0 = (bits_without_sign == 0) ? 1 : 0 ret .globl f32_to_bf16 f32_to_bf16: addi sp, sp, -4 sw s0, 0(sp) mv s0, a0 srli t0, s0, 23 andi t0, t0, 0xFF li t1, 0xFF bne t0, t1, unspecial srli a0, s0, 16 li t0, 0xFFFF and a0, a0, t0 j f32_to_bf16_done unspecial: srli t0, s0, 16 andi t0, t0, 1 # low bit for tie-to-even li t1, 0x7FFF add t0, t0, t1 # t0 = 0x7FFF or 0x8000 add s0, s0, t0 srli a0, s0, 16 # take high 16 bits as bf16 f32_to_bf16_done: lw s0, 0(sp) addi sp, sp, 4 ret .globl bf16_to_f32 bf16_to_f32: slli a0, a0, 16 # place bf16 in high 16 bits of f32 ret .globl BF16_NAN BF16_NAN: li a0, 0x7FC0 # canonical NaN ret .globl BF16_ZERO BF16_ZERO: li a0, 0x0000 # +0 ret .globl bf16_div bf16_div: addi sp, sp, -16 sw s0, 0(sp) sw s1, 4(sp) sw s2, 8(sp) sw s3, 12(sp) srli t0, a0, 15 andi t0, t0, 1 # sign_a srli t1, a1, 15 andi t1, t1, 1 # sign_b srli t2, a0, 7 andi t2, t2, 0xFF # exp_a srli t3, a1, 7 andi t3, t3, 0xFF # exp_b andi t4, a0, 0x7F # mant_a (7 bits) andi t5, a1, 0x7F # mant_b (7 bits) xor s0, t0, t1 # s0 = result_sign li t6, 0xFF # common constant bne t3, t6, check_zero # if exp_b != 0xFF beqz t5, check_inf # mant_b == 0 → Inf mv a0, a1 # b is NaN → return b j recover check_inf: bne t2, t6, result_sign_1 bnez t4, result_sign_1 # a is NaN li a0, 0x7FC0 # Inf / Inf → NaN j recover result_sign_1: # return signed zero (finite / Inf → 0) slli a0, s0, 15 j recover check_zero: bnez t3, check_2_inf # if exp_b != 0 → not zero bnez t5, check_2_inf # mant_b != 0 → subnormal bnez t2, result_sign_2 # if a != 0 → Inf with sign bnez t4, result_sign_2 # a subnormal non-zero li a0, 0x7FC0 # 0 / 0 → NaN j recover result_sign_2: # division by zero → signed Inf slli a0, s0, 15 li t6, 0x7F80 or a0, a0, t6 j recover check_2_inf: bne t2, t6, check_div_zero beqz t4, result_3 # a is Inf (mant == 0) mv a0, a0 # a is NaN → return a j recover result_3: # a is Inf, b finite non-zero → signed Inf slli a0, s0, 15 li t6, 0x7F80 or a0, a0, t6 j recover check_div_zero: bnez t2, norm # exp_a != 0 → not zero bnez t4, norm # mant_a != 0 → not zero slli a0, s0, 15 # 0 / non-zero → signed zero j recover norm: beqz t2, norm_b ori t4, t4, 0x80 # mant_a |= 1 << 7 norm_b: beqz t3, norm_end ori t5, t5, 0x80 # mant_b |= 1 << 7 norm_end: slli s1, t4, 15 # s1: dividend = mant_a << 15 mv s2, t5 # s2: divisor = mant_b li s3, 0 # s3: quotient = 0 li t6, 0 # loop counter i = 0 div_loop: li a2, 16 bge t6, a2, end_div_loop # while (i < 16) slli s3, s3, 1 # quotient <<= 1 li a3, 15 sub a3, a3, t6 # shift = 15 - i sll a4, s2, a3 # (divisor << (15 - i)) bltu s1, a4, skip_sub # if dividend < shifted divisor → skip sub s1, s1, a4 # dividend -= shifted divisor ori s3, s3, 1 # quotient |= 1 skip_sub: addi t6, t6, 1 # i++ j div_loop end_div_loop: sub a2, t2, t3 # a2 = exp_a - exp_b addi a2, a2, 127 # + BF16_EXP_BIAS bnez t2, res_b addi a2, a2, -1 # if a subnormal, exponent-- res_b: bnez t3, q_check addi a2, a2, 1 # if b subnormal, exponent++ q_check: li t6, 0x8000 and a4, s3, t6 # check highest bit of quotient beqz a4, q_else srli s3, s3, 8 j check_overflow q_else: q_loop: li t6, 0x8000 and a4, s3, t6 bnez a4, q_loop_done # stop when MSB becomes 1 li t6, 1 ble a2, t6, q_loop_done # avoid exponent going below 1 slli s3, s3, 1 # shift mantissa left addi a2, a2, -1 # exponent-- j q_loop q_loop_done: srli s3, s3, 8 # keep top 8 bits (1.xxx) check_overflow: andi s3, s3, 0x7F # keep 7-bit mantissa li t6, 0xFF blt a2, t6, check_un slli a0, s0, 15 li t6, 0x7F80 or a0, a0, t6 j recover check_un: bgt a2, x0, final_result # exponent > 0 → normal number slli a0, s0, 15 j recover final_result: slli a0, s0, 15 # sign andi a2, a2, 0xFF slli a2, a2, 7 # exponent or a0, a0, a2 andi s3, s3, 0x7F # mantissa or a0, a0, s3 recover: lw s0, 0(sp) lw s1, 4(sp) lw s2, 8(sp) lw s3, 12(sp) addi sp, sp, 16 ret ``` ::: #### SQRT(Square Root) ##### bf16_sqrt Test Cases **Case 1: sqrt(1.0) = 1.0** - A = `0x3F80` → +1.0 - Expected result: `0x3F80` → +1.0 This case checks the situation where the value is **exactly 1.0**, verifying that: - The exponent halving plus bias re-adding is correct; - The mantissa obtained from the binary search and normalization reconstructs the same bf16 value. **Case 2: sqrt(0.25) = 0.5** - A = `0x3E80` → +0.25 - Expected result: `0x3F00` → +0.5 This case tests a **positive value smaller than 1**: - The original exponent is such that, after square root, the exponent decreases; - We verify that the mantissa from binary search, combined with the adjusted exponent, encodes 0.5 correctly. **Case 3: sqrt(+Inf) = +Inf** - A = `0x7F80` → +Inf (exponent = 0xFF, mantissa = 0) - Expected result: `0x7F80` → +Inf This case focuses on **special value handling**: - When the input is positive infinity, the result must remain positive infinity; - It confirms that the `exp = 0xFF, mantissa = 0` branch correctly treats the value as Inf (not NaN). ![image](https://hackmd.io/_uploads/HyrTKt-GZe.png) :::spoiler C Code ```c= static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` ::: :::spoiler Assembly ```asm= .data pass_msg: .asciz "Test Passed\n" fail_msg: .asciz "Test Failed\n" .text .globl main .globl bf16_sqrt main: addi sp, sp, -4 sw ra, 0(sp) test1: li a0, 0x3F80 # A = 1.0 jal ra, bf16_sqrt # a0 = sqrt(A) li t1, 0x3F80 # expected = 1.0 bne a0, t1, test1_fail jal ra, print_pass j test2 test1_fail: jal ra, print_fail j test2 test2: li a0, 0x3E80 # A = 0.25 jal ra, bf16_sqrt li t1, 0x3F00 # expected = 0.5 bne a0, t1, test2_fail jal ra, print_pass j test3 test2_fail: jal ra, print_fail j test3 test3: li a0, 0x7F80 # A = +Inf jal ra, bf16_sqrt li t1, 0x7F80 # expected = +Inf bne a0, t1, test3_fail jal ra, print_pass j tests_done test3_fail: jal ra, print_fail j tests_done print_pass: la a0, pass_msg li a7, 4 # print string ecall jr ra print_fail: la a0, fail_msg li a7, 4 ecall jr ra tests_done: lw ra, 0(sp) addi sp, sp, 4 li a7, 10 # exit ecall bf16_sqrt: addi sp, sp, -32 sw ra, 28(sp) sw s0, 24(sp) sw s1, 20(sp) sw s2, 16(sp) sw s3, 12(sp) sw s4, 8(sp) sw s5, 4(sp) sw s6, 0(sp) srli t0, a0, 15 # t0 = sign bit andi t0, t0, 1 srli t1, a0, 7 # t1 = exponent (8 bits) andi t1, t1, 0xFF andi t2, a0, 0x7F # t2 = mantissa (7 bits) li t3, 0xFF bne t1, t3, check_zero # if exponent != 0xFF → not Inf/NaN bnez t2, return_a # mantissa != 0 → NaN, just return a bnez t0, return_nan # negative Inf → sqrt is NaN j return_a # +Inf → sqrt(+Inf) = +Inf check_zero: or t3, t1, t2 # if exponent==0 and mantissa==0 → zero bnez t3, check_negative j return_zero check_negative: bnez t0, return_nan # negative finite number → NaN bnez t1, compute_sqrt # if exponent != 0 → normal value j return_zero # subnormal very close to 0 → treat as 0 compute_sqrt: addi s0, t1, -127 # s0 = exp - bias ori s1, t2, 0x80 # normalized mantissa with implicit 1 andi t3, s0, 1 beqz t3, even_exp slli s1, s1, 1 # make mantissa bigger for odd exponent addi t4, s0, -1 srai t4, t4, 1 # (exp - 1) / 2 addi s2, t4, 127 # s2 = result exponent (biased) j binary_search even_exp: srai t4, s0, 1 # exp / 2 addi s2, t4, 127 # s2 = result exponent (biased) binary_search: li s3, 90 # low bound (approx range) li s4, 256 # high bound li s5, 128 # best candidate mantissa search_loop: bgt s3, s4, search_done # while low <= high add t3, s3, s4 srli t3, t3, 1 # mid = (low + high) / 2 mv a1, t3 # multiply mid * mid using shift-add mv a2, t3 jal ra, multiply # result in a0 mv t4, a0 # t4 = mid^2 srli t4, t4, 7 # align to compare with s1 bgt t4, s1, search_high # if mid^2 > mantissa → go left mv s5, t3 # mid is new best addi s3, t3, 1 # low = mid + 1 j search_loop search_high: addi s4, t3, -1 # high = mid - 1 j search_loop search_done: li t3, 256 blt s5, t3, check_low # if s5 < 256 → maybe need left shift srli s5, s5, 1 # s5 >= 256 → shift right once and increment exponent addi s2, s2, 1 j extract_mant check_low: li t3, 128 bge s5, t3, extract_mant norm_loop: li t3, 128 bge s5, t3, extract_mant # stop when MSB is at bit 7 li t3, 1 ble s2, t3, extract_mant # avoid exponent underflow slli s5, s5, 1 addi s2, s2, -1 j norm_loop extract_mant: andi s6, s5, 0x7F # keep 7-bit mantissa li t3, 0xFF bge s2, t3, return_inf # overflow exponent → +Inf blez s2, return_zero # exponent <= 0 → treat as 0 andi t3, s2, 0xFF # final exponent slli t3, t3, 7 or a0, t3, s6 # pack exponent + mantissa (sign is 0) j cleanup return_zero: li a0, 0x0000 # +0 j cleanup return_nan: li a0, 0x7FC0 # canonical NaN j cleanup return_inf: li a0, 0x7F80 # +Inf j cleanup return_a: j cleanup # just return the original a0 cleanup: lw s6, 0(sp) lw s5, 4(sp) lw s4, 8(sp) lw s3, 12(sp) lw s2, 16(sp) lw s1, 20(sp) lw s0, 24(sp) lw ra, 28(sp) addi sp, sp, 32 ret multiply: li a0, 0 # a0 = result = 0 beqz a2, mult_done # if multiplier == 0 → return 0 mult_loop: andi t0, a2, 1 # if (a2 & 1) add a1 beqz t0, mult_skip add a0, a0, a1 mult_skip: slli a1, a1, 1 # a1 <<= 1 srli a2, a2, 1 # a2 >>= 1 bnez a2, mult_loop # loop while any bits remain mult_done: ret ``` ::: ### 5-Stage Pipeline #### Overview of the Pipeline The following figure shows the 5-Stage RISC-V Processor implemented in Ripes: ![image](https://hackmd.io/_uploads/B1EfGqWz-l.png) **Pipeline Components:** * **IF** – Instruction Fetch * **ID** – Instruction Decode * **EX** – Execute / ALU operations * **MEM** – Memory Access * **WB** – Write Back Supporting elements include: **PC, Instruction Memory, Register File, ALU, Data Memory**, and several **MUXes** that handle forwarding and control signals. #### Stage 1 — IF (Instruction Fetch) **Function** : Fetch the instruction from Instruction Memory using the current PC. ![image](https://hackmd.io/_uploads/SJFnI5bGZe.png) At this stage, the processor fetches the instruction from memory. * The PC holds address `0x00000004` and sends it to the Instruction Memory. * The memory outputs instruction `0x10001113`, decoded as `addi x2, x0, 256`. * The IF/ID pipeline register stores this instruction for the next (ID) stage. * The PC adder calculates the next address (`0x00000008`) for the next fetch. #### Stage 2 — ID (Instruction Decode) **Function** : Decode the instruction, read operands (`rs1, rs2`) from Register File, and generate control signals. ![image](https://hackmd.io/_uploads/Hyt2O5-f-l.png) At this stage, the CPU decodes the instruction `lw x3, 0(x2)`. * The **Decode unit** extracts `opcode=0x03`, `rs1=x2`, `rd=x3`, and `imm=0`. * The **Register File** reads the base register `x2`, whose value is `0x7fffffff0`. * The **Immediate Generator** outputs `0x00000000` as the offset. * All values and control signals are stored in the **ID/EX pipeline register** for the next (EX) stage. #### Stage 3 — EX (Execute) **Function**:Perform ALU operations and evaluate branch conditions. ![image](https://hackmd.io/_uploads/Hk3bJjZGWg.png) In this stage, the CPU executes the instruction `add x5, x1, x3`. * The **ALU** receives two operands: `x1 = 0x00000004` and `x3 = 0x00000004`. * The operation performed is addition (`0x00000004 + 0x00000004`), producing the result `0x00000008`. * The **Result** (`Res`) output from the ALU is sent forward to the EX/MEM pipeline register. * The **Branch unit** is inactive (`Branch taken = 0`), meaning this is a normal arithmetic instruction with no control flow change. #### Stage 4 — MEM (Memory Access) **Function**:Perform data memory read or write. ![image](https://hackmd.io/_uploads/H1aa39-Gbl.png) In this stage, the instruction `add x5, x1, x3` reaches the **MEM (Memory Access)** phase. * Since this is an arithmetic instruction (not `lw` or `sw`), **no memory read or write** actually occurs. * The **Data memory** block shows `WrEn = 0`, meaning write is disabled. * The result value from the previous EX stage (`0x00000008`) is simply passed through the **EX/MEM → MEM/WB** pipeline register. * The next stage (WB) will use this value to update register `x5`. #### Stage 5 — WB (Write Back) **Function**:Write ALU or memory results back to the Register File. ![image](https://hackmd.io/_uploads/By_eCcWGZg.png) In this final stage, the instruction `add x5, x1, x3` performs the **write-back operation**. * The result from the MEM/WB register (`0x00000008`) is selected by the multiplexer and sent to the **Register File**. * The **destination register** is `x5` (`rd = 0x05`), and the **write enable signal** is active (green). * The value `0x00000008` is now written into register `x5`, completing the instruction execution.