# Assignment1: RISC-V Assembly and Instruction Pipeline
contributed by <[`hsuhsuhs`](https://github.com/hsuhsuhs/2025_NCKU_Computer_Architecture/tree/main/Assignment1)>
---
[TOC]
## Problem `B` in [Quiz1](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
- [ ] Decoding
$$ D(b) = m \cdot 2^e + (2^e - 1) \cdot 16 $$
Where $e = \lfloor b/16 \rfloor$ (upper 4 bits)and $m = b \bmod 16$(lower 4 bits)
- [ ] Encoding
$$ E(v) = \begin{cases}
v, & \text{if } v < 16 \\
16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise}
\end{cases} $$
where $\text{offset}(e) = (2^e - 1) \cdot 16$
### Original C program
<details>
<summary><b>Open to see the complete C program </b></summary>
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
// count leading zeros)
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c; // If the result is still non-zero after right-shifting by c bits, it means the highest bit is still in y
if (y) {
n -= c; // Reduce the number of leading
x = y; // Narrow down the search range
}
c >>= 1; // 16 -> 8 -> 4 -> 2 -> 1 -> 0
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
// (0x7FFF >> (15 - exponent)) means (2^e)-1
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
// 1) Roughly estimate e
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15) // [eeee | mmmm] e <= 15
exponent = 15;
// 2) Compute offset(e) via recurrence
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
// 3) If the estimate was too large, adjust downward (safety correction)
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1; // Invert recurrence: go back to offset(e-1)
exponent--;
}
}
// 4) Adjust upward to find the exact e: keep advancing while we can cross into the next bucket
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16; // = offset(e+1)
if (value < next_overflow) // Can't cross further → current e is correct
break;
overflow = next_overflow; // Move forward to the next bucket
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent; // = floor((value-offset(e))/2^e)
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
</details>
### Assembly code
<h4>
<span style="color:darkblue">version <code>1</code></span>
</h4>
```s
.data
# Test data
test1: .word 15 # small value
test2: .word 108 # medium value
test3: .word 1000000 # large value
# Test messages
test_start_msg: .string "=== UF8 Automated Test ===\n\n"
test1_msg: .string "Test 1 (small): "
test2_msg: .string "Test 2 (medium): "
test3_msg: .string "Test 3 (large): "
arrow_msg: .string " -> Encoded: "
decode_result_msg: .string ", Decoded: "
pass_msg: .string " PASS\n"
fail_msg: .string " FAIL\n"
separator: .string "\n====================\n"
all_pass_msg: .string "All tests passed.\n"
some_fail_msg: .string "Tests failed.\n"
newline: .string "\n"
.text
.globl _start
_start:
jal main
li a7, 10
ecall
print_string:
li a7, 4
ecall
jr ra
print_int:
li a7, 1
ecall
jr ra
print_char:
li a7, 11
ecall
jr ra
# clz function (count leading zeros)
clz:
li t0, 32 # n = 32
li t1, 16 # c = 16
clz_loop:
srl t2, a0, t1 # y = x >> c
beqz t2, clz_skip # if (y == 0) skip
sub t0, t0, t1 # n -= c
mv a0, t2 # x = y
clz_skip:
srli t1, t1, 1 # c >>= 1
bnez t1, clz_loop # while (c != 0)
sub a0, t0, a0 # return n - x
jr ra
# UF8 decoding function
uf8_decode:
andi t0, a0, 0x0F # mantissa = fl & 0x0F
srli t1, a0, 4 # exponent = fl >> 4
# offset = (0x7FFF >> (15 - exponent)) << 4
li t2, 0x7FFF
li t3, 15
sub t3, t3, t1 # 15 - exponent
srl t2, t2, t3 # 0x7FFF >> (15 - exponent)
slli t2, t2, 4 # << 4
sll t0, t0, t1 # mantissa << exponent
add a0, t0, t2 # return value
jr ra
# UF8 encoding function
uf8_encode:
# Special case: value < 16
li t0, 16
blt a0, t0, encode_small
addi sp, sp, -20
sw ra, 16(sp)
sw s0, 12(sp) # value
sw s1, 8(sp) # exponent
sw s2, 4(sp) # overflow
sw s3, 0(sp) # temporary variable
mv s0, a0 # s0 = value
li s1, 0 # exponent = 0
li s2, 0 # overflow = 0
# Estimate using CLZ
mv a0, s0
jal clz
li t0, 31
sub t0, t0, a0 # msb = 31 - clz(value)
# If msb >= 5, estimate exponent
li t1, 5
blt t0, t1, upscan_init
# Estimate exponent = msb - 4
addi s1, t0, -4 # exponent = msb - 4
li t1, 15
ble s1, t1, calc_overflow
li s1, 15 # clamp to max = 15
calc_overflow:
# overflow = 16 * (2^exponent - 1)
beqz s1, upscan_init # if exponent == 0 → overflow = 0
li t2, 0 # loop counter
li s2, 0 # overflow = 0
build_loop:
slli s2, s2, 1 # overflow << 1
addi s2, s2, 16 # overflow + 16
addi t2, t2, 1
blt t2, s1, build_loop
j check_adjust
check_adjust:
beqz s1, upscan_init # if exponent == 0 → skip
# If value < overflow, adjust downward
bgeu s0, s2, upscan_init
addi s2, s2, -16 # overflow - 16
srli s2, s2, 1 # (overflow - 16) >> 1
addi s1, s1, -1 # exponent--
j upscan_init # continue (only adjust once)
upscan_init:
# Start upward adjustment
li s3, 15 # max_exponent = 15
upscan_loop:
bge s1, s3, upscan_done
# next_overflow = (overflow << 1) + 16
slli t0, s2, 1
addi t0, t0, 16
# Stop if value < next_overflow
blt s0, t0, upscan_done
# Move to next range
mv s2, t0 # overflow = next_overflow
addi s1, s1, 1 # exponent++
j upscan_loop
upscan_done:
# mantissa = (value - overflow) >> exponent
sub t0, s0, s2
srl t0, t0, s1
# Limit mantissa to 15
li t1, 15
ble t0, t1, encode_pack
li t0, 15
encode_pack:
# Combine: (exponent << 4) | mantissa
slli s1, s1, 4
or a0, s1, t0
lw ra, 16(sp)
lw s0, 12(sp)
lw s1, 8(sp)
lw s2, 4(sp)
lw s3, 0(sp)
addi sp, sp, 20
jr ra
encode_small:
# For value < 16, return directly
jr ra
# Automated test function – test one value (returns 0=pass, 1=fail)
test_single_value:
# a0 = test value
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp) # original value
sw s1, 4(sp) # encoded
sw s2, 0(sp) # decoded
mv s0, a0 # save original
# Show test info
mv a0, s0
jal print_int
la a0, arrow_msg
jal print_string
# Encode
mv a0, s0
jal uf8_encode
mv s1, a0 # save encoded result
# Print encoded value
mv a0, s1
jal print_int
# Decode
mv a0, s1
jal uf8_decode
mv s2, a0 # save decoded result
# Print decoded result
la a0, decode_result_msg
jal print_string
mv a0, s2
jal print_int
# Check decoded == original
bne s2, s0, test_fail
# Pass
la a0, pass_msg
jal print_string
li a0, 0
j test_done
test_fail:
# Fail
la a0, fail_msg
jal print_string
li a0, 1
test_done:
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
lw s2, 0(sp)
addi sp, sp, 16
jr ra
# Main – automated tests for three values
main:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp) # test summary (0 = all passed)
li s0, 0 # assume all pass
la a0, test_start_msg
jal print_string
# Test 1: small value
la a0, test1_msg
jal print_string
lw a0, test1
jal test_single_value
or s0, s0, a0 # merge result
# Test 2: medium value
la a0, test2_msg
jal print_string
lw a0, test2
jal test_single_value
or s0, s0, a0
# Test 3: large value
la a0, test3_msg
jal print_string
lw a0, test3
jal test_single_value
or s0, s0, a0
# Print summary
la a0, separator
jal print_string
bnez s0, tests_failed
# All passed
la a0, all_pass_msg
jal print_string
j main_done
tests_failed:
# Some tests failed
la a0, some_fail_msg
jal print_string
main_done:
lw ra, 4(sp)
lw s0, 0(sp)
addi sp, sp, 8
jr ra
```
**1. Execution information** - <span style="color:darkblue">**$755 \;\text{cycles}$**</span>

---
<h4>
<span style="color:darkblue">version <code>2</code></span>
</h4>
**1. Improvement**
* **Revised** the `uf8_decode` function in <span style="color:darkblue">**version**</span> **`1`**:
use `((1 << e) - 1) << 4` instead of `(0x7FFF >> (15 - e)) << 4`
```s
# UF8 decoding function --> value = (m<<e) + (((1<<e)-1)<<4)
uf8_decode:
andi t0, a0, 0x0F # m = b & 0x0F
srli t1, a0, 4 # e = b >> 4
li t2, 1
sll t2, t2, t1 # 1<<e
addi t2, t2, -1 # (1<<e)-1
slli t2, t2, 4 # offset = ((1<<e)-1)<<4
sll t0, t0, t1 # m<<e
add a0, t0, t2 # value
jr ra
```
**2. Execution information** - <span style="color:darkblue">**$749 \;\text{cycles}$**</span>

**3. why faster ?**
The difference lies in the number of instructions and the overhead of loading constants.
| | `decode_v1` | `decode_v2` |
|:--------------------:|:-----------:|:-----------:|
| immediate instrs.num | 2 | 1 |
| instrs.num | 10 | 9 |
---
<h4>
<span style="color:darkblue">version <code>3</code></span>
</h4>
**1. Improvement**
* **Unloop** `CLZ` function in <span style="color:darkblue">**version**</span> **`2`**:
```s
# clz function (count leading zeros) --> unloop
clz:
beqz a0, 9f # x==0 → 32
li t0, 0 # n = 0
# Check upper 16 bits
srli t1, a0, 16 # t1 = x >> 16
bnez t1, 1f # if (t1 != 0) skip
addi t0, t0, 16 # n += 16
slli a0, a0, 16 # x <<= 16
1:
# Check upper 8 bits
srli t1, a0, 24 # t1 = x >> 24
bnez t1, 2f # if (t1 != 0) skip
addi t0, t0, 8 # n += 8
slli a0, a0, 8 # x <<= 8
2:
# Check upper 4 bits
srli t1, a0, 28 # t1 = x >> 28
bnez t1, 3f # if (t1 != 0) skip
addi t0, t0, 4 # n += 4
slli a0, a0, 4 # x <<= 4
3:
# Check upper 2 bits
srli t1, a0, 30 # t1 = x >> 30
bnez t1, 4f # if (t1 != 0) skip
addi t0, t0, 2 # n += 2
slli a0, a0, 2 # x <<= 2
4:
# Check the most significant bit
srli t1, a0, 31 # t1 = x >> 31
bnez t1, 5f # if (t1 != 0) skip
addi t0, t0, 1 # n += 1
5:
mv a0, t0 # return n (number of leading zeros)
jr ra # return
# Case when x == 0
9:
li a0, 32 # return 32 (all bits are zero)
jr ra
```
**2. Execution information** - <span style="color:darkblue">**$712 \;\text{cycles}$**</span>

**3. why faster ?**
* **fewer dynamic instructions**:
The unrolled version replaces the `16/8/4/2/1` loop iterations with straight-line conditional checks.Each stage only performs a simple `“test + (if needed) add n / shift x”` operation,
* **fewer loop-carried dependencies.**
The unrolled version still has dependencies, but since there is **no backward branch**,the entire sequence becomes a **fixed-depth linear dependency chain** that can be forwarded smoothly through a simple pipeline.
* **without accumulating per-iteration branch decision latency**
---
<h4>
<span style="color:darkblue">complete version</span>
</h4>
```s
.data
# Test data
test1: .word 15 # small value
test2: .word 108 # medium value
test3: .word 1000000 # large value
# Test messages
test_start_msg: .string "=== UF8 Automated Test ===\n\n"
test1_msg: .string "Test 1 (small): "
test2_msg: .string "Test 2 (medium): "
test3_msg: .string "Test 3 (large): "
arrow_msg: .string " -> Encoded: "
decode_result_msg: .string ", Decoded: "
pass_msg: .string " PASS\n"
fail_msg: .string " FAIL\n"
separator: .string "\n====================\n"
all_pass_msg: .string "All tests passed.\n"
some_fail_msg: .string "Tests failed.\n"
newline: .string "\n"
.text
.globl _start
_start:
jal main
li a7, 10
ecall
print_string:
li a7, 4
ecall
jr ra
print_int:
li a7, 1
ecall
jr ra
print_char:
li a7, 11
ecall
jr ra
# clz function (count leading zeros) --> unloop
clz:
beqz a0, 9f # x==0 → 32
li t0, 0 # n = 0
# Check upper 16 bits
srli t1, a0, 16 # t1 = x >> 16
bnez t1, 1f # if (t1 != 0) skip
addi t0, t0, 16 # n += 16
slli a0, a0, 16 # x <<= 16
1:
# Check upper 8 bits
srli t1, a0, 24 # t1 = x >> 24
bnez t1, 2f # if (t1 != 0) skip
addi t0, t0, 8 # n += 8
slli a0, a0, 8 # x <<= 8
2:
# Check upper 4 bits
srli t1, a0, 28 # t1 = x >> 28
bnez t1, 3f # if (t1 != 0) skip
addi t0, t0, 4 # n += 4
slli a0, a0, 4 # x <<= 4
3:
# Check upper 2 bits
srli t1, a0, 30 # t1 = x >> 30
bnez t1, 4f # if (t1 != 0) skip
addi t0, t0, 2 # n += 2
slli a0, a0, 2 # x <<= 2
4:
# Check the most significant bit
srli t1, a0, 31 # t1 = x >> 31
bnez t1, 5f # if (t1 != 0) skip
addi t0, t0, 1 # n += 1
5:
mv a0, t0 # return n (number of leading zeros)
jr ra # return
# Case when x == 0
9:
li a0, 32 # return 32 (all bits are zero)
jr ra
# UF8 decoding function --> value = (m<<e) + (((1<<e)-1)<<4)
uf8_decode:
andi t0, a0, 0x0F # m = b & 0x0F
srli t1, a0, 4 # e = b >> 4
li t2, 1
sll t2, t2, t1 # 1<<e
addi t2, t2, -1 # (1<<e)-1
slli t2, t2, 4 # offset = ((1<<e)-1)<<4
sll t0, t0, t1 # m<<e
add a0, t0, t2 # value
jr ra
# UF8 encoding function
uf8_encode:
# Special case: value < 16
li t0, 16
blt a0, t0, encode_small
addi sp, sp, -20
sw ra, 16(sp)
sw s0, 12(sp) # value
sw s1, 8(sp) # exponent
sw s2, 4(sp) # overflow
sw s3, 0(sp) # temporary variable
mv s0, a0 # s0 = value
li s1, 0 # exponent = 0
li s2, 0 # overflow = 0
# Estimate using CLZ
mv a0, s0
jal clz
li t0, 31
sub t0, t0, a0 # msb = 31 - clz(value)
# If msb >= 5, estimate exponent
li t1, 5
blt t0, t1, upscan_init
# Estimate exponent = msb - 4
addi s1, t0, -4 # exponent = msb - 4
li t1, 15
ble s1, t1, calc_overflow
li s1, 15 # clamp to max = 15
calc_overflow:
# overflow = 16 * (2^exponent - 1)
beqz s1, upscan_init # if exponent == 0 → overflow = 0
li t2, 0 # loop counter
li s2, 0 # overflow = 0
build_loop:
slli s2, s2, 1 # overflow << 1
addi s2, s2, 16 # overflow + 16
addi t2, t2, 1
blt t2, s1, build_loop
j check_adjust
check_adjust:
beqz s1, upscan_init # if exponent == 0 → skip
# If value < overflow, adjust downward
bgeu s0, s2, upscan_init
addi s2, s2, -16 # overflow - 16
srli s2, s2, 1 # (overflow - 16) >> 1
addi s1, s1, -1 # exponent--
j upscan_init # continue (only adjust once)
upscan_init:
# Start upward adjustment
li s3, 15 # max_exponent = 15
upscan_loop:
bge s1, s3, upscan_done
# next_overflow = (overflow << 1) + 16
slli t0, s2, 1
addi t0, t0, 16
# Stop if value < next_overflow
blt s0, t0, upscan_done
# Move to next range
mv s2, t0 # overflow = next_overflow
addi s1, s1, 1 # exponent++
j upscan_loop
upscan_done:
# mantissa = (value - overflow) >> exponent
sub t0, s0, s2
srl t0, t0, s1
# Limit mantissa to 15
li t1, 15
ble t0, t1, encode_pack
li t0, 15
encode_pack:
# Combine: (exponent << 4) | mantissa
slli s1, s1, 4
or a0, s1, t0
lw ra, 16(sp)
lw s0, 12(sp)
lw s1, 8(sp)
lw s2, 4(sp)
lw s3, 0(sp)
addi sp, sp, 20
jr ra
encode_small:
# For value < 16, return directly
jr ra
# Automated test function – test one value (returns 0=pass, 1=fail)
test_single_value:
# a0 = test value
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp) # original value
sw s1, 4(sp) # encoded
sw s2, 0(sp) # decoded
mv s0, a0 # save original
# Show test info
mv a0, s0
jal print_int
la a0, arrow_msg
jal print_string
# Encode
mv a0, s0
jal uf8_encode
mv s1, a0 # save encoded result
# Print encoded value
mv a0, s1
jal print_int
# Decode
mv a0, s1
jal uf8_decode
mv s2, a0 # save decoded result
# Print decoded result
la a0, decode_result_msg
jal print_string
mv a0, s2
jal print_int
# Check decoded == original
bne s2, s0, test_fail
# Pass
la a0, pass_msg
jal print_string
li a0, 0
j test_done
test_fail:
# Fail
la a0, fail_msg
jal print_string
li a0, 1
test_done:
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
lw s2, 0(sp)
addi sp, sp, 16
jr ra
# Main – automated tests for three values
main:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp) # test summary (0 = all passed)
li s0, 0 # assume all pass
la a0, test_start_msg
jal print_string
# Test 1: small value
la a0, test1_msg
jal print_string
lw a0, test1
jal test_single_value
or s0, s0, a0 # merge result
# Test 2: medium value
la a0, test2_msg
jal print_string
lw a0, test2
jal test_single_value
or s0, s0, a0
# Test 3: large value
la a0, test3_msg
jal print_string
lw a0, test3
jal test_single_value
or s0, s0, a0
# Print summary
la a0, separator
jal print_string
bnez s0, tests_failed
# All passed
la a0, all_pass_msg
jal print_string
j main_done
tests_failed:
# Some tests failed
la a0, some_fail_msg
jal print_string
main_done:
lw ra, 4(sp)
lw s0, 0(sp)
addi sp, sp, 8
jr ra
```
---
## Problem `C` in [Quiz1](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
### Original C program
<details>
<summary><b>Open to see the complete C program </b></summary>
```c=
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F),
};
}
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
static inline bool bf16_eq(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return true;
return a.bits == b.bits;
}
static inline bool bf16_lt(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return false;
bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1;
if (sign_a != sign_b)
return sign_a > sign_b;
return sign_a ? a.bits > b.bits : a.bits < b.bits;
}
static inline bool bf16_gt(bf16_t a, bf16_t b)
{
return bf16_lt(b, a);
}
#include <stdio.h>
#include <time.h>
#define TEST_ASSERT(cond, msg) \
do { \
if (!(cond)) { \
printf("FAIL: %s\n", msg); \
return 1; \
} \
} while (0)
static int test_basic_conversions(void)
{
printf("Testing basic conversions...\n");
float test_values[] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.5f,
-0.5f, 3.14159f, -3.14159f, 1e10f, -1e10f};
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); i++) {
float orig = test_values[i];
bf16_t bf = f32_to_bf16(orig);
float conv = bf16_to_f32(bf);
if (orig != 0.0f) {
TEST_ASSERT((orig < 0) == (conv < 0), "Sign mismatch");
}
if (orig != 0.0f && !bf16_isinf(f32_to_bf16(orig))) {
float diff = (conv - orig);
float rel_error = (diff < 0) ? -diff / orig : diff / orig;
TEST_ASSERT(rel_error < 0.01f, "Relative error too large");
}
}
printf(" Basic conversions: PASS\n");
return 0;
}
static int test_special_values(void)
{
printf("Testing special values...\n");
bf16_t pos_inf = {.bits = 0x7F80}; /* +Infinity */
TEST_ASSERT(bf16_isinf(pos_inf), "Positive infinity not detected");
TEST_ASSERT(!bf16_isnan(pos_inf), "Infinity detected as NaN");
bf16_t neg_inf = {.bits = 0xFF80}; /* -Infinity */
TEST_ASSERT(bf16_isinf(neg_inf), "Negative infinity not detected");
bf16_t nan_val = BF16_NAN();
TEST_ASSERT(bf16_isnan(nan_val), "NaN not detected");
TEST_ASSERT(!bf16_isinf(nan_val), "NaN detected as infinity");
bf16_t zero = f32_to_bf16(0.0f);
TEST_ASSERT(bf16_iszero(zero), "Zero not detected");
bf16_t neg_zero = f32_to_bf16(-0.0f);
TEST_ASSERT(bf16_iszero(neg_zero), "Negative zero not detected");
printf(" Special values: PASS\n");
return 0;
}
static int test_arithmetic(void)
{
printf("Testing arithmetic operations...\n");
bf16_t a = f32_to_bf16(1.0f);
bf16_t b = f32_to_bf16(2.0f);
bf16_t c = bf16_add(a, b);
float result = bf16_to_f32(c);
float diff = result - 3.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Addition failed");
c = bf16_sub(b, a);
result = bf16_to_f32(c);
diff = result - 1.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Subtraction failed");
a = f32_to_bf16(3.0f);
b = f32_to_bf16(4.0f);
c = bf16_mul(a, b);
result = bf16_to_f32(c);
diff = result - 12.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Multiplication failed");
a = f32_to_bf16(10.0f);
b = f32_to_bf16(2.0f);
c = bf16_div(a, b);
result = bf16_to_f32(c);
diff = result - 5.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Division failed");
/* Test square root */
a = f32_to_bf16(4.0f);
c = bf16_sqrt(a);
result = bf16_to_f32(c);
diff = result - 2.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(4) failed");
a = f32_to_bf16(9.0f);
c = bf16_sqrt(a);
result = bf16_to_f32(c);
diff = result - 3.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(9) failed");
printf(" Arithmetic: PASS\n");
return 0;
}
static int test_comparisons(void)
{
printf("Testing comparison operations...\n");
bf16_t a = f32_to_bf16(1.0f);
bf16_t b = f32_to_bf16(2.0f);
bf16_t c = f32_to_bf16(1.0f);
TEST_ASSERT(bf16_eq(a, c), "Equality test failed");
TEST_ASSERT(!bf16_eq(a, b), "Inequality test failed");
TEST_ASSERT(bf16_lt(a, b), "Less than test failed");
TEST_ASSERT(!bf16_lt(b, a), "Not less than test failed");
TEST_ASSERT(!bf16_lt(a, c), "Equal not less than test failed");
TEST_ASSERT(bf16_gt(b, a), "Greater than test failed");
TEST_ASSERT(!bf16_gt(a, b), "Not greater than test failed");
bf16_t nan_val = BF16_NAN();
TEST_ASSERT(!bf16_eq(nan_val, nan_val), "NaN equality test failed");
TEST_ASSERT(!bf16_lt(nan_val, a), "NaN less than test failed");
TEST_ASSERT(!bf16_gt(nan_val, a), "NaN greater than test failed");
printf(" Comparisons: PASS\n");
return 0;
}
static int test_edge_cases(void)
{
printf("Testing edge cases...\n");
float tiny = 1e-45f;
bf16_t bf_tiny = f32_to_bf16(tiny);
float tiny_val = bf16_to_f32(bf_tiny);
TEST_ASSERT(bf16_iszero(bf_tiny) || (tiny_val < 0 ? -tiny_val : tiny_val) < 1e-37f,
"Tiny value handling");
float huge = 1e38f;
bf16_t bf_huge = f32_to_bf16(huge);
bf16_t bf_huge2 = bf16_mul(bf_huge, f32_to_bf16(10.0f));
TEST_ASSERT(bf16_isinf(bf_huge2), "Overflow should produce infinity");
bf16_t small = f32_to_bf16(1e-38f);
bf16_t smaller = bf16_div(small, f32_to_bf16(1e10f));
float smaller_val = bf16_to_f32(smaller);
TEST_ASSERT(bf16_iszero(smaller) || (smaller_val < 0 ? -smaller_val : smaller_val) < 1e-45f,
"Underflow should produce zero or denormal");
printf(" Edge cases: PASS\n");
return 0;
}
static int test_rounding(void)
{
printf("Testing rounding behavior...\n");
float exact = 1.5f;
bf16_t bf_exact = f32_to_bf16(exact);
float back_exact = bf16_to_f32(bf_exact);
TEST_ASSERT(back_exact == exact,
"Exact representation should be preserved");
float val = 1.0001f;
bf16_t bf = f32_to_bf16(val);
float back = bf16_to_f32(bf);
float diff2 = back - val;
TEST_ASSERT((diff2 < 0 ? -diff2 : diff2) < 0.001f, "Rounding error should be small");
printf(" Rounding: PASS\n");
return 0;
}
#ifndef BFLOAT16_NO_MAIN
int main(void)
{
printf("\n=== bfloat16 Test Suite ===\n\n");
int failed = 0;
failed |= test_basic_conversions();
failed |= test_special_values();
failed |= test_arithmetic();
failed |= test_comparisons();
failed |= test_edge_cases();
failed |= test_rounding();
if (failed) {
printf("\n=== TESTS FAILED ===\n");
return 1;
}
printf("\n=== ALL TESTS PASSED ===\n");
return 0;
}
#endif /* BFLOAT16_NO_MAIN */
```
</details>
### Assembly code
<h4>
<span style="color:darkblue">Part<code>1</code> : Bfloat16 <-----> float32</span>
</h4>
Contains special value test.
```s
.data
# Test Case Description Strings ---------------------------------------------
str_t1: .asciz " 1.0f"
str_t_neg_simple: .asciz " -4.0f"
str_t3: .asciz " 3.14159f"
str_t4: .asciz " 0.1f"
str_t_neg_round: .asciz " -2.7f"
str_t5: .asciz " +0.0f"
str_t6: .asciz " -0.0f"
str_t7: .asciz " +Infinity"
str_t8: .asciz " -Infinity"
str_t9: .asciz " NaN"
# Normal Value Tests (with corrected golden values) ------------------------
normal_test_strings:
.word str_t1, str_t_neg_simple, str_t3, str_t4, str_t_neg_round
normal_test_inputs:
.word 0x3f800000, 0xc0800000, 0x40490fdb, 0x3dcccccd, 0xc02ccccd
normal_test_golden:
.word 0x3f80, 0xc080, 0x4049, 0x3dcd, 0xc02d
# Special Value Tests (with corrected golden values) ---------------------------
special_test_strings:
.word str_t5, str_t6, str_t7, str_t8, str_t9
special_test_inputs:
.word 0x00000000, 0x80000000, 0x7f800000, 0xff800000, 0x7fc00000
special_test_golden:
.word 0x0000, 0x8000, 0x7f80, 0xff80, 0x7fc0
# tests message ---------------------------------------------
str_header_normal: .asciz "\n--- Running Normal Value Test Cases ---\n"
str_header_special: .asciz "\n--- Running Special Value Test Cases (Zero, Inf, NaN) ---\n"
str_testing: .asciz "Testing"
str_orig_label: .asciz "\n Original f32: 0x"
str_bf16_label: .asciz " -> bf16: 0x"
str_restored_label: .asciz " -> Restored f32: 0x"
str_success: .asciz " [PASS]"
str_fail: .asciz " [FAIL]"
str_actual: .asciz " (Actual: 0x"
str_expected: .asciz ", Expected: 0x"
str_close_paren: .asciz ")\n"
str_summary: .asciz "\n\n[Summary: "
str_summary_middle: .asciz " / "
str_summary_end: .asciz " Tests Passed]\n"
newline: .asciz "\n"
str_all_pass: .asciz "\n--- All tests passed! ---\n"
str_some_fail: .asciz "\n--- Some tests failed! ---\n"
.text
.globl main
# ==================== main tests ====================
main:
# Function Prologue
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp) # Reserve space on stack for s0 (total failure counter)
# Initialize overall failure counter
mv s0, zero # s0 = total_failures = 0
# ======== Run Normal Tests ========
la a0, str_header_normal
jal ra, print_string
la a0, normal_test_strings
la a1, normal_test_inputs
la a2, normal_test_golden
addi a3, zero, 5
jal ra, run_test_suite
add s0, s0, a0 # Accumulate failures from this test suite
# ======== Run Special Tests ========
la a0, str_header_special
jal ra, print_string
la a0, special_test_strings
la a1, special_test_inputs
la a2, special_test_golden
addi a3, zero, 5
jal ra, run_test_suite
add s0, s0, a0 # Accumulate failures again
# ======== Print Final Overall Summary ========
bnez s0, some_tests_failed
all_tests_passed:
la a0, str_all_pass
jal ra, print_string
j exit_program
some_tests_failed:
la a0, str_some_fail
jal ra, print_string
exit_program:
# Function Epilogue
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
addi a7, zero, 10 # ecall 10: Exit
ecall
# =============================================================================
# run_test_suite: Loops through a set of tests, runs them,
# and returns the number of failures.
# Return Value: a0 = Number of failures in this suite.
#==============================================================================
run_test_suite:
# Function Prologue: Allocate 28 bytes (1 ra + 6 s-regs)
addi sp, sp, -28
sw ra, 24(sp)
sw s0, 20(sp)
sw s1, 16(sp)
sw s2, 12(sp)
sw s3, 8(sp)
sw s4, 4(sp)
sw s5, 0(sp)
mv s0, a0 # s0 = string array
mv s1, a1 # s1 = input array
mv s2, a2 # s2 = golden value array
mv s3, a3 # s3 = Loop counter
mv s4, zero # s4 = Success counter
mv s5, zero # s5 = Failure counter
mv t5, a3 # t5 = Backup of total test count for summary
test_loop:
# Pass pointers to the single test runner
mv a0, s0
mv a1, s1
mv a2, s2
jal ra, run_single_test
# run_single_test returns 1 in a0 if it failed, 0 if success
add s5, s5, a0 # Accumulate failure count
# Advance all array pointers to the next test case
addi s0, s0, 4
addi s1, s1, 4
addi s2, s2, 4
addi s3, s3, -1
bnez s3, test_loop # Continue if tests remain
# ========= Print Suite Summary ===========
sub s4, t5, s5 # Success count is total - failures
la a0, str_summary
jal ra, print_string
mv a0, s4
jal ra, print_int
la a0, str_summary_middle
jal ra, print_string
mv a0, t5
jal ra, print_int
la a0, str_summary_end
jal ra, print_string
# Function Epilogue
mv a0, s5 # Set return value to the failure count
lw s5, 0(sp)
lw s4, 4(sp)
lw s3, 8(sp)
lw s2, 12(sp)
lw s1, 16(sp)
lw s0, 20(sp)
lw ra, 24(sp)
addi sp, sp, 28
ret
# =============================================================================
# run_single_test: Performs a round-trip conversion and checks the result.
# Arguments: a0=str_ptr, a1=input_ptr, a2=golden_ptr
# Return Value: a0 = 1 if fail, 0 if success.
#==============================================================================
run_single_test:
# Function Prologue
addi sp, sp, -20
sw ra, 16(sp)
sw s0, 12(sp)
sw s1, 8(sp)
sw s2, 4(sp)
sw s3, 0(sp)
# Load test data into saved registers
lw s0, 0(a0) # s0 = address of description string
lw s1, 0(a1) # s1 = original f32 input value
lw s2, 0(a2) # s2 = expected bf16 golden value
# Perform Conversions
mv a0, s1 # Set argument for f32_to_bf16
jal ra, f32_to_bf16
mv s3, a0 # s3 = actual_bf16_result
jal ra, bf16_to_f32
mv t0, a0 # t0 = restored_f32_result
# Calculate Golden Restored Value
slli t1, s2, 16 # t1 = golden_restored_f32
# ===== Print Results on a Single Line ==========
la a0, str_testing
jal ra, print_string # Print "Testing"
mv a0, s0
jal ra, print_string # Print " 1.0f"
la a0, str_orig_label
jal ra, print_string # Print "\n Original f32: 0x"
mv a0, s1
jal ra, print_hex32 # Print original value
la a0, str_bf16_label
jal ra, print_string # Print " -> bf16: 0x"
mv a0, s3
jal ra, print_hex32 # Print bf16 value
la a0, str_restored_label
jal ra, print_string # Print " -> Restored f32: 0x"
mv a0, t0
jal ra, print_hex32 # Print restored value
# Compare restored_f32 (t0) with golden_restored_f32 (t1)
beq t0, t1, test_success
test_fail:
la a0, str_fail
jal ra, print_string
la a0, str_actual
jal ra, print_string
mv a0, t0
jal ra, print_hex32
la a0, str_expected
jal ra, print_string
mv a0, t1
jal ra, print_hex32
la a0, str_close_paren
jal ra, print_string
addi a0, zero, 1 # Return 1 for failure
j end_single_test
test_success:
la a0, str_success
jal ra, print_string
la a0, newline
jal ra, print_string
addi a0, zero, 0 # Return 0 for success
end_single_test:
# Function Epilogue
lw s3, 0(sp)
lw s2, 4(sp)
lw s1, 8(sp)
lw s0, 12(sp)
lw ra, 16(sp)
addi sp, sp, 20
ret
# ==================== bf16_t f32_to_bf16(float val) ========================
f32_to_bf16:
# Check if the number is NaN or Infinity by inspecting the exponent.
srli t0, a0, 23 # Isolate exponent and sign.
andi t0, t0, 0xFF # Mask to get only the 8 exponent bits.
addi t1, zero, 0xFF # Load 0xFF for comparison.
beq t0, t1, is_nan_or_inf # If exponent is all 1s, jump to special handling.
# ====== Normal Number Rounding Path ======
lui t0, 0x80000 # Load upper bits of the sign mask (0x80000000).
and t1, a0, t0 # t1 = sign bit (0x80000000 or 0).
not t0, t0 # sign mask = 0x7FFFFFFF.
and t2, a0, t0 # t2 = magnitude.
srli t3, t2, 16 # Shift magnitude to get the tie-breaking bit's position.
andi t3, t3, 1 # Isolate the tie-breaking bit (0 or 1).
lui t4, 0x8 # Load upper bits of 0x8000.
addi t4, t4, -1 # Create the main rounding constant 0x7FFF.
add t3, t3, t4 # t3 = final addend (0x7FFF or 0x8000).
add t2, t2, t3
srli t2, t2, 16 # Truncate the rounded magnitude to 16 bits.
srli t1, t1, 16 # Shift the original sign bit to its bf16 position (bit 15).
or a0, t2, t1
ret
is_nan_or_inf:
# For NaN, Infinity, and Zero, the correct behavior is simple truncation.
srli a0, a0, 16 # Shift right by 16 bits.
ret
# ================ float bf16_to_f32(bf16_t val) ======================
bf16_to_f32:
slli a0, a0, 16 # Shift left by 16, padding lower bits with zeros.
ret
# =========== Helper Print Functions ==========
print_string:
addi a7, zero, 4 # Set ecall code for Print String.
ecall
ret
print_hex32:
addi a7, zero, 34 # Set ecall code for Print Hex.
ecall
ret
print_int:
addi a7, zero, 1 # Set ecall code for Print Integer.
ecall
ret
```
**1. Execution information** - <span style="color:darkblue">**$2075 \;\text{cycles}$**</span>

<h4>
<span style="color:darkblue">Part<code>2</code> : Bfloat16 arithmetic_operations</span>
</h4>
Contains special value test.
```s
.data
# Addition test cases
add_a1: .word 0x00003F80 # 1.0 in bfloat16
add_b1: .word 0x00004000 # 2.0 in bfloat16
add_exp1: .word 0x00004040 # 3.0 expected (1.0 + 2.0)
add_a2: .word 0x00003FC0 # 1.5 in bfloat16
add_b2: .word 0x00003FC0 # 1.5 in bfloat16
add_exp2: .word 0x00004040 # 3.0 expected (1.5 + 1.5)
add_a3: .word 0x00003F00 # 0.5 in bfloat16
add_b3: .word 0x00003F00 # 0.5 in bfloat16
add_exp3: .word 0x00003F80 # 1.0 expected (0.5 + 0.5)
# Subtraction test cases
sub_a1: .word 0x00004040 # 3.0 in bfloat16
sub_b1: .word 0x00004000 # 2.0 in bfloat16
sub_exp1: .word 0x00003F80 # 1.0 expected (3.0 - 2.0)
sub_a2: .word 0x00004000 # 2.0 in bfloat16
sub_b2: .word 0x00003F80 # 1.0 in bfloat16
sub_exp2: .word 0x00003F80 # 1.0 expected (2.0 - 1.0)
sub_a3: .word 0x00004040 # 3.0 in bfloat16
sub_b3: .word 0x00003FC0 # 1.5 in bfloat16
sub_exp3: .word 0x00003FC0 # 1.5 expected (3.0 - 1.5)
# Multiplication test cases
mul_a1: .word 0x00004040 # 3.0 in bfloat16
mul_b1: .word 0x00004080 # 4.0 in bfloat16
mul_exp1: .word 0x00004180 # 12.0 expected (3.0 * 4.0)
mul_a2: .word 0x00004000 # 2.0 in bfloat16
mul_b2: .word 0x00004020 # 2.5 in bfloat16
mul_exp2: .word 0x00004140 # 5.0 expected (2.0 * 2.5)
mul_a3: .word 0x00003FC0 # 1.5 in bfloat16
mul_b3: .word 0x00004000 # 2.0 in bfloat16
mul_exp3: .word 0x00004040 # 3.0 expected (1.5 * 2.0)
# Division test cases
div_a1: .word 0x00004140 # 5.0 in bfloat16
div_b1: .word 0x00004000 # 2.0 in bfloat16
div_exp1: .word 0x000040A0 # 2.5 expected (5.0 / 2.0)
div_a2: .word 0x00004180 # 6.0 in bfloat16
div_b2: .word 0x00004000 # 2.0 in bfloat16
div_exp2: .word 0x00004040 # 3.0 expected (6.0 / 2.0)
div_a3: .word 0x00004040 # 3.0 in bfloat16
div_b3: .word 0x00004000 # 2.0 in bfloat16
div_exp3: .word 0x00003FC0 # 1.5 expected (3.0 / 2.0)
# Special values for testing edge cases
test_nan: .word 0x00007FC0 # NaN value
test_inf: .word 0x00007F80 # Positive infinity
test_zero: .word 0x00000000 # Zero value
# Result storage for test verification
result: .word 0x00000000
# tests message ---------------------------------------------
test_start: .string "bfloat16 Complete Test Suite\n\n"
add_header: .string "=== Addition Tests ===\n"
sub_header: .string "\n=== Subtraction Tests ===\n"
mul_header: .string "\n=== Multiplication Tests ===\n"
div_header: .string "\n=== Division Tests ===\n"
special_header:.string "\n=== Special Values Tests ===\n"
# Addition test descriptions
add_test1: .string "1.0 + 2.0 = 3.0: "
add_test2: .string "1.5 + 1.5 = 3.0: "
add_test3: .string "0.5 + 0.5 = 1.0: "
# Subtraction test descriptions
sub_test1: .string "3.0 - 2.0 = 1.0: "
sub_test2: .string "2.0 - 1.0 = 1.0: "
sub_test3: .string "3.0 - 1.5 = 1.5: "
# Multiplication test descriptions
mul_test1: .string "3.0 * 4.0 = 12.0: "
mul_test2: .string "2.0 * 2.5 = 5.0: "
mul_test3: .string "1.5 * 2.0 = 3.0: "
# Division test descriptions
div_test1: .string "5.0 / 2.0 = 2.5: "
div_test2: .string "6.0 / 2.0 = 3.0: "
div_test3: .string "3.0 / 2.0 = 1.5: "
# Special values test description
special_test: .string "Special values (NaN, Inf, Zero): "
# Test result messages
pass_msg: .string "PASS\n"
fail_msg: .string "FAIL\n"
# Summary and statistics messages
summary: .string "\n=== Test Summary ===\n"
total_tests: .string "Total tests: "
passed_tests: .string "Passed: "
failed_tests: .string "Failed: "
all_pass: .string "\nAll tests passed!\n"
some_fail: .string "\nSome tests failed.\n"
newline: .string "\n"
.text
.globl _start
# ==================== main tests ====================
_start:
la a0, test_start # Print test suite header
addi a7, zero, 4
ecall
addi s0, zero, 0 # s0 = pass counter
addi s1, zero, 0 # s1 = total test counter
# ==================== addition tests ====================
la a0, add_header
addi a7, zero, 4
ecall
# Test 1.1: 1.0 + 2.0 = 3.0
la a0, add_test1
addi a7, zero, 4
ecall
la t0, add_a1 # Load test values and perform addition
lw a0, 0(t0)
la t0, add_b1
lw a1, 0(t0)
jal bf16_add
la t0, result # Store and verify result
sw a0, 0(t0)
la t0, add_exp1
lw a1, 0(t0)
jal bf16_eq
jal check_test_result # Check and display result
# Test 1.2: 1.5 + 1.5 = 3.0
la a0, add_test2
addi a7, zero, 4
ecall
la t0, add_a2
lw a0, 0(t0)
la t0, add_b2
lw a1, 0(t0)
jal bf16_add
la t0, result
sw a0, 0(t0)
la t0, add_exp2
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 1.3: 0.5 + 0.5 = 1.0
la a0, add_test3
addi a7, zero, 4
ecall
la t0, add_a3
lw a0, 0(t0)
la t0, add_b3
lw a1, 0(t0)
jal bf16_add
la t0, result
sw a0, 0(t0)
la t0, add_exp3
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# ==================== subtraction tests ====================
la a0, sub_header
addi a7, zero, 4
ecall
# Test 2.1: 3.0 - 2.0 = 1.0
la a0, sub_test1
addi a7, zero, 4
ecall
la t0, sub_a1
lw a0, 0(t0)
la t0, sub_b1
lw a1, 0(t0)
jal bf16_sub
la t0, result
sw a0, 0(t0)
la t0, sub_exp1
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 2.2: 2.0 - 1.0 = 1.0
la a0, sub_test2
addi a7, zero, 4
ecall
la t0, sub_a2
lw a0, 0(t0)
la t0, sub_b2
lw a1, 0(t0)
jal bf16_sub
la t0, result
sw a0, 0(t0)
la t0, sub_exp2
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 2.3: 3.0 - 1.5 = 1.5
la a0, sub_test3
addi a7, zero, 4
ecall
la t0, sub_a3
lw a0, 0(t0)
la t0, sub_b3
lw a1, 0(t0)
jal bf16_sub
la t0, result
sw a0, 0(t0)
la t0, sub_exp3
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# ==================== multiplication tests ====================
la a0, mul_header
addi a7, zero, 4
ecall
# Test 3.1: 3.0 * 4.0 = 12.0
la a0, mul_test1
addi a7, zero, 4
ecall
la t0, mul_a1
lw a0, 0(t0)
la t0, mul_b1
lw a1, 0(t0)
jal bf16_mul
la t0, result
sw a0, 0(t0)
la t0, mul_exp1
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 3.2: 2.0 * 2.5 = 5.0
la a0, mul_test2
addi a7, zero, 4
ecall
la t0, mul_a2
lw a0, 0(t0)
la t0, mul_b2
lw a1, 0(t0)
jal bf16_mul
la t0, result
sw a0, 0(t0)
la t0, mul_exp2
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 3.3: 1.5 * 2.0 = 3.0
la a0, mul_test3
addi a7, zero, 4
ecall
la t0, mul_a3
lw a0, 0(t0)
la t0, mul_b3
lw a1, 0(t0)
jal bf16_mul
la t0, result
sw a0, 0(t0)
la t0, mul_exp3
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# ==================== division tests ====================
la a0, div_header
addi a7, zero, 4
ecall
# Test 4.1: 5.0 / 2.0 = 2.5
la a0, div_test1
addi a7, zero, 4
ecall
la t0, div_a1
lw a0, 0(t0)
la t0, div_b1
lw a1, 0(t0)
jal bf16_div
la t0, result
sw a0, 0(t0)
la t0, div_exp1
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 4.2: 6.0 / 2.0 = 3.0
la a0, div_test2
addi a7, zero, 4
ecall
la t0, div_a2
lw a0, 0(t0)
la t0, div_b2
lw a1, 0(t0)
jal bf16_div
la t0, result
sw a0, 0(t0)
la t0, div_exp2
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# Test 4.3: 3.0 / 2.0 = 1.5
la a0, div_test3
addi a7, zero, 4
ecall
la t0, div_a3
lw a0, 0(t0)
la t0, div_b3
lw a1, 0(t0)
jal bf16_div
la t0, result
sw a0, 0(t0)
la t0, div_exp3
lw a1, 0(t0)
jal bf16_eq
jal check_test_result
# ==================== special value test ====================
la a0, special_header
addi a7, zero, 4
ecall
la a0, special_test
addi a7, zero, 4
ecall
# Test NaN detection
la t0, test_nan
lw a0, 0(t0)
jal bf16_isnan
beq a0, zero, special_fail
# Test Infinity detection
la t0, test_inf
lw a0, 0(t0)
jal bf16_isinf
beq a0, zero, special_fail
# Test zero detection
la t0, test_zero
lw a0, 0(t0)
jal bf16_iszero
beq a0, zero, special_fail
# All special value tests passed
addi s0, s0, 1
addi s1, s1, 1
la a0, pass_msg
addi a7, zero, 4
ecall
j print_summary
special_fail:
# Special value test failed
addi s1, s1, 1
la a0, fail_msg
addi a7, zero, 4
ecall
# ==================== test summary ====================
print_summary:
la a0, summary # Print test summary header
addi a7, zero, 4
ecall
sub s2, s1, s0 # s2 = failed tests = total - passed
# Print total tests count
la a0, total_tests
addi a7, zero, 4
ecall
add a0, zero, s1
addi a7, zero, 1
ecall
la a0, newline
addi a7, zero, 4
ecall
# Print passed tests count
la a0, passed_tests
addi a7, zero, 4
ecall
add a0, zero, s0
addi a7, zero, 1
ecall
la a0, newline
addi a7, zero, 4
ecall
# Print failed tests count
la a0, failed_tests
addi a7, zero, 4
ecall
add a0, zero, s2
addi a7, zero, 1
ecall
la a0, newline
addi a7, zero, 4
ecall
# Print final result message
beq s2, zero, all_passed_msg
la a0, some_fail
j exit
all_passed_msg:
la a0, all_pass
exit:
addi a7, zero, 4
ecall
addi a7, zero, 10 # Exit system call
ecall
# ==================== TEST RESULT CHECKER SUBROUTINE ====================
check_test_result:
addi sp, sp, -4
sw ra, 0(sp)
addi s1, s1, 1 # Increment total test counter
beq a0, zero, test_failed
addi s0, s0, 1 # Increment pass counter
la a0, pass_msg
j test_end
test_failed:
la a0, fail_msg
test_end:
addi a7, zero, 4
ecall
lw ra, 0(sp)
addi sp, sp, 4
ret
# ==================== BF16 Subtraction Function ====================
# Implements: a - b = a + (-b)
# Input: a0 = bf16 a, a1 = bf16 b
# Output: a0 = a - b in bf16 format
bf16_sub:
addi sp, sp, -8
sw ra, 0(sp)
sw s0, 4(sp)
mv s0, a1 # Save b
# Flip the sign bit of b (b ^= BF16_SIGN_MASK)
li t0, 0x8000 # BF16_SIGN_MASK
xor a1, s0, t0 # b.bits ^= BF16_SIGN_MASK
# Call bf16_add(a, -b)
jal bf16_add
lw ra, 0(sp)
lw s0, 4(sp)
addi sp, sp, 8
ret
# ==================== BF16 Addition Function ====================
bf16_add:
addi sp, sp, -16
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
add s0, zero, a0 # Save input values
add s1, zero, a1
# Extract sign, exponent, mantissa from a
srli t0, s0, 15 # Extract sign bit (bit 15)
andi t0, t0, 1 # Keep only the sign bit
srli t1, s0, 7 # Extract exponent (bits 7-14)
andi t1, t1, 255 # Keep only 8 exponent bits
andi t2, s0, 127 # Extract mantissa (bits 0-6)
# Extract sign, exponent, mantissa from b
srli t3, s1, 15
andi t3, t3, 1
srli t4, s1, 7
andi t4, t4, 255
andi t5, s1, 127
# Add implicit 1 to mantissas for normalized numbers
beq t1, zero, skip_impl_a
ori t2, t2, 128 # Add implicit 1 (bit 7)
skip_impl_a:
beq t4, zero, skip_impl_b
ori t5, t5, 128
skip_impl_b:
# Align exponents by shifting smaller exponent's mantissa
sub t6, t1, t4 # Calculate exponent difference
bge t6, zero, a_greater_exp
# b has larger exponent, shift a's mantissa
sub t6, t4, t1 # Get shift amount
srl t2, t2, t6 # Shift a's mantissa right
add t1, zero, t4 # Use b's exponent
j exponents_aligned
a_greater_exp:
# a has larger exponent, shift b's mantissa
srl t5, t5, t6
exponents_aligned:
# Perform addition or subtraction based on signs
bne t0, t3, subtract
# Same sign - add mantissas
add t2, t2, t5
add s2, zero, t0 # Save result sign
j normalize
subtract:
# Different signs - subtract smaller from larger
bgeu t2, t5, a_greater_mant
sub t2, t5, t2 # b - a
add s2, zero, t3 # Use b's sign
j normalize
a_greater_mant:
sub t2, t2, t5 # a - b
add s2, zero, t0 # Use a's sign
normalize:
beq t2, zero, zero_result # Result is zero
# Normalize mantissa to range [128, 255] (1.0 to 1.99)
addi t3, zero, 1
slli t3, t3, 8 # t3 = 256 (threshold)
normalize_loop:
bgeu t2, t3, shift_right
slli t2, t2, 1 # Shift left until normalized
addi t1, t1, -1 # Decrement exponent
bne t1, zero, normalize_loop
j underflow # Exponent underflow
shift_right:
srli t2, t2, 1 # Shift right if mantissa too large
addi t1, t1, 1 # Increment exponent
bgeu t2, t3, shift_right
addi t3, zero, 255 # Check for exponent overflow/underflow
bge t1, t3, overflow
bne t1, zero, pack_result
underflow:
add a0, zero, zero # Exponent underflow - return zero
j done
overflow:
# Exponent overflow - return infinity
addi a0, zero, 127 # Construct 0x7F80 (+inf)
slli a0, a0, 8
ori a0, a0, 128
beq s2, zero, done
addi a0, zero, 255 # Construct 0xFF80 (-inf)
slli a0, a0, 8
ori a0, a0, 128
j done
zero_result:
add a0, zero, zero
j done
pack_result:
andi t2, t2, 127 # Remove implicit bit (keep bits 0-6)
slli a0, s2, 15 # Set sign bit (bit 15)
slli t1, t1, 7 # Shift exponent to bits 7-14
or a0, a0, t1 # Combine sign and exponent
or a0, a0, t2 # Combine with mantissa
done:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
lw s2, 12(sp)
addi sp, sp, 16
ret
# ==================== BF16 Multiplication Function (Fixed) ====================
bf16_mul:
addi sp, sp, -24
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
sw s3, 16(sp)
sw s4, 20(sp)
mv s0, a0 # a
mv s1, a1 # b
# Extract sign bits
srli s2, s0, 15
andi s2, s2, 1 # sign_a
srli t0, s1, 15
andi t0, t0, 1 # sign_b
xor s2, s2, t0 # result_sign = sign_a ^ sign_b
# Extract exponents
srli s3, s0, 7
andi s3, s3, 0xFF # exp_a
srli t0, s1, 7
andi t0, t0, 0xFF # exp_b
# Extract mantissas and add implicit bit
andi s4, s0, 0x7F # mant_a
andi t1, s1, 0x7F # mant_b
# Check for special cases
beq s3, zero, mul_zero_a
li t2, 0xFF
beq s3, t2, mul_inf_nan_a
mul_check_b:
beq t0, zero, mul_zero_b
li t2, 0xFF
beq t0, t2, mul_inf_nan_b
# Add implicit 1 to mantissas for normalized numbers
ori s4, s4, 0x80
ori t1, t1, 0x80
# Multiply mantissas (16-bit result)
mul t2, s4, t1 # mantissa product (16 bits)
# Calculate result exponent
add s3, s3, t0 # exp_a + exp_b
addi s3, s3, -127 # subtract bias
# Normalize mantissa
li t3, 0x4000 # 0x4000 = 1<<14 (check if product >= 0x4000)
bgeu t2, t3, mul_normalize_shift_right
# Need to shift left
slli t2, t2, 1
addi s3, s3, -1
j mul_check_exp
mul_normalize_shift_right:
# Product is too large, shift right
srli t2, t2, 7 # shift to get 7-bit mantissa
andi t2, t2, 0x7F # keep only 7 bits
j mul_pack_result
mul_check_exp:
# Check if we need to normalize more
li t3, 0x4000
bgeu t2, t3, mul_normalize_shift_right
# Get final 7-bit mantissa
srli t2, t2, 7
andi t2, t2, 0x7F
mul_pack_result:
# Check exponent bounds
ble s3, zero, mul_underflow
li t3, 0xFF
bge s3, t3, mul_overflow
# Pack result
slli a0, s2, 15 # sign
slli t3, s3, 7 # exponent
or a0, a0, t3
or a0, a0, t2 # mantissa
j mul_done
mul_zero_a:
# a is zero
beq t0, zero, mul_zero_result # 0 * 0 = 0
li t2, 0xFF
beq t0, t2, mul_nan_result # 0 * inf = NaN
j mul_zero_result
mul_zero_b:
# b is zero
li t2, 0xFF
beq s3, t2, mul_nan_result # inf * 0 = NaN
j mul_zero_result
mul_inf_nan_a:
# a is inf or NaN
beq s4, zero, mul_a_inf # a is infinity
j mul_nan_result # a is NaN
mul_inf_nan_b:
# b is inf or NaN
beq t1, zero, mul_b_inf # b is infinity
j mul_nan_result # b is NaN
mul_a_inf:
beq t0, zero, mul_nan_result # inf * 0 = NaN
j mul_inf_result
mul_b_inf:
beq s3, zero, mul_nan_result # 0 * inf = NaN
j mul_inf_result
mul_zero_result:
li a0, 0
j mul_done
mul_inf_result:
li a0, 0x7F80 # +inf
beq s2, zero, mul_done
li a0, 0xFF80 # -inf
j mul_done
mul_nan_result:
li a0, 0x7FC0 # NaN
j mul_done
mul_underflow:
li a0, 0 # flush to zero
j mul_done
mul_overflow:
li a0, 0x7F80 # +inf
beq s2, zero, mul_done
li a0, 0xFF80 # -inf
mul_done:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
lw s2, 12(sp)
lw s3, 16(sp)
lw s4, 20(sp)
addi sp, sp, 24
ret
# ==================== BF16 Division Function (Fixed) ====================
bf16_div:
addi sp, sp, -24
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
sw s3, 16(sp)
sw s4, 20(sp)
mv s0, a0 # a (dividend)
mv s1, a1 # b (divisor)
# Extract sign bits
srli s2, s0, 15
andi s2, s2, 1 # sign_a
srli t0, s1, 15
andi t0, t0, 1 # sign_b
xor s2, s2, t0 # result_sign = sign_a ^ sign_b
# Extract exponents
srli s3, s0, 7
andi s3, s3, 0xFF # exp_a
srli t0, s1, 7
andi t0, t0, 0xFF # exp_b
# Extract mantissas
andi s4, s0, 0x7F # mant_a
andi t1, s1, 0x7F # mant_b
# Check for special cases
# Division by zero
beq t0, zero, div_by_zero_check
# Check for NaN/infinity
li t2, 0xFF
beq s3, t2, div_inf_nan_a
beq t0, t2, div_inf_nan_b
# Check for zero dividend
beq s3, zero, div_zero_dividend
# Add implicit 1 to mantissas
ori s4, s4, 0x80
ori t1, t1, 0x80
# Perform division using restoring division algorithm
li t2, 0 # quotient
li t3, 8 # counter
mv t4, s4 # remainder (start with dividend mantissa)
div_loop:
slli t2, t2, 1 # shift quotient left
slli t4, t4, 1 # shift remainder left
# Compare remainder with divisor
bltu t4, t1, div_skip_sub
sub t4, t4, t1 # subtract divisor
ori t2, t2, 1 # set quotient bit
div_skip_sub:
addi t3, t3, -1
bnez t3, div_loop
# t2 now contains the quotient mantissa (8 bits)
# Calculate result exponent
sub s3, s3, t0 # exp_a - exp_b
addi s3, s3, 127 # add bias
# Normalize quotient if needed
andi t3, t2, 0x80 # check if implicit bit is set
bnez t3, div_normalized
# Need to normalize - shift left and adjust exponent
slli t2, t2, 1
andi t2, t2, 0xFF # keep 8 bits
addi s3, s3, -1
div_normalized:
# Get final 7-bit mantissa (remove implicit bit)
andi t2, t2, 0x7F
# Check exponent bounds
ble s3, zero, div_underflow
li t3, 0xFF
bge s3, t3, div_overflow
# Pack result
slli a0, s2, 15 # sign
slli t3, s3, 7 # exponent
or a0, a0, t3
or a0, a0, t2 # mantissa
j div_done
div_by_zero_check:
# Division by zero
beq s3, zero, div_zero_by_zero # 0/0 = NaN
li a0, 0x7F80 # +inf
beq s2, zero, div_done
li a0, 0xFF80 # -inf
j div_done
div_zero_by_zero:
li a0, 0x7FC0 # NaN
j div_done
div_zero_dividend:
# 0 / non-zero = 0
li a0, 0
j div_done
div_inf_nan_a:
# a is inf or NaN
beq s4, zero, div_a_inf # a is infinity
j div_nan_result # a is NaN
div_inf_nan_b:
# b is inf or NaN
beq t1, zero, div_b_inf # b is infinity
j div_nan_result # b is NaN
div_a_inf:
beq t0, zero, div_nan_result # inf / 0 = NaN
li t2, 0xFF
beq t0, t2, div_nan_result # inf / inf = NaN
j div_inf_result
div_b_inf:
beq s3, zero, div_zero_result # 0 / inf = 0
j div_zero_result # finite / inf = 0
div_inf_result:
li a0, 0x7F80 # +inf
beq s2, zero, div_done
li a0, 0xFF80 # -inf
j div_done
div_zero_result:
li a0, 0
j div_done
div_nan_result:
li a0, 0x7FC0 # NaN
j div_done
div_underflow:
li a0, 0 # flush to zero
j div_done
div_overflow:
li a0, 0x7F80 # +inf
beq s2, zero, div_done
li a0, 0xFF80 # -inf
div_done:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
lw s2, 12(sp)
lw s3, 16(sp)
lw s4, 20(sp)
addi sp, sp, 24
ret
# ==================== utility functions ====================
# Check if bfloat16 value is NaN
bf16_isnan:
srli t0, a0, 7
andi t0, t0, 255
addi t1, zero, 255
bne t0, t1, not_nan_ret
andi t0, a0, 127 # Exponent is all 1s, check if mantissa is non-zero
beq t0, zero, not_nan_ret # Infinity, not NaN
addi a0, zero, 1
ret
not_nan_ret:
add a0, zero, zero
ret
# Check if bfloat16 value is Infinity
bf16_isinf:
srli t0, a0, 7
andi t0, t0, 255
addi t1, zero, 255
bne t0, t1, not_inf_ret
andi t0, a0, 127 # Exponent is all 1s, check if mantissa is zero
bne t0, zero, not_inf_ret # NaN, not Infinity
addi a0, zero, 1
ret
not_inf_ret:
add a0, zero, zero
ret
# Check if bfloat16 value is zero
bf16_iszero:
andi t0, a0, 255 # Check lower 8 bits
srli t1, a0, 8 # Check next 7 bits
andi t1, t1, 127
or t0, t0, t1 # Combine results
bne t0, zero, not_zero_ret
addi a0, zero, 1
ret
not_zero_ret:
add a0, zero, zero
ret
# Check if two bfloat16 values are equal
bf16_eq:
beq a0, a1, equal_ret
add a0, zero, zero
ret
equal_ret:
addi a0, zero, 1
ret
```
**1. Execution information** - <span style="color:darkblue">**$1823 \;\text{cycles}$**</span>

---
## [Leetcode`66`](https://leetcode.com/problems/plus-one/)
### Description
You are given a large integer represented as an integer array `digits`, where each `digits[i]` is the ith digit of the integer. The digits are ordered from **most significant to least significant** in left-to-right order. The large integer does **not** contain any leading `0'`s.
Increment the large integer by one and **return the resulting array of digits**.
#### Example `1`
```vb
Input: digits = [1,2,3]
Output: [1,2,4]
Explanation: The array represents the integer 123.
Incrementing by one gives 123 + 1 = 124, so the result is [1,2,4].
```
#### Example `2`
```vb
Input: digits = [9,9,9]
Output: [1,0,0,0]
Explanation: The array represents the integer 999.
Incrementing by one gives 1000, so the result is [1,0,0,0].
```
### Original C program
```c
int* plusOne(int* digits, int digitsSize, int* returnSize) {
// Allocate result array (max possible size is digitsSize + 1)
int* result = (int*)malloc(sizeof(int) * (digitsSize + 1));
// Initialize result with zeros
for (int i = 0; i <= digitsSize; i++) {
result[i] = 0;
}
// Copy input to result
for (int i = 0; i < digitsSize; i++) {
result[i] = digits[i];
}
int carry = 1; // Start with carry = 1 (adding 1)
// Process from least significant digit
for (int i = digitsSize - 1; i >= 0; i--) {
if (carry) {
result[i] += 1;
if (result[i] >= 10) {
result[i] = 0;
carry = 1;
} else {
carry = 0;
}
}
}
// Check if we need to expand the array
if (carry) {
// Shift all elements right
for (int i = digitsSize; i > 0; i--) {
result[i] = result[i - 1];
}
result[0] = 1;
*returnSize = digitsSize + 1;
} else {
*returnSize = digitsSize;
}
return result;
}
```
### Assembly code
<h4>
<span style="color:darkblue">version <code>1</code></span>
</h4>
```s
# Plus One - RISC-V Assembly Implementation
# Implements: digits = digits + 1 for large integers represented as arrays
.data
# Test data arrays
test1: .word 1, 2, 3 # [1,2,3] -> [1,2,4]
test1_len: .word 3
test1_exp: .word 1, 2, 4 # Expected result
test2: .word 4, 3, 2, 1 # [4,3,2,1] -> [4,3,2,2]
test2_len: .word 4
test2_exp: .word 4, 3, 2, 2 # Expected result
test3: .word 9, 9, 9 # [9,9,9] -> [1,0,0,0]
test3_len: .word 3
test3_exp: .word 1, 0, 0, 0 # Expected result (length changes!)
# Output buffer for results
result: .word 0, 0, 0, 0, 0 # 5 words buffer
# Test messages
test_start_msg: .string "=== Plus One Automated Test ===\n\n"
test1_msg: .string "Test 1: [1,2,3] -> [1,2,4] "
test2_msg: .string "Test 2: [4,3,2,1] -> [4,3,2,2] "
test3_msg: .string "Test 3: [9,9,9] -> [1,0,0,0] "
pass_msg: .string "PASS\n"
fail_msg: .string "FAIL\n"
separator: .string "\n====================\n"
all_pass_msg: .string "All tests passed.\n"
some_fail_msg: .string "Some tests failed.\n"
newline: .string "\n"
.text
.globl main
main:
# Print test start message
la a0, test_start_msg
li a7, 4
ecall
# Initialize test counters
li s0, 0 # s0 = pass count
li s1, 3 # s1 = total test count
# Test 1: [1,2,3] -> [1,2,4]
la a0, test1_msg
li a7, 4
ecall
la a0, test1 # input array
lw a1, test1_len # array length
jal ra, plus_one
la a1, test1_exp # expected result
lw a2, test1_len # expected length
jal ra, compare_arrays
beqz a0, test1_fail
addi s0, s0, 1 # increment pass count
la a0, pass_msg
j test1_end
test1_fail:
la a0, fail_msg
test1_end:
li a7, 4
ecall
# Test 2: [4,3,2,1] -> [4,3,2,2]
la a0, test2_msg
li a7, 4
ecall
la a0, test2 # input array
lw a1, test2_len # array length
jal ra, plus_one
la a1, test2_exp # expected result
lw a2, test2_len # expected length
jal ra, compare_arrays
beqz a0, test2_fail
addi s0, s0, 1 # increment pass count
la a0, pass_msg
j test2_end
test2_fail:
la a0, fail_msg
test2_end:
li a7, 4
ecall
# Test 3: [9,9,9] -> [1,0,0,0]
la a0, test3_msg
li a7, 4
ecall
la a0, test3 # input array
lw a1, test3_len # array length
jal ra, plus_one
la a1, test3_exp # expected result
li a2, 4 # expected length is 4 (changed due to carry)
jal ra, compare_arrays
beqz a0, test3_fail
addi s0, s0, 1 # increment pass count
la a0, pass_msg
j test3_end
test3_fail:
la a0, fail_msg
test3_end:
li a7, 4
ecall
# Print separator
la a0, separator
li a7, 4
ecall
# Print final results
beq s0, s1, all_passed
la a0, some_fail_msg
j print_final
all_passed:
la a0, all_pass_msg
print_final:
li a7, 4
ecall
# Exit program
li a7, 10
ecall
# Plus One function
# Corresponds to: int* plusOne(int* digits, int digitsSize, int* returnSize)
plus_one:
addi sp, sp, -20
sw ra, 0(sp)
sw s0, 4(sp) # input array address
sw s1, 8(sp) # original length
sw s2, 12(sp) # result buffer address
sw s3, 16(sp) # carry flag
mv s0, a0 # s0 = input array address
mv s1, a1 # s1 = original length
la s2, result # s2 = result buffer address
# Initialize result buffer to zeros
la t0, result
li t1, 5 # max possible length
li t2, 0
init_loop:
beqz t1, init_done
sw t2, 0(t0)
addi t0, t0, 4
addi t1, t1, -1
j init_loop
init_done:
# Copy input to result
mv t0, s0 # source = input array
mv t1, s2 # destination = result buffer
mv t2, s1 # counter = length
copy_loop:
beqz t2, copy_done
lw t3, 0(t0)
sw t3, 0(t1)
addi t0, t0, 4
addi t1, t1, 4
addi t2, t2, -1
j copy_loop
copy_done:
# Start from least significant digit with carry = 1
li s3, 1 # carry flag = true (initially adding 1)
# Calculate pointer to last element
la t0, result # result array
addi t1, s1, -1 # index of last element
slli t1, t1, 2 # convert to byte offset
add t0, t0, t1 # point to last element
mv t2, s1 # counter = length
process_digits:
beqz t2, check_final_carry
# Check if we need to add (last digit or carry is set)
beqz s3, no_addition
# Add 1 to current digit
lw t3, 0(t0) # load current digit
addi t3, t3, 1 # add 1
# Check for carry
li t4, 10
blt t3, t4, no_overflow
# Handle overflow: set digit to 0, keep carry flag
li t3, 0
li s3, 1 # carry remains true
j store_digit
no_overflow:
# No overflow, clear carry flag
li s3, 0
store_digit:
sw t3, 0(t0)
no_addition:
# Move to next digit (more significant)
addi t0, t0, -4
addi t2, t2, -1
j process_digits
check_final_carry:
# If we still have carry after processing all digits
beqz s3, plus_one_done
# Need to expand array: shift right and add 1 at front
# Calculate new length
addi s1, s1, 1
# Shift all elements right by one position
la t0, result
addi t1, s1, -2 # index of second last element in new array
slli t1, t1, 2
add t2, t0, t1 # source pointer (starts at original last element)
addi t3, t2, 4 # destination pointer (one position right)
mv t4, s1 # counter = new length - 1
addi t4, t4, -1
shift_loop:
beqz t4, shift_done
lw t5, 0(t2)
sw t5, 0(t3)
addi t2, t2, -4
addi t3, t3, -4
addi t4, t4, -1
j shift_loop
shift_done:
# Add 1 as the new most significant digit
li t0, 1
la t1, result
sw t0, 0(t1)
plus_one_done:
mv a0, s1 # return new length
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
lw s2, 12(sp)
lw s3, 16(sp)
addi sp, sp, 20
ret
# Compare arrays function
# Input: a0 = actual length, a1 = expected array address, a2 = expected length
# Output: a0 = 1 if arrays match, 0 otherwise
compare_arrays:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
mv s0, a1 # s0 = expected array
mv s1, a2 # s1 = expected length
# First check if lengths match
bne a0, s1, compare_fail
# Now compare each element
la t0, result # actual result array
mv t1, s0 # expected array
mv t2, s1 # counter
compare_loop:
beqz t2, compare_success
lw t3, 0(t0) # actual value
lw t4, 0(t1) # expected value
bne t3, t4, compare_fail
addi t0, t0, 4
addi t1, t1, 4
addi t2, t2, -1
j compare_loop
compare_success:
li a0, 1
j compare_done
compare_fail:
li a0, 0
compare_done:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
```
**1. Execution information** - <span style="color:darkblue">**$843 \;\text{cycles}$**</span>

---
<h4>
<span style="color:darkblue">version <code>2</code></span>
</h4>
**1. Improvement**
* **simplify carry handling logic in** <span style="color:darkblue">**version**</span> **`1`**:
```s
add_loop:
# Load current digit
lw t2, 0(t0)
# Add carry
add t2, t2, t1
li t1, 0 # reset carry
# Check if digit >= 10
li t3, 10
blt t2, t3, no_carry
# Handle carry: set digit to 0, set carry to 1
li t2, 0
li t1, 1
```
* **add early exit mechanism in** <span style="color:darkblue">**version**</span> **`1`**:
```s
beqz t1, add_done # Exit early if no carry
```
* **eliminated unnecessary initialization**
Directly copies the input data without separate initialization.
**2. Execution information** - <span style="color:darkblue">**$610 \;\text{cycles}$**</span>

**3. why faster ?**
* **reduced memory access**
Uses only 2 saved registers (`s0-s1`) + temporary registers.
* **branch prediction improvements**
More straightforward logic with fewer conditional branches and consistent iteration behavior improves branch prediction accuracy.
* **instruction count reduction**
* Removed `O(n)` initialization loop that zeroed result buffer.
* Reduced from `2 ~ 3` conditional branches to direct arithmetic
* Added immediate termination when carry propagation stops.
<h4>
<span style="color:darkblue">complete version</span>
</h4>
```s
# Plus One - RISC-V Assembly Implementation
# Implements: digits = digits + 1 for large integers represented as arrays
.data
# Test data arrays
test1: .word 1, 2, 3 # [1,2,3] -> [1,2,4]
test1_len: .word 3
test1_exp: .word 1, 2, 4 # Expected result
test2: .word 4, 3, 2, 1 # [4,3,2,1] -> [4,3,2,2]
test2_len: .word 4
test2_exp: .word 4, 3, 2, 2 # Expected result
test3: .word 9, 9, 9 # [9,9,9] -> [1,0,0,0]
test3_len: .word 3
test3_exp: .word 1, 0, 0, 0 # Expected result (length changes!)
# Output buffer for results
result: .word 0, 0, 0, 0, 0 # 5 words buffer
# Test messages
test_start_msg: .string "=== Plus One Automated Test ===\n\n"
test1_msg: .string "Test 1: [1,2,3] -> [1,2,4] "
test2_msg: .string "Test 2: [4,3,2,1] -> [4,3,2,2] "
test3_msg: .string "Test 3: [9,9,9] -> [1,0,0,0] "
pass_msg: .string "PASS\n"
fail_msg: .string "FAIL\n"
separator: .string "\n====================\n"
all_pass_msg: .string "All tests passed.\n"
some_fail_msg: .string "Some tests failed.\n"
newline: .string "\n"
.text
.globl main
main:
# Print test start message
la a0, test_start_msg
li a7, 4
ecall
# Initialize test counters
li s0, 0 # s0 = pass count
li s1, 3 # s1 = total test count
# Test 1: [1,2,3] -> [1,2,4]
la a0, test1_msg
li a7, 4
ecall
la a0, test1 # input array
lw a1, test1_len # array length
jal ra, plus_one
la a1, test1_exp # expected result
lw a2, test1_len # expected length
jal ra, compare_arrays
beqz a0, test1_fail
addi s0, s0, 1 # increment pass count
la a0, pass_msg
j test1_end
test1_fail:
la a0, fail_msg
test1_end:
li a7, 4
ecall
# Test 2: [4,3,2,1] -> [4,3,2,2]
la a0, test2_msg
li a7, 4
ecall
la a0, test2 # input array
lw a1, test2_len # array length
jal ra, plus_one
la a1, test2_exp # expected result
lw a2, test2_len # expected length
jal ra, compare_arrays
beqz a0, test2_fail
addi s0, s0, 1 # increment pass count
la a0, pass_msg
j test2_end
test2_fail:
la a0, fail_msg
test2_end:
li a7, 4
ecall
# Test 3: [9,9,9] -> [1,0,0,0]
la a0, test3_msg
li a7, 4
ecall
la a0, test3 # input array
lw a1, test3_len # array length
jal ra, plus_one
la a1, test3_exp # expected result
li a2, 4 # expected length is 4 (changed due to carry)
jal ra, compare_arrays
beqz a0, test3_fail
addi s0, s0, 1 # increment pass count
la a0, pass_msg
j test3_end
test3_fail:
la a0, fail_msg
test3_end:
li a7, 4
ecall
# Print separator
la a0, separator
li a7, 4
ecall
# Print final results
beq s0, s1, all_passed
la a0, some_fail_msg
j print_final
all_passed:
la a0, all_pass_msg
print_final:
li a7, 4
ecall
# Exit program
li a7, 10
ecall
# Optimized plus_one function
plus_one:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp) # result buffer address
sw s1, 8(sp) # original length
la s0, result # s0 = result buffer address
mv s1, a1 # s1 = original length
# Copy input to result
mv t0, a0 # source
mv t1, s0 # destination
mv t2, a1 # counter
copy_loop:
beqz t2, copy_done
lw t3, 0(t0)
sw t3, 0(t1)
addi t0, t0, 4
addi t1, t1, 4
addi t2, t2, -1
j copy_loop
copy_done:
# Add 1 starting from the last digit
addi t0, s1, -1 # index of last element
slli t0, t0, 2 # convert to byte offset
add t0, s0, t0 # pointer to last element
li t1, 1 # carry = 1 (we're adding 1)
add_loop:
# Load current digit
lw t2, 0(t0)
# Add carry
add t2, t2, t1
li t1, 0 # reset carry
# Check if digit >= 10
li t3, 10
blt t2, t3, no_carry
# Handle carry: set digit to 0, set carry to 1
li t2, 0
li t1, 1
no_carry:
# Store updated digit
sw t2, 0(t0)
# Check if we need to continue
beqz t1, add_done # no carry, we're done
# Move to previous digit
addi t0, t0, -4
# Check if we've reached the beginning
blt t0, s0, expand_array
j add_loop
add_done:
mv a0, s1 # return original length
j plus_one_exit
expand_array:
# All digits were 9, need to expand array
# Shift all digits right by one position
addi t0, s1, -1 # start from last index
shift_loop:
bltz t0, shift_done
slli t1, t0, 2 # byte offset
add t1, s0, t1 # source address
lw t2, 0(t1) # load value
addi t3, t1, 4 # destination address (one position right)
sw t2, 0(t3) # store value
addi t0, t0, -1 # move to previous index
j shift_loop
shift_done:
# Set first digit to 1
li t0, 1
sw t0, 0(s0)
addi a0, s1, 1 # return new length
plus_one_exit:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
# Compare arrays function
# Input: a0 = actual length, a1 = expected array address, a2 = expected length
# Output: a0 = 1 if arrays match, 0 otherwise
compare_arrays:
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
mv s0, a1 # s0 = expected array
mv s1, a2 # s1 = expected length
# First check if lengths match
bne a0, s1, compare_fail
# Now compare each element
la t0, result # actual result array
mv t1, s0 # expected array
mv t2, s1 # counter
compare_loop:
beqz t2, compare_success
lw t3, 0(t0) # actual value
lw t4, 0(t1) # expected value
bne t3, t4, compare_fail
addi t0, t0, 4
addi t1, t1, 4
addi t2, t2, -1
j compare_loop
compare_success:
li a0, 1
j compare_done
compare_fail:
li a0, 0
compare_done:
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
ret
```
## Ripes Simulator
Use <span style="color:blue">**Ripes**</span>, which is a graphical processor simulator and assembly editor for the **RISC-V** architecture
### 5-stage pipelined processor

