# Assignment 1: RISC-V Assembly and Instruction Pipeline
>[!Note] AI tools usage
>I use ChatGPT and Claude to deal with Quiz 1 by providing code explanations, related research, code summaries, debug, and generate table for test data.
## Problem B
* Refer to [Quiz1 of Computer Architecture (2025 Fall) Problem B](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
* You can find the source code [here](https://github.com/CarSam16/ca2025-quizzes). Feel free to fork and modify it.
* uf8 format:
```
┌──────────────┬────────────────┐
│ Exponent (4) │ Mantissa (4) │
└──────────────┴────────────────┘
7 4 3 0
```
### Assembly code : uf8 decode
```
# input uf8 value : a0
# output uint32 value : a0
uf8_decode:
andi t1, a0, 0x0f # mantissa = fl & 0x0f
srli t2, a0, 4 # exponent = fl >> 4
li t3, 15
sub t3, t3, t2 # t3 = (15 - exponent)
li t4, 0x7FFF
srl t3, t4, t3 # t3 = (0x7FFF >> (15 - exponent))
slli t3, t3, 4 # t3 = (0x7FFF >> (15 - exponent)) << 4;
sll t1, t1, t2 # t1 = offset = (mantissa << exponent)
add a0, t1, t3 # (mantissa << exponent) + offset
jr ra
```
### Assembly code : uf8 encode
```
# input uint32 : a0
# output uf8 : a0
uf8_encode:
# callee save
addi sp, sp, -8
sw ra, 0(sp) # used to call CLZ function
sw s0, 4(sp)
add s0, a0, x0 # s0 keep original value
# /* Use CLZ for fast exponent calculation */
li t0, 16
blt s0, t0, end_encode
jal ra, CLZ_myfunction # call CLZ
li t0, 31
sub t0, t0, a0 # msb = 31 - lz, lz is in a0
li a1, 0 # exp
li a2, 0 # overflow
li t1, 5
blt t0, t1, find_exact_exp
li t1, 4
sub a1, t0, t1
li t1, 15
blt a1, t1, cal_overflow
li a1, 15
cal_overflow:
li t1, 0
cal_overflow_loop:
bge t1, a1, adjust_loop
slli a2, a2, 1
addi a2, a2, 16 # # overflow = (overflow << 1) + 16
addi t1, t1, 1 # counter ++
j cal_overflow_loop
adjust_loop:
bltz a1, find_exact_exp
bge s0, a2, find_exact_exp
addi t2, a2, -16
srli a2, t2, 1
addi a1, a1, -1
j adjust_loop
find_exact_exp:
li t1, 15
bge a1, t1, end_encode
slli t2, a2, 1
addi t2, t2, 16
blt s0, t2, end_encode
add a2, t2, x0
addi a1, a1, 1
j find_exact_exp
end_encode:
sub t2, s0, a2
srl t2, t2, a1
slli a1, a1, 4
or a0, a1, t2
return:
lw s0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 8
jr ra
```
### CLZ function
```
# Function: clz
# Arguments:
# a0 = x (input)
# Return:
# a0 = n - x
# Temporaries:
# t0 = n
# t1 = c
# t2 = y
CLZ_myfunction:
li t0, 32 # n
li t1, 16 # c
do_while:
srl t2, a0, t1 # y = x >> c
bnez t2, if_loop # if (y != 0) -> go to if_loop
srli t1, t1, 1 # c >>= 1
beqz t1, end_loop # if (c == 0) break
j do_while # continue loop
if_loop:
sub t0, t0, t1 # n -= c
add a0, t2, x0 # x = y
srli t1, t1, 1 # c >>= 1
beqz t1, end_loop # if (c == 0) break
j do_while # loop again
end_loop:
sub a0, t0, a0
jr ra
```
### Assembly code : uf8 test
VALIDATE_ROUNDTRIP: Test encoding and decoding for all 256 values
```
.data
msg1: .asciz ": produces value "
msg2: .asciz " but encodes back to "
msg3: .asciz ": value "
msg4: .asciz " <= previous_value "
msg5: .asciz "All tests passed.\n"
msg6: .asciz "Some tests failed.\n"
newline: .asciz "\n"
.align 2
.text
.globl main
main:
jal ra, test # start to test
beq a0, x0, Not_pass # fail
la a0, msg5 # print msg5 when passing
li a7, 4
ecall
li a7, 93 # ecall: exit
li a0, 0 # exit code is 0, successful
ecall
Not_pass:
la a0, msg6 # print msg6 when not passing
li a7, 4
ecall
li a7, 93 # ecall: exit
li a0, 1 # exit code is 1, not successful
ecall
test:
addi sp, sp, -4
sw ra, 0(sp) # because test need to call other function
addi s0, x0, -1 # previous_value
li s1, 1 # passed, 1 means true, 0 means false
li s2, 0 # f1, counter from 0 to 255
li s3, 256 # counter's end
For_2:
add a0, s2, x0 # prepare a0 for uf8_decode
jal ra, uf8_decode
add s4, a0, x0 # value (return value from uf8_decode)
add a0, s4, x0 # prepare a0 for uf8_encode
jal ra, uf8_encode
add s5, a0, x0 # fl2 (return value from uf8_encode)
test_if_1:
beq s2, s5, test_if_2
mv a0, s2 # print s2(f1)
li a7, 34 # (RARS) print integer in hex
ecall
la a0, msg1 # print msg1
li a7, 4
ecall
mv a0, s4 # print value
li a7, 1
ecall
la a0, msg2 # print msg2
li a7, 4
ecall
mv a0, s5 # prepare to print fl2(s5)'s hexdecimal
li a7, 34 # (RARS) print integer in hex
ecall
la a0, newline # print newline
li a7, 4
ecall
li s1, 0 # passed = false
test_if_2:
blt s0, s4, after_if
mv a0, s2 # print s2(f1)
li a7, 34 # (RARS) print integer in hex
ecall
la a0, msg3 # print msg1
li a7, 4
ecall
mv a0, s4 # print value
li a7, 1
ecall
la a0, msg4 # print msg2
li a7, 4
ecall
mv a0, s0 # prepare to print s0(previous_value)'s hexdecimal
li a7, 34 # (RARS) print integer in hex
ecall
la a0, newline # print newline
li a7, 4
ecall
li s1, 0 # passed = false
after_if:
mv s0, s4
addi s2, s2, 1
blt s2, s3, For_2
mv a0, s1 # return passed
lw ra, 0(sp)
addi sp, sp, 4
jr ra # jump to ra
# Function: clz
# Arguments:
# a0 = x (input)
# Return:
# a0 = n - x
# Temporaries:
# t0 = n
# t1 = c
# t2 = y
```
## Problem C
* Refer to [Quiz1 of Computer Architecture (2025 Fall) Problem C](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
* bf16_t format
```
┌─────────┬──────────────┬──────────────┐
│Sign (1) │ Exponent (8) │ Mantissa (7) │
└─────────┴──────────────┴──────────────┘
15 14 7 6 0
S: Sign bit (0 : positive, 1 : negative)
E: Exponent bits (bias = 127)
M: Mantissa bits
```
* f32_t format
```
┌─────────┬──────────────┬───────────────┐
│Sign (1) │ Exponent (8) │ Mantissa (23) │
└─────────┴──────────────┴───────────────┘
31 30 23 22 0
S: Sign bit (0 : positive, 1 : negative)
E: Exponent bits (bias = 127)
M: Mantissa bits
```
### Arithmetic Part
#### Test data
| Test # | Operation | Input A | Input B | BF16 A | BF16 B | Expected (F32) | Expected (BF16) | Tolerance |
|--------|-----------|---------|---------|--------|--------|----------------|-----------------|-----------|
| 1 | Add | 1.0f | 2.0f | 0x3F80 | 0x4000 | 3.0f (0x40400000) | 0x4040 | ±0.01 |
| 2 | Sub | 2.0f | 1.0f | 0x4000 | 0x3F80 | 1.0f (0x3F800000) | 0x3F80 | ±0.01 |
| 3 | Mul | 3.0f | 4.0f | 0x4040 | 0x4080 | 12.0f (0x41400000) | 0x4140 | ±0.1 |
| 4 | Div | 10.0f | 2.0f | 0x4120 | 0x4000 | 5.0f (0x40A00000) | 0x40A0 | ±0.1 |
| 5 | Sqrt | 4.0f | - | 0x4080 | - | 2.0f (0x40000000) | 0x4000 | ±0.01 |
#### Test result


#### Assembly code :
```
.data
# ----------------------------
# Output strings
# ----------------------------
str_banner: .asciz "\n=== BFloat16 Arithmetic Test ===\n\n"
str_test_arith: .asciz "Testing arithmetic operations...\n"
str_add_fail: .asciz " [Addition failed]\n"
str_sub_fail: .asciz " [Subtraction failed]\n"
str_mul_fail: .asciz " [Multiplication failed]\n"
str_div_fail: .asciz " [Division failed]\n"
str_sqrt_fail: .asciz " [Sqrt failed]\n"
str_arith_pass: .asciz " Arithmetic: PASS\n"
str_summary: .asciz "\n=== Test Summary ===\n"
str_passed_count: .asciz "Tests passed: "
str_failed_count: .asciz "Tests failed: "
str_newline: .asciz "\n"
str_add_pass: .asciz "[Add passed]\n"
str_sub_pass: .asciz "[Sub passed]\n"
str_mul_pass: .asciz "[Mul passed]\n"
str_div_pass: .asciz "[Div passed]\n"
str_sqrt_pass: .asciz "[Sqrt passed]\n"
# ----------------------------
# Counters
# ----------------------------
tests_passed: .word 0
tests_failed: .word 0
# ----------------------------
# BF16 constant words (used inside arithmetic implementations)
# Keep these labels for compatibility with the original implementation.
# ----------------------------
Inf_pos: .word 0x7F800000, 0x7F80
Inf_neg: .word 0xFF800000, 0xFF80
NaN: .word 0xFFC00000, 0xFFC0
normal: .word 0x40490fd0, 0x4049
denormal: .word 0x40000fd0
BF16_SIGN_MASK: .word 0x8000
BF16_EXP_MASK: .word 0x7F80
BF16_MANT_MASK: .word 0x007F
BF16_EXP_BIAS: .word 127
BF16_NAN: .word 0x7FC0
BF16_ZERO: .word 0x0
.text
.globl main
# ============================================================================
# MAIN: entrypoint
# ============================================================================
main:
addi sp, sp, -16
sw ra, 12(sp)
# Print banner
la a0, str_banner
call print_string
# Run the arithmetic tests
call test_arithmetic
# Print summary
call print_summary
# Exit gracefully
lw ra, 12(sp)
addi sp, sp, 16
li a7, 10
ecall
# ============================================================================
# test_arithmetic
# ============================================================================
test_arithmetic:
addi sp, sp, -32
sw ra, 28(sp)
sw s0, 24(sp)
sw s1, 20(sp)
sw s2, 16(sp)
sw s3, 12(sp)
sw s4, 8(sp)
sw s5, 4(sp)
la a0, str_test_arith
call print_string
# ---------------------------
# Test 1: Addition 1.0 + 2.0 = 3.0
# ---------------------------
li a0, 0x3F80 # float bits for 1.0
#call f32_to_bf16
mv s0, a0 # s0 = bf16(1.0)
li a0, 0x4000 # float bits for 2.0
#call f32_to_bf16
mv s1, a0 # s1 = bf16(2.0)
mv a0, s0
mv a1, s1
call bf16_add
mv s2, a0 # s2 = result bf16
mv a0, s2
#call bf16_to_f32
mv s3, a0 # s3 = result as f32 bits
li s4, 0x4040 # expected = 3.0f (bits)
mv a0, s3
mv a1, s4
call check_relative_error_1pct
beqz a0, arith_add_fail
la a0, str_add_pass
call print_string
call increment_passed
j arith_sub_test
arith_add_fail:
la a0, str_add_fail
call print_string
call increment_failed
j arith_sub_test
# ---------------------------
# Test 2: Subtraction 2.0 - 1.0 = 1.0
# ---------------------------
arith_sub_test:
li a0, 0x4000 # 2.0f
#call f32_to_bf16
mv s0, a0
li a0, 0x3F80 # 1.0f
#call f32_to_bf16
mv s1, a0
mv a0, s0
mv a1, s1
call bf16_sub
mv s2, a0
mv a0, s2
#call bf16_to_f32
mv s3, a0
li s4, 0x3F80 # expected = 1.0f
mv a0, s3
mv a1, s4
call check_relative_error_1pct
beqz a0, arith_sub_fail
la a0, str_sub_pass
call print_string
call increment_passed
j arith_mul_test
arith_sub_fail:
la a0, str_sub_fail
call print_string
call increment_failed
j arith_done
# ---------------------------
# Test 3: Multiplication 3.0 * 4.0 = 12.0
# ---------------------------
arith_mul_test:
li a0, 0x4040 # 3.0f
#call f32_to_bf16
mv s0, a0
li a0, 0x4080 # 4.0f
#call f32_to_bf16
mv s1, a0
mv a0, s0
mv a1, s1
call bf16_mul
mv s2, a0
mv a0, s2
#call bf16_to_f32
mv s3, a0
li s4, 0x4140 # expected = 12.0f
mv a0, s3
mv a1, s4
call check_relative_error_1pct
beqz a0, arith_mul_fail
la a0, str_mul_pass
call print_string
call increment_passed
j arith_div_test
arith_mul_fail:
la a0, str_mul_fail
call print_string
call increment_failed
j arith_done
# ---------------------------
# Test 4: Division 10.0 / 2.0 = 5.0
# ---------------------------
arith_div_test:
li a0, 0x4120 # 10.0f
#call f32_to_bf16
mv s0, a0
li a0, 0x4000 # 2.0f
#call f32_to_bf16
mv s1, a0
mv a0, s0
mv a1, s1
call bf16_div
mv s2, a0
mv a0, s2
#call bf16_to_f32
mv s3, a0
li s4, 0x40A0 # expected = 5.0f
mv a0, s3
mv a1, s4
call check_relative_error_1pct
beqz a0, arith_div_fail
la a0, str_div_pass
call print_string
call increment_passed
j arith_sqrt_test
arith_div_fail:
la a0, str_div_fail
call print_string
call increment_failed
j arith_done
# ---------------------------
# Test 5: sqrt(9.0) = 3.0
# ---------------------------
arith_sqrt_test:
li a0, 0x4110 # 9.0f
call bf16_sqrt
mv s1, a0
li s4, 0x4040 # expected = 3.0f
mv a0, s1
mv a1, s4
call check_relative_error_1pct
beqz a0, arith_sqrt_fail
la a0, str_sqrt_pass
call print_string
call increment_passed
j arith_done
arith_sqrt_fail:
la a0, str_sqrt_fail
call print_string
call increment_failed
j arith_done
# ---------------------------
# All arithmetic tests done
# ---------------------------
arith_done:
lw s5, 4(sp)
lw s4, 8(sp)
lw s3, 12(sp)
lw s2, 16(sp)
lw s1, 20(sp)
lw s0, 24(sp)
lw ra, 28(sp)
addi sp, sp, 32
ret
print_string:
li a7, 4
ecall
ret
print_int:
li a7, 1
ecall
ret
increment_passed:
la t0, tests_passed
lw t1, 0(t0)
addi t1, t1, 1
sw t1, 0(t0)
ret
increment_failed:
la t0, tests_failed
lw t1, 0(t0)
addi t1, t1, 1
sw t1, 0(t0)
ret
print_summary:
addi sp, sp, -16
sw ra, 12(sp)
la a0, str_summary
call print_string
la a0, str_passed_count
call print_string
la t0, tests_passed
lw a0, 0(t0)
call print_int
la a0, str_newline
call print_string
la a0, str_failed_count
call print_string
la t0, tests_failed
lw a0, 0(t0)
call print_int
la a0, str_newline
call print_string
lw ra, 12(sp)
addi sp, sp, 16
ret
check_relative_error_1pct:
srli t0, a0, 16 # top16 of actual
srli t1, a1, 16 # top16 of expected
sub t2, t0, t1
bgez t2, crp_pos
sub t2, x0, t2 # abs
crp_pos:
li t3, 1
ble t2, t3, crp_ok
li a0, 0
ret
crp_ok:
li a0, 1
ret
```
#### Assembly code : Arithmetic functions
* Integrates bf16_add, bf16_sub, bf16_mul, bf16_div, bf16_sqrt from w1-bfloat16.s
```
============================================================================
# - f32_to_bf16
# - bf16_to_f32
# - bf16_add
# - bf16_sub
# - bf16_mul
# - bf16_div
# - bf16_sqrt
# ============================================================================
# ============================================================================
# f32_to_bf16
# Input: a0 = 32-bit float bits
# Output: a0 = 16-bit bf16 bits
# ============================================================================
f32_to_bf16:
# callee-save
addi sp, sp, -8
sw ra, 0(sp)
sw s0, 4(sp)
mv s0, a0 # preserve original argument in s0
# Extract exponent: (f32bits >> 23) & 0xFF
li t0, 0xFF
srli t2, s0, 23
and t1, t2, t0 # t1 = exponent
beq t1, t0, f32_exception # if exponent == 0xFF -> NaN/Inf path
# Normal case: compute rounding bias
srli t0, s0, 16 # t0 = f32 >> 16
addi t1, x0, 1
and t0, t0, t1 # t0 = (f32 >> 16) & 1
li t1, 0x7FFF
add t0, t0, t1 # t0 = rounding bias (0x7FFF or 0x8000)
add t0, s0, t0 # add bias
srli a0, t0, 16 # top 16 bits -> bf16
# restore and return
beq x0, x0, f32_to_bf16_end
f32_exception:
srli a0, s0, 16 # for Inf/NaN: just take top 16 bits
f32_to_bf16_end:
lw ra, 0(sp)
lw s0, 4(sp)
addi sp, sp, 8
ret
# ============================================================================
# bf16_to_f32:
# Input: a0 = 16-bit bf16 bits
# Output: a0 = 32-bit float bits
# ============================================================================
bf16_to_f32:
slli a0, a0, 16
ret
# ============================================================================
# bf16_add: BF16 addition (bit-accurate software emulation)
# Input: a0 = bf16 a, a1 = bf16 b
# Output: a0 = bf16 (a + b)
# ============================================================================
bf16_add:
# callee save
addi sp, sp, -28
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
sw s3, 16(sp)
sw s4, 20(sp)
sw s5, 24(sp)
# if a is +-inf / NaN
srli s0, a0, 7 # s0 = a >> 7
andi s0, s0, 0xFF # s0 = exponent a
srli s1, a1, 7 # s1 = b >> 7
andi s1, s1, 0xFF # s1 = exponent b
andi s2, a0, 0x7F # s2 = mantissa a
andi s3, a1, 0x7F # s3 = mantissa b
srli s4, a0, 15 # s4 = sign a
srli s5, a1, 15 # s5 = sign b
li t6, 0xFF # t6 = 0xFF
bne t6, s0, a_actual_num # if a_exp != 0xFF jump
# a exponent == 0xFF
beq s3, x0, end_add # if a = NAN return a
beq s1, t6, a_b_exp_FF # if b exp = 0xFF, jump to handling section
jal x0, end_add # if b is actual number, return a
# a b exp = 0xFF
a_b_exp_FF:
bne s3, x0, return_NAN # if b is NAN, return NAN
beq s4, s5, end_add # if a b have same sign return a
jal x0, return_NAN # else return NaN
a_actual_num:
beq s1, t6, add_return_b # b exp = 0xFF, return b
# b = 0
slli t1, a1, 1
beq x0, t1, end_add
# a = 0
slli t1, a0, 1
beq x0, t1, add_return_b
# a b are both actual numbers
add_adjust_mant_a:
beq x0, s0, add_adjust_mant_b # a exp == 0,no need to adjust
addi s2, s2, 0x80 # retrieve hidden 1. for mantissa a
add_adjust_mant_b:
beq x0, s1, add_main # b exp == 0,no need to adjust
addi s3, s3, 0x80 # retrieve hidden 1. for mantissa b
##### ADD MAIN #####
# s0 = result exp
# s1 = result mantissa
# s2 = mantissa a
# s3 = mantissa b
# s4 = result sign
# t6 = 0xFF
add_main:
### fraction alignment ###
sub t0, s0, s1 # t0 = a_exp - b_exp
li t1, 8 # t1 = 8
mv t3, s2 # put mant_a in buffer
mv t4, s3 # put mant_b in buffer
# 8 <= exp_diff (exp_diff > 8)
blt t1, t0, end_add # |a| is too big, return a
# 8 <= -exp_diff (exp_diff < -8)
sub t2, x0, t0 # t2 = -exp_diff
blt t1, t2, add_return_b # |b| is too big, return b
# exp_diff == 0
beq t0, x0, true_add # s0 = a_exp already, just jump
# 8 > exp_diff > 0
srl t4, s3, t0 # t4 = mant_b >>= exp_diff;
blt t0, x0, true_add # mant_b aligned with s0(result exp) = a_exp, jump
# -8 < exp_diff < 0
sub t0, x0, t0 # t0 = -exp_diff (make positive)
srl t3, s2, t0 # t3 = mant_a >>= -exp_diff;
mv s0, s1 # s0 (result exp) = b exp
blt t0, x0, true_add # mant_a aligned with s0(result exp) = b_exp, jump
true_add:
mv s2, t3 # move aligned t3 (mant_a) to s2
mv s3, t4 # move aligned t4 (mant_b) to s3
bne s4, s5, add_diff_sign
add_same_sign:
slli s4, s4, 15 ## s4(result sign) = a sign << 15
add s1, s2, s3 ## s1(result mantissa) = mantissa (a + b)
andi t0, s1, 0x100 # t0 = overflow bit
beq t0, x0, add_result # if no overflow (t0==0), get result
add_handle_overflow:
srli s1, s1, 1 # s1(result_mant) >>= 1
addi s0, s0, 1 # s0(esult_exp) += 1
blt s0, t6, add_result # s0(esult_exp) < 0xFF, result number is normal, get result
mv s0, t6 # else set s0(esult_exp) = 0xFF
add s1, x0, x0 # s1(result mantissa) = 0
jal x0, add_result # overflow => return inf with according sign
add_diff_sign:
# assume mantissa a(s2) < b(s3), set result to sign_b
slli s4, s5, 15 # s4(result sign) = sign_b << 15
sub s1, s3, s2 # s1(result mantissa) = s3(mant_b) - s2(mant_a)
blt s2, s3, add_handle_zero # assumption is true, handle zero condition
# otherwise mantissa a >= b, set result sign = sign a
slli s4, s4, 15 # s4(result sign) = sign_a << 15
sub s1, s2, s3 # s1(result mantissa) = s2(mant_a) - s3(mant_b)
# check the result of substraction of aligned mantissa not be zero
# otherwise error exists when we use a loop to adjust mantissa
add_handle_zero:
beq s1, x0, add_return_zero #
add_adjust_mantissa:
andi t1, s1, 0x80 # t1 is the first bit of result_mant
bne t1, x0, add_result # if first mantissa bit = 1, done / else shift until 1 is found
slli s1, s1, 1 # s1(result_mant) <<= 1
addi s0, s0 -1 # s0(result_exp) -= 1
blt x0, s0, add_return_zero # if 0 >= result_exp, underflow
jal x0, add_adjust_mantissa
###
# s0 = result exp
# s1 = result mantissa
# s2 = mantissa a
# s3 = mantissa b
# s4 = result sign
add_result:
andi s0, s0, 0xFF # mask out the logic bit out for neg exp
slli s0, s0, 7 # left shift to match the correct format
andi s1, s1, 0x7F # mask out the first bit of mantissa (1.XX)
or a0, s4, s0 # result_sign | result exp
or a0, a0, s1 # result_sign | result exp | result mantissa
jal x0, end_add
### end of bf16_add function
add_return_zero:
mv a0, x0 #
jal x0, end_add
add_return_b:
mv a0, a1
jal x0, end_add
return_NAN:
lw a0, BF16_NAN
jal x0, end_add
end_add:
# retrieve ra and callee save
lw s5, 24(sp)
lw s4, 20(sp)
lw s3, 16(sp)
lw s2, 12(sp)
lw s1, 8(sp)
lw s0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 28
ret
# ============================================================================
# bf16_sub: implemented simply by flipping sign of b and calling bf16_add
# Input: a0 = a, a1 = b
# Output: a0 = a - b
# ============================================================================
bf16_sub:
addi sp, sp, -4
sw ra, 0(sp)
la t0, BF16_SIGN_MASK
xor a1, a1, t0 # flip sign bit of a1 (b)
mv a0, a0
jal ra, bf16_add
lw ra, 0(sp)
addi sp, sp, 4
ret
# ============================================================================
# bf16_mul: bit-accurate multiplication
# Input: a0 = a (bf16), a1 = b (bf16)
# Output: a0 = a * b (bf16)
# ============================================================================
bf16_mul:
# callee save
addi sp, sp, -28
sw ra, 0(sp)
sw s0, 4(sp) # s0 = exponent a
sw s1, 8(sp) # s1 = exponent b
sw s2, 12(sp) # s2 = mantissa a
sw s3, 16(sp) # s3 = mantissa b
sw s4, 20(sp) # s4 = sign a
sw s5, 24(sp) # s5 = sign b
# Decomposite a and b into sign / exp / mantissa
srli s0, a0, 7 # s0 = a >> 7
andi s0, s0, 0xFF # s0 = exp_a = ((a.bits >> 7) & 0xFF)
srli s1, a1, 7 # s1 = b >> 7
andi s1, s1, 0xFF # s1 = exp_b = ((b.bits >> 7) & 0xFF)
andi s2, a0, 0x7F # s2 = mant_a = a.bits & 0x7F
andi s3, a1, 0x7F # s3 = mant_b = b.bits & 0x7F
srli s4, a0, 15 # s4 = sign a
srli s5, a1, 15 # s5 = sign b
#####
# t3 = result mant
# t4 = result exp
# t5 = result sign
# t6 = 0xFF
xor t5, s4, s5 # t5(esult sign) = sign_a ^ sign_b
addi t6, x0, 0xFF # t6 = 0xFF
# check a exp
bne s0, t6, mul_check_exp_b # if exp_a != 0xFF, jump
bne s0, x0, end_mul # if a = NaN, return a
# a is +-inf
li t0, 0x7FFF # sign filter
and t1, a1, t0 # t1 = b | 0x7FFF
beq t1, x0, mul_return_NAN # if b = +- 0, inf * 0 = NaN
# return inf with correct sign
slli t5, t5, 15 #
li t0, 0x7F80
or a0, t5, t0 # (result_sign << 15) | 0x7F80
jal x0, end_mul
mul_check_exp_b:
bne s0, t6, mul_check_zero # if exp_b != 0xFF, jump
bne s0, x0, mul_return_b # if b = NaN, return b
# b is +-inf
li t0, 0x7FFF # sign filter
and t1, a0, t0 # t1 = a | 0x7FFF
beq t1, x0, mul_return_NAN # if a = +- 0, 0 * inf = NaN
# return inf with correct sign
slli t5, t5, 15 #
li t0, 0x7F80
or a0, t5, t0 # (result_sign << 15) | 0x7F80
jal x0, end_mul
### a, b both are actual numbers
# check 0 * 0 = 0
mul_check_zero:
or t0, s0, s1 # t0 = exp_a | exp_b
or t0, t0, s2 # t0 = exp_a | exp_b | mant_a
or t0, t0, s3 # t0 = exp_a | exp_b | mant_a | mant_b
bne t0, x0, mul_main # if t0 != 0, is not 0*0 case, jump
# return zero with correct sign
slli a0, t5, 15 # a0 = result_sign << 15
jal x0, end_mul
### a, b both are non zero actual numbers
# aligned the mentissa part and then do the multiple
mul_main:
add t0, x0, x0 # t0 = adjust exp = 0
mul_adjust_a:
beq s0, x0, mul_denormal_a # if s0(exp_a) = 0, a is denormal number
# a is non-zero normal number
ori s2, s2, 0x80 # retrieve 1.XXX in mant_a
jal x0, mul_adjust_b
mul_denormal_a:
# find the first set bit
andi t1, s2, 0x80 # t1 = s2(mant_a) & 0x80
bne t1, x0, mul_mant_a_aligned # while t1(first bit) is not found, loop
slli s2, s2, 1 # left shift s2(mant_a) to find the first set bit
addi t0, t0, -1 # exp_adjust--
jal x0, mul_denormal_a
mul_mant_a_aligned:
addi s0, x0, 1 # first set bit is found, exp_a should change from 0 to 1
mul_adjust_b:
beq s1, x0, mul_denormal_b # if s1(exp_b) = 0, b is denormal number
# a is non-zero normal number
ori s3, s3, 0x80 # retrieve 1.XXX in s3(mant_b)
jal x0, result_exp
mul_denormal_b:
# find the first set bit
andi t1, s3, 0x80 # t1 = s3(mant_b) & 0x80
bne t1, x0, mul_mant_a_aligned # while t1(first bit) is not found, loop
slli s3, s3, 1 # left shift s3(mant_b) to find the first set bit
addi t0, t0, -1 # exp_adjust--
jal x0, mul_denormal_b
# mantissas are non-zero positive integer, just multiple
# t0 = adjust_exp s0 = exponent a
# t3 = result mant s1 = exponent b
# t4 = result exp s2 = mantissa a
# t5 = result sign s3 = mantissa b
# t6 = 0xFF s4 = sign a
# s5 = sign b
result_exp:
add t4, s0, s1 # t4 (result exp) = exp_a + exp_b
addi t4, t4, -127 # result exp = exp_a + exp_b - 127
add t4, t4, t0 # result exp = exp_a + exp_b - 127 + adjust_exp
# t3 = result_mant = (uint32_t) s2 mant_a * s3 mant_b;
# while(mant_b != 0)
# iif(mant_b & 1) result += mant_a
# mant_a << 1, mant_b >> 1
true_mul:
beq x0, s3, mul_get_result # while (mant_b != 0)
andi t2, s3, 1 # t2 = mant_b & 1 = lsb bit
beq x0, t2, mul_skip # if lsb bit of mant_b is 0, do not add to result
add t3, t3, s2 # result_mant += (shifted) mant_a
mul_skip:
srli s3, s3, 1 # mant_b >> 1
slli s2, s2 ,1 # mant_a << 1
jal x0, true_mul # go back to while
mul_get_result:
# check result mentissa overflow
li t0, 0x8000 # overflow mask
and t0, t0, t3 # t0 = 0x8000 & t3 (result_mant)
beq t0, x0, mul_result_mant_adjust # if 0x8000 & t3 = 0 goto else
# overflow happened
srli t3, t3, 8 # result_mant >> 8
andi t3, t3, 0x7F # result_mant = (result_mant >> 8) & 0x7F
addi t4, t4, 1 # result_exp ++
jal x0, mul_result_exp_adjust
mul_result_mant_adjust:
# no mantissa overflow
srli t3, t3, 7 # t3 (result_mant) >> 7
andi t3, t3, 0x7F # t3 (result_mant) = (result_mant >> 7) & 0x7F
mul_result_exp_adjust:
# check exp again after checking mantissa overflow
blt t4, t6, mul_result_check_denormal
# return inf with correct sign
slli t5, t5, 15 # t5 (result_sign << 15)
li t0, 0x7F80
or a0, t5, t0 # (result_sign << 15) | 0x7F80
jal x0, end_mul
mul_result_check_denormal:
# 0 < result_exp , not denoraml
blt x0, t4, mul_result
li t0, -6
bge t4, t0, mul_result_handle_denormal
# return zero with correct sign
slli a0, t5, 15 # a0 = result_sign << 15
jal x0, end_mul
mul_result_handle_denormal:
addi t0, t4, -1 # t0 = s4(result_exp) -1
sub t0, x0, t0 # t0 = 1 - s4(result_exp)
srl t3, t3, t0 # ts (result_mant) >>= (1 - result_exp)
add t4, x0, x0 # t4 (result_exp) = 0
mul_result:
slli t5, t5, 15 # t5 (result_sign << 15)
and t4, t4, t6 # t4 (result_exp) |= 0xFF
slli t4, t4, 7 # (t4 (result_exp) | 0xFF) << 7
andi t3, t3, 0x7F # result_mant & 0x7F
or a0, t5, t4
or a0, a0, t3
jal x0, end_mul
mul_return_b:
mv a0, a1
jal x0, end_mul
mul_return_NAN:
lw a0, BF16_NAN
jal x0, end_mul
end_mul:
# retrieve ra and callee save
lw s5, 24(sp)
lw s4, 20(sp)
lw s3, 16(sp)
lw s2, 12(sp)
lw s1, 8(sp)
lw s0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 28
ret
# ---------------------------------------------------------------------------
# bf16_div: bit-accurate division
# Input: a0 = a (bf16), a1 = b (bf16)
# Output: a0 = a / b (bf16)
# ---------------------------------------------------------------------------
bf16_div:
####
## input argument
# a0 = (bf16) a
# a1 = (bf16) b
## output argument
# a0 = (bf16) a/b
####
# callee save
addi sp, sp, -28
sw ra, 0(sp)
sw s0, 4(sp) # s0 = exponent a
sw s1, 8(sp) # s1 = exponent b
sw s2, 12(sp) # s2 = mantissa a
sw s3, 16(sp) # s3 = mantissa b
sw s4, 20(sp) # s4 = sign a
sw s5, 24(sp) # s5 = sign b
# Decomposite a and b into sign / exp / mantissa
srli s0, a0, 7 # s0 = a >> 7
andi s0, s0, 0xFF # s0 = exp_a = ((a.bits >> 7) & 0xFF)
srli s1, a1, 7 # s1 = b >> 7
andi s1, s1, 0xFF # s1 = exp_b = ((b.bits >> 7) & 0xFF)
andi s2, a0, 0x7F # s2 = mant_a = a.bits & 0x7F
andi s3, a1, 0x7F # s3 = mant_b = b.bits & 0x7F
srli s4, a0, 15 # s4 = sign a
srli s5, a1, 15 # s5 = sign b
# t0 = adjust_exp s0 = exponent a
# t3 = result mant s1 = exponent b
# t4 = result exp s2 = mantissa a
# t5 = result sign s3 = mantissa b
# t6 = 0xFF s4 = sign a
# s5 = sign b
xor t5, s4, s5 # t5(esult sign) = sign_a ^ sign_b
addi t6, x0, 0xFF # t6 = 0xFF
# (exp_b == 0xFF)
div_check_exp_b:
bne s1, t6, div_check_zero_b # if s1 (exp_b) != 0xFF, jump
bne s1, x0, div_return_b # if b = NaN, return b
# inf / inf = NAN
# => a != inf , return_sign_zero
# => exp_a != 0xFF, return_sign_zero
# exp_a == 0xFF and mant_a != 0, return_sign_zero
bne s0, t0, div_return_sign_zero
bne s2, x0, div_return_sign_zero
jal x0, div_return_NAN
div_check_zero_b:
# (!exp_b && !mant_b)
bne s1, x0, div_check_exp_a # if s1 (exp_b) != 0, b!=0, pass
bne s3, x0, div_check_exp_a # if s3 (mant_b) != 0, b!=0, pass
# iif (!exp_a && !mant_a) / a=b=0, return NAN
# elsee return sign inf
bne s0, x0, div_return_sign_inf # if s0 (exp_a) != 0, a!=0, a/0=inf
bne s2, x0, div_return_sign_inf # if s2 (mant_a) != 0, a!=0, a/0=inf
# 0/0 = NAN
jal x0, div_return_NAN
div_check_exp_a:
bne s0, t6, div_check_zero_a # if s0 (exp_a) != 0xFF, jump
bne s1, x0, end_div # if a = NaN, return a
jal x0, div_return_sign_inf # else inf/b = inf
div_check_zero_a:
# (!exp_b && !mant_b)
bne s0, x0, div_get_mant_a # if s0 (exp_a) != 0, a!=0, pass
bne s2, x0, div_get_mant_a # if s2 (mant_a) != 0, a!=0, pass
# we have already handle 0/0 in the previous code
jal x0, div_return_sign_zero # else 0/b = 0
###
# During devision, there is no need to aligned the mantissa with ccorrect exp
div_get_mant_a:
beq x0, s2, div_get_mant_b # denormal a, pass
ori s2, s2, 0x80 # retrieve 1.XX
div_get_mant_b:
beq x0, s3, div_main # denormal b, pass
ori s3, s3, 0x80 # retrieve 1.XX
# t1 = i s0 = exponent a
# t2 = s1 = exponent b
# t3 = result mant s2 = mantissa a
# t4 = result exp s3 = mantissa b
# t5 = result sign s4 = divident
# t6 = quotient s5 = divisor
div_main:
slli s4, s2, 15 # dividend = (uint32_t) mant_a << 15
add s5, x0, s3 # divisor = mant_b
add t6, x0, x0 # initial quotient = 0
addi t0, x0, 16 # t0 = 16 (maximum persision)
add t1, x0, x0 # t1 = i = 0
true_div:
bge t1, t0, div_get_result # if t1 (i) >= 16 end loop
slli t6, t6, 1 # t6 (quotient) <<= 1
addi t2, x0, 15 # t2 = 15
sub t2, t2, t1 # t2 = 15 - i
sll t2, s5, t2 # divisor << (15 - i)
blt s4, t2, div_i_minus_one ## dividend < (divisor << (15 - i)), no quotient
sub s4, s4, t2 # dividend -= (divisor << (15 - i))
ori t6, t6, 1 # quotient |= 1
div_i_minus_one:
addi t1, t1, 1 # i++
jal x0, true_div
###
div_get_result:
sub t4, s0, s1 # t4 (result exp) = exp_a - exp_b
addi t4, t4, 127 # t4 (result exp) = exp_a - exp_b + BF16_EXP_BIAS
div_zero_exp_a_correction:
## if (!exp_a), result_exp--
bne x0, s0, div_zero_exp_b_correction
addi t4, t4, -1
div_zero_exp_b_correction:
## if (!exp_b), result_exp++
bne x0, s1, div_check_quotient
addi t4, t4, 1
div_check_quotient:
li t0, 0x8000 # t0 = 0x8000
and t0, t6, t0 # t0 = quotient & 0x8000
bne t0, x0, div_result_mant_shift
div_result_mant_adjust:
## find the first set bit
# t0 = quotient & 0x8000
li t0, 0x8000 # t0 = 0x8000
and t0, t6, t0
bne t0, x0, div_result_mant_shift
# result_exp-1 <= 0, jump
addi t1, t4, -1 # t1 = result_exp - 1
bge x0, t1, div_result_mant_shift
slli t6, t6, 1 # t6 (quotient) << 1
addi t4, t4, -1 # result_exp --
div_result_mant_shift:
srli t6, t6, 8 # quotient >>= 8
li t0, 0xFF
bge t4, t0, div_return_sign_inf
bge x0, t4, div_return_sign_zero
div_result:
slli t5, t5, 15 # t5 (result_sign << 15)
and t4, t4, t0 # t4 (result_exp) |= 0xFF
slli t4, t4, 7 # (t4 (result_exp) | 0xFF) << 7
andi t6, t6, 0x7F # quotient & 0x7F
or a0, t5, t4
or a0, a0, t6
jal x0, end_mul
####
div_return_sign_zero:
slli a0, t5, 15 # result_sign << 15
jal x0, end_div
div_return_sign_inf:
slli a0, t5, 15 # result_sign << 15
li t0, 0x7F80 # t0 = 0x780
or a0, a0, t0 # result_sign << 15 | 0x7F80
jal x0, end_div
div_return_NAN:
lw a0, BF16_NAN
jal x0, end_mul
div_return_b:
mv a0, a1
jal x0, end_div
end_div:
# retrieve ra and callee save
lw s5, 24(sp)
lw s4, 20(sp)
lw s3, 16(sp)
lw s2, 12(sp)
lw s1, 8(sp)
lw s0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 28
ret
# ---------------------------------------------------------------------------
# bf16_sqrt: bit-accurate square-root implementation (binary-search)
# Input: a0 = bf16
# Output: a0 = bf16(sqrt(a))
# ---------------------------------------------------------------------------
bf16_sqrt:
addi sp, sp, -20
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
sw s3, 16(sp)
li s3, 0xFF
srli s0, a0, 7
and s0, s0, s3
andi s1, a0, 0x7F
srli s2, a0, 15
bne s2, x0, sqrt_ret_nan
bne s0, s3, sqrt_ck_input_0
jal x0, end_sqrt
sqrt_ck_input_0:
bne s0, x0, sqrt_ck_input_denormal
bne s1, x0, sqrt_ret_zero
jal sqrt_ret_zero
sqrt_ck_input_denormal:
beq s0, x0, sqrt_ret_zero
sqrt_main:
addi s0, s0, -127
ori s1, s1, 0x80
andi t0, s0, 0x1
beq t0, x0, handle_even_exp
handle_odd_exp:
slli s1, s1, 1
addi t6, s0, -1
srli t6, t6, 1
addi t6, t6, 127
jal x0, true_sqrt
handle_even_exp:
add t6, s0, x0
srli t6, t6, 1
addi t6, t6, 127
true_sqrt:
li t3, 90
li t4, 256
li t5, 0x10
binary_search:
blt t4, t3, sqrt_normalized_result
add t1, t3, t4
srli t1, t1, 1
mv a0, t1
mv a1, t1
jal ra, int_mul
srli t2, a0, 7
blt s1, t2, sqrt_too_big
mv t5, t1
addi t3, t1, 1
jal x0, binary_search
sqrt_too_big:
addi t4, t1, -1
jal x0, binary_search
sqrt_normalized_result:
li t0, 256
blt t5, t0, sqer_borrow_exp
srli t5, t5, 1
addi t6, t6, 1
jal x0, remove_result_mant_one
sqer_borrow_exp:
li t0, 128
li t1, 1
bge t5, t0, remove_result_mant_one
bge t1, t6, remove_result_mant_one
slli t5, t5, 1
addi t6, t6, -1
jal x0, sqrt_exp_adjust_loop
sqrt_exp_adjust_loop:
bge t5, t0, remove_result_mant_one
bge t1, t6, remove_result_mant_one
slli t5, t5, 1
addi t6, t6, -1
jal x0, sqrt_exp_adjust_loop
remove_result_mant_one:
andi t5, t5, 0x7F
bge t6, s3, sqrt_ret_inf
bge x0, t6, sqrt_ret_zero
sqrt_get_result:
and a0, t6, s3
slli a0, a0, 7
or a0, a0, t5
jal x0, end_sqrt
sqrt_ret_inf:
li a0, 0x7F80
jal x0, end_sqrt
sqrt_ret_zero:
lw a0, BF16_ZERO
jal x0, end_sqrt
sqrt_ret_nan:
lw a0, BF16_NAN
jal x0, end_sqrt
end_sqrt:
lw s3, 16(sp)
lw s2, 12(sp)
lw s1, 8(sp)
lw s0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 20
ret
# ---------------------------------------------------------------------------
# int_mul: small integer multiply utility used by sqrt (schoolbook)
# Input: a0 = integer multiplicand, a1 = integer multiplier
# Output: a0 = a0 * a1
# ---------------------------------------------------------------------------
int_mul:
add t0, x0, x0
int_mul_loop:
beq x0, a1, end_int_mul
andi t2, a1, 1
beq x0, t2, int_mul_skip
add t0, t0, a0
int_mul_skip:
srli a1, a1, 1
slli a0, a0, 1
jal x0, int_mul_loop
end_int_mul:
mv a0, t0
ret
# ============================================================================
# End of file
# ============================================================================
```
### Conversions, special values, arithmetic, comparisons, edge cases
#### Basic Conversions Test
| Test # | Input (F32) | F32 Hex | Expected Behavior | Test Item |
|--------|-------------|---------|-------------------|-----------|
| 1 | 0.0f | 0x00000000 | Sign match + Error check | Positive zero |
| 2 | 1.0f | 0x3F800000 | Sign match + Error check | Positive number |
| 3 | -1.0f | 0xBF800000 | Sign match + Error check | Negative number |
| 4 | 2.0f | 0x40000000 | Sign match + Error check | Power of 2 |
| 5 | -2.0f | 0xC0000000 | Sign match + Error check | Negative power of 2 |
| 6 | 0.5f | 0x3F000000 | Sign match + Error check | Fraction |
| 7 | -0.5f | 0xBF000000 | Sign match + Error check | Negative fraction |
| 8 | 3.14159f | 0x40490FDB | Sign match + Error check | π |
| 9 | -3.14159f | 0xC0490FDB | Sign match + Error check | -π |
| 10 | 1e10f | 0x501502F9 | Sign match + Error check | Large number |
| 11 | -1e10f | 0xD01502F9 | Sign match + Error check | Large negative |
#### Special Values Test
| Test # | Input (F32) | F32 Hex | BF16 Expected | Test Function | Expected Result |
|--------|-------------|---------|---------------|---------------|-----------------|
| 1a | +∞ | 0x7F800000 | 0x7F80 | `bf16_isinf()` | Return 1 |
| 1b | +∞ | 0x7F800000 | 0x7F80 | `!bf16_isnan()` | Return 0 |
| 2 | -∞ | 0xFF800000 | 0xFF80 | `bf16_isinf()` | Return 1 |
| 3a | NaN | 0x7FC00000 | 0x7FC0 | `bf16_isnan()` | Return 1 |
| 3b | NaN | 0x7FC00000 | 0x7FC0 | `!bf16_isinf()` | Return 0 |
| 4 | +0.0 | 0x00000000 | 0x0000 | `bf16_iszero()` | Return 1 |
| 5 | -0.0 | 0x80000000 | 0x8000 | `bf16_iszero()` | Return 1 |s
#### Assembly code :
```
# ============================================================================
# BFloat16 Complete Test
# ============================================================================
# Tests: conversions, special values, arithmetic, comparisons, edge cases
# ============================================================================
.data
# Test values for basic conversions (Float32 as uint32_t bits)
test_values:
.word 0x00000000 # 0.0f
.word 0x3F800000 # 1.0f
.word 0xBF800000 # -1.0f
.word 0x40000000 # 2.0f
.word 0xC0000000 # -2.0f
.word 0x3F000000 # 0.5f
.word 0xBF000000 # -0.5f
.word 0x40490FDB # 3.14159f
.word 0xC0490FDB # -3.14159f
.word 0x501502F9 # 1e10f
.word 0xD01502F9 # -1e10f
test_count: .word 11
# Expected results for arithmetic
expected_add: .word 0x40400000 # 3.0f (1.0 + 2.0)
expected_sub: .word 0x3F800000 # 1.0f (2.0 - 1.0)
expected_mul: .word 0x41400000 # 12.0f (3.0 * 4.0)
expected_div: .word 0x40A00000 # 5.0f (10.0 / 2.0)
expected_sqrt4: .word 0x40000000 # 2.0f (sqrt(4.0))
expected_sqrt9: .word 0x40400000 # 3.0f (sqrt(9.0))
# Test counters
tests_passed: .word 0
tests_failed: .word 0
# Output strings
str_banner: .string "\n=== BFloat16 Test Suite ===\n\n"
str_test_basic: .string "Testing basic conversions...\n"
str_test_special: .string "Testing special values...\n"
str_test_arith: .string "Testing arithmetic operations...\n"
str_test_comp: .string "Testing comparison operations...\n"
str_test_edge: .string "Testing edge cases...\n"
str_test_round: .string "Testing rounding behavior...\n"
str_pass: .string " PASS\n"
str_fail: .string " FAIL\n"
str_test_num: .string " Test "
str_colon: .string ": "
str_summary: .string "\n=== Test Summary ===\n"
str_passed_count: .string "Tests passed: "
str_failed_count: .string "Tests failed: "
str_all_pass: .string "\n=== ALL TESTS PASSED ===\n"
str_some_fail: .string "\n=== SOME TESTS FAILED ===\n"
str_newline: .string "\n"
# Test names for detailed output
str_sign_mismatch: .string " [Sign mismatch]\n"
str_error_large: .string " [Relative error too large]\n"
str_inf_detect: .string " [Infinity detection]\n"
str_nan_detect: .string " [NaN detection]\n"
str_zero_detect: .string " [Zero detection]\n"
str_add_fail: .string " [Addition failed]\n"
str_sub_fail: .string " [Subtraction failed]\n"
str_mul_fail: .string " [Multiplication failed]\n"
str_div_fail: .string " [Division failed]\n"
str_sqrt_fail: .string " [sqrt failed]\n"
.text
.globl main
# ============================================================================
# MAIN: Entry point
# ============================================================================
main:
addi sp, sp, -16
sw ra, 12(sp)
# Print banner
la a0, str_banner
call print_string
# Run all test suites
call test_basic_conversions
call test_special_values
# Print summary
call print_summary
# Check if any failed
la t0, tests_failed
lw t0, 0(t0)
beqz t0, main_all_pass
la a0, str_some_fail
call print_string
li a0, 1
j main_exit
main_all_pass:
la a0, str_all_pass
call print_string
li a0, 0
main_exit:
lw ra, 12(sp)
addi sp, sp, 16
li a7, 10
ecall
# ============================================================================
# TEST_BASIC_CONVERSIONS
# ============================================================================
test_basic_conversions:
addi sp, sp, -32
sw ra, 28(sp)
sw s0, 24(sp) # s0 = loop counter
sw s1, 20(sp) # s1 = test_values pointer
sw s2, 16(sp) # s2 = test_count
sw s3, 12(sp) # s3 = original f32
sw s4, 8(sp) # s4 = bf16 result
sw s5, 4(sp) # s5 = converted f32
la a0, str_test_basic
call print_string
la s1, test_values
la t0, test_count
lw s2, 0(t0)
li s0, 0
basic_loop:
bge s0, s2, basic_done
lw s3, 0(s1) # s3 = original f32
# Convert f32 → bf16
mv a0, s3
call f32_to_bf16
mv s4, a0 # s4 = bf16
# Convert bf16 → f32
mv a0, s4
call bf16_to_f32
mv s5, a0 # s5 = converted f32
# Test 1: Check sign consistency (if not zero)
beqz s3, basic_skip_sign
mv a0, s3
mv a1, s5
call check_sign_match
beqz a0, basic_sign_fail
basic_skip_sign:
# Test 2: Check relative error (if not zero and not inf)
beqz s3, basic_continue
mv a0, s3
call is_f32_infinity
bnez a0, basic_continue
mv a0, s3
mv a1, s5
call check_relative_error_1pct
beqz a0, basic_error_fail
basic_continue:
# Both tests passed for this value
addi s1, s1, 4
addi s0, s0, 1
j basic_loop
basic_sign_fail:
# Sign test failed - report and exit immediately
la a0, str_sign_mismatch
call print_string
call increment_failed
j basic_fail_exit
basic_error_fail:
# Error test failed - report and exit immediately
la a0, str_error_large
call print_string
call increment_failed
j basic_fail_exit
basic_fail_exit:
# Exit without printing PASS
lw s5, 4(sp)
lw s4, 8(sp)
lw s3, 12(sp)
lw s2, 16(sp)
lw s1, 20(sp)
lw s0, 24(sp)
lw ra, 28(sp)
addi sp, sp, 32
ret
basic_done:
# All tests passed
call increment_passed
la a0, str_pass
call print_string
lw s5, 4(sp)
lw s4, 8(sp)
lw s3, 12(sp)
lw s2, 16(sp)
lw s1, 20(sp)
lw s0, 24(sp)
lw ra, 28(sp)
addi sp, sp, 32
ret
# ============================================================================
# TEST_SPECIAL_VALUES
# ============================================================================
test_special_values:
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp)
sw s1, 4(sp)
la a0, str_test_special
call print_string
# Test 1a: Positive infinity - should be infinity
li s0, 0x7F800000 # +inf in f32
mv a0, s0
call f32_to_bf16
mv s1, a0
call bf16_isinf
bnez a0, special_test1a_pass
la a0, str_inf_detect
call print_string
call increment_failed
j special_test1b
special_test1a_pass:
call increment_passed
special_test1b:
# Test 1b: Positive infinity - should NOT be NaN
mv a0, s1
call bf16_isnan
beqz a0, special_test1b_pass
la a0, str_nan_detect
call print_string
call increment_failed
j special_test2
special_test1b_pass:
call increment_passed
special_test2:
# Test 2: Negative infinity
li s0, 0xFF800000 # -inf in f32
mv a0, s0
call f32_to_bf16
mv s1, a0
call bf16_isinf
bnez a0, special_test2_pass
la a0, str_inf_detect
call print_string
call increment_failed
j special_test3a
special_test2_pass:
call increment_passed
special_test3a:
# Test 3a: NaN - should be NaN
li s0, 0x7FC00000 # NaN in f32
mv a0, s0
call f32_to_bf16
mv s1, a0
call bf16_isnan
bnez a0, special_test3a_pass
la a0, str_nan_detect
call print_string
call increment_failed
j special_test3b
special_test3a_pass:
call increment_passed
special_test3b:
# Test 3b: NaN - should NOT be infinity
mv a0, s1
call bf16_isinf
beqz a0, special_test3b_pass
la a0, str_inf_detect
call print_string
call increment_failed
j special_test4
special_test3b_pass:
call increment_passed
special_test4:
# Test 4: Positive zero
li s0, 0x00000000 # +0.0 in f32
mv a0, s0
call f32_to_bf16
mv s1, a0
call bf16_iszero
bnez a0, special_test4_pass
la a0, str_zero_detect
call print_string
call increment_failed
j special_test5
special_test4_pass:
call increment_passed
special_test5:
# Test 5: Negative zero
li s0, 0x80000000 # -0.0 in f32
mv a0, s0
call f32_to_bf16
mv s1, a0
call bf16_iszero
bnez a0, special_test5_pass
la a0, str_zero_detect
call print_string
call increment_failed
j special_done
special_test5_pass:
call increment_passed
special_done:
la a0, str_pass
call print_string
lw s1, 4(sp)
lw s0, 8(sp)
lw ra, 12(sp)
addi sp, sp, 16
ret
# ============================================================================
# CONVERSION FUNCTIONS
# ============================================================================
# F32_TO_BF16: Convert Float32 to BFloat16
# Input: a0 = float32 bits (32-bit)
# Output: a0 = bfloat16 bits (16-bit, in lower 16 bits of a0)
f32_to_bf16:
# Get exponent (bits 30:23)
srli t0, a0, 23
andi t0, t0, 0xFF
# Check for special cases (NaN or Inf: exponent == 0xFF)
li t1, 0xFF
beq t0, t1, f32_special
# Normal case: round-to-nearest-even
# Rounding bias = (f32 >> 16) & 1 + 0x7FFF
srli t2, a0, 16 # Get bit 16 (LSB of BF16)
andi t2, t2, 1 # Isolate bit 16
# Create 0x7FFF
lui t3, 0x8 # t3 = 0x8000
addi t3, t3, -1 # t3 = 0x7FFF
# Add rounding bias
add t2, t2, t3 # t2 = 0x7FFF or 0x8000
add a0, a0, t2 # Add bias to original value
# Extract top 16 bits (this is the BF16 result)
srli a0, a0, 16
ret
f32_special:
# For special values (NaN/Inf), just truncate
srli a0, a0, 16
ret
# BF16_TO_F32: Convert BFloat16 to Float32
# Input: a0 = bfloat16 bits (16-bit)
# Output: a0 = float32 bits (32-bit)
bf16_to_f32:
# BF16 is top 16 bits of FP32, just shift left by 16
slli a0, a0, 16
ret
# ============================================================================
# SPECIAL VALUE CHECKS
# ============================================================================
# BF16_ISINF: Check if BF16 is infinity
# Input: a0 = bf16 (16-bit)
# Output: a0 = 1 if inf, 0 otherwise
bf16_isinf:
# Extract exponent (bits 14:7)
srli t0, a0, 7
andi t0, t0, 0xFF
# Extract mantissa (bits 6:0)
andi t1, a0, 0x7F
# Is inf if exponent==0xFF and mantissa==0
li t2, 0xFF
bne t0, t2, isinf_false
bnez t1, isinf_false
li a0, 1
ret
isinf_false:
li a0, 0
ret
# BF16_ISNAN: Check if BF16 is NaN
# Input: a0 = bf16 (16-bit)
# Output: a0 = 1 if NaN, 0 otherwise
bf16_isnan:
# Extract exponent (bits 14:7)
srli t0, a0, 7
andi t0, t0, 0xFF
# Extract mantissa (bits 6:0)
andi t1, a0, 0x7F
# Is NaN if exponent==0xFF and mantissa!=0
li t2, 0xFF
bne t0, t2, isnan_false
beqz t1, isnan_false
li a0, 1
ret
isnan_false:
li a0, 0
ret
# BF16_ISZERO: Check if BF16 is zero (including -0)
# Input: a0 = bf16 (16-bit)
# Output: a0 = 1 if zero, 0 otherwise
bf16_iszero:
# Zero if all bits except sign are 0
# Create mask 0x7FFF
lui t1, 0x8 # t1 = 0x8000
addi t1, t1, -1 # t1 = 0x7FFF
and t0, a0, t1 # Mask off sign bit
seqz a0, t0 # a0 = (t0 == 0) ? 1 : 0
ret
# ============================================================================
# HELPER FUNCTIONS FOR TESTING
# ============================================================================
# CHECK_SIGN_MATCH: Check if two f32 have same sign
# Input: a0 = f32_1, a1 = f32_2
# Output: a0 = 1 if same sign, 0 otherwise
check_sign_match:
srli t0, a0, 31 # Get sign bit of a0
srli t1, a1, 31 # Get sign bit of a1
xor t0, t0, t1 # XOR: 0 if same, 1 if different
seqz a0, t0 # a0 = 1 if t0==0 (same sign)
ret
# IS_F32_INFINITY: Check if f32 is infinity
# Input: a0 = f32
# Output: a0 = 1 if inf, 0 otherwise
is_f32_infinity:
# Extract exponent (bits 30:23)
srli t0, a0, 23
andi t0, t0, 0xFF
# Extract mantissa (bits 22:0)
lui t2, 0x80 # Load 0x80000
addi t2, t2, -1 # t2 = 0x7FFFF
and t1, a0, t2 # Mask mantissa
# Inf if exponent==0xFF and mantissa==0
li t3, 0xFF
bne t0, t3, f32_not_inf
bnez t1, f32_not_inf
li a0, 1
ret
f32_not_inf:
li a0, 0
ret
# CHECK_RELATIVE_ERROR_1PCT: Check if relative error < 1%
# Input: a0 = expected (f32), a1 = actual (f32)
# Output: a0 = 1 if error acceptable, 0 otherwise
check_relative_error_1pct:
# For BF16, check if top 16 bits are close
srli t0, a0, 16
srli t1, a1, 16
sub t2, t0, t1 # Difference
# Get absolute difference
bgez t2, check_positive
sub t2, zero, t2 # Make positive
check_positive:
# Allow difference of ±1 for rounding
li t3, 1
ble t2, t3, error_ok
li a0, 0
ret
error_ok:
li a0, 1
ret
# ============================================================================
# TEST COUNTER FUNCTIONS
# ============================================================================
increment_passed:
la t0, tests_passed
lw t1, 0(t0)
addi t1, t1, 1
sw t1, 0(t0)
ret
increment_failed:
la t0, tests_failed
lw t1, 0(t0)
addi t1, t1, 1
sw t1, 0(t0)
ret
# ============================================================================
# OUTPUT FUNCTIONS
# ============================================================================
print_summary:
addi sp, sp, -16
sw ra, 12(sp)
la a0, str_summary
call print_string
la a0, str_passed_count
call print_string
la t0, tests_passed
lw a0, 0(t0)
call print_int
la a0, str_newline
call print_string
la a0, str_failed_count
call print_string
la t0, tests_failed
lw a0, 0(t0)
call print_int
la a0, str_newline
call print_string
lw ra, 12(sp)
addi sp, sp, 16
ret
print_string:
li a7, 4
ecall
ret
print_int:
li a7, 1
ecall
ret
```
## Analysis 5-stage pipelined processor for problem B
[Introduction to 5-stage pipelined processor on Ripes ](https://hackmd.io/@CarSam/Bkz2amw6xl)
Analysing the first instruction of problem B
```
addi sp, sp, -16
```

### Why its machine code is **ff010113**?
According to the folloing picture, ```addi``` is I-Type instruction, its opcode is ```0010011``` .

The address of stack pointer(sp) at ```x2``` is ```0x7ffffff0```.
The 12 bits binary form of -16 is ```111111110000```
After combiming these infomation, we know the machine code in binary form is
```
imm rs1 funct3 rd opcode
111111110000 00010 000 00010 0010011
```
Then turn it into hexideximal form
```
0xFF010113
```
### 5 stage pipeline processor

There are 5 stages:
1. Instruction fetch (IF)
2. Instruction decode and register fetch (ID)
3. Execute (EX)
4. Memory access (MEM)
5. Register write back (WB)
Let's see how the instruction go through each stage:
1. IF

* ```addr``` of input insturction is ```0x00000000```.
* The machine code of first instructio is ```0xFF010113```.
* Next PC will become PC+4 automatically.
2. ID

* OP code is ```0b010011```
* I-type instruction read rs1 value at R2, which value is```0x10```
* Immediate is ```0xfffffff0```
3. EX

```
ALU_result = Reg1 + immediate
= 0x7FFFFFF0 + (-16)
= 0x7FFFFFE0
```
* Output of ALU is ```0x7fffffe0```
4.MEM

* Since ```addi``` don't store or load data from memory, Read out is```0x00000000```
5.WB

* The output value ```0x7fffffe0``` and Wr data are send back to registers block.
After all these stage are done, the register is updated like this:
