# Assignment1: RISC-V Assembly and Instruction Pipeline
contributed by [rrrchii](https://github.com/rrrchii/ca2025-quizzes/tree/main)
---
## Quiz 1 Prob B : UF8 ↔ Unsigned Integer Conversion
---
### Overview
**This problem is about implementing an a code between a 20-bit unsigned integer and an 8-bit uf8 format.**
**Input:** 20-bit unsigned integer
**Output:** 8-bit uf8 value, packs 4-bit exponent (e) and 4-bit mantissa (m)
```c
[ e e e e ][ m m m m ]
high 4 low 4
```
### Decode (uf8 → uint32)
---
Given an 8-bit uf8 value $b$:
- $\text{offset}=(2e−1)\cdot16$
- $D(b) = m \cdot 2^e + (2^e - 1) \cdot 16$
:::spoiler C Code
```c
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f; // extract the last 4 bit for mantissa part
uint8_t exponent = fl >> 4; // extract the last 4 bit for exponent part
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
/* e = 0, offset = 0 = 0x7FFF >> 15 = 0x0 0
e = 1, offset = 1 = 0x7FFF >> 14 = 0x1 0
e = 2, offset = 3 = 0x7FFF >> 13 = 0x3 0 ... */
return (mantissa << exponent) + offset; // m << e means m * 2^e
}
```
:::
### Encode (uint32 → uf8)
---
Given an unsigned integer $v$ :
$E(v) = \begin{cases}
v, & \text{if } v < 16 \\
16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise}
\end{cases}$
where $\text{offset}(e) = (2^e - 1) \cdot 16$
:::spoiler C Code
```c
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value; // if value < 16 return value
/* Find appropriate exponent using CLZ hint */
int lz = clz(value); // counting leading zeros
int msb = 31 - lz; // get MSB
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0; // overflow is offset(e)
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4; // e = msb - 4
if (exponent > 15)
exponent = 15; // max exponent is 15
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* e = 0, overflow = 16
e = 1, overflow = 16 * 2 + 16 = 48
e = 2, overflow = 48 * 2 + 16 = 112 ... */
/* Adjust down if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Adjust up if estimate was to small*/
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
// condition [ overflow < value < next_overflow ] is satisfied and find exact exponent.
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) { // range for all possible uf8 value
uint8_t fl = i;
int32_t value = uf8_decode(fl); // decode the 8 bits value
uint8_t fl2 = uf8_encode(value); // encode it back
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,value, fl2);
passed = false; // set passed to flase
}
if (value <= previous_value) { // Verify that it's strictly
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value); increasing
passed = false;
}
previous_value = value;
}
return passed;
}
```
:::
### Note
---
- **Test Data**
- These values are the exact integers that the UF8 format can represent without any loss after a round-trip conversion.
```
0,1,2,...,15, //step = 1
16,18,20,...,46, //step = 2
48,52,56,...,108, //step = 4
112,120,128,...,232,
240,256,272,... (until exponent=15)
```
- **Encoding and Decoding Behavior**
- If the unsigned integer `v < 16`, return `v` directly.
- Example: `v = 7 → 0x07`
Decode :
$m = 7, e = 0 ⇒ D(0x07) = 7 \cdot2^0+(2^0-1)\cdot16 = 7$
- For larger values, the exponent `e` determines the segment.
- Example: `v = 52 = 0b110100`
- MSB = 5
- Exponent: e = 5 - 4 = 1
- Segment offset: offset(e = 1) = 16
- next offset: offset(e = 2) = 48 (52>48)
- adjust exponent to 2
- new offset 48 ~ 112
- Encoding :
$mantissa = \lfloor(v - \text{offset}(e))/2^e\rfloor=\lfloor(52 - 48)/2^2\rfloor = 1$
$E(52)= \ \ (e<<4) \ | \ m \ \ = \ \ (2<<4) \ | \ 2 \ = \ 0x21$
Decoding :
$m = 1, e = 2 ⇒ D(0x21) = 2 \cdot2^2+(2^2-1)\cdot16 = 8+48=56$
(Decoded value is slightly less than the original value, $b<v$)
- **Exponent and Segment Relationship**
- Each exponent e defines a segment of length: $16 × 2^e$
- Larger exponents correspond to larger segment lengths.
- **Immediate Range**
- The representable immediate value range is: $-2048 \ to \ 2047$
- **Rationale for Encoding After Decoding in Tests**
- The tests perform encode-after-decode to ensure that all values within a practical range are covered.
- An upper limit of 256 is used because this encompasses all possible UTF-8 byte values (0–255), ensuring comprehensive validation across typical encoding scenarios.
### Implementation(實作)
---
#### Execution info for C code (using rv64i compiler)
I can't find the `riscv32-unknown-elf-gcc.exe` that adapt to windows environment.
Is there any method to fix this issue?
I use the method of this page to install `riscv64-unknown-elf-gcc.exe`
https://hackmd.io/@accomdemy/BJprQ8Xjc/https%3A%2F%2Fhackmd.io%2F%40accomdemy%2FSyoatR-sc