A pipelined processor divides instruction execution into five stages:
* **Instruction Fetch (IF)**
The Instruction Fetch stage is the first step in the pipeline process. Here, the CPU retrieves an instruction from the program memory. The primary tasks in this stage include:
* `Program Counter (PC) Management`
* `Instruction Memory Access`
* `Buffering`
Efficiency at this stage is paramount as delays here can **stall** the entire pipeline, leading to performance bottlenecks. **Advanced CPUs** often use techniques like prefetching to mitigate potential delays.
* **Decode (ID)**
Once the instruction is fetched, it enters the Instruction Decode stage. This stage involves interpreting the fetched instruction and preparing the necessary operands for execution. Key activities include:
* `Opcode Decoding`
* `Register Read`
* `Instruction Classification`
This stage is also responsible for **hazard detection and forwarding** to resolve data dependencies and prevent pipeline stalls.
* **Execute (EX)**
The Execute stage is where the actual computation or operation specified by the instruction takes place. This stage includes:
* `ALU Operations`
* `Address Calculation`
* `Branch Evaluation`
The execution stage is critical for the **overall performance**, as complex instructions can take multiple cycles, potentially causing pipeline stalls.
* **Memory Access (MEM)**
After execution, some instructions require access to memory to read or write data. The Memory Access stage handles these operations. Key functions include:
* `Load Operations`
* `Store Operations`
* `Address Translation`
Efficiency in this stage is achieved through techniques like caching, which **reduces latency** by storing frequently accessed data closer to the CPU.
* **Write Back (WB)**
The final stage in the pipeline is Write Back. Here, the results of the executed instructions are written back to the CPU's register file. This stage involves:
* `Register Write`
* `Completion Logging`
Write Back is crucial for ensuring that subsequent instructions have access to the correct and updated data, **maintaining the integrity and consistency** of the processor’s state.
---
**Reference** :
* [leetcode 66](https://leetcode.com/problems/plus-one/)
* [5 Stages of Pipeline in Computer Architecture](https://medium.com/@aylia.zulfiqar29/5-stages-of-pipeline-in-computer-architecture-dc9fca11784e)
* [arch2025-quiz1-sol](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
* [@sysprog21/ca2025-quizzes](https://github.com/sysprog21/ca2025-quizzes/tree/main)