owned this note
owned this note
Published
Linked with GitHub
# Assignment1: RISC-V Assembly and Instruction Pipeline
contributed by < [JimmyCh1025](https://github.com/JimmyCh1025) >
[TOC]
###### tags: `RISC-V` `Computer Architure 2025`
## Problem B
### UF8
UF8 implements a logarithmic 8-bit codec that is suitable for representing level-of-detail (LOD) distances and fog density values, but is not appropriate for financial calculations. It maps 20-bit unsigned integers to 8-bit symbols using logarithmic quantization, delivering 2.5:1 compression and a relative error of ≤6.25%.
UF8 format:
| exponent | mantissa |
|--------------|---------------|
| 4bits | 4bits |
Decoding
\begin{gather*}
D(b) = m \cdot 2^e + (2^e - 1) \cdot 16
\end{gather*}
* The maximum exponent and mantissa are both 15 (4 bits), so the base value is 15 × 2¹⁵ = 491,520. The offset is used to bring the decoded value closer to the original input(15 × 2¹⁵ + (2¹⁵ - 1) × 16 = 1015792 ~ 1048575 = 2²⁰ - 1).
* To avoid overlapping value ranges between exponent groups, an offset is added.
* Without offset:
* e = 0 → [0 × 2⁰, 15 × 2⁰] = [0, 15]
* e = 1 → [0 × 2¹, 15 × 2¹] = [0, 30] → Overlapping
* After applying offset:
* e = 0 → [0, 15]
* e = 1 → [16, 46]
Encoding
\begin{gather*}
E(v) = \begin{cases}
v, & \text{if } v < 16 \\
16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise}
\end{cases}
\end{gather*}
* For input values less than 16, the UF8 encoding is lossless:
E(v)=v, because the 4-bit mantissa can directly represent these values without exponent shifting.
* If the MSB ≥ 5, then exponent is set to msb - 4 to leave 4 bits for the mantissa.
UF8 only supports up to ~1 million (due to 4-bit exponent and mantissa), so large values will be clipped to the maximum representable range.
* Normal values are encoded by finding a suitable exponent, computing the offset to avoid overlapping ranges, and placing the remaining value into the mantissa field.
### C code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
### Assembly code
:::spoiler More detailed information
```assembly=
#=======================================================================================
# File : uf8.s
# Author : Jimmy Chen
# Date : 2025-10-08
# Brief: Implements encoding and decoding for the custom 8-bit UF8 format.
#=======================================================================================
.data
# string
all_tests_passed_str:
.string "All tests passed.\n"
mismatch_prod_val_str:
.string ": produces value "
mismatch_encode_str:
.string " but encodes back to "
not_incr_val_str:
.string ": value "
not_incr_prev_val_str:
.string " <= previous_value "
endline_str:
.string "\n"
.text
.global main
# =======================================================
# Function : main()
# Parameter : none
# Variable :
# Description : execute test function and print pass if test return true.
# Return : 0 (exit program)
# =======================================================
main:
jal ra, test # call test
beq a0, x0, main_return1 # return 1
# printf "All tests passed.\n"
la a0, all_tests_passed_str
li a7, 4 # System call number 4 (print string)
ecall
# return 0
li a7, 10
add a0, a0, x0
ecall
main_return1:
# return 1
li a7, 10
addi a0, a0, 1
ecall
# =======================================================
# Function : test()
# Parameter : none
# Variable :
# Description : Performs round-trip tests on UF8 encoding and decoding to ensure correctness.
# Return : 1(true) or 0(false)
# =======================================================
test:
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp)
sw s1, 4(sp)
sw s2, 0(sp)
addi s0, x0, -1 # previous_value = -1
addi s1, x0, 0 # i = 0
addi s2, x0, 1 # passed = true
test_loop:
# if i > 255, break
li s3, 0xFF
blt s3, s1, test_done
# uint8_t fl = i
andi s3, s1, 0xFF
addi a0, s3, 0
# value = uf8_decode(fl)
jal ra, uf8_decode
addi s4, a0, 0
# fl2 = uf8_encode(value)
addi a0, s4, 0
jal ra, uf8_encode
andi s5, a0, 0xFF
# if (fl != fl2)
bne s3, s5, mismatch
# if (value <= previous_value)
ble s4, s0, not_increasing
# previous_value = value
mv s0, s4
addi s1, s1, 1 # i++
j test_loop
mismatch:
# printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2);
addi a0, s3, 0
li a7, 1
ecall
la a0, mismatch_prod_val_str
li a7, 4
ecall
addi a0, s4, 0
li a7, 1
ecall
la a0, mismatch_encode_str
li a7, 4
ecall
addi a0, s5, 0
li a7, 1
ecall
la a0, endline_str
li a7, 4
ecall
add s2, x0, x0 # passed = false
j continue_loop
not_increasing:
# printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value);
addi a0, s3, 0 # fl
li a7, 1
ecall
la a0, not_incr_val_str
li a7, 4
ecall
addi a0, s4, 0 # value
li a7, 1
ecall
la a0, not_incr_prev_val_str
li a7, 4
ecall
addi a0, s0, 0 # previous_value
li a7, 1
ecall
la a0, endline_str
li a7, 4
ecall
add s2, x0, x0 # passed = false
continue_loop:
# previous_value = value
addi s0, t0, 0
# i++
addi s1, s1, 1
j test_loop
test_done:
# return passed
addi a0, s2, 0
lw s2, 0(sp)
lw s1, 4(sp)
lw s0, 8(sp)
lw ra, 12(sp)
addi sp, sp, 16
jalr x0, x1, 0
# =======================================================
# Function : uf8_decode()
# Parameter : uf8 fl
# Variable :
# Description : Decodes an 8-bit UF8 value into a 32-bit unsigned integer.
# Return : uint32_t uf8_decode value
# =======================================================
uf8_decode:
# u32 mantissa = fl & 0x0F
andi t0, a0, 0x0F
# u8 exponent = fl >> 4
srli t1, a0, 4
andi t1, t1, 0xFF
# u32 offset = (0x7FFF >> (15-exponent)) << 4
li t2, 15
li t3, 0x7FFF
sub t2, t2, t1 # (15-exponent)
srl t3, t3, t2 # (0x7FFF >> (15-exponent))
slli t4, t3, 4 # (0x7FFF >> (15-exponent)) << 4
# return (mantissa << exponent) + offset;
sll a0, t0, t1
add a0, a0, t4
jalr x0, x1, 0
# =======================================================
# Function : uf8_encode()
# Parameter : uint32_t value
# Variable :
# Description : Encodes a 32-bit unsigned integer into an 8-bit UF8 representation.
# Return : uf8 endcode value
# =======================================================
uf8_encode:
li t0, 16
blt a0, t0, uf8_encode_return_val
# assign t0 = value
addi t0, a0, 0
# call clz
addi sp, sp, -8
sw ra, 4(sp)
sw t0, 0(sp)
jal ra, clz
lw t0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
# lz = clz(value)
add t1, a0, x0
# msb = 31 - lz
li t2, 31
sub t2, t2, t1
# exponent = 0
li t3, 0
# overflow = 0
li t4, 0
# if msb < 5
li t5, 5
blt t2, t5, uf8_find_extra_exp
uf8_encode_bge5:
# exponent = msb - 4
addi t3, t2, -4
andi t3, t3, 0xFF
li t5, 15
# e = 0
li t6, 0
# if 15 < exp
blt t5, t3, uf8_exp_bg15
j uf8_calculate_overflow
uf8_exp_bg15:
# exp = 15
li t3, 15
uf8_calculate_overflow:
# e < exp
bge t6, t3, uf8_adjust_if_off
# (overflow << 1)+16
slli t4, t4, 1
addi t4, t4, 16
# e++
addi t6, t6, 1
j uf8_calculate_overflow
uf8_adjust_if_off:
# 0 >= exponent
bge x0, t3, uf8_find_extra_exp
# value >= overflow
bge t0, t4, uf8_find_extra_exp
# overflow = (overflow-16) >> 1
addi t4, t4, -16
srli t4, t4, 1
# exp--
addi t3, t3, -1
j uf8_adjust_if_off
uf8_find_extra_exp:
# while exp < 15
li t5, 15
# if exp >= 15, return value
bge t3, t5, uf8_encode_return
# next_overflow = (overflow << 1)+16
slli t5, t4, 1
addi t5, t5, 16
# value < next_overflow
blt t0, t5, uf8_encode_return
# overflow = next_overflow
add t4, t5, x0
# exp++
addi t3, t3, 1
j uf8_find_extra_exp
uf8_encode_return:
# (value - overflow) >> exponent
sub t5, t0, t4
srl t5, t5, t3
andi t5, t5, 0xFF
# (exponent << 4)|mantissa
slli t1, t3, 4
andi a0, t1, 0xFF
or a0, a0, t5
jalr x0, x1, 0
uf8_encode_return_val:
# return value(a0)
jalr x0, x1, 0
# =======================================================
# Function : clz()
# Parameter : uint32_t x
# Variable :
# Description : Counts the number of leading zeros in a 32-bit unsigned integer.
# Return : unsigned val
# =======================================================
clz:
li t0, 32 # t0 = n
li t1, 16 # t1 = c
clz_whileLoop:
srl t2, a0, t1 # y = t2, y = x >> c
beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit
sub t0, t0, t1 # n = n - c
addi a0, t2, 0 # x = y
shift_right_1bit:
srli t1, t1, 1 # c = c >> 1
bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop
sub a0, t0, a0 # x = n - x
jalr x0, x1, 0 # return x
```
:::
### Result

----
## Problem C
### Float
| sign | exponent | mantissa |
| -------- | -------- | -------- |
| 1 bit | 8 bits | 23 bits |
### BF16
| sign | exponent | mantissa |
| -------- | -------- | -------- |
| 1 bit | 8 bits | 7 bits |
* Normalization : $±(1.mantissa) × 2^{exponent-127}$
* Denormalization : $±(0.mantissa) × 2^{-126}$
* NAN : The exponent bits are all 1, and the mantissa bits are not all 0.
* INF : The exponent bits are all 1, and the mantissa bits are all 0.
* ZERO : The exponent bits are all 0, and the mantissa bits are all 0.
* F32toB16 :
* The line f32bits += ((f32bits >> 16) & 1) + 0x7FFF implements round to nearest, ties to even.
It adds 0x7FFF for the rounding offset, and uses bit 16 (the 17th bit counting from 0) to decide whether to round up in tie cases.
### C code
```c=
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F)};
}
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
```
### Assembly code
:::spoiler More detailed information
```assembly=
#=======================================================================================
# File : bfloat16.s
# Author : Jimmy Chen
# Date : 2025-10-08
# Brief : Implementation of bfloat16 arithmetic operations (add, sub, mul, div, sqrt)
# including conversion between float and bfloat16, with IEEE 754 support.
#=======================================================================================
.data
# define
.equ BF16_SIGN_MASK, 0x8000
.equ BF16_EXP_MASK, 0x7F80
.equ BF16_MANT_MASK, 0x007F
.equ BF16_EXP_BIAS, 127
.equ BF16_NAN, 0x7FC0
.equ BF16_ZERO, 0x0000
# input data
input_a:
.half 0x0000, 0x8000, 0x3f80, 0x4000, 0x4040, 0xbf80, 0x7f80, 0xff80, 0x7fc1, 0x4110, 0xc080, 0x0001, 0x0a4b, 0x40a0, 0x7f00
input_b:
.half 0x0000, 0x0000, 0x4000, 0x3f80, 0x4000, 0x3f80, 0x3f80, 0x7f80, 0x3f80, 0x0000, 0x0000, 0x0001, 0x0a4b, 0x0000, 0x7f00
input_float:
.word 0x3F800000, 0x3FC00000, 0x3F81AE14, 0x00000000, 0x80000000, 0x477FE000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x00000001, 0xC0200000, 0x40490FDB, 0xC2F6E979, 0x3EAAAAAB, 0x2F06C6D6
# ans
isNanAns:
.half 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
isInfAns:
.half 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0001, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
isZeroAns:
.half 0x0001, 0x0001, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000
f32tob16Ans:
.half 0x3f80, 0x3fc0, 0x3f82, 0x0000, 0x8000, 0x4780, 0x7f80, 0xff80, 0x7fc0, 0x0000, 0xc020, 0x4049, 0xc2f7, 0x3eab, 0x2f07
b16tof32Ans:
.word 0x00000000, 0x80000000, 0x3f800000, 0x40000000, 0x40400000, 0xbf800000, 0x7f800000, 0xff800000, 0x7fc10000, 0x41100000, 0xc0800000, 0x00010000, 0x0a4b0000, 0x40a00000, 0x7f000000
addAns:
.half 0x0000, 0x0000, 0x4040, 0x4040, 0x40a0, 0x0000, 0x7f80, 0x7fc0, 0x7fc1, 0x4110, 0xc080, 0x0002, 0x0acb, 0x40a0, 0x7f80
subAns:
.half 0x8000, 0x8000, 0xbf80, 0x3f80, 0x3f80, 0xc000, 0x7f80, 0xff80, 0x7fc1, 0x4110, 0xc080, 0x0000, 0x0000, 0x40a0, 0x0000
mulAns:
.half 0x0000, 0x8000, 0x4000, 0x4000, 0x40c0, 0xbf80, 0x7f80, 0xff80, 0x7fc1, 0x0000, 0x8000, 0x0000, 0x0000, 0x0000, 0x7f80
divAns:
.half 0x7fc0, 0x7fc0, 0x3f00, 0x4000, 0x3fc0, 0xbf80, 0x7f80, 0x7fc0, 0x7fc1, 0x7f80, 0xff80, 0x3f80, 0x3f80, 0x7f80, 0x3f80
sqrtAns:
.half 0x0000, 0x0000, 0x3f80, 0x3fb5, 0x3fdd, 0x7fc0, 0x7f80, 0x7fc0, 0x7fc1, 0x4040, 0x7fc0, 0x0000, 0x24e4, 0x400f, 0x5f35
# string
msg_test_nan:
.string "======Test NAN======\n"
msg_test_inf:
.string "======Test INF======\n"
msg_test_zero:
.string "======Test ZERO======\n"
msg_test_f32tob16:
.string "======Test F32ToB16======\n"
msg_test_b16tof32:
.string "======Test B16ToF32======\n"
msg_test_add:
.string "======Test ADD======\n"
msg_test_sub:
.string "======Test SUB======\n"
msg_test_mul:
.string "======Test MUL======\n"
msg_test_div:
.string "======Test DIV======\n"
msg_test_sqrt:
.string "======Test SQRT======\n"
pass_str:
.string " => Pass\n"
fail_str:
.string " => Fail\n"
output_str:
.string "Output = "
answer_str:
.string ",Answer = "
.text
.global main
# =======================================================
# Function : main()
# Parameter : none
# Variable : i = s0, boundary = 15
# Description : execute all function and print the result
# Return : 0 (exit program)
# =======================================================
main:
# i = 0, boundary = 15
add s0, x0, x0
j main_for
main_for:
addi t0, x0, 15
# if i >= 15, return exit
bge s0, t0, main_exit
j main_for_run_nan
#======================nan=====================
main_for_run_nan:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i] to bf16_nan()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
# call bf16_isnan
jal ra, bf16_isnan
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_nan
main_for_printf_nan:
# print test nan
la a0, msg_test_nan
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, isNanAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_nan_pass
j main_for_nan_fail
main_for_nan_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_inf
main_for_nan_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_inf
#======================inf=====================
main_for_run_inf:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i] to bf16_inf()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
# call bf16_isinf
jal ra, bf16_isinf
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_inf
main_for_printf_inf:
# print test inf
la a0, msg_test_inf
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, isInfAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_inf_pass
j main_for_inf_fail
main_for_inf_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_zero
main_for_inf_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_zero
#======================zero=====================
main_for_run_zero:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i] to bf16_iszero()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
# call bf16_iszero
jal ra, bf16_iszero
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_zero
main_for_printf_zero:
# print test zero
la a0, msg_test_zero
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, isZeroAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_zero_pass
j main_for_zero_fail
main_for_zero_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_f32tob16
main_for_zero_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_f32tob16
#======================f32tob16=====================
main_for_run_f32tob16:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load test_float[i] to f32_to_bf16()
slli t0, s0, 2
la t1, input_float
add t1, t1, t0
lw a0, 0(t1)
# call f32_to_bf16
jal ra, f32_to_bf16
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_f32tob16
main_for_printf_f32tob16:
# print test f32tob16
la a0, msg_test_f32tob16
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, f32tob16Ans
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_f32tob16_pass
j main_for_f32tob16_fail
main_for_f32tob16_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_b16tof32
main_for_f32tob16_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_b16tof32
#======================b16tof32=====================
main_for_run_b16tof32:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i] to bf16_to_f32()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
# call bf16_to_f32
jal ra, bf16_to_f32
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_b16tof32
main_for_printf_b16tof32:
# print test b16tof32
la a0, msg_test_b16tof32
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
add a0, s3, x0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, b16tof32Ans
slli t0, s0, 2
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
ecall
beq t5, t6, main_for_b16tof32_pass
j main_for_b16tof32_fail
main_for_b16tof32_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_add
main_for_b16tof32_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_add
#======================add=====================
main_for_run_add:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i], b[i] to bf16_add()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
la t2, input_b
add t2, t2, t0
lw a1, 0(t2)
# call bf16_add
jal ra, bf16_add
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_add
main_for_printf_add:
# print test add
la a0, msg_test_add
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, addAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_add_pass
j main_for_add_fail
main_for_add_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_sub
main_for_add_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_sub
#======================sub=====================
main_for_run_sub:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i], b[i] to bf16_sub()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
la t2, input_b
add t2, t2, t0
lw a1, 0(t2)
# call bf16_sub
jal ra, bf16_sub
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_sub
main_for_printf_sub:
# print test sub
la a0, msg_test_sub
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, subAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_sub_pass
j main_for_sub_fail
main_for_sub_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_mul
main_for_sub_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_mul
#======================mul=====================
main_for_run_mul:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i], b[i] to bf16_mul()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
la t2, input_b
add t2, t2, t0
lw a1, 0(t2)
# call bf16_mul
jal ra, bf16_mul
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_mul
main_for_printf_mul:
# print test mul
la a0, msg_test_mul
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, mulAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_mul_pass
j main_for_mul_fail
main_for_mul_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_div
main_for_mul_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_div
#======================div=====================
main_for_run_div:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i], b[i] to bf16_div()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
la t2, input_b
add t2, t2, t0
lw a1, 0(t2)
# call bf16_div
jal ra, bf16_div
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_div
main_for_printf_div:
# print test div
la a0, msg_test_div
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, divAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_div_pass
j main_for_div_fail
main_for_div_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
j main_for_run_sqrt
main_for_div_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
j main_for_run_sqrt
#======================sqrt=====================
main_for_run_sqrt:
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
# load a[i] to bf16_sqrt()
slli t0, s0, 1
la t1, input_a
add t1, t1, t0
lw a0, 0(t1)
# call bf16_sqrt
jal ra, bf16_sqrt
# value of return stores in s3
add s3, a0, x0
lw s0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
j main_for_printf_sqrt
main_for_printf_sqrt:
# print test sqrt
la a0, msg_test_sqrt
li a7, 4
ecall
# print output string
la a0, output_str
ecall
# print output value
li t0, 0xFFFF
and a0, s3, t0
li a7, 34
# for compare
add t5, a0, x0
ecall
# print answer string
la a0, answer_str
li a7, 4
ecall
# print ans value
la s3, sqrtAns
slli t0, s0, 1
add t0, s3, t0
lw a0, 0(t0)
li a7, 34
# for compare
add t6, a0, x0
li t4, 0xffff
and a0, t6, t4
and t6, t6, t4
ecall
beq t5, t6, main_for_sqrt_pass
j main_for_sqrt_fail
main_for_sqrt_pass:
# print pass
la a0, pass_str
li a7, 4
ecall
# ++i
addi s0, s0, 1
j main_for
main_for_sqrt_fail:
# print fail
la a0, fail_str
li a7, 4
ecall
# ++i
addi s0, s0, 1
j main_for
#======================main exit=====================
main_exit:
li a7, 10
ecall
# =======================================================
# Function : bf16_isnan()
# Parameter : bf16_t a
# Variable :
# Description : Returns true if a is NaN; otherwise, returns false
# Return : 1(true) or 0(false)
# =======================================================
# test ok
bf16_isnan:
# t1 = (a.bits & BF16_EXP_MASK)
li t0, BF16_EXP_MASK
and t1, a0, t0
# if (a.bits & BF16_EXP_MASK) == BF16_EXP_MASK
bne t1, t0, bf16_isnan_ret0
# if (a.bits & BF16_MANT_MASK) == 0
li t0, BF16_MANT_MASK
and t1, a0, t0
beq t1, x0, bf16_isnan_ret0
# return 1
addi a0, x0, 1
jalr x0, ra, 0
bf16_isnan_ret0:
# return 0
add a0, x0, x0
jalr x0, ra, 0
# =======================================================
# Function : bf16_isinf()
# Parameter : bf16_t a
# Variable :
# Description : Returns true if the input is +Infinity or -Infinity.
# Return : 1(true) or 0(false)
# =======================================================
# test ok
bf16_isinf:
# t1 = (a.bits & BF16_EXP_MASK)
li t0, BF16_EXP_MASK
and t1, a0, t0
# if (a.bits & BF16_EXP_MASK) == BF16_EXP_MASK
bne t1, t0, bf16_isinf_ret0
# if !(a.bits & BF16_MANT_MASK) == 0
li t0, BF16_MANT_MASK
and t1, a0, t0
bne t1, x0, bf16_isinf_ret0
# return 1
addi a0, x0, 1
jalr x0, ra, 0
bf16_isinf_ret0:
# return 0
add a0, x0, x0
jalr x0, ra, 0
# =======================================================
# Function : bf16_iszero()
# Parameter : bf16_t a
# Variable :
# Description : Returns true if the input is positive or negative zero.
# Return : 1(true) or 0(false)
# =======================================================
# test ok
bf16_iszero:
# t1 = (a.bits & 0x7FFF)
li t0, 0x7FFF
and t1, a0, t0
bne t1, x0, bf16_iszero_ret0
# return 1
addi a0, x0, 1
jalr x0, ra, 0
bf16_iszero_ret0:
# return 0
add a0, x0, x0
jalr x0, ra, 0
# =======================================================
# Function : f32_to_bf16()
# Parameter : float val
# Variable :
# Description : Convert a 32-bit float to 16-bit bfloat16 by keeping the upper 16 bits.
# Return : bf16_t value(a0)
# =======================================================
# test ok
f32_to_bf16:
# u32 t0 = f32bits
# memcpy(&f32bits, &val, sizeof(float));
add t0, a0, x0
# if (((f32bits >> 23) & 0xFF) == 0xFF)
srli t1, t0, 23
andi t2, t1, 0xFF
li t3, 0xFF
bne t2, t3 , f32_to_bf16_ret_exp_not_allOne
# return all exp 1
# (f32bits >> 16)& 0xFFFF
srli t1, t0, 16
li t2, 0xFFFF
and a0, t1, t2
# return all exp 1
jalr x0, ra, 0
f32_to_bf16_ret_exp_not_allOne:
# return exp not all 1
# f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
srli t1, t0, 16
andi t2, t1, 1
li t3, 0x7FFF
add t4, t2, t3
add t0, t0, t4
# return (bf16_t)f32bits >> 16
srli a0, t0, 16
li t1, 0xFFFF
and a0, a0, t1
jalr x0, ra, 0
# =======================================================
# Function : bf16_to_f32()
# Parameter : bf16_t a
# Variable :
# Description : Converts a bfloat16 value to 32-bit float by zero-extending the lower bits.
# Return : float value(a0)
# =======================================================
# test ok
bf16_to_f32:
# u32 f32bits = ((u32) val.bits) << 16;
slli t0, a0, 16
# memcpy(&result, &f32bits, sizeof(float))
add a0, t0, x0
# return result
jalr x0, ra, 0
# =======================================================
# Function : bf16_add()
# Parameter : bf16_t a, bf16_t b
# Variable :
# Description : Performs bfloat16 addition with proper handling of special cases (NaN, Inf, zero).
# Return : b16_t value(a0)
# =======================================================
# test ok
bf16_add:
addi sp, sp, -24
sw s6, 20(sp)
sw s5, 16(sp)
sw s4, 12(sp)
sw s3, 8(sp)
sw s2, 4(sp)
sw s1, 0(sp)
# sign_a = (a.bits >> 15) & 1
srli s1, a0, 15
andi s1, s1, 1
# sign_b = (b.bits >> 15) & 1;
srli s2, a1, 15
andi s2, s2, 1
# exp_a = ((a.bits >> 7) & 0xFF)
srli s3, a0, 7
andi s3, s3, 0xFF
# exp_b = ((b.bits >> 7) & 0xFF)
srli s4, a1, 7
andi s4, s4, 0xFF
# mant_a = a.bits & 0x7F
andi s5, a0, 0x7F
# mant_b = b.bits & 0x7F
andi s6, a1, 0x7F
# if exp_a == 0xFF
li t0, 0xFF
beq s3, t0, bf16_add_exp_a_allOne
# if exp_b == 0xFF
beq s4, t0, bf16_add_ret_b
# if (!exp_a && !mant_a) <=> exp and mant = 0
or t6, s3, s5
beq t6, x0, bf16_add_ret_b
# if (!exp_b && !mant_b) <=> exp and mant = 0
or t6, s4, s6
beq t6, x0, bf16_add_ret_a
# if (exp_a)
bne s3, x0, bf16_add_mant_a_or0x80
# if (exp_b)
bne s4, x0, bf16_add_mant_b_or0x80
j bf16_add_dif
bf16_add_mant_a_or0x80:
ori s5, s5, 0x80
bf16_add_exp_b_not0:
# if (exp_b)
bne s4, x0, bf16_add_mant_b_or0x80
j bf16_add_dif
bf16_add_mant_b_or0x80:
ori s6, s6, 0x80
bf16_add_dif:
# maybe some error
# exp_diff = exp_a - exp_b;
sub t0, s3, s4
# if (exp_diff > 0)
blt x0, t0, bf16_add_exp_dif_bgt0
# if (exp_diff < 0)
blt t0, x0, bf16_add_exp_dif_blt0
# if (exp_diff == 0)
beq x0, t0, bf16_add_exp_dif_beq0
# impossible
j bf16_add_check_sign
bf16_add_exp_dif_bgt0:
# result_exp = exp_a
add t2, s3, x0
# if (exp_diff > 8)
li t6, 8
blt t6, t0, bf16_add_ret_a
# mant_b >>= exp_diff
srl s6, s6, t0
# jump to if (sign_a == sign_b)
j bf16_add_check_sign
bf16_add_exp_dif_blt0:
# result_exp = exp_b
add t2, s4, x0
# if (exp_diff < -8)
li t6, -8
blt t0, t2, bf16_add_ret_b
# mant_a >>= -exp_diff
sub t6, x0, t0
srl s6, s6, t6
# jump to if (sign_a == sign_b)
j bf16_add_check_sign
bf16_add_exp_dif_beq0:
# result_exp = exp_a
add t2, s3, x0
bf16_add_check_sign:
# if (sign_a == sign_b) , eq jump to bf16_add_check_sign_eq
beq s1, s2, bf16_add_check_sign_eq
# else
# if (mant_a >= mant_b), true jump to gn
bge s5, s6, bf16_add_check_mant_gn
# else <
# result_sign = sign_b
add t1, s2, x0
# result_mant = mant_b - mant_a
sub t3, s6, s5
# jump bf16_add_check_result_mant
j bf16_add_check_result_mant
bf16_add_check_mant_gn:
# result_sign = sign_a
add t1, s1, x0
# result_mant = mant_a - mant_b
sub t3, s5, s6
bf16_add_check_result_mant:
# if (!result_mant)
beq t3, x0, bf16_add_ret0
bf16_add_check_result_mant_while:
# while (!(result_mant & 0x80))
andi t5, t3, 0x80
bne t5, x0, bf16_add_ret
# result_mant <<= 1
slli t3, t3, 1
# if (--result_exp <= 0)
addi t2, t2, -1
bge x0, t2, bf16_add_ret0
j bf16_add_check_result_mant_while
bf16_add_check_sign_eq:
# result_sign = sign_a
add t1, s1, x0
# result_mant = (uint32_t) mant_a + mant_b
add t3, s5, s6
# if (result_mant & 0x100), eq 0 jump to return
andi t6, t3, 0x100
beq t6, x0, bf16_add_ret
# result_mant >>= 1
srli t3, t3, 1
# if(++result_exp >= 0xFF)
# ++result_exp
addi t2, t2, 1
# (++result_exp >= 0xFF), if result_exp < 0xFF, jump return
li t6, 0xFF
blt t2, t6, bf16_add_ret
# else
# ((return result_sign << 15) | 0x7F80)
li t6, 0x7F80
slli t1, t1, 15
or a0, t1, t6
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_add_exp_a_allOne:
# if mant_a != 0
bne s5, x0, bf16_add_ret_a
# if (exp_b == 0xFF)
beq s4, t0, bf16_add_exp_b_allOne
j bf16_add_ret_a
bf16_add_exp_b_allOne:
# return (mant_b || sign_a == sign_b) ? b : BF16_NAN()
# sign_a == sign_b
sub t0, s1, s2
# (mant_b || sign_a == sign_b)
or t1, s6, t0
# if true , return b, otherwise return bf16 nan
bne t1, x0, bf16_add_ret_b
# return BF16_NAN
li a0, BF16_NAN
add a0, a1, x0
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_add_ret_a:
# return a
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_add_ret_b:
# return b
add a0, a1, x0
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_add_ret0:
li a0, BF16_ZERO
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_add_ret:
# (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)
# (result_sign << 15)
slli a0, t1, 15
# ((result_exp & 0xFF) << 7)
andi t5, t2, 0xFF
slli t5, t5, 7
# (result_mant & 0x7F)
andi t6, t3, 0x7F
or a0, a0, t5
or a0, a0, t6
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
# =======================================================
# Function : bf16_sub()
# Parameter : bf16_t a, bf16_t b
# Variable :
# Description : Performs bfloat16 subtraction by flipping the sign of the second operand and adding.
# Return : b16_t value(a0)
# =======================================================
# test ok
bf16_sub:
addi sp, sp, -4
sw ra, 0(sp)
# b.bits ^= BF16_SIGN_MASK
li t0, BF16_SIGN_MASK
xor a1, a1, t0
# call bf16_add
jal ra, bf16_add
lw ra, 0(sp)
addi sp, sp, 4
jalr x0, ra, 0
# =======================================================
# Function : bf16_mul()
# Parameter : bf16_t a, bf16_t b
# Variable :
# Description : Performs bfloat16 multiplication with normalization and special case handling.
# Return : b16_t value(a0)
# =======================================================
# test ok
bf16_mul:
addi sp, sp, -24
sw s6, 20(sp)
sw s5, 16(sp)
sw s4, 12(sp)
sw s3, 8(sp)
sw s2, 4(sp)
sw s1, 0(sp)
# sign_a = (a.bits >> 15) & 1
srli s1, a0, 15
andi s1, s1, 1
# sign_b = (b.bits >> 15) & 1;
srli s2, a1, 15
andi s2, s2, 1
# exp_a = ((a.bits >> 7) & 0xFF)
srli s3, a0, 7
andi s3, s3, 0xFF
# exp_b = ((b.bits >> 7) & 0xFF)
srli s4, a1, 7
andi s4, s4, 0xFF
# mant_a = a.bits & 0x7F
andi s5, a0, 0x7F
# mant_b = b.bits & 0x7F
andi s6, a1, 0x7F
# result_sign = sign_a ^ sign_b
xor t1, s1, s2
# if (exp_a == 0xFF), exp_a all one jump to bf16_mul_a_exp_allOne
li t6, 0xFF
beq s3, t6, bf16_mul_a_exp_allOne
# if (exp_b == 0xFF), exp_b all one jump to bf16_mul_b_exp_allOne
beq s4, t6, bf16_mul_b_exp_allOne
# if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
or t5, s3, s5
or t6, s4, s6
beq t5, x0, bf16_mul_retSign_slli15
beq t6, x0, bf16_mul_retSign_slli15
# exp_adjust = 0
add t4, x0, x0
# if (!exp_a), a exp is zero, jump bf16_mul_a_exp_zero
beq s3, x0, bf16_mul_a_exp_zero
# mant_a |= 0x80
ori s5, s5, 0x80
# if (!exp_b), b exp is zero, jump bf16_mul_b_exp_zero
beq s4, x0, bf16_mul_b_exp_zero
# mant_b |= 0x80
ori s6, s6, 0x80
j bf16_mul_result_exp_mant
bf16_mul_a_exp_zero:
# (mant_a & 0x80)
andi t6, s5, 0x80
# while (!(mant_a & 0x80))
beq t6, x0, bf16_mul_a_exp_zero_while
# exp_a = 1
addi s3, x0, 1
# if (!exp_b), b exp is zero, jump bf16_mul_b_exp_zero
beq s4, x0, bf16_mul_b_exp_zero
# mant_b |= 0x80
ori s6, s6, 0x80
j bf16_mul_result_exp_mant
bf16_mul_a_exp_zero_while:
# mant_a <<= 1
slli s5, s5, 1
# exp_adjust--
addi t4, t4, -1
# (mant_a & 0x80)
andi t6, s5, 0x80
# while (!(mant_a & 0x80))
beq t6, x0, bf16_mul_a_exp_zero_while
# exp_a = 1
addi s3, x0, 1
# if (!exp_b), b exp is zero, jump bf16_mul_b_exp_zero
beq s4, x0, bf16_mul_b_exp_zero
# mant_b |= 0x80
ori s6, s6, 0x80
j bf16_mul_result_exp_mant
bf16_mul_b_exp_zero:
# (mant_b & 0x80)
andi t6, s6, 0x80
# while (!(mant_b & 0x80))
beq t6, x0, bf16_mul_b_exp_zero_while
# exp_b = 1
addi s4, x0, 1
j bf16_mul_result_exp_mant
bf16_mul_b_exp_zero_while:
# mant_b <<= 1
slli s6, s6, 1
# exp_adjust--
addi t4, t4, -1
# (mant_b & 0x80)
andi t6, s6, 0x80
# while (!(mant_b & 0x80))
beq t6, x0, bf16_mul_b_exp_zero_while
# exp_b = 1
addi s4, x0, 1
bf16_mul_result_exp_mant:
# result_mant = (uint32_t) mant_a * mant_b
mul t3, s5, s6
# result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust
li t6, BF16_EXP_BIAS
# result_exp = exp_a + exp_b
add t2, s3, s4
# result_exp = result_exp - BF16_EXP_BIAS
sub t2, t2, t6
# result_exp = result_exp + exp_adjust
add t2, t2, t4
# if (result_mant & 0x8000)
li t6, 0x8000
and t6, t3, t6
beq t6, x0, bf16_mul_set_result_mant_srl7
# result_mant = (result_mant >> 8) & 0x7F
srli t3, t3, 8
andi t3, t3, 0x7F
# result_exp++
addi t2, t2, 1
j bf16_mul_check_result_exp
bf16_mul_set_result_mant_srl7:
# result_mant = (result_mant >> 7) & 0x7F
srli t3, t3, 7
andi t3, t3, 0x7F
bf16_mul_check_result_exp:
# if (result_exp >= 0xFF)
li t6, 0xFF
bge t2, t6, bf16_mul_result_exp_bg0xFF_ret
# if (result_exp <= 0), exp > 0, jump ret
blt x0, t2, bf16_mul_ret
# if (result_exp < -6)
li t6, -6
blt t2, t6, bf16_mul_result_exp_smNeg6_ret
# result_mant >>= (1 - result_exp)
li t6, 1
sub t5, t6, t2
srl t3, t3, t5
# result_exp = 0
add t2, x0, x0
# return
j bf16_mul_ret
bf16_mul_result_exp_bg0xFF_ret:
# return ((result_sign << 15) | 0x7F80)
li t6, 0x7F80
slli t1, t1, 15
or a0, t1, t6
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_result_exp_smNeg6_ret:
# return (result_sign << 15)
slli a0, t1, 15
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_a_exp_allOne:
# if (mant_a), mant_a isn't zero, jump return a
bne s5, x0, bf16_mul_reta
# if (!exp_b && !mant_b) <=> b exp and mant equal 0
or t5, s4, s6
beq t5, x0, bf16_mul_retNan
# return ((result_sign << 15) | 0x7F80)
li t6, 0x7F80
slli t1, t1, 15
or a0, t1, t6
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_b_exp_allOne:
# if (mant_b), mant_b isn't zero, jump return b
bne s6, x0, bf16_mul_retb
# if (!exp_a && !mant_a) <=> a exp and mant equal 0
or t5, s3, s5
beq t5, x0, bf16_mul_retNan
# return ((result_sign << 15) | 0x7F80)
li t6, 0x7F80
slli t1, t1, 15
or a0, t1, t6
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_retSign_slli15:
# return result_sign << 15
slli a0, t1, 15
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_retNan:
# a0 = a1
li a0, BF16_NAN
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_reta:
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_retb:
# a0 = a1
add a0, a1, x0
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
bf16_mul_ret:
# return ((result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F))
slli t1, t1, 15
andi t2, t2, 0xFF
slli t2, t2, 7
andi t3, t3, 0x7F
or a0, t1, t2
or a0, a0, t3
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
# =======================================================
# Function : bf16_div()
# Parameter : bf16_t a, bf16_t b
# Variable :
# Description : Performs bfloat16 division using bit-level integer division and handles edge cases.
# Return : b16_t value(a0)
# =======================================================
# test ok
bf16_div:
addi sp, sp, -24
sw s6, 20(sp)
sw s5, 16(sp)
sw s4, 12(sp)
sw s3, 8(sp)
sw s2, 4(sp)
sw s1, 0(sp)
# sign_a = (a.bits >> 15) & 1
srli s1, a0, 15
andi s1, s1, 1
# sign_b = (b.bits >> 15) & 1;
srli s2, a1, 15
andi s2, s2, 1
# exp_a = ((a.bits >> 7) & 0xFF)
srli s3, a0, 7
andi s3, s3, 0xFF
# exp_b = ((b.bits >> 7) & 0xFF)
srli s4, a1, 7
andi s4, s4, 0xFF
# mant_a = a.bits & 0x7F
andi s5, a0, 0x7F
# mant_b = b.bits & 0x7F
andi s6, a1, 0x7F
# result_sign = sign_a ^ sign_b
xor t1, s1, s2
# if (exp_b == 0xFF)
li t6, 0xFF
beq s4, t6, bf16_div_b_exp_allOne
# if (!exp_b && !mant_b)
or t5, s4, s6
beq t5, x0, bf16_div_b_exp_mant_zero
# if (exp_a == 0xFF)
beq s3, t6, bf16_div_a_exp_allOne
# if (!exp_a && !mant_a), a exp and mant all zero
or t5, s3, s5
beq t5, x0, bf16_div_a_exp_mant_zero
# if (exp_a)
bne s3, x0, bf16_div_mant_a_or0x80
# if (exp_b)
bne s4, x0, bf16_div_mant_b_or0x80
j bf16_div_run
bf16_div_mant_a_or0x80:
# mant_a |= 0x80
ori s5, s5, 0x80
# if (exp_b)
bne s4, x0, bf16_div_mant_b_or0x80
j bf16_div_run
bf16_div_mant_b_or0x80:
# mant_b |= 0x80
ori s6, s6, 0x80
j bf16_div_run
bf16_div_run:
# dividend = (uint32_t) mant_a << 15
slli t3, s5, 15
# divisor = mant_b
add t4, s6, x0
# uint32_t quotient = 0
add t5, x0, x0
# set i = 0
add t0, x0, x0
li t6, 16
j bf16_div_for
bf16_div_for:
# for (int i = 0; i < 16; i++)
bge t0, t6, bf16_div_result_exp
# quotient <<= 1
slli t5, t5, 1
# (divisor << (15 - i))
li t6, 15
sub t6, t6, t0
sll t6, t4, t6
# if (dividend >= (divisor << (15 - i)))
bge t3, t6, bf16_div_for_divdend_divsor
li t6, 16
# i++
addi t0, t0, 1
j bf16_div_for
bf16_div_for_divdend_divsor:
# dividend -= (divisor << (15 - i))
sub t3, t3, t6
# quotient |= 1
ori t5, t5, 1
li t6, 16
# i++
addi t0, t0, 1
j bf16_div_for
bf16_div_result_exp:
# result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS
li t6, BF16_EXP_BIAS
sub t2, s3, s4
add t2, t2, t6
# if (!exp_a)
beq s3, x0, bf16_div_result_exp_minus1
# if (!exp_b)
beq s4, x0, bf16_div_result_exp_plus1
j bf16_div_quotient
bf16_div_result_exp_minus1:
# result_exp--
addi t2, t2, -1
# if (!exp_b)
beq s4, x0, bf16_div_result_exp_plus1
j bf16_div_quotient
bf16_div_result_exp_plus1:
# result_exp++
addi t2, t2, 1
j bf16_div_quotient
bf16_div_quotient:
# if (quotient & 0x8000), quot&0x8000 > 0, srli 8
li t6, 0x8000
and t6, t5, t6
bne t6, x0, bf16_div_quot_srl8
# else jump quot while
j bf16_div_quot_while
bf16_div_quot_while:
# while (!(quotient & 0x8000) && result_exp > 1)
bne t6, x0, bf16_div_quot_srl8
li t0, 1
bge t0, t2, bf16_div_quot_srl8
# quotient <<= 1
slli t5, t5, 1
# result_exp--
addi t2, t2, -1
# quotient & 0x8000
li t6, 0x8000
and t6, t5, t6
j bf16_div_quot_while
bf16_div_quot_srl8:
# quotient >>= 8
srli t5, t5, 8
j bf16_div_result_exp_ret
bf16_div_result_exp_ret:
# quotient &= 0x7F
andi t5, t5, 0x7F
# if (result_exp >= 0xFF)
li t6, 0xFF
bge t2, t6, bf16_div_ret_sign_sll15_or7F80
# if (result_exp <= 0)
bge x0, t2, bf16_div_ret_sign_sll15
j bf16_div_ret
bf16_div_a_exp_allOne:
# if (mant_a), mant_a != 0, return a
bne s5, x0, bf16_div_reta
# return ((result_sign << 15) | 0x7F80)
j bf16_div_ret_sign_sll15_or7F80
bf16_div_a_exp_mant_zero:
# return (result_sign << 15)
j bf16_div_ret_sign_sll15
bf16_div_b_exp_allOne:
# if (mant_b), b mant != 0, return b
bne s6, x0, bf16_div_retb
# if (exp_a == 0xFF && !mant_a)
li t6, 0xFF
beq s3, t6, bf16_div_b_check_a_NAN
# return result_sign << 15
j bf16_div_ret_sign_sll15
bf16_div_b_exp_mant_zero:
# if (!exp_a && !mant_a), a exp and mant all zero, return NAN
or t5, s3, s5
beq t5, x0, bf16_div_retNAN
# return ((result_sign << 15) | 0x7F80)
j bf16_div_ret_sign_sll15_or7F80
bf16_div_b_check_a_NAN:
# if (exp_a == 0xFF && !mant_a), exp = 0xFF, mant == 0 return nan
beq s5, x0, bf16_div_retNAN
# else mant = 0
# return result_sign << 15
j bf16_div_ret_sign_sll15
bf16_div_ret_sign_sll15_or7F80:
# return ((result_sign << 15) | 0x7F80)
slli a0, t1, 15
li t6, 0x7F80
or a0, a0, t6
j bf16_div_return
bf16_div_ret_sign_sll15:
# return result_sign << 15
slli a0, t1, 15
j bf16_div_return
bf16_div_retNAN:
# set a0 = BF16_NAN
li a0, BF16_NAN
j bf16_div_return
bf16_div_reta:
j bf16_div_return
bf16_div_retb:
# set a0 = b
add a0, a1, x0
j bf16_div_return
bf16_div_ret:
# ((result_sign << 15) | ((result_exp & 0xFF) << 7) |(quotient & 0x7F))
slli t1, t1, 15
andi t2, t2, 0xFF
slli t2, t2, 7
andi t5, t5, 0x7F
or a0, t1, t2
or a0, a0, t5
j bf16_div_return
bf16_div_return:
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
lw s4, 12(sp)
lw s5, 16(sp)
lw s6, 20(sp)
addi sp, sp, 24
jalr x0, ra, 0
# =======================================================
# Function : bf16_sqrt()
# Parameter : bf16_t a
# Variable :
# Description : Computes the square root of a bfloat16 number using bitwise operations and binary search.
# Return : b16_t value(a0)
# =======================================================
# test ok
bf16_sqrt:
addi sp, sp, -12
sw s3, 8(sp)
sw s2, 4(sp)
sw s1, 0(sp)
# sign = (a.bits >> 15) & 1
srli s1, a0, 15
andi s1, s1, 1
# exp = ((a.bits >> 7) & 0xFF)
srli s2, a0, 7
andi s2, s2, 0xFF
# mant = a.bits & 0x7F
andi s3, a0, 0x7F
# if (exp == 0xFF), exp all one, jump bf16_sqrt_exp_allOne
li t6, 0xFF
beq s2, t6, bf16_sqrt_exp_allOne
# if (!exp && !mant), exp and mant all zeros, return zero
or t0, s2, s3
beq t0, x0, bf16_sqrt_retZero
# if (sign), sign = 1 => negative, return nan
bne s1, x0, bf16_sqrt_retNAN
# if (!exp), exp = 0, return zero
beq s2, x0, bf16_sqrt_retZero
# e = exp - BF16_EXP_BIAS
li t6, BF16_EXP_BIAS
sub t1, s2, t6
# m = 0x80 | mant
ori t3, s3, 0x80
j bf16_sqrt_adjust_odd_exp
bf16_sqrt_adjust_odd_exp:
# if (e & 1), odd, jump bf16_sqrt_adjust_odd
andi t5, t1, 1
bne t5, x0, bf16_sqrt_adjust_odd
# else jump bf16_sqrt_adjust_even
j bf16_sqrt_adjust_even
bf16_sqrt_adjust_odd:
# m <<= 1
slli t3, t3, 1
# new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS
addi t2, t1, -1
srai t2, t2, 1
li t6, BF16_EXP_BIAS
add t2, t2, t6
j bf16_sqrt_search_square
bf16_sqrt_adjust_even:
# new_exp = (e >> 1) + BF16_EXP_BIAS
srai t2, t1, 1
li t6, BF16_EXP_BIAS
add t2, t2, t6
j bf16_sqrt_search_square
bf16_sqrt_search_square:
# low = 90
li t4, 90
# high = 256
li t5, 256
# result = 128
li t6, 128
j bf16_sqrt_search_square_while
bf16_sqrt_search_square_while:
# while (low <= high), low > high, jump ensure result
blt t5, t4, bf16_sqrt_norm_ensure_result
# mid = (low + high) >> 1
add t0, t4, t5
srli t0, t0, 1
j bf16_sqrt_pow_init
bf16_sqrt_search_square_pow:
# sq = (mid * mid) / 128
srli t1, t1, 7
# if (sq <= m)
bge t3, t1, bf16_sqrt_search_square_while_midPlus1
# else
j bf16_sqrt_search_square_while_midMinus1
bf16_sqrt_pow_init:
# set sq = 0 , s5 = t0 , s4 = t0
add s4, t0, x0
add s5, t0, x0
add t1, x0, x0
j bf16_sqrt_pow_forLoop
bf16_sqrt_pow_forLoop:
# if s5 == 0, return bf16_sqrt_search_square_pow
beq s5, x0, bf16_sqrt_search_square_pow
andi s0, s5, 1
# if (s0&1 == 0), j bf16_sqrt_pow_lsbZero
beq s0, x0, bf16_sqrt_pow_lsbZero
j bf16_sqrt_pow_lsbOne
bf16_sqrt_pow_lsbOne:
# sq = sq + s4
add t1, t1, s4
# s4 <<= 1
slli s4, s4, 1
# s5 >>= 1
srli s5, s5, 1
j bf16_sqrt_pow_forLoop
bf16_sqrt_pow_lsbZero:
# s4 <<= 1
slli s4, s4, 1
# s5 >>= 1
srli s5, s5, 1
j bf16_sqrt_pow_forLoop
bf16_sqrt_search_square_while_midPlus1:
# result = mid;
add t6, t0, x0
# low = mid + 1;
addi t4, t0, 1
j bf16_sqrt_search_square_while
bf16_sqrt_search_square_while_midMinus1:
# high = mid - 1;
addi t5, t0, -1
j bf16_sqrt_search_square_while
bf16_sqrt_norm_ensure_result:
# if (result >= 256), bge
li t0, 256
bge t6, t0, bf16_sqrt_norm_ensure_result_bge256
# else if (result < 128)
li t0, 128
blt t6, t0, bf16_sqrt_norm_ensure_result_blt128
j bf16_sqrt_extract_mantissa
bf16_sqrt_norm_ensure_result_bge256:
# result >>= 1
srli t6, t6, 1
# new_exp++
addi t2, t2, 1
j bf16_sqrt_extract_mantissa
bf16_sqrt_norm_ensure_result_blt128:
li t0, 128
li t1, 1
j bf16_sqrt_norm_ensure_result_blt128_while
bf16_sqrt_norm_ensure_result_blt128_while:
# while (result < 128 && new_exp > 1)
# result < 128
bge t6, t0, bf16_sqrt_extract_mantissa
# new_exp > 1
bge t1, t2, bf16_sqrt_extract_mantissa
# result <<= 1
slli t6, t6, 1
# new_exp--
addi t2, t2, -1
j bf16_sqrt_norm_ensure_result_blt128_while
bf16_sqrt_extract_mantissa:
# new_mant = result & 0x7F
andi t4, t6, 0x7F
li t5, 0xFF
# if (new_exp >= 0xFF)
bge t2, t5, bf16_sqrt_retBge0xFF
# if (new_exp <= 0)
bge x0, t2, bf16_sqrt_retZero
j bf16_sqrt_ret
bf16_sqrt_exp_allOne:
# if (mant), mant != 0, return a
bne s3, x0, bf16_sqrt_reta
# if (sign), sign != 0, return nan
bne s1, x0, bf16_sqrt_retNAN
# return a
j bf16_sqrt_reta
bf16_sqrt_retBge0xFF:
# return 0x7F80
li a0, 0x7F80
j bf16_sqrt_return
bf16_sqrt_retZero:
# return zero
li a0, BF16_ZERO
j bf16_sqrt_return
bf16_sqrt_retNAN:
# return NAN
li a0, BF16_NAN
j bf16_sqrt_return
bf16_sqrt_reta:
# return a
j bf16_sqrt_return
bf16_sqrt_ret:
# return ((new_exp & 0xFF) << 7) | new_mant
andi a0, t2, 0xFF
slli a0, a0, 7
or a0, a0, t4
j bf16_sqrt_return
bf16_sqrt_return:
lw s1, 0(sp)
lw s2, 4(sp)
lw s3, 8(sp)
addi sp, sp, 12
jalr x0, ra, 0
```
:::
### Result