#### [RISC-V](https://github.com/rrrchii/ca2025-quizzes/blob/main/uf8_to_uint32_last_version.s)
:::spoiler Assembly code
```
.data
text1: .asciz ": produces value "
text2: .asciz " but encodes back to "
text3: .asciz ": value "
text4: .asciz " <= previous_value "
text5: .asciz "tests pass."
text6: .asciz "some tests failed."
newline: .asciz "\n"
.text
.globl main
main:
jal ra, test
beq a0, x0, fail # a0 is passed flag if (passed == 0) means check fail
la a0, text5
li a7, 4
ecall # print "tests pass."
la a0, newline
li a7, 4
ecall # print "\n"
li a0, 0 # a0 return 1 means success
li a7, 10
ecall # exit
test:
addi sp, sp, -4
sw ra, 0(sp) # protect ra
addi s0, x0, -1 # previous value
li s1, 1 # defult s1(passed flag) = 1 mean ture
li s2, 0 # counter i
li s3, 256 # end of counter i
test_loop:
bgeu s2, s3, return_passed
mv a0, s2
jal ra, uf8_decode
mv s4, a0 # s4 decode value
jal ra, uf8_encode
mv s5, a0 # the value has been decode and encode
check_1:
beq s2, s5, check_2 # if s2 == s5 mean check_1 success
mv a0, s2
li a7, 34
ecall # print "i" in hex
la a0, text1
li a7, 4
ecall # print "produces value"
mv a0, s4
li a7, 1
ecall # print "value" in dec
la a0, text2
li a7, 4
ecall # print "but encodes back to"
mv a0, s5
li a7, 34
ecall # print the "value" that has been decode and encode in hex
la a0, newline
li a7, 4
ecall # print "\n" for newline
li s1, 0 # set s1 = 0 means passed = false
check_2:
bgt s2, s0, check_done
mv a0, s2
li a7, 34
ecall # print "i" in hex
la a0, text3
li a7, 4
ecall # print ": value"
mv a0, s4
li a7, 1
ecall # print "value" in dec
la a0, text4
li a7, 4
ecall # print " <= previous_value "
mv a0, s0
li a7, 34
ecall # print "previous_value value" in hex
la a0, newline
li a7, 4
ecall # print newline
li s1, 0 # set s1 = 0 means passed = false
check_done:
mv s0, s2 # to check the increase
mv a0, s1 # return passed
addi s2, s2, 1
j test_loop
fail:
la a0, text6
li a7, 4
ecall # print "some tests failed."
la a0, newline
li a7, 4
ecall # print "\n"
li a0, 1 # a0 return 1 means fail
li a7, 10
ecall # exit
return_passed:
lw ra, 0(sp) # load back the return address from main:
addi sp, sp, 4
ret
clz:
beq a0, x0, zero # if (x == 0) return 32
li t0, 0 # n = 0
srli t1, a0, 16
bne t1, x0, chk8
addi t0, t0, 16
slli a0, a0, 16 # if ((x >> 16) == 0) { n += 16; x <<= 16; }
chk8:
srli t1, a0, 24
bne t1, x0, chk4
addi t0, t0, 8
slli a0, a0, 8 # if ((x >> 24) == 0) { n += 8; x <<= 8; }
chk4:
srli t1, a0, 28
bne t1, x0, chk2
addi t0, t0, 4
slli a0, a0, 4 # if ((x >> 28) == 0) { n += 4; x <<= 4; }
chk2:
srli t1, a0, 30
bne t1, x0, chk1
addi t0, t0, 2
slli a0, a0, 2 # if ((x >> 30) == 0) { n += 2; x <<= 2; }
chk1:
srli t1, a0, 31
bne t1, x0, ret
addi t0, t0, 1 # if ((x >> 31) == 0) { n += 1; }
ret:
mv a0, t0 # return n
ret
zero:
li a0, 32
ret
uf8_decode:
andi t0, a0, 0x0F # t0 is mantissa
srli t1, a0, 4 # t1 is exponent
li t2, 15
sub t2, t2, t1 # t2 = 15 - exponent
li t3, 0x7FFF
srl t3, t3, t2
slli t3, t3, 4 # t3 is offset
sll t0, t0, t1 # mantissa << exponent
add a0, t0, t3 # fl = (mantissa << exponent) + offset
ret
uf8_encode:
addi sp, sp, -4
sw ra, 0(sp) # because it will call CLZ in this function
li t4, 16 # t4 = 16
bltu a0, t4, fast_return # if( value < 16 ) return value;
mv t5, a0 # t5 = value use t5 input clz
jal ra, clz
mv t0, a0 # t0 = lz
li t1, 31
sub t1, t1, t0 # t1 = msb
li t2, 0 # exponent = 0
li t3, 0 # overflow = 0
li t4, 5 # t6 = 5
blt t1, t4, after_adjust_down
addi t2, t1, -4 # exponent = msb - 4
li t4, 15 # t6 = 15
bleu t2, t4, exp_ok # if(exp > 15)
li t2, 15 # exp = 15
exp_ok:
li t4, 0 # t4 = counter e = 0
build_overflow:
bgeu t4, t2, adjust_overflow
slli t3, t3, 1 # overflow <<= 1
addi t3, t3, 16 # (overflow << 1) + 16
addi t4, t4, 1 # e ++
j build_overflow
adjust_overflow:
adjust_down:
beq t2, x0, after_adjust_down # exp > 0
bgeu t5, t3, after_adjust_down # value < overflow
addi t3, t3, -16 # overflow - 16
srli t3, t3, 1 # (overflow - 16) >> 1
addi t2, t2, -1 # exp --
j adjust_down
after_adjust_down:
li t6, 15
li t4, 0
adjust_up:
bgeu t2, t6, after_estimate_overflow
slli t4, t3, 1 # next_overflow = (overflow << 1)
addi t4, t4, 16 # next_overflow = (overflow << 1) + 16
blt t5, t4, after_estimate_overflow
mv t3, t4 # overflow = next_overflow
addi t2, t2, 1 # exp ++
j adjust_up
after_estimate_overflow:
ret_val:
sub t5, t5, t3
srl t5, t5, t2 # t5 is mantissa
slli a0, t2, 4
or a0, a0, t5
fast_return:
lw ra, 0(sp)
addi sp, sp, 4
ret
```
:::
#### Execution info for RISC-V code

### Improvment of CLZ
---
#### Analysis
Both functions compute the number of leading zeros in a 32-bit integer. The key differences are control flow and dependencies.
#### Looped version (`do { ... } while (c)`)
- Each iteration does:
1) `y = x >> c`
2) branch on `if (y)` (data-dependent)
3) `c >>= 1`
4) loop back branch `while (c)` (**backward branch**)
- Up to 5 iterations (c=16,8,4,2,1), **2 branches per iteration** → many branches including a back-edge.
- `x` and `c` are updated each round → **loop-carried dependencies**, limiting ILP and scheduling.
- Ends with `return n - x;` (a trick that reduces readability and may hinder optimization).
#### Branch-and-shift (16/8/4/2/1) version
```c
static inline unsigned clz32(uint32_t x) {
if (x == 0) return 32;
unsigned n = 0;
if ((x & 0xFFFF0000u) == 0) { n += 16; x <<= 16; }
if ((x & 0xFF000000u) == 0) { n += 8; x <<= 8; }
if ((x & 0xF0000000u) == 0) { n += 4; x <<= 4; }
if ((x & 0xC0000000u) == 0) { n += 2; x <<= 2; }
if ((x & 0x80000000u) == 0) { n += 1; }
return n;
}
```
#### Execution info for RISC-V code after improvment

