# Assignment 1: RISC-V Assembly and Instruction Pipeline
> contributed by < [`kkevinhu`](https://github.com/kkevinhu) >
## Introduction
### 1009. Complement of Base 10 Integer
The complement of an integer is the integer you get when you flip all the 0's to 1's and all the 1's to 0's in its binary representation.
For example, The integer 5 is "101" in binary and its complement is "010" which is the integer 2.
Given an integer n, return its complement.
Example 1:
- Input: n = 5
- Output: 2
- Explanation: 5 is "101" in binary, with complement "010" in binary, which is 2 in base-10.
Example 2:
- Input: n = 7
- Output: 0
- Explanation: 7 is "111" in binary, with complement "000" in binary, which is 0 in base-10.
## Implementation
You can find the source code and more completed details [here](https://github.com/kkevinhu/ca2025-quizzes).
And the CPU I used in Ripes is `5-stage processor`
(A 5-stage in-order processor with hazard detection/elimination and forwarding.)
### Motivation
The main task is to find the bitmask covering the significant bits of a number. A simple loop-based approach shifts bits repeatedly, causing variable runtime. By using CLZ, we can determine the bit length directly and build the mask in constant time. This makes the algorithm faster, more predictable, and closer to hardware-efficient execution on architectures that support CLZ.
#### For example :
`5` -> `0000...0101`, its complement will be `1111...1010`, but only the lower 3 bits is matter, so we can use `CLZ()` to find the position of MSB is 3.
Then we can make a bit mask like (1 << 3) - 1 = `0000...0111`.
Finally get the result by `1111...1010` & `0000...0111` = `0000...0010` which decimal is `2`.
### C code without CLZ
The code repeatedly shifts bits to find the highest set bit
```c=
int bitwiseComplement(int n) {
if (n == 0) return 1;
int mask = 0, temp = n;
while (temp > 0) {
mask = (mask << 1) | 1;
temp >>= 1;
}
return n ^ mask;
}
```
### C code with loopless CLZ
CLZ directly uses efficient bit operations to find the most significant bit (MSB), without needing loops or right-shift masking.
```c=
static inline unsigned clz(uint32_t x)
{
if (x == 0) return 32; // all bits are 0 → 32 leading zeros
int n = 0;
if ((x >> 16) == 0) {
n += 16;
x <<= 16;
}
if ((x >> 24) == 0) {
n += 8;
x <<= 8;
}
if ((x >> 28) == 0) {
n += 4;
x <<= 4;
}
if ((x >> 30) == 0) {
n += 2;
x <<= 2;
}
if ((x >> 31) == 0) n += 1;
return n;
}
int bitwiseComplement(int n) {
if (n==0) return 1;
unsigned lz = clz(n);
unsigned msb = 31 - lz;
int mask = (1 << (msb + 1)) - 1;
return n ^ mask;
}
```
### C code with branchless CLZ
```c=
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
int bitwiseComplement(int n) {
if (n==0) return 1;
int mask = (1 << (32 - clz(n))) - 1;
return ~n & mask;
}
```
### Assembly for Complement of Base 10 Integer
In this assembly, I converted the above C code into assembly.
- `main` : Verified the correctness of function bitwiseComplement and CLZ
- `bitwiseComplement` :
* n is zero : return 1
* Otherwise : First, count leading zero, it can finds the position of MSB in n. Then we have to create bit mask that has 1s in all positions up to the MSB . Finally `XOR` flips all bits of n that fall under the mask.
#### Without CLZ
:::spoiler See More
```asm=
.data
tests: .word 0, 5, 7, 10, 121 # Test Input
answers: .word 1, 2, 0, 5, 6 # Answers for test input
n_tests: .word 5
msg_test: .string "Test "
msg_input: .string ": input = "
msg_result: .string ", result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, tests # s0 = tests
la s1, answers # s1 = answers
lw s2, n_tests # s2 = n_tests
li s3, 0 # s3 = index
loop_tests:
beq s3, s2, done
# Output "Test <index>"
la a0, msg_test
li a7, 4
ecall
addi a0, s3, 1
li a7, 1
ecall
la a0, msg_input
li a7, 4
ecall
# Get test input from tests
slli t0, s3, 2
add t1, s0, t0
lw a0, 0(t1)
addi t4, a0, 0
li a7, 1
ecall
la a0, msg_result
li a7, 4
ecall
addi a0, t4, 0
jal ra, bitwiseComplement
# a1 = result
addi t5, a1, 0
# Output result
addi a0, t5, 0
li a7, 1
ecall
# Get ans from answers
slli t0, s3, 2
add t2, s1, t0
lw t3, 0(t2)
# Compare result and correct answers
bne t5, t3, fail
pass:
la a0, msg_pass
li a7, 4
ecall
addi s3, s3, 1
j loop_tests
fail:
la a0, msg_fail
li a7, 4
ecall
addi s3, s3, 1
j loop_tests
done:
li a7, 10 # exit
ecall
bitwiseComplement:
addi sp, sp, -16
sw ra, 12(sp)
# if (n == 0) return 1;
beqz a0, ret_one
mv t0, a0 # temp = n
li t1, 0 # mask = 0
build_mask:
blez t0, mask_done # while (temp > 0)
slli t1, t1, 1 # mask <<= 1
ori t1, t1, 1 # mask |= 1
srli t0, t0, 1 # temp >>= 1
j build_mask
mask_done:
xor a1, a0, t1 # return n ^ mask
j done_func
ret_one:
li a1, 1 # return 1
done_func:
lw ra, 12(sp)
addi sp, sp, 16
jr ra
```
:::
- code line : 114
- cpu cycle : 506
#### With branchless CLZ
:::spoiler See More
```asm=
.data
tests: .word 0, 5, 7, 10, 121 # Test Input
answers: .word 1, 2, 0, 5, 6 # Answers for test input
n_tests: .word 5
msg_test: .string "Test "
msg_input: .string ": input = "
msg_result: .string ", result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, tests # s0 = tests
la s1, answers # s1 = answers
lw s2, n_tests # s2 = n_tests
li s3, 0 # s3 = index
loop_tests:
beq s3, s2, done
# Output "Test <index>"
la a0, msg_test
li a7, 4
ecall
addi a0, s3, 1
li a7, 1
ecall
la a0, msg_input
li a7, 4
ecall
# Get test input from tests
slli t0, s3, 2
add t1, s0, t0
lw a0, 0(t1)
addi t4, a0, 0
li a7, 1
ecall
la a0, msg_result
li a7, 4
ecall
addi a0, t4, 0
jal ra, bitwiseComplement
# a1 = result
addi t5, a1, 0
# Output result
addi a0, t5, 0
li a7, 1
ecall
# Get ans from answers
slli t0, s3, 2
add t2, s1, t0
lw t3, 0(t2)
# Compare result and correct answers
bne t5, t3, fail
pass:
la a0, msg_pass
li a7, 4
ecall
addi s3, s3, 1
j loop_tests
fail:
la a0, msg_fail
li a7, 4
ecall
addi s3, s3, 1
j loop_tests
done:
li a7, 10 # exit
ecall
bitwiseComplement:
addi sp, sp, -4
sw ra, 0(sp)
beqz a0, zero
jal ra, clz
addi t0, a1, 0
li t1, 32
sub t0, t1, t0
li t1, 1
sll t0, t1, t0
sub t0, t0, t1
xor a1, a0, t0
j return
zero:
addi a1, a0, 1
j return
return:
lw ra, 0(sp)
addi sp, sp, 4
ret
clz:
addi sp, sp, -16
sw ra, 12(sp)
sw a0, 8(sp)
li t0, 32
li t1, 16
loop:
srl t2, a0, t1
beqz t2, skip
sub t0, t0, t1
addi a0, t2, 0
skip:
srli t1, t1, 1
bnez t1, loop
sub a1, t0, a0
lw ra, 12(sp)
lw a0, 8(sp)
addi sp, sp, 16
jr ra
```
:::
- code line : 128
- cpu cycle : 600
#### With loopless CLZ
:::spoiler See More
```asm=
.data
tests: .word 0, 5, 7, 10, 121 # Test Input
answers: .word 1, 2, 0, 5, 6 # Answers for test input
n_tests: .word 5
msg_test: .string "Test "
msg_input: .string ": input = "
msg_result: .string ", result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, tests # s0 = tests
la s1, answers # s1 = answers
lw s2, n_tests # s2 = n_tests
li s3, 0 # s3 = index
loop_tests:
beq s3, s2, done
# Output "Test <index>"
la a0, msg_test
li a7, 4
ecall
addi a0, s3, 1
li a7, 1
ecall
la a0, msg_input
li a7, 4
ecall
# Get test input from tests
slli t0, s3, 2
add t1, s0, t0
lw a0, 0(t1)
addi t4, a0, 0
li a7, 1
ecall
la a0, msg_result
li a7, 4
ecall
addi a0, t4, 0
jal ra, bitwiseComplement
# a1 = result
addi t5, a1, 0
# Output result
addi a0, t5, 0
li a7, 1
ecall
# Get ans from answers
slli t0, s3, 2
add t2, s1, t0
lw t3, 0(t2)
# Compare result and correct answers
bne t5, t3, fail
pass:
la a0, msg_pass
li a7, 4
ecall
addi s3, s3, 1
j loop_tests
fail:
la a0, msg_fail
li a7, 4
ecall
addi s3, s3, 1
j loop_tests
done:
li a7, 10 # exit
ecall
bitwiseComplement:
addi sp, sp, -4
sw ra, 0(sp)
beqz a0, zero
jal ra, clz
addi t0, a1, 0
li t1, 32
sub t0, t1, t0
li t1, 1
sll t0, t1, t0
sub t0, t0, t1
xor a1, a0, t0
j return
zero:
addi a1, a0, 1
j return
return:
lw ra, 0(sp)
addi sp, sp, 4
ret
clz:
addi sp, sp, -8
sw ra, 0(sp)
sw a0, 4(sp)
beqz a0, clz_zero # if x == 0 -> return 32
li t0, 0 # n = 0
chk_16:
srli t1, a0, 16
bnez t1, chk_8 # if (x >> 16) != 0 -> skip
addi t0, t0, 16 # n += 16
slli a0, a0, 16 # x <<= 16
chk_8:
srli t1, a0, 24
bnez t1, chk_4
addi t0, t0, 8
slli a0, a0, 8
chk_4:
srli t1, a0, 28
bnez t1, chk_2
addi t0, t0, 4
slli a0, a0, 4
chk_2:
srli t1, a0, 30
bnez t1, chk_31
addi t0, t0, 2
slli a0, a0, 2
chk_31:
srli t1, a0, 31
beqz t1, add_one # if bit31 == 0 -> add 1
j clz_done
add_one:
addi t0, t0, 1
clz_done:
mv a1, t0
j clz_return
clz_zero:
li a1, 32 # return 32 if input = 0
clz_return:
lw ra, 0(sp)
lw a0, 4(sp)
addi sp, sp, 8
ret
```
:::
- code line : 159
- cpu cycle : 550
### Analysis
- Time complexity
- Due to the loop in branchless CLZ at most run $log16$ = $4$ times, so it's time complexity will be constant $O(1)$
- Due to loopless CLZ doesn't has any loop, so it's time complexity will be $O(1)$
- So I reduce time complexity from $O(logn)$ to $O(1)$
- Cycle (branchless v.s loopless)
- I reduce cycles from 600 to 550, so loopless CLZ has more outstanding performance than branchless CLZ
- Fewer branch misprediction penalty
- Fixed number of operations
- Better pipeline utilization
- Summary :
- Using CLZ achieves better time complexity compared to the version without CLZ. Although the non-CLZ may have fewer CPU cycles in some cases, its cycle increases significantly when tested with extreme input values.
- Moreover, the loopless version requires even fewer cycles than the branchless version to complete the same operation.
## uf8
In this part, I transfer `q1-uf8.c` into Assembly, and I also used some test data to verify each function's correctness.
### C code
:::spoiler See More
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
:::
### Assembly
- #### uf8-decode
:::spoiler See More
```asm=
.data
test: .word 0x00
.word 0xf
.word 0x68
.word 0xa9
.word 0xFF
ans: .word 0x00
.word 0xf
.word 0x5f0
.word 0x63f0
.word 0xF7FF0
n_tests: .word 5
msg_input: .string "Input = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, test
la s1, ans
lw s2, n_tests
li s4, 0 # index
loop_t:
beq s4, s2, done
la a0, msg_input
li a7, 4
ecall
lw s3, 0(s0)
addi a0, s3, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, decode
addi t0, a1, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s1)
addi s0, s0, 4
addi s1, s1, 4
addi s4, s4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
decode:
addi sp, sp, -16
sw ra, 12(sp)
sw s3, 8(sp) # s3 = f1
andi t0, s3, 0x0f
srli t1, s3, 4
li t2, 15
sub t2, t2, t1
li t3, 0x7FFF
srl t3, t3, t2
slli t3, t3, 4
sll t0, t0, t1
add a1, t0, t3
lw ra, 12(sp)
lw s3, 8(sp)
addi sp, sp, 16
jr ra
```
:::
- #### uf8-clz
:::spoiler See More
```asm=
.data
test: .word 0x00
.word 0x24
.word 0x210
.word 0x63ff0
.word 0x10000000
ans: .word 0x20
.word 0x1a
.word 0x16
.word 0xd
.word 0x3
n_tests: .word 5
msg_input: .string "Input = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, test
la s1, ans
lw s2, n_tests
li s4, 0 # index
loop_t:
beq s4, s2, done
la a0, msg_input
li a7, 4
ecall
lw s3, 0(s0)
addi a0, s3, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, clz
addi t0, a1, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s1)
addi s0, s0, 4
addi s1, s1, 4
addi s4, s4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
clz:
addi sp, sp, -16
sw ra, 12(sp)
sw s3, 8(sp) # s3 = x
li t0, 32 # t0 = n
li t1, 16 # t1 = c
loop:
srl t2, s3, t1 # t2 = y
beqz t2, skip
sub t0, t0, t1
addi s3, t2, 0
skip:
srli t1, t1, 1
bnez t1, loop
sub a1, t0, s3
lw ra, 12(sp)
lw s3, 8(sp)
addi sp, sp, 16
jr ra
```
:::
- #### uf8-encode
:::spoiler See More
```asm=
.data
test: .word 0x00
.word 0xf
.word 0xd0
.word 0x2df0
.word 0x000F7FF0
ans: .word 0x00
.word 0xf
.word 0x3c
.word 0x97
.word 0xFF
n_tests: .word 5
msg_input: .string "Input = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.global main
main:
la s0, test
la s1, ans
lw s2, n_tests
li s4, 0 # index
loop_t:
beq s4, s2, done
la a0, msg_input
li a7, 4
ecall
lw s3, 0(s0)
addi a0, s3, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, encode
addi t0, a1, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s1)
addi s0, s0, 4
addi s1, s1, 4
addi s4, s4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
encode:
addi sp, sp, -32
sw ra, 28(sp)
sw s3, 24(sp) # s3 = x
sw s0, 20(sp)
sw s1, 16(sp)
sw s2, 12(sp)
li t0, 16
blt s3, t0, special
jal ra, clz
addi t1, a1, 0 # t1 = lz
li t2, 31
sub t1, t2, t1 # t1 = msb
li t2, 5
blt t1, t2, find_exact_exp
addi s0, t1, -4 # s0 = exponent
li t2, 15
bgt s0, t2, limit_exp
j est_loop_init
limit_exp:
addi s0, t2, 0
est_loop_init:
li s1, 0 # s1 = overflow
li t1, 0 # t1 = e
est_loop:
bge t1, s0, adjust_est
slli t2, s1, 1
addi s1, t2, 16
addi t1, t1, 1
j est_loop
adjust_est:
beqz s0, find_exact_exp
adjust_loop:
ble s3, s1, adjust_inner
j find_exact_exp
adjust_inner:
addi t2, s1, -16
srli s1, t2, 1
addi s0, s0, -1
bnez s0, adjust_loop
find_exact_exp:
li t2, 15
find_loop:
bge s0, t2, find_done
slli t0, s1, 1
addi t0, t0, 16
blt s3, t0, find_done
addi s1, t0, 0
addi s0, s0, 1
j find_loop
find_done:
sub t0, s3, s1
srl s2, t0, s0 # s2 = mantissa
slli s0, s0, 4
or a1, s0, s2
return:
lw ra, 28(sp)
lw s3, 24(sp) # s3 = x
lw s0, 20(sp)
lw s1, 16(sp)
lw s2, 12(sp)
addi sp, sp, 32
ret
special:
addi a1, s3, 0
j return
clz:
addi sp, sp, -16
sw ra, 12(sp)
sw s3, 8(sp) # s3 = x
li t0, 32 # t0 = n
li t1, 16 # t1 = c
loop:
srl t2, s3, t1 # t2 = y
beqz t2, skip
sub t0, t0, t1
addi s3, t2, 0
skip:
srli t1, t1, 1
bnez t1, loop
sub a1, t0, s3
lw ra, 12(sp)
lw s3, 8(sp)
addi sp, sp, 16
jr ra
```
:::
## bfloat16
In this part, I transfer `q1-bfloat16.c` into Assembly, and I also used some test data to verify each function's correctness.
### C code
:::spoiler See More
```c=
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F),
};
}
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
static inline bool bf16_eq(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return true;
return a.bits == b.bits;
}
static inline bool bf16_lt(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return false;
bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1;
if (sign_a != sign_b)
return sign_a > sign_b;
return sign_a ? a.bits > b.bits : a.bits < b.bits;
}
static inline bool bf16_gt(bf16_t a, bf16_t b)
{
return bf16_lt(b, a);
}
#include <stdio.h>
#include <time.h>
#define TEST_ASSERT(cond, msg) \
do { \
if (!(cond)) { \
printf("FAIL: %s\n", msg); \
return 1; \
} \
} while (0)
static int test_basic_conversions(void)
{
printf("Testing basic conversions...\n");
float test_values[] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.5f,
-0.5f, 3.14159f, -3.14159f, 1e10f, -1e10f};
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); i++) {
float orig = test_values[i];
bf16_t bf = f32_to_bf16(orig);
float conv = bf16_to_f32(bf);
if (orig != 0.0f) {
TEST_ASSERT((orig < 0) == (conv < 0), "Sign mismatch");
}
if (orig != 0.0f && !bf16_isinf(f32_to_bf16(orig))) {
float diff = (conv - orig);
float rel_error = (diff < 0) ? -diff / orig : diff / orig;
TEST_ASSERT(rel_error < 0.01f, "Relative error too large");
}
}
printf(" Basic conversions: PASS\n");
return 0;
}
static int test_special_values(void)
{
printf("Testing special values...\n");
bf16_t pos_inf = {.bits = 0x7F80}; /* +Infinity */
TEST_ASSERT(bf16_isinf(pos_inf), "Positive infinity not detected");
TEST_ASSERT(!bf16_isnan(pos_inf), "Infinity detected as NaN");
bf16_t neg_inf = {.bits = 0xFF80}; /* -Infinity */
TEST_ASSERT(bf16_isinf(neg_inf), "Negative infinity not detected");
bf16_t nan_val = BF16_NAN();
TEST_ASSERT(bf16_isnan(nan_val), "NaN not detected");
TEST_ASSERT(!bf16_isinf(nan_val), "NaN detected as infinity");
bf16_t zero = f32_to_bf16(0.0f);
TEST_ASSERT(bf16_iszero(zero), "Zero not detected");
bf16_t neg_zero = f32_to_bf16(-0.0f);
TEST_ASSERT(bf16_iszero(neg_zero), "Negative zero not detected");
printf(" Special values: PASS\n");
return 0;
}
static int test_arithmetic(void)
{
printf("Testing arithmetic operations...\n");
bf16_t a = f32_to_bf16(1.0f);
bf16_t b = f32_to_bf16(2.0f);
bf16_t c = bf16_add(a, b);
float result = bf16_to_f32(c);
float diff = result - 3.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Addition failed");
c = bf16_sub(b, a);
result = bf16_to_f32(c);
diff = result - 1.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Subtraction failed");
a = f32_to_bf16(3.0f);
b = f32_to_bf16(4.0f);
c = bf16_mul(a, b);
result = bf16_to_f32(c);
diff = result - 12.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Multiplication failed");
a = f32_to_bf16(10.0f);
b = f32_to_bf16(2.0f);
c = bf16_div(a, b);
result = bf16_to_f32(c);
diff = result - 5.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Division failed");
/* Test square root */
a = f32_to_bf16(4.0f);
c = bf16_sqrt(a);
result = bf16_to_f32(c);
diff = result - 2.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(4) failed");
a = f32_to_bf16(9.0f);
c = bf16_sqrt(a);
result = bf16_to_f32(c);
diff = result - 3.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(9) failed");
printf(" Arithmetic: PASS\n");
return 0;
}
static int test_comparisons(void)
{
printf("Testing comparison operations...\n");
bf16_t a = f32_to_bf16(1.0f);
bf16_t b = f32_to_bf16(2.0f);
bf16_t c = f32_to_bf16(1.0f);
TEST_ASSERT(bf16_eq(a, c), "Equality test failed");
TEST_ASSERT(!bf16_eq(a, b), "Inequality test failed");
TEST_ASSERT(bf16_lt(a, b), "Less than test failed");
TEST_ASSERT(!bf16_lt(b, a), "Not less than test failed");
TEST_ASSERT(!bf16_lt(a, c), "Equal not less than test failed");
TEST_ASSERT(bf16_gt(b, a), "Greater than test failed");
TEST_ASSERT(!bf16_gt(a, b), "Not greater than test failed");
bf16_t nan_val = BF16_NAN();
TEST_ASSERT(!bf16_eq(nan_val, nan_val), "NaN equality test failed");
TEST_ASSERT(!bf16_lt(nan_val, a), "NaN less than test failed");
TEST_ASSERT(!bf16_gt(nan_val, a), "NaN greater than test failed");
printf(" Comparisons: PASS\n");
return 0;
}
static int test_edge_cases(void)
{
printf("Testing edge cases...\n");
float tiny = 1e-45f;
bf16_t bf_tiny = f32_to_bf16(tiny);
float tiny_val = bf16_to_f32(bf_tiny);
TEST_ASSERT(bf16_iszero(bf_tiny) || (tiny_val < 0 ? -tiny_val : tiny_val) < 1e-37f,
"Tiny value handling");
float huge = 1e38f;
bf16_t bf_huge = f32_to_bf16(huge);
bf16_t bf_huge2 = bf16_mul(bf_huge, f32_to_bf16(10.0f));
TEST_ASSERT(bf16_isinf(bf_huge2), "Overflow should produce infinity");
bf16_t small = f32_to_bf16(1e-38f);
bf16_t smaller = bf16_div(small, f32_to_bf16(1e10f));
float smaller_val = bf16_to_f32(smaller);
TEST_ASSERT(bf16_iszero(smaller) || (smaller_val < 0 ? -smaller_val : smaller_val) < 1e-45f,
"Underflow should produce zero or denormal");
printf(" Edge cases: PASS\n");
return 0;
}
static int test_rounding(void)
{
printf("Testing rounding behavior...\n");
float exact = 1.5f;
bf16_t bf_exact = f32_to_bf16(exact);
float back_exact = bf16_to_f32(bf_exact);
TEST_ASSERT(back_exact == exact,
"Exact representation should be preserved");
float val = 1.0001f;
bf16_t bf = f32_to_bf16(val);
float back = bf16_to_f32(bf);
float diff2 = back - val;
TEST_ASSERT((diff2 < 0 ? -diff2 : diff2) < 0.001f, "Rounding error should be small");
printf(" Rounding: PASS\n");
return 0;
}
#ifndef BFLOAT16_NO_MAIN
int main(void)
{
printf("\n=== bfloat16 Test Suite ===\n\n");
int failed = 0;
failed |= test_basic_conversions();
failed |= test_special_values();
failed |= test_arithmetic();
failed |= test_comparisons();
failed |= test_edge_cases();
failed |= test_rounding();
if (failed) {
printf("\n=== TESTS FAILED ===\n");
return 1;
}
printf("\n=== ALL TESTS PASSED ===\n");
return 0;
}
#endif /* BFLOAT16_NO_MAIN */
```
:::
### Assembly
- #### fp32_to_bf16
:::spoiler See More
```asm=
.data
test: .word 0x7FC00000 # NaN
.word 0x7F800000 # inf
.word 0x3F818000 # round up
.word 0x3F808000 # round down
ans: .word 0x7FC0
.word 0x7F80
.word 0x3F82
.word 0x3F80
n_tests: .word 4
msg_input: .string "Input = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, test
la s1, ans
lw s2, n_tests
li s4, 0 # index
loop_t:
beq s4, s2, done
la a0, msg_input
li a7, 4
ecall
lw s3, 0(s0)
addi a0, s3, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, f32_to_bf16
addi t0, a1, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s1)
addi s0, s0, 4
addi s1, s1, 4
addi s4, s4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
f32_to_bf16:
addi sp, sp, -16
sw ra, 12(sp)
sw s3, 8(sp)
li t0, 0xFF
srli t1, s3, 23
and t1, t0, t1
beq t0, t1, isNAN
srli t0, s3, 16
andi t0, t0, 1
li t1, 0x7FFF
add t0, t0, t1
add a1, s3, t0
srli a1, a1, 16
lw ra, 12(sp)
addi sp, sp, 16
jr ra
isNAN:
srli a1, s3, 16
ret
```
:::
- #### bf16_to_fp32
:::spoiler See More
```asm=
.data
test: .word 0x0000
.word 0x7F80
.word 0xBF82
ans: .word 0x00000000
.word 0x7F800000
.word 0xBF820000
n_tests: .word 3
msg_input: .string "Input = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s0, test
la s1, ans
lw s2, n_tests
li s4, 0 # index
loop_t:
beq s4, s2, done
la a0, msg_input
li a7, 4
ecall
lw s3, 0(s0)
addi a0, s3, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, bf16_to_f32
addi t0, a1, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s1)
addi s0, s0, 4
addi s1, s1, 4
addi s4, s4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
bf16_to_f32:
addi sp, sp, -16
sw ra, 12(sp)
sw s3, 8(sp)
slli a1, s3, 16
lw ra, 12(sp)
addi sp, sp, 16
jr ra
```
:::
- #### bf16_add
:::spoiler See More
```asm=
.data
test1: .word 0x0000 # 0
.word 0x7F80 # inf
.word 0x7FC0 # NaN
.word 0x3F81 # 1.0078125
.word 0xBF82 # -1.015625
test2: .word 0x0000
.word 0x7F80
.word 0x7FC0
.word 0x3F81
.word 0xBF82
ans: .word 0x0000
.word 0x7F80
.word 0x7FC0
.word 0x4001
.word 0xC002
n_tests: .word 5
msg_input_a: .string "A = "
msg_input_b: .string ", B = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s6, test1
la s7, test2
la s10, ans
lw s11, n_tests
li a4, 0 # index
loop_t:
beq a4, s11, done
la a0, msg_input_a
li a7, 4
ecall
lw s8, 0(s6)
addi a0, s8, 0
li a7, 34
ecall
la a0, msg_input_b
li a7, 4
ecall
lw s9, 0(s7)
addi a0, s9, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, bf16_add
addi t0, a2, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s10)
addi s6, s6, 4
addi s10, s10, 4
addi s7, s7, 4
addi a4, a4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
bf16_add:
addi sp, sp, -16
sw ra, 12(sp)
sw s8, 8(sp)
sw s9, 4(sp)
# extract sign/exponent/mantissa from a0 (a)
srli s0, s8, 15 # sign_a = (a.bits >> 15)
srli s2, s8, 7 # exp_a = (a.bits >> 7) & 0xFF
andi s2, s2, 0xFF
andi s4, s8, 0x7F # mant_a = a.bits & 0x7F
# extract sign/exponent/mantissa from a1 (b)
srli s1, s9, 15
srli s3, s9, 7
andi s3, s3, 0xFF
andi s5, s9, 0x7F
# NaN / INF check
li t6, 0xFF
beq s2, t6, check_b_naninf # if exp_a == 0xFF
check_b_naninf:
beq s3, t6, handle_b_naninf
check_a_naninf:
bne s2, t6, check_b_zero # if exp_a != 0xFF skip
bnez s4, return_nan # if mant_a != 0 => NaN
beq s3, t6, both_inf # if both Inf
j return_a # else return Inf (a)
handle_b_naninf:
bnez s5, return_nan # if mant_b != 0 => NaN
j return_b # else Inf(b)
both_inf:
beq s0, s1, return_a # same sign -> Inf
j return_nan # opposite sign -> NaN
# Normal zero handling
check_b_zero:
or t3, s2, s4
beqz t3, return_b
or t3, s3, s5
beqz t3, return_a
beqz s2, r1
ori s4, s4, 0x80
r1:
beqz s3, r2
ori s5, s5, 0x80
r2:
# Exp_diff
sub t3, s2, s3
li t4, 8
bgt t3, t4, return_a
neg t5, t3
bgt t5, t4, return_b
# Align
bgtz t3, shift_b
bltz t3, shift_a
j aligned
shift_b:
srl s5, s5, t3
mv t1, s2
j compute
shift_a:
neg t3, t3
srl s4, s4, t3
mv t1, s3
j compute
aligned:
mv t1, s2
# Compute
compute:
beq s0, s1, same_sign
# Different signs
slt t4, s4, s5
beqz t4, a_ge_b
mv t0, s1
sub t2, s5, s4
j normalize
a_ge_b:
mv t0, s0
sub t2, s4, s5
j normalize
# Same sign
same_sign:
mv t0, s0
add t2, s4, s5
li t3, 0x100
and t3, t2, t3
beqz t3, pack
srli t2, t2, 1
addi t1, t1, 1
j pack
# Normalize
normalize:
beqz t2, return_zero
norm_loop:
andi t3, t2, 0x80
bnez t3, pack
slli t2, t2, 1
addi t1, t1, -1
blez t1, return_zero
j norm_loop
# Result
pack:
andi t2, t2, 0x7F
andi t1, t1, 0xFF
slli t1, t1, 7
slli t0, t0, 15
or a2, t1, t2
or a2, a2, t0
j return
return_nan:
li a2, 0x7FC0
j return
return_a:
mv a2, s8
j return
return_b:
mv a2, s9
j return
return_zero:
li a2, 0
j return
return:
lw ra, 12(sp)
lw s8, 8(sp)
lw s9, 4(sp)
addi sp, sp, 16
jr ra
```
:::
- #### bf16_mul
:::spoiler See More
```asm=
.data
test1: .word 0x0000 # 0
.word 0x7F80 # inf
.word 0x7FC0 # NaN
.word 0x3F81 # 1.0078125
.word 0xBF82 # -1.015625
test2: .word 0x0000
.word 0x7F80
.word 0x7FC0
.word 0x3F81
.word 0xBF82
ans: .word 0x0000
.word 0x7F80
.word 0x7FC0
.word 0x3F82
.word 0x3F84
n_tests: .word 5
msg_input_a: .string "A = "
msg_input_b: .string ", B = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s6, test1
la s7, test2
la s10, ans
lw s11, n_tests
li a4, 0 # index
loop_t:
beq a4, s11, done
la a0, msg_input_a
li a7, 4
ecall
lw s8, 0(s6)
addi a0, s8, 0
li a7, 34
ecall
la a0, msg_input_b
li a7, 4
ecall
lw s9, 0(s7)
addi a0, s9, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, bf16_mul
addi t0, a2, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s10)
addi s6, s6, 4
addi s10, s10, 4
addi s7, s7, 4
addi a4, a4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
bf16_mul:
addi sp, sp, -16
sw ra, 12(sp)
# extract sign/exponent/mantissa from a (a0)
srli s0, s8, 15 # sign_a
srli s2, s8, 7
andi s2, s2, 0xFF # exp_a
andi s4, s8, 0x7F # mant_a
# extract sign/exponent/mantissa from b (a1)
srli s1, s9, 15 # sign_b
srli s3, s9, 7
andi s3, s3, 0xFF # exp_b
andi s5, s9, 0x7F # mant_b
# result_sign = sign_a ^ sign_b
xor s0, s0, s1
# if exp_a == 0xFF
li t1, 0xFF
beq s2, t1, check_a_inf
check_b_inf:
beq s3, t1, check_b_nan_inf
check_zero:
beqz s2, check_a_zero
beqz s3, check_b_zero
norm_mant:
# normalize a
beqz s2, norm_a_sub
ori s4, s4, 0x80
j norm_b
norm_a_sub:
li t2, 0
norm_a_shift:
andi t3, s4, 0x80
bnez t3, norm_a_done
slli s4, s4, 1
addi t2, t2, -1
j norm_a_shift
norm_a_done:
li s2, 1
mv t5, t2
norm_b:
beqz s3, norm_b_sub
ori s5, s5, 0x80
j mul_core
norm_b_sub:
li t2, 0
norm_b_shift:
andi t3, s5, 0x80
bnez t3, norm_b_done
slli s5, s5, 1
addi t2, t2, -1
j norm_b_shift
norm_b_done:
li s3, 1
add t5, t5, t2 # exp_adjust
mul_core:
# result_mant = mant_a * mant_b
mul t6, s4, s5
# result_exp = exp_a + exp_b - 127 + exp_adjust
add t1, s2, s3
add t1, t1, t5
addi t1, t1, -127
# normalize mantissa
li t0, 0x8000
and t2, t6, t0
beqz t2, mant_shift7
srli t6, t6, 8
andi t6, t6, 0x7F
addi t1, t1, 1
j check_exp
mant_shift7:
srli t6, t6, 7
andi t6, t6, 0x7F
check_exp:
li t3, 0xFF
bge t1, t3, ret_inf
blez t1, underflow
# normal result
slli t0, s0, 15
slli t1, t1, 7
or a2, t0, t1
or a2, a2, t6
j done_mul
check_a_inf:
andi t2, s4, 0x7F
bnez t2, ret_a
beq s3, t1, both_inf
j ret_inf
check_b_nan_inf:
andi t2, s5, 0x7F
bnez t2, ret_b
beq s2, x0, ret_nan
j ret_inf
check_a_zero:
beqz s4, ret_b
check_b_zero:
beqz s5, ret_a
j norm_mant
underflow:
li t2, -6
blt t1, t2, ret_zero
li t3, 1
sub t3, t3, t1
srl t6, t6, t3
li t1, 0
slli t0, t0, 15
slli t1, t1, 7
or a2, t0, t1
or a2, a2, t6
j done_mul
ret_inf:
li a2, 0x7F80
slli t0, t0, 15
or a2, a2, t0
j done_mul
ret_a:
mv a2, s8
j done_mul
ret_b:
mv a2, s9
j done_mul
ret_zero:
slli a2, t0, 15
j done_mul
ret_nan:
li a2, 0x7FC0
j done_mul
both_inf:
li a2, 0x7F80
slli t0, t0, 15
or a2, a2, t0
j done_mul
done_mul:
lw ra, 12(sp)
addi sp, sp, 16
jr ra
```
:::
- #### bf16_div
:::spoiler See More
```asm=
.data
test1: .word 0x7FC0
.word 0x7F80
.word 0x0000
.word 0x0080
.word 0x3E80
.word 0x3F80
test2: .word 0x3F80
.word 0x7F80
.word 0x0000
.word 0x3F80
.word 0x3F81
.word 0x0040
ans: .word 0x7FC0
.word 0x7FC0
.word 0x7FC0
.word 0x0080
.word 0x3E7E
.word 0x7F80
n_tests: .word 6
msg_input_a: .string "A = "
msg_input_b: .string ", B = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s6, test1
la s7, test2
la s10, ans
lw s11, n_tests
li a4, 0 # index
loop_t:
beq a4, s11, done
la a0, msg_input_a
li a7, 4
ecall
lw s8, 0(s6)
addi a0, s8, 0
li a7, 34
ecall
la a0, msg_input_b
li a7, 4
ecall
lw s9, 0(s7)
addi a0, s9, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
jal ra, bf16_div
addi t0, a2, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s10)
addi s6, s6, 4
addi s10, s10, 4
addi s7, s7, 4
addi a4, a4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
bf16_div:
addi sp, sp, -16
sw ra, 12(sp)
# extract fields
srli s0, s8, 15 # sign_a
srli s2, s8, 7
andi s2, s2, 0xFF # exp_a
andi s4, s8, 0x7F # mant_a
srli s1, s9, 15 # sign_b
srli s3, s9, 7
andi s3, s3, 0xFF # exp_b
andi s5, s9, 0x7F # mant_b
# result_sign = sign_a ^ sign_b
xor t0, s0, s1
li t1, 0xFF
# Special cases: b is Inf or NaN
beq s3, t1, check_b_inf
# Special cases: b exponent == 0 (subnormal or zero)
beqz s3, check_b_zero
# Special cases: a is Inf or NaN
beq s2, t1, check_a_inf
# Special cases: a exponent == 0 (subnormal or zero)
beqz s2, check_a_zero_needed
# Normalize mantissas / handle subnormals
j norm_mant
# b is Inf or NaN
check_b_inf:
andi t2, s5, 0x7F # t2 = mant_b
bnez t2, ret_b # b is NaN -> return b
# b is +Inf/-Inf (mant_b == 0)
# if a is also Inf (exp_a==0xFF) and mant_a==0 -> Inf/Inf = NaN
beq s2, t1, a_maybe_inf
# else finite / Inf => result is signed zero (result_sign << 15)
slli a2, t0, 15
j done_div
a_maybe_inf:
andi t2, s4, 0x7F # t2 = mant_a
beqz t2, ret_nan # a is Inf too -> Inf/Inf = NaN
# else a is NaN -> return a
mv a2, s8
j done_div
# b exponent == 0 (subnormal or zero)
check_b_zero:
andi t2, s5, 0x7F # t2 = mant_b
bnez t2, norm_mant # subnormal -> normalize then divide
# b == 0 exactly -> division by zero
andi t2, s4, 0x7F # t2 = mant_a
beqz t2, ret_nan # 0/0 -> NaN
# else -> signed Inf
slli a2, t0, 15
li t3, 0x7F80
or a2, a2, t3
j done_div
# a exponent == 0 (subnormal or zero) check
check_a_zero_needed:
andi t2, s4, 0x7F
beqz t2, ret_zero # a == 0 -> signed zero
j norm_mant
# a is Inf or NaN
check_a_inf:
andi t2, s4, 0x7F
bnez t2, ret_a # a is NaN -> return a
# a is Inf -> result signed Inf (b not Inf here)
slli a2, t0, 15
li t3, 0x7F80
or a2, a2, t3
j done_div
# normalize mantissas (handle subnormals)
norm_mant:
# normalize a (if subnormal)
beqz s2, norm_a_sub
ori s4, s4, 0x80 # implicit 1
li t5, 0 # exp_adjust = 0
j norm_b
norm_a_sub:
li t5, 0
norm_a_shift:
andi t2, s4, 0x80
bnez t2, norm_a_done
slli s4, s4, 1
addi t5, t5, -1
j norm_a_shift
norm_a_done:
li s2, 1
norm_b:
beqz s3, norm_b_sub
ori s5, s5, 0x80
j div_core
norm_b_sub:
li t2, 0
norm_b_shift:
andi t2, s5, 0x80
bnez t2, norm_b_done
slli s5, s5, 1
addi t2, t2, -1
j norm_b_shift
norm_b_done:
li s3, 1
add t5, t5, t2 # t5 = exp_adjust (a_adjust + b_adjust)
# div_core: long division 16 iterations
div_core:
# dividend = mant_a << 15
slli t1, s4, 15 # t1 = dividend
mv t2, s5 # t2 = divisor
li t3, 0 # t3 = quotient
li t4, 15 # j = 15 down to 0
div_loop:
slli t3, t3, 1
# compute shifted = divisor << j (use register shift: sll rd, rs1, rs2)
sll t6, t2, t4 # t6 = divisor << j
blt t1, t6, no_sub
sub t1, t1, t6
ori t3, t3, 1
no_sub:
addi t4, t4, -1
bgez t4, div_loop # continue while j >= 0
# compute result_exp = exp_a - exp_b + BF16_EXP_BIAS (127) + exp_adjust
sub t1, s2, s3
addi t1, t1, 127
add t1, t1, t5
# normalize quotient: if bit15 set -> right shift 8; else left shift until bit15 set (decrement exp)
li t4, 0x8000
and t2, t3, t4
bnez t2, q_shift8
q_norm_loop:
and t2, t3, t4
bnez t2, q_norm_done
slli t3, t3, 1
addi t1, t1, -1
bgt t1, x0, q_norm_loop
q_norm_done:
srli t3, t3, 8
j q_after_norm
q_shift8:
srli t3, t3, 8
q_after_norm:
andi t3, t3, 0x7F # mantissa (7 bits)
# overflow / underflow
li t2, 0xFF
bge t1, t2, ret_inf
blez t1, ret_zero
# pack result
slli t0, t0, 15
slli t1, t1, 7
or a2, t0, t1
or a2, a2, t3
j done_div
# returns / special cases
ret_a:
mv a2, s8
j done_div
ret_b:
mv a2, s9
j done_div
ret_inf:
slli a2, t0, 15
li t2, 0x7F80
or a2, a2, t2
j done_div
ret_zero:
slli a2, t0, 15
j done_div
ret_nan:
li a2, 0x7FC0
j done_div
done_div:
lw ra, 12(sp)
addi sp, sp, 16
jr ra
```
:::
- #### bf16_sqrt
:::spoiler See More
```asm=
.data
test1: .word 0x7FC1
.word 0x7F80
.word 0x0000
.word 0x0040
.word 0x3E80
.word 0x407F
ans: .word 0x7FC1
.word 0x7F80
.word 0x0000
.word 0x0000
.word 0x3F00
.word 0x3FFF
n_tests: .word 6
msg_input_a: .string "A = "
msg_result: .string ", Result = "
msg_pass: .string ", PASS!\n"
msg_fail: .string ", FAIL!\n"
.text
.globl main
main:
la s6, test1
la s10, ans
lw s11, n_tests
li a4, 0 # index
loop_t:
beq a4, s11, done
la a0, msg_input_a
li a7, 4
ecall
lw s8, 0(s6)
addi a0, s8, 0
li a7, 34
ecall
la a0, msg_result
li a7, 4
ecall
mv a0, s8
jal ra, bf16_sqrt
addi t0, a2, 0
addi a0, t0, 0
li a7, 34
ecall
lw t1, 0(s10)
addi s6, s6, 4
addi s10, s10, 4
addi a4, a4, 1
bne t0, t1, fail
la a0, msg_pass
li a7, 4
ecall
j loop_t
fail:
la a0, msg_fail
li a7, 4
ecall
j loop_t
done:
li a7, 10 # exit
ecall
bf16_sqrt:
addi sp, sp, -32
sw ra, 28(sp)
sw a0, 24(sp)
mv t0, a0
srli t1, t0, 15 # sign
srli t2, t0, 7
andi t2, t2, 0xFF # exp
andi t3, t0, 0x7F # mant
# if exp == 0xFF
li t4, 0xFF
beq t2, t4, check_inf_nan
# if exp==0 && mant==0 -> return 0
beqz t2, check_zero
j check_neg
check_zero:
beqz t3, ret_zero
j check_neg
check_neg:
bnez t1, ret_nan # negative -> NaN
# flush denormals
beqz t2, ret_zero
# e = exp - 127
addi t5, t2, -127
# get mantissa with implicit 1
ori t6, t3, 0x80 # m = 0x80 | mant
# adjust for odd exponent
andi t7, t5, 1
beqz t7, even_exp
slli t6, t6, 1
addi t5, t5, -1
even_exp:
srai t5, t5, 1
addi t5, t5, 127 # new_exp = (e>>1)+127
# binary search for sqrt(m)
li s0, 90
li s1, 256
li s2, 128 # result
bs_loop:
bgt s0, s1, bs_done
add s3, s0, s1
srli s3, s3, 1 # mid = (low+high)>>1
mul s4, s3, s3 # mid*mid
srli s4, s4, 7 # /128
bleu s4, t6, bs_leq
addi s1, s3, -1
j bs_loop
bs_leq:
mv s2, s3
addi s0, s3, 1
j bs_loop
bs_done:
mv t6, s2
# normalize result
li t7, 256
bge t6, t7, norm_shift
li t7, 128
bge t6, t7, mant_ok
norm_shift:
srli t6, t6, 1
addi t5, t5, 1
mant_ok:
andi t6, t6, 0x7F # mantissa only
# check overflow/underflow
li t7, 0xFF
bge t5, t7, ret_inf
blez t5, ret_zero
slli t5, t5, 7
or a2, t5, t6
j done_sqrt
check_inf_nan:
bnez t3, ret_a # NaN propagation
bnez t1, ret_nan # sqrt(-Inf)=NaN
mv a2, t0 # sqrt(+Inf)=+Inf
j done_sqrt
ret_a:
mv a2, t0
j done_sqrt
ret_zero:
li a2, 0
j done_sqrt
ret_nan:
li a2, 0x7FC0
j done_sqrt
ret_inf:
li a2, 0x7F80
j done_sqrt
done_sqrt:
lw ra, 28(sp)
addi sp, sp, 32
jr ra
```
:::
## Analysis
Testing the code using [Ripes](https://ripes.me/) simulator.
### Pseudo instruction
```
00000000 <main>:
0: 10000417 auipc x8 0x10000
4: 00040413 addi x8 x8 0
8: 10000497 auipc x9 0x10000
c: 00c48493 addi x9 x9 12
10: 10000917 auipc x18 0x10000
14: 01892903 lw x18 24 x18
18: 00000993 addi x19 x0 0
0000001c <loop_tests>:
1c: 0b298863 beq x19 x18 176 <done>
20: 10000517 auipc x10 0x10000
24: 00c50513 addi x10 x10 12
28: 00400893 addi x17 x0 4
2c: 00000073 ecall
30: 00198513 addi x10 x19 1
34: 00100893 addi x17 x0 1
38: 00000073 ecall
3c: 10000517 auipc x10 0x10000
40: ff650513 addi x10 x10 -10
44: 00400893 addi x17 x0 4
48: 00000073 ecall
4c: 00299293 slli x5 x19 2
50: 00540333 add x6 x8 x5
54: 00032503 lw x10 0 x6
58: 00050e93 addi x29 x10 0
5c: 00100893 addi x17 x0 1
60: 00000073 ecall
64: 10000517 auipc x10 0x10000
68: fd950513 addi x10 x10 -39
6c: 00400893 addi x17 x0 4
70: 00000073 ecall
74: 000e8513 addi x10 x29 0
78: 05c000ef jal x1 92 <bitwiseComplement>
7c: 00058f13 addi x30 x11 0
80: 000f0513 addi x10 x30 0
84: 00100893 addi x17 x0 1
88: 00000073 ecall
8c: 00299293 slli x5 x19 2
90: 005483b3 add x7 x9 x5
94: 0003ae03 lw x28 0 x7
98: 01cf1e63 bne x30 x28 28 <fail>
0000009c <pass>:
9c: 10000517 auipc x10 0x10000
a0: fad50513 addi x10 x10 -83
a4: 00400893 addi x17 x0 4
a8: 00000073 ecall
ac: 00198993 addi x19 x19 1
b0: f6dff06f jal x0 -148 <loop_tests>
000000b4 <fail>:
b4: 10000517 auipc x10 0x10000
b8: f9e50513 addi x10 x10 -98
bc: 00400893 addi x17 x0 4
c0: 00000073 ecall
c4: 00198993 addi x19 x19 1
c8: f55ff06f jal x0 -172 <loop_tests>
000000cc <done>:
cc: 00a00893 addi x17 x0 10
d0: 00000073 ecall
000000d4 <bitwiseComplement>:
d4: ffc10113 addi x2 x2 -4
d8: 00112023 sw x1 0 x2
dc: 02050463 beq x10 x0 40 <zero>
e0: 038000ef jal x1 56 <clz>
e4: 00058293 addi x5 x11 0
e8: 02000313 addi x6 x0 32
ec: 405302b3 sub x5 x6 x5
f0: 00100313 addi x6 x0 1
f4: 005312b3 sll x5 x6 x5
f8: 406282b3 sub x5 x5 x6
fc: 005545b3 xor x11 x10 x5
100: 00c0006f jal x0 12 <return>
00000104 <zero>:
104: 00150593 addi x11 x10 1
108: 0040006f jal x0 4 <return>
0000010c <return>:
10c: 00012083 lw x1 0 x2
110: 00410113 addi x2 x2 4
114: 00008067 jalr x0 x1 0
00000118 <clz>:
118: ff010113 addi x2 x2 -16
11c: 00112623 sw x1 12 x2
120: 00a12423 sw x10 8 x2
124: 02000293 addi x5 x0 32
128: 01000313 addi x6 x0 16
0000012c <loop>:
12c: 006553b3 srl x7 x10 x6
130: 00038663 beq x7 x0 12 <skip>
134: 406282b3 sub x5 x5 x6
138: 00038513 addi x10 x7 0
0000013c <skip>:
13c: 00135313 srli x6 x6 1
140: fe0316e3 bne x6 x0 -20 <loop>
144: 40a285b3 sub x11 x5 x10
148: 00c12083 lw x1 12 x2
14c: 00812503 lw x10 8 x2
150: 01010113 addi x2 x2 16
154: 00008067 jalr x0 x1 0
```
### 5-stage pipelined processor

Above is a 5-stage in-order processor with hazard detection / elimination and forwarding CPU.
| Execution info | with CLZ |
| --------------- | -------- |
| Cycles | 606 |
| Instrs. retired | 400 |
| CPI | 1.51 |
| IPC | 0.66 |
| Clock Rate | 10.31 HZ |
#### IF

- PC in this stage is `0x0000000C`
- After we get PC, we can get instruction `0X00C48493` (addi x9, x9, 12) from Instr. Memory
- There is no branch occur, so the next pc will be PC + 4 (`0x00000010`), then the mux before PC will select input from adder
#### ID

- Instruction `addi x9, x9, 12` will be decoded into opcode `addi`, R1 idx `0x0C`, Wr idx `0x09`, imm `0x0C`
- In addi, R2 is no need
- Reg1 read value `0x00` from register file
- `0X0C` will be sign extension to 32 bits `0x0000000C` through Imm.
#### EX

- Multiplexers 1 is to check whether data hazard happens, if the required register is not the newest value, we have to forward data form MEM stage or WB stage. In this case, the value in Reg1 is `0x00`, but we detect the newest value is at M stage, so we need to forward data `0x10000008` from MEM stage
- Multiplexers 2 is to choose ALU's operands, the upper one has to choose RS1 `0x10000008`, the lower one has to choose immediate `0x0000000C`
- ALU add two operand, so the result will be `0x10000014`
#### MEM

- Instruction `addi` is no need to use Data Memory
- So just pass through this stage and go to WB in next stage
#### WB

- The mux select result from ALU
- Then write value `0X10000014` back to register `0x09`
Before WB :

After WB :

## Reference
- [Leetcode:1009. Complement of Base 10 Integer](https://leetcode.com/problems/complement-of-base-10-integer/)
- [Lab1: RV32I Simulator](https://hackmd.io/@sysprog/H1TpVYMdB)
- [Quiz1 of Computer Architecture (2025 Fall)](https://hackmd.io/@sysprog/arch2025-quiz1)