----
## Problem B in [quiz1](https://hackmd.io/@sysprog/arch2025-quiz1-sol)
#### clz
The clz (Count Leading Zeros) function works by examining the upper half of the input bits to determine where the first 1 appears, starting from the most significant bit. It uses a binary search–like approach to reduce the number of comparisons.
At each step, it right-shifts the number by c bits and checks whether the result is zero:
* If the result is zero, it means the upper half being examined contains only zeros, so the next check moves to the lower half.
* If the result is non-zero, it means there's at least one 1 in the upper half. In that case, the current number is updated to the shifted result, and the count of leading zeros n is reduced.
The shift amount c is halved after each iteration, gradually narrowing down the range to locate the first 1.
### C code of clz
```c=
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
```
### Assembly code of clz
```Assembly=
clz:
addi sp, sp, -16
sw ra, 12(sp)
sw t0, 8(sp)
sw t1, 4(sp)
sw t2, 0(sp)
li t0, 32 # t0 = n
li t1, 16 # t1 = c
clz_whileLoop:
srl t2, a0, t1 # y = t2, y = x >> c
beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit
sub t0, t0, t1 # n = n - c
addi a0, t2, 0 # x = y
shift_right_1bit:
srli t1, t1, 1 # c = c >> 1
bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop
sub a0, t0, a0 # x = n - x
lw t2, 0(sp)
lw t1, 4(sp)
lw t0, 8(sp)
lw ra, 12(sp)
addi sp, sp, 16
```
---
## Problem [Leetcode 3370. Smallest Number With All Set Bits](https://leetcode.com/problems/smallest-number-with-all-set-bits/description/)
>You are given a positive number n.
>
>Return the smallest number x greater than or equal to n, such that the binary representation of x contains only set bits
>
>Set Bit
A set bit refers to a bit in the binary representation of a number that has a value of 1.
>
>
>### Example 1:
>
>
>Input: n = 5
>Output: 7
>Explanation:
>The binary representation of 7 is "111".
>
>
>### Constraints:
>* 1 <= n <= 1000
>
## Solution
### Idea for problem solving
The problem asks us to find the smallest number x such that x >= n and all bits of x are 1 in binary representation.
This means we are looking for the smallest number in the form of 2^k - 1 such that 2^k - 1 >= n.
To solve this:
We can start from x = 1 and keep shifting left (i.e., multiplying by 2), until x >= n.
This approach is simple but takes O(log n) time.
To optimize, we can use the clz (Count Leading Zeros) instruction to directly find the position of the most significant 1 bit in n.
Using clz, we can calculate the required power of 2 in O(1) time, and construct the result as x = (1 << (32 - clz(n))) - 1.
This is a typical bit manipulation problem, where understanding binary patterns and bitwise operations helps simplify the logic and improve performance.
---
#### Original:
##### C code
```c=
int smallestNumber(int n) {
int x = 1;
while (x<n)
x <<=1;
if (x == n)
return (x<<1) - 1;
return x-1;
}
```
##### Assembly code
```Assmebly=
smallestNumber:
addi sp, sp, -8
sw ra, 4(sp)
sw t0, 0(sp)
li t0, 1 # x = 1
sml_while_loop:
bge t0, a0, sml_cmp_xandn
slli t0, t0, 1 # x <<= 1
j sml_while_loop
sml_cmp_xandn:
bne t0, a0, sml_x_shift
slli t0, t0, 1
sml_x_shift:
addi a0, t0, -1
lw t0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
```
----
### Using clz:
#### C code
```c=
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
int smallestNumber(int n) {
int bit_len = (1 << (32-clz(n)))-1;
return bit_len;
}
```
#### Assembly code
```Assmebly=
smallestNumber:
addi sp, sp, -12
sw ra, 8(sp)
sw t0, 4(sp)
sw t1, 0(sp)
# call clz
addi sp, sp, -8 # store ra, a0(n)
sw ra, 4(sp)
sw a0, 0(sp)
jal ra, clz
li t0, 32 # bit_len = 32
li t1, 1
sub t0, t0, a0 # bit_len = 32 - clz(n)
lw a0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
sll t1, t1, t0
addi a0, t1, -1
lw t1, 0(sp)
lw t0, 4(sp)
lw ra, 8(sp)
addi sp, sp, 12
j x0, ra, 0
clz:
addi sp, sp, -16
sw ra, 12(sp)
sw t0, 8(sp)
sw t1, 4(sp)
sw t2, 0(sp)
li t0, 32 # t0 = n
li t1, 16 # t1 = c
clz_whileLoop:
srl t2, a0, t1 # y = t2, y = x >> c
beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit
sub t0, t0, t1 # n = n - c
addi a0, t2, 0 # x = y
shift_right_1bit:
srli t1, t1, 1 # c = c >> 1
bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop
sub a0, t0, a0 # x = n - x
lw t2, 0(sp)
lw t1, 4(sp)
lw t0, 8(sp)
lw ra, 12(sp)
addi sp, sp, 16
j x0, ra, 0
```
----
### Using clz in Leetcode
#### C code
```c=
#include <stdio.h>
#include <stdint.h>
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
int smallestNumber(int n) {
int bit_len = (1 << (32-clz(n)))-1;
return bit_len;
}
int main()
{
int input[] = {1, 509, 1000}, output;
int ans[] = {1, 511, 1023};
for (int i = 0 ; i < 3 ; ++i)
{
output = smallestNumber(input[i]);
printf("output = %d, answer = %d\n", output, ans[i]);
if (output == ans[i])
printf("True\n");
else
printf("False\n");
}
return 0;
}
```
#### Assembly code
```Assmebly=
.data
input:
.word 1, 509, 1000
ans:
.word 1, 511, 1023
output_str:
.string "output = "
answer_str:
.string ", answer = "
endline_str:
.string "\n"
true_str:
.string "True\n"
false_str:
.string "False\n"
.text
.global main
main:
# load input array
la s0, input
# load answer array
la s1, ans
# i = 0
li t0, 0
# for loop boundary 3
li t1, 3
main_for:
# for (int i = 0 ; i < 3 ; ++i)
bge t0, t1, main_exit
# load input[i] to a0
slli t2, t0, 2
add t3, s0, t2
lw a0, 0(t3)
# call smallestNumber
addi sp, sp, -16
sw ra, 12(sp)
sw t2, 8(sp)
sw t1, 4(sp)
sw t0, 0(sp)
jal ra, smallestNumber
# output = smallestNumber(input[i])
add s2, a0, x0
lw t0, 0(sp)
lw t1, 4(sp)
lw t2, 8(sp)
lw ra, 12(sp)
addi sp, sp, 16
j print_result
print_result:
# print "output = "
la a0, output_str
li a7, 4
ecall
# print output
addi a0, s2, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[i] to t3
# print answer
add t2, s1, t2
lw a0, 0(t2)
add t2, a0, x0
li a7, 1
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output[i] == ans[i])
beq s2, t2, print_true
# else
j print_false
print_true:
# print "True\n"
la a0, true_str
li a7, 4
ecall
# ++i
addi t0, t0, 1
j main_for
print_false:
# print "False\n"
la a0, false_str
li a7, 4
ecall
# ++i
addi t0, t0, 1
j main_for
smallestNumber:
addi sp, sp, -12
sw ra, 8(sp)
sw t0, 4(sp)
sw t1, 0(sp)
# call clz
addi sp, sp, -8 # store ra, a0(n)
sw ra, 4(sp)
sw a0, 0(sp)
jal ra, clz
li t0, 32 # bit_len = 32
li t1, 1
sub t0, t0, a0 # bit_len = 32 - clz(n)
lw a0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
sll t1, t1, t0
addi a0, t1, -1
lw t1, 0(sp)
lw t0, 4(sp)
lw ra, 8(sp)
addi sp, sp, 12
jalr x0, x1, 0
clz:
li t0, 32 # t0 = n
li t1, 16 # t1 = c
j clz_whileLoop
clz_whileLoop:
srl t2, a0, t1 # y = t2, y = x >> c
beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit
sub t0, t0, t1 # n = n - c
addi a0, t2, 0 # x = y
j shift_right_1bit
shift_right_1bit:
srli t1, t1, 1 # c = c >> 1
bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop
sub a0, t0, a0 # x = n - x
jalr x0, x1, 0
main_exit:
li a7, 10
ecall
```
#### Result

