# Computer Architecture — Fall 2025 Homework 1
## Problem B
### 1. Introduction
#### overview
The **uf8** format is a custom 8-bit *logarithmic compression scheme* designed to represent large unsigned integers with compact precision.
Instead of storing values linearly, it uses a **logarithmic scale** so that larger numbers occupy proportionally fewer bits.
`uf8` behaves like a miniature floating-point format.
Its 8 bits are divided as follows:
| Bits | Field | Description |
|:----:|:------|:-------------|
| 7 – 4 | Exponent ($e$) | Controls the logarithmic scale |
| 3 – 0 | Mantissa ($m$) | Provides local precision within the scale |
The decoded value is given by
$$
D(b) = m \cdot 2^{e} + (2^{e} - 1) \cdot 16
$$
Where
$$
e = \left\lfloor \frac{b}{16} \right\rfloor, \quad m = b \bmod 16
$$
The encode value is given by
$$
E(v) =
\begin{cases}
v, & \text{if } v < 16 \\[6pt]
16e + \left\lfloor \dfrac{v - \text{offset}(e)}{2^{e}} \right\rfloor, & \text{otherwise}
\end{cases}
$$
Where
$$
\text{offset}(e) = (2^{e} - 1) \cdot 16
$$
---
### 2.Implementation
#### Original C code include test
:::spoiler code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
:::
#### RiscV assembly code include test
:::spoiler code
```asm=
.data
PASS_MSG: .asciz "All tests PASSED!\n"
FAIL_MSG: .asciz "Some tests FAILED!\n"
.text
.globl main
main:
jal ra, test
mv t0, a0
beqz t0, print_fail
print_pass:
la a0, PASS_MSG
li a7, 4 # ecall 4 = print string
ecall
j done
print_fail:
la a0, FAIL_MSG
li a7, 4
ecall
done:
li a7, 10 # ecall 10 = exit
ecall
clz:
li t0,32 #t0 = n
li t1,16 #t1 = c
clz_Loop:
srl t2,a0,t1 #t2 = y
sgtz t3,t2
beqz t3,shift
sub t0,t0,t1
mv a0,t2
shift:
srli t1,t1,1
sgtz t3,t1
bnez t3,clz_Loop
sub a0,t0,a0
ret
uf8_decode:
andi t1,a0,0x0f #t1 = mantissa
srli t2,a0,4
li t3,15
sub t0,t3,t2
li t3,0x7fff
srl t3,t3,t0
slli t0,t3,4 # t0 = offset
sll t1,t1,t2
add a0,t1,t0
ret
uf8_encode:
addi sp, sp, -24
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
mv s7,a0 #s7 = value
addi t0,s7,-16
sltz t0,t0
beqz t0,find_exp
j restore_regs
find_exp:
addi sp,sp,-4
sw ra,0(sp)
jal ra,clz
lw ra, 0(sp)
addi sp,sp,4
mv s2,a0 #s2=clz
li s3,31
sub s3,s3,s2 #s3 = msb
li s4,0 #exp=0
li s5,0 #overflow=0
li t0,5
blt s3,t0,find_extract_exp
addi s4,s3,-4
addi t0,s4,-15
sgtz t0,t0
beqz t0,end_check
li s4,15
end_check:
li t0,0
check_overflow:
sub t1,t0,s4
sltz t1,t1
beqz t1,check_overflow_done
slli s5,s5,1
addi s5,s5,16
addi t0,t0,1
j check_overflow
check_overflow_done:
adjust:
sgtz t0,s4
sub t1,s7,s5
sltz t1,t1
and t0,t1,t0
beqz t0,adjust_done
addi s5,s5,-16
srli s5,s5,1
addi s4,s4,-1
j adjust
adjust_done:
find_extract_exp:
addi t0,s4,-15
sltz t0,t0
beqz t0,extract_done
slli t0,s5,1
addi s6,t0,16
sub t0,s7,s6
sltz t0,t0
bnez t0,extract_done
mv s5,s6
addi s4,s4,1
j find_extract_exp
extract_done:
sub t0,s7,s5
srl t0,t0,s4
slli t1,s4,4
or a0,t1,t0
j restore_regs
restore_regs:
lw s2, 0(sp)
lw s3, 4(sp)
lw s4, 8(sp)
lw s5, 12(sp)
lw s6, 16(sp)
lw s7, 20(sp)
addi sp, sp, 24
ret
test:
addi sp, sp, -32
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
sw s0, 24(sp)
sw ra, 28(sp)
# s2 = i
# s3 = previous_value
# s4 = pass flag
# s5 = decoded value
# s6 = re-encoded value
# s7 = pointer to FAIL_RECORD
li s2, 0 # i = 0
li s3, -1 # previous_value = -1
li s4, 1 # passed = true
test_loop:
li t0, 256
beq s2, t0, test_done
mv a0, s2 # a0 = fl = i
jal ra, uf8_decode
mv s5, a0 # s5 = value
mv a0, s5
jal ra, uf8_encode
mv s6, a0 # s6 = fl2
# if (fl != fl2)
bne s2, s6, record_fail
# if (value <= previous_value)
sub t1, s5, s3
blez t1, record_fail
# update
mv s3, s5
addi s2, s2, 1
j test_loop
record_fail:
sw s2, 0(s7) # fl
sw s5, 4(s7) # value
sw s6, 8(s7) # fl2
li s4, 0 # passed = false
addi s7, s7, 12 # next record slot
addi s2, s2, 1
j test_loop
test_done:
mv a0, s4 # return pass flag (1 = pass, 0 = fail)
lw s2, 0(sp)
lw s3, 4(sp)
lw s4, 8(sp)
lw s5, 12(sp)
lw s6, 16(sp)
lw s7, 20(sp)
lw s0, 24(sp)
lw ra, 28(sp)
addi sp, sp, 32
ret
```
:::
#### testing result
all cases pass successfully in ripes!

