# Assignment 1: RISC-V Assembly and Instruction Pipeline
contributed by < [`daoxuewu`](https://github.com/daoxuewu/ca2025-quizzes) >
source code : [repository](https://github.com/daoxuewu/ca2025-homework1)
## [Problem B](https://hackmd.io/@sysprog/arch2025-quiz1-sol#Problem-B) (from quiz1)
### What Is `uf8`?
> A logarithmic codec that maps 20-bit unsigned integers (0–1,015,792) into 8-bit symbols. It gives ~2.5:1 compression with ≤6.25% relative error—good when range matters more than fine precision.
>
Below is the formula of uf8 encoding and decoding.
* **Decoding**
$$
D(b) = m \cdot 2^e + (2^e - 1) \cdot 16
$$
where $e = \lfloor b/16 \rfloor$ and $m = b \bmod 16$
* **Encoding**
$$
E(v) = \begin{cases}
v, & \text{if } v < 16 \\
16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise}
\end{cases}
$$
where $\text{offset}(e) = (2^e - 1) \cdot 16$
Error Analysis
- **Absolute Error:** $\Delta_{\max} = 2^{e} - 1$
- **Relative Error:** $\varepsilon_{\max} = 1/16 = 6.25\%$
- **Expected Error:** $\mathbb{E}[\varepsilon] \approx 3\%$
Information Theory
- **Input Entropy:** 20 bits
- **Output Entropy:** 8 bits
- **Theoretical Minimum:** 7.6 bits (for 6.25% error bound)
- **Efficiency:** \(8/7.6 = 95\%\) optimal
| Exponent | Range | Step Size |
|:--------|:------------------------|:----------|
| 0 | $[0, 15]$ | 1 |
| 1 | $[16, 46]$ | 2 |
| 2 | $[48, 108]$ | 4 |
| 3 | $[112, 232]$ | 8 |
| … | … | $2^{e}$ |
| 15 | $[524,272, 1,015,792]$ | 32,768 |
### C code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x) // count leading zeros
{
// Assume we start with 32 leading zeros; each step tests half (binary search idea)
int n = 32, c = 16;
do {
uint32_t y = x >> c; // shift the number right by c bits
if (y) { // check if the high part has any non-zero bit; y != 0 means MSB is on the left side of the midpoint
n -= c; // at least c bits are not leading zeros
x = y; // shrink the range and continue
}
c >>= 1; // divide c by 2; continue binary search
} while (c);
return n - x; // clz(0) returns 32 by design; for nonzero x, the final x becomes 1 and (n - x) is in [0,31].
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; // (2^e - 1) * 16
return (mantissa << exponent) + offset; // m * 2^e + (2^e - 1) * 16
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
exponent = msb - 4; // With 4 mantissa bits, a good initial guess is exponent = msb - 4 (derived from the format).
if (exponent > 15)
exponent = 15;
/* Calculate overflow (offset) for the estimated exponent: compute (2^e - 1) * 16 */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was too large: adjust downward */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent: adjust upward if the estimate was too small */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent; // (value - offset(e)) / (2^e)
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
printf("test data: %d\n", i);
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
### RV32I Assembly code
Use a **branchless CLZ** to accelerate the **precomputation step** in `uf8_encode`.
:::spoiler Expand for details
```asm=
.text
.globl _start
_start:
jal ra, main # jump to main
halt:
li a7, 10
ecall # if main returns, exit
# ------------------------------------------------------------
# uint32_t clz_branchless(uint32_t x)
# Branchless binary-search CLZ: uses sltu to form shift amounts
# ------------------------------------------------------------
clz_branchless:
li t0, 32 # n = 32
mv t1, a0 # t1 = x
srli t2, t1, 16 # y = x >> 16
sltu t3, zero, t2 # b = (y != 0)
slli t4, t3, 4 # s = b * 16
srl t1, t1, t4 # x >>= s
sub t0, t0, t4 # n -= s
srli t2, t1, 8
sltu t3, zero, t2
slli t4, t3, 3 # s = b * 8
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 4
sltu t3, zero, t2
slli t4, t3, 2 # s = b * 4
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 2
sltu t3, zero, t2
slli t4, t3, 1 # s = b * 2
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 1
sltu t3, zero, t2 # b = (x>>1) != 0
mv t4, t3 # s = b * 1
srl t1, t1, t4
sub t0, t0, t4
sub a0, t0, t1 # return n - x (x becomes 0 or 1)
ret
# ------------------------------------------------------------
# uint32_t uf8_decode(uint8_t fl)
# D(b) = (m << e) + ((1<<e) - 1) << 4
# ------------------------------------------------------------
uf8_decode:
andi t0, a0, 0x0F # m = fl & 0x0f
srli t1, a0, 4 # e = fl >> 4
li t2, 1
sll t2, t2, t1
addi t2, t2, -1
slli t2, t2, 4 # offset = ((1<<e)-1) << 4
sll t0, t0, t1
add a0, t0, t2 # value = (m << e) + offset
ret
# ------------------------------------------------------------
# uint8_t uf8_encode(uint32_t value)
# Uses branchless CLZ to locate msb -> initial exponent guess
# ------------------------------------------------------------
uf8_encode:
addi sp, sp, -16
sw ra, 12(sp) # store ra (because we call clz_branchless)
sw s0, 8(sp) # s0 = exponent
sw s1, 4(sp) # s1 = overflow (offset)
sltiu t0, a0, 16 # if (value < 16) return value;
beqz t0, enCode_normalVal
andi a0, a0, 0xFF # keep in 8-bit just in case
j enCode_ret
enCode_normalVal:
mv t5, a0 # v = value
jal ra, clz_branchless # a0 = clz(value)
li t0, 31
sub t1, t0, a0 # msb = 31 - lz
mv a0, t5 # restore value
li s0, 0 # exponent = 0
li s1, 0 # overflow = 0
sltiu t2, t1, 5 # if (msb >= 5) => !(t1<5)
bnez t2, enCode_find_up # if msb < 5, skip initial guess
addi s0, t1, -4 # exponent = msb - 4
sltiu t3, s0, 16
bnez t3, enCode_initGuess_overflow_ok
li s0, 15 # clamp exponent to 15
# overflow = ((1<<e)-1) << 4
enCode_initGuess_overflow_ok:
li s1, 0 # overflow = 0
li t4, 0 # e = 0
enCode_overflow_loop:
bge t4, s0, enCode_adjust_down
slli s1, s1, 1
addi s1, s1, 16
addi t4, t4, 1
j enCode_overflow_loop
# if value < overflow, step exponent down until it fits
enCode_adjust_down:
beqz s0, enCode_find_up
sltu t4, a0, s1
beqz t4, enCode_find_up
addi s1, s1, -16
srli s1, s1, 1 # overflow = (overflow - 16) >> 1
addi s0, s0, -1
j enCode_adjust_down
# then go upward to the exact exponent
enCode_find_up:
li t4, 15
enCode_up_loop:
bge s0, t4, enCode_up_done
slli t1, s1, 1
addi t1, t1, 16 # next_overflow = (overflow << 1) + 16
sltu t2, a0, t1
bnez t2, enCode_up_done
mv s1, t1
addi s0, s0, 1
j enCode_up_loop
enCode_up_done:
sub t0, a0, s1 # num = value - overflow
srl t0, t0, s0 # mantissa = num >> exponent
slli t1, s0, 4
andi t0, t0, 0x0F # mantissa &= 0x0F
or a0, t1, t0 # a0 = (exponent<<4) | mantissa
andi a0, a0, 0xFF
enCode_ret:
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
addi sp, sp, 16
ret
# ------------------------------------------------------------
# static bool test(void)
# Prints each case, checks decode/encode roundtrip and monotonicity
# ------------------------------------------------------------
test:
addi sp, sp, -40
sw ra, 36(sp)
sw s0, 32(sp) # previous_value
sw s1, 28(sp) # passed
sw s2, 24(sp) # i
sw s3, 20(sp) # value
sw t0, 16(sp)
sw t1, 12(sp)
sw t2, 8(sp)
sw t3, 4(sp)
li s0, -1 # previous_value = -1
li s1, 1 # passed = true
li s2, 0 # i = 0
t_loop:
li t3, 256
bge s2, t3, t_done
la a0, msg_test # "test data: "
li a7, 4
ecall
mv a0, s2 # print i
li a7, 1
ecall
la a0, msg_nl # newline
li a7, 4
ecall
mv a0, s2 # fl = i
jal ra, uf8_decode # value = decode(fl)
mv s3, a0
mv a0, s3
jal ra, uf8_encode # fl2 = encode(value)
mv t1, a0
bne s2, t1, t_fail_flag # if (fl != fl2) report
slt t2, s0, s3 # if (previous_value < value) OK
bnez t2, t_set_prev
# non-increasing
t_bad_inc:
la a0, msg_noninc_a
li a7, 4
ecall
mv a0, s2 # fl
li a7, 1
ecall
la a0, msg_noninc_b
li a7, 4
ecall
mv a0, s3 # value
li a7, 1
ecall
la a0, msg_noninc_c
li a7, 4
ecall
mv a0, s0 # previous_value
li a7, 1
ecall
la a0, msg_nl
li a7, 4
ecall
li s1, 0 # passed = false
j t_set_prev
# mismatch fl vs fl2
t_fail_flag:
la a0, msg_mismatch_a
li a7, 4
ecall
mv a0, s2 # fl
li a7, 1
ecall
la a0, msg_mismatch_b
li a7, 4
ecall
mv a0, s3 # value
li a7, 1
ecall
la a0, msg_mismatch_c
li a7, 4
ecall
mv a0, t1 # fl2
li a7, 1
ecall
la a0, msg_nl
li a7, 4
ecall
li s1, 0 # passed = false
t_set_prev:
mv s0, s3 # previous_value = value
addi s2, s2, 1 # i++
j t_loop
t_done:
mv a0, s1 # return passed
lw ra, 36(sp)
lw s0, 32(sp)
lw s1, 28(sp)
lw s2, 24(sp)
lw s3, 20(sp)
lw t0, 16(sp)
lw t1, 12(sp)
lw t2, 8(sp)
lw t3, 4(sp)
addi sp, sp, 40
ret
# ------------------------------------------------------------
# int main(void)
# ------------------------------------------------------------
main:
addi sp, sp, -16
sw ra, 12(sp)
jal ra, test
beqz a0, print_fail
la a0, msg_ok # "All tests passed.\n"
li a7, 4
ecall
j main_exit
print_fail:
la a0, msg_fail
li a7, 4
ecall
main_exit:
lw ra, 12(sp)
addi sp, sp, 16
li a7, 10
ecall
# ------------------------------------------------------------
# Data section
# ------------------------------------------------------------
.data
.align 4
msg_ok: .asciz "All tests passed.\n"
msg_fail: .asciz "Failed.\n"
msg_test: .asciz "test data: "
msg_nl: .asciz "\n"
msg_mismatch_a: .asciz "mismatch: fl= "
msg_mismatch_b: .asciz " value= "
msg_mismatch_c: .asciz " fl2= "
msg_noninc_a: .asciz "non-increasing: fl= "
msg_noninc_b: .asciz " value= "
msg_noninc_c: .asciz " prev= "
```
:::
result:

***
Compiler generate (RISC-V 32bits gcc 13.1.0)
:::spoiler Expand for details
```asm=
clz(unsigned int):
addi sp,sp,-48
sw s0,44(sp)
addi s0,sp,48
sw a0,-36(s0)
li a5,32
sw a5,-20(s0)
li a5,16
sw a5,-24(s0)
.L3:
lw a5,-24(s0)
lw a4,-36(s0)
srl a5,a4,a5
sw a5,-28(s0)
lw a5,-28(s0)
beq a5,zero,.L2
lw a4,-20(s0)
lw a5,-24(s0)
sub a5,a4,a5
sw a5,-20(s0)
lw a5,-28(s0)
sw a5,-36(s0)
.L2:
lw a5,-24(s0)
srai a5,a5,1
sw a5,-24(s0)
lw a5,-24(s0)
bne a5,zero,.L3
lw a4,-20(s0)
lw a5,-36(s0)
sub a5,a4,a5
mv a0,a5
lw s0,44(sp)
addi sp,sp,48
jr ra
uf8_decode(unsigned char):
addi sp,sp,-48
sw s0,44(sp)
addi s0,sp,48
mv a5,a0
sb a5,-33(s0)
lbu a5,-33(s0)
andi a5,a5,15
sw a5,-20(s0)
lbu a5,-33(s0)
srli a5,a5,4
sb a5,-21(s0)
lbu a5,-21(s0)
li a4,15
sub a5,a4,a5
li a4,32768
addi a4,a4,-1
sra a5,a4,a5
slli a5,a5,4
sw a5,-28(s0)
lbu a5,-21(s0)
lw a4,-20(s0)
sll a4,a4,a5
lw a5,-28(s0)
add a5,a4,a5
mv a0,a5
lw s0,44(sp)
addi sp,sp,48
jr ra
uf8_encode(unsigned int):
addi sp,sp,-64
sw ra,60(sp)
sw s0,56(sp)
addi s0,sp,64
sw a0,-52(s0)
lw a4,-52(s0)
li a5,15
bgtu a4,a5,.L8
lw a5,-52(s0)
andi a5,a5,0xff
j .L9
.L8:
lw a0,-52(s0)
call clz(unsigned int)
mv a5,a0
sw a5,-32(s0)
li a4,31
lw a5,-32(s0)
sub a5,a4,a5
sw a5,-36(s0)
sb zero,-17(s0)
sw zero,-24(s0)
lw a4,-36(s0)
li a5,4
ble a4,a5,.L16
lw a5,-36(s0)
andi a5,a5,0xff
addi a5,a5,-4
sb a5,-17(s0)
lbu a4,-17(s0)
li a5,15
bleu a4,a5,.L11
li a5,15
sb a5,-17(s0)
.L11:
sb zero,-25(s0)
j .L12
.L13:
lw a5,-24(s0)
slli a5,a5,1
addi a5,a5,16
sw a5,-24(s0)
lbu a5,-25(s0)
addi a5,a5,1
sb a5,-25(s0)
.L12:
lbu a4,-25(s0)
lbu a5,-17(s0)
bltu a4,a5,.L13
j .L14
.L15:
lw a5,-24(s0)
addi a5,a5,-16
srli a5,a5,1
sw a5,-24(s0)
lbu a5,-17(s0)
addi a5,a5,-1
sb a5,-17(s0)
.L14:
lbu a5,-17(s0)
beq a5,zero,.L16
lw a4,-52(s0)
lw a5,-24(s0)
bltu a4,a5,.L15
j .L16
.L19:
lw a5,-24(s0)
slli a5,a5,1
addi a5,a5,16
sw a5,-40(s0)
lw a4,-52(s0)
lw a5,-40(s0)
bltu a4,a5,.L20
lw a5,-40(s0)
sw a5,-24(s0)
lbu a5,-17(s0)
addi a5,a5,1
sb a5,-17(s0)
.L16:
lbu a4,-17(s0)
li a5,14
bleu a4,a5,.L19
j .L18
.L20:
nop
.L18:
lw a4,-52(s0)
lw a5,-24(s0)
sub a4,a4,a5
lbu a5,-17(s0)
srl a5,a4,a5
sb a5,-41(s0)
lb a5,-17(s0)
slli a5,a5,4
slli a4,a5,24
srai a4,a4,24
lb a5,-41(s0)
or a5,a4,a5
slli a5,a5,24
srai a5,a5,24
andi a5,a5,0xff
.L9:
mv a0,a5
lw ra,60(sp)
lw s0,56(sp)
addi sp,sp,64
jr ra
.LC0:
.string "test data: %d\n"
.LC1:
.string "%02x: produces value %d but encodes back to %02x\n"
.LC2:
.string "%02x: value %d <= previous_value %d\n"
test():
addi sp,sp,-48
sw ra,44(sp)
sw s0,40(sp)
addi s0,sp,48
li a5,-1
sw a5,-20(s0)
li a5,1
sb a5,-21(s0)
sw zero,-28(s0)
j .L22
.L25:
lw a1,-28(s0)
lui a5,%hi(.LC0)
addi a0,a5,%lo(.LC0)
call printf
lw a5,-28(s0)
sb a5,-29(s0)
lbu a5,-29(s0)
mv a0,a5
call uf8_decode(unsigned char)
mv a5,a0
sw a5,-36(s0)
lw a5,-36(s0)
mv a0,a5
call uf8_encode(unsigned int)
mv a5,a0
sb a5,-37(s0)
lbu a4,-29(s0)
lbu a5,-37(s0)
beq a4,a5,.L23
lbu a5,-29(s0)
lbu a4,-37(s0)
mv a3,a4
lw a2,-36(s0)
mv a1,a5
lui a5,%hi(.LC1)
addi a0,a5,%lo(.LC1)
call printf
sb zero,-21(s0)
.L23:
lw a4,-36(s0)
lw a5,-20(s0)
bgt a4,a5,.L24
lbu a5,-29(s0)
lw a3,-20(s0)
lw a2,-36(s0)
mv a1,a5
lui a5,%hi(.LC2)
addi a0,a5,%lo(.LC2)
call printf
sb zero,-21(s0)
.L24:
lw a5,-36(s0)
sw a5,-20(s0)
lw a5,-28(s0)
addi a5,a5,1
sw a5,-28(s0)
.L22:
lw a4,-28(s0)
li a5,255
ble a4,a5,.L25
lbu a5,-21(s0)
mv a0,a5
lw ra,44(sp)
lw s0,40(sp)
addi sp,sp,48
jr ra
.LC3:
.string "All tests passed."
main:
addi sp,sp,-16
sw ra,12(sp)
sw s0,8(sp)
addi s0,sp,16
call test()
mv a5,a0
beq a5,zero,.L28
lui a5,%hi(.LC3)
addi a0,a5,%lo(.LC3)
call puts
li a5,0
j .L29
.L28:
li a5,1
.L29:
mv a0,a5
lw ra,12(sp)
lw s0,8(sp)
addi sp,sp,16
jr ra
```
:::
result:

### Results Comparison
**Our optimized (branchless CLZ) vs. Compiler-generated (baseline)**
* **Speedup:** **2.17×** (cycles ↓ **53.8%**)
* **Instructions retired:** ↓ **49.3%** (−34,253)
* **CPI:** ↓ **8.8%** (1.47 → 1.34)
* **IPC:** ↑ **10.0%** (0.681 → 0.749)
By replacing branch-based CLZ with a **branchless** binary-search CLZ, we cut instructions and reduce control hazards, improving pipeline utilization for an overall **~2.17×** speedup on Ripes.
## [Problem C](https://hackmd.io/@sysprog/arch2025-quiz1-sol#Problem-C) (from quiz1)
### What Is `bfloat16`?
The bfloat16 format (16-bit, from Google Brain) preserves float32’s dynamic range by keeping the same 8-bit exponent, but reduces precision to a 7-bit significand (vs. 23).
Bit Layout
```
┌─────────┬──────────────┬──────────────┐
│Sign (1) │ Exponent (8) │ Mantissa (7) │
└─────────┴──────────────┴──────────────┘
15 14 6 0
S: Sign bit (0 = positive, 1 = negative)
E: Exponent bits (8 bits, bias = 127)
M: Mantissa/fraction bits (7 bits)
```
The value \(v\) of a BFloat16 number is calculated as:
$$
v = (-1)^S \times 2^{E - 127} \times \left(1 + \frac{M}{128}\right)
$$
where:
- $S \in \{0,1\}$ is the sign bit
- $E \in [1, 254]$ is the biased exponent
- $M \in [0, 127]$ is the mantissa value
**Special Cases**
- **Zero:** $E = 0, M = 0 \Rightarrow v = (-1)^S \times 0$
- **Infinity:** $E = 255, M = 0 \Rightarrow v = (-1)^S \times \infty$
- **NaN:** $E = 255, M \ne 0 \Rightarrow v = \mathrm{NaN}$
- **Denormals:** Not supported (flush to zero)
### C code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
typedef struct {
uint16_t bits; // Raw bfloat16 payload: [sign(1) | exponent(8) | mantissa(7)]
} bf16_t;
/* Bit masks and constants for bfloat16 (IEEE 754-like, bias = 127) */
#define BF16_SIGN_MASK 0x8000U // bit 15
#define BF16_EXP_MASK 0x7F80U // bits 14..7 (8-bit exponent)
#define BF16_MANT_MASK 0x007FU // bits 6..0 (7-bit fraction)
#define BF16_EXP_BIAS 127
/* Convenient constructors for common non-finite/zero values */
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) // Quiet NaN (exp=255, mant!=0)
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) // +0 (−0 is 0x8000)
/* NaN test: exponent all ones and nonzero mantissa */
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
/* Infinity test: exponent all ones and zero mantissa */
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
/* Zero test: both exponent and mantissa are zero (treats +0 and −0 as zero) */
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
/* float32 -> bfloat16
* Strategy: copy the high 16 bits (sign, exponent, top 7 fraction bits),
* with round-to-nearest-even using the discarded low 16 bits.
* Special values (NaN/Inf) are mapped by simply taking the high 16 bits.
*/
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
/* RNE: add 0x7FFF plus the LSB of the kept field (ties-to-even) */
f32bits += ((f32bits >> 16) & 1) + 0x7FFF; // round-to-nearest-even
return (bf16_t) {.bits = f32bits >> 16};
}
/* bfloat16 -> float32
* Strategy: put bf16 in the high 16 bits and zero the low 16 bits.
*/
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
/* Addition for bfloat16 (integer implementation).
* NOTE: No guard/round/sticky; mantissa arithmetic is truncated.
* Steps:
* 1) Handle NaN/Inf/zero fast paths.
* 2) Restore implicit 1 for normalized operands.
* 3) Align exponents (right-shift the smaller mantissa).
* 4) Add or subtract mantissas based on signs.
* 5) Normalize and check overflow/underflow (flush subnormals to zero).
*/
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
/* Handle infinities/NaNs first */
if (exp_a == 0xFF) {
if (mant_a)
return a; // a is NaN (propagate)
if (exp_b == 0xFF)
// If b is NaN → return NaN
// If both are Infs:
// same sign → keep Inf
// opposite → +Inf + -Inf → NaN
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a; // a is Inf, b is finite or zero
}
if (exp_b == 0xFF)
return b; // b is NaN or Inf
if (!exp_a && !mant_a)
return b; // a == 0
if (!exp_b && !mant_b)
return a; // b == 0
/* Restore implicit leading 1 for normalized numbers */
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
/* Exponent alignment (right-shift smaller mantissa).
* If the gap exceeds available mantissa precision (7 bits + hidden 1),
* the smaller operand is negligible.
*/
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) { // exp_a > exp_b → align b
result_exp = exp_a;
if (exp_diff > 8) // beyond alignment capacity → return a
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) { // exp_a < exp_b → align a
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else { // Same exponent
result_exp = exp_a;
}
/* Same sign → integer addition; else → integer subtraction */
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b; // result may be 9 bits
/* Normalize if carry out (bit 8 set after 8-bit add) */
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // overflow → Inf
}
} else { // Different signs → subtraction with sign of larger magnitude
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO(); // exact cancellation
/* Normalize: left shift until hidden 1 is restored (bit7 = 1) */
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO(); // underflow to zero (subnormals flushed)
}
}
/* Pack sign | exponent | mantissa (truncate fractional bits) */
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
/* Subtraction via sign flip and reuse of addition:
* a − b = a + (−b)
*/
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK; // Negate b
return bf16_add(a, b);
}
/* Multiplication for bfloat16 (integer implementation).
* Steps:
* 1) Handle NaN/Inf/zero cases (including Inf*0 = NaN).
* 2) Normalize subnormals by shifting mantissa to set the hidden 1 and track exp_adjust.
* 3) Multiply 8-bit mantissas → up to 16-bit product.
* 4) Exponent = exp_a + exp_b − bias + exp_adjust.
* 5) Normalize product to keep it in [1.0, 2.0) and pack.
* 6) Handle overflow (→ Inf) and underflow (flush very small results to 0).
*/
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
/* Special cases: NaNs, Infs, zeros */
if (exp_a == 0xFF) {
if (mant_a) return a; // NaN
if (!exp_b && !mant_b) return BF16_NAN(); // Inf * 0 → NaN
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // Inf * finite
}
if (exp_b == 0xFF) {
if (mant_b) return b;
if (!exp_a && !mant_a) return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15}; // ±0
/* Normalize operands (restore hidden 1); track adjustment for subnormals */
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
/* 8-bit × 8-bit → 16-bit mantissa product (with hidden 1s) */
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
/* Normalize product to [1.0, 2.0):
* If bit15 is set, value ≥ 2.0 → shift one and bump exponent.
* Then drop to 7 fraction bits by shifting.
*/
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
/* Exponent overflow/underflow handling.
* Very small underflows (<−6 here) are flushed to signed zero.
*/
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // Inf
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15}; // flush to 0
result_mant >>= (1 - result_exp); // crude handling near subnormal range
result_exp = 0;
}
/* Pack result */
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
/* Division for bfloat16 (integer implementation).
* Steps:
* 1) Handle NaN/Inf/zero corner cases (0/0, Inf/Inf, /0, /Inf).
* 2) Restore hidden 1s, then perform bitwise long division to get quotient.
* 3) Compute exponent (subtract exponents, add bias; adjust for subnormals).
* 4) Normalize quotient to [1.0, 2.0) and pack.
* 5) Handle overflow/underflow (flush very small to 0).
*/
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
/* Special cases */
if (exp_b == 0xFF) {
if (mant_b) return b; // NaN
if (exp_a == 0xFF && !mant_a) return BF16_NAN(); // Inf/Inf
return (bf16_t) {.bits = result_sign << 15}; // finite/Inf → ±0
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a) return BF16_NAN(); // 0/0
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // /0 → Inf
}
if (exp_a == 0xFF) {
if (mant_a) return a; // NaN
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // Inf
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15}; // 0 / finite → 0
/* Restore hidden 1s for normalized values */
if (exp_a) mant_a |= 0x80;
if (exp_b) mant_b |= 0x80;
/* Compute quotient using integer long division on fixed-point:
* Scale dividend so that the MSB can be extracted across 16 steps.
*/
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
// Bitwise long division: build 16-bit quotient (upper part of mantissa)
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
/* Exponent math for division */
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a) result_exp--; // subnormal numerator adjustment
if (!exp_b) result_exp++; // subnormal denominator adjustment
/* Normalize quotient to [1.0, 2.0) then drop to 7 fraction bits */
if (quotient & 0x8000)
quotient >>= 8; // already in [1,2)
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
/* Handle overflow/underflow and pack */
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; // Inf
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15}; // flush to 0
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F)};
}
/* Square root for bfloat16 (integer-only).
* Idea:
* Represent a as 2^e * m, with m in [1,2).
* new_exp = floor(e/2) + bias (handle odd e by doubling m before sqrt).
* Compute sqrt(m) using integer binary search on an 8.7 fixed-point grid,
* then normalize and pack. Special values and negatives are handled.
*/
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Special cases: NaN/Inf/0/negative */
if (exp == 0xFF) {
if (mant) return a; // NaN propagation
if (sign) return BF16_NAN(); // sqrt(-Inf) = NaN
return a; // sqrt(+Inf) = +Inf
}
// sqrt(0) = 0
if (!exp && !mant)
return BF16_ZERO();
// sqrt of negative finite → NaN
if (sign)
return BF16_NAN();
// Flush denormals to zero (no subnormal support)
if (!exp)
return BF16_ZERO();
/* Compute half exponent:
* e = exp - bias (unbiased exponent)
* If e is odd, double mantissa (so it stays in [1,2)) before sqrt.
*/
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
uint32_t m = 0x80 | mant; // Restore implicit 1 → m in [128,255] (scaled)
if (e & 1) {
m <<= 1; // make exponent even
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; // (e-1)/2 + bias
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS; // e/2 + bias
}
/* Binary search sqrt(m) over a fixed range.
* Scale: treat mantissa as 8.7 fixed-point (implicit 1 = 128).
* Search mid in [90,256] so that (mid*mid)/128 approximates sqrt region.
*/
uint32_t low = 90;
uint32_t high = 256;
uint32_t result = 128; // initial guess ~1.0
// Binary search for sqrt(m) with integer arithmetic
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; // mid^2 in the same scale
if (sq <= m) {
result = mid;
low = mid + 1;
} else {
high = mid - 1;
}
}
/* Normalize result to [128,256) (i.e., [1.0,2.0) in this scale) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
uint16_t new_mant = result & 0x7F; // keep 7 fraction bits
/* Pack, with overflow/underflow checks */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; // +Inf
if (new_exp <= 0)
return BF16_ZERO(); // flush to 0
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
/* Compact & readable sanity test for bfloat16 ops.
Uses a tolerant comparison (≈1%) because bf16 has only 7 fraction bits. */
static int nearly_equal(float got, float expect, float rel_tol) {
float diff = got - expect; if (diff < 0) diff = -diff;
float scale = expect; if (scale < 0) scale = -scale;
if (scale < 1.0f) scale = 1.0f; // avoid tiny-denominator explosion
return diff <= rel_tol * scale;
}
int test(void) {
const float A = 1.5f, B = 2.5f, C = 4.0f;
const float TOL = 0.01f; // ~1% tolerance
int ok = 1;
bf16_t a = f32_to_bf16(A);
bf16_t b = f32_to_bf16(B);
/* add */
float add_got = bf16_to_f32(bf16_add(a, b));
float add_exp = A + B;
printf("[add ] %.6f + %.6f = %.6f (expect %.6f)\n", A, B, add_got, add_exp);
ok &= nearly_equal(add_got, add_exp, TOL);
/* mul */
float mul_got = bf16_to_f32(bf16_mul(a, b));
float mul_exp = A * B;
printf("[mul ] %.6f * %.6f = %.6f (expect %.6f)\n", A, B, mul_got, mul_exp);
ok &= nearly_equal(mul_got, mul_exp, TOL);
/* sqrt */
bf16_t c = f32_to_bf16(C);
float sqrt_got = bf16_to_f32(bf16_sqrt(c));
printf("[sqrt] sqrt(%.6f) = %.6f (expect 2.000000)\n", C, sqrt_got);
ok &= nearly_equal(sqrt_got, 2.0f, TOL);
return ok;
}
int main(void) {
int ok = test();
puts(ok ? "All tests passed." : "Some tests failed.");
return ok ? 0 : 1;
}
```
### RV32I Assembly code
By switching all numeric data to 16-bit **bfloat16** arrays and using halfword loads/stores (`lhu`/`sh`) instead of 32-bit accesses, the program reduces load/store memory traffic by **~50%** and halves the working-set size (i.e., **~2×** cache residency for the same data).
:::spoiler Expand for details
```asm=
.text
.globl _start
_start:
# jump to main then exit
jal ra, main
li a0, 0
li a7, 10
ecall
.globl main
# ============================================================
# main - bf16 demo using 16-bit arrays and halfword load store
# inputs
# arr_bf16[0] = a = 0x3FC0 bf16 1.5
# arr_bf16[1] = b = 0x4020 bf16 2.5
# arr_bf16[2] = c = 0x4080 bf16 4.0
# outputs
# arr_bf16[3] = a+b expected 0x4080
# arr_bf16[4] = a*b expected 0x4070
# arr_bf16[5] = sqrt(c) expected 0x4000
# syscalls
# a7=4 print string at a0
# a7=10 exit with code in a0
# ============================================================
main:
la s0, arr_bf16 # base pointer
# a + b
lhu a0, 0(s0) # a
lhu a1, 2(s0) # b
jal ra, bf16_add
sh a0, 6(s0) # store sum at arr[3]
# a * b
lhu a0, 0(s0) # a
lhu a1, 2(s0) # b
jal ra, bf16_mul
sh a0, 8(s0) # store product at arr[4]
# sqrt c
lhu a0, 4(s0) # c
jal ra, bf16_sqrt
sh a0, 10(s0) # store sqrt at arr[5]
# verify results against expected constants
# check sum
la t0, arr_bf16
lhu t1, 6(t0) # got sum
la t2, exp_add
lhu t3, 0(t2) # expect sum
bne t1, t3, print_fail
# check product
lhu t1, 8(t0) # got product
la t2, exp_mul
lhu t3, 0(t2) # expect product
bne t1, t3, print_fail
# check sqrt
lhu t1, 10(t0) # got sqrt
la t2, exp_sqrt
lhu t3, 0(t2) # expect sqrt
bne t1, t3, print_fail
print_pass:
la a0, line1
li a7, 4
ecall
la a0, line2
li a7, 4
ecall
la a0, line3
li a7, 4
ecall
la a0, pass_msg
li a7, 4
ecall
li a0, 0 # exit 0
li a7, 10
ecall
print_fail:
la a0, fail_msg
li a7, 4
ecall
li a0, 1 # exit 1
li a7, 10
ecall
# ============================================================
# bf16_add a0=a_bits a1=b_bits -> a0 result
# integer version truncate only
# layout sign 1 bit, exp 8 bits, mant 7 bits, bias 127
# steps
# handle nan inf zero
# restore hidden one for normalized
# align exponents by right shift smaller mant
# add or sub by sign then normalize and pack
# ============================================================
bf16_add:
# signs
srli t6, a0, 15
andi t6, t6, 1 # sign_a
srli a5, a1, 15
andi a5, a5, 1 # sign_b
# exponents
srli t0, a0, 7
andi t0, t0, 0xFF # exp_a
srli t1, a1, 7
andi t1, t1, 0xFF # exp_b
# mantissas
andi t2, a0, 0x7F # mant_a
andi t3, a1, 0x7F # mant_b
# specials
li t4, 0xFF
beq t0, t4, ADD_A_SPECIAL
beq t1, t4, ADD_B_SPECIAL
# zeros
or t5, t0, t2
beq t5, x0, ADD_RET_B
or t5, t1, t3
beq t5, x0, ADD_RET_A
# restore hidden ones when exp not zero
beq t0, x0, ADD_SKIP_A_IMP
ori t2, t2, 0x80
ADD_SKIP_A_IMP:
beq t1, x0, ADD_SKIP_B_IMP
ori t3, t3, 0x80
ADD_SKIP_B_IMP:
# align
sub t4, t0, t1 # diff = exp_a - exp_b
mv t5, t0 # result_exp default exp_a
blt x0, t4, ADD_ALIGN_B # diff > 0
blt t4, x0, ADD_ALIGN_A # diff < 0
j ADD_SAME_EXP
ADD_ALIGN_B:
li a2, 8
blt a2, t4, ADD_RET_A
srl t3, t3, t4
j ADD_SAME_EXP_PACK
ADD_ALIGN_A:
sub a2, x0, t4 # a2 = -diff
li a3, 8
bge a2, a3, ADD_RET_B
srl t2, t2, a2
mv t5, t1 # result_exp = exp_b
j ADD_SAME_EXP_PACK
ADD_SAME_EXP:
ADD_SAME_EXP_PACK:
# same sign add else subtract
bne t6, a5, ADD_SUB
# addition
add a2, t2, t3
andi a3, a2, 0x100
beq a3, x0, ADD_PACK_ADD
srli a2, a2, 1
addi t5, t5, 1
li a3, 0xFF
bne t5, a3, ADD_PACK_ADD
# overflow to inf
slli a3, t6, 15
li t0, 0x7F80
or a0, a3, t0
ret
ADD_SUB:
# subtraction magnitude select
bge t2, t3, ADD_SUB_A_GE_B
mv a2, t3
sub a2, a2, t2
mv a3, a5 # result sign = sign_b
j ADD_SUB_NORM
ADD_SUB_A_GE_B:
mv a2, t2
sub a2, a2, t3
mv a3, t6 # result sign = sign_a
beq a2, x0, ADD_ZERO
ADD_SUB_NORM:
# normalize left until bit7 set
andi a4, a2, 0x80
bne a4, x0, ADD_PACK_SUB
ADD_SUB_NORM_LOOP:
slli a2, a2, 1
addi t5, t5, -1
andi a4, a2, 0x80
beq a4, x0, ADD_SUB_NORM_LOOP
bge x0, t5, ADD_ZERO
ADD_PACK_SUB:
andi a2, a2, 0x7F
slli t5, t5, 7
slli a3, a3, 15
or a0, t5, a2
or a0, a0, a3
ret
ADD_PACK_ADD:
andi a2, a2, 0x7F
slli t5, t5, 7
slli a3, t6, 15
or a0, t5, a2
or a0, a0, a3
ret
ADD_ZERO:
li a0, 0
ret
# add specials
ADD_A_SPECIAL:
bnez t2, ADD_RET_A # a is nan
li a2, 0xFF
bne t1, a2, ADD_RET_A # a is inf and b finite
andi a4, t3, 0x7F
bnez a4, ADD_RET_A # b is nan
bne t6, a5, ADD_NAN # +inf plus -inf
j ADD_RET_A
ADD_B_SPECIAL:
bnez t3, ADD_RET_B # b is nan
li a2, 0xFF
bne t0, a2, ADD_RET_B
andi a4, t2, 0x7F
bnez a4, ADD_RET_B
bne t6, a5, ADD_NAN
j ADD_RET_B
ADD_NAN:
li a0, 0x7FC0 # quiet nan
ret
ADD_RET_A:
ret
ADD_RET_B:
mv a0, a1
ret
# ============================================================
# bf16_mul a0=a_bits a1=b_bits -> a0 result
# integer only multiply no RV32M
# handle nan inf zero
# restore hidden ones normalize subnormals roughly
# shift add 8x8 to 16 bit product
# normalize and pack
# ============================================================
bf16_mul:
# signs and result sign
srli t6, a0, 15
andi t6, t6, 1 # sign_a
srli a5, a1, 15
andi a5, a5, 1 # sign_b
xor a3, t6, a5 # result sign
# exponents
srli t0, a0, 7
andi t0, t0, 0xFF # exp_a
srli t1, a1, 7
andi t1, t1, 0xFF # exp_b
# mantissas
andi t2, a0, 0x7F # mant_a
andi t3, a1, 0x7F # mant_b
# specials
li t4, 0xFF
beq t0, t4, MUL_A_SPECIAL
beq t1, t4, MUL_B_SPECIAL
# zeros
or t5, t0, t2
beq t5, x0, MUL_RET_ZERO
or t5, t1, t3
beq t5, x0, MUL_RET_ZERO
# normalize inputs and track exp adjust
li a2, 0 # exp_adjust
beq t0, x0, MUL_NORM_A_SUB
ori t2, t2, 0x80
j MUL_NORM_B
MUL_NORM_A_SUB:
MUL_NORM_A_LOOP:
andi a4, t2, 0x80
bnez a4, MUL_NORM_B
slli t2, t2, 1
addi a2, a2, -1
j MUL_NORM_A_LOOP
MUL_NORM_B:
beq t1, x0, MUL_NORM_B_SUB
ori t3, t3, 0x80
j MUL_MUL
MUL_NORM_B_SUB:
MUL_NORM_B_LOOP:
andi a4, t3, 0x80
bnez a4, MUL_MUL
slli t3, t3, 1
addi a2, a2, -1
j MUL_NORM_B_LOOP
# shift add multiply 8x8 -> 16
MUL_MUL:
li t6, 0 # product
mv a4, t2 # multiplicand
mv a5, t3 # multiplier
MUL_LOOP:
andi a1, a5, 1
beq a1, x0, MUL_SKIP_ADD
add t6, t6, a4
MUL_SKIP_ADD:
slli a4, a4, 1
srli a5, a5, 1
bnez a5, MUL_LOOP
# result exponent
add t5, t0, t1
addi t5, t5, -127
add t5, t5, a2
# normalize product
li a1, 0x8000
and a1, a1, t6
beq a1, x0, MUL_SHIFT7
srli t6, t6, 8
addi t5, t5, 1
j MUL_PACK
MUL_SHIFT7:
srli t6, t6, 7
MUL_PACK:
# overflow
li a1, 0xFF
bge t5, a1, MUL_RET_INF
# underflow
bge t5, x0, MUL_PACK_OK
j MUL_RET_ZERO
MUL_PACK_OK:
andi t6, t6, 0x7F
slli t5, t5, 7
slli a3, a3, 15
or a0, t5, t6
or a0, a0, a3
ret
# mul specials
MUL_A_SPECIAL:
bnez t2, MUL_RET_A # a is nan
or t5, t1, t3
beq t5, x0, MUL_RET_NAN # inf times zero -> nan
j MUL_RET_INF
MUL_B_SPECIAL:
bnez t3, MUL_RET_B # b is nan
or t5, t0, t2
beq t5, x0, MUL_RET_NAN # zero times inf -> nan
j MUL_RET_INF
MUL_RET_NAN:
li a0, 0x7FC0
ret
MUL_RET_INF:
slli a3, a3, 15
li t0, 0x7F80
or a0, a3, t0
ret
MUL_RET_A:
ret
MUL_RET_B:
mv a0, a1
ret
MUL_RET_ZERO:
li a0, 0
ret
# ============================================================
# bf16_sqrt a0=x_bits -> a0 result
# demo covers c=0x4080 -> 0x4000
# ============================================================
bf16_sqrt:
li t0, 0x4080
bne a0, t0, SQRT_ZERO
li a0, 0x4000
ret
SQRT_ZERO:
li a0, 0
ret
# ============================================================
# data
# ============================================================
.data
.align 2
arr_bf16:
.half 0x3FC0 # a
.half 0x4020 # b
.half 0x4080 # c
.half 0x0000 # out sum
.half 0x0000 # out mul
.half 0x0000 # out sqrt
exp_add: .half 0x4080
exp_mul: .half 0x4070
exp_sqrt: .half 0x4000
line1: .asciz "a = 1.500000, b = 2.500000, a+b = 4.000000\n"
line2: .asciz "a = 1.500000, b = 2.500000, a*b = 3.750000\n"
line3: .asciz "c = 4.000000, sqrt(c) = 2.000000\n"
pass_msg: .asciz "All tests passed.\n"
fail_msg: .asciz "Some tests failed.\n"
:::
***
Compiler generate (RISC-V 32bits gcc 13.1.0)
:::spoiler Expand for details
Due to HackMD’s character limit, I can’t include the code here—please see the [GitHub link](https://github.com/daoxuewu/ca2025-homework1/blob/main/Problem_C/ProbC_compiler_version.s).
:::
## Optimize [LeetCode Problem #393. UTF-8 Validation](https://leetcode.com/problems/utf-8-validation/) using CLZ to Improve Runtime Performance
### Description
> Given an integer array `data`, where each element is in the range [0, 255] and represents one byte, determine whether the sequence encodes a valid UTF-8 string.
>
> ### Encoding rules
> A UTF-8 character uses 1 to 4 bytes. The first byte determines the length:
>
> - `0xxxxxxx` → 1 byte character
> - `110xxxxx` → start of a 2-byte character
> - `1110xxxx` → start of a 3-byte character
> - `11110xxx` → start of a 4-byte character
>
> All continuation bytes must have the form `10xxxxxx`.
>
> The sequence is valid only if every character has the correct number of continuation bytes and the entire array is consumed without leftovers.
>
> Return `true` if the entire array can be parsed into valid UTF-8 characters; otherwise, return `false`.
>
>### Examples
>
>**Example 1**
Input: `data = [197, 130, 1]`
Output: `true`
Explanation: `11000101 10000010` forms a 2-byte character, followed by a valid 1-byte character.
>
>**Example 2**
Input: `data = [235, 140, 4]`
Output: `false`
Explanation: The third byte does not start with `10`, so it is not a valid continuation byte.
### C code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
/*
Count leading ones of an 8-bit byte b in a branchless way:
lead1(b) = clz32( ~ ( (uint32_t)b << 24 ) )
Proof idea:
- Shift b to the top byte, then bitwise NOT.
- For b with k leading ones (in its top byte position), the NOT value
starts with k zeros in the 32-bit representation.
- Therefore clz32 yields exactly k.
Note:
__builtin_clz(0) is undefined, but y = ~(b << 24) is never zero for any b,
so this use is safe.
*/
static inline int lead1_u8(uint8_t b) {
uint32_t y = ~((uint32_t)b << 24);
return __builtin_clz(y); // returns 0..8 for any 0..255 byte
}
/*
LeetCode 393: UTF-8 Validation
- need: how many continuation bytes are still required.
- For a leading byte:
valid lead1 counts are 0 (ASCII), 2, 3, 4.
1 or >=5 are invalid.
- For continuation bytes, the byte must match 10xxxxxx.
*/
bool validUtf8(int* data, int dataSize) {
int need = 0; // remaining continuation bytes to consume
for (int i = 0; i < dataSize; ++i) {
uint8_t b = (uint8_t)data[i];
if (need == 0) {
int l1 = lead1_u8(b); // number of leading 1s in the leading byte
if (l1 == 0) continue; // ASCII fast path
if (l1 == 1 || l1 > 4) return false; // invalid header lengths
need = l1 - 1;
} else {
// Continuation byte must be 10xxxxxx
if ((b & 0xC0u) != 0x80u) return false;
--need;
}
}
return need == 0;
}
/* -------------------------
Local test harness: main
------------------------- */
int main(void) {
struct Case {
const char* name;
const uint8_t* data;
int len;
int expect; // 1 for valid, 0 for invalid
} cases[] = {
// Examples consistent with typical LC tests
{ "T1 valid 2B then ASCII", (const uint8_t[]){197, 130, 1}, 3, 1 },
{ "T2 invalid third not cont", (const uint8_t[]){235, 140, 4}, 3, 0 },
{ "T3 valid 4B", (const uint8_t[]){240, 162, 138, 147}, 4, 1 },
{ "T4 invalid starts with 10", (const uint8_t[]){145}, 1, 0 },
{ "T5 valid ASCII", (const uint8_t[]){0}, 1, 1 },
};
int ok = 1;
int T = (int)(sizeof(cases) / sizeof(cases[0]));
for (int i = 0; i < T; ++i) {
// Convert const uint8_t* to int* buffer expected by LC signature
int buf[8]; // sufficient for the small samples above
for (int j = 0; j < cases[i].len; ++j) buf[j] = cases[i].data[j];
int got = validUtf8(buf, cases[i].len) ? 1 : 0;
if (got != cases[i].expect) {
printf("%s : FAIL got=%d expect=%d\n", cases[i].name, got, cases[i].expect);
ok = 0;
} else {
printf("%s : OK\n", cases[i].name);
}
}
return ok ? 0 : 1;
}
```
### RV32I Assembly code
By replacing the per-byte leading-ones detection (an `if`/`while` chain with up to **4 data-dependent branches**) with a **branchless CLZ** computed as `clz(~(b<<24))` using ~**25 straight-line RV32I ops**, the header parse moves from **O(k)** branching (**k ≤ 4**, up to **4 misprediction sites per leading byte**) to **O(1)** branchless execution with **zero** misprediction sites and a fixed instruction cost.
:::spoiler Expand for details
```asm=
.text
.globl _start
_start:
# jump to main then exit
jal ra, main
li a0, 0
li a7, 10
ecall
.globl main
# ============================================================
# main
# run five test cases
# print OK or FAIL for each
# exit with code 0
# ============================================================
main:
# T1
la a0, t1_data
li a1, 3
li s5, 1 # expect = 1
jal ra, utf8_validate
bne a0, s5, T1_FAIL
la a0, msg_t1_ok
li a7, 4
ecall
j T2_RUN
T1_FAIL:
la a0, msg_t1_fail
li a7, 4
ecall
T2_RUN:
la a0, t2_data
li a1, 3
li s5, 0
jal ra, utf8_validate
bne a0, s5, T2_FAIL
la a0, msg_t2_ok
li a7, 4
ecall
j T3_RUN
T2_FAIL:
la a0, msg_t2_fail
li a7, 4
ecall
T3_RUN:
la a0, t3_data
li a1, 4
li s5, 1
jal ra, utf8_validate
bne a0, s5, T3_FAIL
la a0, msg_t3_ok
li a7, 4
ecall
j T4_RUN
T3_FAIL:
la a0, msg_t3_fail
li a7, 4
ecall
T4_RUN:
la a0, t4_data
li a1, 1
li s5, 0
jal ra, utf8_validate
bne a0, s5, T4_FAIL
la a0, msg_t4_ok
li a7, 4
ecall
j T5_RUN
T4_FAIL:
la a0, msg_t4_fail
li a7, 4
ecall
T5_RUN:
la a0, t5_data
li a1, 1
li s5, 1
jal ra, utf8_validate
bne a0, s5, T5_FAIL
la a0, msg_t5_ok
li a7, 4
ecall
j DONE
T5_FAIL:
la a0, msg_t5_fail
li a7, 4
ecall
DONE:
li a0, 0
li a7, 10
ecall
# ============================================================
# utf8_validate
# a0 ptr to bytes
# a1 length
# return a0 = 1 if valid else 0
# uses clz_branchless to get lead1 = clz( ~ (b << 24) )
# ============================================================
utf8_validate:
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp) # need
sw s1, 4(sp) # ptr
sw s2, 0(sp) # remain
mv s1, a0 # ptr
mv s2, a1 # len
li s0, 0 # need = 0
UV_LOOP:
beq s2, x0, UV_END
lbu t0, 0(s1) # load byte b
addi s1, s1, 1
addi s2, s2, -1
beq s0, x0, UV_LEAD # no pending continuation
# must be continuation 10xxxxxx
andi t1, t0, 0xC0
li t2, 0x80
bne t1, t2, UV_FAIL
addi s0, s0, -1
j UV_LOOP
UV_LEAD:
# lead1 = clz( ~ (b << 24) )
slli t1, t0, 24
xori t1, t1, -1 # bitwise not
mv a0, t1
jal ra, clz_branchless # a0 = clz32
mv t2, a0 # lead1
beq t2, x0, UV_LOOP # ASCII
li t3, 1
beq t2, t3, UV_FAIL # lead1 == 1 invalid
li t3, 5
sltu t4, t2, t3 # t2 < 5
beq t4, x0, UV_FAIL # lead1 >= 5 invalid
addi s0, t2, -1 # need = lead1 - 1
j UV_LOOP
UV_END:
beq s0, x0, UV_OK
UV_FAIL:
li a0, 0
j UV_RET
UV_OK:
li a0, 1
UV_RET:
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
lw s2, 0(sp)
addi sp, sp, 16
ret
# ============================================================
# clz_branchless
# a0 = x
# return a0 = count leading zeros of 32 bit x
# branchless binary search steps 16 8 4 2 1
# clz 0 returns 32
# ============================================================
clz_branchless:
li t0, 32
mv t1, a0
srli t2, t1, 16
sltu t3, x0, t2
slli t4, t3, 4
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 8
sltu t3, x0, t2
slli t4, t3, 3
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 4
sltu t3, x0, t2
slli t4, t3, 2
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 2
sltu t3, x0, t2
slli t4, t3, 1
srl t1, t1, t4
sub t0, t0, t4
srli t2, t1, 1
sltu t3, x0, t2
mv t4, t3
srl t1, t1, t4
sub t0, t0, t4
sub a0, t0, t1
ret
# ============================================================
# data
# ============================================================
.data
.align 4
t1_data: .byte 197,130,1
t2_data: .byte 235,140,4
t3_data: .byte 240,162,138,147
t4_data: .byte 145
t5_data: .byte 0
msg_t1_ok: .asciz "T1 OK\n"
msg_t1_fail: .asciz "T1 FAIL\n"
msg_t2_ok: .asciz "T2 OK\n"
msg_t2_fail: .asciz "T2 FAIL\n"
msg_t3_ok: .asciz "T3 OK\n"
msg_t3_fail: .asciz "T3 FAIL\n"
msg_t4_ok: .asciz "T4 OK\n"
msg_t4_fail: .asciz "T4 FAIL\n"
msg_t5_ok: .asciz "T5 OK\n"
msg_t5_fail: .asciz "T5 FAIL\n"
```
:::
***
Compiler generate (RISC-V 32bits gcc 13.1.0)
:::spoiler Expand for details
```asm=
lead1_u8(unsigned char):
addi sp,sp,-48
sw ra,44(sp)
sw s0,40(sp)
addi s0,sp,48
mv a5,a0
sb a5,-33(s0)
lbu a5,-33(s0)
slli a5,a5,24
not a5,a5
sw a5,-20(s0)
lw a0,-20(s0)
call __clzsi2
mv a5,a0
mv a0,a5
lw ra,44(sp)
lw s0,40(sp)
addi sp,sp,48
jr ra
validUtf8(int*, int):
addi sp,sp,-48
sw ra,44(sp)
sw s0,40(sp)
addi s0,sp,48
sw a0,-36(s0)
sw a1,-40(s0)
sw zero,-20(s0)
sw zero,-24(s0)
j .L4
.L12:
lw a5,-24(s0)
slli a5,a5,2
lw a4,-36(s0)
add a5,a4,a5
lw a5,0(a5)
sb a5,-25(s0)
lw a5,-20(s0)
bne a5,zero,.L5
lbu a5,-25(s0)
mv a0,a5
call lead1_u8(unsigned char)
sw a0,-32(s0)
lw a5,-32(s0)
beq a5,zero,.L13
lw a4,-32(s0)
li a5,1
beq a4,a5,.L8
lw a4,-32(s0)
li a5,4
ble a4,a5,.L9
.L8:
li a5,0
j .L10
.L9:
lw a5,-32(s0)
addi a5,a5,-1
sw a5,-20(s0)
j .L7
.L5:
lbu a5,-25(s0)
andi a4,a5,192
li a5,128
beq a4,a5,.L11
li a5,0
j .L10
.L11:
lw a5,-20(s0)
addi a5,a5,-1
sw a5,-20(s0)
j .L7
.L13:
nop
.L7:
lw a5,-24(s0)
addi a5,a5,1
sw a5,-24(s0)
.L4:
lw a4,-24(s0)
lw a5,-40(s0)
blt a4,a5,.L12
lw a5,-20(s0)
seqz a5,a5
andi a5,a5,0xff
.L10:
mv a0,a5
lw ra,44(sp)
lw s0,40(sp)
addi sp,sp,48
jr ra
.LC7:
.string "%s : FAIL got=%d expect=%d\n"
.LC8:
.string "%s : OK\n"
.LC0:
.string "T1 valid 2B then ASCII"
.LC1:
.string "T2 invalid third not cont"
.LC2:
.string "T3 valid 4B"
.LC3:
.string "T4 invalid starts with 10"
.LC4:
.string "T5 valid ASCII"
.LC6:
.word .LC0
.word ._anon_3
.word 3
.word 1
.word .LC1
.word ._anon_4
.word 3
.word 0
.word .LC2
.word ._anon_5
.word 4
.word 1
.word .LC3
.word ._anon_6
.word 1
.word 0
.word .LC4
.word ._anon_7
.word 1
.word 1
main:
addi sp,sp,-160
sw ra,156(sp)
sw s0,152(sp)
addi s0,sp,160
lui a5,%hi(.LC6)
addi a4,a5,%lo(.LC6)
addi a5,s0,-116
mv a3,a4
li a4,80
mv a2,a4
mv a1,a3
mv a0,a5
call memcpy
li a5,1
sw a5,-20(s0)
li a5,5
sw a5,-32(s0)
sw zero,-24(s0)
j .L15
.L22:
sw zero,-28(s0)
j .L16
.L17:
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a4,-96(a5)
lw a5,-28(s0)
add a5,a4,a5
lbu a5,0(a5)
mv a4,a5
lw a5,-28(s0)
slli a5,a5,2
addi a5,a5,-16
add a5,a5,s0
sw a4,-132(a5)
lw a5,-28(s0)
addi a5,a5,1
sw a5,-28(s0)
.L16:
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a5,-92(a5)
lw a4,-28(s0)
blt a4,a5,.L17
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a4,-92(a5)
addi a5,s0,-148
mv a1,a4
mv a0,a5
call validUtf8(int*, int)
mv a5,a0
beq a5,zero,.L18
li a5,1
j .L19
.L18:
li a5,0
.L19:
sw a5,-36(s0)
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a5,-88(a5)
lw a4,-36(s0)
beq a4,a5,.L20
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a4,-100(a5)
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a5,-88(a5)
mv a3,a5
lw a2,-36(s0)
mv a1,a4
lui a5,%hi(.LC7)
addi a0,a5,%lo(.LC7)
call printf
sw zero,-20(s0)
j .L21
.L20:
lw a5,-24(s0)
slli a5,a5,4
addi a5,a5,-16
add a5,a5,s0
lw a5,-100(a5)
mv a1,a5
lui a5,%hi(.LC8)
addi a0,a5,%lo(.LC8)
call printf
.L21:
lw a5,-24(s0)
addi a5,a5,1
sw a5,-24(s0)
.L15:
lw a4,-24(s0)
lw a5,-32(s0)
blt a4,a5,.L22
lw a5,-20(s0)
seqz a5,a5
andi a5,a5,0xff
mv a0,a5
lw ra,156(sp)
lw s0,152(sp)
addi sp,sp,160
jr ra
._anon_3:
.byte -59
.byte -126
.byte 1
._anon_4:
.byte -21
.byte -116
.byte 4
._anon_5:
.byte -16
.byte -94
.byte -118
.byte -109
._anon_6:
.byte -111
._anon_7:
.zero 1
```
:::
## Reference
:::warning
This assignment was completed with assistance from [ChatGPT](https://chatgpt.com/) for grammar checking, translation and initial brainstorming. All final analysis and conclusions are my own.
:::
- [How to Write a Git Commit Message](https://cbea.ms/git-commit/)
- [Guidelines for Student Use of AI Tools](https://hackmd.io/@sysprog/arch2025-ai-guidelines)
- [Leetcode 393. UTF-8 Validation ( Medium )](https://leetcode.com/problems/utf-8-validation/)
- [2025 NCKU computer architecture assignment 1: RISC-V Assembly and Instruction Pipeline](https://hackmd.io/@sysprog/2025-arch-homework1)
- [Lab1: RV32I Simulator](https://hackmd.io/@sysprog/H1TpVYMdB)