----
### Optimizing assembly code
#### The above method can optimize the assembly code for the three types of hazards.
1. structure hazard : add more hardware
2. data hazard : data forwarding or reordering
3. control hazard : delayed slot(ripes unsupported(auto flush)) or branch prediction(1 bit、2 bits、3 bits)
* Loop unrolling can also solve data and control hazards.
#### loop unrolling
**I use loop unrolling to reduce the number of branches.**
:::spoiler More detailed information
```assembly=
.data
input:
.word 1, 509, 1000
ans:
.word 1, 511, 1023
output_str:
.string "output = "
answer_str:
.string ", answer = "
endline_str:
.string "\n"
true_str:
.string "True\n"
false_str:
.string "False\n"
.text
.global main
main:
# load input array
la s0, input
# load answer array
la s1, ans
main_for:
# loop unrolling
# load input[0] to a0
addi t3, s0, 0
lw a0, 0(t3)
# call smallestNumber
addi sp, sp, -4
sw ra, 0(sp)
jal ra, smallestNumber
# output = smallestNumber(input[0])
add s2, a0, x0
# load input[1] to a0
addi t3, s0, 4
lw a0, 0(t3)
# call smallestNumber
jal ra, smallestNumber
# output = smallestNumber(input[1])
add s3, a0, x0
# load input[2] to a0
addi t3, s0, 8
lw a0, 0(t3)
# call smallestNumber
jal ra, smallestNumber
# output = smallestNumber(input[1])
add s4, a0, x0
lw ra, 0(sp)
addi sp, sp, 4
j print_result_0
print_result_0:
# print "output = "
la a0, output_str
li a7, 4
ecall
# print output
addi a0, s2, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[0] to t3
# print answer
addi t2, s1, 0
lw a0, 0(t2)
add t2, a0, x0
li a7, 1
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output[i] == ans[i])
beq s2, t2, print_true_0
# else
j print_false_0
print_true_0:
# print "True\n"
la a0, true_str
li a7, 4
ecall
j print_result_1
print_false_0:
# print "False\n"
la a0, false_str
li a7, 4
ecall
j print_result_1
print_result_1:
# print "output = "
la a0, output_str
li a7, 4
ecall
# print output
addi a0, s3, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[0] to t3
# print answer
addi t2, s1, 4
lw a0, 0(t2)
add t2, a0, x0
li a7, 1
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output == ans[1])
beq s3, t2, print_true_1
# else
j print_false_1
print_true_1:
# print "True\n"
la a0, true_str
li a7, 4
ecall
j print_result_2
print_false_1:
# print "False\n"
la a0, false_str
li a7, 4
ecall
j print_result_2
print_result_2:
# print "output = "
la a0, output_str
li a7, 4
ecall
# print output
addi a0, s4, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[0] to t3
# print answer
addi t2, s1, 8
lw a0, 0(t2)
add t2, a0, x0
li a7, 1
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output == ans[2])
beq s4, t2, print_true_2
# else
j print_false_2
print_true_2:
# print "True\n"
la a0, true_str
li a7, 4
ecall
j main_exit
print_false_2:
# print "False\n"
la a0, false_str
li a7, 4
ecall
j main_exit
smallestNumber:
addi sp, sp, -12
sw ra, 8(sp)
sw t0, 4(sp)
sw t1, 0(sp)
# call clz
addi sp, sp, -8 # store ra, a0(n)
sw ra, 4(sp)
sw a0, 0(sp)
jal ra, clz
li t0, 32 # bit_len = 32
li t1, 1
sub t0, t0, a0 # bit_len = 32 - clz(n)
lw a0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
sll t1, t1, t0
addi a0, t1, -1
lw t1, 0(sp)
lw t0, 4(sp)
lw ra, 8(sp)
addi sp, sp, 12
jalr x0, x1, 0
clz:
li t0, 32 # t0 = n
li t1, 16 # t1 = c
j clz_whileLoop
clz_whileLoop:
srl t2, a0, t1 # y = t2, y = x >> c
beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit
sub t0, t0, t1 # n = n - c
addi a0, t2, 0 # x = y
j shift_right_1bit
shift_right_1bit:
srli t1, t1, 1 # c = c >> 1
bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop
sub a0, t0, a0 # x = n - x
jalr x0, x1, 0
main_exit:
li a7, 10
ecall
```
:::
#### Loop unrolling + Reorder(load-use)
:::spoiler More detailed information
```assembly=
.data
input:
.word 1, 509, 1000
ans:
.word 1, 511, 1023
output_str:
.string "output = "
answer_str:
.string ", answer = "
endline_str:
.string "\n"
true_str:
.string "True\n"
false_str:
.string "False\n"
.text
.global main
main:
# load input array
la s0, input
# load answer array
la s1, ans
main_for:
# loop unrolling
# load input[0] to a0
addi t3, s0, 0
lw a0, 0(t3)
# call smallestNumber
addi sp, sp, -4
sw ra, 0(sp)
jal ra, smallestNumber
# output = smallestNumber(input[0])
add s2, a0, x0
# load input[1] to a0
addi t3, s0, 4
lw a0, 0(t3)
# call smallestNumber
jal ra, smallestNumber
# output = smallestNumber(input[1])
add s3, a0, x0
# load input[2] to a0
addi t3, s0, 8
lw a0, 0(t3)
# call smallestNumber
jal ra, smallestNumber
# output = smallestNumber(input[1])
add s4, a0, x0
lw ra, 0(sp)
addi sp, sp, 4
j print_result_0
print_result_0:
# print "output = "
la a0, output_str
li a7, 4
ecall
# print output
addi a0, s2, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[0] to t3
# print answer
addi t2, s1, 0
lw a0, 0(t2)
# reorder li and add
li a7, 1
add t2, a0, x0
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output[i] == ans[i])
beq s2, t2, print_true_0
# else
j print_false_0
print_true_0:
# print "True\n"
la a0, true_str
ecall
j print_result_1
print_false_0:
# print "False\n"
la a0, false_str
ecall
j print_result_1
print_result_1:
# print "output = "
la a0, output_str
ecall
# print output
addi a0, s3, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[0] to t3
# print answer
addi t2, s1, 4
lw a0, 0(t2)
# reorder li and add
li a7, 1
add t2, a0, x0
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output == ans[1])
beq s3, t2, print_true_1
# else
j print_false_1
print_true_1:
# print "True\n"
la a0, true_str
ecall
j print_result_2
print_false_1:
# print "False\n"
la a0, false_str
ecall
j print_result_2
print_result_2:
# print "output = "
la a0, output_str
ecall
# print output
addi a0, s4, 0
li a7, 1
ecall
# ", answer = "
la a0, answer_str
li a7, 4
ecall
# load ans[0] to t3
# print answer
addi t2, s1, 8
lw a0, 0(t2)
# reorder li and add
li a7, 1
add t2, a0, x0
ecall
# print "\n"
la a0, endline_str
li a7, 4
ecall
# if (output == ans[2])
beq s4, t2, print_true_2
# else
j print_false_2
print_true_2:
# print "True\n"
la a0, true_str
ecall
j main_exit
print_false_2:
# print "False\n"
la a0, false_str
ecall
j main_exit
smallestNumber:
addi sp, sp, -12
sw ra, 8(sp)
sw t0, 4(sp)
sw t1, 0(sp)
# call clz
addi sp, sp, -8 # store ra, a0(n)
sw ra, 4(sp)
sw a0, 0(sp)
jal ra, clz
li t0, 32 # bit_len = 32
li t1, 1
sub t0, t0, a0 # bit_len = 32 - clz(n)
lw a0, 0(sp)
lw ra, 4(sp)
addi sp, sp, 8
sll t1, t1, t0
addi a0, t1, -1
lw t1, 0(sp)
lw t0, 4(sp)
lw ra, 8(sp)
addi sp, sp, 12
jalr x0, x1, 0
clz:
li t0, 32 # t0 = n
li t1, 16 # t1 = c
j clz_whileLoop
clz_whileLoop:
srl t2, a0, t1 # y = t2, y = x >> c
beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit
sub t0, t0, t1 # n = n - c
addi a0, t2, 0 # x = y
j shift_right_1bit
shift_right_1bit:
srli t1, t1, 1 # c = c >> 1
bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop
sub a0, t0, a0 # x = n - x
jalr x0, x1, 0
main_exit:
li a7, 10
ecall
```
:::
#### Performance and CPU Cycle Counts of Leetcode Problems in Assembly:
|C |Assembly|Unrolling |
|--------|--------|--------|
||||
|Unrolling+Reorder|
|--------|
||
By implementing the problem in Assembly, **the number of CPU cycles was reduced from 10785 to 437 cycles, resulting in a 95% reduction in execution time.** Further optimization using techniques like **unrolling and reordering brought the cycles down to 383, improving performance by an additional 12%**.
## Analysis
### RISC-V operation process in Ripes
Ripes offers multiple execution models. **I selected the 5-stage pipeline with data forwarding and hazard detection** to run my program.

The "5-stage" means this processor using five-stage pipeline to parallelize instructions.
The stages are:
1. Instruction fetch
2. Instruction decode
3. Execute
4. Memory access
5. Write back
---
### R-type format
|31 - 25|24 - 20|19 - 15|14 - 12|11 - 7|6 - 0|
|---------|-------|-------|--------|------|---------|
| func7 | rs2 | rs1 | funct3 | rd | opcode |
1. opcode : operation code
2. rd : destination register number
3. funct3 : 3-bit function code
4. rs1 : the first source register number
5. rs2 : the second source register number
6. funct7 : 7-bit function code
ex:
add x28, x8, x7
| func7 |rs2(x7)|rs1(x8)| funct3 |rd(x28)| opcode |
|---------|-------|-------|--------|------|---------|
| 0000000 | 00111 | 01000 | 000 |11100 | 0110011 |
=> hex = 0x00740e33
---
#### IF (Instruction fetch)

1. The program counter points to the current instruction address in memory.
2. Since no branch occurs, the PC is updated to PC + 4. The multiplexer before the PC selects the adder output as the next address.
3. The instruction memory fetches the instruction at the address, and the compressed decoder (if enabled) expands any compressed instructions to 32-bit format.
#### ID (Instruction decode)

1. Program Counter (PC)
* 0x00000000 is the previous program counter (PC), pointing to the last instruction.
* 0x00000004 is the current PC, pointing to the instruction currently being decoded. In RISC-V, each instruction is 4 bytes, so PC increases by 4 each time.
2. Decoder
The decoder takes the binary instruction fetched from memory and extracts the following fields:
* rs1: the first source register
* rs2: the second source register
* rd: the destination register
opcode: tells what kind of instruction it is (e.g., arithmetic, load, branch)
3. Registers
* The register file uses the decoded rs1 and rs2 to read values from the corresponding registers
* These values are output and passed to the next stage (usually the Execute stage).
4. Immediate
* In my example, no immediate is used, which means the instruction is likely an R-type instruction like add, sub, etc.
* If it were an I-type, S-type, etc., the immediate value would be extracted during this stage.
#### EX (Execute)

1. Multiplexer
* A 3-to-1 multiplexer selects input data from either the ID stage, EX stage, or WB stage.
Since only one instruction is being used, the MUX will select the data from the ID stage.
* There are two 2-to-1 multiplexers in the design:
-- The first multiplexer selects between the Program Counter (PC) and rs1.
-- The second multiplexer selects between the immediate value and rs2.
2. ALU
* The control unit sends the opcode to the ALU, which uses it to determine the operation to perform.
* The ALU result is sent back to the IF stage, allowing the PC to choose between PC + 4 or the ALU result (e.g., for branches or jumps).
3. Branch
* The input data for the branch decision comes from the values in rs1 and rs2.
However, since no branch instructions are used, the branch logic can be ignored in this case.
#### MEM (Memory access)

1. Data memory
* In a typical design, the data memory uses the ALU result as the address to load or store data.
However, since this design does not include lw or sw instructions, the data memory is not utilized.
2. The result value is sent back to the EX (Execute) stage to be used by the next instruction.

#### WB (Write back)

* The multiplexer chooses the result from the ALU as the final output, so the output value is 0x00000000.
* The value is sent back to the EX stage if a future instruction needs the value of this instruction’s rd register.
* Regardless, the value 0x00000000 is written into register x28.
After all these stage are done, the register is updated like this:


## Reference
* [RISC-V Instruction Set Manual](https://riscv.org/specifications/ratified/)
* [Quiz1 of Computer Architecture (2025 Fall)](https://hackmd.io/@sysprog/arch2025-quiz1-sol)