# Assignment1: RISC-V Assembly and Instruction Pipeline
> contributed by < [Shaoen-Lin](https://github.com/Shaoen-Lin) >
## Problem `B` in Quiz1
### C code
```c
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
### RV32I Assembly code
The following RV32I code includes several test data cases and uses **automated testing** for verification.
```riscv
.data
nl: .string "\n"
msg0: .string ": produces value "
msg1: .string " but encodes back to "
msg2: .string ": value "
msg3: .string " <= previous_value "
msg4: .string "All tests passed."
.text
.global main
main:
jal ra, test
beqz a0, return_1
la a0, msg4
li a7, 4
ecall
la a0, nl
li a7, 4
ecall
li a0, 0
li a7, 10
ecall
return_1:
li a0, 1
li a7, 10
ecall
# ============================================================
# clz: Count leading zeros (binary search)
# Input : a0 (unsigned int)
# Output: a0 = leading zero count
# ============================================================
clz:
li s0, 32
li s1, 16
clz_while_loop:
srl t0, a0, s1
bnez t0, clz_if
srli s1, s1, 1
j check_condition
clz_if:
sub s0, s0, s1
add a0, t0, zero
check_condition:
bnez s1, clz_while_loop
sub a0, s0, a0
ret
# ============================================================
# uf8_decode: Decode uf8 -> uint32_t
# ============================================================
uf8_decode:
andi s0, a0, 0x0f
srli s1, a0, 4
li t0, 15
sub t0, t0, s1
li s2, 0x7FFF
srl s2, s2, t0
slli s2, s2, 4
sll a0, s0, s1
add a0, a0, s2
ret
# ============================================================
# uf8_encode: Encode uint32_t -> uf8
# ============================================================
uf8_encode:
li t0, 16
blt a0, t0, return_a0
addi sp, sp, -8
sw ra, 0(sp)
sw a0, 4(sp)
jal ra, clz
add s4, a0, zero
lw ra, 0(sp)
lw a0, 4(sp)
addi sp, sp, 8
li t0, 31
sub s5, t0, s4
li s6, 0
li s7, 0
li t0, 5
bge s5, t0, encode_if_msb_bge_5
msb_less_5: # If msb < 5, find exponent loop
li t0, 15
check_while_loop2_condition:
blt s6, t0, encode_while_loop2
encode_return:
sub s0, a0, s7
srl s0, s0, s6
slli t0, s6, 4
or a0, t0, s0
ret
encode_if_msb_bge_5: # If msb >= 5, estimate exponent
addi s6, s5, -4
li t0, 15
bgt s6, t0, set_expoent_15
back_encode_if_msb_bge_5:
li t0, 0
encode_for_loop: # overflow = (overflow << 1) + 16
bge t0, s6, out_of_encode_loop
slli t1, s7, 1
addi s7, t1, 16
addi t0, t0, 1
j encode_for_loop
out_of_encode_loop:
bgt s6, zero, encode_while_loop1
j msb_less_5
encode_while_loop1: # Adjust exponent if overflow too large
bge a0, s7, msb_less_5
addi t0, s7, -16
srli s7, t0, 1
addi s6, s6, -1
j out_of_encode_loop
set_expoent_15:
addi s6, zero, 15
j back_encode_if_msb_bge_5
encode_while_loop2: # Find exact exponent
slli s8, s7, 1
addi s8, s8, 16
blt a0, s8, encode_return
add s7, s8, x0
addi s6, s6, 1
j check_while_loop2_condition
return_a0:
ret
# ============================================================
# test: Run encode/decode test loop
# ============================================================
test:
li s0, -1
li s1, 1
addi t2, zero, 0
li t3, 256
test_for_loop:
bge t2, t3, out_test_for_loop
addi t4, t2, 0
# call uf8_decode(fl)
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
add a0, t4, x0
jal ra, uf8_decode
addi t5, a0, 0
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
# call uf8_encode(value)
addi sp, sp, -12
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
jal ra, uf8_encode
addi t6, a0, 0
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
addi sp, sp, 12
bne t4, t6, test_if_1 # if (fl != fl2)
out_test_if_1:
ble t5, s0, test_if_2 # if (value <= previous_value)
out_test_if_2:
add s0, t5, x0
addi t2, t2, 1
j test_for_loop
out_test_for_loop:
add a0, s1, zero
ret
# Print mismatch: fl != fl2
test_if_1:
mv a0, t4
li a7, 34
ecall
la a0, msg0
li a7, 4
ecall
mv a0, t5
li a7, 1
ecall
la a0, msg1
li a7, 4
ecall
mv a0, t6
li a7, 34
ecall
la a0, nl
li a7, 4
ecall
li s1, 0
j out_test_if_1
# Print non-monotonic: value <= previous_value
test_if_2:
mv a0, t4
li a7, 34
ecall
la a0, msg2
li a7, 4
ecall
mv a0, t5
li a7, 1
ecall
la a0, msg3
li a7, 4
ecall
mv a0, s0
li a7, 34
ecall
la a0, nl
li a7, 4
ecall
li s1, 0
j out_test_if_2
```
## Problem `C` in Quiz1
In the C code (both with and without bf16_sqrt), the corresponding RV32I assembly implementation requires the use of multiplication operations (*).
However, the **RV32I** base instruction set **does not include any hardware multiplication instruction** (such as mul), so this project adopts the **Egyptian Multiplication algorithm** to emulate the multiplication process in software.
### Egyptian Multiplication algorithm
Egyptian Multiplication is an algorithm based on doubling and halving.
Its core idea is to repeatedly shift the multiplicand left and shift the multiplier right, accumulating the multiplicand into the result whenever the least significant bit (LSB) of the multiplier is 1.
This sequence of shifts and conditional additions effectively simulates the multiplication process without hardware support.
Following is the algorithm of Egyptian Multiplication:
$$
nm =
\begin{cases}
\frac{n}{2} \cdot 2m & \text{if } n \text{ is even}, \\
\frac{n-1}{2} \cdot 2m + m & \text{if } n \text{ is odd}, \\
m & \text{if } n = 1.
\end{cases}
$$
Following is the RISC-V assembly code of Egyptian Multiplication:
```riscv
# =======================================================
# multiply8(a0, a1): Egyptian Multiplication
# =======================================================
# Parameters:
# a0 = multiplicand (8-bit)
# a1 = multiplier (8-bit)
#
# Return:
# a0 = 16-bit result
# =======================================================
multiply8:
mv s10, a0
mv s9, a1
li a0, 0
mul_loop:
beqz s9, mul_done
andi s8, s9, 1
beqz s8, skip_add
add a0, a0, s10
skip_add:
slli s10, s10, 1
srli s9, s9, 1
j mul_loop
mul_done:
ret
```
### C code without `bf16_sqrt`
```c
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F)};
}
```
### RV32I Assembly code without `bf16_sqrt`
The following RV32I code includes several test data cases and uses **automated testing** for verification.
```riscv
.data
BF16_SIGN_MASK: .half 0x8000
BF16_EXP_MASK: .half 0x7F80
BF16_MANT_MASK: .half 0x007F
BF16_EXP_BIAS: .half 127
BF16_NAN: .half 0x7FC0
BF16_ZERO: .half 0x0000
nl: .string "\n"
msg_case: .string "Test case "
msg_input: .string "Input: "
msg_output: .string "Output: "
msg_expect: .string "Expect: "
msg_ok: .string "✅ Correct\n"
msg_wrong: .string "❌ Wrong\n"
# ====== CONVERSION TEST String ======
msg_conv1: .string "\n=== BF16 -> F32 TESTS ===\n"
msg_conv2: .string "\n=== F32 -> BF16 TESTS ===\n"
msg_add: .string "\n=== BF16 ADD TESTS ===\n"
msg_sub: .string "\n=== BF16 SUB TESTS ===\n"
msg_mul: .string "\n=== BF16 MUL TESTS ===\n"
msg_div: .string "\n=== BF16 DIV TESTS ===\n"
# ====== ADD String ======
msg1: .string "1.0 + 2.0 = "
msg2: .string "2.0 + (-2.0) = "
msg3: .string "inf + 1.0 = "
msg4: .string "inf + -inf = "
msg5: .string "NaN + 1.0 = "
msg6: .string "1.0 + 0.015625 = "
# ====== SUB String ======
msgs1: .string "2.0 - 1.0 = "
msgs2: .string "5.0 - 2.0 = "
msgs3: .string "1.0 - 2.0 = "
msgs4: .string "(-2.0) - 3.0 = "
msgs5: .string "Inf - Inf = "
msgs6: .string "NaN - 1.0 = "
msgs7: .string "0.0 - 1.0 = "
msgs8: .string "1.0 - 0.0 = "
# ====== MUL String ======
msgm1: .string "1.0 * 2.0 = "
msgm2: .string "0.5 * 0.5 = "
msgm3: .string "-1.0 * 3.0 = "
msgm4: .string "Inf * 2.0 = "
msgm5: .string "0 * 123.0 = "
msgm6: .string "Inf * 0 = "
msgm7: .string "NaN * 5.0 = "
msgm8: .string "subnormal * 2.0 = "
# ====== DIV String ======
msgd1: .string "1.0 / 2.0 = "
msgd2: .string "2.0 / 1.0 = "
msgd3: .string "1.0 / 0.0 = "
msgd4: .string "0.0 / 1.0 = "
msgd5: .string "Inf / Inf = "
msgd6: .string "NaN / 1.0 = "
msgd7: .string "(-2.0) / 1.0 = "
# ======= CONVERSION expected output =======
conv_expect_b2f: .word 0x3F800000, 0xC0000000
conv_expect_f2b: .half 0x4060, 0xC194
# ======= ADD expected output =======
add_expect: .half 0x4040, 0x0000, 0x7F80, 0x7FC0, 0x7FC1, 0x3F80
# ======= SUB expected output =======
sub_expect: .half 0x3F80, 0x4040, 0xBF80, 0xC0A0, 0x7FC0, 0x7FC0, 0xBF80, 0x3F80
# ======= MUL expected output =======
mul_expect: .half 0x4000, 0x3E80, 0xC040, 0x7F80, 0x0000, 0x7FC0, 0x7FC1, 0x0000
# ======= DIV expected output =======
div_expect: .half 0x3F00, 0x4000, 0x7F80, 0x0000, 0x7FC0, 0x7FC0, 0xC000
.text
.global main
main:
li s0, 255
lui t0, 0x8
addi s11, t0, -0x80 # s11 = 0x7F80 (Inf mask)
# ------------------------------
# BF16 -> F32 TESTS
# ------------------------------
la a0, msg_conv1
li a7, 4
ecall
la a0, msg_input
li a7, 4
ecall
li a0, 0x3F80
li a7, 34
ecall
li a0, 0x3F80
jal ra, bf16_to_f32
mv t0, a0
la a1, conv_expect_b2f
li a2, 0
li a3, 1
jal ra, compare_result
la a0, msg_input
li a7, 4
ecall
li a0, 0xC000
li a7, 34
ecall
li a0, 0xC000
jal ra, bf16_to_f32
mv t0, a0
la a1, conv_expect_b2f
li a2, 1
li a3, 1
jal ra, compare_result
# ------------------------------
# F32 -> BF16 TESTS
# ------------------------------
la a0, msg_conv2
li a7, 4
ecall
la a0, msg_input
li a7, 4
ecall
li a0, 0x40600000
li a7, 34
ecall
li a0, 0x40600000
jal ra, f32_to_bf16
mv t0, a0
la a1, conv_expect_f2b
li a2, 0
li a3, 0
jal ra, compare_result
la a0, msg_input
li a7, 4
ecall
li a0, 0xC19447AE
li a7, 34
ecall
li a0, 0xC19447AE
jal ra, f32_to_bf16
mv t0, a0
la a1, conv_expect_f2b
li a2, 1
li a3, 0
jal ra, compare_result
# ------------------------------
# ADD TEST
# ------------------------------
la a0, msg_add
li a7, 4
ecall
la a0, msg1
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x4000
jal ra, bf16_add
la a1, add_expect
li a2, 0
jal ra, compare_result
la a0, msg2
li a7, 4
ecall
li a0, 0x4000
li a1, 0xC000
jal ra, bf16_add
la a1, add_expect
li a2, 1
jal ra, compare_result
la a0, msg3
li a7, 4
ecall
li a0, 0x7F80
li a1, 0x3F80
jal ra, bf16_add
la a1, add_expect
li a2, 2
jal ra, compare_result
la a0, msg4
li a7, 4
ecall
li a0, 0x7F80
li a1, 0xFF80
jal ra, bf16_add
la a1, add_expect
li a2, 3
jal ra, compare_result
la a0, msg5
li a7, 4
ecall
li a0, 0x7FC1
li a1, 0x3F80
jal ra, bf16_add
la a1, add_expect
li a2, 4
jal ra, compare_result
la a0, msg6
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x3800
jal ra, bf16_add
la a1, add_expect
li a2, 5
jal ra, compare_result
# ------------------------------
# SUB TEST
# ------------------------------
la a0, msg_sub
li a7, 4
ecall
la a0, msgs1
li a7, 4
ecall
li a0, 0x4000
li a1, 0x3F80
jal ra, bf16_sub
la a1, sub_expect
li a2, 0
jal ra, compare_result
la a0, msgs2
li a7, 4
ecall
li a0, 0x40A0
li a1, 0x4000
jal ra, bf16_sub
la a1, sub_expect
li a2, 1
jal ra, compare_result
la a0, msgs3
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x4000
jal ra, bf16_sub
la a1, sub_expect
li a2, 2
jal ra, compare_result
la a0, msgs4
li a7, 4
ecall
li a0, 0xC000
li a1, 0x4040
jal ra, bf16_sub
la a1, sub_expect
li a2, 3
jal ra, compare_result
la a0, msgs5
li a7, 4
ecall
li a0, 0x7F80
li a1, 0x7F80
jal ra, bf16_sub
la a1, sub_expect
li a2, 4
jal ra, compare_result
la a0, msgs6
li a7, 4
ecall
li a0, 0x7FC0
li a1, 0x3F80
jal ra, bf16_sub
la a1, sub_expect
li a2, 5
jal ra, compare_result
la a0, msgs7
li a7, 4
ecall
li a0, 0x0000
li a1, 0x3F80
jal ra, bf16_sub
la a1, sub_expect
li a2, 6
jal ra, compare_result
la a0, msgs8
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x0000
jal ra, bf16_sub
la a1, sub_expect
li a2, 7
jal ra, compare_result
# ------------------------------
# MUL TEST
# ------------------------------
la a0, msg_mul
li a7, 4
ecall
la a0, msgm1
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x4000
jal ra, bf16_mul
li a7, 34
ecall
la a1, mul_expect
li a2, 0
jal ra, compare_result
la a0, msgm2
li a7, 4
ecall
li a0, 0x3F00
li a1, 0x3F00
jal ra, bf16_mul
la a1, mul_expect
li a2, 1
jal ra, compare_result
la a0, msgm3
li a7, 4
ecall
li a0, 0xBF80
li a1, 0x4040
jal ra, bf16_mul
la a1, mul_expect
li a2, 2
jal ra, compare_result
la a0, msgm4
li a7, 4
ecall
li a0, 0x7F80
li a1, 0x4000
jal ra, bf16_mul
la a1, mul_expect
li a2, 3
jal ra, compare_result
la a0, msgm5
li a7, 4
ecall
li a0, 0x0000
li a1, 0x42F6
jal ra, bf16_mul
la a1, mul_expect
li a2, 4
jal ra, compare_result
la a0, msgm6
li a7, 4
ecall
li a0, 0x7F80
li a1, 0x0000
jal ra, bf16_mul
la a1, mul_expect
li a2, 5
jal ra, compare_result
la a0, msgm7
li a7, 4
ecall
li a0, 0x7FC1
li a1, 0x40A0
jal ra, bf16_mul
la a1, mul_expect
li a2, 6
jal ra, compare_result
la a0, msgm8
li a7, 4
ecall
li a0, 0x0001
li a1, 0x4000
jal ra, bf16_mul
la a1, mul_expect
li a2, 7
jal ra, compare_result
# ------------------------------
# DIV TEST
# ------------------------------
la a0, msg_div
li a7, 4
ecall
la a0, msgd1
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x4000
jal ra, bf16_div
la a1, div_expect
li a2, 0
jal ra, compare_result
la a0, msgd2
li a7, 4
ecall
li a0, 0x4000
li a1, 0x3F80
jal ra, bf16_div
la a1, div_expect
li a2, 1
jal ra, compare_result
la a0, msgd3
li a7, 4
ecall
li a0, 0x3F80
li a1, 0x0000
jal ra, bf16_div
la a1, div_expect
li a2, 2
jal ra, compare_result
la a0, msgd4
li a7, 4
ecall
li a0, 0x0000
li a1, 0x3F80
jal ra, bf16_div
la a1, div_expect
li a2, 3
jal ra, compare_result
la a0, msgd5
li a7, 4
ecall
li a0, 0x7F80
li a1, 0x7F80
jal ra, bf16_div
la a1, div_expect
li a2, 4
jal ra, compare_result
la a0, msgd6
li a7, 4
ecall
li a0, 0x7FC0
li a1, 0x3F80
jal ra, bf16_div
la a1, div_expect
li a2, 5
jal ra, compare_result
la a0, msgd7
li a7, 4
ecall
li a0, 0xC000
li a1, 0x3F80
jal ra, bf16_div
la a1, div_expect
li a2, 6
jal ra, compare_result
li a7, 10
ecall
# =======================================================
# compare_result(a0, expect_addr, idx, is32bit)
# =======================================================
# a0 = actual result (16-bit or 32-bit)
# a1 = address of expected value table
# a2 = test case index (0-based)
# a3 = is32bit flag (1 = 32-bit, 0 = 16-bit)
# =======================================================
compare_result:
addi sp, sp, -20
sw t0, 0(sp)
sw t1, 4(sp)
sw t2, 8(sp)
sw t3, 12(sp)
sw t4, 16(sp)
mv t0, a0
beqz a3, half_case
slli t1, a2, 2
add t2, a1, t1
lw t3, 0(t2)
j load_done
half_case:
slli t1, a2, 1
add t2, a1, t1
lhu t3, 0(t2)
load_done:
li t4, 0xFFFFFFFF
and t0, t0, t4
and t3, t3, t4
la a0, nl
li a7, 4
ecall
# ---- Output (hex) ----
la a0, msg_output
li a7, 4
ecall
mv a0, t0
li a7, 34
ecall
la a0, nl
li a7, 4
ecall
# ---- Expect (hex) ----
la a0, msg_expect
li a7, 4
ecall
mv a0, t3
li a7, 34
ecall
la a0, nl
li a7, 4
ecall
# ---- Compare ----
beq t0, t3, print_ok
la a0, msg_wrong
li a7, 4
ecall
j done
print_ok:
la a0, msg_ok
li a7, 4
ecall
done:
lw t0, 0(sp)
lw t1, 4(sp)
lw t2, 8(sp)
lw t3, 12(sp)
lw t4, 16(sp)
addi sp, sp, 20
ret
# ------------------------------
# bf16_isnan(a0): check if NaN
# ------------------------------
bf16_isnan:
la t0, BF16_EXP_MASK
lh t1, 0(t0)
and t2, a0, t1
bne t2, t1, not_nan
la t0, BF16_MANT_MASK
lh t1, 0(t0)
and t2, a0, t1
beqz t2, not_nan
li a0, 1
ret
not_nan:
li a0, 0
ret
# ------------------------------
# bf16_isinf(a0): check if Inf
# ------------------------------
bf16_isinf:
la t0, BF16_EXP_MASK
lh t1, 0(t0)
and t2, a0, t1
bne t2, t1, not_inf
la t0, BF16_MANT_MASK
lh t1, 0(t0)
and t2, a0, t1
bnez t2, not_inf
li a0, 1
ret
not_inf:
li a0, 0
ret
# ------------------------------
# bf16_iszero(a0): check if Zero
# ------------------------------
bf16_iszero:
lui t0, 8
addi t0, t0, -1
and a0, a0, t0
beqz a0, is_zero
li a0, 1
ret
is_zero:
li a0, 0
ret
# ------------------------------
# f32_to_bf16(a0): convert f32 → bf16
# ------------------------------
f32_to_bf16:
srli t0, a0, 23
li t1, 255
and t0, t0, t1
beq t0, t1, is_nan_inf
srli t0, a0, 16
andi t0, t0, 1
lui t1, 8
addi t1, t1, -1
add t0, t0, t1
add a0, a0, t0
srli a0, a0, 16
ret
is_nan_inf:
srli t0, a0, 16
lui t1, 16
addi t1, t1, -1
and a0, t0, t1
ret
# ------------------------------
# bf16_to_f32(a0): extend bf16 → f32
# ------------------------------
bf16_to_f32:
slli a0, a0, 16
ret
# ------------------------------
# bf16_add(a0, a1): BF16 addition
# ------------------------------
bf16_add:
# Extract sign/exponent/mantissa
srli t0, a0, 15
andi t0, t0, 1
srli t1, a1, 15
andi t1, t1, 1
srli t2, a0, 7
and t2, t2, s0
srli t3, a1, 7
and t3, t3, s0
andi t4, a0, 127
andi t5, a1, 127
# Handle Inf/NaN and zeros
beq t2, s0, a_inf_nan
beq t3, s0, return_b
beqz t2, check_mantissa_a_zero
jal x0, next
check_mantissa_a_zero:
beqz t4, return_b
next:
beqz t3, check_mantissa_b_zero
jal x0, next0
check_mantissa_b_zero:
beqz t5, return_a
next0:
beqz t2, skip_a_implicit_1
ori t4, t4, 0x80
skip_a_implicit_1:
beqz t3, skip_b_implicit_1
ori t5, t5, 0x80
skip_b_implicit_1:
jal x0, next1
# --- handle special cases ---
a_inf_nan:
bnez t4, return_a
beq t3, s0, a_and_b_inf_nan
ret
a_and_b_inf_nan:
bnez t5, return_b
beq t0, t1, return_b
jal x0, return_nan
# --- align exponents ---
next1:
sub s2, t2, t3
bgt s2, zero, greater_than_zero
blt s2, zero, less_than_zero
add s3, zero, t2
jal x0, next2
greater_than_zero:
add s3, zero, t2
li t6, 8
bgt s2, t6, return_a
srl t5, t5, s2
jal x0, next2
less_than_zero:
add s3, zero, t3
li t6, -8
blt s2, t6, return_b
neg t6, s2
srl t4, t4, t6
# --- perform mantissa add/sub ---
next2:
beq t0, t1, signa_eq_signb
bge t4, t5, mant_a_greater_mant_b
add s4, t1, zero
sub s5, t5, t4
jal x0, next4
mant_a_greater_mant_b:
add s4, t0, zero
sub s5, t4, t5
# --- normalize result ---
next4:
bnez s5, normalize_loop
la t6, BF16_ZERO
lh a0, 0(t6)
ret
normalize_loop:
andi t6, s5, 0x80
bnez t6, final_return
addi s3, s3, -1
blez s3, underflow_zero
slli s5, s5, 1
j normalize_loop
underflow_zero:
la t6, BF16_ZERO
lh a0, 0(t6)
ret
# --- same sign addition ---
signa_eq_signb:
add s4, t0, zero
add s5, t4, t5
andi t6, s5, 0x100
beqz t6, final_return
srli s5, s5, 1
addi s3, s3, 1
j final_return
# --- pack result ---
final_return:
slli s4, s4, 15
and s3, s3, s0
slli s3, s3, 7
andi s5, s5, 0x7F
or a0, s3, s4
or a0, a0, s5
ret
# ------------------------------
# bf16_sub(a0, a1): subtraction
# ------------------------------
bf16_sub:
lui t6, 0x8
xor a1, a1, t6 # flip sign bit
addi sp, sp, -4
sw ra, 0(sp)
jal ra, bf16_add
lw ra, 0(sp)
addi sp, sp, 4
ret
# ------------------------------
# bf16_mul(a0, a1): BF16 multiplication
# ------------------------------
bf16_mul:
# Extract sign / exponent / mantissa
srli t0, a0, 15
andi t0, t0, 1
srli t1, a1, 15
andi t1, t1, 1
srli t2, a0, 7
and t2, t2, s0
srli t3, a1, 7
and t3, t3, s0
andi t4, a0, 127
andi t5, a1, 127
xor s1, t0, t1
# Check for NaN / Inf cases
bne t2, s0, check_b_exp
bnez t4, return_a
beqz t3, check_b_mant
back1:
slli a0, s1, 15
or a0, a0, s11
ret
check_b_mant:
bnez t5, back1
j return_nan
# --- check exponent B special case ---
check_b_exp:
bne t3, s0, next6
bnez t4, return_b
beqz t2, check_a_mant
back2:
slli a0, s1, 15
or a0, a0, s11
ret
check_a_mant:
bnez t4, back2
jal x0, return_nan
# --- handle zero operands ---
next6:
beqz t2, check_a_is_zero
check_b:
beqz t3, check_b_is_zero
a_b_no_zero:
j next7
check_a_is_zero:
beqz t4, a_or_b_is_zero
j check_b
check_b_is_zero:
beqz t5, a_or_b_is_zero
j a_b_no_zero
a_or_b_is_zero:
slli a0, s1, 15
ret
# --- normalize subnormal exponents ---
next7:
add s2, zero, zero
beqz t2, exp_a_zero
ori t4, t4, 0x80
j check_b_exp_zero
exp_a_zero:
addi t2, zero, 1
andi t6, t4, 0x80
bnez t6, check_b_exp_zero
slli t4, t4, 1
addi s2, s2, -1
j exp_a_zero
# --- same for operand B ---
check_b_exp_zero:
beq t3, s0, exp_b_zero
ori t5, t5, 0x80
j next8
exp_b_zero:
addi t3, zero, 1
andi t6, t5, 0x80
bnez t6, next8
slli t5, t5, 1
addi s2, s2, -1
j exp_b_zero
# --- perform mantissa multiplication (Egyptian method) ---
next8:
addi sp, sp, -12
sw a0, 0(sp)
sw a1, 4(sp)
sw ra, 8(sp)
add a0, t4, zero
add a1, t5, zero
jal ra, multiply8
add s3, a0, zero
lw a0, 0(sp)
lw a1, 4(sp)
lw ra, 8(sp)
addi sp, sp, 12
# Calculate result exponent
add s4, t2, t3
la t6, BF16_EXP_BIAS
lh t6, 0(t6)
sub s4, s4, t6
add s4, s4, s2
# Normalize mantissa
lui t6, 0x8
and t6, s3, t6
bnez t6, ret_val_is_neg
srli t6, s3, 7
andi s3, t6, 0x7F
j ret_exp
ret_val_is_neg:
srli t6, s3, 8
andi s3, t6, 0x7F
addi s4, s4, 1
# --- check overflow/underflow ---
ret_exp:
bge s4, s0, over_ff
ble s4, zero, under_zero
mul_final_return:
slli s1, s1, 15
and s4, s4, s0
slli s4, s4, 7
andi s3, s3, 0x7F
or a0, s1, s4
or a0, a0, s3
ret
# overflow → Inf
over_ff:
slli a0, s1, 15
or a0, a0, s11
ret
# underflow → 0
under_zero:
li t6, -6
blt s4, t6, shift_sign_15
li t6, 1
sub t6, t6, s4
srl s3, s3, t6
li s4, 0
j mul_final_return
shift_sign_15:
slli a0, s1, 15
ret
# ------------------------------
# bf16_div(a0, a1): BF16 division
# ------------------------------
bf16_div:
# Extract sign / exponent / mantissa
srli t0, a0, 15
andi t0, t0, 1
srli t1, a1, 15
andi t1, t1, 1
srli t2, a0, 7
and t2, t2, s0
srli t3, a1, 7
and t3, t3, s0
andi t4, a0, 127
andi t5, a1, 127
xor s1, t0, t1
# --- check special cases ---
beq t3, s0, div_b_inf_nan
beqz t3, div_b_check_mant_0
b_is_not_zero_but_exp_0:
beq t2, s0, div_a_inf_nan
beqz t2, div_a_exp_0_check_mant_0
a_is_not_zero_but_exp_0:
bnez t2, set_a_mant
also_check_b:
bnez t3, set_b_mant
j set_div
# --- handle b = Inf/NaN ---
div_b_inf_nan:
bnez t5, return_b
beq t2, s0, b_check_a_mant
div_b_inf_nan_return:
slli a0, s1, 15
ret
b_check_a_mant:
beqz t4, return_nan
j div_b_inf_nan_return
# --- b = 0 case ---
div_b_check_mant_0:
beqz t5, div_a_check_0
j b_is_not_zero_but_exp_0
div_b_check_0_return:
slli a0, s1, 15
or a0, a0, s11
ret
div_a_check_0:
beq t2, s0, div_a_check_mant_0
j div_b_check_0_return
div_a_check_mant_0:
beqz t4, return_nan
j div_b_check_0_return
# --- a = Inf/NaN ---
div_a_inf_nan:
bnez t4, return_a
slli a0, s1, 15
ret
# --- a = 0 case ---
div_a_exp_0_check_mant_0:
beqz t4, div_a_exp_0_check_0
j a_is_not_zero_but_exp_0
div_a_exp_0_check_0:
slli a0, s1, 15
ret
# --- set implicit 1 ---
set_a_mant:
ori t4, t4, 0x80
j also_check_b
set_b_mant:
ori t5, t5, 0x80
j set_div
# ------------------------------
# division core (long division)
# ------------------------------
set_div:
slli s2, t4, 15
add s3, t5, zero
li s4, 0
li t6, 0
li s5, 16
# --- for loop: binary long division ---
for_loop:
bge t6, s5, out_for_loop
slli s4, s4, 1
addi s6, zero, 15
sub s6, s6, t6
sll s6, s3, s6
blt s2, s6, out_if
sub s2, s2, s6
ori s4, s4, 1
out_if:
addi t6, t6, 1
j for_loop
out_for_loop:
# --- compute exponent ---
sub s5, t2, t3
la t6, BF16_EXP_BIAS
lh t6, 0(t6)
add s5, s5, t6
bnez t2, exp_a_isnot_zero
addi s5, s5, -1
exp_a_isnot_zero:
bnez t3, exp_b_isnot_zero
addi s5, s5, 1
exp_b_isnot_zero:
# --- normalize quotient ---
lui t6, 0x8
and t6, s4, t6
beqz t6, check_while_condition
srli s4, s4, 8
j next9
# --- normalization loop ---
check_while_condition:
lui t6, 0x8
and t6, s4, t6
beqz t6, check_result_exp
else_shift_quotient:
srli s4, s4, 8
j next9
check_result_exp:
li s6, 1
bgt s5, s6, while_loop
j else_shift_quotient
while_loop:
slli s4, s4, 1
addi s5, s5, -1
j check_while_condition
# --- pack final result ---
next9:
andi s4, s4, 0x7F
bge s5, s0, exp_greater_all_one
ble s5, zero, exp_less_equal_zero
slli s1, s1, 15
and s5, s5, s0
slli s5, s5, 7
andi s4, s4, 0x7F
or a0, s1, s5
or a0, a0, s4
ret
# --- overflow / underflow handling ---
exp_greater_all_one:
slli a0, s1, 15
or a0, a0, s11
ret
exp_less_equal_zero:
slli a0, s1, 15
ret
# Common return labels
return_zero:
la t6, BF16_ZERO
lh a0, 0(t6)
ret
return_nan:
la t6, BF16_NAN
lh a0, 0(t6)
ret
return_a:
ret
return_b:
add a0, a1, zero
ret
# =======================================================
# multiply8(a0, a1): Egyptian Multiplication
# =======================================================
# Parameters:
# a0 = multiplicand (8-bit)
# a1 = multiplier (8-bit)
#
# Return:
# a0 = 16-bit result
# =======================================================
multiply8:
mv s10, a0
mv s9, a1
li a0, 0
mul_loop:
beqz s9, mul_done
andi s8, s9, 1
beqz s8, skip_add
add a0, a0, s10
skip_add:
slli s10, s10, 1
srli s9, s9, 1
j mul_loop
mul_done:
ret
```
### C code with `bf16_sqrt` only
```c
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF)
{
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1)
{
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
}
else
{
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high)
{
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m)
{
result = mid; /* This could be our answer */
low = mid + 1;
}
else
{
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256)
{
result >>= 1;
new_exp++;
}
else if (result < 128)
{
while (result < 128 && new_exp > 1)
{
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t){.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t){.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
```
### RV32I Assembly Code with `bf16_sqrt` only
The following RV32I code includes several test data cases and uses **automated testing** for verification.
```riscv
.data
# ======= Constants =======
BF16_SIGN_MASK: .half 0x8000
BF16_EXP_MASK: .half 0x7F80
BF16_MANT_MASK: .half 0x007F
BF16_EXP_BIAS: .half 127
BF16_NAN: .half 0x7FC0
BF16_ZERO: .half 0x0000
# ======= Common Messages =======
nl: .string "\n"
msg_case: .string "Test case "
msg_input: .string "Input: "
msg_output: .string "Output: "
msg_expect: .string "Expect: "
msg_ok: .string "✅ Correct\n"
msg_wrong: .string "❌ Wrong\n"
# ======= SQRT Test Labels =======
msg_sqrt: .string "\n=== BF16 SQRT TESTS ===\n"
msg0: .string "sqrt(0.0) = "
msg1: .string "sqrt(1.0) = "
msg2: .string "sqrt(4.0) = "
msg3: .string "sqrt(9.0) = "
msg4: .string "sqrt(-1.0) = "
msg5: .string "sqrt(+Inf) = "
msg6: .string "sqrt(-Inf) = "
msg7: .string "sqrt(0.25) = "
msg8: .string "sqrt(16.0) = "
msg9: .string "sqrt(2.0) = "
# ======= Test Inputs =======
val0: .half 0x0000 # 0.0
val1: .half 0x3F80 # 1.0
val2: .half 0x4080 # 4.0
val3: .half 0x4110 # 9.0
val4: .half 0xBF80 # -1.0
val5: .half 0x7F80 # +Inf
val6: .half 0xFF80 # -Inf
val7: .half 0x3E80 # 0.25
val8: .half 0x4180 # 16.0
val9: .half 0x4000 # 2.0
# ======= Expected Outputs =======
sqrt_expect: .half 0x0000, 0x3F80, 0x4000, 0x4040, 0x7FC0, 0x7F80, 0x7FC0, 0x3F00, 0x4080, 0x3FB5
.text
.global main
main:
li s0, 255
lui t0, 0x8
addi s11, t0, -0x80 # s11 = 0x7F80 (Inf mask)
# ==== Print Header ====
la a0, msg_sqrt
li a7, 4
ecall
# =======================================================
# Test 0: sqrt(0.0)
# =======================================================
la a0, msg0
li a7, 4
ecall
la a0, val0
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 0
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 1: sqrt(1.0)
# =======================================================
la a0, msg1
li a7, 4
ecall
la a0, val1
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 1
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 2: sqrt(4.0)
# =======================================================
la a0, msg2
li a7, 4
ecall
la a0, val2
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 2
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 3: sqrt(9.0)
# =======================================================
la a0, msg3
li a7, 4
ecall
la a0, val3
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 3
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 4: sqrt(-1.0)
# =======================================================
la a0, msg4
li a7, 4
ecall
la a0, val4
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 4
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 5: sqrt(+Inf)
# =======================================================
la a0, msg5
li a7, 4
ecall
la a0, val5
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 5
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 6: sqrt(-Inf)
# =======================================================
la a0, msg6
li a7, 4
ecall
la a0, val6
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 6
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 7: sqrt(0.25)
# =======================================================
la a0, msg7
li a7, 4
ecall
la a0, val7
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 7
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 8: sqrt(16.0)
# =======================================================
la a0, msg8
li a7, 4
ecall
la a0, val8
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 8
mv a0, t5
jal ra, compare_result
# =======================================================
# Test 9: sqrt(2.0)
# =======================================================
la a0, msg9
li a7, 4
ecall
la a0, val9
lh a0, 0(a0)
jal ra, bf16_sqrt
mv t5, a0
la a1, sqrt_expect
li a2, 9
mv a0, t5
jal ra, compare_result
# =======================================================
# End of Program
# =======================================================
li a7, 10
ecall
# =======================================================
# compare_result(a0, expect_addr, idx)
# =======================================================
# compare_result(a0, expect_addr, idx)
# a0 = actual result (16-bit)
# a1 = address of the expected value table
# a2 = test case index (0-based)
# =======================================================
compare_result:
addi sp, sp, -16
sw t0, 0(sp)
sw t1, 4(sp)
sw t2, 8(sp)
sw t3, 12(sp)
mv t0, a0
slli t1, a2, 1
add t2, a1, t1
lhu t3, 0(t2)
li t4, 0xFFFF
and t0, t0, t4
and t3, t3, t4
la a0, nl
li a7, 4
ecall
# ---- Output (hex) ----
la a0, msg_output
li a7, 4
ecall
mv a0, t0
li a7, 34
ecall
la a0, nl
li a7, 4
ecall
# ---- Expect (hex) ----
la a0, msg_expect
li a7, 4
ecall
mv a0, t3
li a7, 34
ecall
la a0, nl
li a7, 4
ecall
# ---- Compare ----
beq t0, t3, print_ok
la a0, msg_wrong
li a7, 4
ecall
j done
print_ok:
la a0, msg_ok
li a7, 4
ecall
done:
lw t0, 0(sp)
lw t1, 4(sp)
lw t2, 8(sp)
lw t3, 12(sp)
addi sp, sp, 16
ret
# ============================================================
# Function: bf16_sqrt
# Input : a0 (BF16 value)
# Output: a0 (sqrt result)
# ============================================================
bf16_sqrt:
srai s1, a0, 15
andi s1, s1, 1
srai s2, a0, 7
and s2, s2, s0
andi s3, a0, 0x7F
beq s2, s0, Handle_special_cases
beqz s2, sqrt_check_mant
bnez s1, return_nan
beqz s2, return_zero
la s4, BF16_EXP_BIAS
lh s4, 0(s4)
sub t0, s2, s4
ori t2, s3, 0x80
andi t3, t0, 1
bnez t3, Adjust_for_odd_exponents
srai t3, t0, 1
add t1, t3, s4
# Binary search: find integer sqrt(mantissa)
low_high_result:
li s5, 90
li s6, 256
li s7, 128
Binary_search_loop:
bgt s5, s6, out_Binary_Search
add t3, s5, s6
srli t3, t3, 1
addi sp, sp, -12
sw a0, 0(sp)
sw a1, 4(sp)
sw ra, 8(sp)
mv a0, t3
mv a1, t3
jal ra, multiply8
mv t4, a0
lw a0, 0(sp)
lw a1, 4(sp)
lw ra, 8(sp)
addi sp, sp, 12
srli t4, t4, 7
ble t4, t2, binary_search_if
addi s6, t3, -1
j Binary_search_loop
binary_search_if:
add s7, t3, x0
addi s5, t3, 1
j Binary_search_loop
# Post-processing after binary search
out_Binary_Search:
li t3, 256
bge s7, t3, result_greater_256
li t3, 128
blt s7, t3, result_less_128
Extract_7_bit_mantissa:
andi t3, s7, 0x7F
bge t1, s0, sqrt_overflow
ble t1, zero, return_zero
and a0, t1, s0
slli a0, a0, 7
or a0, a0, t3
ret
# Special / edge case handlers
Handle_special_cases:
bnez s3, return_a
bnez s1, return_nan
j return_a
sqrt_check_mant:
beqz s3, return_zero
Adjust_for_odd_exponents:
slli t2, t2, 1
addi t3, t0, -1
srai t3, t3, 1
add t1, t3, s4
j low_high_result
# Handle result normalization
result_greater_256:
srli s7, s7, 1
addi t1, t1, 1
j Extract_7_bit_mantissa
result_less_128:
li t3, 128
blt s7, t3, sqrt_check_new_exp
j Extract_7_bit_mantissa
sqrt_check_new_exp:
li t3, 1
bgt t1, t3, sqrt_while_loop_2
j Extract_7_bit_mantissa
sqrt_while_loop_2:
slli s7, s7, 1
addi t1, t1, -1
j result_less_128
sqrt_overflow:
add a0, zero, s11
ret
# Return helper sections
return_zero:
la t6, BF16_ZERO
lh a0, 0(t6)
ret
return_nan:
la t6, BF16_NAN
lh a0, 0(t6)
ret
return_a:
ret
return_b:
add a0, a1, zero
ret
# ============================================================
# multiply8: Egyptian Multiplication
# Input : a0=a, a1=b
# Output: a0=a*b (16-bit result)
# ============================================================
multiply8:
mv s10, a0
mv s9, a1
li a0, 0
mul_loop:
beqz s9, mul_done
andi s8, s9, 1
beqz s8, skip_add
add a0, a0, s10
skip_add:
slli s10, s10, 1
srli s9, s9, 1
j mul_loop
mul_done:
ret
```
## LeetCode 260. Single Number III
### Description
Given an integer array nums, in which exactly two elements appear only once and all the other elements appear exactly twice. Find the two elements that appear only once. You can return the answer **in any order**.
You must write an algorithm that runs in linear runtime complexity and uses only constant extra space.
Example 1:
> **Input:** nums = [1,2,1,3,2,5]
> **Output:** [3,5]
> **Explanation:** [5, 3] is also a valid answer.
Example 2:
> **Input:** nums = [-1,0]
> **Output:** [-1,0]
Example 3:
> **Input:** nums = [0,1]
> **Output:** [1,0]
Constraints:
* 2 <= nums.length <= 3 * 104
* -231 <= nums[i] <= 231 - 1
* Each integer in nums will appear twice, only two integers will appear once
### `clz` function
The purpose of this function is to accelerate **the computation of the number of leading zeros** in a 32-bit unsigned integer using a combination of right shifts and binary search.
```c
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do
{
uint32_t y = x >> c;
if (y)
{
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
```
### Solution Concept
**Step 1.** XOR all the numbers to get xor_all = a ^ b, where a and b are the two unique numbers.
**Step 2.** Find any bit in xor_all that is set to 1 — this bit represents a position where a and b differ, where the bit is also called "set bit".
* In the C code without `clz` version, we use
unsigned int set_bit = (unsigned int)xor_val & -(unsigned int)xor_val;
We are trying to isolate the lowest set bit (the rightmost 1) in xor_all. However, if xor_all equals INT_MIN (-2147483648), its binary form is:
10000000 00000000 00000000 00000000
In 32-bit signed integers, this value has no positive counterpart —
`-INT_MIN` would require 2147483648, which cannot be represented in int (the max is 2147483647).That causes signed overflow, which is **undefined behavior** in C.
So as to fix the problem, we cast signed to **unsigned**, we tell the compiler to treat the bits as pure binary, without interpreting the sign bit.
* In the C code with `clz` version, we use
int shift = 31 - clz((uint32_t)xor_val);
unsigned int mask = 1U << shift;
We are trying to isolate the highest set bit (the leftmost 1) then store it in **shift**.
Then, we sets only the highest differing bit to 1 and all others to 0. You’ll use this mask to separate a and b.
* In the C code with `__builtin_clz` version, we use
int shift = 31 - __builtin_clz((unsigned int)xor_val);
unsigned int mask = 1U << shift;
where the concept is totally same with `clz` version.
**Step 3.** Use that bit to divide the entire array into two groups:
Group 1: numbers where this bit is 0
Group 2: numbers where this bit is 1
This ensures that:
* All paired (duplicate) numbers fall into the same group (since they are identical)
* The two unique numbers a and b fall into different groups
**Step 4.** XOR all numbers within each group separately to obtain the two unique numbers.
### C code without `clz`
```c
int *singleNumber(int *nums, int numsSize, int *returnSize)
{
// This is Step 1
int xor_val = 0;
for (int i = 0; i < numsSize; i++)
xor_val ^= nums[i];
// This is Step 2
unsigned int set_bit = (unsigned int)xor_val & -(unsigned int)xor_val;
// This is Step 3
int a = 0, b = 0;
for (int i = 0; i < numsSize; i++)
{
if (nums[i] & set_bit)
a ^= nums[i];
else
b ^= nums[i];
}
// This is Step 4
int *res = malloc(sizeof(int) * 2);
res[0] = a;
res[1] = b;
*returnSize = 2;
return res;
}
```
### C code with `clz`
```c
#include <stdint.h>
#include <stdlib.h>
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do
{
uint32_t y = x >> c;
if (y)
{
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
int *singleNumber(int *nums, int numsSize, int *returnSize)
{
// This is Step 1
long xor_val = 0;
for (int i = 0; i < numsSize; i++)
{
xor_val ^= nums[i];
}
// This is Step 2
int shift = 31 - clz((uint32_t)xor_val);
unsigned int mask = 1U << shift;
// This is Step 3
int a = 0, b = 0;
for (int i = 0; i < numsSize; i++)
{
if (nums[i] & mask)
a ^= nums[i];
else
b ^= nums[i];
}
// This is Step 4
int *res = (int *)malloc(2 * sizeof(int));
res[0] = a;
res[1] = b;
*returnSize = 2;
return res;
}
```
### C code with `__builtin_clz`
```c
int *singleNumber(int *nums, int numsSize, int *returnSize)
{
int xor_val = 0;
for (int i = 0; i < numsSize; i++)
{
xor_val ^= nums[i];
}
int shift = 31 - __builtin_clz((unsigned int)xor_val);
unsigned int mask = 1U << shift;
int a = 0, b = 0;
for (int i = 0; i < numsSize; i++)
{
if (nums[i] & mask)
a ^= nums[i];
else
b ^= nums[i];
}
int *res = (int *)malloc(2 * sizeof(int));
res[0] = a;
res[1] = b;
*returnSize = 2;
return res;
}
```
### RV32I Assembly code without `clz`
```riscv
.data
# ==== Three test cases ====
nums1: .word 2, 2, 3, 3, 4, 4, 0, 1, 100, 100, 99, 99
nums1_size: .word 12
ans1: .word 1, 0
nums2: .word 101, 17, 102, 102, -98, 0, 1, 101, 0, 1, 99, -98, 100, 17
nums2_size: .word 14
ans2: .word 99, 100
nums3: .word -2, -2, -2, 2, 2, 2, -6, -9, -2, 2, 2, -5, 2, -6, -2, -10, -11, -10, -11, -2, -6, -9
nums3_size: .word 22
ans3: .word -5, -6
# === Pointer tables ===
test_cases: .word nums1, nums2, nums3
test_sizes: .word nums1_size, nums2_size, nums3_size
test_ans: .word ans1, ans2, ans3
ressize: .word 0
result: .word 0, 0
# ===== Display strings =====
msg_case: .string "Test case "
msg_input: .string "Input: "
msg_output: .string "Output: "
msg_expect: .string "Expect: "
msg_ok: .string "✅ Correct\n"
msg_wrong: .string "❌ Wrong\n"
space: .string " "
nl: .string "\n"
.text
.global main
# =====================================================
# main: iterate through three test cases for singleNumber
# =====================================================
main:
# Setup iterators for the three pointer tables
la s5, test_cases
la s6, test_sizes
la s7, test_ans
li s3, 3 # total test cases
li s4, 1 # current case index (1-based)
loop_cases:
beqz s3, end_main
# Load current case pointers
lw s0, 0(s5) # s0 = &nums
lw s1, 0(s6) # s1 = &size
lw s2, 0(s7) # s2 = &expected answer
# Print "Test case i"
la a0, msg_case
li a7, 4
ecall
mv a0, s4
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# Print input array
la a0, msg_input
li a7, 4
ecall
lw t3, 0(s1)
li t4, 0
print_input_loop:
bge t4, t3, print_input_done
slli t5, t4, 2
add t6, s0, t5
lw a0, 0(t6)
li a7, 1
ecall
la a0, space
li a7, 4
ecall
addi t4, t4, 1
j print_input_loop
print_input_done:
la a0, nl
li a7, 4
ecall
# Save caller-saved registers
addi sp, sp, -24
sw ra, 20(sp)
sw s0, 16(sp)
sw s1, 12(sp)
sw s2, 8(sp)
sw s3, 4(sp)
sw s4, 0(sp)
# Call singleNumber(nums, size, &ressize)
mv a0, s0
lw a1, 0(s1)
la a2, ressize
jal singleNumber
mv t6, a0
# Restore saved registers
lw ra, 20(sp)
lw s0, 16(sp)
lw s1, 12(sp)
lw s2, 8(sp)
lw s3, 4(sp)
lw s4, 0(sp)
addi sp, sp, 24
# Print output
la a0, msg_output
li a7, 4
ecall
lw t0, 0(t6)
mv a0, t0
li a7, 1
ecall
la a0, space
li a7, 4
ecall
lw t1, 4(t6)
mv a0, t1
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# Print expected answer
la a0, msg_expect
li a7, 4
ecall
lw t2, 0(s2)
mv a0, t2
li a7, 1
ecall
la a0, space
li a7, 4
ecall
lw t3, 4(s2)
mv a0, t3
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# Check correctness
lw t4, 0(t6)
lw t5, 4(t6)
beq t4, t2, check_second
j print_wrong
check_second:
beq t5, t3, print_ok
j print_wrong
print_ok:
la a0, msg_ok
li a7, 4
ecall
j next_case
print_wrong:
la a0, msg_wrong
li a7, 4
ecall
next_case:
addi s5, s5, 4
addi s6, s6, 4
addi s7, s7, 4
addi s4, s4, 1
addi s3, s3, -1
j loop_cases
end_main:
li a7, 10
ecall
# =====================================================
# function: singleNumber
# input: a0 = pointer to nums, a1 = numsSize, a2 = returnSize
# output: a0 = pointer to integer array
# =====================================================
singleNumber:
li s2, 0
li t0, 0
# First loop: XOR all numbers
for_loop_1:
bge t0, a1, after_for_1
slli t1, t0, 2
add t1, a0, t1
lw t1, 0(t1)
xor s2, s2, t1
addi t0, t0, 1
j for_loop_1
after_for_1:
# diff_bit = xor_all & (-xor_all)
neg t0, s2
and s4, s2, t0
li t1, 0
li t2, 0
li t0, 0
# Second loop: XOR numbers into two groups
for_loop_2:
bge t0, a1, done
slli t3, t0, 2
add t3, a0, t3
lw t3, 0(t3)
and t4, t3, s4
beqz t4, else_part
xor t1, t1, t3
j inc_i
else_part:
xor t2, t2, t3
inc_i:
addi t0, t0, 1
j for_loop_2
# Store result [a, b] and return
done:
la a0, result
sw t1, 0(a0)
sw t2, 4(a0)
li t3, 2
sw t3, 0(a2)
ret
```
### RV32I Assembly code with `clz`
```riscv
.data
nums1: .word 2, 2, 3, 3, 4, 4, 0, 1, 100, 100, 99, 99
nums1_size: .word 12
ans1: .word 1, 0
nums2: .word 101, 17, 102, 102, -98, 0, 1, 101, 0, 1, 99, -98, 100, 17
nums2_size: .word 14
ans2: .word 100, 99
nums3: .word -2, -2, -2, 2, 2, 2, -6, -9, -2, 2, 2, -5, 2, -6, -2, -10, -11, -10, -11, -2, -6, -9
nums3_size: .word 22
ans3: .word -5, -6
test_cases: .word nums1, nums2, nums3
test_sizes: .word nums1_size, nums2_size, nums3_size
test_ans: .word ans1, ans2, ans3
ressize: .word 0
result: .word 0, 0
msg_case: .string "Test case "
msg_input: .string "Input: "
msg_output: .string "Output: "
msg_expect: .string "Expect: "
msg_ok: .string "✅ Correct\n"
msg_wrong: .string "❌ Wrong\n"
space: .string " "
nl: .string "\n"
.text
.global main
main:
# initialize iterators for case tables
la s5, test_cases
la s6, test_sizes
la s7, test_ans
li s3, 3
li s4, 1
loop_cases:
beqz s3, end_main
# load current pointers
lw s0, 0(s5)
lw s1, 0(s6)
lw s2, 0(s7)
# print "Test case i"
la a0, msg_case
li a7, 4
ecall
mv a0, s4
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# print input array
la a0, msg_input
li a7, 4
ecall
lw t3, 0(s1)
li t4, 0
print_input_loop:
bge t4, t3, print_input_done
slli t5, t4, 2
add t6, s0, t5
lw a0, 0(t6)
li a7, 1
ecall
la a0, space
li a7, 4
ecall
addi t4, t4, 1
j print_input_loop
print_input_done:
la a0, nl
li a7, 4
ecall
# save registers before call
addi sp, sp, -24
sw ra, 20(sp)
sw s0, 16(sp)
sw s1, 12(sp)
sw s2, 8(sp)
sw s3, 4(sp)
sw s4, 0(sp)
# call singleNumber(a0=nums, a1=size, a2=&ressize)
mv a0, s0
lw a1, 0(s1)
la a2, ressize
jal singleNumber
mv t6, a0
# restore registers
lw ra, 20(sp)
lw s0, 16(sp)
lw s1, 12(sp)
lw s2, 8(sp)
lw s3, 4(sp)
lw s4, 0(sp)
addi sp, sp, 24
# print output
la a0, msg_output
li a7, 4
ecall
lw t0, 0(t6)
mv a0, t0
li a7, 1
ecall
la a0, space
li a7, 4
ecall
lw t1, 4(t6)
mv a0, t1
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# print expected result
la a0, msg_expect
li a7, 4
ecall
lw t2, 0(s2)
mv a0, t2
li a7, 1
ecall
la a0, space
li a7, 4
ecall
lw t3, 4(s2)
mv a0, t3
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# compare results
lw t4, 0(t6)
lw t5, 4(t6)
beq t4, t2, check_second
j print_wrong
check_second:
beq t5, t3, print_ok
j print_wrong
print_ok:
la a0, msg_ok
li a7, 4
ecall
j next_case
print_wrong:
la a0, msg_wrong
li a7, 4
ecall
next_case:
addi s5, s5, 4
addi s6, s6, 4
addi s7, s7, 4
addi s4, s4, 1
addi s3, s3, -1
j loop_cases
end_main:
li a7, 10
ecall
##################################
# Count Leading Zeros (clz)
# a0 = input, return a0 = clz(x)
##################################
clz:
li s0, 32
li s1, 16
clz_loop:
srl t0, a0, s1
bnez t0, clz_if
srli s1, s1, 1
j clz_check
clz_if:
sub s0, s0, s1
mv a0, t0
clz_check:
bnez s1, clz_loop
sub a0, s0, a0
ret
##################################
# singleNumber
# a0 = nums ptr, a1 = size, a2 = &returnSize
# return a0 = &result
##################################
singleNumber:
li s2, 0
li t0, 0
# first loop: xor all numbers
for1_cond:
blt t0, a1, for1_body
li t1, 31
addi sp, sp, -8
sw ra, 0(sp)
sw a0, 4(sp)
mv a0, s2
jal ra, clz
sub s3, t1, a0
lw a0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 8
li t1, 1
sll s4, t1, s3 # mask = 1U << shift
li t1, 0 # a = 0
li t2, 0 # b = 0
li t0, 0
j for2_cond
for1_body:
slli t1, t0, 2
add t1, a0, t1
lw t1, 0(t1)
xor s2, s2, t1
addi t0, t0, 1
j for1_cond
# second loop: split by mask bit
for2_cond:
blt t0, a1, for2_body
la a0, result
sw t1, 0(a0)
sw t2, 4(a0)
li t3, 2
sw t3, 0(a2)
ret
for2_body:
slli t3, t0, 2
add t3, a0, t3
lw t3, 0(t3)
and t4, t3, s4
bnez t4, for2_if
xor t2, t2, t3
j for2_next
for2_if:
xor t1, t1, t3
for2_next:
addi t0, t0, 1
j for2_cond
```
### Loop Unrolling Optimization
In our RV32I assembly implementation, we identified three major loops:
1. the clz (Count Leading Zeros) function,
2. the first loop (which performs a global XOR across all numbers), and
3. the second loop (which splits elements into two groups based on a mask bit).
Our goal is to optimize these three loops using loop unrolling to reduce branching overhead and improve instruction-level parallelism (ILP) on a 5-stage pipeline processor.
* **clz:**
* The clz function takes a 32-bit integer as input and performs a do-while loop that iterates exactly **five times**—checking shifts by 16, 8, 4, 2, and 1 bits respectively.
Because this loop always executes a fixed number of iterations regardless of the input value, we can **fully unroll** it.
* By expanding each iteration manually, we completely remove all branch and jump instructions associated with the loop control. This eliminates dynamic branching cost, which is especially beneficial in RV32I pipelines where every conditional branch can cause a flush and stall.
* In other words, clz becomes a straight-line sequence of shift and compare instructions. This trades a small increase in code size for a significant improvement in runtime predictability and speed.
* **First Loop: Partial Unrolling by Four**
* The first loop performs a simple reduction:
```c
for (int i = 0; i < numsSize; i++)
{
xor_val ^= nums[i];
}
```
It iterates `numSize` times — i.e. once per test data element.
* The original version processes one element per iteration:
```c
for1_body:
slli t1, t0, 2
add t1, a0, t1
lw t1, 0(t1)
xor s2, s2, t1
addi t0, t0, 1
j for1_cond
```
We apply loop unrolling by a factor of 4, so that each iteration processes four consecutive integers.
* Here is the optimized version:
```riscv
loop1_unroll4:
addi t2, s1, 4
bgt t2, a1, loop1_remainder
lw t3, 0(s0)
lw t4, 4(s0)
lw t5, 8(s0)
lw t6, 12(s0)
xor s2, s2, t3
xor s2, s2, t4
xor s2, s2, t5
xor s2, s2, t6
addi s0, s0, 16 # ptr += 16
addi s1, s1, 4 # i += 4
j loop1_unroll4
```
A small remainder loop handles any leftover elements:
```riscv
loop1_remainder:
bge s1, a1, after_loop1
loop1_rem_iter:
bge s1, a1, after_loop1
lw t3, 0(s0)
xor s2, s2, t3
addi s0, s0, 4
addi s1, s1, 1
j loop1_rem_iter
```
Although the total number of assembly lines increases, the number of branch and jump instructions per processed element decreases by roughly **75%**.
This reduces control hazards and improves ILP, since multiple lw and xor instructions can now overlap in the pipeline.
You can verify this improvement in the **Performance** section later in this document.
* **Second Loop: Partial Unrolling by Four**
* The second loop also iterates over `numSize` elements but performs conditional XOR operations depending on whether (`nums[i] & mask`) is zero.
Like the first loop, we unroll it by four iterations to minimize branch overhead while preserving correctness.
* Each unrolled block loads four elements, performs four conditional checks, and applies the corresponding XORs into two accumulators (a and b).
* Even though this increases the code length, the loop executes significantly faster in the common case.
### Optimized RV32I Assembly code with loop unrolling
```riscv
.data
nums1: .word 2, 2, 3, 3, 4, 4, 0, 1, 100, 100, 99, 99
nums1_size: .word 12
ans1: .word 1, 0
nums2: .word 101, 17, 102, 102, -98, 0, 1, 101, 0, 1, 99, -98, 100, 17
nums2_size: .word 14
ans2: .word 100, 99
nums3: .word -2, -2, -2, 2, 2, 2, -6, -9, -2, 2, 2, -5, 2, -6, -2, -10, -11, -10, -11, -2, -6, -9
nums3_size: .word 22
ans3: .word -5, -6
test_cases: .word nums1, nums2, nums3
test_sizes: .word nums1_size, nums2_size, nums3_size
test_ans: .word ans1, ans2, ans3
ressize: .word 0
result: .word 0, 0
msg_case: .string "Test case "
msg_input: .string "Input: "
msg_output: .string "Output: "
msg_expect: .string "Expect: "
msg_ok: .string "✅ Correct\n"
msg_wrong: .string "❌ Wrong\n"
space: .string " "
nl: .string "\n"
.text
.global main
main:
# initialize iterators for case tables
la s5, test_cases
la s6, test_sizes
la s7, test_ans
li s3, 3 # total cases
li s4, 1 # case index (1-based)
loop_cases:
beqz s3, end_main
# load current pointers
lw s0, 0(s5) # s0 = nums ptr
lw s1, 0(s6) # s1 = size ptr
lw s2, 0(s7) # s2 = ans ptr
# print "Test case i"
la a0, msg_case
li a7, 4
ecall
mv a0, s4
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# print input array
la a0, msg_input
li a7, 4
ecall
lw t3, 0(s1) # size
li t4, 0
print_input_loop:
bge t4, t3, print_input_done
slli t5, t4, 2
add t6, s0, t5
lw a0, 0(t6)
li a7, 1
ecall
la a0, space
li a7, 4
ecall
addi t4, t4, 1
j print_input_loop
print_input_done:
la a0, nl
li a7, 4
ecall
# save registers before call
addi sp, sp, -24
sw ra, 20(sp)
sw s0, 16(sp)
sw s1, 12(sp)
sw s2, 8(sp)
sw s3, 4(sp)
sw s4, 0(sp)
# call singleNumber(a0=nums, a1=size, a2=&ressize)
mv a0, s0
lw a1, 0(s1)
la a2, ressize
jal singleNumber
mv t6, a0 # t6 = &result
# restore registers
lw ra, 20(sp)
lw s0, 16(sp)
lw s1, 12(sp)
lw s2, 8(sp)
lw s3, 4(sp)
lw s4, 0(sp)
addi sp, sp, 24
# print output
la a0, msg_output
li a7, 4
ecall
lw t0, 0(t6)
mv a0, t0
li a7, 1
ecall
la a0, space
li a7, 4
ecall
lw t1, 4(t6)
mv a0, t1
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# print expected result
la a0, msg_expect
li a7, 4
ecall
lw t2, 0(s2)
mv a0, t2
li a7, 1
ecall
la a0, space
li a7, 4
ecall
lw t3, 4(s2)
mv a0, t3
li a7, 1
ecall
la a0, nl
li a7, 4
ecall
# compare results
lw t4, 0(t6)
lw t5, 4(t6)
beq t4, t2, check_second
j print_wrong
check_second:
beq t5, t3, print_ok
j print_wrong
print_ok:
la a0, msg_ok
li a7, 4
ecall
j next_case
print_wrong:
la a0, msg_wrong
li a7, 4
ecall
next_case:
addi s5, s5, 4
addi s6, s6, 4
addi s7, s7, 4
addi s4, s4, 1
addi s3, s3, -1
j loop_cases
end_main:
li a7, 10
ecall
##################################
# Count Leading Zeros (unrolled)
# input : a0 = 32-bit unsigned
# output: a0 = #leading zeros
##################################
clz:
addi sp, sp, -16
sw s8, 0(sp)
sw s9, 4(sp)
sw s10, 8(sp)
sw s11,12(sp)
beqz a0, clz_zero
li s8, 32
mv s9, a0
# (x >> 16)
srli s10, s9, 16
beqz s10, clz_chk8
addi s8, s8, -16
mv s9, s10
clz_chk8:
# (x >> 8)
srli s10, s9, 8
beqz s10, clz_chk4
addi s8, s8, -8
mv s9, s10
clz_chk4:
# (x >> 4)
srli s10, s9, 4
beqz s10, clz_chk2
addi s8, s8, -4
mv s9, s10
clz_chk2:
# (x >> 2)
srli s10, s9, 2
beqz s10, clz_chk1
addi s8, s8, -2
mv s9, s10
clz_chk1:
# (x >> 1)
srli s10, s9, 1
beqz s10, clz_ret
addi s8, s8, -1
mv s9, s10
clz_ret:
sub a0, s8, s9
lw s8, 0(sp)
lw s9, 4(sp)
lw s10, 8(sp)
lw s11, 12(sp)
addi sp, sp, 16
ret
clz_zero:
li a0, 32
lw s8, 0(sp)
lw s9, 4(sp)
lw s10, 8(sp)
lw s11, 12(sp)
addi sp, sp, 16
ret
##################################
# singleNumber (loop unrolled)
# input : a0 = nums*, a1 = size, a2 = &returnSize
# output: a0 = &result
##################################
singleNumber:
# first loop: XOR all numbers (unrolled ×4)
li s2, 0
li s1, 0
mv s0, a0
loop1_unroll4:
addi t2, s1, 4
bgt t2, a1, loop1_remainder
lw t3, 0(s0)
lw t4, 4(s0)
lw t5, 8(s0)
lw t6, 12(s0)
xor s2, s2, t3
xor s2, s2, t4
xor s2, s2, t5
xor s2, s2, t6
addi s0, s0, 16
addi s1, s1, 4
j loop1_unroll4
loop1_remainder:
bge s1, a1, after_loop1
loop1_rem_iter:
bge s1, a1, after_loop1
lw t3, 0(s0)
xor s2, s2, t3
addi s0, s0, 4
addi s1, s1, 1
j loop1_rem_iter
after_loop1:
# compute mask bit
li t1, 31
addi sp, sp, -8
sw ra, 0(sp)
sw a0, 4(sp)
mv a0, s2
jal ra, clz
sub s3, t1, a0
lw a0, 4(sp)
lw ra, 0(sp)
addi sp, sp, 8
li t1, 1
sll s4, t1, s3
# second loop: split by mask (unrolled ×4)
li t1, 0
li t2, 0
li s1, 0
mv s0, a0
loop2_unroll4:
addi t3, s1, 4
bgt t3, a1, loop2_remainder
lw t3, 0(s0)
lw t4, 4(s0)
lw t5, 8(s0)
lw t6, 12(s0)
# element1
and t0, t3, s4
beqz t0, l2_e1
xor t1, t1, t3
j l2_n1
l2_e1:
xor t2, t2, t3
l2_n1:
# element2
and t0, t4, s4
beqz t0, l2_e2
xor t1, t1, t4
j l2_n2
l2_e2:
xor t2, t2, t4
l2_n2:
# element3
and t0, t5, s4
beqz t0, l2_e3
xor t1, t1, t5
j l2_n3
l2_e3:
xor t2, t2, t5
l2_n3:
# element4
and t0, t6, s4
beqz t0, l2_e4
xor t1, t1, t6
j l2_n4
l2_e4:
xor t2, t2, t6
l2_n4:
addi s0, s0, 16
addi s1, s1, 4
j loop2_unroll4
loop2_remainder:
bge s1, a1, end_loop2
loop2_rem_iter:
bge s1, a1, end_loop2
lw t3, 0(s0)
and t0, t3, s4
beqz t0, l2_eR
xor t1, t1, t3
j l2_nR
l2_eR:
xor t2, t2, t3
l2_nR:
addi s0, s0, 4
addi s1, s1, 1
j loop2_rem_iter
end_loop2:
la a0, result
sw t1, 0(a0)
sw t2, 4(a0)
li t3, 2
sw t3, 0(a2)
ret
```
## Performance
By using [Ripes](https://github.com/mortbopet/Ripes) simulator, We evaluated multiple implementations of the same algorithm — both at the C and assembly levels — with and without loop unrolling and the clz optimization.
| C code without `clz` | C code with `clz` | C code with `__builtin_clz` |
| -------- | -------- | -------- |
|  |  |  |
* We can see that when we directly compiled C code into Ripes, the number of cycles is extremely high (≈98 k).
* This is because the compiler generates many generic operations that are not tailored for the RV32I instruction set.
* Even though the clz function or built-in intrinsic helps reduce the logical complexity, it does not reduce the actual loop control cost, because the compiler still emits similar branching and comparison sequences.
| RV32I Assembly code without `clz` | RV32I Assembly code with `clz` | Optimized RV32I Assembly code with loop unrolling |
| -------- | -------- | -------- |
|  | |  |
* As you can see, the hand-written RV32I assembly is far more efficient than the C-compiled output.
* However, the key improvement comes from loop unrolling:
* The optimized version executes **fewer total instructions (−20%)**.
* The total **cycle count decreases by roughly 27%** compared to the baseline assembly version.
* **CPI improves from 1.60 → 1.45**, and IPC increases to 0.692, indicating higher pipeline utilization.
> As the number of test data and test cases increases, the performance advantage of the unrolled version becomes even more significant.
> This is because loop unrolling amortizes the loop control overhead across more iterations, leading to higher efficiency gains for larger datasets.
> In our experiment, there are **only three test cases**, and **each test case contains fewer than twenty data elements** on average — yet the optimized version already shows a noticeable performance gap.
> This clearly demonstrates that even with small input sizes, loop unrolling can effectively reduce control overhead and improve execution throughput on RV32I processors.
## Analysis
Again, we test our code by using [Ripes](https://github.com/mortbopet/Ripes) simulator.
### 5-stage pipelined processor
Ripes supports three types of processors:
1. Single-cycle processor
2. 5-stage pipelined processor w/o forwarding or hazard detection
3. 5-stage pipelined processor w/o hazard detection
4. 5-Stage pipelined processor w/o forwarding unit
5. 5-stage pipelined processor
6. 6-stage dual-issue processor
For this assignment, the 5-stage pipelined processor has been selected as the target device because it is the most commonly used architecture.
Its block diagram look like this:

The "5-stage" means this processor using five-stage pipeline to parallelize instructions. The stages are:
| Stage | Description |
| -------- | -------- |
| IF | Instruction Fetch |
| ID | Instruction Decode and Register Fetch |
| EX | Execution or Address Calculation |
| MEM | Memory Access |
| WB | Register Write Back |
**Main Task:**
* **IF Stage :** Fetch the next instruction from memory (using the Program Counter) and update the PC.
* **ID Stage :** Decode the instruction, read source registers, and determine the operation type.
* **EXE Stage :** Perform arithmetic or logic operations in the ALU, or compute a memory address.
* **MEM Stage :** For load/store instructions, read from or write to data memory.
* **WB Stage :** Write the result from the ALU or memory back to the destination register.
Instruction in different type of format will go through 5 stages with different signal turned on. Let's discuss I-type format in detail with an example as below.
#### I-type format
`slli x8, x4, 5`
This is an **I-type** instruction in RISC-V, performing an **immediate logically shift left**, where the value of register x4 is logically shifted left 5 bits, and the result is stored in register x8.
##### I-Type Instruction Format:
`| funct7[31:25] | shamt[24:20] | rs1[19:15] | funct3[14:12] | rd[11:7] | opcode[6:0] |`
* **funct7 :** the 7-bit function code for `slli` is `0000000 `
* **shamt :** shift amount is `5(00101)`
* **rs1 :** the source register `x4(00100)`
* **funct3 :** the 3-bit function code for `slli` is `001 `
* **rd :** the destination register is `x8(01000) `
* **opcode :** the opcode for `slli` is `0010011 `
Thus, The machine code of `slli x8, x4, 5` is `0000000 00101 00100 001 01000 0010011(bin)` = `0x00521413(hex)`
##### 1. Instruction Fetch (IF)

* We start from instruction put at `0x00000000`, so `addr` is equal to `0x00000000`
* The machine code of the instruction is `0x00521413`, so `instr` is equal to 0x00521413.
* PC will increment by 4 automatically using the above adder, because the instruction of RV32I is 32 bits long.
* Because there is no branch occur, next instruction will be at PC + 4, so the multiplexer before PC choose input come from adder.
##### 2. Instruction Fetch (ID)

* Instruction `0x00521413` is decoded to five part:
* `opcode` = `slli`
* `Wr idx` = `0x08`
* `imm.` = `0x00000005 `
* `R1 idx` = `0x04`
* `R2 idx` = `0x05`
* Though I-type format read `R1 idx(0x04)` and `R2 idx(0x05)`, the register value in `R2 idx` will not be used in EX stage.
* `R1 idx(0x04)` and `R2 idx(0x05)` will be sent to Registers for extracting the register value which are both `0x00000000`, because the initial value is `0x00000000`.
* Current PC value `(0x00000000)`, next PC value `(0x00000004)` and `Wr idx (0x08)` are just sent through this stage, we don't use them.
##### 3. Execute (EX)

* First level multiplexers choose value come from `Reg 1` and `Reg 2`, but this is an I-type format instruction, we don't use `Reg 2`. So they are filtered by second level multiplexer.
* Second level multiplexer choose value come from `Reg 1` rather than current PC value (upper one) and immediate (lower one) as `Op1` and `Op2` of ALU for executing shift left instruction.
* ALU add two operand togeher, so the `Res` is equal to `0x00000000` (0 << 5 is also 0).
* `Reg 1` and `Reg 2` are also send to branch block, but no branch is taken.
* Next PC value `(0x00000004)` and `Wr idx (0x08)` are just send through this stage, we don't use them.
##### 4. Memory access (MEM)

* `Res` from ALU is send to 3 ways:
* Pass through this stage and go to WB stage (the lower line)
* Send back to EX stage for next instruction to use (the upper line)
* Use as data memory address (the middle line). Memory read data at address `0x00000000`, so Read out is equal to `0x00521413`. The table below denotes the data section of memory.
* 
* Otherwise, `Reg 2` is send to `Data in`, but memory doesn't enable writing.
* Next PC value `(0x00000004)` and `Wr idx (0x08)` are just send through this stage, we don't use them.
##### 5. Register write back (WB)

* The multiplexer choose `Res` from ALU(the middle line) as final output, so the output value is `0x00000000`.
* The output value and `Wr idx` are send back to registers block in ID stage, and `Wr En` is 1. Finally, the value `0x00000000` will be write into `x8` register, whose ABI name is `s0`.
After all these stage are done, the register is updated like this:

Finally, all the source code mentioned above can be found [Here](https://github.com/Shaoen-Lin/ca2025-quizzes). Feel free to check it out !
## Reference
* [Quiz1 of Computer Architecture (2025 Fall)](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
* [Assignment 1: RISC-V Assembly and Instruction Pipeline](https://hackmd.io/@sysprog/2025-arch-homework1)
* [LeetCode 260. Single Number III](https://leetcode.com/problems/single-number-iii/description/)