---
### 3.encode Optimization
In the original C implementation of **`uf8_encode()`**, the encoding process relies on **two loops** to determine the correct exponent range.
The first loop estimates the exponent by gradually doubling a threshold value, and the second loop fine-tunes the result by checking for overflow and underflow conditions.
While this approach is functionally correct, it is computationally inefficient in RISC-V assembly due to multiple iterations and branches.
However, by analyzing the mathematical relationship of the exponent range, we can derive a **closed-form expression** that eliminates all iterative loops.
The goal of the exponent search is to find the integer \( e \) such that:
$$
(2^e - 1) \times 16 \le v < (2^{e+1} - 1) \times 16
$$
To simplify, we first divide both sides by 16:
$$
2^e - 1 \le \frac{v}{16} < 2^{e+1} - 1
$$
Let $w = \left\lfloor \frac{v}{16} \right\rfloor$.
Then we can rewrite the inequality as:
$$
2^e - 1 \le w < 2^{e+1} - 1
$$
Next, we add 1 to each term:
$$
2^e \le w + 1 < 2^{e+1}
$$
Now, by taking the base-2 logarithm on both sides, we can directly obtain:
$$
e = \lfloor \log_2(w + 1) \rfloor
$$
Since \( w = v >> 4 \), the final **closed-form equation** becomes:
$$
\boxed{e = \lfloor \log_2((v >> 4) + 1) \rfloor}
$$
#### Implementation Insight
In integer arithmetic, the logarithmic part can be efficiently implemented using the **CLZ (Count Leading Zeros)** instruction, since:
$$
\lfloor \log_2(x) \rfloor = 31 - \text{CLZ}(x)
$$
Therefore, the exponent can be calculated as:
$$
e = 31 - \text{CLZ}((v >> 4) + 1)
$$
This closed-form transformation **completely removes the original overflow and underflow loops**, turning the multi-step iterative process into a single constant-time computation.
After obtaining \( e \), the offset and mantissa can be computed directly as:
$$
\text{offset} = ((1 << e) - 1) << 4,
\qquad
m = \frac{v - \text{offset}}{2^e}
$$
and the final encoded value is:
$$
\text{uf8} = (e << 4) \,|\, (m \& 0x0F)
$$
Asmembly code of refinment encode funtion
:::spoiler code
```asm=
uf8_encode:
addi sp, sp, -24
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
mv s2,a0 #s2 = value
addi t0,s2,-16
sltz t0,t0
bnez t0,restore_regs #if v<16 return
# compute e = floor ( log2 (v/16)+1 )
srli t2,s2,4
addi t2,t2,1
mv a0,t2
addi sp,sp,-4
sw ra,0(sp)
jal ra,clz
lw ra,0(sp)
addi sp,sp,4
mv s3,a0
li t0,31
sub s3,t0,s3 #s3 = e = 31-clz
addi t0,s3,-15
sgtz t0,t0
beqz t0,compute_offset
li s3,15
# compute offser : offset = ((1<<e)-1) << 4
compute_offset:
li t0,1
sll t0,t0,s3
addi t0,t0,-1
slli s4,t0,4 #s4 = offset
# compute mantissa = v-offset >> e
sub t0,s2,s4
srl s5,t0,s3 #s5 = m
# pack result
slli t0,s3,4
andi t1,s5,0x0f
or a0,t0,t1
j restore_regs
```
:::
testing result:

#### Code Size and Cycle Count Comparison
To evaluate the optimization impact, we compare the **original loop-based implementation** with the **closed-form CLZ-based implementation** in terms of both *code size* and *average execution cycles*.
The closed-form version eliminates all loops and conditional branches in the exponent search, leading to a significantly smaller code footprint and a constant-time execution pattern.
| Version | Exponent Computation Method |Code lines in assembler| Cycles |
|:--------:|:----------------------------|:----------------:|:----------------------------:|
| **Original** | Iterative search + fine-tune (overflow/underflow loops) | 252 | 56531 |
| **Optimized (Closed Form)** | Single-step CLZ-based computation | 148 | 34067 |
---
<div align="center">
**Original (Loop-based):**

