# Assignment 1: RISC-V Assembly and Instruction Pipeline >[!Note] AI tools usage >I use ChatGPT and Claude to deal with Quiz 1 by providing code explanations, related research, code summaries, debug, and generate table for test data. ## Problem B * Refer to [Quiz1 of Computer Architecture (2025 Fall) Problem B](https://hackmd.io/@sysprog/arch2025-quiz1-sol) * You can find the source code [here](https://github.com/CarSam16/ca2025-quizzes). Feel free to fork and modify it. * uf8 format: ``` ┌──────────────┬────────────────┐ │ Exponent (4) │ Mantissa (4) │ └──────────────┴────────────────┘ 7 4 3 0 ``` ### Assembly code : uf8 decode ``` # input uf8 value : a0 # output uint32 value : a0 uf8_decode: andi t1, a0, 0x0f # mantissa = fl & 0x0f srli t2, a0, 4 # exponent = fl >> 4 li t3, 15 sub t3, t3, t2 # t3 = (15 - exponent) li t4, 0x7FFF srl t3, t4, t3 # t3 = (0x7FFF >> (15 - exponent)) slli t3, t3, 4 # t3 = (0x7FFF >> (15 - exponent)) << 4; sll t1, t1, t2 # t1 = offset = (mantissa << exponent) add a0, t1, t3 # (mantissa << exponent) + offset jr ra ``` ### Assembly code : uf8 encode ``` # input uint32 : a0 # output uf8 : a0 uf8_encode: # callee save addi sp, sp, -8 sw ra, 0(sp) # used to call CLZ function sw s0, 4(sp) add s0, a0, x0 # s0 keep original value # /* Use CLZ for fast exponent calculation */ li t0, 16 blt s0, t0, end_encode jal ra, CLZ_myfunction # call CLZ li t0, 31 sub t0, t0, a0 # msb = 31 - lz, lz is in a0 li a1, 0 # exp li a2, 0 # overflow li t1, 5 blt t0, t1, find_exact_exp li t1, 4 sub a1, t0, t1 li t1, 15 blt a1, t1, cal_overflow li a1, 15 cal_overflow: li t1, 0 cal_overflow_loop: bge t1, a1, adjust_loop slli a2, a2, 1 addi a2, a2, 16 # # overflow = (overflow << 1) + 16 addi t1, t1, 1 # counter ++ j cal_overflow_loop adjust_loop: bltz a1, find_exact_exp bge s0, a2, find_exact_exp addi t2, a2, -16 srli a2, t2, 1 addi a1, a1, -1 j adjust_loop find_exact_exp: li t1, 15 bge a1, t1, end_encode slli t2, a2, 1 addi t2, t2, 16 blt s0, t2, end_encode add a2, t2, x0 addi a1, a1, 1 j find_exact_exp end_encode: sub t2, s0, a2 srl t2, t2, a1 slli a1, a1, 4 or a0, a1, t2 return: lw s0, 4(sp) lw ra, 0(sp) addi sp, sp, 8 jr ra ``` ### CLZ function ``` # Function: clz # Arguments: # a0 = x (input) # Return: # a0 = n - x # Temporaries: # t0 = n # t1 = c # t2 = y CLZ_myfunction: li t0, 32 # n li t1, 16 # c do_while: srl t2, a0, t1 # y = x >> c bnez t2, if_loop # if (y != 0) -> go to if_loop srli t1, t1, 1 # c >>= 1 beqz t1, end_loop # if (c == 0) break j do_while # continue loop if_loop: sub t0, t0, t1 # n -= c add a0, t2, x0 # x = y srli t1, t1, 1 # c >>= 1 beqz t1, end_loop # if (c == 0) break j do_while # loop again end_loop: sub a0, t0, a0 jr ra ``` ### Assembly code : uf8 test VALIDATE_ROUNDTRIP: Test encoding and decoding for all 256 values ``` .data msg1: .asciz ": produces value " msg2: .asciz " but encodes back to " msg3: .asciz ": value " msg4: .asciz " <= previous_value " msg5: .asciz "All tests passed.\n" msg6: .asciz "Some tests failed.\n" newline: .asciz "\n" .align 2 .text .globl main main: jal ra, test # start to test beq a0, x0, Not_pass # fail la a0, msg5 # print msg5 when passing li a7, 4 ecall li a7, 93 # ecall: exit li a0, 0 # exit code is 0, successful ecall Not_pass: la a0, msg6 # print msg6 when not passing li a7, 4 ecall li a7, 93 # ecall: exit li a0, 1 # exit code is 1, not successful ecall test: addi sp, sp, -4 sw ra, 0(sp) # because test need to call other function addi s0, x0, -1 # previous_value li s1, 1 # passed, 1 means true, 0 means false li s2, 0 # f1, counter from 0 to 255 li s3, 256 # counter's end For_2: add a0, s2, x0 # prepare a0 for uf8_decode jal ra, uf8_decode add s4, a0, x0 # value (return value from uf8_decode) add a0, s4, x0 # prepare a0 for uf8_encode jal ra, uf8_encode add s5, a0, x0 # fl2 (return value from uf8_encode) test_if_1: beq s2, s5, test_if_2 mv a0, s2 # print s2(f1) li a7, 34 # (RARS) print integer in hex ecall la a0, msg1 # print msg1 li a7, 4 ecall mv a0, s4 # print value li a7, 1 ecall la a0, msg2 # print msg2 li a7, 4 ecall mv a0, s5 # prepare to print fl2(s5)'s hexdecimal li a7, 34 # (RARS) print integer in hex ecall la a0, newline # print newline li a7, 4 ecall li s1, 0 # passed = false test_if_2: blt s0, s4, after_if mv a0, s2 # print s2(f1) li a7, 34 # (RARS) print integer in hex ecall la a0, msg3 # print msg1 li a7, 4 ecall mv a0, s4 # print value li a7, 1 ecall la a0, msg4 # print msg2 li a7, 4 ecall mv a0, s0 # prepare to print s0(previous_value)'s hexdecimal li a7, 34 # (RARS) print integer in hex ecall la a0, newline # print newline li a7, 4 ecall li s1, 0 # passed = false after_if: mv s0, s4 addi s2, s2, 1 blt s2, s3, For_2 mv a0, s1 # return passed lw ra, 0(sp) addi sp, sp, 4 jr ra # jump to ra # Function: clz # Arguments: # a0 = x (input) # Return: # a0 = n - x # Temporaries: # t0 = n # t1 = c # t2 = y ``` ## Problem C * Refer to [Quiz1 of Computer Architecture (2025 Fall) Problem C](https://hackmd.io/@sysprog/arch2025-quiz1-sol) * bf16_t format ``` ┌─────────┬──────────────┬──────────────┐ │Sign (1) │ Exponent (8) │ Mantissa (7) │ └─────────┴──────────────┴──────────────┘ 15 14 7 6 0 S: Sign bit (0 : positive, 1 : negative) E: Exponent bits (bias = 127) M: Mantissa bits ``` * f32_t format ``` ┌─────────┬──────────────┬───────────────┐ │Sign (1) │ Exponent (8) │ Mantissa (23) │ └─────────┴──────────────┴───────────────┘ 31 30 23 22 0 S: Sign bit (0 : positive, 1 : negative) E: Exponent bits (bias = 127) M: Mantissa bits ``` ### Arithmetic Part #### Test data | Test # | Operation | Input A | Input B | BF16 A | BF16 B | Expected (F32) | Expected (BF16) | Tolerance | |--------|-----------|---------|---------|--------|--------|----------------|-----------------|-----------| | 1 | Add | 1.0f | 2.0f | 0x3F80 | 0x4000 | 3.0f (0x40400000) | 0x4040 | ±0.01 | | 2 | Sub | 2.0f | 1.0f | 0x4000 | 0x3F80 | 1.0f (0x3F800000) | 0x3F80 | ±0.01 | | 3 | Mul | 3.0f | 4.0f | 0x4040 | 0x4080 | 12.0f (0x41400000) | 0x4140 | ±0.1 | | 4 | Div | 10.0f | 2.0f | 0x4120 | 0x4000 | 5.0f (0x40A00000) | 0x40A0 | ±0.1 | | 5 | Sqrt | 4.0f | - | 0x4080 | - | 2.0f (0x40000000) | 0x4000 | ±0.01 | #### Test result ![image](https://hackmd.io/_uploads/S106h-sale.png) ![image](https://hackmd.io/_uploads/HyzLTZs6lx.png) #### Assembly code : ``` .data # ---------------------------- # Output strings # ---------------------------- str_banner: .asciz "\n=== BFloat16 Arithmetic Test ===\n\n" str_test_arith: .asciz "Testing arithmetic operations...\n" str_add_fail: .asciz " [Addition failed]\n" str_sub_fail: .asciz " [Subtraction failed]\n" str_mul_fail: .asciz " [Multiplication failed]\n" str_div_fail: .asciz " [Division failed]\n" str_sqrt_fail: .asciz " [Sqrt failed]\n" str_arith_pass: .asciz " Arithmetic: PASS\n" str_summary: .asciz "\n=== Test Summary ===\n" str_passed_count: .asciz "Tests passed: " str_failed_count: .asciz "Tests failed: " str_newline: .asciz "\n" str_add_pass: .asciz "[Add passed]\n" str_sub_pass: .asciz "[Sub passed]\n" str_mul_pass: .asciz "[Mul passed]\n" str_div_pass: .asciz "[Div passed]\n" str_sqrt_pass: .asciz "[Sqrt passed]\n" # ---------------------------- # Counters # ---------------------------- tests_passed: .word 0 tests_failed: .word 0 # ---------------------------- # BF16 constant words (used inside arithmetic implementations) # Keep these labels for compatibility with the original implementation. # ---------------------------- Inf_pos: .word 0x7F800000, 0x7F80 Inf_neg: .word 0xFF800000, 0xFF80 NaN: .word 0xFFC00000, 0xFFC0 normal: .word 0x40490fd0, 0x4049 denormal: .word 0x40000fd0 BF16_SIGN_MASK: .word 0x8000 BF16_EXP_MASK: .word 0x7F80 BF16_MANT_MASK: .word 0x007F BF16_EXP_BIAS: .word 127 BF16_NAN: .word 0x7FC0 BF16_ZERO: .word 0x0 .text .globl main # ============================================================================ # MAIN: entrypoint # ============================================================================ main: addi sp, sp, -16 sw ra, 12(sp) # Print banner la a0, str_banner call print_string # Run the arithmetic tests call test_arithmetic # Print summary call print_summary # Exit gracefully lw ra, 12(sp) addi sp, sp, 16 li a7, 10 ecall # ============================================================================ # test_arithmetic # ============================================================================ test_arithmetic: addi sp, sp, -32 sw ra, 28(sp) sw s0, 24(sp) sw s1, 20(sp) sw s2, 16(sp) sw s3, 12(sp) sw s4, 8(sp) sw s5, 4(sp) la a0, str_test_arith call print_string # --------------------------- # Test 1: Addition 1.0 + 2.0 = 3.0 # --------------------------- li a0, 0x3F80 # float bits for 1.0 #call f32_to_bf16 mv s0, a0 # s0 = bf16(1.0) li a0, 0x4000 # float bits for 2.0 #call f32_to_bf16 mv s1, a0 # s1 = bf16(2.0) mv a0, s0 mv a1, s1 call bf16_add mv s2, a0 # s2 = result bf16 mv a0, s2 #call bf16_to_f32 mv s3, a0 # s3 = result as f32 bits li s4, 0x4040 # expected = 3.0f (bits) mv a0, s3 mv a1, s4 call check_relative_error_1pct beqz a0, arith_add_fail la a0, str_add_pass call print_string call increment_passed j arith_sub_test arith_add_fail: la a0, str_add_fail call print_string call increment_failed j arith_sub_test # --------------------------- # Test 2: Subtraction 2.0 - 1.0 = 1.0 # --------------------------- arith_sub_test: li a0, 0x4000 # 2.0f #call f32_to_bf16 mv s0, a0 li a0, 0x3F80 # 1.0f #call f32_to_bf16 mv s1, a0 mv a0, s0 mv a1, s1 call bf16_sub mv s2, a0 mv a0, s2 #call bf16_to_f32 mv s3, a0 li s4, 0x3F80 # expected = 1.0f mv a0, s3 mv a1, s4 call check_relative_error_1pct beqz a0, arith_sub_fail la a0, str_sub_pass call print_string call increment_passed j arith_mul_test arith_sub_fail: la a0, str_sub_fail call print_string call increment_failed j arith_done # --------------------------- # Test 3: Multiplication 3.0 * 4.0 = 12.0 # --------------------------- arith_mul_test: li a0, 0x4040 # 3.0f #call f32_to_bf16 mv s0, a0 li a0, 0x4080 # 4.0f #call f32_to_bf16 mv s1, a0 mv a0, s0 mv a1, s1 call bf16_mul mv s2, a0 mv a0, s2 #call bf16_to_f32 mv s3, a0 li s4, 0x4140 # expected = 12.0f mv a0, s3 mv a1, s4 call check_relative_error_1pct beqz a0, arith_mul_fail la a0, str_mul_pass call print_string call increment_passed j arith_div_test arith_mul_fail: la a0, str_mul_fail call print_string call increment_failed j arith_done # --------------------------- # Test 4: Division 10.0 / 2.0 = 5.0 # --------------------------- arith_div_test: li a0, 0x4120 # 10.0f #call f32_to_bf16 mv s0, a0 li a0, 0x4000 # 2.0f #call f32_to_bf16 mv s1, a0 mv a0, s0 mv a1, s1 call bf16_div mv s2, a0 mv a0, s2 #call bf16_to_f32 mv s3, a0 li s4, 0x40A0 # expected = 5.0f mv a0, s3 mv a1, s4 call check_relative_error_1pct beqz a0, arith_div_fail la a0, str_div_pass call print_string call increment_passed j arith_sqrt_test arith_div_fail: la a0, str_div_fail call print_string call increment_failed j arith_done # --------------------------- # Test 5: sqrt(9.0) = 3.0 # --------------------------- arith_sqrt_test: li a0, 0x4110 # 9.0f call bf16_sqrt mv s1, a0 li s4, 0x4040 # expected = 3.0f mv a0, s1 mv a1, s4 call check_relative_error_1pct beqz a0, arith_sqrt_fail la a0, str_sqrt_pass call print_string call increment_passed j arith_done arith_sqrt_fail: la a0, str_sqrt_fail call print_string call increment_failed j arith_done # --------------------------- # All arithmetic tests done # --------------------------- arith_done: lw s5, 4(sp) lw s4, 8(sp) lw s3, 12(sp) lw s2, 16(sp) lw s1, 20(sp) lw s0, 24(sp) lw ra, 28(sp) addi sp, sp, 32 ret print_string: li a7, 4 ecall ret print_int: li a7, 1 ecall ret increment_passed: la t0, tests_passed lw t1, 0(t0) addi t1, t1, 1 sw t1, 0(t0) ret increment_failed: la t0, tests_failed lw t1, 0(t0) addi t1, t1, 1 sw t1, 0(t0) ret print_summary: addi sp, sp, -16 sw ra, 12(sp) la a0, str_summary call print_string la a0, str_passed_count call print_string la t0, tests_passed lw a0, 0(t0) call print_int la a0, str_newline call print_string la a0, str_failed_count call print_string la t0, tests_failed lw a0, 0(t0) call print_int la a0, str_newline call print_string lw ra, 12(sp) addi sp, sp, 16 ret check_relative_error_1pct: srli t0, a0, 16 # top16 of actual srli t1, a1, 16 # top16 of expected sub t2, t0, t1 bgez t2, crp_pos sub t2, x0, t2 # abs crp_pos: li t3, 1 ble t2, t3, crp_ok li a0, 0 ret crp_ok: li a0, 1 ret ``` #### Assembly code : Arithmetic functions * Integrates bf16_add, bf16_sub, bf16_mul, bf16_div, bf16_sqrt from w1-bfloat16.s ``` ============================================================================ # - f32_to_bf16 # - bf16_to_f32 # - bf16_add # - bf16_sub # - bf16_mul # - bf16_div # - bf16_sqrt # ============================================================================ # ============================================================================ # f32_to_bf16 # Input: a0 = 32-bit float bits # Output: a0 = 16-bit bf16 bits # ============================================================================ f32_to_bf16: # callee-save addi sp, sp, -8 sw ra, 0(sp) sw s0, 4(sp) mv s0, a0 # preserve original argument in s0 # Extract exponent: (f32bits >> 23) & 0xFF li t0, 0xFF srli t2, s0, 23 and t1, t2, t0 # t1 = exponent beq t1, t0, f32_exception # if exponent == 0xFF -> NaN/Inf path # Normal case: compute rounding bias srli t0, s0, 16 # t0 = f32 >> 16 addi t1, x0, 1 and t0, t0, t1 # t0 = (f32 >> 16) & 1 li t1, 0x7FFF add t0, t0, t1 # t0 = rounding bias (0x7FFF or 0x8000) add t0, s0, t0 # add bias srli a0, t0, 16 # top 16 bits -> bf16 # restore and return beq x0, x0, f32_to_bf16_end f32_exception: srli a0, s0, 16 # for Inf/NaN: just take top 16 bits f32_to_bf16_end: lw ra, 0(sp) lw s0, 4(sp) addi sp, sp, 8 ret # ============================================================================ # bf16_to_f32: # Input: a0 = 16-bit bf16 bits # Output: a0 = 32-bit float bits # ============================================================================ bf16_to_f32: slli a0, a0, 16 ret # ============================================================================ # bf16_add: BF16 addition (bit-accurate software emulation) # Input: a0 = bf16 a, a1 = bf16 b # Output: a0 = bf16 (a + b) # ============================================================================ bf16_add: # callee save addi sp, sp, -28 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) sw s3, 16(sp) sw s4, 20(sp) sw s5, 24(sp) # if a is +-inf / NaN srli s0, a0, 7 # s0 = a >> 7 andi s0, s0, 0xFF # s0 = exponent a srli s1, a1, 7 # s1 = b >> 7 andi s1, s1, 0xFF # s1 = exponent b andi s2, a0, 0x7F # s2 = mantissa a andi s3, a1, 0x7F # s3 = mantissa b srli s4, a0, 15 # s4 = sign a srli s5, a1, 15 # s5 = sign b li t6, 0xFF # t6 = 0xFF bne t6, s0, a_actual_num # if a_exp != 0xFF jump # a exponent == 0xFF beq s3, x0, end_add # if a = NAN return a beq s1, t6, a_b_exp_FF # if b exp = 0xFF, jump to handling section jal x0, end_add # if b is actual number, return a # a b exp = 0xFF a_b_exp_FF: bne s3, x0, return_NAN # if b is NAN, return NAN beq s4, s5, end_add # if a b have same sign return a jal x0, return_NAN # else return NaN a_actual_num: beq s1, t6, add_return_b # b exp = 0xFF, return b # b = 0 slli t1, a1, 1 beq x0, t1, end_add # a = 0 slli t1, a0, 1 beq x0, t1, add_return_b # a b are both actual numbers add_adjust_mant_a: beq x0, s0, add_adjust_mant_b # a exp == 0,no need to adjust addi s2, s2, 0x80 # retrieve hidden 1. for mantissa a add_adjust_mant_b: beq x0, s1, add_main # b exp == 0,no need to adjust addi s3, s3, 0x80 # retrieve hidden 1. for mantissa b ##### ADD MAIN ##### # s0 = result exp # s1 = result mantissa # s2 = mantissa a # s3 = mantissa b # s4 = result sign # t6 = 0xFF add_main: ### fraction alignment ### sub t0, s0, s1 # t0 = a_exp - b_exp li t1, 8 # t1 = 8 mv t3, s2 # put mant_a in buffer mv t4, s3 # put mant_b in buffer # 8 <= exp_diff (exp_diff > 8) blt t1, t0, end_add # |a| is too big, return a # 8 <= -exp_diff (exp_diff < -8) sub t2, x0, t0 # t2 = -exp_diff blt t1, t2, add_return_b # |b| is too big, return b # exp_diff == 0 beq t0, x0, true_add # s0 = a_exp already, just jump # 8 > exp_diff > 0 srl t4, s3, t0 # t4 = mant_b >>= exp_diff; blt t0, x0, true_add # mant_b aligned with s0(result exp) = a_exp, jump # -8 < exp_diff < 0 sub t0, x0, t0 # t0 = -exp_diff (make positive) srl t3, s2, t0 # t3 = mant_a >>= -exp_diff; mv s0, s1 # s0 (result exp) = b exp blt t0, x0, true_add # mant_a aligned with s0(result exp) = b_exp, jump true_add: mv s2, t3 # move aligned t3 (mant_a) to s2 mv s3, t4 # move aligned t4 (mant_b) to s3 bne s4, s5, add_diff_sign add_same_sign: slli s4, s4, 15 ## s4(result sign) = a sign << 15 add s1, s2, s3 ## s1(result mantissa) = mantissa (a + b) andi t0, s1, 0x100 # t0 = overflow bit beq t0, x0, add_result # if no overflow (t0==0), get result add_handle_overflow: srli s1, s1, 1 # s1(result_mant) >>= 1 addi s0, s0, 1 # s0(esult_exp) += 1 blt s0, t6, add_result # s0(esult_exp) < 0xFF, result number is normal, get result mv s0, t6 # else set s0(esult_exp) = 0xFF add s1, x0, x0 # s1(result mantissa) = 0 jal x0, add_result # overflow => return inf with according sign add_diff_sign: # assume mantissa a(s2) < b(s3), set result to sign_b slli s4, s5, 15 # s4(result sign) = sign_b << 15 sub s1, s3, s2 # s1(result mantissa) = s3(mant_b) - s2(mant_a) blt s2, s3, add_handle_zero # assumption is true, handle zero condition # otherwise mantissa a >= b, set result sign = sign a slli s4, s4, 15 # s4(result sign) = sign_a << 15 sub s1, s2, s3 # s1(result mantissa) = s2(mant_a) - s3(mant_b) # check the result of substraction of aligned mantissa not be zero # otherwise error exists when we use a loop to adjust mantissa add_handle_zero: beq s1, x0, add_return_zero # add_adjust_mantissa: andi t1, s1, 0x80 # t1 is the first bit of result_mant bne t1, x0, add_result # if first mantissa bit = 1, done / else shift until 1 is found slli s1, s1, 1 # s1(result_mant) <<= 1 addi s0, s0 -1 # s0(result_exp) -= 1 blt x0, s0, add_return_zero # if 0 >= result_exp, underflow jal x0, add_adjust_mantissa ### # s0 = result exp # s1 = result mantissa # s2 = mantissa a # s3 = mantissa b # s4 = result sign add_result: andi s0, s0, 0xFF # mask out the logic bit out for neg exp slli s0, s0, 7 # left shift to match the correct format andi s1, s1, 0x7F # mask out the first bit of mantissa (1.XX) or a0, s4, s0 # result_sign | result exp or a0, a0, s1 # result_sign | result exp | result mantissa jal x0, end_add ### end of bf16_add function add_return_zero: mv a0, x0 # jal x0, end_add add_return_b: mv a0, a1 jal x0, end_add return_NAN: lw a0, BF16_NAN jal x0, end_add end_add: # retrieve ra and callee save lw s5, 24(sp) lw s4, 20(sp) lw s3, 16(sp) lw s2, 12(sp) lw s1, 8(sp) lw s0, 4(sp) lw ra, 0(sp) addi sp, sp, 28 ret # ============================================================================ # bf16_sub: implemented simply by flipping sign of b and calling bf16_add # Input: a0 = a, a1 = b # Output: a0 = a - b # ============================================================================ bf16_sub: addi sp, sp, -4 sw ra, 0(sp) la t0, BF16_SIGN_MASK xor a1, a1, t0 # flip sign bit of a1 (b) mv a0, a0 jal ra, bf16_add lw ra, 0(sp) addi sp, sp, 4 ret # ============================================================================ # bf16_mul: bit-accurate multiplication # Input: a0 = a (bf16), a1 = b (bf16) # Output: a0 = a * b (bf16) # ============================================================================ bf16_mul: # callee save addi sp, sp, -28 sw ra, 0(sp) sw s0, 4(sp) # s0 = exponent a sw s1, 8(sp) # s1 = exponent b sw s2, 12(sp) # s2 = mantissa a sw s3, 16(sp) # s3 = mantissa b sw s4, 20(sp) # s4 = sign a sw s5, 24(sp) # s5 = sign b # Decomposite a and b into sign / exp / mantissa srli s0, a0, 7 # s0 = a >> 7 andi s0, s0, 0xFF # s0 = exp_a = ((a.bits >> 7) & 0xFF) srli s1, a1, 7 # s1 = b >> 7 andi s1, s1, 0xFF # s1 = exp_b = ((b.bits >> 7) & 0xFF) andi s2, a0, 0x7F # s2 = mant_a = a.bits & 0x7F andi s3, a1, 0x7F # s3 = mant_b = b.bits & 0x7F srli s4, a0, 15 # s4 = sign a srli s5, a1, 15 # s5 = sign b ##### # t3 = result mant # t4 = result exp # t5 = result sign # t6 = 0xFF xor t5, s4, s5 # t5(esult sign) = sign_a ^ sign_b addi t6, x0, 0xFF # t6 = 0xFF # check a exp bne s0, t6, mul_check_exp_b # if exp_a != 0xFF, jump bne s0, x0, end_mul # if a = NaN, return a # a is +-inf li t0, 0x7FFF # sign filter and t1, a1, t0 # t1 = b | 0x7FFF beq t1, x0, mul_return_NAN # if b = +- 0, inf * 0 = NaN # return inf with correct sign slli t5, t5, 15 # li t0, 0x7F80 or a0, t5, t0 # (result_sign << 15) | 0x7F80 jal x0, end_mul mul_check_exp_b: bne s0, t6, mul_check_zero # if exp_b != 0xFF, jump bne s0, x0, mul_return_b # if b = NaN, return b # b is +-inf li t0, 0x7FFF # sign filter and t1, a0, t0 # t1 = a | 0x7FFF beq t1, x0, mul_return_NAN # if a = +- 0, 0 * inf = NaN # return inf with correct sign slli t5, t5, 15 # li t0, 0x7F80 or a0, t5, t0 # (result_sign << 15) | 0x7F80 jal x0, end_mul ### a, b both are actual numbers # check 0 * 0 = 0 mul_check_zero: or t0, s0, s1 # t0 = exp_a | exp_b or t0, t0, s2 # t0 = exp_a | exp_b | mant_a or t0, t0, s3 # t0 = exp_a | exp_b | mant_a | mant_b bne t0, x0, mul_main # if t0 != 0, is not 0*0 case, jump # return zero with correct sign slli a0, t5, 15 # a0 = result_sign << 15 jal x0, end_mul ### a, b both are non zero actual numbers # aligned the mentissa part and then do the multiple mul_main: add t0, x0, x0 # t0 = adjust exp = 0 mul_adjust_a: beq s0, x0, mul_denormal_a # if s0(exp_a) = 0, a is denormal number # a is non-zero normal number ori s2, s2, 0x80 # retrieve 1.XXX in mant_a jal x0, mul_adjust_b mul_denormal_a: # find the first set bit andi t1, s2, 0x80 # t1 = s2(mant_a) & 0x80 bne t1, x0, mul_mant_a_aligned # while t1(first bit) is not found, loop slli s2, s2, 1 # left shift s2(mant_a) to find the first set bit addi t0, t0, -1 # exp_adjust-- jal x0, mul_denormal_a mul_mant_a_aligned: addi s0, x0, 1 # first set bit is found, exp_a should change from 0 to 1 mul_adjust_b: beq s1, x0, mul_denormal_b # if s1(exp_b) = 0, b is denormal number # a is non-zero normal number ori s3, s3, 0x80 # retrieve 1.XXX in s3(mant_b) jal x0, result_exp mul_denormal_b: # find the first set bit andi t1, s3, 0x80 # t1 = s3(mant_b) & 0x80 bne t1, x0, mul_mant_a_aligned # while t1(first bit) is not found, loop slli s3, s3, 1 # left shift s3(mant_b) to find the first set bit addi t0, t0, -1 # exp_adjust-- jal x0, mul_denormal_b # mantissas are non-zero positive integer, just multiple # t0 = adjust_exp s0 = exponent a # t3 = result mant s1 = exponent b # t4 = result exp s2 = mantissa a # t5 = result sign s3 = mantissa b # t6 = 0xFF s4 = sign a # s5 = sign b result_exp: add t4, s0, s1 # t4 (result exp) = exp_a + exp_b addi t4, t4, -127 # result exp = exp_a + exp_b - 127 add t4, t4, t0 # result exp = exp_a + exp_b - 127 + adjust_exp # t3 = result_mant = (uint32_t) s2 mant_a * s3 mant_b; # while(mant_b != 0) # iif(mant_b & 1) result += mant_a # mant_a << 1, mant_b >> 1 true_mul: beq x0, s3, mul_get_result # while (mant_b != 0) andi t2, s3, 1 # t2 = mant_b & 1 = lsb bit beq x0, t2, mul_skip # if lsb bit of mant_b is 0, do not add to result add t3, t3, s2 # result_mant += (shifted) mant_a mul_skip: srli s3, s3, 1 # mant_b >> 1 slli s2, s2 ,1 # mant_a << 1 jal x0, true_mul # go back to while mul_get_result: # check result mentissa overflow li t0, 0x8000 # overflow mask and t0, t0, t3 # t0 = 0x8000 & t3 (result_mant) beq t0, x0, mul_result_mant_adjust # if 0x8000 & t3 = 0 goto else # overflow happened srli t3, t3, 8 # result_mant >> 8 andi t3, t3, 0x7F # result_mant = (result_mant >> 8) & 0x7F addi t4, t4, 1 # result_exp ++ jal x0, mul_result_exp_adjust mul_result_mant_adjust: # no mantissa overflow srli t3, t3, 7 # t3 (result_mant) >> 7 andi t3, t3, 0x7F # t3 (result_mant) = (result_mant >> 7) & 0x7F mul_result_exp_adjust: # check exp again after checking mantissa overflow blt t4, t6, mul_result_check_denormal # return inf with correct sign slli t5, t5, 15 # t5 (result_sign << 15) li t0, 0x7F80 or a0, t5, t0 # (result_sign << 15) | 0x7F80 jal x0, end_mul mul_result_check_denormal: # 0 < result_exp , not denoraml blt x0, t4, mul_result li t0, -6 bge t4, t0, mul_result_handle_denormal # return zero with correct sign slli a0, t5, 15 # a0 = result_sign << 15 jal x0, end_mul mul_result_handle_denormal: addi t0, t4, -1 # t0 = s4(result_exp) -1 sub t0, x0, t0 # t0 = 1 - s4(result_exp) srl t3, t3, t0 # ts (result_mant) >>= (1 - result_exp) add t4, x0, x0 # t4 (result_exp) = 0 mul_result: slli t5, t5, 15 # t5 (result_sign << 15) and t4, t4, t6 # t4 (result_exp) |= 0xFF slli t4, t4, 7 # (t4 (result_exp) | 0xFF) << 7 andi t3, t3, 0x7F # result_mant & 0x7F or a0, t5, t4 or a0, a0, t3 jal x0, end_mul mul_return_b: mv a0, a1 jal x0, end_mul mul_return_NAN: lw a0, BF16_NAN jal x0, end_mul end_mul: # retrieve ra and callee save lw s5, 24(sp) lw s4, 20(sp) lw s3, 16(sp) lw s2, 12(sp) lw s1, 8(sp) lw s0, 4(sp) lw ra, 0(sp) addi sp, sp, 28 ret # --------------------------------------------------------------------------- # bf16_div: bit-accurate division # Input: a0 = a (bf16), a1 = b (bf16) # Output: a0 = a / b (bf16) # --------------------------------------------------------------------------- bf16_div: #### ## input argument # a0 = (bf16) a # a1 = (bf16) b ## output argument # a0 = (bf16) a/b #### # callee save addi sp, sp, -28 sw ra, 0(sp) sw s0, 4(sp) # s0 = exponent a sw s1, 8(sp) # s1 = exponent b sw s2, 12(sp) # s2 = mantissa a sw s3, 16(sp) # s3 = mantissa b sw s4, 20(sp) # s4 = sign a sw s5, 24(sp) # s5 = sign b # Decomposite a and b into sign / exp / mantissa srli s0, a0, 7 # s0 = a >> 7 andi s0, s0, 0xFF # s0 = exp_a = ((a.bits >> 7) & 0xFF) srli s1, a1, 7 # s1 = b >> 7 andi s1, s1, 0xFF # s1 = exp_b = ((b.bits >> 7) & 0xFF) andi s2, a0, 0x7F # s2 = mant_a = a.bits & 0x7F andi s3, a1, 0x7F # s3 = mant_b = b.bits & 0x7F srli s4, a0, 15 # s4 = sign a srli s5, a1, 15 # s5 = sign b # t0 = adjust_exp s0 = exponent a # t3 = result mant s1 = exponent b # t4 = result exp s2 = mantissa a # t5 = result sign s3 = mantissa b # t6 = 0xFF s4 = sign a # s5 = sign b xor t5, s4, s5 # t5(esult sign) = sign_a ^ sign_b addi t6, x0, 0xFF # t6 = 0xFF # (exp_b == 0xFF) div_check_exp_b: bne s1, t6, div_check_zero_b # if s1 (exp_b) != 0xFF, jump bne s1, x0, div_return_b # if b = NaN, return b # inf / inf = NAN # => a != inf , return_sign_zero # => exp_a != 0xFF, return_sign_zero # exp_a == 0xFF and mant_a != 0, return_sign_zero bne s0, t0, div_return_sign_zero bne s2, x0, div_return_sign_zero jal x0, div_return_NAN div_check_zero_b: # (!exp_b && !mant_b) bne s1, x0, div_check_exp_a # if s1 (exp_b) != 0, b!=0, pass bne s3, x0, div_check_exp_a # if s3 (mant_b) != 0, b!=0, pass # iif (!exp_a && !mant_a) / a=b=0, return NAN # elsee return sign inf bne s0, x0, div_return_sign_inf # if s0 (exp_a) != 0, a!=0, a/0=inf bne s2, x0, div_return_sign_inf # if s2 (mant_a) != 0, a!=0, a/0=inf # 0/0 = NAN jal x0, div_return_NAN div_check_exp_a: bne s0, t6, div_check_zero_a # if s0 (exp_a) != 0xFF, jump bne s1, x0, end_div # if a = NaN, return a jal x0, div_return_sign_inf # else inf/b = inf div_check_zero_a: # (!exp_b && !mant_b) bne s0, x0, div_get_mant_a # if s0 (exp_a) != 0, a!=0, pass bne s2, x0, div_get_mant_a # if s2 (mant_a) != 0, a!=0, pass # we have already handle 0/0 in the previous code jal x0, div_return_sign_zero # else 0/b = 0 ### # During devision, there is no need to aligned the mantissa with ccorrect exp div_get_mant_a: beq x0, s2, div_get_mant_b # denormal a, pass ori s2, s2, 0x80 # retrieve 1.XX div_get_mant_b: beq x0, s3, div_main # denormal b, pass ori s3, s3, 0x80 # retrieve 1.XX # t1 = i s0 = exponent a # t2 = s1 = exponent b # t3 = result mant s2 = mantissa a # t4 = result exp s3 = mantissa b # t5 = result sign s4 = divident # t6 = quotient s5 = divisor div_main: slli s4, s2, 15 # dividend = (uint32_t) mant_a << 15 add s5, x0, s3 # divisor = mant_b add t6, x0, x0 # initial quotient = 0 addi t0, x0, 16 # t0 = 16 (maximum persision) add t1, x0, x0 # t1 = i = 0 true_div: bge t1, t0, div_get_result # if t1 (i) >= 16 end loop slli t6, t6, 1 # t6 (quotient) <<= 1 addi t2, x0, 15 # t2 = 15 sub t2, t2, t1 # t2 = 15 - i sll t2, s5, t2 # divisor << (15 - i) blt s4, t2, div_i_minus_one ## dividend < (divisor << (15 - i)), no quotient sub s4, s4, t2 # dividend -= (divisor << (15 - i)) ori t6, t6, 1 # quotient |= 1 div_i_minus_one: addi t1, t1, 1 # i++ jal x0, true_div ### div_get_result: sub t4, s0, s1 # t4 (result exp) = exp_a - exp_b addi t4, t4, 127 # t4 (result exp) = exp_a - exp_b + BF16_EXP_BIAS div_zero_exp_a_correction: ## if (!exp_a), result_exp-- bne x0, s0, div_zero_exp_b_correction addi t4, t4, -1 div_zero_exp_b_correction: ## if (!exp_b), result_exp++ bne x0, s1, div_check_quotient addi t4, t4, 1 div_check_quotient: li t0, 0x8000 # t0 = 0x8000 and t0, t6, t0 # t0 = quotient & 0x8000 bne t0, x0, div_result_mant_shift div_result_mant_adjust: ## find the first set bit # t0 = quotient & 0x8000 li t0, 0x8000 # t0 = 0x8000 and t0, t6, t0 bne t0, x0, div_result_mant_shift # result_exp-1 <= 0, jump addi t1, t4, -1 # t1 = result_exp - 1 bge x0, t1, div_result_mant_shift slli t6, t6, 1 # t6 (quotient) << 1 addi t4, t4, -1 # result_exp -- div_result_mant_shift: srli t6, t6, 8 # quotient >>= 8 li t0, 0xFF bge t4, t0, div_return_sign_inf bge x0, t4, div_return_sign_zero div_result: slli t5, t5, 15 # t5 (result_sign << 15) and t4, t4, t0 # t4 (result_exp) |= 0xFF slli t4, t4, 7 # (t4 (result_exp) | 0xFF) << 7 andi t6, t6, 0x7F # quotient & 0x7F or a0, t5, t4 or a0, a0, t6 jal x0, end_mul #### div_return_sign_zero: slli a0, t5, 15 # result_sign << 15 jal x0, end_div div_return_sign_inf: slli a0, t5, 15 # result_sign << 15 li t0, 0x7F80 # t0 = 0x780 or a0, a0, t0 # result_sign << 15 | 0x7F80 jal x0, end_div div_return_NAN: lw a0, BF16_NAN jal x0, end_mul div_return_b: mv a0, a1 jal x0, end_div end_div: # retrieve ra and callee save lw s5, 24(sp) lw s4, 20(sp) lw s3, 16(sp) lw s2, 12(sp) lw s1, 8(sp) lw s0, 4(sp) lw ra, 0(sp) addi sp, sp, 28 ret # --------------------------------------------------------------------------- # bf16_sqrt: bit-accurate square-root implementation (binary-search) # Input: a0 = bf16 # Output: a0 = bf16(sqrt(a)) # --------------------------------------------------------------------------- bf16_sqrt: addi sp, sp, -20 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) sw s3, 16(sp) li s3, 0xFF srli s0, a0, 7 and s0, s0, s3 andi s1, a0, 0x7F srli s2, a0, 15 bne s2, x0, sqrt_ret_nan bne s0, s3, sqrt_ck_input_0 jal x0, end_sqrt sqrt_ck_input_0: bne s0, x0, sqrt_ck_input_denormal bne s1, x0, sqrt_ret_zero jal sqrt_ret_zero sqrt_ck_input_denormal: beq s0, x0, sqrt_ret_zero sqrt_main: addi s0, s0, -127 ori s1, s1, 0x80 andi t0, s0, 0x1 beq t0, x0, handle_even_exp handle_odd_exp: slli s1, s1, 1 addi t6, s0, -1 srli t6, t6, 1 addi t6, t6, 127 jal x0, true_sqrt handle_even_exp: add t6, s0, x0 srli t6, t6, 1 addi t6, t6, 127 true_sqrt: li t3, 90 li t4, 256 li t5, 0x10 binary_search: blt t4, t3, sqrt_normalized_result add t1, t3, t4 srli t1, t1, 1 mv a0, t1 mv a1, t1 jal ra, int_mul srli t2, a0, 7 blt s1, t2, sqrt_too_big mv t5, t1 addi t3, t1, 1 jal x0, binary_search sqrt_too_big: addi t4, t1, -1 jal x0, binary_search sqrt_normalized_result: li t0, 256 blt t5, t0, sqer_borrow_exp srli t5, t5, 1 addi t6, t6, 1 jal x0, remove_result_mant_one sqer_borrow_exp: li t0, 128 li t1, 1 bge t5, t0, remove_result_mant_one bge t1, t6, remove_result_mant_one slli t5, t5, 1 addi t6, t6, -1 jal x0, sqrt_exp_adjust_loop sqrt_exp_adjust_loop: bge t5, t0, remove_result_mant_one bge t1, t6, remove_result_mant_one slli t5, t5, 1 addi t6, t6, -1 jal x0, sqrt_exp_adjust_loop remove_result_mant_one: andi t5, t5, 0x7F bge t6, s3, sqrt_ret_inf bge x0, t6, sqrt_ret_zero sqrt_get_result: and a0, t6, s3 slli a0, a0, 7 or a0, a0, t5 jal x0, end_sqrt sqrt_ret_inf: li a0, 0x7F80 jal x0, end_sqrt sqrt_ret_zero: lw a0, BF16_ZERO jal x0, end_sqrt sqrt_ret_nan: lw a0, BF16_NAN jal x0, end_sqrt end_sqrt: lw s3, 16(sp) lw s2, 12(sp) lw s1, 8(sp) lw s0, 4(sp) lw ra, 0(sp) addi sp, sp, 20 ret # --------------------------------------------------------------------------- # int_mul: small integer multiply utility used by sqrt (schoolbook) # Input: a0 = integer multiplicand, a1 = integer multiplier # Output: a0 = a0 * a1 # --------------------------------------------------------------------------- int_mul: add t0, x0, x0 int_mul_loop: beq x0, a1, end_int_mul andi t2, a1, 1 beq x0, t2, int_mul_skip add t0, t0, a0 int_mul_skip: srli a1, a1, 1 slli a0, a0, 1 jal x0, int_mul_loop end_int_mul: mv a0, t0 ret # ============================================================================ # End of file # ============================================================================ ``` ### Conversions, special values, arithmetic, comparisons, edge cases #### Basic Conversions Test | Test # | Input (F32) | F32 Hex | Expected Behavior | Test Item | |--------|-------------|---------|-------------------|-----------| | 1 | 0.0f | 0x00000000 | Sign match + Error check | Positive zero | | 2 | 1.0f | 0x3F800000 | Sign match + Error check | Positive number | | 3 | -1.0f | 0xBF800000 | Sign match + Error check | Negative number | | 4 | 2.0f | 0x40000000 | Sign match + Error check | Power of 2 | | 5 | -2.0f | 0xC0000000 | Sign match + Error check | Negative power of 2 | | 6 | 0.5f | 0x3F000000 | Sign match + Error check | Fraction | | 7 | -0.5f | 0xBF000000 | Sign match + Error check | Negative fraction | | 8 | 3.14159f | 0x40490FDB | Sign match + Error check | π | | 9 | -3.14159f | 0xC0490FDB | Sign match + Error check | -π | | 10 | 1e10f | 0x501502F9 | Sign match + Error check | Large number | | 11 | -1e10f | 0xD01502F9 | Sign match + Error check | Large negative | #### Special Values Test | Test # | Input (F32) | F32 Hex | BF16 Expected | Test Function | Expected Result | |--------|-------------|---------|---------------|---------------|-----------------| | 1a | +∞ | 0x7F800000 | 0x7F80 | `bf16_isinf()` | Return 1 | | 1b | +∞ | 0x7F800000 | 0x7F80 | `!bf16_isnan()` | Return 0 | | 2 | -∞ | 0xFF800000 | 0xFF80 | `bf16_isinf()` | Return 1 | | 3a | NaN | 0x7FC00000 | 0x7FC0 | `bf16_isnan()` | Return 1 | | 3b | NaN | 0x7FC00000 | 0x7FC0 | `!bf16_isinf()` | Return 0 | | 4 | +0.0 | 0x00000000 | 0x0000 | `bf16_iszero()` | Return 1 | | 5 | -0.0 | 0x80000000 | 0x8000 | `bf16_iszero()` | Return 1 |s #### Assembly code : ``` # ============================================================================ # BFloat16 Complete Test # ============================================================================ # Tests: conversions, special values, arithmetic, comparisons, edge cases # ============================================================================ .data # Test values for basic conversions (Float32 as uint32_t bits) test_values: .word 0x00000000 # 0.0f .word 0x3F800000 # 1.0f .word 0xBF800000 # -1.0f .word 0x40000000 # 2.0f .word 0xC0000000 # -2.0f .word 0x3F000000 # 0.5f .word 0xBF000000 # -0.5f .word 0x40490FDB # 3.14159f .word 0xC0490FDB # -3.14159f .word 0x501502F9 # 1e10f .word 0xD01502F9 # -1e10f test_count: .word 11 # Expected results for arithmetic expected_add: .word 0x40400000 # 3.0f (1.0 + 2.0) expected_sub: .word 0x3F800000 # 1.0f (2.0 - 1.0) expected_mul: .word 0x41400000 # 12.0f (3.0 * 4.0) expected_div: .word 0x40A00000 # 5.0f (10.0 / 2.0) expected_sqrt4: .word 0x40000000 # 2.0f (sqrt(4.0)) expected_sqrt9: .word 0x40400000 # 3.0f (sqrt(9.0)) # Test counters tests_passed: .word 0 tests_failed: .word 0 # Output strings str_banner: .string "\n=== BFloat16 Test Suite ===\n\n" str_test_basic: .string "Testing basic conversions...\n" str_test_special: .string "Testing special values...\n" str_test_arith: .string "Testing arithmetic operations...\n" str_test_comp: .string "Testing comparison operations...\n" str_test_edge: .string "Testing edge cases...\n" str_test_round: .string "Testing rounding behavior...\n" str_pass: .string " PASS\n" str_fail: .string " FAIL\n" str_test_num: .string " Test " str_colon: .string ": " str_summary: .string "\n=== Test Summary ===\n" str_passed_count: .string "Tests passed: " str_failed_count: .string "Tests failed: " str_all_pass: .string "\n=== ALL TESTS PASSED ===\n" str_some_fail: .string "\n=== SOME TESTS FAILED ===\n" str_newline: .string "\n" # Test names for detailed output str_sign_mismatch: .string " [Sign mismatch]\n" str_error_large: .string " [Relative error too large]\n" str_inf_detect: .string " [Infinity detection]\n" str_nan_detect: .string " [NaN detection]\n" str_zero_detect: .string " [Zero detection]\n" str_add_fail: .string " [Addition failed]\n" str_sub_fail: .string " [Subtraction failed]\n" str_mul_fail: .string " [Multiplication failed]\n" str_div_fail: .string " [Division failed]\n" str_sqrt_fail: .string " [sqrt failed]\n" .text .globl main # ============================================================================ # MAIN: Entry point # ============================================================================ main: addi sp, sp, -16 sw ra, 12(sp) # Print banner la a0, str_banner call print_string # Run all test suites call test_basic_conversions call test_special_values # Print summary call print_summary # Check if any failed la t0, tests_failed lw t0, 0(t0) beqz t0, main_all_pass la a0, str_some_fail call print_string li a0, 1 j main_exit main_all_pass: la a0, str_all_pass call print_string li a0, 0 main_exit: lw ra, 12(sp) addi sp, sp, 16 li a7, 10 ecall # ============================================================================ # TEST_BASIC_CONVERSIONS # ============================================================================ test_basic_conversions: addi sp, sp, -32 sw ra, 28(sp) sw s0, 24(sp) # s0 = loop counter sw s1, 20(sp) # s1 = test_values pointer sw s2, 16(sp) # s2 = test_count sw s3, 12(sp) # s3 = original f32 sw s4, 8(sp) # s4 = bf16 result sw s5, 4(sp) # s5 = converted f32 la a0, str_test_basic call print_string la s1, test_values la t0, test_count lw s2, 0(t0) li s0, 0 basic_loop: bge s0, s2, basic_done lw s3, 0(s1) # s3 = original f32 # Convert f32 → bf16 mv a0, s3 call f32_to_bf16 mv s4, a0 # s4 = bf16 # Convert bf16 → f32 mv a0, s4 call bf16_to_f32 mv s5, a0 # s5 = converted f32 # Test 1: Check sign consistency (if not zero) beqz s3, basic_skip_sign mv a0, s3 mv a1, s5 call check_sign_match beqz a0, basic_sign_fail basic_skip_sign: # Test 2: Check relative error (if not zero and not inf) beqz s3, basic_continue mv a0, s3 call is_f32_infinity bnez a0, basic_continue mv a0, s3 mv a1, s5 call check_relative_error_1pct beqz a0, basic_error_fail basic_continue: # Both tests passed for this value addi s1, s1, 4 addi s0, s0, 1 j basic_loop basic_sign_fail: # Sign test failed - report and exit immediately la a0, str_sign_mismatch call print_string call increment_failed j basic_fail_exit basic_error_fail: # Error test failed - report and exit immediately la a0, str_error_large call print_string call increment_failed j basic_fail_exit basic_fail_exit: # Exit without printing PASS lw s5, 4(sp) lw s4, 8(sp) lw s3, 12(sp) lw s2, 16(sp) lw s1, 20(sp) lw s0, 24(sp) lw ra, 28(sp) addi sp, sp, 32 ret basic_done: # All tests passed call increment_passed la a0, str_pass call print_string lw s5, 4(sp) lw s4, 8(sp) lw s3, 12(sp) lw s2, 16(sp) lw s1, 20(sp) lw s0, 24(sp) lw ra, 28(sp) addi sp, sp, 32 ret # ============================================================================ # TEST_SPECIAL_VALUES # ============================================================================ test_special_values: addi sp, sp, -16 sw ra, 12(sp) sw s0, 8(sp) sw s1, 4(sp) la a0, str_test_special call print_string # Test 1a: Positive infinity - should be infinity li s0, 0x7F800000 # +inf in f32 mv a0, s0 call f32_to_bf16 mv s1, a0 call bf16_isinf bnez a0, special_test1a_pass la a0, str_inf_detect call print_string call increment_failed j special_test1b special_test1a_pass: call increment_passed special_test1b: # Test 1b: Positive infinity - should NOT be NaN mv a0, s1 call bf16_isnan beqz a0, special_test1b_pass la a0, str_nan_detect call print_string call increment_failed j special_test2 special_test1b_pass: call increment_passed special_test2: # Test 2: Negative infinity li s0, 0xFF800000 # -inf in f32 mv a0, s0 call f32_to_bf16 mv s1, a0 call bf16_isinf bnez a0, special_test2_pass la a0, str_inf_detect call print_string call increment_failed j special_test3a special_test2_pass: call increment_passed special_test3a: # Test 3a: NaN - should be NaN li s0, 0x7FC00000 # NaN in f32 mv a0, s0 call f32_to_bf16 mv s1, a0 call bf16_isnan bnez a0, special_test3a_pass la a0, str_nan_detect call print_string call increment_failed j special_test3b special_test3a_pass: call increment_passed special_test3b: # Test 3b: NaN - should NOT be infinity mv a0, s1 call bf16_isinf beqz a0, special_test3b_pass la a0, str_inf_detect call print_string call increment_failed j special_test4 special_test3b_pass: call increment_passed special_test4: # Test 4: Positive zero li s0, 0x00000000 # +0.0 in f32 mv a0, s0 call f32_to_bf16 mv s1, a0 call bf16_iszero bnez a0, special_test4_pass la a0, str_zero_detect call print_string call increment_failed j special_test5 special_test4_pass: call increment_passed special_test5: # Test 5: Negative zero li s0, 0x80000000 # -0.0 in f32 mv a0, s0 call f32_to_bf16 mv s1, a0 call bf16_iszero bnez a0, special_test5_pass la a0, str_zero_detect call print_string call increment_failed j special_done special_test5_pass: call increment_passed special_done: la a0, str_pass call print_string lw s1, 4(sp) lw s0, 8(sp) lw ra, 12(sp) addi sp, sp, 16 ret # ============================================================================ # CONVERSION FUNCTIONS # ============================================================================ # F32_TO_BF16: Convert Float32 to BFloat16 # Input: a0 = float32 bits (32-bit) # Output: a0 = bfloat16 bits (16-bit, in lower 16 bits of a0) f32_to_bf16: # Get exponent (bits 30:23) srli t0, a0, 23 andi t0, t0, 0xFF # Check for special cases (NaN or Inf: exponent == 0xFF) li t1, 0xFF beq t0, t1, f32_special # Normal case: round-to-nearest-even # Rounding bias = (f32 >> 16) & 1 + 0x7FFF srli t2, a0, 16 # Get bit 16 (LSB of BF16) andi t2, t2, 1 # Isolate bit 16 # Create 0x7FFF lui t3, 0x8 # t3 = 0x8000 addi t3, t3, -1 # t3 = 0x7FFF # Add rounding bias add t2, t2, t3 # t2 = 0x7FFF or 0x8000 add a0, a0, t2 # Add bias to original value # Extract top 16 bits (this is the BF16 result) srli a0, a0, 16 ret f32_special: # For special values (NaN/Inf), just truncate srli a0, a0, 16 ret # BF16_TO_F32: Convert BFloat16 to Float32 # Input: a0 = bfloat16 bits (16-bit) # Output: a0 = float32 bits (32-bit) bf16_to_f32: # BF16 is top 16 bits of FP32, just shift left by 16 slli a0, a0, 16 ret # ============================================================================ # SPECIAL VALUE CHECKS # ============================================================================ # BF16_ISINF: Check if BF16 is infinity # Input: a0 = bf16 (16-bit) # Output: a0 = 1 if inf, 0 otherwise bf16_isinf: # Extract exponent (bits 14:7) srli t0, a0, 7 andi t0, t0, 0xFF # Extract mantissa (bits 6:0) andi t1, a0, 0x7F # Is inf if exponent==0xFF and mantissa==0 li t2, 0xFF bne t0, t2, isinf_false bnez t1, isinf_false li a0, 1 ret isinf_false: li a0, 0 ret # BF16_ISNAN: Check if BF16 is NaN # Input: a0 = bf16 (16-bit) # Output: a0 = 1 if NaN, 0 otherwise bf16_isnan: # Extract exponent (bits 14:7) srli t0, a0, 7 andi t0, t0, 0xFF # Extract mantissa (bits 6:0) andi t1, a0, 0x7F # Is NaN if exponent==0xFF and mantissa!=0 li t2, 0xFF bne t0, t2, isnan_false beqz t1, isnan_false li a0, 1 ret isnan_false: li a0, 0 ret # BF16_ISZERO: Check if BF16 is zero (including -0) # Input: a0 = bf16 (16-bit) # Output: a0 = 1 if zero, 0 otherwise bf16_iszero: # Zero if all bits except sign are 0 # Create mask 0x7FFF lui t1, 0x8 # t1 = 0x8000 addi t1, t1, -1 # t1 = 0x7FFF and t0, a0, t1 # Mask off sign bit seqz a0, t0 # a0 = (t0 == 0) ? 1 : 0 ret # ============================================================================ # HELPER FUNCTIONS FOR TESTING # ============================================================================ # CHECK_SIGN_MATCH: Check if two f32 have same sign # Input: a0 = f32_1, a1 = f32_2 # Output: a0 = 1 if same sign, 0 otherwise check_sign_match: srli t0, a0, 31 # Get sign bit of a0 srli t1, a1, 31 # Get sign bit of a1 xor t0, t0, t1 # XOR: 0 if same, 1 if different seqz a0, t0 # a0 = 1 if t0==0 (same sign) ret # IS_F32_INFINITY: Check if f32 is infinity # Input: a0 = f32 # Output: a0 = 1 if inf, 0 otherwise is_f32_infinity: # Extract exponent (bits 30:23) srli t0, a0, 23 andi t0, t0, 0xFF # Extract mantissa (bits 22:0) lui t2, 0x80 # Load 0x80000 addi t2, t2, -1 # t2 = 0x7FFFF and t1, a0, t2 # Mask mantissa # Inf if exponent==0xFF and mantissa==0 li t3, 0xFF bne t0, t3, f32_not_inf bnez t1, f32_not_inf li a0, 1 ret f32_not_inf: li a0, 0 ret # CHECK_RELATIVE_ERROR_1PCT: Check if relative error < 1% # Input: a0 = expected (f32), a1 = actual (f32) # Output: a0 = 1 if error acceptable, 0 otherwise check_relative_error_1pct: # For BF16, check if top 16 bits are close srli t0, a0, 16 srli t1, a1, 16 sub t2, t0, t1 # Difference # Get absolute difference bgez t2, check_positive sub t2, zero, t2 # Make positive check_positive: # Allow difference of ±1 for rounding li t3, 1 ble t2, t3, error_ok li a0, 0 ret error_ok: li a0, 1 ret # ============================================================================ # TEST COUNTER FUNCTIONS # ============================================================================ increment_passed: la t0, tests_passed lw t1, 0(t0) addi t1, t1, 1 sw t1, 0(t0) ret increment_failed: la t0, tests_failed lw t1, 0(t0) addi t1, t1, 1 sw t1, 0(t0) ret # ============================================================================ # OUTPUT FUNCTIONS # ============================================================================ print_summary: addi sp, sp, -16 sw ra, 12(sp) la a0, str_summary call print_string la a0, str_passed_count call print_string la t0, tests_passed lw a0, 0(t0) call print_int la a0, str_newline call print_string la a0, str_failed_count call print_string la t0, tests_failed lw a0, 0(t0) call print_int la a0, str_newline call print_string lw ra, 12(sp) addi sp, sp, 16 ret print_string: li a7, 4 ecall ret print_int: li a7, 1 ecall ret ``` ## Analysis 5-stage pipelined processor for problem B [Introduction to 5-stage pipelined processor on Ripes ](https://hackmd.io/@CarSam/Bkz2amw6xl) Analysing the first instruction of problem B ``` addi sp, sp, -16 ``` ![image](https://hackmd.io/_uploads/BkQMfsOTgx.png) ### Why its machine code is **ff010113**? According to the folloing picture, ```addi``` is I-Type instruction, its opcode is ```0010011``` . ![messageImage_1760137124787](https://hackmd.io/_uploads/B16HZs_Tll.jpg) The address of stack pointer(sp) at ```x2``` is ```0x7ffffff0```. The 12 bits binary form of -16 is ```111111110000``` After combiming these infomation, we know the machine code in binary form is ``` imm rs1 funct3 rd opcode 111111110000 00010 000 00010 0010011 ``` Then turn it into hexideximal form ``` 0xFF010113 ``` ### 5 stage pipeline processor ![image](https://hackmd.io/_uploads/Sky1Eou6xe.png) There are 5 stages: 1. Instruction fetch (IF) 2. Instruction decode and register fetch (ID) 3. Execute (EX) 4. Memory access (MEM) 5. Register write back (WB) Let's see how the instruction go through each stage: 1. IF ![image](https://hackmd.io/_uploads/SyQAmsuTxl.png) * ```addr``` of input insturction is ```0x00000000```. * The machine code of first instructio is ```0xFF010113```. * Next PC will become PC+4 automatically. 2. ID ![image](https://hackmd.io/_uploads/rylVMx6dTel.png) * OP code is ```0b010011``` * I-type instruction read rs1 value at R2, which value is```0x10``` * Immediate is ```0xfffffff0``` 3. EX ![image](https://hackmd.io/_uploads/Hy5Ef6uTxe.png) ``` ALU_result = Reg1 + immediate = 0x7FFFFFF0 + (-16) = 0x7FFFFFE0 ``` * Output of ALU is ```0x7fffffe0``` 4.MEM ![image](https://hackmd.io/_uploads/BJCVEadTll.png) * Since ```addi``` don't store or load data from memory, Read out is```0x00000000``` 5.WB ![image](https://hackmd.io/_uploads/ryp07T_6eg.png) * The output value ```0x7fffffe0``` and Wr data are send back to registers block. After all these stage are done, the register is updated like this: ![image](https://hackmd.io/_uploads/H14WFnYaeg.png)