# Assignment1: RISC-V Assembly and Instruction Pipeline
contributed by < [hbshub](https://github.com/hbshub/ca2025-quizzes) >
>[!Note] AI tools usage
>I use ChatGPT to assist with Quiz 1 by providing code explanations, grammar revisions, pre-work research, code summaries.
## Problem B
In problem B, we want to implement a logarithmic 8-bit codec that maps 20-bit unsigned integers ($[0,1{,}015{,}792]$), and we have decode and encode formula.
**<font size=5>`Decode formula`</font><br>**
$$D(\text{uf8}) = m \cdot 2^e + (2^e - 1) \cdot 16$$
Where e = $\lfloor b/16 \rfloor$ and $m = b \bmod 16$
```
uf8 notation : eeeemmmm
┌──────────────┬──────────────┐
│ Exponent (4) │ Mantissa (4) │
└──────────────┴──────────────┘
7 4 3 0
E: Exponent bits (4 bits)
M: Mantissa bits (4 bits)
```
The `high 4` bits of `uf8` represent `e = floor(uf8/16)`
The `low 4` bits of `uf8` represent `m = uf8 mod 16`.
After decode, we get the 20-bit unsigned integers.
**<font size=5>`Encode formula`</font><br>**
$$
E(v) = \begin{cases}
v, & \text{if } v < 16 \\
16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise}
\end{cases}
$$
where $\text{offset}(e) = (2^e - 1) \cdot 16$
We can encode values in the range $[0, 1{,}015{,}792]$ into the uf8 format, using only 8 bits to represent a 20-bit value.
$$
### **<font size=5>`clz`</font><br>**
Helper function to check the MSB of a value.
:::spoiler c source code
```c=
// count leading zero
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
```
:::
:::spoiler rv32i assembly code
```asm=
# clz(uint32_t x)
# a0: x (input/output)
# return a0 = count of leading zeros
clz:
li t0, 32 # n = 32
li t1, 16 # c = 16
clz_loop:
srl t2, a0, t1 # y = x >> c
beq t2, x0, skip_update # if (y == 0) skip update
sub t0, t0, t1 # n = n - c
mv a0, t2 # x = y
skip_update:
srli t1, t1, 1 # c = c / 2
bnez t1, clz_loop # while (c) loop
sub a0, t0, a0 # return n - x
ret
```
:::
### **<font size=5>`uf8_decode`</font><br>**
:::spoiler c source code
```c=
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
```
:::
:::spoiler rv32i assembly code
```asm=
# fewer instructions by simplify the decode foumula
# => (m ≪ e) + offset
# => (m ≪ e) + ((2^e − 1)⋅16)
# => (m ≪ e) + (16 << e) - 16
# => ((m + 16) << e) - 16
# uf8_decode(uf8 f)
# a0: f (input/output)
# return a0 = uf8_decode(f)
uf8_decode:
srli t0, a0, 4 # t0 = e
andi a0, a0, 0x0F # a0 = m
addi a0, a0, 16 # a0 = m + 16
sll a0, a0, t0 # a0 = (m + 16) << e
addi a0, a0, -16 # a0 = a0 - 16
ret
```
:::
### **<font size=5>`uf8_encode`</font><br>**
:::spoiler c source code
```c=
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
```
:::
:::spoiler rv32i assembly code
```asm=
# uf8_encode : encode a value into 1-byte uf8 encoding
# input a0 - uint32_t value
# return a0 - uf8 uf8_encode(value)
uf8_encode:
li t3, 16
bge a0, t3, e_!0 # if (a0 >= 16) e_!0
ret
e_!0: # e != 0
addi sp, sp, -8
sw a0, 0(sp) # store a0 in stack
sw ra, 4(sp) # store ra in stack
jal ra, clz # a0 = clz(a0)
li t3, 31
sub t0, t3, a0 # msb = t0 = 31 - clz(a0)
lw a0, 0(sp) # restore a0 from stack
lw ra, 4(sp) # restore ra from stack
addi sp, sp, 8
li t1, 0 # exp = t1 = 0
li t2, 0 # of = t2 = 0
li t3, 5
blt t0, t3, find_exa_exp # if (msb < 5) find_exa_exp
addi t1, t0, -4 # exp = msb - 4
li t0, 0 # t0 = cnt = 0
li t3, 15 # t3 = 15, cmp value
ble t1, t3, calc_exp # if (exp < 15) calc_exp
li t1, 15 # exp = 15
calc_exp:
bge t0, t1, adj_exp # if (cnt < exp) loop
slli t2, t2, 1 # of = of << 1
addi t2, t2, 16 # of = of + 16
addi t0, t0, 1 # cnt++
jal x0, calc_exp
adj_exp:
ble t1, x0, find_exa_exp # if (exp <= 0) find_exa_exp
bge a0, t2, find_exa_exp # if (a0 >= of) find_exa_exp
addi t2, t2, -16 # of = of - 16
srli t2, t2, 1 # of = of >> 1
addi t1, t1, -1 # exp--
jal x0, adj_exp
find_exa_exp:
bge t1, t3, calc_m # if (exp >= 15) calc_m
slli t0, t2, 1 # t0 = of << 1
addi t0, t0, 16 # t0 = (of << 1) + 16 = of_e+1
blt a0, t0, calc_m # if (a0 >= of_e) calc_m
mv t2, t0 # of = of_e
addi t1, t1, 1 # exp++
jal x0, find_exa_exp
calc_m:
sub t0, a0, t2 # t0 = value - of
srl t0, t0, t1 # t0 = (value - of) >> exp = m
ble t0, t3, cmb_num # if (m < 15) cmb_num
li t0, 15 # m = 15
cmb_num:
slli t1, t1, 4 # t1 = exp << 4
or a0, t1, t0 # a0 = (exp << 4) | m
ret
```
:::
### Run in Ripes
uf8 code [0~255] decode and encode back all tests passed