improve cycle count about 12%
---
## Quiz 1 Prob C : BF16 ↔ FP32 Conversion
---
### Note
- **Subnormal (Denormalized) Values**
- Definition:
- the exponent field = 0 (but not all bits are 1)
- the mantissa (fraction) ≠ 0.
- Purpose:
- Subnormal numbers allow for a **gradual underflow** between zero and the smallest normalized value.
### BF16 Conversion and Classification Utilities
:::spoiler C code
```c
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) //exp = all 1, mantissa != 0
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK); //exp == all 1, mantissa != 0
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK); // exp == all 1, mantissa == 0
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF); // exp == 0, mantissa == 0
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float)); // copy float val in bit pattern to uint32 called f32bits
if (((f32bits >> 23) & 0xFF) == 0xFF) // exp = all 1
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF; // ties-to-even
return (bf16_t) {.bits = f32bits >> 16}; // 判斷完進位後回傳高16位
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16; // bf16 value left shift 16 bits
float result;
memcpy(&result, &f32bits, sizeof(float));// copy float val in bit pattern to uint32 called f32bits
return result;
}
```
:::
#### Add and Sub
##### Note
- **Rounding Rule — Ties to Eve**
| low 16 bits| high last bit | carry or not |
| ---------- | -------- | -------------|
| `< 0x8000` | | ❌ |
| `= 0x8000` | 0 | ❌ |
| `= 0x8000` | 1 | ✅ |
| `> 0x8000` | | ✅ |
This rule ensures that rounding produces the nearest even result, avoiding bias over multiple operations.
- **Bitwise Copy for Floating-Point Reinterpretation**
- When reinterpreting floating-point bits (for example, from uint32_t to float), direct casting may lead to undefined behavior.
- To ensure correct bitwise reinterpretation, use memcpy:
```memcpy(&result, &f32bits, sizeof(float));```
This copies the exact bit pattern from f32bits into result without altering the binary layout.
- **Normalizing the Mantissa**
- Once the number is confirmed to be within the representable range, we must restore the implicit leading `1` of the mantissa for normalized values:
`mant_a |= 0x80;`
This makes the mantissa represent $1.mantissa$ instead of $0.mantissa$,
which is required for correct reconstruction:
$(-1^{sign} \ × \ (1.mantissa) \ × \ 2^{(exp-bias)})$
Without this step, the computation would represent a subnormal (denormalized) value incorrectly.
- **BF16 Addition — Exponent Alignment**
| condition | meaning | motion | math |
| --------------- | ------- | -------------------- | ------------ |
| `exp_diff > 0` | a bigger that b | `mant_b >>= exp_diff` | Align exp_b to exp_a |
| `exp_diff < 0` | b bigger that a | `mant_a >>= -exp_diff` | Align exp_a to exp_b |
| `exp_diff > 8` | b is too small | `return a` | b can't influence the result |
| `exp_diff < -8` | a is too small | `return b` | b can't influence the result |
| `exp_diff == 0` | already align | no need movement | already align just add a and b |
After alignment, both mantissas share the same exponent base, allowing direct addition or subtraction.
#### C Code
::: spoiler C code
```c
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1; // get a, b sign bit
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF); // get a, b exp bits
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F; // get a, b mantissa bits
if (exp_a == 0xFF) { // a is Inf/NaN
// a is NaN
if (mant_a)
return a;
// a is Inf
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
/* mant_b != 0 b is NaN so a+b is NaN
* sign_a == sign_b a,b are +Inf a+b = +Inf (same as negative)
* sign_a != sing_b a is +Inf b is _Inf a+b = NaN */
return a;
}
/* Summary of specials handled above/below:
Inf + NaN = NaN
NaN + NaN = NaN
+Inf + +Inf = +Inf
-Inf + -Inf = -Inf
+Inf + -Inf = NaN
*/
if (exp_b == 0xFF)
return b; // b is NaN so a+b is NaN
if (!exp_a && !mant_a)
return b; // a = 0 a+b = 0+b = b return b
if (!exp_b && !mant_b)
return a; // b = 0 a+b = a+0 = a return a
if (exp_a)
mant_a |= 0x80; // exp_a != 0 把mantissa最開頭的1補上,確保後面計算正確
if (exp_b)
mant_b |= 0x80; // exp_b != 0 把mantissa最開頭的1補上,確保後面計算正確
int16_t exp_diff = exp_a - exp_b; // difference between exp_a and exp_b
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) { // define exponent bits and align mantissa bits
result_exp = exp_a; // a is bigger so use exp_a
if (exp_diff > 8)
return a;
mant_b >>= exp_diff; // align mant_b to mant_a
} else if (exp_diff < 0) {
result_exp = exp_b; // b is bigger so use exp_b
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff; // align mant_a to mant_b
} else {
result_exp = exp_a; // exp_a == exp_b so use exp_a
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b; // a and b have same sign so we can add mantissa
if (result_mant & 0x100) { // if addition of mantissa bigger than 0xFF
result_mant >>= 1; // result_mant right shift 1 bit
if (++result_exp >= 0xFF) // if result_exp overflow
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // return ±Inf depends on the sign
}
} else { // sing_a != sing_b
if (mant_a >= mant_b) { // mant_a bigger than mant_b
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else { // mant_a < mant_b
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant) // exact cancel → +0 (per BF16_ZERO macro)
return BF16_ZERO(); // sing_a != sing_b && mant_b == mant_a
// Normalize left until hidden 1 (bit7) is restored; decrement exponent
while (!(result_mant & 0x80)) {
result_mant <<= 1; // mantissa left shift times 2
if (--result_exp <= 0) // exp - 1
return BF16_ZERO(); // if adjust to zero return BF16_ZERO()
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
}; // Pack result
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK; // flip sign_b
return bf16_add(a, b);
}
```
:::
#### RISC-V
:::spoiler Assembly code
```
bf16_add:
srli t0, a0, 15
andi t0, t0, 1 # t0 = sign_a (0/1)
srli t1, a1, 15 # t1 = sign_b (0/1)
andi t1, t1, 1
srli t2, a0, 7
andi t2, t2, 0xFF # t2 = exp_a (8-bit)
srli t3, a1, 7
andi t3, t3, 0xFF # t3 = exp_b (8-bit)
andi t4, a0, 0x7F # t4 = mant_a (7-bit)
andi t5, a1, 0x7F # t5 = mant_b (7-bit)
li t6, 0xFF
bne t2, t6, check_exp_b_FF
beqz t4, a_is_Inf
mv a0, a0
ret # a is NaN
a_is_Inf:
bne t3, t6, ret_a_inf # a is Inf b is not NaN or Inf
bnez t5, ret_b # b is NaN return b
bne t0, t1, nan # +Inf + -Inf = NaN
ret_b:
mv a0, a1
ret # b is NaN
ret_a_inf:
mv a0, a0
ret # a and b are same sign Inf
check_exp_b_FF: # a is not NaN or Inf
bne t3, t6, check_a_zero # b is not NaN or Inf either
mv a0, a1 # b is NaN return NaN, b is Inf return Inf
ret
check_a_zero:
bnez t2, check_b_zero
bnez t4, check_b_zero
mv a0, a1 # a is zero
ret # return b
check_b_zero:
bnez t3, add_implict1
bnez t5, add_implict1 # a and b are not NaN Inf zero
mv a0, a0 # b is zero
ret # return a
add_implict1:
beqz t2, a_no_implict1
ori t4, t4, 0x80
a_no_implict1:
beqz t3, b_no_implict1
ori t5, t5, 0x80
b_no_implict1:
sub t6, t2, t3 # t6 <- exp_diff
bgez t6, exp_diff_ge0
mv s0, t3 # s0 = result_exp <- exp_b
li s1, -8
blt t6, s1, ret_b # a is too small
neg s1, t6 # s1 = -exp_diff
srl t4, t4, s1
j exp_aligned
exp_diff_ge0:
beqz t6, exp_equal
mv s0, t2 # s0 = result_exp <- exp_a
li s1, 8
bgt t6, s1, ret_a # b is too small
srl t5, t5, t6
j exp_aligned
exp_equal:
mv s0, t2 # s0 = result_exp <- exp_a
exp_aligned:
bne t0, t1, diff_sign
mv s2, t0 # s2 = result_sign <- sign_a
add s3, t4, t5 # s3 = result_mant <- mant_a + mant_b
andi s5, s3, 0x100 # check is matissa overflow ?
beqz s5, pack
srli s3, s3, 1
addi s0, s0, 1
li s4, 0xFF
blt s0, s4, pack # check is exp overflow
li a0, 0x7F80
slli s2, s2, 15
or a0, a0, s2
ret # exp overflow return +-Inf
diff_sign:
blt t4, t5, b_bigger
mv s2, t0 # s2 = result_sign <- sign_a
sub s3, t4, t5 # s3 = result_mant <- mant_a - mant_b
j after_sub
b_bigger:
mv s2, t1 # s2 = result_sign <- sign_b
sub s3, t5, t4 # s3 = result_mant <- mant_b - mant_a
after_sub:
beqz s3, ret_zero
norm_loop:
andi s5, s3, 0x80
bnez s5, pack # while (!(result_mant & 0x80))
slli s3, s3, 1
addi s0, s0, -1
blez s0, ret_zero
j norm_loop
pack: # s0 : result_exp s2 : result_sign s3 : result_mant
andi s3, s3, 0x7F
andi s0, s0, 0xFF
slli s0, s0, 7
slli s2, s2, 15
or a0, s0, s3
or a0, a0, s2
ret
ret_a:
mv a0, a0
ret
nan:
li a0, 0x7FC0 # NaN
ret
ret_zero:
mv a0, x0 # zero
ret
bf16_sub:
li t0, 0x8000
xor a1, a1, t0 # reverse b_sign
j bf16_add
```
:::
#### Execution info for RISC-V code

