# Assignment1: RISC-V Assembly and Instruction Pipeline > contributed by < [Shaoen-Lin](https://github.com/Shaoen-Lin) > ## Problem `B` in Quiz1 ### C code ```c #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> typedef uint8_t uf8; static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } /* Test encode/decode round-trip */ static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } int main(void) { if (test()) { printf("All tests passed.\n"); return 0; } return 1; } ``` ### RV32I Assembly code The following RV32I code includes several test data cases and uses **automated testing** for verification. ```riscv .data nl: .string "\n" msg0: .string ": produces value " msg1: .string " but encodes back to " msg2: .string ": value " msg3: .string " <= previous_value " msg4: .string "All tests passed." .text .global main main: jal ra, test beqz a0, return_1 la a0, msg4 li a7, 4 ecall la a0, nl li a7, 4 ecall li a0, 0 li a7, 10 ecall return_1: li a0, 1 li a7, 10 ecall # ============================================================ # clz: Count leading zeros (binary search) # Input : a0 (unsigned int) # Output: a0 = leading zero count # ============================================================ clz: li s0, 32 li s1, 16 clz_while_loop: srl t0, a0, s1 bnez t0, clz_if srli s1, s1, 1 j check_condition clz_if: sub s0, s0, s1 add a0, t0, zero check_condition: bnez s1, clz_while_loop sub a0, s0, a0 ret # ============================================================ # uf8_decode: Decode uf8 -> uint32_t # ============================================================ uf8_decode: andi s0, a0, 0x0f srli s1, a0, 4 li t0, 15 sub t0, t0, s1 li s2, 0x7FFF srl s2, s2, t0 slli s2, s2, 4 sll a0, s0, s1 add a0, a0, s2 ret # ============================================================ # uf8_encode: Encode uint32_t -> uf8 # ============================================================ uf8_encode: li t0, 16 blt a0, t0, return_a0 addi sp, sp, -8 sw ra, 0(sp) sw a0, 4(sp) jal ra, clz add s4, a0, zero lw ra, 0(sp) lw a0, 4(sp) addi sp, sp, 8 li t0, 31 sub s5, t0, s4 li s6, 0 li s7, 0 li t0, 5 bge s5, t0, encode_if_msb_bge_5 msb_less_5: # If msb < 5, find exponent loop li t0, 15 check_while_loop2_condition: blt s6, t0, encode_while_loop2 encode_return: sub s0, a0, s7 srl s0, s0, s6 slli t0, s6, 4 or a0, t0, s0 ret encode_if_msb_bge_5: # If msb >= 5, estimate exponent addi s6, s5, -4 li t0, 15 bgt s6, t0, set_expoent_15 back_encode_if_msb_bge_5: li t0, 0 encode_for_loop: # overflow = (overflow << 1) + 16 bge t0, s6, out_of_encode_loop slli t1, s7, 1 addi s7, t1, 16 addi t0, t0, 1 j encode_for_loop out_of_encode_loop: bgt s6, zero, encode_while_loop1 j msb_less_5 encode_while_loop1: # Adjust exponent if overflow too large bge a0, s7, msb_less_5 addi t0, s7, -16 srli s7, t0, 1 addi s6, s6, -1 j out_of_encode_loop set_expoent_15: addi s6, zero, 15 j back_encode_if_msb_bge_5 encode_while_loop2: # Find exact exponent slli s8, s7, 1 addi s8, s8, 16 blt a0, s8, encode_return add s7, s8, x0 addi s6, s6, 1 j check_while_loop2_condition return_a0: ret # ============================================================ # test: Run encode/decode test loop # ============================================================ test: li s0, -1 li s1, 1 addi t2, zero, 0 li t3, 256 test_for_loop: bge t2, t3, out_test_for_loop addi t4, t2, 0 # call uf8_decode(fl) addi sp, sp, -12 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) add a0, t4, x0 jal ra, uf8_decode addi t5, a0, 0 lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) addi sp, sp, 12 # call uf8_encode(value) addi sp, sp, -12 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) jal ra, uf8_encode addi t6, a0, 0 lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) addi sp, sp, 12 bne t4, t6, test_if_1 # if (fl != fl2) out_test_if_1: ble t5, s0, test_if_2 # if (value <= previous_value) out_test_if_2: add s0, t5, x0 addi t2, t2, 1 j test_for_loop out_test_for_loop: add a0, s1, zero ret # Print mismatch: fl != fl2 test_if_1: mv a0, t4 li a7, 34 ecall la a0, msg0 li a7, 4 ecall mv a0, t5 li a7, 1 ecall la a0, msg1 li a7, 4 ecall mv a0, t6 li a7, 34 ecall la a0, nl li a7, 4 ecall li s1, 0 j out_test_if_1 # Print non-monotonic: value <= previous_value test_if_2: mv a0, t4 li a7, 34 ecall la a0, msg2 li a7, 4 ecall mv a0, t5 li a7, 1 ecall la a0, msg3 li a7, 4 ecall mv a0, s0 li a7, 34 ecall la a0, nl li a7, 4 ecall li s1, 0 j out_test_if_2 ``` ## Problem `C` in Quiz1 In the C code (both with and without bf16_sqrt), the corresponding RV32I assembly implementation requires the use of multiplication operations (*). However, the **RV32I** base instruction set **does not include any hardware multiplication instruction** (such as mul), so this project adopts the **Egyptian Multiplication algorithm** to emulate the multiplication process in software. ### Egyptian Multiplication algorithm Egyptian Multiplication is an algorithm based on doubling and halving. Its core idea is to repeatedly shift the multiplicand left and shift the multiplier right, accumulating the multiplicand into the result whenever the least significant bit (LSB) of the multiplier is 1. This sequence of shifts and conditional additions effectively simulates the multiplication process without hardware support. Following is the algorithm of Egyptian Multiplication: $$ nm = \begin{cases} \frac{n}{2} \cdot 2m & \text{if } n \text{ is even}, \\ \frac{n-1}{2} \cdot 2m + m & \text{if } n \text{ is odd}, \\ m & \text{if } n = 1. \end{cases} $$ Following is the RISC-V assembly code of Egyptian Multiplication: ```riscv # ======================================================= # multiply8(a0, a1): Egyptian Multiplication # ======================================================= # Parameters: # a0 = multiplicand (8-bit) # a1 = multiplier (8-bit) # # Return: # a0 = 16-bit result # ======================================================= multiply8: mv s10, a0 mv s9, a1 li a0, 0 mul_loop: beqz s9, mul_done andi s8, s9, 1 beqz s8, skip_add add a0, a0, s10 skip_add: slli s10, s10, 1 srli s9, s9, 1 j mul_loop mul_done: ret ``` ### C code without `bf16_sqrt` ```c #include <stdbool.h> #include <stdint.h> #include <string.h> typedef struct { uint16_t bits; } bf16_t; #define BF16_SIGN_MASK 0x8000U #define BF16_EXP_MASK 0x7F80U #define BF16_MANT_MASK 0x007FU #define BF16_EXP_BIAS 127 #define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) #define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F)}; } ``` ### RV32I Assembly code without `bf16_sqrt` The following RV32I code includes several test data cases and uses **automated testing** for verification. ```riscv .data BF16_SIGN_MASK: .half 0x8000 BF16_EXP_MASK: .half 0x7F80 BF16_MANT_MASK: .half 0x007F BF16_EXP_BIAS: .half 127 BF16_NAN: .half 0x7FC0 BF16_ZERO: .half 0x0000 nl: .string "\n" msg_case: .string "Test case " msg_input: .string "Input: " msg_output: .string "Output: " msg_expect: .string "Expect: " msg_ok: .string "✅ Correct\n" msg_wrong: .string "❌ Wrong\n" # ====== CONVERSION TEST String ====== msg_conv1: .string "\n=== BF16 -> F32 TESTS ===\n" msg_conv2: .string "\n=== F32 -> BF16 TESTS ===\n" msg_add: .string "\n=== BF16 ADD TESTS ===\n" msg_sub: .string "\n=== BF16 SUB TESTS ===\n" msg_mul: .string "\n=== BF16 MUL TESTS ===\n" msg_div: .string "\n=== BF16 DIV TESTS ===\n" # ====== ADD String ====== msg1: .string "1.0 + 2.0 = " msg2: .string "2.0 + (-2.0) = " msg3: .string "inf + 1.0 = " msg4: .string "inf + -inf = " msg5: .string "NaN + 1.0 = " msg6: .string "1.0 + 0.015625 = " # ====== SUB String ====== msgs1: .string "2.0 - 1.0 = " msgs2: .string "5.0 - 2.0 = " msgs3: .string "1.0 - 2.0 = " msgs4: .string "(-2.0) - 3.0 = " msgs5: .string "Inf - Inf = " msgs6: .string "NaN - 1.0 = " msgs7: .string "0.0 - 1.0 = " msgs8: .string "1.0 - 0.0 = " # ====== MUL String ====== msgm1: .string "1.0 * 2.0 = " msgm2: .string "0.5 * 0.5 = " msgm3: .string "-1.0 * 3.0 = " msgm4: .string "Inf * 2.0 = " msgm5: .string "0 * 123.0 = " msgm6: .string "Inf * 0 = " msgm7: .string "NaN * 5.0 = " msgm8: .string "subnormal * 2.0 = " # ====== DIV String ====== msgd1: .string "1.0 / 2.0 = " msgd2: .string "2.0 / 1.0 = " msgd3: .string "1.0 / 0.0 = " msgd4: .string "0.0 / 1.0 = " msgd5: .string "Inf / Inf = " msgd6: .string "NaN / 1.0 = " msgd7: .string "(-2.0) / 1.0 = " # ======= CONVERSION expected output ======= conv_expect_b2f: .word 0x3F800000, 0xC0000000 conv_expect_f2b: .half 0x4060, 0xC194 # ======= ADD expected output ======= add_expect: .half 0x4040, 0x0000, 0x7F80, 0x7FC0, 0x7FC1, 0x3F80 # ======= SUB expected output ======= sub_expect: .half 0x3F80, 0x4040, 0xBF80, 0xC0A0, 0x7FC0, 0x7FC0, 0xBF80, 0x3F80 # ======= MUL expected output ======= mul_expect: .half 0x4000, 0x3E80, 0xC040, 0x7F80, 0x0000, 0x7FC0, 0x7FC1, 0x0000 # ======= DIV expected output ======= div_expect: .half 0x3F00, 0x4000, 0x7F80, 0x0000, 0x7FC0, 0x7FC0, 0xC000 .text .global main main: li s0, 255 lui t0, 0x8 addi s11, t0, -0x80 # s11 = 0x7F80 (Inf mask) # ------------------------------ # BF16 -> F32 TESTS # ------------------------------ la a0, msg_conv1 li a7, 4 ecall la a0, msg_input li a7, 4 ecall li a0, 0x3F80 li a7, 34 ecall li a0, 0x3F80 jal ra, bf16_to_f32 mv t0, a0 la a1, conv_expect_b2f li a2, 0 li a3, 1 jal ra, compare_result la a0, msg_input li a7, 4 ecall li a0, 0xC000 li a7, 34 ecall li a0, 0xC000 jal ra, bf16_to_f32 mv t0, a0 la a1, conv_expect_b2f li a2, 1 li a3, 1 jal ra, compare_result # ------------------------------ # F32 -> BF16 TESTS # ------------------------------ la a0, msg_conv2 li a7, 4 ecall la a0, msg_input li a7, 4 ecall li a0, 0x40600000 li a7, 34 ecall li a0, 0x40600000 jal ra, f32_to_bf16 mv t0, a0 la a1, conv_expect_f2b li a2, 0 li a3, 0 jal ra, compare_result la a0, msg_input li a7, 4 ecall li a0, 0xC19447AE li a7, 34 ecall li a0, 0xC19447AE jal ra, f32_to_bf16 mv t0, a0 la a1, conv_expect_f2b li a2, 1 li a3, 0 jal ra, compare_result # ------------------------------ # ADD TEST # ------------------------------ la a0, msg_add li a7, 4 ecall la a0, msg1 li a7, 4 ecall li a0, 0x3F80 li a1, 0x4000 jal ra, bf16_add la a1, add_expect li a2, 0 jal ra, compare_result la a0, msg2 li a7, 4 ecall li a0, 0x4000 li a1, 0xC000 jal ra, bf16_add la a1, add_expect li a2, 1 jal ra, compare_result la a0, msg3 li a7, 4 ecall li a0, 0x7F80 li a1, 0x3F80 jal ra, bf16_add la a1, add_expect li a2, 2 jal ra, compare_result la a0, msg4 li a7, 4 ecall li a0, 0x7F80 li a1, 0xFF80 jal ra, bf16_add la a1, add_expect li a2, 3 jal ra, compare_result la a0, msg5 li a7, 4 ecall li a0, 0x7FC1 li a1, 0x3F80 jal ra, bf16_add la a1, add_expect li a2, 4 jal ra, compare_result la a0, msg6 li a7, 4 ecall li a0, 0x3F80 li a1, 0x3800 jal ra, bf16_add la a1, add_expect li a2, 5 jal ra, compare_result # ------------------------------ # SUB TEST # ------------------------------ la a0, msg_sub li a7, 4 ecall la a0, msgs1 li a7, 4 ecall li a0, 0x4000 li a1, 0x3F80 jal ra, bf16_sub la a1, sub_expect li a2, 0 jal ra, compare_result la a0, msgs2 li a7, 4 ecall li a0, 0x40A0 li a1, 0x4000 jal ra, bf16_sub la a1, sub_expect li a2, 1 jal ra, compare_result la a0, msgs3 li a7, 4 ecall li a0, 0x3F80 li a1, 0x4000 jal ra, bf16_sub la a1, sub_expect li a2, 2 jal ra, compare_result la a0, msgs4 li a7, 4 ecall li a0, 0xC000 li a1, 0x4040 jal ra, bf16_sub la a1, sub_expect li a2, 3 jal ra, compare_result la a0, msgs5 li a7, 4 ecall li a0, 0x7F80 li a1, 0x7F80 jal ra, bf16_sub la a1, sub_expect li a2, 4 jal ra, compare_result la a0, msgs6 li a7, 4 ecall li a0, 0x7FC0 li a1, 0x3F80 jal ra, bf16_sub la a1, sub_expect li a2, 5 jal ra, compare_result la a0, msgs7 li a7, 4 ecall li a0, 0x0000 li a1, 0x3F80 jal ra, bf16_sub la a1, sub_expect li a2, 6 jal ra, compare_result la a0, msgs8 li a7, 4 ecall li a0, 0x3F80 li a1, 0x0000 jal ra, bf16_sub la a1, sub_expect li a2, 7 jal ra, compare_result # ------------------------------ # MUL TEST # ------------------------------ la a0, msg_mul li a7, 4 ecall la a0, msgm1 li a7, 4 ecall li a0, 0x3F80 li a1, 0x4000 jal ra, bf16_mul li a7, 34 ecall la a1, mul_expect li a2, 0 jal ra, compare_result la a0, msgm2 li a7, 4 ecall li a0, 0x3F00 li a1, 0x3F00 jal ra, bf16_mul la a1, mul_expect li a2, 1 jal ra, compare_result la a0, msgm3 li a7, 4 ecall li a0, 0xBF80 li a1, 0x4040 jal ra, bf16_mul la a1, mul_expect li a2, 2 jal ra, compare_result la a0, msgm4 li a7, 4 ecall li a0, 0x7F80 li a1, 0x4000 jal ra, bf16_mul la a1, mul_expect li a2, 3 jal ra, compare_result la a0, msgm5 li a7, 4 ecall li a0, 0x0000 li a1, 0x42F6 jal ra, bf16_mul la a1, mul_expect li a2, 4 jal ra, compare_result la a0, msgm6 li a7, 4 ecall li a0, 0x7F80 li a1, 0x0000 jal ra, bf16_mul la a1, mul_expect li a2, 5 jal ra, compare_result la a0, msgm7 li a7, 4 ecall li a0, 0x7FC1 li a1, 0x40A0 jal ra, bf16_mul la a1, mul_expect li a2, 6 jal ra, compare_result la a0, msgm8 li a7, 4 ecall li a0, 0x0001 li a1, 0x4000 jal ra, bf16_mul la a1, mul_expect li a2, 7 jal ra, compare_result # ------------------------------ # DIV TEST # ------------------------------ la a0, msg_div li a7, 4 ecall la a0, msgd1 li a7, 4 ecall li a0, 0x3F80 li a1, 0x4000 jal ra, bf16_div la a1, div_expect li a2, 0 jal ra, compare_result la a0, msgd2 li a7, 4 ecall li a0, 0x4000 li a1, 0x3F80 jal ra, bf16_div la a1, div_expect li a2, 1 jal ra, compare_result la a0, msgd3 li a7, 4 ecall li a0, 0x3F80 li a1, 0x0000 jal ra, bf16_div la a1, div_expect li a2, 2 jal ra, compare_result la a0, msgd4 li a7, 4 ecall li a0, 0x0000 li a1, 0x3F80 jal ra, bf16_div la a1, div_expect li a2, 3 jal ra, compare_result la a0, msgd5 li a7, 4 ecall li a0, 0x7F80 li a1, 0x7F80 jal ra, bf16_div la a1, div_expect li a2, 4 jal ra, compare_result la a0, msgd6 li a7, 4 ecall li a0, 0x7FC0 li a1, 0x3F80 jal ra, bf16_div la a1, div_expect li a2, 5 jal ra, compare_result la a0, msgd7 li a7, 4 ecall li a0, 0xC000 li a1, 0x3F80 jal ra, bf16_div la a1, div_expect li a2, 6 jal ra, compare_result li a7, 10 ecall # ======================================================= # compare_result(a0, expect_addr, idx, is32bit) # ======================================================= # a0 = actual result (16-bit or 32-bit) # a1 = address of expected value table # a2 = test case index (0-based) # a3 = is32bit flag (1 = 32-bit, 0 = 16-bit) # ======================================================= compare_result: addi sp, sp, -20 sw t0, 0(sp) sw t1, 4(sp) sw t2, 8(sp) sw t3, 12(sp) sw t4, 16(sp) mv t0, a0 beqz a3, half_case slli t1, a2, 2 add t2, a1, t1 lw t3, 0(t2) j load_done half_case: slli t1, a2, 1 add t2, a1, t1 lhu t3, 0(t2) load_done: li t4, 0xFFFFFFFF and t0, t0, t4 and t3, t3, t4 la a0, nl li a7, 4 ecall # ---- Output (hex) ---- la a0, msg_output li a7, 4 ecall mv a0, t0 li a7, 34 ecall la a0, nl li a7, 4 ecall # ---- Expect (hex) ---- la a0, msg_expect li a7, 4 ecall mv a0, t3 li a7, 34 ecall la a0, nl li a7, 4 ecall # ---- Compare ---- beq t0, t3, print_ok la a0, msg_wrong li a7, 4 ecall j done print_ok: la a0, msg_ok li a7, 4 ecall done: lw t0, 0(sp) lw t1, 4(sp) lw t2, 8(sp) lw t3, 12(sp) lw t4, 16(sp) addi sp, sp, 20 ret # ------------------------------ # bf16_isnan(a0): check if NaN # ------------------------------ bf16_isnan: la t0, BF16_EXP_MASK lh t1, 0(t0) and t2, a0, t1 bne t2, t1, not_nan la t0, BF16_MANT_MASK lh t1, 0(t0) and t2, a0, t1 beqz t2, not_nan li a0, 1 ret not_nan: li a0, 0 ret # ------------------------------ # bf16_isinf(a0): check if Inf # ------------------------------ bf16_isinf: la t0, BF16_EXP_MASK lh t1, 0(t0) and t2, a0, t1 bne t2, t1, not_inf la t0, BF16_MANT_MASK lh t1, 0(t0) and t2, a0, t1 bnez t2, not_inf li a0, 1 ret not_inf: li a0, 0 ret # ------------------------------ # bf16_iszero(a0): check if Zero # ------------------------------ bf16_iszero: lui t0, 8 addi t0, t0, -1 and a0, a0, t0 beqz a0, is_zero li a0, 1 ret is_zero: li a0, 0 ret # ------------------------------ # f32_to_bf16(a0): convert f32 → bf16 # ------------------------------ f32_to_bf16: srli t0, a0, 23 li t1, 255 and t0, t0, t1 beq t0, t1, is_nan_inf srli t0, a0, 16 andi t0, t0, 1 lui t1, 8 addi t1, t1, -1 add t0, t0, t1 add a0, a0, t0 srli a0, a0, 16 ret is_nan_inf: srli t0, a0, 16 lui t1, 16 addi t1, t1, -1 and a0, t0, t1 ret # ------------------------------ # bf16_to_f32(a0): extend bf16 → f32 # ------------------------------ bf16_to_f32: slli a0, a0, 16 ret # ------------------------------ # bf16_add(a0, a1): BF16 addition # ------------------------------ bf16_add: # Extract sign/exponent/mantissa srli t0, a0, 15 andi t0, t0, 1 srli t1, a1, 15 andi t1, t1, 1 srli t2, a0, 7 and t2, t2, s0 srli t3, a1, 7 and t3, t3, s0 andi t4, a0, 127 andi t5, a1, 127 # Handle Inf/NaN and zeros beq t2, s0, a_inf_nan beq t3, s0, return_b beqz t2, check_mantissa_a_zero jal x0, next check_mantissa_a_zero: beqz t4, return_b next: beqz t3, check_mantissa_b_zero jal x0, next0 check_mantissa_b_zero: beqz t5, return_a next0: beqz t2, skip_a_implicit_1 ori t4, t4, 0x80 skip_a_implicit_1: beqz t3, skip_b_implicit_1 ori t5, t5, 0x80 skip_b_implicit_1: jal x0, next1 # --- handle special cases --- a_inf_nan: bnez t4, return_a beq t3, s0, a_and_b_inf_nan ret a_and_b_inf_nan: bnez t5, return_b beq t0, t1, return_b jal x0, return_nan # --- align exponents --- next1: sub s2, t2, t3 bgt s2, zero, greater_than_zero blt s2, zero, less_than_zero add s3, zero, t2 jal x0, next2 greater_than_zero: add s3, zero, t2 li t6, 8 bgt s2, t6, return_a srl t5, t5, s2 jal x0, next2 less_than_zero: add s3, zero, t3 li t6, -8 blt s2, t6, return_b neg t6, s2 srl t4, t4, t6 # --- perform mantissa add/sub --- next2: beq t0, t1, signa_eq_signb bge t4, t5, mant_a_greater_mant_b add s4, t1, zero sub s5, t5, t4 jal x0, next4 mant_a_greater_mant_b: add s4, t0, zero sub s5, t4, t5 # --- normalize result --- next4: bnez s5, normalize_loop la t6, BF16_ZERO lh a0, 0(t6) ret normalize_loop: andi t6, s5, 0x80 bnez t6, final_return addi s3, s3, -1 blez s3, underflow_zero slli s5, s5, 1 j normalize_loop underflow_zero: la t6, BF16_ZERO lh a0, 0(t6) ret # --- same sign addition --- signa_eq_signb: add s4, t0, zero add s5, t4, t5 andi t6, s5, 0x100 beqz t6, final_return srli s5, s5, 1 addi s3, s3, 1 j final_return # --- pack result --- final_return: slli s4, s4, 15 and s3, s3, s0 slli s3, s3, 7 andi s5, s5, 0x7F or a0, s3, s4 or a0, a0, s5 ret # ------------------------------ # bf16_sub(a0, a1): subtraction # ------------------------------ bf16_sub: lui t6, 0x8 xor a1, a1, t6 # flip sign bit addi sp, sp, -4 sw ra, 0(sp) jal ra, bf16_add lw ra, 0(sp) addi sp, sp, 4 ret # ------------------------------ # bf16_mul(a0, a1): BF16 multiplication # ------------------------------ bf16_mul: # Extract sign / exponent / mantissa srli t0, a0, 15 andi t0, t0, 1 srli t1, a1, 15 andi t1, t1, 1 srli t2, a0, 7 and t2, t2, s0 srli t3, a1, 7 and t3, t3, s0 andi t4, a0, 127 andi t5, a1, 127 xor s1, t0, t1 # Check for NaN / Inf cases bne t2, s0, check_b_exp bnez t4, return_a beqz t3, check_b_mant back1: slli a0, s1, 15 or a0, a0, s11 ret check_b_mant: bnez t5, back1 j return_nan # --- check exponent B special case --- check_b_exp: bne t3, s0, next6 bnez t4, return_b beqz t2, check_a_mant back2: slli a0, s1, 15 or a0, a0, s11 ret check_a_mant: bnez t4, back2 jal x0, return_nan # --- handle zero operands --- next6: beqz t2, check_a_is_zero check_b: beqz t3, check_b_is_zero a_b_no_zero: j next7 check_a_is_zero: beqz t4, a_or_b_is_zero j check_b check_b_is_zero: beqz t5, a_or_b_is_zero j a_b_no_zero a_or_b_is_zero: slli a0, s1, 15 ret # --- normalize subnormal exponents --- next7: add s2, zero, zero beqz t2, exp_a_zero ori t4, t4, 0x80 j check_b_exp_zero exp_a_zero: addi t2, zero, 1 andi t6, t4, 0x80 bnez t6, check_b_exp_zero slli t4, t4, 1 addi s2, s2, -1 j exp_a_zero # --- same for operand B --- check_b_exp_zero: beq t3, s0, exp_b_zero ori t5, t5, 0x80 j next8 exp_b_zero: addi t3, zero, 1 andi t6, t5, 0x80 bnez t6, next8 slli t5, t5, 1 addi s2, s2, -1 j exp_b_zero # --- perform mantissa multiplication (Egyptian method) --- next8: addi sp, sp, -12 sw a0, 0(sp) sw a1, 4(sp) sw ra, 8(sp) add a0, t4, zero add a1, t5, zero jal ra, multiply8 add s3, a0, zero lw a0, 0(sp) lw a1, 4(sp) lw ra, 8(sp) addi sp, sp, 12 # Calculate result exponent add s4, t2, t3 la t6, BF16_EXP_BIAS lh t6, 0(t6) sub s4, s4, t6 add s4, s4, s2 # Normalize mantissa lui t6, 0x8 and t6, s3, t6 bnez t6, ret_val_is_neg srli t6, s3, 7 andi s3, t6, 0x7F j ret_exp ret_val_is_neg: srli t6, s3, 8 andi s3, t6, 0x7F addi s4, s4, 1 # --- check overflow/underflow --- ret_exp: bge s4, s0, over_ff ble s4, zero, under_zero mul_final_return: slli s1, s1, 15 and s4, s4, s0 slli s4, s4, 7 andi s3, s3, 0x7F or a0, s1, s4 or a0, a0, s3 ret # overflow → Inf over_ff: slli a0, s1, 15 or a0, a0, s11 ret # underflow → 0 under_zero: li t6, -6 blt s4, t6, shift_sign_15 li t6, 1 sub t6, t6, s4 srl s3, s3, t6 li s4, 0 j mul_final_return shift_sign_15: slli a0, s1, 15 ret # ------------------------------ # bf16_div(a0, a1): BF16 division # ------------------------------ bf16_div: # Extract sign / exponent / mantissa srli t0, a0, 15 andi t0, t0, 1 srli t1, a1, 15 andi t1, t1, 1 srli t2, a0, 7 and t2, t2, s0 srli t3, a1, 7 and t3, t3, s0 andi t4, a0, 127 andi t5, a1, 127 xor s1, t0, t1 # --- check special cases --- beq t3, s0, div_b_inf_nan beqz t3, div_b_check_mant_0 b_is_not_zero_but_exp_0: beq t2, s0, div_a_inf_nan beqz t2, div_a_exp_0_check_mant_0 a_is_not_zero_but_exp_0: bnez t2, set_a_mant also_check_b: bnez t3, set_b_mant j set_div # --- handle b = Inf/NaN --- div_b_inf_nan: bnez t5, return_b beq t2, s0, b_check_a_mant div_b_inf_nan_return: slli a0, s1, 15 ret b_check_a_mant: beqz t4, return_nan j div_b_inf_nan_return # --- b = 0 case --- div_b_check_mant_0: beqz t5, div_a_check_0 j b_is_not_zero_but_exp_0 div_b_check_0_return: slli a0, s1, 15 or a0, a0, s11 ret div_a_check_0: beq t2, s0, div_a_check_mant_0 j div_b_check_0_return div_a_check_mant_0: beqz t4, return_nan j div_b_check_0_return # --- a = Inf/NaN --- div_a_inf_nan: bnez t4, return_a slli a0, s1, 15 ret # --- a = 0 case --- div_a_exp_0_check_mant_0: beqz t4, div_a_exp_0_check_0 j a_is_not_zero_but_exp_0 div_a_exp_0_check_0: slli a0, s1, 15 ret # --- set implicit 1 --- set_a_mant: ori t4, t4, 0x80 j also_check_b set_b_mant: ori t5, t5, 0x80 j set_div # ------------------------------ # division core (long division) # ------------------------------ set_div: slli s2, t4, 15 add s3, t5, zero li s4, 0 li t6, 0 li s5, 16 # --- for loop: binary long division --- for_loop: bge t6, s5, out_for_loop slli s4, s4, 1 addi s6, zero, 15 sub s6, s6, t6 sll s6, s3, s6 blt s2, s6, out_if sub s2, s2, s6 ori s4, s4, 1 out_if: addi t6, t6, 1 j for_loop out_for_loop: # --- compute exponent --- sub s5, t2, t3 la t6, BF16_EXP_BIAS lh t6, 0(t6) add s5, s5, t6 bnez t2, exp_a_isnot_zero addi s5, s5, -1 exp_a_isnot_zero: bnez t3, exp_b_isnot_zero addi s5, s5, 1 exp_b_isnot_zero: # --- normalize quotient --- lui t6, 0x8 and t6, s4, t6 beqz t6, check_while_condition srli s4, s4, 8 j next9 # --- normalization loop --- check_while_condition: lui t6, 0x8 and t6, s4, t6 beqz t6, check_result_exp else_shift_quotient: srli s4, s4, 8 j next9 check_result_exp: li s6, 1 bgt s5, s6, while_loop j else_shift_quotient while_loop: slli s4, s4, 1 addi s5, s5, -1 j check_while_condition # --- pack final result --- next9: andi s4, s4, 0x7F bge s5, s0, exp_greater_all_one ble s5, zero, exp_less_equal_zero slli s1, s1, 15 and s5, s5, s0 slli s5, s5, 7 andi s4, s4, 0x7F or a0, s1, s5 or a0, a0, s4 ret # --- overflow / underflow handling --- exp_greater_all_one: slli a0, s1, 15 or a0, a0, s11 ret exp_less_equal_zero: slli a0, s1, 15 ret # Common return labels return_zero: la t6, BF16_ZERO lh a0, 0(t6) ret return_nan: la t6, BF16_NAN lh a0, 0(t6) ret return_a: ret return_b: add a0, a1, zero ret # ======================================================= # multiply8(a0, a1): Egyptian Multiplication # ======================================================= # Parameters: # a0 = multiplicand (8-bit) # a1 = multiplier (8-bit) # # Return: # a0 = 16-bit result # ======================================================= multiply8: mv s10, a0 mv s9, a1 li a0, 0 mul_loop: beqz s9, mul_done andi s8, s9, 1 beqz s8, skip_add add a0, a0, s10 skip_add: slli s10, s10, 1 srli s9, s9, 1 j mul_loop mul_done: ret ``` ### C code with `bf16_sqrt` only ```c static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t){.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t){.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` ### RV32I Assembly Code with `bf16_sqrt` only The following RV32I code includes several test data cases and uses **automated testing** for verification. ```riscv .data # ======= Constants ======= BF16_SIGN_MASK: .half 0x8000 BF16_EXP_MASK: .half 0x7F80 BF16_MANT_MASK: .half 0x007F BF16_EXP_BIAS: .half 127 BF16_NAN: .half 0x7FC0 BF16_ZERO: .half 0x0000 # ======= Common Messages ======= nl: .string "\n" msg_case: .string "Test case " msg_input: .string "Input: " msg_output: .string "Output: " msg_expect: .string "Expect: " msg_ok: .string "✅ Correct\n" msg_wrong: .string "❌ Wrong\n" # ======= SQRT Test Labels ======= msg_sqrt: .string "\n=== BF16 SQRT TESTS ===\n" msg0: .string "sqrt(0.0) = " msg1: .string "sqrt(1.0) = " msg2: .string "sqrt(4.0) = " msg3: .string "sqrt(9.0) = " msg4: .string "sqrt(-1.0) = " msg5: .string "sqrt(+Inf) = " msg6: .string "sqrt(-Inf) = " msg7: .string "sqrt(0.25) = " msg8: .string "sqrt(16.0) = " msg9: .string "sqrt(2.0) = " # ======= Test Inputs ======= val0: .half 0x0000 # 0.0 val1: .half 0x3F80 # 1.0 val2: .half 0x4080 # 4.0 val3: .half 0x4110 # 9.0 val4: .half 0xBF80 # -1.0 val5: .half 0x7F80 # +Inf val6: .half 0xFF80 # -Inf val7: .half 0x3E80 # 0.25 val8: .half 0x4180 # 16.0 val9: .half 0x4000 # 2.0 # ======= Expected Outputs ======= sqrt_expect: .half 0x0000, 0x3F80, 0x4000, 0x4040, 0x7FC0, 0x7F80, 0x7FC0, 0x3F00, 0x4080, 0x3FB5 .text .global main main: li s0, 255 lui t0, 0x8 addi s11, t0, -0x80 # s11 = 0x7F80 (Inf mask) # ==== Print Header ==== la a0, msg_sqrt li a7, 4 ecall # ======================================================= # Test 0: sqrt(0.0) # ======================================================= la a0, msg0 li a7, 4 ecall la a0, val0 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 0 mv a0, t5 jal ra, compare_result # ======================================================= # Test 1: sqrt(1.0) # ======================================================= la a0, msg1 li a7, 4 ecall la a0, val1 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 1 mv a0, t5 jal ra, compare_result # ======================================================= # Test 2: sqrt(4.0) # ======================================================= la a0, msg2 li a7, 4 ecall la a0, val2 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 2 mv a0, t5 jal ra, compare_result # ======================================================= # Test 3: sqrt(9.0) # ======================================================= la a0, msg3 li a7, 4 ecall la a0, val3 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 3 mv a0, t5 jal ra, compare_result # ======================================================= # Test 4: sqrt(-1.0) # ======================================================= la a0, msg4 li a7, 4 ecall la a0, val4 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 4 mv a0, t5 jal ra, compare_result # ======================================================= # Test 5: sqrt(+Inf) # ======================================================= la a0, msg5 li a7, 4 ecall la a0, val5 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 5 mv a0, t5 jal ra, compare_result # ======================================================= # Test 6: sqrt(-Inf) # ======================================================= la a0, msg6 li a7, 4 ecall la a0, val6 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 6 mv a0, t5 jal ra, compare_result # ======================================================= # Test 7: sqrt(0.25) # ======================================================= la a0, msg7 li a7, 4 ecall la a0, val7 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 7 mv a0, t5 jal ra, compare_result # ======================================================= # Test 8: sqrt(16.0) # ======================================================= la a0, msg8 li a7, 4 ecall la a0, val8 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 8 mv a0, t5 jal ra, compare_result # ======================================================= # Test 9: sqrt(2.0) # ======================================================= la a0, msg9 li a7, 4 ecall la a0, val9 lh a0, 0(a0) jal ra, bf16_sqrt mv t5, a0 la a1, sqrt_expect li a2, 9 mv a0, t5 jal ra, compare_result # ======================================================= # End of Program # ======================================================= li a7, 10 ecall # ======================================================= # compare_result(a0, expect_addr, idx) # ======================================================= # compare_result(a0, expect_addr, idx) # a0 = actual result (16-bit) # a1 = address of the expected value table # a2 = test case index (0-based) # ======================================================= compare_result: addi sp, sp, -16 sw t0, 0(sp) sw t1, 4(sp) sw t2, 8(sp) sw t3, 12(sp) mv t0, a0 slli t1, a2, 1 add t2, a1, t1 lhu t3, 0(t2) li t4, 0xFFFF and t0, t0, t4 and t3, t3, t4 la a0, nl li a7, 4 ecall # ---- Output (hex) ---- la a0, msg_output li a7, 4 ecall mv a0, t0 li a7, 34 ecall la a0, nl li a7, 4 ecall # ---- Expect (hex) ---- la a0, msg_expect li a7, 4 ecall mv a0, t3 li a7, 34 ecall la a0, nl li a7, 4 ecall # ---- Compare ---- beq t0, t3, print_ok la a0, msg_wrong li a7, 4 ecall j done print_ok: la a0, msg_ok li a7, 4 ecall done: lw t0, 0(sp) lw t1, 4(sp) lw t2, 8(sp) lw t3, 12(sp) addi sp, sp, 16 ret # ============================================================ # Function: bf16_sqrt # Input : a0 (BF16 value) # Output: a0 (sqrt result) # ============================================================ bf16_sqrt: srai s1, a0, 15 andi s1, s1, 1 srai s2, a0, 7 and s2, s2, s0 andi s3, a0, 0x7F beq s2, s0, Handle_special_cases beqz s2, sqrt_check_mant bnez s1, return_nan beqz s2, return_zero la s4, BF16_EXP_BIAS lh s4, 0(s4) sub t0, s2, s4 ori t2, s3, 0x80 andi t3, t0, 1 bnez t3, Adjust_for_odd_exponents srai t3, t0, 1 add t1, t3, s4 # Binary search: find integer sqrt(mantissa) low_high_result: li s5, 90 li s6, 256 li s7, 128 Binary_search_loop: bgt s5, s6, out_Binary_Search add t3, s5, s6 srli t3, t3, 1 addi sp, sp, -12 sw a0, 0(sp) sw a1, 4(sp) sw ra, 8(sp) mv a0, t3 mv a1, t3 jal ra, multiply8 mv t4, a0 lw a0, 0(sp) lw a1, 4(sp) lw ra, 8(sp) addi sp, sp, 12 srli t4, t4, 7 ble t4, t2, binary_search_if addi s6, t3, -1 j Binary_search_loop binary_search_if: add s7, t3, x0 addi s5, t3, 1 j Binary_search_loop # Post-processing after binary search out_Binary_Search: li t3, 256 bge s7, t3, result_greater_256 li t3, 128 blt s7, t3, result_less_128 Extract_7_bit_mantissa: andi t3, s7, 0x7F bge t1, s0, sqrt_overflow ble t1, zero, return_zero and a0, t1, s0 slli a0, a0, 7 or a0, a0, t3 ret # Special / edge case handlers Handle_special_cases: bnez s3, return_a bnez s1, return_nan j return_a sqrt_check_mant: beqz s3, return_zero Adjust_for_odd_exponents: slli t2, t2, 1 addi t3, t0, -1 srai t3, t3, 1 add t1, t3, s4 j low_high_result # Handle result normalization result_greater_256: srli s7, s7, 1 addi t1, t1, 1 j Extract_7_bit_mantissa result_less_128: li t3, 128 blt s7, t3, sqrt_check_new_exp j Extract_7_bit_mantissa sqrt_check_new_exp: li t3, 1 bgt t1, t3, sqrt_while_loop_2 j Extract_7_bit_mantissa sqrt_while_loop_2: slli s7, s7, 1 addi t1, t1, -1 j result_less_128 sqrt_overflow: add a0, zero, s11 ret # Return helper sections return_zero: la t6, BF16_ZERO lh a0, 0(t6) ret return_nan: la t6, BF16_NAN lh a0, 0(t6) ret return_a: ret return_b: add a0, a1, zero ret # ============================================================ # multiply8: Egyptian Multiplication # Input : a0=a, a1=b # Output: a0=a*b (16-bit result) # ============================================================ multiply8: mv s10, a0 mv s9, a1 li a0, 0 mul_loop: beqz s9, mul_done andi s8, s9, 1 beqz s8, skip_add add a0, a0, s10 skip_add: slli s10, s10, 1 srli s9, s9, 1 j mul_loop mul_done: ret ``` ## LeetCode 260. Single Number III ### Description Given an integer array nums, in which exactly two elements appear only once and all the other elements appear exactly twice. Find the two elements that appear only once. You can return the answer **in any order**. You must write an algorithm that runs in linear runtime complexity and uses only constant extra space. Example 1: > **Input:** nums = [1,2,1,3,2,5] > **Output:** [3,5] > **Explanation:** [5, 3] is also a valid answer. Example 2: > **Input:** nums = [-1,0] > **Output:** [-1,0] Example 3: > **Input:** nums = [0,1] > **Output:** [1,0] Constraints: * 2 <= nums.length <= 3 * 104 * -231 <= nums[i] <= 231 - 1 * Each integer in nums will appear twice, only two integers will appear once ### `clz` function The purpose of this function is to accelerate **the computation of the number of leading zeros** in a 32-bit unsigned integer using a combination of right shifts and binary search. ```c static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } ``` ### Solution Concept **Step 1.** XOR all the numbers to get xor_all = a ^ b, where a and b are the two unique numbers. **Step 2.** Find any bit in xor_all that is set to 1 — this bit represents a position where a and b differ, where the bit is also called "set bit". * In the C code without `clz` version, we use unsigned int set_bit = (unsigned int)xor_val & -(unsigned int)xor_val; We are trying to isolate the lowest set bit (the rightmost 1) in xor_all. However, if xor_all equals INT_MIN (-2147483648), its binary form is: 10000000 00000000 00000000 00000000 In 32-bit signed integers, this value has no positive counterpart — `-INT_MIN` would require 2147483648, which cannot be represented in int (the max is 2147483647).That causes signed overflow, which is **undefined behavior** in C. So as to fix the problem, we cast signed to **unsigned**, we tell the compiler to treat the bits as pure binary, without interpreting the sign bit. * In the C code with `clz` version, we use int shift = 31 - clz((uint32_t)xor_val); unsigned int mask = 1U << shift; We are trying to isolate the highest set bit (the leftmost 1) then store it in **shift**. Then, we sets only the highest differing bit to 1 and all others to 0. You’ll use this mask to separate a and b. * In the C code with `__builtin_clz` version, we use int shift = 31 - __builtin_clz((unsigned int)xor_val); unsigned int mask = 1U << shift; where the concept is totally same with `clz` version. **Step 3.** Use that bit to divide the entire array into two groups: Group 1: numbers where this bit is 0 Group 2: numbers where this bit is 1 This ensures that: * All paired (duplicate) numbers fall into the same group (since they are identical) * The two unique numbers a and b fall into different groups **Step 4.** XOR all numbers within each group separately to obtain the two unique numbers. ### C code without `clz` ```c int *singleNumber(int *nums, int numsSize, int *returnSize) { // This is Step 1 int xor_val = 0; for (int i = 0; i < numsSize; i++) xor_val ^= nums[i]; // This is Step 2 unsigned int set_bit = (unsigned int)xor_val & -(unsigned int)xor_val; // This is Step 3 int a = 0, b = 0; for (int i = 0; i < numsSize; i++) { if (nums[i] & set_bit) a ^= nums[i]; else b ^= nums[i]; } // This is Step 4 int *res = malloc(sizeof(int) * 2); res[0] = a; res[1] = b; *returnSize = 2; return res; } ``` ### C code with `clz` ```c #include <stdint.h> #include <stdlib.h> static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } int *singleNumber(int *nums, int numsSize, int *returnSize) { // This is Step 1 long xor_val = 0; for (int i = 0; i < numsSize; i++) { xor_val ^= nums[i]; } // This is Step 2 int shift = 31 - clz((uint32_t)xor_val); unsigned int mask = 1U << shift; // This is Step 3 int a = 0, b = 0; for (int i = 0; i < numsSize; i++) { if (nums[i] & mask) a ^= nums[i]; else b ^= nums[i]; } // This is Step 4 int *res = (int *)malloc(2 * sizeof(int)); res[0] = a; res[1] = b; *returnSize = 2; return res; } ``` ### C code with `__builtin_clz` ```c int *singleNumber(int *nums, int numsSize, int *returnSize) { int xor_val = 0; for (int i = 0; i < numsSize; i++) { xor_val ^= nums[i]; } int shift = 31 - __builtin_clz((unsigned int)xor_val); unsigned int mask = 1U << shift; int a = 0, b = 0; for (int i = 0; i < numsSize; i++) { if (nums[i] & mask) a ^= nums[i]; else b ^= nums[i]; } int *res = (int *)malloc(2 * sizeof(int)); res[0] = a; res[1] = b; *returnSize = 2; return res; } ``` ### RV32I Assembly code without `clz` ```riscv .data # ==== Three test cases ==== nums1: .word 2, 2, 3, 3, 4, 4, 0, 1, 100, 100, 99, 99 nums1_size: .word 12 ans1: .word 1, 0 nums2: .word 101, 17, 102, 102, -98, 0, 1, 101, 0, 1, 99, -98, 100, 17 nums2_size: .word 14 ans2: .word 99, 100 nums3: .word -2, -2, -2, 2, 2, 2, -6, -9, -2, 2, 2, -5, 2, -6, -2, -10, -11, -10, -11, -2, -6, -9 nums3_size: .word 22 ans3: .word -5, -6 # === Pointer tables === test_cases: .word nums1, nums2, nums3 test_sizes: .word nums1_size, nums2_size, nums3_size test_ans: .word ans1, ans2, ans3 ressize: .word 0 result: .word 0, 0 # ===== Display strings ===== msg_case: .string "Test case " msg_input: .string "Input: " msg_output: .string "Output: " msg_expect: .string "Expect: " msg_ok: .string "✅ Correct\n" msg_wrong: .string "❌ Wrong\n" space: .string " " nl: .string "\n" .text .global main # ===================================================== # main: iterate through three test cases for singleNumber # ===================================================== main: # Setup iterators for the three pointer tables la s5, test_cases la s6, test_sizes la s7, test_ans li s3, 3 # total test cases li s4, 1 # current case index (1-based) loop_cases: beqz s3, end_main # Load current case pointers lw s0, 0(s5) # s0 = &nums lw s1, 0(s6) # s1 = &size lw s2, 0(s7) # s2 = &expected answer # Print "Test case i" la a0, msg_case li a7, 4 ecall mv a0, s4 li a7, 1 ecall la a0, nl li a7, 4 ecall # Print input array la a0, msg_input li a7, 4 ecall lw t3, 0(s1) li t4, 0 print_input_loop: bge t4, t3, print_input_done slli t5, t4, 2 add t6, s0, t5 lw a0, 0(t6) li a7, 1 ecall la a0, space li a7, 4 ecall addi t4, t4, 1 j print_input_loop print_input_done: la a0, nl li a7, 4 ecall # Save caller-saved registers addi sp, sp, -24 sw ra, 20(sp) sw s0, 16(sp) sw s1, 12(sp) sw s2, 8(sp) sw s3, 4(sp) sw s4, 0(sp) # Call singleNumber(nums, size, &ressize) mv a0, s0 lw a1, 0(s1) la a2, ressize jal singleNumber mv t6, a0 # Restore saved registers lw ra, 20(sp) lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 24 # Print output la a0, msg_output li a7, 4 ecall lw t0, 0(t6) mv a0, t0 li a7, 1 ecall la a0, space li a7, 4 ecall lw t1, 4(t6) mv a0, t1 li a7, 1 ecall la a0, nl li a7, 4 ecall # Print expected answer la a0, msg_expect li a7, 4 ecall lw t2, 0(s2) mv a0, t2 li a7, 1 ecall la a0, space li a7, 4 ecall lw t3, 4(s2) mv a0, t3 li a7, 1 ecall la a0, nl li a7, 4 ecall # Check correctness lw t4, 0(t6) lw t5, 4(t6) beq t4, t2, check_second j print_wrong check_second: beq t5, t3, print_ok j print_wrong print_ok: la a0, msg_ok li a7, 4 ecall j next_case print_wrong: la a0, msg_wrong li a7, 4 ecall next_case: addi s5, s5, 4 addi s6, s6, 4 addi s7, s7, 4 addi s4, s4, 1 addi s3, s3, -1 j loop_cases end_main: li a7, 10 ecall # ===================================================== # function: singleNumber # input: a0 = pointer to nums, a1 = numsSize, a2 = returnSize # output: a0 = pointer to integer array # ===================================================== singleNumber: li s2, 0 li t0, 0 # First loop: XOR all numbers for_loop_1: bge t0, a1, after_for_1 slli t1, t0, 2 add t1, a0, t1 lw t1, 0(t1) xor s2, s2, t1 addi t0, t0, 1 j for_loop_1 after_for_1: # diff_bit = xor_all & (-xor_all) neg t0, s2 and s4, s2, t0 li t1, 0 li t2, 0 li t0, 0 # Second loop: XOR numbers into two groups for_loop_2: bge t0, a1, done slli t3, t0, 2 add t3, a0, t3 lw t3, 0(t3) and t4, t3, s4 beqz t4, else_part xor t1, t1, t3 j inc_i else_part: xor t2, t2, t3 inc_i: addi t0, t0, 1 j for_loop_2 # Store result [a, b] and return done: la a0, result sw t1, 0(a0) sw t2, 4(a0) li t3, 2 sw t3, 0(a2) ret ``` ### RV32I Assembly code with `clz` ```riscv .data nums1: .word 2, 2, 3, 3, 4, 4, 0, 1, 100, 100, 99, 99 nums1_size: .word 12 ans1: .word 1, 0 nums2: .word 101, 17, 102, 102, -98, 0, 1, 101, 0, 1, 99, -98, 100, 17 nums2_size: .word 14 ans2: .word 100, 99 nums3: .word -2, -2, -2, 2, 2, 2, -6, -9, -2, 2, 2, -5, 2, -6, -2, -10, -11, -10, -11, -2, -6, -9 nums3_size: .word 22 ans3: .word -5, -6 test_cases: .word nums1, nums2, nums3 test_sizes: .word nums1_size, nums2_size, nums3_size test_ans: .word ans1, ans2, ans3 ressize: .word 0 result: .word 0, 0 msg_case: .string "Test case " msg_input: .string "Input: " msg_output: .string "Output: " msg_expect: .string "Expect: " msg_ok: .string "✅ Correct\n" msg_wrong: .string "❌ Wrong\n" space: .string " " nl: .string "\n" .text .global main main: # initialize iterators for case tables la s5, test_cases la s6, test_sizes la s7, test_ans li s3, 3 li s4, 1 loop_cases: beqz s3, end_main # load current pointers lw s0, 0(s5) lw s1, 0(s6) lw s2, 0(s7) # print "Test case i" la a0, msg_case li a7, 4 ecall mv a0, s4 li a7, 1 ecall la a0, nl li a7, 4 ecall # print input array la a0, msg_input li a7, 4 ecall lw t3, 0(s1) li t4, 0 print_input_loop: bge t4, t3, print_input_done slli t5, t4, 2 add t6, s0, t5 lw a0, 0(t6) li a7, 1 ecall la a0, space li a7, 4 ecall addi t4, t4, 1 j print_input_loop print_input_done: la a0, nl li a7, 4 ecall # save registers before call addi sp, sp, -24 sw ra, 20(sp) sw s0, 16(sp) sw s1, 12(sp) sw s2, 8(sp) sw s3, 4(sp) sw s4, 0(sp) # call singleNumber(a0=nums, a1=size, a2=&ressize) mv a0, s0 lw a1, 0(s1) la a2, ressize jal singleNumber mv t6, a0 # restore registers lw ra, 20(sp) lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 24 # print output la a0, msg_output li a7, 4 ecall lw t0, 0(t6) mv a0, t0 li a7, 1 ecall la a0, space li a7, 4 ecall lw t1, 4(t6) mv a0, t1 li a7, 1 ecall la a0, nl li a7, 4 ecall # print expected result la a0, msg_expect li a7, 4 ecall lw t2, 0(s2) mv a0, t2 li a7, 1 ecall la a0, space li a7, 4 ecall lw t3, 4(s2) mv a0, t3 li a7, 1 ecall la a0, nl li a7, 4 ecall # compare results lw t4, 0(t6) lw t5, 4(t6) beq t4, t2, check_second j print_wrong check_second: beq t5, t3, print_ok j print_wrong print_ok: la a0, msg_ok li a7, 4 ecall j next_case print_wrong: la a0, msg_wrong li a7, 4 ecall next_case: addi s5, s5, 4 addi s6, s6, 4 addi s7, s7, 4 addi s4, s4, 1 addi s3, s3, -1 j loop_cases end_main: li a7, 10 ecall ################################## # Count Leading Zeros (clz) # a0 = input, return a0 = clz(x) ################################## clz: li s0, 32 li s1, 16 clz_loop: srl t0, a0, s1 bnez t0, clz_if srli s1, s1, 1 j clz_check clz_if: sub s0, s0, s1 mv a0, t0 clz_check: bnez s1, clz_loop sub a0, s0, a0 ret ################################## # singleNumber # a0 = nums ptr, a1 = size, a2 = &returnSize # return a0 = &result ################################## singleNumber: li s2, 0 li t0, 0 # first loop: xor all numbers for1_cond: blt t0, a1, for1_body li t1, 31 addi sp, sp, -8 sw ra, 0(sp) sw a0, 4(sp) mv a0, s2 jal ra, clz sub s3, t1, a0 lw a0, 4(sp) lw ra, 0(sp) addi sp, sp, 8 li t1, 1 sll s4, t1, s3 # mask = 1U << shift li t1, 0 # a = 0 li t2, 0 # b = 0 li t0, 0 j for2_cond for1_body: slli t1, t0, 2 add t1, a0, t1 lw t1, 0(t1) xor s2, s2, t1 addi t0, t0, 1 j for1_cond # second loop: split by mask bit for2_cond: blt t0, a1, for2_body la a0, result sw t1, 0(a0) sw t2, 4(a0) li t3, 2 sw t3, 0(a2) ret for2_body: slli t3, t0, 2 add t3, a0, t3 lw t3, 0(t3) and t4, t3, s4 bnez t4, for2_if xor t2, t2, t3 j for2_next for2_if: xor t1, t1, t3 for2_next: addi t0, t0, 1 j for2_cond ``` ### Loop Unrolling Optimization In our RV32I assembly implementation, we identified three major loops: 1. the clz (Count Leading Zeros) function, 2. the first loop (which performs a global XOR across all numbers), and 3. the second loop (which splits elements into two groups based on a mask bit). Our goal is to optimize these three loops using loop unrolling to reduce branching overhead and improve instruction-level parallelism (ILP) on a 5-stage pipeline processor. * **clz:** * The clz function takes a 32-bit integer as input and performs a do-while loop that iterates exactly **five times**—checking shifts by 16, 8, 4, 2, and 1 bits respectively. Because this loop always executes a fixed number of iterations regardless of the input value, we can **fully unroll** it. * By expanding each iteration manually, we completely remove all branch and jump instructions associated with the loop control. This eliminates dynamic branching cost, which is especially beneficial in RV32I pipelines where every conditional branch can cause a flush and stall. * In other words, clz becomes a straight-line sequence of shift and compare instructions. This trades a small increase in code size for a significant improvement in runtime predictability and speed. * **First Loop: Partial Unrolling by Four** * The first loop performs a simple reduction: ```c for (int i = 0; i < numsSize; i++) { xor_val ^= nums[i]; } ``` It iterates `numSize` times — i.e. once per test data element. * The original version processes one element per iteration: ```c for1_body: slli t1, t0, 2 add t1, a0, t1 lw t1, 0(t1) xor s2, s2, t1 addi t0, t0, 1 j for1_cond ``` We apply loop unrolling by a factor of 4, so that each iteration processes four consecutive integers. * Here is the optimized version: ```riscv loop1_unroll4: addi t2, s1, 4 bgt t2, a1, loop1_remainder lw t3, 0(s0) lw t4, 4(s0) lw t5, 8(s0) lw t6, 12(s0) xor s2, s2, t3 xor s2, s2, t4 xor s2, s2, t5 xor s2, s2, t6 addi s0, s0, 16 # ptr += 16 addi s1, s1, 4 # i += 4 j loop1_unroll4 ``` A small remainder loop handles any leftover elements: ```riscv loop1_remainder: bge s1, a1, after_loop1 loop1_rem_iter: bge s1, a1, after_loop1 lw t3, 0(s0) xor s2, s2, t3 addi s0, s0, 4 addi s1, s1, 1 j loop1_rem_iter ``` Although the total number of assembly lines increases, the number of branch and jump instructions per processed element decreases by roughly **75%**. This reduces control hazards and improves ILP, since multiple lw and xor instructions can now overlap in the pipeline. You can verify this improvement in the **Performance** section later in this document. * **Second Loop: Partial Unrolling by Four** * The second loop also iterates over `numSize` elements but performs conditional XOR operations depending on whether (`nums[i] & mask`) is zero. Like the first loop, we unroll it by four iterations to minimize branch overhead while preserving correctness. * Each unrolled block loads four elements, performs four conditional checks, and applies the corresponding XORs into two accumulators (a and b). * Even though this increases the code length, the loop executes significantly faster in the common case. ### Optimized RV32I Assembly code with loop unrolling ```riscv .data nums1: .word 2, 2, 3, 3, 4, 4, 0, 1, 100, 100, 99, 99 nums1_size: .word 12 ans1: .word 1, 0 nums2: .word 101, 17, 102, 102, -98, 0, 1, 101, 0, 1, 99, -98, 100, 17 nums2_size: .word 14 ans2: .word 100, 99 nums3: .word -2, -2, -2, 2, 2, 2, -6, -9, -2, 2, 2, -5, 2, -6, -2, -10, -11, -10, -11, -2, -6, -9 nums3_size: .word 22 ans3: .word -5, -6 test_cases: .word nums1, nums2, nums3 test_sizes: .word nums1_size, nums2_size, nums3_size test_ans: .word ans1, ans2, ans3 ressize: .word 0 result: .word 0, 0 msg_case: .string "Test case " msg_input: .string "Input: " msg_output: .string "Output: " msg_expect: .string "Expect: " msg_ok: .string "✅ Correct\n" msg_wrong: .string "❌ Wrong\n" space: .string " " nl: .string "\n" .text .global main main: # initialize iterators for case tables la s5, test_cases la s6, test_sizes la s7, test_ans li s3, 3 # total cases li s4, 1 # case index (1-based) loop_cases: beqz s3, end_main # load current pointers lw s0, 0(s5) # s0 = nums ptr lw s1, 0(s6) # s1 = size ptr lw s2, 0(s7) # s2 = ans ptr # print "Test case i" la a0, msg_case li a7, 4 ecall mv a0, s4 li a7, 1 ecall la a0, nl li a7, 4 ecall # print input array la a0, msg_input li a7, 4 ecall lw t3, 0(s1) # size li t4, 0 print_input_loop: bge t4, t3, print_input_done slli t5, t4, 2 add t6, s0, t5 lw a0, 0(t6) li a7, 1 ecall la a0, space li a7, 4 ecall addi t4, t4, 1 j print_input_loop print_input_done: la a0, nl li a7, 4 ecall # save registers before call addi sp, sp, -24 sw ra, 20(sp) sw s0, 16(sp) sw s1, 12(sp) sw s2, 8(sp) sw s3, 4(sp) sw s4, 0(sp) # call singleNumber(a0=nums, a1=size, a2=&ressize) mv a0, s0 lw a1, 0(s1) la a2, ressize jal singleNumber mv t6, a0 # t6 = &result # restore registers lw ra, 20(sp) lw s0, 16(sp) lw s1, 12(sp) lw s2, 8(sp) lw s3, 4(sp) lw s4, 0(sp) addi sp, sp, 24 # print output la a0, msg_output li a7, 4 ecall lw t0, 0(t6) mv a0, t0 li a7, 1 ecall la a0, space li a7, 4 ecall lw t1, 4(t6) mv a0, t1 li a7, 1 ecall la a0, nl li a7, 4 ecall # print expected result la a0, msg_expect li a7, 4 ecall lw t2, 0(s2) mv a0, t2 li a7, 1 ecall la a0, space li a7, 4 ecall lw t3, 4(s2) mv a0, t3 li a7, 1 ecall la a0, nl li a7, 4 ecall # compare results lw t4, 0(t6) lw t5, 4(t6) beq t4, t2, check_second j print_wrong check_second: beq t5, t3, print_ok j print_wrong print_ok: la a0, msg_ok li a7, 4 ecall j next_case print_wrong: la a0, msg_wrong li a7, 4 ecall next_case: addi s5, s5, 4 addi s6, s6, 4 addi s7, s7, 4 addi s4, s4, 1 addi s3, s3, -1 j loop_cases end_main: li a7, 10 ecall ################################## # Count Leading Zeros (unrolled) # input : a0 = 32-bit unsigned # output: a0 = #leading zeros ################################## clz: addi sp, sp, -16 sw s8, 0(sp) sw s9, 4(sp) sw s10, 8(sp) sw s11,12(sp) beqz a0, clz_zero li s8, 32 mv s9, a0 # (x >> 16) srli s10, s9, 16 beqz s10, clz_chk8 addi s8, s8, -16 mv s9, s10 clz_chk8: # (x >> 8) srli s10, s9, 8 beqz s10, clz_chk4 addi s8, s8, -8 mv s9, s10 clz_chk4: # (x >> 4) srli s10, s9, 4 beqz s10, clz_chk2 addi s8, s8, -4 mv s9, s10 clz_chk2: # (x >> 2) srli s10, s9, 2 beqz s10, clz_chk1 addi s8, s8, -2 mv s9, s10 clz_chk1: # (x >> 1) srli s10, s9, 1 beqz s10, clz_ret addi s8, s8, -1 mv s9, s10 clz_ret: sub a0, s8, s9 lw s8, 0(sp) lw s9, 4(sp) lw s10, 8(sp) lw s11, 12(sp) addi sp, sp, 16 ret clz_zero: li a0, 32 lw s8, 0(sp) lw s9, 4(sp) lw s10, 8(sp) lw s11, 12(sp) addi sp, sp, 16 ret ################################## # singleNumber (loop unrolled) # input : a0 = nums*, a1 = size, a2 = &returnSize # output: a0 = &result ################################## singleNumber: # first loop: XOR all numbers (unrolled ×4) li s2, 0 li s1, 0 mv s0, a0 loop1_unroll4: addi t2, s1, 4 bgt t2, a1, loop1_remainder lw t3, 0(s0) lw t4, 4(s0) lw t5, 8(s0) lw t6, 12(s0) xor s2, s2, t3 xor s2, s2, t4 xor s2, s2, t5 xor s2, s2, t6 addi s0, s0, 16 addi s1, s1, 4 j loop1_unroll4 loop1_remainder: bge s1, a1, after_loop1 loop1_rem_iter: bge s1, a1, after_loop1 lw t3, 0(s0) xor s2, s2, t3 addi s0, s0, 4 addi s1, s1, 1 j loop1_rem_iter after_loop1: # compute mask bit li t1, 31 addi sp, sp, -8 sw ra, 0(sp) sw a0, 4(sp) mv a0, s2 jal ra, clz sub s3, t1, a0 lw a0, 4(sp) lw ra, 0(sp) addi sp, sp, 8 li t1, 1 sll s4, t1, s3 # second loop: split by mask (unrolled ×4) li t1, 0 li t2, 0 li s1, 0 mv s0, a0 loop2_unroll4: addi t3, s1, 4 bgt t3, a1, loop2_remainder lw t3, 0(s0) lw t4, 4(s0) lw t5, 8(s0) lw t6, 12(s0) # element1 and t0, t3, s4 beqz t0, l2_e1 xor t1, t1, t3 j l2_n1 l2_e1: xor t2, t2, t3 l2_n1: # element2 and t0, t4, s4 beqz t0, l2_e2 xor t1, t1, t4 j l2_n2 l2_e2: xor t2, t2, t4 l2_n2: # element3 and t0, t5, s4 beqz t0, l2_e3 xor t1, t1, t5 j l2_n3 l2_e3: xor t2, t2, t5 l2_n3: # element4 and t0, t6, s4 beqz t0, l2_e4 xor t1, t1, t6 j l2_n4 l2_e4: xor t2, t2, t6 l2_n4: addi s0, s0, 16 addi s1, s1, 4 j loop2_unroll4 loop2_remainder: bge s1, a1, end_loop2 loop2_rem_iter: bge s1, a1, end_loop2 lw t3, 0(s0) and t0, t3, s4 beqz t0, l2_eR xor t1, t1, t3 j l2_nR l2_eR: xor t2, t2, t3 l2_nR: addi s0, s0, 4 addi s1, s1, 1 j loop2_rem_iter end_loop2: la a0, result sw t1, 0(a0) sw t2, 4(a0) li t3, 2 sw t3, 0(a2) ret ``` ## Performance By using [Ripes](https://github.com/mortbopet/Ripes) simulator, We evaluated multiple implementations of the same algorithm — both at the C and assembly levels — with and without loop unrolling and the clz optimization. | C code without `clz` | C code with `clz` | C code with `__builtin_clz` | | -------- | -------- | -------- | | ![image](https://hackmd.io/_uploads/r1UcuZ-aee.png =100%x) | ![image](https://hackmd.io/_uploads/HyN8u-Zaxx.png =110%x) | ![image](https://hackmd.io/_uploads/ByTJRZWTxg.png =80%x) | * We can see that when we directly compiled C code into Ripes, the number of cycles is extremely high (≈98 k). * This is because the compiler generates many generic operations that are not tailored for the RV32I instruction set. * Even though the clz function or built-in intrinsic helps reduce the logical complexity, it does not reduce the actual loop control cost, because the compiler still emits similar branching and comparison sequences. | RV32I Assembly code without `clz` | RV32I Assembly code with `clz` | Optimized RV32I Assembly code with loop unrolling | | -------- | -------- | -------- | | ![image](https://hackmd.io/_uploads/rJiBvW-agl.png =100%x) | ![image](https://hackmd.io/_uploads/S1vRDZbage.png =100%x)| ![image](https://hackmd.io/_uploads/SJeqD--Tex.png =75%x) | * As you can see, the hand-written RV32I assembly is far more efficient than the C-compiled output. * However, the key improvement comes from loop unrolling: * The optimized version executes **fewer total instructions (−20%)**. * The total **cycle count decreases by roughly 27%** compared to the baseline assembly version. * **CPI improves from 1.60 → 1.45**, and IPC increases to 0.692, indicating higher pipeline utilization. > As the number of test data and test cases increases, the performance advantage of the unrolled version becomes even more significant. > This is because loop unrolling amortizes the loop control overhead across more iterations, leading to higher efficiency gains for larger datasets. > In our experiment, there are **only three test cases**, and **each test case contains fewer than twenty data elements** on average — yet the optimized version already shows a noticeable performance gap. > This clearly demonstrates that even with small input sizes, loop unrolling can effectively reduce control overhead and improve execution throughput on RV32I processors. ## Analysis Again, we test our code by using [Ripes](https://github.com/mortbopet/Ripes) simulator. ### 5-stage pipelined processor Ripes supports three types of processors: 1. Single-cycle processor 2. 5-stage pipelined processor w/o forwarding or hazard detection 3. 5-stage pipelined processor w/o hazard detection 4. 5-Stage pipelined processor w/o forwarding unit 5. 5-stage pipelined processor 6. 6-stage dual-issue processor For this assignment, the 5-stage pipelined processor has been selected as the target device because it is the most commonly used architecture. Its block diagram look like this: ![image](https://hackmd.io/_uploads/Hy35njeTgx.png) The "5-stage" means this processor using five-stage pipeline to parallelize instructions. The stages are: | Stage | Description | | -------- | -------- | | IF | Instruction Fetch | | ID | Instruction Decode and Register Fetch | | EX | Execution or Address Calculation | | MEM | Memory Access | | WB | Register Write Back | **Main Task:** * **IF Stage :** Fetch the next instruction from memory (using the Program Counter) and update the PC. * **ID Stage :** Decode the instruction, read source registers, and determine the operation type. * **EXE Stage :** Perform arithmetic or logic operations in the ALU, or compute a memory address. * **MEM Stage :** For load/store instructions, read from or write to data memory. * **WB Stage :** Write the result from the ALU or memory back to the destination register. Instruction in different type of format will go through 5 stages with different signal turned on. Let's discuss I-type format in detail with an example as below. #### I-type format `slli x8, x4, 5` This is an **I-type** instruction in RISC-V, performing an **immediate logically shift left**, where the value of register x4 is logically shifted left 5 bits, and the result is stored in register x8. ##### I-Type Instruction Format: `| funct7[31:25] | shamt[24:20] | rs1[19:15] | funct3[14:12] | rd[11:7] | opcode[6:0] |` * **funct7 :** the 7-bit function code for `slli` is `0000000 ` * **shamt :** shift amount is `5(00101)` * **rs1 :** the source register `x4(00100)` * **funct3 :** the 3-bit function code for `slli` is `001 ` * **rd :** the destination register is `x8(01000) ` * **opcode :** the opcode for `slli` is `0010011 ` Thus, The machine code of `slli x8, x4, 5` is `0000000 00101 00100 001 01000 0010011(bin)` = `0x00521413(hex)` ##### 1. Instruction Fetch (IF) ![image](https://hackmd.io/_uploads/SypREaxpeg.png =50%x) * We start from instruction put at `0x00000000`, so `addr` is equal to `0x00000000` * The machine code of the instruction is `0x00521413`, so `instr` is equal to 0x00521413. * PC will increment by 4 automatically using the above adder, because the instruction of RV32I is 32 bits long. * Because there is no branch occur, next instruction will be at PC + 4, so the multiplexer before PC choose input come from adder. ##### 2. Instruction Fetch (ID) ![image](https://hackmd.io/_uploads/ryYJnngaxx.png =50%x) * Instruction `0x00521413` is decoded to five part: * `opcode` = `slli` * `Wr idx` = `0x08` * `imm.` = `0x00000005 ` * `R1 idx` = `0x04` * `R2 idx` = `0x05` * Though I-type format read `R1 idx(0x04)` and `R2 idx(0x05)`, the register value in `R2 idx` will not be used in EX stage. * `R1 idx(0x04)` and `R2 idx(0x05)` will be sent to Registers for extracting the register value which are both `0x00000000`, because the initial value is `0x00000000`. * Current PC value `(0x00000000)`, next PC value `(0x00000004)` and `Wr idx (0x08)` are just sent through this stage, we don't use them. ##### 3. Execute (EX) ![image](https://hackmd.io/_uploads/SkMYahgpxx.png =50%x) * First level multiplexers choose value come from `Reg 1` and `Reg 2`, but this is an I-type format instruction, we don't use `Reg 2`. So they are filtered by second level multiplexer. * Second level multiplexer choose value come from `Reg 1` rather than current PC value (upper one) and immediate (lower one) as `Op1` and `Op2` of ALU for executing shift left instruction. * ALU add two operand togeher, so the `Res` is equal to `0x00000000` (0 << 5 is also 0). * `Reg 1` and `Reg 2` are also send to branch block, but no branch is taken. * Next PC value `(0x00000004)` and `Wr idx (0x08)` are just send through this stage, we don't use them. ##### 4. Memory access (MEM) ![image](https://hackmd.io/_uploads/Hyd06hl6lx.png =45%x) * `Res` from ALU is send to 3 ways: * Pass through this stage and go to WB stage (the lower line) * Send back to EX stage for next instruction to use (the upper line) * Use as data memory address (the middle line). Memory read data at address `0x00000000`, so Read out is equal to `0x00521413`. The table below denotes the data section of memory. * ![image](https://hackmd.io/_uploads/HJMjCAlaee.png) * Otherwise, `Reg 2` is send to `Data in`, but memory doesn't enable writing. * Next PC value `(0x00000004)` and `Wr idx (0x08)` are just send through this stage, we don't use them. ##### 5. Register write back (WB) ![image](https://hackmd.io/_uploads/SJ6SAhgale.png =40%x) * The multiplexer choose `Res` from ALU(the middle line) as final output, so the output value is `0x00000000`. * The output value and `Wr idx` are send back to registers block in ID stage, and `Wr En` is 1. Finally, the value `0x00000000` will be write into `x8` register, whose ABI name is `s0`. After all these stage are done, the register is updated like this: ![image](https://hackmd.io/_uploads/B1J4lybTle.png =30%x) Finally, all the source code mentioned above can be found [Here](https://github.com/Shaoen-Lin/ca2025-quizzes). Feel free to check it out ! ## Reference * [Quiz1 of Computer Architecture (2025 Fall)](https://hackmd.io/@sysprog/arch2025-quiz1-sol) * [Assignment 1: RISC-V Assembly and Instruction Pipeline](https://hackmd.io/@sysprog/2025-arch-homework1) * [LeetCode 260. Single Number III](https://leetcode.com/problems/single-number-iii/description/)