### 5-stage Pipelined Processor observation
instruction `srli t0, a0, 4`

1. IF stage
fetch instruction code `0x0045293`
- opcode : `0010011`
- rd : `00101` = x5 = t0
- rs1 : `01010` = x10 = a0
- shamt : `00100` = 4
program counter update
- pc = pc + 4

2. ID stage
decode instruction to four parts
- opcode = `SRLI`
- r1_reg = `0x0a` = x10
- wr_reg = `0x05` = x5
- imm de = `0x000000004`

3. EX stage
execute the `SRLI` operation
- op1 = 0x00000000
- op2 = 0x00000004
- res = 0x0 >> 0x4 = 0x0 (shift right)

4. MEM stage
`SRLI` instruction nothing happens in the MEM stage, because no memory access is needed.

5. WB stage
- wr_data =`0x0`
- wr_idx = `0x5`
write data back to register file.

### cycle count improvement
compiler generated code ([RISC-V (32-bits) gcc (trunk)](https://godbolt.org/))

my implementation

<font size=5>Reduce the cycle count by 60%.</font><br>
## Problem C
### conversion function (f32_to_bf16, bf16_to_f32)
:::spoiler c source code
```c=
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
```
:::
:::spoiler rv32i assembly code
```asm=
# ---- BF16 Masks ----
.equ BF16_SIGN_MASK, 0x8000
.equ BF16_EXP_MASK, 0x7F80
.equ BF16_MANT_MASK, 0x007F
# ---- BF16 Constant ----
.equ BF16_EXP_BIAS, 127
.equ BF16_NAN, 0x7FC0
.equ BF16_ZERO, 0x0000
.data
.text
# input a0 = f32 val
# return bf16 a0
f32_to_bf16:
srli t0, a0, 23 # t0 = s | e
andi t0, t0, 0xFF # t0 = e
li t1, 0xFF
beq t0, t1, skip_rounding # if e == 0xFF -> skip rounding
srli t1, a0, 16 # t1 = upper 16 bits
andi t0, t1, 1 # t0 = LSB of upper half
li t2, 0x7FFF
add t0, t2, t0 # rounding bias = 0x7FFF + LSB(1/0)
add a0, a0, t0 # RNE (round to nearest even)
skip_rounding:
srli a0, a0, 16 # return upper 16 bits
ret
bf16_to_f32:
slli a0, a0, 16 # shift bf16 to upper half
ret
```
:::
### special value function (nan, inf, zero)
:::spoiler c source code
```c=
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
```
:::
:::spoiler rv32i assembly code
```asm=
# ---- BF16 Masks ----
.equ BF16_SIGN_MASK, 0x8000
.equ BF16_EXP_MASK, 0x7F80
.equ BF16_MANT_MASK, 0x007F
# ---- BF16 Constant ----
.equ BF16_EXP_BIAS, 127
.equ BF16_NAN, 0x7FC0
.equ BF16_ZERO, 0x0000
# a0 = bf16 bits
# return 1 if exp==0x7F80 && mant!=0, else 0
bf16_isnan:
li t0, BF16_EXP_MASK # 0x7F80
and t1, a0, t0 # t1 = e
bne t1, t0, nan_false # if e != 0x7F80
andi t1, a0, BF16_MANT_MASK # t1 = m
sltu a0, x0, t1 # a0 = (m != 0)
ret
nan_false:
li a0, 0
ret
# a0 = bf16 bits
# return 1 if exp==0x7F80 && mant==0, else 0
bf16_isinf:
li t0, BF16_EXP_MASK # 0x7F80
and t1, a0, t0 # t1 = e
bne t1, t0, inf_false # if e != 0x7F80
andi t1, a0, BF16_MANT_MASK # t1 = m
sltiu a0, t1, 1 # a0 = (m == 0)
ret
inf_false:
li a0, 0
ret
# a0 = bf16 bits
# return 1 if (a.bits & 0x7FFF)==0, else 0
bf16_iszero:
li t0, 0x7FFF
and a0, a0, t0 # clear sign bit
sltiu a0, a0, 1 # a0 = (a0 == 0)
ret
```
:::
### compare function (eq, gt, lt)
:::spoiler c source code
```c=
static inline bool bf16_eq(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return true;
return a.bits == b.bits;
}
static inline bool bf16_lt(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return false;
bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1;
if (sign_a != sign_b)
return sign_a > sign_b;
return sign_a ? a.bits > b.bits : a.bits < b.bits;
}
static inline bool bf16_gt(bf16_t a, bf16_t b)
{
return bf16_lt(b, a);
}
```
:::
:::spoiler rv32i assembly code
```asm=
# a0 = a, a1 = b
# return: a0 = 1(true) / 0(false)
# bf16_isnan(a0) -> a0=1/0
# bf16_iszero(a0) -> a0=1/0
bf16_eq:
addi sp, sp, -16
sw ra, 12(sp)
sw a0, 8(sp) # save a
sw a1, 4(sp) # save b
# (isnan(a)) return 0;
lw a0, 8(sp) # a
jal ra, bf16_isnan
bnez a0, eq_false
# (isnan(b)) return 0;
lw a0, 4(sp) # b
jal ra, bf16_isnan
bnez a0, eq_false
# (iszero(a) && iszero(b)) return 1;
lw a0, 8(sp) # a
jal ra, bf16_iszero
beqz a0, cmp_bits # a != 0 -> cmp_bits
lw a0, 4(sp) # b
jal ra, bf16_iszero
bnez a0, eq_true # a == b == 0 -> true
cmp_bits:
# a0 = (a == b) ? 1 : 0 (xor + sltiu) avoid branch
lw t0, 8(sp) # t0 = a
lw t1, 4(sp) # t1 = b
xor t2, t0, t1 # t2 = a ^ b
sltiu a0, t2, 1 # a0 = (t2 == 0) ? 1 : 0
j eq_ret
eq_false:
li a0, 0
j eq_ret
eq_true:
li a0, 1
# fallthrough to ret
eq_ret:
lw ra, 12(sp)
addi sp, sp, 16
ret
# a0 = a, a1 = b
# return: a0 = 1 if (a < b) else 0
# bf16_isnan(a0) -> a0=1/0
# bf16_iszero(a0) -> a0=1/0
bf16_lt:
addi sp, sp, -16
sw ra, 12(sp)
sw a0, 8(sp) # save a
sw a1, 4(sp) # save b
# (isnan(a) || isnan(b)) return false;
lw a0, 8(sp) # a
jal ra, bf16_isnan
bnez a0, lt_false
lw a0, 4(sp) # b
jal ra, bf16_isnan
bnez a0, lt_false
# (iszero(a) && iszero(b)) return false;
lw a0, 8(sp) # a
jal ra, bf16_iszero
beqz a0, sign_cmp # a != 0 -> sign_cmp
lw a0, 4(sp) # b
jal ra, bf16_iszero
bnez a0, lt_false # a == b == 0 -> lt_false
sign_cmp:
lw t0, 8(sp) # t0 = a.bits
lw t1, 4(sp) # t1 = b.bits
srli t2, t0, 15
andi t2, t2, 1 # t2 = sign_a
srli t3, t1, 15
andi t3, t3, 1 # t3 = sign_b
bne t2, t3, diff_sign
# same sign
# sign = 0 -> pos:a.bits < b.bits
beqz t2, pos_cmp
# sign = 1 -> neg:a < b <-> a.bits > b.bits
sltu a0, t1, t0 # a0 = (b.bits < a.bits)
j ret_common
pos_cmp:
sltu a0, t0, t1 # a0 = (a.bits < b.bits)
j ret_common
diff_sign:
# return (sign_a > sign_b)
sltu a0, t3, t2 # a0 = (sign_b < sign_a)
j ret_common
lt_false:
li a0, 0
ret_common:
lw ra, 12(sp)
addi sp, sp, 16
ret
# a0 = a, a1 = b
# return: a0 = 1 if (a > b) else 0
bf16_gt:
mv t0, a0 # swap a0,a1
mv a0, a1
mv a1, t0
j bf16_lt # tail-call, no stack frame needed
```
:::
### arithmetic function (add, sub, mul, div, sqrt)
:::spoiler add/sub c source code
```c=
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
```
:::
:::spoiler add/sub rv32i assembly code
```asm=
# bf16_add(a0=a, a1=b) -> a0
# BF16: [sign:1][exp:8][frac:7]
# const: INF=0x7F80, NAN=0x7FC0, ZERO=0x0000
# regs:
# t0=a, t1=b
# t2=sign_a, t3=exp_a, t4=mant_a
# t5=sign_b, t6=exp_b
# a1=mant_b(temp), a0=temp/result
.data
cases:
.half 0x3F80,0x3F80
.half 0x4000,0x4000
.half 0x4380,0x4180
.half 0x4040,0xBF80
.half 0x3F80,0xBF00
# .half 0x4380,0x4180 # NaN + 1.0
# .half 0x3F80,0x7FC0 # 1.0 + NaN
# .half 0x7F80,0x4080 # +Inf + 4.0
# .half 0xc88b, 0xe9c9 # -1091.375 + -476.78125 = -1568.15625
# .half 0xfb44, 0xa286 # -0.001953125 + -20.15625 = -20.158203125
out:
.half 0,0,0,0,0 # 5 results
.text
main:
la s0, cases # s0 -> input pairs (a,b), 4 bytes each
la s1, out # s1 -> output, 2 bytes each
li s2, 5 # loop count
1: lh a0, 0(s0) # load a
lh a1, 2(s0) # load b
addi s0, s0, 4 # move to next input
jal ra, bf16_add # call bf16_add
sh a0, 0(s1) # store result
addi s1, s1, 2 # move to next output
addi s2, s2, -1
bnez s2, 1b
halt:
j halt # infinite loop
bf16_add:
# Save original a, b
mv t0, a0 # t0 = a
mv t1, a1 # t1 = b
# ---------------- Special case: a Exp==0xFF ----------------
srli t3, t0, 7 # t3 = exp_a
andi t3, t3, 0xFF
li a0, 0xFF
bne t3, a0, chk_b_ff
andi t4, t0, 0x7F # mant_a
bnez t4, ret_a # a is NaN
# a is Inf, check b
srli t6, t1, 7 # exp_b
andi t6, t6, 0xFF
bne t6, a0, ret_a # b not ExpFF -> return a
andi a0, t1, 0x7F # mant_b
bnez a0, ret_b # b is NaN -> return b
# both a and b are Inf: same sign -> return b; diff sign -> NaN
srli t2, t0, 15
andi t2, t2, 1
srli t5, t1, 15
andi t5, t5, 1
beq t2, t5, ret_b
li a0, 0x7FC0 # NaN
ret
# ---------------- Special case: b Exp==0xFF ----------------
chk_b_ff:
srli t6, t1, 7 # t6 = exp_b
andi t6, t6, 0xFF
li a0, 0xFF
bne t6, a0, quick_zero
andi a0, t1, 0x7F # mant_b
bnez a0, ret_b # b is NaN
mv a0, t1 # b is Inf
ret
# ---------------- Fast path: ±0 ----------------
quick_zero:
# a == ±0 ?
andi t4, t0, 0x7F # mant_a
beqz t3, 1f # exp_a==0 ?
j 2f
1: beqz t4, ret_b # a is ±0 -> return b
2:
# b == ±0 ?
andi a0, t1, 0x7F # mant_b (temporarily use a0)
beqz t6, 3f
j 4f
3: beqz a0, ret_a # b is ±0 -> return a
4:
# ---------------- Extract sign/exp/mant ----------------
srli t2, t0, 15 # sign_a
andi t2, t2, 1
srli t5, t1, 15 # sign_b
andi t5, t5, 1
andi t4, t0, 0x7F # mant_a
andi a1, t1, 0x7F # mant_b
# normal add implicit 1
beqz t3, 5f
ori t4, t4, 0x80
5: beqz t6, 6f
ori a1, a1, 0x80
6:
# ---------------- Exponent alignment ----------------
sub a0, t3, t6 # a0 = exp_diff = exp_a - exp_b
bgtz a0, diff_pos
bltz a0, diff_neg
mv a0, t3 # result_exp = exp_a (diff==0)
j add_or_sub
# exp_a > exp_b
diff_pos:
mv a0, t3 # a0 = result_exp
sub t6, t3, t6 # t6 = exp_diff
li t3, 8
bgt t6, t3, ret_a # diff>8 → return a
beqz t6, add_or_sub
srl a1, a1, t6 # mant_b >>= diff
j add_or_sub
# exp_a < exp_b
diff_neg:
sub t3, t6, t3 # t3 = -exp_diff = exp_b - exp_a
li t6, 8
bgt t3, t6, ret_b # -diff>8 → return b
beqz t3, 7f
srl t4, t4, t3 # mant_a >>= -diff
7: mv a0, t6 # a0 = result_exp = exp_b (already in t6)
# ---------------- Same sign → add; different sign → sub ----------------
add_or_sub:
beq t2, t5, same_sign
# different sign: larger minus smaller
bgeu t4, a1, 8f
sub t4, a1, t4 # result_mant = mant_b - mant_a
mv t2, t5 # result_sign = sign_b
beqz t4, ret_zero
j norm_sub
8:
sub t4, t4, a1 # result_mant = mant_a - mant_b
# result_sign = sign_a (t2)
beqz t4, ret_zero
norm_sub: # Normalize (shift left until bit7=1)
andi t6, t4, 0x80
bnez t6, pack
slli t4, t4, 1
addi a0, a0, -1 # --exp
blez a0, ret_zero
j norm_sub
same_sign:
add t4, t4, a1 # result_mant = mant_a + mant_b
li t6, 0x100
and t6, t4, t6 # check carry into bit9
beqz t6, pack
srli t4, t4, 1 # shift right back to 8 bits
addi a0, a0, 1 # ++exp
li t6, 0xFF
blt a0, t6, pack
# overflow -> ±Inf
li t6, 0x7F80
slli t2, t2, 15
or a0, t2, t6
ret
# ---------------- Pack back to BF16 ----------------
pack:
andi t4, t4, 0x7F # frac
andi a0, a0, 0xFF # exp
slli a0, a0, 7
slli t2, t2, 15 # sign
or a0, a0, t4
or a0, a0, t2
ret
# ---------------- Fast return ----------------
ret_a:
mv a0, t0
ret
ret_b:
mv a0, t1
ret
ret_zero:
li a0, 0x0000
ret
# a0=a, a1=b -> a0 = a - b
bf16_sub:
li t0, 0x8000 # mask: sign bit
xor a1, a1, t0 # b = b ^ 0x8000 (flip sign bit)
j bf16_add # tail call : a + (-b)
```
:::
:::spoiler mul c source code
```c=
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
```
:::
:::spoiler mul rv32i assembly code
```asm=
.data
cases:
.half 0x3FC0, 0x3FC0 # expected 0x4010 (~2.25) ; 1.5 * 1.5
.half 0x4040, 0x3F00 # expected 0x3FC0 (~1.5) ; 3.0 * 0.5
.half 0x7F7F, 0x4000 # expected 0x7F80 (+Inf) ; max_finite * 2.0
.half 0x0080, 0x0080 # expected 0x0000 (+0) ; min_normal * min_normal
.half 0x0001, 0x3F80 # expected 0x0000 (+0) ; min_subnormal * 1.0
.half 0x7F80, 0x0000 # expected 0x7FC0 (NaN) ; +Inf * +0
.half 0x7FC1, 0x4000 # expected 0x7FC1 (NaN) ; NaN(payload 0x01) * 2.0
.half 0xBFA0, 0x4000 # expected 0xC020 (~-2.5) ; -1.25 * 2.0
.half 0x0000, 0xC5A6 # expected 0x8000 (-0) ; +0 * negative
.half 0x0001, 0x0001 # expected 0x0000 (+0) ; subnormal * subnormal
out:
.half 0,0,0,0,0,0,0,0,0,0 # 10 results
.text
# ------------------------------------------------------------
# main: run 10 test pairs in 'cases', write results into 'out'
# ------------------------------------------------------------
main:
la s0, cases # s0 -> (a,b) pairs, 4 bytes per pair
la s1, out # s1 -> output buffer, 2 bytes per result
li s2, 10 # number of test pairs
loop:
lhu a0, 0(s0) # load a (zero-extend)
lhu a1, 2(s0) # load b (zero-extend)
addi s0, s0, 4 # advance to next pair
jal ra, bf16_mul # compute
sh a0, 0(s1) # store result
addi s1, s1, 2 # advance output pointer
addi s2, s2, -1
bnez s2, loop
halt:
j halt
# ------------------------------------------------------------
# mul8x8_u32: shift-add 8x8 unsigned multiply
# IN : a2=x(8-bit), a3=y(8-bit)
# OUT: a2 = x*y (lower 16 bits valid)
# Clobbers: t0,t1,t2
# ------------------------------------------------------------
mul8x8_u32:
li t0, 0
li t1, 8
mul8_loop:
andi t2, a2, 1
beqz t2, mul8_skip_add
add t0, t0, a3
mul8_skip_add:
srli a2, a2, 1
slli a3, a3, 1
addi t1, t1, -1
bnez t1, mul8_loop
mv a2, t0
ret
# ------------------------------------------------------------
# bf16_mul (a0=a_bits, a1=b_bits) -> a0=result_bits
# RV32I only; truncation (no RNE); handles NaN/Inf/±0 and subnormals
# ------------------------------------------------------------
bf16_mul:
# extract fields
srli t0, a0, 15 # sign_a
andi t0, t0, 1
srli t1, a1, 15 # sign_b
andi t1, t1, 1
xor t2, t0, t1 # result_sign = sign_a ^ sign_b
slli t2, t2, 15 # (sign<<15)
mv a5, t2 # SAVE sign in a5 (callee won't clobber)
srli t3, a0, 7 # exp_a
andi t3, t3, 0xFF
srli t4, a1, 7 # exp_b
andi t4, t4, 0xFF
andi t5, a0, 0x7F # mant_a
andi t6, a1, 0x7F # mant_b
# preload +Inf pattern for early special paths (will reload after mul)
li a2, 0x7F80
# special cases: exp==0xFF?
li t1, 0xFF
beq t3, t1, special_a
beq t4, t1, special_b
# zero short-circuit
beqz t3, check_a_zero
j check_b_zero_done
check_a_zero:
beqz t5, ret_signed_zero
check_b_zero_done:
beqz t4, check_b_zero
j norm_inputs
check_b_zero:
beqz t6, ret_signed_zero
# normalize inputs: subnormals shift until bit7=1; normals add implicit 1
norm_inputs:
li a4, 0 # exp_adjust = 0
# A: operand a
beqz t3, norm_a_sub
ori t5, t5, 0x80 # add implicit 1
j norm_b
norm_a_sub:
li t1, 0x80
norm_a_loop:
and t2, t5, t1
bnez t2, norm_a_done
slli t5, t5, 1
addi a4, a4, -1
j norm_a_loop
norm_a_done:
li t3, 1 # exp_a = 1
# B: operand b
norm_b:
beqz t4, norm_b_sub
ori t6, t6, 0x80
j mul_mant
norm_b_sub:
li t1, 0x80
norm_b_loop:
and t2, t6, t1 # use t2; do NOT touch a5
bnez t2, norm_b_done
slli t6, t6, 1
addi a4, a4, -1
j norm_b_loop
norm_b_done:
li t4, 1 # exp_b = 1
# mantissa multiply (8x8)
mul_mant:
mv a2, t5
mv a3, t6
addi sp, sp, -8
sw ra, 4(sp)
jal ra, mul8x8_u32 # clobbers t0,t1,t2 and a2/a3
lw ra, 4(sp)
addi sp, sp, 8
mv t5, a2 # result_mant = product
# RELOAD after call (callee clobbered these)
li t0, 127 # bias
li a2, 0x7F80 # +Inf pattern for ret_inf
# result_exp = exp_a + exp_b - bias + exp_adjust
add t6, t3, t4
sub t6, t6, t0
add t6, t6, a4
# normalize product to 1.x or 2.x; keep 7 fraction bits (truncate)
li t1, 0x8000
and a4, t5, t1
beqz a4, norm_prod_1x
# 2.x: (prod>>8)&0x7F; exp++
srli t5, t5, 8
andi t5, t5, 0x7F
addi t6, t6, 1
j check_over_under
norm_prod_1x:
# 1.x: (prod>>7)&0x7F
srli t5, t5, 7
andi t5, t5, 0x7F
# overflow / underflow / subnormal
check_over_under:
li t1, 255
bge t6, t1, ret_inf # overflow -> ±Inf
# underflow if exp <= 0
bge zero, t6, underflow_path
j pack
underflow_path:
# exp < -6 -> ±0
li t1, -6
blt t6, t1, ret_signed_zero
# subnormal output: mant >>= (1 - exp); exp = 0
li t1, 1
sub t1, t1, t6
srl t5, t5, t1
li t6, 0
j pack
# pack back to bf16
pack:
andi t6, t6, 0xFF
slli t6, t6, 7
andi t5, t5, 0x7F
or a0, a5, t6 # use saved sign in a5
or a0, a0, t5
ret
# return ±Inf
ret_inf:
or a0, a5, a2 # a2 = 0x7F80
ret
# return signed zero
ret_signed_zero:
mv a0, a5 # sign|0
ret
# ------------------------------------------------------------
# special cases
# special_a: a is Inf/NaN (exp_a==0xFF)
# special_b: b is Inf/NaN (exp_b==0xFF)
# ------------------------------------------------------------
# a is special:
# (mant_a) return a; # NaN → propagate operand
# (b==0) return NaN;
# return ±Inf;
special_a:
bnez t5, ret_a
beqz t4, chk_b_zero2
j ret_inf
chk_b_zero2:
bnez t6, ret_inf
li a0, 0x7FC0 # canonical NaN
ret
ret_a:
mv a0, a0 # return a as-is
ret
# b is special:
# (mant_b) return b; # NaN → propagate operand
# (a==0) return NaN;
# return ±Inf;
special_b:
bnez t6, ret_b
beqz t3, chk_a_zero2
j ret_inf
chk_a_zero2:
bnez t5, ret_inf
li a0, 0x7FC0 # canonical NaN
ret
ret_b:
mv a0, a1 # return b as-is
ret
```
:::
:::spoiler div c source code
```c=
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F),
};
}
```
:::
:::spoiler div rv32i assembly code
```asm=
.data
# ------------------------------------------------------------
# 10 pairs of (a,b) in BF16 (little-endian .half)
# ------------------------------------------------------------
cases:
# .half 0x3F80, 0x4000 # 1.0 / 2.0 -> 0x3F00 (0.5)
# .half 0x4000, 0x3F80 # 2.0 / 1.0 -> 0x4000 (2.0)
# .half 0x4040, 0x4000 # 3.0 / 2.0 -> 0x3FC0 (1.5)
# .half 0x0000, 0x4000 # 0.0 / 2.0 -> 0x0000 (+0.0)
# .half 0x3F80, 0x0000 # 1.0 / 0.0 -> 0x7F80 (+Inf)
# .half 0x7F80, 0x4000 # +Inf / 2.0 -> 0x7F80 (+Inf)
# .half 0x4000, 0x7F80 # 2.0 / +Inf -> 0x0000 (+0.0)
# .half 0x7FC1, 0x4000 # NaN / 2.0 -> 0x7FC1 (NaN, payload pass)
# .half 0x0040, 0x3F80 # subnormal_a / 1.0 -> 0x0000 (underflow→0)
# .half 0x3F80, 0x0001 # 1.0 / tiny subnormal -> 0x7F80 (+Inf)
.half 0x4040, 0x4000 # 3.0 / 2.0 -> 0x3FC0 (1.5)
.half 0x40A0, 0x4080 # 5.0 / 4.0 -> 0x3FA0 (1.25)
.half 0x4110, 0x4100 # 9.0 / 8.0 -> 0x3F90 (1.125)
.half 0x40E0, 0x4100 # 7.0 / 8.0 -> 0x3F70 (0.875)
.half 0x40C0, 0x4100 # 6.0 / 8.0 -> 0x3F40 (0.75)
.half 0x40A0, 0x4100 # 5.0 / 8.0 -> 0x3F20 (0.625)
.half 0x40A0, 0x4000 # 5.0 / 2.0 -> 0x4020 (2.5)
.half 0x4200, 0x4100 # 32 / 8 -> 0x4080 (4.0)
.half 0x4140, 0x4040 # 12 / 3 -> 0x4080 (4.0)
.half 0x4120, 0x4080 # 10 / 4 -> 0x4020 (2.5)
out:
.half 0,0,0,0,0,0,0,0,0,0 # 10 results
.text
# ------------------------------------------------------------
# main: run 10 test pairs in 'cases', write results into 'out'
# ------------------------------------------------------------
main:
la s0, cases # s0 -> (a,b) pairs, 4 bytes per pair
la s1, out # s1 -> output buffer, 2 bytes per result
li s2, 10 # number of test pairs
loop:
lhu a0, 0(s0) # load a (zero-extend)
lhu a1, 2(s0) # load b (zero-extend)
addi s0, s0, 4 # advance to next pair
jal ra, bf16_div # compute
sh a0, 0(s1) # store result
addi s1, s1, 2 # advance output pointer
addi s2, s2, -1
bnez s2, loop
halt:
j halt
# ------------------------------------------------------------
# bf16_div (a0=a_bits, a1=b_bits) -> a0=result_bits
# RV32I only; restoring division on mantissas; truncation (no RNE)
# Handles NaN/Inf/±0 and subnormals; underflow -> ±0 (no gradual subnormals)
# BF16: [sign:1][exp:8][frac:7], bias=127
# ------------------------------------------------------------
bf16_div:
# ------- unpack a -------
srli t2, a0, 15 # t2=sign_a
andi t2, t2, 1
srli t3, a0, 7 # t3=exp_a
andi t3, t3, 0xFF
andi t4, a0, 0x7F # t4=mant_a
mv s3, t3 # s3=exp_a
# ------- unpack b -------
srli t5, a1, 15 # t5=sign_b
andi t5, t5, 1
srli t6, a1, 7 # t6=exp_b
andi t6, t6, 0xFF
andi a2, a1, 0x7F # a2=mant_b
# result_sign = sign_a ^ sign_b
xor t0, t2, t5 # t0=result_sign
# consts
li a3, 0xFF # 0xFF
li a4, 0x7F80 # +Inf pattern
li a5, 127 # bias
# ------- b is Inf/NaN? -------
beq t6, a3, b_inf_nan
# ------- b == 0 ? -------
beqz t6, b_zero_check
j a_inf_nan_check
b_zero_check:
beqz a2, div_by_zero # x/0
# b subnormal -> continue
j a_inf_nan_check
div_by_zero:
# 0/0 -> NaN ; x/0 -> ±Inf
beqz s3, a_zero_chk_for_00
j ret_signed_inf
a_zero_chk_for_00:
beqz t4, ret_nan # 0/0 -> NaN
j ret_signed_inf
b_inf_nan:
# b is Inf/NaN
bnez a2, ret_b # b is NaN -> return b
# b is Inf: x/Inf -> signed zero; Inf/Inf handled later
beq s3, a3, a_inf_then_nan
slli a0, t0, 15 # signed zero
ret
a_inf_then_nan:
beqz t4, ret_nan # Inf/Inf -> NaN
# a is NaN actually (but mant!=0 implies NaN) -> return a
mv a0, a0
ret
# ------- a Inf/NaN? -------
a_inf_nan_check:
bne s3, a3, a_zero_check
bnez t4, ret_a # a is NaN -> return a
# a is Inf ; Inf/finite -> ±Inf
ret_signed_inf:
slli t1, t0, 15
or a0, t1, a4
ret
ret_a:
mv a0, a0
ret
# ------- a == 0 ? -------
a_zero_check:
beqz s3, a_zero_exp
j norm_mantissas
a_zero_exp:
beqz t4, ret_signed_zero # 0/x -> ±0
# a subnormal -> continue
j norm_mantissas
ret_signed_zero:
slli a0, t0, 15
ret
# ------- add hidden 1 for normals -------
norm_mantissas:
beqz s3, 1f
ori t4, t4, 0x80 # mant_a |= 0x80
1: beqz t6, 2f
ori a2, a2, 0x80 # mant_b |= 0x80
2:
# ------- restoring division (16 bits of quotient) -------
slli t1, t4, 15 # t1 = dividend
slli t2, a2, 15 # t2 = d (divisor aligned)
li t3, 0 # t3 = quotient
li a7, 16
div_loop:
slli t3, t3, 1 # quotient <<= 1
bltu t1, t2, no_sub
sub t1, t1, t2 # dividend -= d
ori t3, t3, 1 # quotient |= 1
no_sub:
srli t2, t2, 1 # d >>= 1
addi a7, a7, -1
bnez a7, div_loop
# ------- result_exp = exp_a - exp_b + bias (+subnormal adjust) -------
sub t1, s3, t6 # t1 = exp_a - exp_b
add t1, t1, a5 # t1 += bias
beqz s3, 3f
j 4f
3: addi t1, t1, -1 # a subnormal -> --exp
4: beqz t6, 5f
j 6f
5: addi t1, t1, 1 # b subnormal -> ++exp
6:
# ------- normalize quotient to 1.x then drop hidden 1 and truncate -------
# (quotient & 0x8000) q >>= 8; else while(!(q&0x8000)&&exp>1){ q<<=1; exp--; } q >>= 8;
li t2, 0x8000
and t4, t3, t2
bnez t4, q_has_one
q_need_shift:
# while (!(q&0x8000) && exp > 1) { q<<=1; exp--; }
and t4, t3, t2 # test (q & 0x8000)
bnez t4, q_align_done
addi t1, t1, 0 # exp
# addi t0, zero, 1 # tmp one
# ble t1, t0, q_align_done
li a6, 1
ble t1, a6, q_align_done
slli t3, t3, 1 # q <<= 1
addi t1, t1, -1 # exp--
j q_need_shift
q_align_done:
srli t3, t3, 8 # drop to frac range
j after_q_align
q_has_one:
srli t3, t3, 8
after_q_align:
andi t3, t3, 0x7F # keep 7-bit fraction
# ------- overflow/underflow checks -------
ble t1, zero, ret_signed_zero # exp <= 0 -> ±0 (no subnormals generated)
li t2, 255
bge t1, t2, ret_signed_inf # exp >= 255 -> ±Inf
# li t2, 0xFF
# bgeu t1, t2, ret_signed_inf # exp >= 255 -> ±Inf
# blez t1, ret_signed_zero # exp <= 0 -> ±0 (no subnormals generated)
# ------- pack sign|exp|frac -------
slli t0, t0, 15 # sign bit
andi t1, t1, 0xFF # clamp exp
slli t1, t1, 7 # exp << 7
or t0, t0, t1
or a0, t0, t3
ret
# ------- quick return -------
ret_b:
mv a0, a1
ret
ret_nan:
li a0, 0x7FC0 # canonical qNaN
ret
```
:::
:::spoiler sqrt c source code
```c=
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
```
:::
:::spoiler sqrt v32i assembly code
```asm=
.data
# ------------------------------------------------------------
# 10 BF16 inputs (.half, little-endian) for sqrt
# rule:
# NaN propagation
# sqrt(+Inf) -> +Inf
# sqrt(neg) -> NaN
# sqrt(subnormal) -> 0(flush)
# ------------------------------------------------------------
cases:
# .half 0x3F80 # sqrt(1.0) -> 1.0 (0x3F80)
# .half 0x4080 # sqrt(4.0) -> 2.0 (0x4000)
# .half 0x4010 # sqrt(2.25) -> 1.5 (0x3FC0)
# .half 0x3F00 # sqrt(0.5) -> ~0.7071 (≈0x3F59)
# .half 0x0000 # sqrt(+0.0) -> +0.0 (0x0000)
# .half 0x7F80 # sqrt(+Inf) -> +Inf (0x7F80)
# .half 0x7FC1 # sqrt(NaN) -> NaN(payload reserved) (0x7FC1)
# .half 0x0001 # sqrt(tiny subnormal) -> 0 (flush to 0)
# .half 0x4110 # sqrt(9.0) -> 3.0 (0x4040)
# .half 0xBF80 # sqrt(-1.0) -> NaN (0x7FC0)
.half 0x3F10 # 0.5625 -> 0.75 (0x3F40)
.half 0x3EC8 # 0.390625 -> 0.625 (0x3F20)
.half 0x4010 # 2.25 -> 1.5 (0x3FC0)
.half 0x40C8 # 6.25 -> 2.5 (0x4020)
.half 0x3C80 # 1/64=0.015625 -> 0.125 (0x3E00)
out:
.half 0,0,0,0,0,0,0,0,0,0 # 10 results
.text
# ------------------------------------------------------------
# main: run 10 inputs in 'cases', write results into 'out'
# a0=input_bits, ret a0=result_bits
# ------------------------------------------------------------
main:
la s0, cases # s0 -> inputs (2 bytes per case)
la s1, out # s1 -> output buffer
li s2, 5 # number of tests
loop:
lhu a0, 0(s0) # load input (zero-extend)
addi s0, s0, 2
jal ra, bf16_sqrt # compute sqrt
sh a0, 0(s1) # store result
addi s1, s1, 2
addi s2, s2, -1
bnez s2, loop
halt:
j halt
# ------------------------------------------------------------
# mul8x8_u32: 8x8 unsigned multiply (RV32I shift-add)
# IN : a2=x (8-bit), a3=y (8-bit)
# OUT: a2 = x*y (low 16 bits valid)
# Clobbers: t0, t1, t2
# ------------------------------------------------------------
mul8x8_u32:
li t0, 0
li t1, 8
mul_loop:
andi t2, a2, 1
beqz t2, mul_skip_add
add t0, t0, a3
mul_skip_add:
srli a2, a2, 1
slli a3, a3, 1
addi t1, t1, -1
bnez t1, mul_loop
mv a2, t0
ret
# ------------------------------------------------------------
# bf16_sqrt (a0=input_bits) -> a0=result_bits
# RV32I only; truncation (no RNE); NaN/Inf/±0/neg/subnormals handled.
# Mantissa scale = 128 (1.0 -> 128)
# ------------------------------------------------------------
bf16_sqrt:
# ------------- parse fields -------------
srli t0, a0, 15 # t0 = sign
andi t0, t0, 1
srli t1, a0, 7 # t1 = exp (8-bit)
andi t1, t1, 0xFF
andi t2, a0, 0x7F # t2 = mant (7-bit)
# ------------- NaN, Inf -------------
li t3, 0xFF
bne t1, t3, not_inf_nan # if exp==0xFF: NaN/Inf
bnez t2, ret_nan_payload # NaN: payload propogation -> return a0
bnez t0, ret_qnan # -Inf -> NaN
ret # +Inf -> +Inf
not_inf_nan:
beqz t1, exp_zero_or_subnorm # exp==0 ?
j check_negative
exp_zero_or_subnorm:
beqz t2, ret_zero # ±0 -> +0(flush to 0)
li a0, 0 # subnormal flush-to-zero -> 0
ret
check_negative:
beqz t0, core # non-zero & non-neg -> core
ret_qnan:
li a0, 0x7FC0 # quiet NaN
ret
core:
# e = exp - 127
addi t3, t1, -127 # t3 = e (signed)
# m = 0x80 | mant
ori t4, t2, 0x80 # t4 = m (uint32)
# (e & 1) { m<<=1; new_exp=((e-1)>>1)+127; } else { new_exp=(e>>1)+127; }
andi t2, t3, 1
beqz t2, exp_even
slli t4, t4, 1 # m <<= 1
addi t5, t3, -1
srai t5, t5, 1
addi t5, t5, 127 # t5 = new_exp
j after_exp_adjust
exp_even:
mv t5, t3
srai t5, t5, 1
addi t5, t5, 127 # t5 = new_exp
after_exp_adjust:
# ------------- binary search, result in [128..255] -------------
li t6, 90 # low
li a5, 255 # high
li a1, 0 # result
binsearch_loop:
bltu a5, t6, binsearch_done # if (high < low) break
add t1, t6, a5 # mid = (low + high) >> 1
srli t1, t1, 1
# --- sq = (mid*mid) >> 7 (scale to 128)
mv a4, t1 # backup mid → a4(mul will clobber t1)
mv a2, t1 # a2=mid
mv a3, t1 # a3=mid
addi sp, sp, -8
sw ra, 4(sp)
jal ra, mul8x8_u32 # a2 = mid*mid (low 16 bits);mul will clobber t0,t1,t2
lw ra, 4(sp)
addi sp, sp, 8
mv t1, a4 # restore mid
mv t2, a2
srli t2, t2, 7 # t2 = sq
# (sq <= m) { result=mid; low=mid+1; } else { high=mid-1; }
bgtu t2, t4, shrink_high
mv a1, t1 # result = mid
addi t6, t1, 1 # low = mid + 1
j binsearch_loop
shrink_high:
addi a5, t1, -1 # high = mid - 1
j binsearch_loop
binsearch_done:
# ------------- pack back BF16 -------------
andi a1, a1, 0x7F # new_mant = result & 0x7F
slli t5, t5, 7 # (new_exp << 7)
or a0, t5, a1 # sign = 0
ret
# ----------- quick return for special values -----------
ret_nan_payload:
ret # NaN payload propogation -> return a0
ret_zero:
li a0, 0 # +0
ret
```
:::