#### Test Case
| Case | Operation | Input A (Hex) | Input A (value) | Input B (Hex) | Input B (value) | expect (Hex) | expect (value) | test type |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---:
| **TC1** | add | `0x0000` | $+0$ | `0x3FC0` | $+1.5$ | `0x3FC0` | $+1.5$ | normal |
| **TC2** | add | `0x7F80` | $+\infty$ | `0xFF80` | $-\infty$ | `NaN` | $NaN$ | edge |
| **TC3** | add | `0x3FC0` | $+1.5$ | `0x3F00` | $+0.5$ | `0x4000` | $+2.0$ | normal |
| **TC4** | sub | `0x0000` | $+0$ | `0x3FC0` | $+1.5$ | `0xBFC0` | $-1.5$ | normal |
| **TC5** | sub | `0x7F80` | $+\infty$ | `0xFF80` | $-\infty$ | `0x7F80` | $+\infty$ | edge
| **TC6** | sub | `0x3FC0` | $+1.5$ | `0x3F00` | $+0.5$ | `0x3F80` | $+1.0$ | normal |
### multiple
#### Note
- Special case: `a = Inf,b = NaN`
- updated C handles this explicitly
```c
// a = Inf or NaN
if (exp_a == 0xFF) {
if (mant_a) // a = NaN
return a; // a * b = NaN
/*extra condition*/
if (exp_b == 0xFF && mant_b)
return BF16_NAN(); // Inf * NaN = NaN
if (!exp_b && !mant_b)
return BF16_NAN(); // a * b = NaN
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
// a = Inf, b = NaN will return here and the value is Inf
// but Inf * NaN = NaN
} ```
- **Exponent normalization when an operand’s exponent is zero**
- Goal: Convert subnormal to normalized before multiply (if possible).
- Example :
- Input:
`exp_a = 0`, `mant_a = 0x10 (0001 0000)`
`exp_b = 130`, `mant_b = 0x40 (0100 0000)`
- Process:
- For `a` (subnormal): shift left until hidden-1 restored
→ shift 3 times: `mant_a = 0x80`, `exp_adjust = -3`, `set exp_a = 1`
- For `b` (already normal): `exp_b = 130`, then restore hidden-1:
`mant_b |= 0x80` → `mant_b = 0xC0`
- Meaning of `mant_a |= 0x80;`
BF16 mantissa stores bits[6:0]. For normalized values we must insert the hidden 1 at bit7, so arithmetic uses 1.mantissa (not 0.mantissa) in:
$(-1^{sign} \ × \ (1.mantissa) \ × \ 2^{(exp-bias)})$
- Post-multiply normalization when `result_mant >= 2`
- Example 1 – no adjust
`mant_a = 1.1000000b (1.5)`
`mant_b = 1.0010000b (1.125)`
→ `mant_a * mant_b = 1.6875 ≈ 1.1011b`(in `[1,2)`) → no shift.
- Example 2 – needs adjust
`mant_a = 1.1111111b (1.992)`
`mant_b = 1.1111111b (1.992)`
→ `mant_a * mant_b ≈ 3.968 ≈ 11.111b` → `result_mant > = 2`
→ normalize:
`result_mant = result_mant >> 1` → `result_exp ++`
This keeps the mantissa in the normalized range `[1, 2)` with the hidden-1 at bit7.
:::spoiler C Code
#### C Code
```c
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1; // get a, b sign bit
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF); // get a, b exp bits
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F; // get a, b mant bits
uint16_t result_sign = sign_a ^ sign_b; // same sign : 0, different sign : 1
if (exp_a == 0xFF) { // a = Inf or NaN
if (mant_a) // a = NaN
return a; // a * b = NaN
// a = Inf
if (exp_b == 0xFF && mant_b) // b = NaN
return BF16_NAN(); // Inf * NaN = NaN
if (!exp_b && !mant_b) // a = Inf, b = 0
return BF16_NAN(); // a * b = NaN
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // return ±Inf depends on the sign
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (exp_a == 0xFF && mant_a) // a = NaN
return BF16_NAN(); // Inf * NaN = NaN
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
} // same as above
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15}; // except for above condition, when a or b is zero return ±0 depends on the sign
int16_t exp_adjust = 0;
if (!exp_a) { // exp_a = 0
while (!(mant_a & 0x80)) { // mant_a's MSB is not 1
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80; // 補上隱含的bit7,mantissa實際是bit6...0,但在計算時會補上bit7為1,也就是exp = 1讓他可以成為1.xxxx,方便計算
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--; // if both a and b exp = 0 the exp_adjust will be accumulate (累積)
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b; //both mant_a and mant_b are 1.XXXX and the exp part will compute by exp_adjust
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; // compute exp part
if (result_mant & 0x8000) { //當bit15是1,代表乘積超過2.0
result_mant = (result_mant >> 8) & 0x7F; // result_mant基本是移7 bits讓結果是[6,...,0],但是因為乘積超過2.0在多移1bit共8bits,把bit15對齊到bit7,除二讓他回到1.xxx
result_exp++; // 然後把 result_exp ++ 相當於總數乘2
} else
result_mant = (result_mant >> 7) & 0x7F; // 剩下bit15是0,代表bit14是1,乘積不超過2.0直接移 7 bits,把bit14對齊到bit7
if (result_exp >= 0xFF) //if result_exp overflow
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // return ±Inf depends on the sign
if (result_exp <= 0) {
if (result_exp < -6) // 這個 -6 是 bf16 的 mantissa 能表示的極限
return (bf16_t) {.bits = result_sign << 15}; // smaller than -6, return ±0 depends on the sign
result_mant >>= (1 - result_exp); // 0~6之間 直接右移mantissa
result_exp = 0; // result_exp 就可以是0
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)}; // combine all component and return
}
```
:::
#### RISC-V
:::spoiler Assembly code
```
.data
msg_pass:
.asciz "tests pass\n"
msg_fail:
.asciz "some tests fail\n"
.text
.globl main
main:
li s10, 0 # fail_count = 0
# ------------- case1 ---------------
# TC1: 1.5 * 0.5 = 0.75
# a = 0x3FC0 (1.5), b = 0x3F00 (0.5), exp = 0x3F40 (0.75)
li a0, 0x3FC0 # a = 1.5 (bf16)
li a1, 0x3F00 # b = 0.5 (bf16)
jal ra, bf16_mul
li t0, 0x3F40 # expected = 0.75 (bf16)
bne a0, t0, 1f # if result != expected -> fail++
j 2f
1:
addi s10, s10, 1
2:
# ------------- case2 (edge) --------
# TC2 (edge): +Inf * +0 = NaN
# a = 0x7F80 (+Inf), b = 0x0000 (+0)
li a0, 0x7F80 # +Inf
li a1, 0x0000 # +0
jal ra, bf16_mul
li t1, 0x7F80
and t2, a0, t1 # �� exponent bits
bne t2, t1, 3f # exp != 0xFF -> fail
andi t3, a0, 0x7F # �� mantissa
beqz t3, 3f # mant == 0 -> ���O NaN -> fail
j 4f
3:
addi s10, s10, 1
4:
# ------------- case3 ---------------
# TC3: (-2.0) * (-0.5) = +1.0
# a = 0xC000 (-2.0), b = 0xBF00 (-0.5), exp = 0x3F80 (+1.0)
li a0, 0xC000 # -2.0 (bf16)
li a1, 0xBF00 # -0.5 (bf16)
jal ra, bf16_mul
li t0, 0x3F80 # +1.0 (bf16)
bne a0, t0, 5f
j 6f
5:
addi s10, s10, 1
6:
# ----------------------------
# print result
# ----------------------------
beqz s10, print_pass # fail_count == 0 �� pass
la a0, msg_fail
li a7, 4 # print_string
ecall
j exit
print_pass:
la a0, msg_pass
li a7, 4 # print_string
ecall
exit:
li a7, 10 # exit
ecall
bf16_mul:
srli t0, a0, 15
andi t0, t0, 1 # t0 = sign_a (0/1)
srli t1, a1, 15 # t1 = sign_b (0/1)
andi t1, t1, 1
srli t2, a0, 7
andi t2, t2, 0xFF # t2 = exp_a (8-bit)
srli t3, a1, 7
andi t3, t3, 0xFF # t3 = exp_b (8-bit)
andi t4, a0, 0x7F # t4 = mant_a (7-bit)
andi t5, a1, 0x7F # t5 = mant_b (7-bit)
xor s0, t0, t1 # s0 = result_sign
li t6, 0xFF
bne t2, t6, check_b_FF # exp_a != 0xFF ?
bnez t4, nan # exp_a == FF mant_a != 0 => NaN
check_b_FF:
bne t3, t6, both_not_nan # exp_b != 0xFF ?
bnez t5, nan # exp_b == FF mant_b != 0 => NaN
both_not_nan:
beq t2, t6, a_is_inf
j check_b_inf
a_is_inf:
beqz t3, b_is_zero
j check_b_inf
b_is_zero:
beqz t5, nan # a = Inf b = 0 return nan
check_b_inf:
bne t3, t6, check_inf
beqz t2, b_is_inf
j check_inf
b_is_inf:
beqz t4, nan # b = Inf a = 0 return nan
check_inf:
beq t2, t6, inf # b != Inf a == Inf ret Inf
beq t3, t6, inf # a != Inf b == Inf ret Inf
beqz t2, check_mant_a_zero
j check_exp_b_zero
check_mant_a_zero:
beqz t4, zero
check_exp_b_zero:
beqz t3, check_mant_b_zero
j exp_adjust
check_mant_b_zero:
beqz t5, zero
exp_adjust:
li a2, 0
beqz t2, while_a
li s7, 0x80
or t4, t4, s7
j adjust_b
while_a: # while (!(mant_a & 0x80)) { mant_a <<=1; exp_adjust--; }
andi a4, t4, 0x80 # (mant_a & 0x80)
bnez a4, done_adjust_a
slli t4, t4, 1
addi a2, a2, -1
j while_a
done_adjust_a:
li t2, 1
adjust_b:
beqz t3, while_b
ori t5, t5, 0x80
j adjust_done
while_b:
andi a4, t5, 0x80
bnez a4, done_adjust_b
slli t5, t5, 1
addi a2, a2, -1
j while_b
done_adjust_b:
li t3, 1
adjust_done:
li a3, 0
li a4, 8 # counter
mul_loop: # mant_a * mant_b
andi a5, t5, 1
beqz a5, skip_add
add a3, a3, t4
skip_add:
slli t4, t4, 1
srli t5, t5, 1
addi a4, a4, -1
bnez a4, mul_loop
# result_exp = exp_a + exp_b - 127 + exp_adjust
add a4, t2, t3 # exp_a + exp_b
add a4, a4, a2 # + exp_adjust
addi a4, a4, -127 # BF16_EXP_BIAS
li a5, 0x8000
and a5, a3, a5
beqz a5, no_big
srli a3, a3, 8
andi a3, a3, 0x7F
addi a4, a4, 1
j scaled
no_big:
srli a3, a3, 7
andi a3, a3, 0x7F
scaled:
li a5, 0xFF
ble a5, a4, inf
blez a4, underflow
slli t0, s0, 15 # sign
andi t1, a4, 0xFF
slli t1, t1, 7 # exp<<7
or t0, t0, t1
or t0, t0, a3 # + mant
mv a0, t0
ret
inf:
slli a0, s0, 15
li s8, 0x7F80
or a0, a0, s8
ret
zero:
slli a0, s0, 15 # 0
ret
nan:
li a0, 0x7FC0 # NaN
ret
underflow:
li a5, -6
blt a4, a5, zero
li a5, 1
sub a5, a5, a4
shift_dn:
beqz a5, pack_den
srli a3, a3, 1
addi a5, a5, -1
j shift_dn
pack_den:
andi a3, a3, 0x7F
slli t0, t6, 15 # sign
or t0, t0, a3 # exp=0
andi a0, t0, 0xFF
ret
```
:::
#### Test Case
| Case | Operation | Input A (Hex) | Input A (value) | Input B (Hex) | Input B (value) | expect (Hex) | expect (value) | test type |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---:
| **TC1** | mul | `0x3FC0` | $+1.5$ | `0x3F00` | $+0.5$ | `0x3F40` | $+0.75$ | normal |
| **TC2** | mul | `0x7F80` | $+\infty$ | `0x0000` | $0$ | `NaN` | $NaN$ | edge |
| **TC3** | mul | `0xC000` | $-2.0$ | `0xBF00` | $-0.5$ | `0x3F80` | $+1.0$ | normal |
#### Execution info for RISC-V code