**Optimized (Closed-form):**
(
</div>
- The **original implementation** spends most of its time inside two nested loops for exponent adjustment (`overflow` and `underflow` searches).
- In the **closed-form version**, the exponent is determined by a single logarithmic estimation:
$$
e = 31 - \text{CLZ}((v >> 4) + 1)
$$
eliminating all iterative loops and reducing branch overhead.
- The improvement cuts the cycle count by **~40%**
---
### 4.CLZ Function Optimization
In the original assembly implementation, the **CLZ (Count Leading Zeros)** function was written using a loop structure.
However, on a pipelined CPU, each branch instruction may introduce **branch hazards** and **pipeline stalls**, which waste cycles and reduce performance.
To address this issue, we can **unroll the loop** and transform the function into a **branchless CLZ implementation**.
This approach replaces conditional branching with arithmetic and logical operations, allowing the CPU pipeline to execute the entire routine **sequentially without control flow interruptions**.
As a result, the branchless CLZ achieves the same functionality as the original version but with **predictable execution timing** and **significantly better performance**.
:::spoiler code
```asm=
clz:
li t0,32 #t0 = n
# chceck high 16 bits
# --- k = 16 ---
srli t1, a0, 16 # y = x >> 16
sltu t2, x0, t1 # cond = (y != 0) ? 1 : 0
sub t3, x0, t2 # mask = -cond (0x00000000 or 0xFFFFFFFF)
slli t4, t2, 4 # cond*16
sub t0, t0, t4 # n -= 16 (if cond)
xori t5, t3, -1 # ~mask
and a0, a0, t5 # a0 = (a0 & ~mask) | (y & mask)
and t1, t1, t3
or a0, a0, t1
# --- k = 8 ---
# similar to k = 16
srli t1, a0, 8
sltu t2, x0, t1
sub t3, x0, t2
slli t4, t2, 3 # cond*8
sub t0, t0, t4
xori t5, t3, -1
and a0, a0, t5
and t1, t1, t3
or a0, a0, t1
# --- k = 4 ---
srli t1, a0, 4
sltu t2, x0, t1
sub t3, x0, t2
slli t4, t2, 2 # cond*4
sub t0, t0, t4
xori t5, t3, -1
and a0, a0, t5
and t1, t1, t3
or a0, a0, t1
# --- k = 2 ---
srli t1, a0, 2
sltu t2, x0, t1
sub t3, x0, t2
slli t4, t2, 1 # cond*2
sub t0, t0, t4
xori t5, t3, -1
and a0, a0, t5
and t1, t1, t3
or a0, a0, t1
# --- k = 1 ---
srli t1, a0, 1
sltu t2, x0, t1
sub t3, x0, t2 # mask
# cond*1
sub t0, t0, t2 # n -= 1 if cond
xori t5, t3, -1
and a0, a0, t5
and t1, t1, t3
or a0, a0, t1
# return n - x
sub a0, t0, a0
ret
```
:::
The branchless implementation removes all conditional branches (`beq`, `bnez`)
and replaces them with pure arithmetic and bitwise masking operations.
This eliminates **branch misprediction penalties** and makes the pipeline execution **fully deterministic**.
Although the total number of ALU operations increases slightly,
the CPU can now execute the instructions linearly without waiting for branch resolution.
As a result, the overall performance improves from 34607 cycles down to 32867 cycles
while maintaining identical functional correctness for all test inputs.
We can also observe a **lower CPI (Cycles Per Instruction)**,
which indicates that the **pipeline is operating more efficiently**,
with fewer stalls and better instruction throughput across the execution stages.

---
### 5.LeetCode 2571 — Minimum Number of Operations to Reduce an Integer to 0
#### Problem Description
You are given a positive integer $n$.
In one operation, you may replace $n$ with either $n + 2^k$ or $n - 2^k$ for any integer $k \ge 0$.
Return the **minimum number of operations** required to make $n = 0$.
Example:
Input: n = 39
Output: 3
Explanation:
39 → 7 (subtract 32)
7 → 8 (add 1)
8 → 0 (subtract 8)
#### Core Idea
Every integer $n$ lies between two consecutive powers of two:
$$
2^e \le n < 2^{e+1}
$$
If we move $n$ to the nearest power of two ($2^e$ or $2^{e+1}$),
we effectively remove the **most significant bit** in one step.
This is optimal because:
- Smaller power jumps ($\pm 1, \pm 2, \pm 4, \dots$) cannot affect the highest bit.
- Removing the largest active bit first minimizes the total number of steps.
Hence the **greedy rule**:
1. Find the highest set bit $e = \lfloor \log_2 n \rfloor$
2. Compute $p = 2^e$ and $q = 2^{e+1}$
3. Move $n$ toward whichever of $p$ or $q$ is closer:
$$
d_1 = n - p, \quad d_2 = q - n
$$
Choose the smaller of $d_1, d_2$
4. Repeat until $n = 0$
#### Using CLZ (Count Leading Zeros)
To find $\lfloor \log_2 n \rfloor$ efficiently, we can use **CLZ**:
$$
e = 31 - \text{CLZ}(n)
$$
This operation counts the number of leading zeros in the 32-bit binary representation of $n$,
giving the position of the most significant 1-bit in $O(1)$ time.
#### Algorithm Steps
1. Initialize `steps = 0`
2. While $n > 0$:
- Find $e = 31 - \text{CLZ}(n)$
- Let $p = 2^e$, $q = 2^{e+1}$
- Compare:
$$
\text{if } 2n \le 3p \Rightarrow n = n - p
\quad \text{else} \quad n = q - n
$$
- Increment `steps`
3. Return `steps`
#### C Implementation
:::spoiler code
```c
#include <stdint.h>
static inline int msb_index(uint32_t x) {
return 31 - __builtin_clz(x);
}
int minOperations(int n) {
int steps = 0;
while (n) {
int e = msb_index((uint32_t)n);
uint32_t p = 1u << e;
uint32_t q = p << 1;
uint64_t two_n = ((uint64_t)n) << 1;
uint64_t three_p = (uint64_t)p + ((uint64_t)p << 1);
if (two_n <= three_p)
n -= (int)p;
else
n = (int)(q - (uint32_t)n);
++steps;
}
return steps;
}
```
:::
assembly code including test data
:::spoiler code
```asm=
.data
PASS_MSG: .asciz "PASS\n"
FAIL_MSG: .asciz "FAIL\n"
# 10 test inputs
test_cases:
.word 1, 2, 3, 5, 7, 15, 39, 100, 1234, 1023
# Expected minimal steps for each input
expected_steps:
.word 1, 1, 2, 2, 2, 2, 3, 3, 5, 2
.text
.globl main
# =====================================================
# main: runs 10 test cases; prints PASS if all match.
# =====================================================
main:
la s0, test_cases # s0 = ptr to inputs
la s1, expected_steps # s1 = ptr to expected
li s2, 10 # s2 = number of test cases
li s4, 1 # s4 = pass flag (1 = true)
case_loop:
beqz s2, all_done # if all cases done → finish
# Load n and expected steps
lw s5, 0(s0) # s5 = n
lw s6, 0(s1) # s6 = expected steps
addi s0, s0, 4
addi s1, s1, 4
# ---- Run algorithm for this n ----
mv s2, s2 # keep counter intact
mv s7, s5 # s7 = n (working copy)
li s3, 0 # s3 = steps = 0
Loop:
beqz s7, Loop_end # if (n == 0) break
# --- Compute e = 31 - clz(n) ---
mv a0, s7
jal ra, clz # a0 = number of leading zeros
li t0, 31
sub t0, t0, a0 # t0 = e = 31 - clz(n)
# --- Compute p = 2^e ---
li t1, 1
sll t1, t1, t0 # t1 = p = 1 << e
# --- Compute two_n = 2 * n ---
slli t3, s7, 1 # t3 = two_n
# --- Compute three_p = 3 * p = p + 2p ---
slli t4, t1, 1 # t4 = 2p
add t4, t4, t1 # t4 = 3p
# --- Unsigned compare: if (2n <= 3p) go to p else go to 2p ---
sltu t5, t4, t3 # t5 = (3p < 2n)
beqz t5, to_p # if (2n <= 3p) → to_p
# Case: move toward 2p → n = 2p - n
slli t6, t1, 1 # t6 = 2p
sub s7, t6, s7 # n = 2p - n
j step_done
to_p:
# Case: move toward p → n = n - p
sub s7, s7, t1 # n = n - p
step_done:
addi s3, s3, 1 # steps++
j Loop
Loop_end:
# Compare computed steps (s3) with expected (s6)
beq s3, s6, case_ok
li s4, 0 # mark FAIL
case_ok:
addi s2, s2, -1 # next case
j case_loop
all_done:
# Print PASS/FAIL
beqz s4, print_fail
print_pass:
la a0, PASS_MSG
li a7, 4 # print string
ecall
j exit
print_fail:
la a0, FAIL_MSG
li a7, 4 # print string
ecall
exit:
li a7, 10 # exit
ecall
# =====================================================
# clz: count leading zeros (returns in a0)
# Input: a0 = x (x > 0)
# Output: a0 = number of leading zeros in 32-bit x
# =====================================================
clz:
li t0, 32 # t0 = n (leading-zero count base)
li t1, 16 # t1 = c (shift amount)
clz_Loop:
srl t2, a0, t1 # t2 = y = x >> c
sgtz t3, t2 # t3 = (y > 0)
beqz t3, shift
sub t0, t0, t1 # n -= c
mv a0, t2 # x = y
shift:
srli t1, t1, 1 # c >>= 1
sgtz t3, t1 # c > 0 ?
bnez t3, clz_Loop
sub a0, t0, a0 # a0 = n - x (final adjust in this scheme)
ret
```
:::
and the cycle is 1948

#### implement with branchless clz
Using the branchless CLZ implementation, we achieved 1833 cycles, and the CPI decreased from 1.47 to 1.16. Although the total instruction count increased, both the cycle count and CPI were reduced, showing that the pipeline became more efficient with fewer control hazards and smoother instruction flow.

---
## Problem C
---
### 1. Introduction
#### Goal
The purpose of this assignment is to implement the **bfloat16 (BF16)** arithmetic operations using only **RV32I** instructions
This project aims to demonstrate understanding of:
- IEEE-754 floating-point representation
- Bitwise manipulation of sign, exponent, and mantissa
- Software emulation of floating-point behavior on integer hardware
#### Motivation
By building BF16 arithmetic from scratch, we can observe how floating-point units handle normalization, bias adjustment, rounding, and exceptional values (NaN/Inf/Zero).
It also helps strengthen low-level reasoning about performance and pipeline hazards.
---
### 2. Design Overview
#### Supported Operations
- `bf16_add`
- `bf16_sub`
- `bf16_mul`
- `bf16_div`
- `bf16_sqrt`
- Helper utilities:
`bf16_isnan`, `bf16_isinf`, `bf16_iszero`, `bf16_to_f32`, `f32_to_bf16`
#### Constraints
- Only **RV32I** instructions are allowed
- No hardware floating-point unit
- Must correctly handle **NaN**, **Inf**, **Zero**, and **Subnormal** values
---
### 3. Implementation
#### Data Representation
| Field | Bits | Description |
|-------|------|-------------|
| Sign | 1 | Positive (0) / Negative (1) |
| Exponent | 8 | Biased with bias = 127 |
| Mantissa | 7 | Fraction bits (with implicit 1 for normalized values) |
---
#### C code
:::spoiler code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F)};
}
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 127; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
/*if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}*/
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; +Inf
if (new_exp <= 0)
return BF16_ZERO();*/
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
```
:::
#### assembly code
#### bf16 isnan isinf iszero
:::spoiler code
```asm=
bf16_isnan:
# check exponent
li t1, 0x7F80
and t2, a0, t1
bne t2, t1, isnan_end
# check mantissa
li t1, 0x007F
and t2, a0, t1
beqz t2, isnan_end
li a0, 1 # return 1
ret
isnan_end:
li a0, 0 # return 0
ret
#----------------------------------------
bf16_isinf:
# check exponent
li t1, 0x7F80
and t2, a0, t1
bne t2, t1, isinf_end
# check mantissa
li t1, 0x007F
and t2, a0, t1
bnez t2, isinf_end
li a0, 1 # return 1
ret
isinf_end:
li a0, 0 # return 0
ret
#----------------------------------------
bf16_iszero:
li t1, 0x7FFF
and t2, a0, t1
bnez t2, not_zero
li a0, 1
ret
not_zero:
li a0, 0
ret
#----------------------------------------
```
#### bf32 to bf16
```asm=
bf32_to_bf16:
lw t0,0(a0)
srli t1,t0,23 #shift right t0 23bits
li t2,0xff #load 0xff
and t1,t1,t2
beq t1,t2,return_bf16_inf
srli t1,t0,16
andi t1,t1,1
li t2,0x7fff
add t1,t1,t2
add t0,t0,t1
srli a0,t0,16
ret
return_bf16_inf:
srli t1,t0,16
li t2,0xffff
and a0,t1,t2
ret
```
#### bf16 to bf32
```asm=
bf16_to_bf32:
slli a0,a0,16
ret
```
:::
#### bf16 add
:::spoiler code
```asm=
bf16_add:
# s2 = sign_a
# s3 = sign_b
# s4 = exp_a
# s5 = exp_b
# s6 = mant_a
# s7 = mant_b
addi sp, sp, -24
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
#------------ extract a
addi t0, a0,0
srli t0, t0, 15
andi s2, t0, 1 # s2 = sign_a
addi t0, a0,0
srli t0, t0, 7
andi s4, t0, 0xFF # s4 = exp_a
addi t0, a0,0
andi s6, t0, 0x7F # s6 = mant_a
#------------ extract b
addi t0, a1,0
srli t0, t0, 15
andi s3, t0, 1 # s3 = sign_b
addi t0, a1,0
srli t0, t0, 7
andi s5, t0, 0xFF # s5 = exp_b
addi t0, a1,0
andi s7, t0, 0x7F # s7 = mant_b
#handle inf
li t0,0xff
bne s4,t0,check_exponent_b
bnez s6,return_a
bne s5,t0,return_a
xor t0,s2,s3 #if signa=signb t6=0 else t6=1
or t0,s7,t0
beqz t0,return_b
li a0,0x7FC0
j restore_regs
return_a:
j restore_regs
return_b:
addi a0,a1,0
j restore_regs
check_exponent_b:
li t0,0xff
beq s5,t0,return_b
seqz t1,s4 #s2=!exp_a
seqz t2,s6 #s3=!mantissa_a
and t1,t1,t2
bnez t1,return_b
seqz t1,s5 #s2=!exp_b
seqz t2,s7 #s3=!mantissa_b
and t1,t1,t2
bnez t1,return_a
beqz s4,skip_a
ori s6,s6,0x80
skip_a:
beqz s5,skip_b
ori s7,s7,0x80
skip_b:
sub t0,s4,s5 #t0=exp_diff=exp_a-exp_b
bgtz t0,exp_a_bigger
bltz t0,exp_b_bigger
exp_equal:
addi t1,s4,0 #t1=result_exp = exp_a
j sign_check
exp_a_bigger:
addi t1,s4,0 #t1=result_exp = exp_a
addi t0,t0,-8 # exp_diff -= 8
bgtz t0,return_a #exp_diff>0 return a
addi t0,t0,8
srl s7,s7,t0 #mantissa_b >> exp_diff
j sign_check
exp_b_bigger:
addi t1,s5,0 #t1= result_exp =b
addi t0,t0,8
bltz t0,return_b
addi t0,t0,-8
neg t0,t0
srl s6,s6,t0
sign_check:
bne s2,s3,sign_diff
addi t2,s2,0 #t2 = result_sign = sign_a
add t3,s6,s7 #t3 = result_mantissa = man_a + man_b
andi t0,t3,0x100
beqz t0,return_result
srli t3,t3,1
addi t1,t1,1
addi t0,t1,-0xff
bltz t0,return_result
slli t2,t2,15
li t0,0x7F80
or t2,t2,t0
addi a0,t2,0
j restore_regs
sign_diff:
sub t0,s6,s7
bltz t0,manb_greater
addi t2,s2,0 #t2 = result_sign = sign_a
sub t3,s6,s7 #t3 = result_mantissa = mana-manb
j check_result
manb_greater:
addi t2,s3,0 #t2 = result_sign = sign_b
sub t3,s7,s6 #t3 = result_man = manb-mana
check_result:
beqz t3,bf16_zero
normalize_loop:
andi t0,t3,0x80
bnez t0,return_result
slli t3,t3,1
addi t1,t1,-1
blez t1,bf16_zero
j normalize_loop
bf16_zero:
li a0,0x0000
j restore_regs
return_result:
slli t2,t2,15
andi t1,t1,0xFF
slli t1,t1,7
andi t3,t3,0x7F
or a0,t1,t2
or a0,a0,t3
j restore_regs
```
:::
#### bf16 sub
:::spoiler code
```asm=
bf16_sub:
li t0,0x8000
xor a1, a1, t0 # b.bits ^= 0x8000
# call bf16_add(a, b)
jal ra, bf16_add
```
:::
#### bf16 mul
:::spoiler code
```asm=
bf16_mul:
# s2 = sign_a
# s3 = sign_b
# s4 = exp_a
# s5 = exp_b
# s6 = mant_a
# s7 = mant_b
addi sp, sp, -24
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
#------------ extract a
addi t0, a0,0
srli t0, t0, 15
andi s2, t0, 1 # s2 = sign_a
addi t0, a0,0
srli t0, t0, 7
andi s4, t0, 0xFF # s4 = exp_a
addi t0, a0,0
andi s6, t0, 0x7F # s6 = mant_a
#------------ extract b
addi t0, a1,0
srli t0, t0, 15
andi s3, t0, 1 # s3 = sign_b
addi t0, a1,0
srli t0, t0, 7
andi s5, t0, 0xFF # s5 = exp_b
addi t0, a1,0
andi s7, t0, 0x7F # s7 = mant_b
xor t1,s2,s3 #result_sign = signa xor signb
li t0,0xff
beq t0,s4,mul_a_inf
beq t0,s5,mul_b_inf
seqz t2,s4
seqz t3,s6
seqz t4,s5
seqz t5,s7
and t2,t2,t3
and t3,t4,t5
or t0,t2,t3
beqz t0,mul_process
slli t1,t1,15
addi a0,t1,0
j restore_regs
mul_a_inf:
beqz s6,check_expbmantb
j restore_regs
check_expbmantb:
seqz t2,s5
seqz t3,s7
and t0,t2,t3
bnez t0,mul_return_nan
slli t1,t1,15
li t0,0x7f80
or a0,t1,t0
j restore_regs
mul_return_nan:
li a0,0x7fc0
j restore_regs
mul_b_inf:
beqz s7,check_expamanta
addi a0,a1,0
j restore_regs
check_expamanta:
seqz t2,s4
seqz t3,s6
and t0,t2,t3
bnez t0,mul_return_nan
slli t1,t1,15
li t0,0x7f80
or a0,t1,t0
j restore_regs
mul_process:
li t2,0 #exp_adjust
beqz s4,mul_normalized_a
ori s6,s6,0x80
j mul_a_done
mul_normalized_a:
andi t0,s6,0x80
bnez t0,mul_norm_a_done
slli s6,s6,1
addi t2,t2,-1
j mul_normalized_a
mul_norm_a_done:
li s4,1
mul_a_done:
beqz s5,mul_normalized_b
ori s7,s7,0x80
j mul_b_done
mul_normalized_b:
andi t0,s7,0x80
bnez t0,mul_norm_b_done
slli s7,s7,1
addi t2,t2,-1
j mul_normalized_b
mul_norm_b_done:
li s5,1
mul_b_done:
add t3, s4, s5 # exp_a + exp_b
add t3, t3, t2 # + exp_adjust
addi t2, t3, -127 # - bias #t2=result_exp t1=result_sign
#mantissa multiplyer
li t3,0
addi t4,s6,0
addi t5,s7,0
li t6,8
mul_loop:
andi t0,t5,1 # t0 = (t5 & 1)
neg t0,t0 # t0 = 0 or -1 (全1)
and t0,t0,t4 # t0 = (t5&1)? t4 : 0
add t3,t3,t0 # t3 += t0
slli t4,t4,1
srli t5,t5,1
addi t6,t6,-1
bnez t6,mul_loop
#check result mantissa
li t0,0x8000
and t0,t0,t3
beqz t0,result_mant_zero
srli t0,t3,8
andi t3,t0,0x7F
addi t2,t2,1
j check_mantissa_done
result_mant_zero:
srli t0,t3,7
andi t3,t0,0x7f
check_mantissa_done:
#check result exp
# check overflow
li t0, 0xff
bge t2, t0, mul_overflow # if result_exp >= 255 → INF
# check underflow
blez t2, mul_underflow # if result_exp <= 0 → handle underflow
j mul_done # else → normal
mul_overflow:
slli t1, t1, 15 # sign
li t0, 0x7f80 # INF exponent
or a0, t1, t0
j restore_regs
mul_underflow:
li t0, -6
blt t2, t0, mul_to_zero # if result_exp < -6 → return 0
# else: shift mantissa >> (1 - exp)
li t0, 1
sub t0, t0, t2 # t0 = 1 - result_exp
srl t3, t3, t0
li t2, 0 # exp = 0
j mul_done
mul_to_zero:
slli t1, t1, 15 # just sign
addi a0, t1, 0 # return ±0
j restore_regs
mul_done:
slli t1, t1, 15 # sign
andi t2, t2, 0xff # exp
slli t2, t2, 7
andi t3, t3, 0x7f # mant
or t1, t1, t2
or t1, t1, t3
addi a0, t1, 0
j restore_regs
```
:::
Since the use of the hardware multiply instruction (mul) is not allowed, this program implements mantissa multiplication using the shift-and-add method. This approach mimics how binary multiplication works at the bit level.
#### bf16 div
:::spoiler code
```asm=
bf16_div:
addi sp, sp, -24
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
#--------------------------------
# Extract fields
#--------------------------------
srli t0, a0, 15
andi s2, t0, 1 # s2 = sign_a
srli t0, a1, 15
andi s3, t0, 1 # s3 = sign_b
srli t0, a0, 7
andi s4, t0, 0xFF # s4 = exp_a
srli t0, a1, 7
andi s5, t0, 0xFF # s5 = exp_b
andi s6, a0, 0x7F # mant_a
andi s7, a1, 0x7F # mant_b
xor t1, s2, s3 # result_sign = sign_a ^ sign_b
#--------------------------------
# Handle special cases
#--------------------------------
li t0, 0xFF
beq s5, t0, div_b_inf_or_nan # b = Inf or NaN
beqz s5, div_b_zero_check # b = subnormal or zero?
j div_check_a # else continue
div_b_inf_or_nan:
bnez s7, div_return_nan
# mant_b == 0 → b 是 Inf
li t0, 0xFF
bne s4, t0, div_return_zero
bnez s6, div_return_nan
# a 是 Inf, b 是 Inf → NaN
j div_return_nan
div_b_zero_check:
beqz s7, div_b_is_zero
j div_check_a
div_b_is_zero:
beqz s4, div_a_zero_check_for_bzero
j div_by_zero
div_a_zero_check_for_bzero:
beqz s6, div_return_nan # 0 / 0 → NaN
j div_by_zero
div_by_zero:
li t0, 0x7F80 # Inf exponent
slli t1, t1, 15
or a0, t0, t1
j restore_regs
div_return_zero:
slli t1, t1, 15
addi a0, t1, 0
j restore_regs
div_return_nan:
li a0, 0x7FC0
j restore_regs
div_check_a:
li t0, 0xFF
beq s4, t0, div_a_inf_check
beqz s4, div_a_zero_check
j div_process
div_a_inf_check:
beqz s6, div_return_inf
j div_return_nan
div_a_zero_check:
beqz s6, div_return_zero
j div_process
div_return_inf:
li t0, 0x7F80
slli t1, t1, 15
or a0, t1, t0
j restore_regs
#--------------------------------
# Main process (normal / subnormal)
#--------------------------------
div_process:
# add implicit 1
beqz s4, div_norm_a_done
ori s6, s6, 0x80
div_norm_a_done:
beqz s5, div_norm_b_done
ori s7, s7, 0x80
div_norm_b_done:
# exp = exp_a - exp_b + bias
sub t2, s4, s5
addi t2, t2, 127 # t2 = result_exp
#divider
slli t4,s6,15 #dividend
addi t5,s7,0 #divisor
li t3,0 # quotient
li t6,0 #loop_count
div_loop:
slli t3,t3,1
li t0,15
sub t0,t0,t6
sll t0,t5,t0
blt t4,t0,div_skip
sub t4,t4,t0
ori t3,t3,1
div_skip:
addi t6,t6,1
li t0,16
blt t6,t0,div_loop
#div_end
# exp_a == 0 → exp_adjust--
beqz s4, div_exp_a_zero
j div_exp_a_done
div_exp_a_zero:
addi t2, t2, -1
div_exp_a_done:
# exp_b == 0 → exp_adjust++
beqz s5, div_exp_b_zero
j div_exp_b_done
div_exp_b_zero:
addi t2, t2, 1
div_exp_b_done:
#mantissa normalization
li t0, 0x8000
and t0, t0, t3
bnez t0,div_msb_one
div_shift_check:
li t0, 0x8000
and t0, t3, t0 # t0 = (t3 & 0x8000)
bnez t0, div_norm_done # MSB==1 done
addi t0, t2, -1 # t0 = t2 - 1
blez t0, div_norm_done # t0<1 done
slli t3, t3, 1 # quotient <<= 1
addi t2, t2, -1 # result_exp--
j div_shift_check
div_norm_done:
srli t3, t3, 8 # quotient >>= 8
j div_done
div_msb_one:
srli t3,t3,8
j div_done
div_done:
andi t3, t3, 0x7F # mantissa & 0x7F
# --- check overflow (result_exp >= 0xFF) ---
li t0, 0xFF
bge t2, t0, div_overflow # if exp >= 255 → INF
# --- check underflow (result_exp <= 0) ---
blez t2, div_underflow # if exp <= 0 → ZERO
# --- normal case ---
slli t1, t1, 15 # sign << 15
andi t2, t2, 0xFF
slli t2, t2, 7 # exp << 7
or t1, t1, t2
or t1, t1, t3
addi a0, t1, 0
j restore_regs
# --- overflow (INF) ---
div_overflow:
slli t1, t1, 15
li t0, 0x7F80
or a0, t1, t0
j restore_regs
# --- underflow (ZERO) ---
div_underflow:
slli t1, t1, 15
addi a0, t1, 0
j restore_regs
```
:::
#### bf16 sqrt
:::spoiler code
```asm=
bf16_sqrt:
# s2 = sign_a
# s3 = exp_a
# s4 = mant_a
# s5 = result_sign
# s6 = result_exp
# s7 = result_mant
addi sp, sp, -24
sw s2, 0(sp)
sw s3, 4(sp)
sw s4, 8(sp)
sw s5, 12(sp)
sw s6, 16(sp)
sw s7, 20(sp)
#------------ extract a (bf16 in a0)
srli t0, a0, 15
andi s2, t0, 1 # s2 = sign_a
srli t0, a0, 7
andi s3, t0, 0xFF # s3 = exp_a
andi s4, a0, 0x7F # s4 = mant_a
# default sign=0
li s5, 0
#=====================================
# Special cases
#=====================================
# exp == 0xFF → Inf or NaN
li t0, 0xFF
beq s3, t0, sqrt_is_inf_or_nan
# exp == 0 → zero or subnormal
beqz s3, sqrt_is_zero_or_sub
# negative (normal, non-zero) → NaN
bnez s2, sqrt_neg_input
# others go main process
j sqrt_process
# ---------- exp==0xFF ----------
sqrt_is_inf_or_nan:
# mant != 0 NaN
bnez s4, sqrt_return_input_nan
# mant == 0 → Inf
# −Inf → NaN;+Inf → +Inf
bnez s2, sqrt_return_qnan
li a0, 0x7F80 # +Inf
j restore_regs
sqrt_return_input_nan:
addi a0, a0, 0
j restore_regs
sqrt_return_qnan:
li a0, 0x7FC0 # quiet NaN
j restore_regs
# ---------- exp==0 ----------
sqrt_is_zero_or_sub:
# mant == 0 → ±0
beqz s4, sqrt_return_same_zero
li a0, 0x0000
j restore_regs
sqrt_return_same_zero:
li a0, 0x0000
j restore_regs
# ---------- negative (normal, non-zero) ----------
sqrt_neg_input:
li a0, 0x7FC0 # NaN
j restore_regs
sqrt_process:
addi t2,s3,-127 #t2= exp-bias
ori t1,s4,0x80 # get full mantissa
andi t0,t2,1
beqz t0,even_exp
slli t1,t1,1
addi t0,t2,-1
srai t0,t0,1
addi t0,t0,127
mv s6,t0 #s6 = new_exp
j sqrt_adjust_done
even_exp:
srai t0,t2,1
addi t0,t0,127
mv s6,t0 #s6 = new_exp
sqrt_adjust_done:
li t2, 128 # low
li t3, 256 # high
li t4, 128 # result
sqrt_binary_search_loop:
# while (low <= high)
blt t3, t2, sqrt_binary_search_end
# mid = (low + high) >> 1
add t0, t2, t3
srli t0, t0, 1 # t0 = mid
# sq = (mid * mid) >> 7
li t5, 0
mv t6, t0 # counter = mid
sqrt_mul_loop:
beqz t6, sqrt_mul_done
add t5, t5, t0 # t5 += mid
addi t6, t6, -1
j sqrt_mul_loop
sqrt_mul_done:
srli t5, t5, 7 # sq = (mid*mid)/128
# if (sq <= m) { result=mid; low=mid+1; } else { high=mid-1; }
blt t1, t5, sqrt_go_hi # if m < sq → high = mid - 1
mv t4, t0 # result = mid
addi t2, t0, 1 # low = mid + 1
j sqrt_binary_search_loop
sqrt_go_hi:
addi t3, t0, -1 # high = mid - 1
j sqrt_binary_search_loop
sqrt_binary_search_end:
li t0, 256
blt t4, t0, sqrt_check_low
# result >= 256 exp++
srli t4, t4, 1
addi s6, s6, 1
sqrt_check_low:
li t0, 128
bge t4, t0, sqrt_norm_done
sqrt_norm_loop:
slli t4, t4, 1
addi s6, s6, -1
blt t4, t0, sqrt_norm_loop
sqrt_norm_done:
andi s7, t4, 0x7F
j pack_result
pack_result:
# new_mant = result & 0x7F
andi s7, t4, 0x7F
# return (sign<<15) | (new_exp<<7) | new_mant;
slli t1, s5, 15 # sign
andi t2, s6, 0xFF
slli t2, t2, 7 # !!! exp << 7(
andi t3, s7, 0x7F # mant
or t1, t1, t2
or a0, t1, t3
j restore_regs
```
:::
In the sqrt function, we use binary search to approximate the square root of the mantissa.
The mantissa part of a bfloat16 number has 7 bits. After adding the implicit leading 1, it becomes an 8-bit value ranging from $128\ (b'10000000)$ to $255\ (b'11111111)$.
If the exponent is even, the mantissa range is $[128, 256)$. If the exponent is odd, we subtract $1$ from the exponent and multiply the mantissa by $2$, resulting in a range of $[256, 512)$.
Thus, before taking the square root, the mantissa lies within $[128, 512).$
Since the square-root operation effectively halves the exponent, the normalized mantissa result always falls within $[128, 256).$
Therefore, we can safely use $[128, 256)$ as a unified binary-search range for both even and odd exponent cases.
In the original implementation, a wider range of $[90, 256]$ was used to guarantee correctness under all inputs.
However, this range is unnecessarily large and increases the number of iterations required for convergence.
By refining the bounds to $[128, 256)$, the search interval becomes tighter, allowing the algorithm to converge faster while maintaining full accuracy.
Moreover, because the result is guaranteed to remain within $[128, 256)$ (exclusive of $256$), it will never overflow or underflow the mantissa field.
As a result, additional overflow or underflow checks are no longer required, simplifying the post-processing logic.
### Tese Cases
#### ADD Test Cases (bfloat16)
| No. | Operand A (hex) | Operand B (hex) | Expected (hex) | Description |
|:---:|:----------------:|:----------------:|:----------------:|:-------------|
| 1 | `0x3F40` | `0x3EA0` | `0x3F88` | $0.75 + 0.3125 = \mathbf{1.0625}$ |
| 2 | `0x3F80` | `0xBFF0` | `0xBF60` | $1.0 + (-1.875) = \mathbf{-0.875}$ |
| 3 | `0x3F80` | `0xBF80` | `0x0000` | $1.0 + (-1.0) = \mathbf{0.0}$ |
| 4 | `0x7F80` | `0x3F80` | `0x7F80` | $+\infty + 1.0 = \mathbf{+\infty}$ |
| 5 | `0x7F80` | `0xFF80` | `0x7FC0` | $+\infty + (-\infty) = \mathbf{\text{NaN}}$ |
| 6 | `0x7F00` | `0x7F00` | `0x7F80` | $\ 0x7F00 + 0x7F00 = \mathbf{+\infty}$ |
#### SUB Test Cases (bfloat16)
| No. | Operand A (hex) | Operand B (hex) | Expected (hex) | Description |
|:---:|:----------------:|:----------------:|:----------------:|:-------------|
| 1 | `0x3FC0` | `0x3F40` | `0x3F40` | $1.5 - 0.75 = \mathbf{0.75}$ |
| 2 | `0x3F80` | `0x4000` | `0xBF80` | $1.0 - 2.0 = \mathbf{-1.0}$ |
| 3 | `0x3F80` | `0x3F80` | `0x0000` | $1.0 - 1.0 = \mathbf{0.0}$ |
| 4 | `0x7F80` | `0x7F80` | `0x7FC0` | $+\infty - +\infty = \mathbf{\text{NaN}}$ |
#### MUL Test Cases (bfloat16)
| No. | A (hex) | B (hex) | Expected (hex) | Description |
|:--:|:--------:|:--------:|:---------------:|:-------------|
| 1 | `0x4040` | `0x40C0` | `0x4190` | $3.0 \times 6.0 = 18.0$ |
| 2 | `0x3F80` | `0x3F40` | `0x3F40` | $1.0 \times 0.75 = 0.75$ |
| 3 | `0x3FC8` | `0x40D2` | `0x4124` | $1.5625 \times 6.5625 = 10.25390625$ |
| 4 | `0x7F80` | `0xBF80` | `0xFF80` | $+\infty \times -1.0 = -\infty$ |
| 5 | `0x7F80` | `0x0000` | `0x7FC0` | $+\infty \times +0 = \text{NaN (invalid)}$ |
| 6 | `0x0000` | `0x0000` | `0x0000` | $0 \times 0 = 0$ |
#### DIV Test Cases (bfloat16)
| No. | Operand A (hex) | Operand B (hex) | Expected (hex) | Description |
|:---:|:----------------:|:----------------:|:----------------:|:-------------|
| 1 | `0x3F80` | `0x4000` | `0x3F00` | **1.0 ÷ 2.0 = 0.5** |
| 2 | `0x40B8` | `0x4000` | `0x4038` | **5.75 ÷ 2.0 = 2.875** |
| 3 | `0xBF80` | `0x4000` | `0xBF00` | **−1.0 ÷ 2.0 = −0.5** |
| 4 | `0x7F80` | `0x4000` | `0x7F80` | **+Inf ÷ 2.0 = +Inf** |
| 5 | `0x3F80` | `0x7F80` | `0x0000` | **1.0 ÷ +Inf = +0** |
| 6 | `0x0000` | `0x0000` | `0x7FC0` | **0 ÷ 0 = NaN (invalid)** |
| 7 | `0x3F80` | `0x0000` | `0x7F80` | **1.0 ÷ 0 = +Inf** |
#### SQRT Test Cases (bfloat16)
| No. | Operand A (hex) | Expected (hex) | Description |
|:---:|:----------------:|:----------------:|:-------------|
| 1 | `0x4080` | `0x4000` | **sqrt(4.0) = 2.0** |
| 2 | `0x40C8` | `0x4020` | **sqrt(6.25) = 2.5** |
| 3 | `0xBF80` | `0x7FC0` | **sqrt(−1.0) = NaN** |
| 4 | `0x7F80` | `0x7F80` | **sqrt(+Inf) = +Inf** |
| 5 | `0x0000` | `0x0000` | **sqrt(+0) = +0**
### Testing Result