### divide
#### Note
- **Meaning of `if (!exp_a) result_exp--;`**
When performing floating-point multiply/divide operations in hardware, the mantissa is assumed to be in normalized form $1.xxx$
To simplify the arithmetic, subnormal numbers (where $exp = 0$)are temporarily treated as if normalized by setting $exp = 1$
After the computation, we must adjust the exponent back to correct for this artificial bias.
`if dividend(a)`was subnormal : `result_exp --` (divide by 2)
`if divisor(b)`was subnormal : `result_exp ++` (multiply by 2)
#### C code
:::spoiler C code
```c
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1; // get a, b sign bit
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF); // get a, b exp bits
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F; // get a, b mant bits
uint16_t result_sign = sign_a ^ sign_b; // same sign : 0, different sign : 1
if (exp_b == 0xFF) {
if (mant_b)
return b; // b is NaN, anything / NaN = NaN
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN(); // a and b are Inf, Inf / Inf = NaN
return (bf16_t) {.bits = result_sign << 15}; // b = Inf, anything / ±Inf = ±0 return ±0 depends on the result_sign
}
if (!exp_b && !mant_b) { // if b = ±0
if (!exp_a && !mant_a) // if a = ±0
return BF16_NAN(); // 0 / 0 = NaN
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // b = ±0 a != 0 a / ±0 = ±Inf
}
if (exp_a == 0xFF) {
if (mant_a)
return a; // a is NaN, NaN / anything = NaN
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // a is ±Inf
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15}; // a is 0 ,except above condition a / b = 0
if (exp_a)
mant_a |= 0x80; // let mantissa's bit 7 be 1 to ensure the computing
if (exp_b)
mant_b |= 0x80; // same as above
uint32_t dividend = (uint32_t) mant_a << 15; // dividend(被除數) extend to 32bits mant_a[7...0] move to [22...15]
uint32_t divisor = mant_b; // divisor(除數)
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1; // left shift to get a empty LSB
if (dividend >= (divisor << (15 - i))) { //divisor align to dividend and check is dividend bigger than divisor, if not go to next iteration
dividend -= (divisor << (15 - i)); // if yes "dividend - divisor"
quotient |= 1; // and let the quotient set to 1
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++; // explain at Note
if (quotient & 0x8000) // if 商's MSB is at bit15 -> >=2
quotient >>= 8; // right shift 7+1bits to align bit15 to bit7
else { //商's MSB is at bit14 -> < 2
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
} // first align to bit15 to get the exact result_exp
quotient >>= 8; // right shift 7+1bits to align bit15 to bit7
}
quotient &= 0x7F; //extract last 7 bits [bit6..bit0]
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // ±Inf
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};// ±0
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F)};
} //combine all component
```
:::
#### RISC-V
:::spoiler RISV-V
```
bf16_div:
srli t0, a0, 15
andi t0, t0, 1
srli t1, a1, 15
andi t1, t1, 1
srli t2, a0, 7
andi t2, t2, 0xFF
srli t3, a1, 7
andi t3, t3, 0xFF
andi t4, a0, 0x7F
andi t5, a1, 0x7F
xor a2, t0, t1
check_b_FF:
addi t6, x0, 255
bne t3, t6, check_b_zerocase
bne t5, x0, ret_b
bne t2, t6, zero
beq t4, x0, nan
jal x0, zero
check_b_zerocase:
bne t3, x0, check_a_FF
bne t5, x0, check_a_FF
bne t2, x0, inf
beq t4, x0, nan
jal x0, inf
check_a_FF:
bne t2, t6, check_a_zerocase
bne t4, x0, ret_a
jal x0, inf
check_a_zerocase:
bne t2, x0, implict_1_a
beq t4, x0, zero
implict_1_a:
beq t2, x0, implict_1_b
ori t4, t4, 0x80
implict_1_b:
beq t3, x0, start_divide
ori t5, t5, 0x80
start_divide:
slli a4, t4, 15
addi a5, x0, 0
slli t6, t5, 15
addi t0, x0, 16
div_loop:
slli a5, a5, 1
sltu t1, a4, t6
bne t1, x0, no_sub
sub a4, a4, t6
ori a5, a5, 1
no_sub:
srli t6, t6, 1
addi t0, t0, -1
bne t0, x0, div_loop
result_exp:
sub a3, t2, t3
addi a3, a3, 127
beq t2, x0, dec_exp_a
jal x0, chk_exp_b
dec_exp_a:
addi a3, a3, -1
chk_exp_b:
beq t3, x0, inc_exp_b
jal x0, norm_q
inc_exp_b:
addi a3, a3, 1
norm_q:
lui t0, 0x8
and t1, a5, t0
beq t1, x0, shift_left_phase
srli a5, a5, 8
jal x0, combine_all_component
shift_left_phase:
norm_loop:
and t1, a5, t0
bne t1, x0, after_left_norm
addi t1, x0, 1
slt t1, t1, a3
beq t1, x0, after_left_norm
slli a5, a5, 1
addi a3, a3, -1
jal x0, norm_loop
after_left_norm:
srli a5, a5, 8
combine_all_component:
andi a5, a5, 0x7F
addi t0, x0, 255
bge a3, t0, inf
beq a3, x0, zero
slt t1, a3, x0
bne t1, x0, zero
andi t0, a3, 255
slli t0, t0, 7
slli a0, a2, 15
or a0, a0, t0
or a0, a0, a5
jalr x0, ra, 0
ret_a:
addi a0, a0, 0
jalr x0, ra, 0
ret_b:
addi a0, a1, 0
jalr x0, ra, 0
inf:
slli a0, a2, 15
addi t0, x0, 255
slli t0, t0, 7
or a0, a0, t0
jalr x0, ra, 0
nan:
addi t0, x0, 255
slli t0, t0, 7
ori a0, t0, 0x40
jalr x0, ra, 0
zero:
slli a0, a2, 15
jalr x0, ra, 0
```
:::
#### Test Case
| Case | Operation | Input A (Hex) | Input A (value) | Input B (Hex) | Input B (value) | expect (Hex) | expect (value) | test type |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---:
| **TC1** | div | `0x3FC0` | $+1.5$ | `0x3F00` | $+0.5$ | `0x4040` | $+3.0$ | normal |
| **TC2** | div | `0x4000` | $+2.0$ | `0x0000` | $+0$ | `Inf` | $Inf$ | edge |
| **TC3** | div | `0xC000` | $-2.0$ | `0xBF00` | $-0.5$ | `0x4080` | $+4.0$ | normal |
#### Execution info for RISC-V code

### square root
#### Note
- binary search 那邊概念還不是很通,先寫risc-v,寫完再回頭看
#### C Code
:::spoiler C code
```c
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1; // get "a" sign bit
int16_t exp = ((a.bits >> 7) & 0xFF); // get "a" exp bits
uint16_t mant = a.bits & 0x7F; // get "a" mantissa bits
/* Handle special cases */
if (exp == 0xFF) {
if (mant) //a is NaN
return a; /* NaN propagation */
if (sign) //a is -Inf
return BF16_NAN(); // sqrt(-Inf) = NaN
return a; // sqrt(+Inf) = +Inf
}
if (!exp && !mant) // a is 0
return BF16_ZERO(); // sqrt(0) = 0
if (sign) // a is negative value
return BF16_NAN(); // sqrt(negative val) = NaN
if (!exp) // value that is too small or it's subnormal value
return BF16_ZERO(); // return zero
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS; // get the real exponent e without bias
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; // let subnormal value get implicit 1 that can be well compute
/* Range [128, 256) representing [1.0, 2.0), using interger to represent decimal (小數) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; // Double mantissa for odd exponentadjust odd exp into
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; // e -> (e-1) / 2
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;// e -> e / 2
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++; // result bigger the 256, result divided by 2 and exp + 1
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--; // result smaller than 128, result multiple by 2 and exp - 1
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; // +Inf
if (new_exp <= 0)
return BF16_ZERO(); // 0
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
} // combine all component
```
:::
#### RISC-V
:::spoiler RISV-V
```
bf16_sqrt:
srli t0, a0, 15
andi t0, t0, 1 # sign
srli t1, a0, 7
andi t1, t1, 0xFF # exp
andi t2, a0, 0x7F # mant
check_FF:
addi t3, x0, 0xFF
bne t1, t3, check_zerocase
bne t2, x0, nan
bne t0, x0, nan
jal x0, inf
check_zerocase:
bne t1, x0, check_negative
bne t2, x0, check_negative
jal x0, zero
check_negative:
beq t0, x0, check_too_small
jal x0, nan
check_too_small:
beq t1, x0, zero
real_exp:
addi t1, t1, -127 # e 取代 exp
li t4, 0 # new_exp
implict_1:
ori t2, t2, 0x80 # m 取代 mant
start_square_root:
andi t6, t1, 1
sll t2, t2, t6 # mant <<= mask
sub t4, t1, t6 # new_exp = e - mask
srai t4, t4, 1 # new_exp >>= 1 new_exp is int32_t so need use arithmetic shift
addi t4, t4, 127 # new_exp += BF16_EXP_BIAS
binary_search_init:
li a1, 90 # low
li a2, 255 # high
li a3, 128 # result
binary_search_loop:
bgt a1, a2, binary_loop_end
add a4, a1, a2
srli a4, a4, 1 # mid
mul_loop_init:
li a5, 0 # sq
li s0, 0 # i
li s1, 32
mul_loop:
bge s0, s1, mul_loop_end
srl t6, a4, s0 # mask
andi t6, t6, 1
sub t6, x0, t6
sll s3, a4, s0
and s3, s3, t6
add a5, a5, s3
addi s0, s0, 1 # i++
j mul_loop
mul_loop_end:
srli a5, a5, 7
ble a5, t2, set_result
addi a2, a4, -1
j binary_search_loop
set_result:
mv a3, a4
addi a1, a4, 1
j binary_search_loop
binary_loop_end:
li s6, 0xFF
andi t2, a3, 0x7F
blt t4, s6, chk_new_exp
jal x0, inf
chk_new_exp:
bgt t4, x0, return
jal x0, zero
return:
andi t4, t4, 0xFF
slli t4, t4, 7
or a0, t4, t2
ret
inf:
slli a0, t0, 15
addi t0, x0, 255
slli t0, t0, 7
or a0, a0, t0
jalr x0, ra, 0
nan: # 0x7FC0
li t0, 0 # sign = 0
addi t1, x0, 255
slli t1, t1, 7
ori a0, t1, 0x40 # mant = 0x40
jalr x0, ra, 0
zero: # 0x0000
mv a0, x0
jalr x0, ra, 0
```
:::
#### Test Case
| Case | Operation | Input A (Hex) | Input A (value) | expect (Hex) | expect (value) | test type |
| :---: |:---: | :---: | :---: | :---: | :---: | :---: |
| **TC1** | sqrt | `0x4080` | $+4.0$ | `0x4000` | $+2.0$ | normal |
| **TC2** | sqrt | `0xBF80` | $-1.0$ | `0x7F80` | $NaN$ | edge |
| **TC3** | dqrt | `0x7F80` | $+Inf$ | `0x7F80` | $+Inf$ | edge |
#### Execution info for RISC-V code

### Reference
https://hackmd.io/@3xOSPTI6QMGdj6jgMMe08w/Bk-uxCYxz
https://zhuanlan.zhihu.com/p/540887151
## Five Stage
| Stage | Name | Function Summary | Where to Observe in Ripes |
| ------- | ------------------ | ------------------------------------------ | ------------------------------------------------ |
| 1️⃣ IF | Instruction Fetch | Fetch the instruction from memory | **Instruction Memory**, **Program Counter (PC)** |
| 2️⃣ ID | Instruction Decode | Decode the instruction and read registers | **Register File** |
| 3️⃣ EX | Execute / ALU | Perform arithmetic or calculate address | **ALU**, **MUX (ALUSrc)** |
| 4️⃣ MEM | Memory Access | Access data memory (load/store) | **Data Memory** |
| 5️⃣ WB | Write Back | Write the result back to the register file | **Register File (Write Enable)** |
### `li s10, 10`
#### Instruction Function
This is a pseudo-instruction, which the assembler expands into the real RISC-V instruction:`addi s10, x0, 10`
#### IF(Instruction Fetch)

- The PC (Program Counter) starts at address `0x00000000`.
- The Instruction Memory outputs the 32-bit instruction stored at that address, 0x00a00d13, which corresponds to addi `s10, x0, 10`.
#### ID(Instruction Decode)

- The Decoder interprets the instruction 0x00a00d13 and splits it into its fields:
```
opcode = 0010011 (I-type, addi)
funct3 = 000
rd = s10 (x26)
rs1 = x0
imm = 10
```
- Reg1 (x0) reads the value `0x00000000`, since register x0 is always hardwired to zero.
- Because this is an I-type instruction, there is no Reg2 input; the second ALU operand comes from the immediate value instead. Thus, Reg2 shows the default value `0x00000000`.
- The Immediate Generator outputs the immediate value 10, represented as `0x0000000a`.
#### EX(Execute / ALU)

- The ALU performs the operation `Reg[x0] + imm` → `0 + 10 = 10`.
- Since this is an `addi` instruction, the MUX (Multiplexer) selects the immediate (imm) as the second ALU input instead of a register.
- The Result (Res) output is `0x0000000a`, indicating the ALU computed the correct result.
#### MEM(Memory Access)

- This instruction does not access the data memory because it is neither a load nor a store instruction.
- The ALU result simply passes through to the next stage without modifying memory.
#### WB(Write Back)

- The ALU result (10, or `0x0000000a`) is written back to the destination register `rd = s10` (x26, register index `0x1a`).
- After execution, you can confirm that s10 now contains the value `0x0000000a`.
