# Assignment1: RISC-V Assembly and Instruction Pipeline
contributed by [<`Jackiempty`>](https://github.com/Jackiempty)
## Problem B
### C code
```c
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
typedef uint8_t uf8;
static inline unsigned clz(uint32_t x)
{
int n = 32, c = 16;
do {
uint32_t y = x >> c;
if (y) {
n -= c;
x = y;
}
c >>= 1;
} while (c);
return n - x;
}
/* Decode uf8 to uint32_t */
uint32_t uf8_decode(uf8 fl)
{
uint32_t mantissa = fl & 0x0f;
uint8_t exponent = fl >> 4;
uint32_t offset = (0x7FFF >> (15 - exponent)) << 4;
return (mantissa << exponent) + offset;
}
/* Encode uint32_t to uf8 */
uf8 uf8_encode(uint32_t value)
{
/* Use CLZ for fast exponent calculation */
if (value < 16)
return value;
/* Find appropriate exponent using CLZ hint */
int lz = clz(value);
int msb = 31 - lz;
/* Start from a good initial guess */
uint8_t exponent = 0;
uint32_t overflow = 0;
if (msb >= 5) {
/* Estimate exponent - the formula is empirical */
exponent = msb - 4;
if (exponent > 15)
exponent = 15;
/* Calculate overflow for estimated exponent */
for (uint8_t e = 0; e < exponent; e++)
overflow = (overflow << 1) + 16;
/* Adjust if estimate was off */
while (exponent > 0 && value < overflow) {
overflow = (overflow - 16) >> 1;
exponent--;
}
}
/* Find exact exponent */
while (exponent < 15) {
uint32_t next_overflow = (overflow << 1) + 16;
if (value < next_overflow)
break;
overflow = next_overflow;
exponent++;
}
uint8_t mantissa = (value - overflow) >> exponent;
return (exponent << 4) | mantissa;
}
/* Test encode/decode round-trip */
static bool test(void)
{
int32_t previous_value = -1;
bool passed = true;
for (int i = 0; i < 256; i++) {
uint8_t fl = i;
int32_t value = uf8_decode(fl);
uint8_t fl2 = uf8_encode(value);
if (fl != fl2) {
printf("%02x: produces value %d but encodes back to %02x\n", fl,
value, fl2);
passed = false;
}
if (value <= previous_value) {
printf("%02x: value %d <= previous_value %d\n", fl, value,
previous_value);
passed = false;
}
previous_value = value;
}
return passed;
}
int main(void)
{
if (test()) {
printf("All tests passed.\n");
return 0;
}
return 1;
}
```
### Assembly
```asm
.data
str_all_passed: .asciz "All tests passed.\n"
str_fail1: .asciz "%02x: produces value %d but encodes back to %02x\n"
str_fail2: .asciz "%02x: value %d <= previous_value %d\n"
.text
setup:
li ra, -1
li sp, 0x7ffffff0
main:
#######################################################
# < Function >
# main procedure
#
# < Parameters >
# NULL
#
# < Return Value >
# NULL
#######################################################
# < Local Variable >
# s0: string
#######################################################
## Save ra & Callee Saved
addi sp, sp, -12
sw ra, 8(sp)
sw s0, 4(sp)
sw s1, 0(sp)
li s0, 0x01000000
############### Call Function Procedure ###############
# Caller Saved
# Pass Arguments
# Jump to Callee
jal ra, FUNC_TEST
#######################################################
## Retrieve Caller Saved
bne a0, x0 ,main_pass
li s1, 88 # if not pass, load 88 to 0x01000000
sw s1, 0(s0)
li a0, 1 # return 1
j main_exit
main_pass:
# la s0, str_all_passed # load string
jal ra, print_str # go to print string
li s1, 66 # if pass, load 66 to 0x01000000
sw s1, 0(s0)
li a0, 0 # return 0
main_exit:
## Retrieve ra & Callee Saved
lw ra, 8(sp)
lw s0, 4(sp)
sw s1, 0(sp)
addi sp, sp, 12
## return
ret
print_str:
# print_str: s0=address of string
# For simulation, replace with ecall or system call as needed
ret
FUNC_TEST:
#######################################################
# < Function >
# test
#
# < Parameters >
# NULL
#
# < Return Value >
# NULL
#######################################################
# < Local Variable >
# s0 : pass
# s1 : previous_value
# t0 : fl
# t1 : value
# t2 : fl2
# t3 : i
#######################################################
## Save ra & Callee Saved
addi sp, sp, -12
sw s0, 8(sp)
sw s1, 4(sp)
sw ra, 0(sp)
li s1, -1 # previous_value(s1) = -1
li s0, 1 # passed(s0) = true
li t3, 0 # i(t3) = 0
test_loop:
li t4, 256
bge t3, t4, test_end
mv t0, t3 # fl(t0) = i(t3)
############### Call Function Procedure ###############
# Caller Saved
addi sp, sp, -16
sw t0, 12(sp)
sw t1, 8(sp)
sw t2, 4(sp)
sw t3, 0(sp)
# Pass Arguments
mv a0, t0
# Jump to Callee
jal ra, uf8_decode
## Retrieve Caller Saved
lw t0, 12(sp)
lw t1, 8(sp)
lw t2, 4(sp)
lw t3, 0(sp)
addi sp, sp, 16
mv t1, a0 # value(t1) = uf8_decode(fl)
#######################################################
############### Call Function Procedure ###############
# Caller Saved
addi sp, sp, -16
sw t0, 12(sp)
sw t1, 8(sp)
sw t2, 4(sp)
sw t3, 0(sp)
# Pass Arguments
mv a0, t1 # a0 = value(t1)
# Jump to Callee
jal ra, uf8_encode
## Retrieve Caller Saved
lw t0, 12(sp)
lw t1, 8(sp)
lw t2, 4(sp)
lw t3, 0(sp)
addi sp, sp, 16
mv t2, a0 # fl2(t2) = uf8_decode(value)
#######################################################
andi t0, t0, 0xff
andi t2, t2, 0xff
bne t0, t2, test_fail1
endif1:
bge s1, t1, test_fail2
endif2:
mv s1, t1 # previous_value(s1) = value(t1)
addi t3, t3, 1 # i(t3)++
j test_loop
test_fail1:
li s0, 0
# print fail1: skip for now
j endif1
test_fail2:
li s0, 0
# print fail2: skip for now
j endif2
test_end:
mv a0, s0 # return passed(s0)
## Retrieve ra & Callee Saved
lw s0, 8(sp)
lw s1, 4(sp)
lw ra, 0(sp)
addi sp, sp, 12
## return
ret
clz:
#######################################################
# < Function >
# clz
#
# < Parameters >
# a1 : x
#
# < Return Value >
# a1
#######################################################
# < Local Variable >
# t0 : n
# t1 : c
# t2 : y
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## function start
li t0, 32 # n = 32
li t1, 16 # c = 16
clz_loop:
srl t2, a1, t1 # y = x >> c
beq t2, x0, clz_skip
sub t0, t0, t1 # n -= c
mv a1, t2 # x = y
clz_skip:
srli t1, t1, 1 # c >>= 1
bne t1, x0, clz_loop
sub a1, t0, a1 # return n - x
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
uf8_decode:
#######################################################
# < Function >
# uf8_decode
#
# < Parameters >
# a0 : fl
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : mantissa
# t1 : exponent
# t2 : offset
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
andi t0, a0, 0x0f # mantissa = fl & 0x0f
srli t1, a0, 4 # exponent = fl >> 4
li t2, 0x7fff
li t3, 15
sub t3, t3, t1 # 15 - exponent
srl t2, t2, t3 # 0x7fff >> (15-exponent)
slli t2, t2, 4 # << 4
sll t0, t0, t1 # mantissa << exponent
add a0, t0, t2 # (mantissa << exponent) + offset
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
uf8_encode:
#######################################################
# < Function >
# uf8_encode
#
# < Parameters >
# a0 : value
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : lz
# t1 : msb
# t2 : exponent
# t3 : overflow
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## function start
li t0, 16
bltu a0, t0, uf8_encode_ret # if value < 16, return value
############### Call Function Procedure ###############
# Caller Saved
addi sp, sp, -16
sw t0, 12(sp)
sw t1, 8(sp)
sw t2, 4(sp)
sw t3, 0(sp)
# Pass Arguments
mv a1, a0
# Jump to Callee
jal ra, clz # ra = Addr(ra = lw t0, 20(sp) )
## Retrieve Caller Saved
lw t0, 12(sp)
lw t1, 8(sp)
lw t2, 4(sp)
lw t3, 0(sp)
addi sp, sp, 16
mv t0, a1 # lz = clz(value)
#######################################################
li t1, 31 # msb
sub t1, t1, t0 # msb(t1) = 31 - lz(t0)
li t2, 0 # exponent(t2) = 0
li t3, 0 # overflow(t3) = 0
li t4, 5
bge t1, t4, en_if1 # if(msb >=5)
j en_endif1 # else
en_if1:
addi t2, t1, -4 # exponent = msb - 4
li t4, 15
blt t4, t2, en_endif2 # if(exponent > 15)
li t2, 15 # exponent = 15
en_endif2:
li t4, 0 # e(t4) = 0
# li t6, 0
if1_for:
bge t4, t2, if1_for_end # if e(t4) >= exponent(t2)
slli t3, t3, 1
addi t3, t3, 16
addi t4, t4, 1 # e(t4)++
j if1_for
if1_for_end:
while1:
beq t2, x0, en_endif1 # exponent == 0
bltu a0, t3, in_while1 # value < overflow
j en_endif1
in_while1:
addi t3, t3, -16
srli t3, t3, 1
addi t2, t2, -1
j while1
en_endif1:
li t4, 15
in_while2:
bge t2, t4, end_while2 # exponent >= 15
slli t5, t3, 1 # next_onerflow(t5)
addi t5, t5, 16
bltu a0, t5, end_while2 # if value < next_overflow
mv t3, t5 # overflow = next_overflow
addi t2, t2, 1
j in_while2
end_while2:
sub t5, a0, t3 # mantissa(t5)
srl t5, t5, t2
slli t2, t2, 4
or a0, t2, t5
uf8_encode_ret:
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
```
## Problem C
### C code
```c
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
typedef struct {
uint16_t bits;
} bf16_t;
#define BF16_SIGN_MASK 0x8000U
#define BF16_EXP_MASK 0x7F80U
#define BF16_MANT_MASK 0x007FU
#define BF16_EXP_BIAS 127
#define BF16_NAN() ((bf16_t) {.bits = 0x7FC0})
#define BF16_ZERO() ((bf16_t) {.bits = 0x0000})
static inline bool bf16_isnan(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_isinf(bf16_t a)
{
return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) &&
!(a.bits & BF16_MANT_MASK);
}
static inline bool bf16_iszero(bf16_t a)
{
return !(a.bits & 0x7FFF);
}
static inline bf16_t f32_to_bf16(float val)
{
uint32_t f32bits;
memcpy(&f32bits, &val, sizeof(float));
if (((f32bits >> 23) & 0xFF) == 0xFF)
return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF};
f32bits += ((f32bits >> 16) & 1) + 0x7FFF;
return (bf16_t) {.bits = f32bits >> 16};
}
static inline float bf16_to_f32(bf16_t val)
{
uint32_t f32bits = ((uint32_t) val.bits) << 16;
float result;
memcpy(&result, &f32bits, sizeof(float));
return result;
}
static inline bf16_t bf16_add(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (exp_b == 0xFF)
return (mant_b || sign_a == sign_b) ? b : BF16_NAN();
return a;
}
if (exp_b == 0xFF)
return b;
if (!exp_a && !mant_a)
return b;
if (!exp_b && !mant_b)
return a;
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
int16_t exp_diff = exp_a - exp_b;
uint16_t result_sign;
int16_t result_exp;
uint32_t result_mant;
if (exp_diff > 0) {
result_exp = exp_a;
if (exp_diff > 8)
return a;
mant_b >>= exp_diff;
} else if (exp_diff < 0) {
result_exp = exp_b;
if (exp_diff < -8)
return b;
mant_a >>= -exp_diff;
} else {
result_exp = exp_a;
}
if (sign_a == sign_b) {
result_sign = sign_a;
result_mant = (uint32_t) mant_a + mant_b;
if (result_mant & 0x100) {
result_mant >>= 1;
if (++result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
} else {
if (mant_a >= mant_b) {
result_sign = sign_a;
result_mant = mant_a - mant_b;
} else {
result_sign = sign_b;
result_mant = mant_b - mant_a;
}
if (!result_mant)
return BF16_ZERO();
while (!(result_mant & 0x80)) {
result_mant <<= 1;
if (--result_exp <= 0)
return BF16_ZERO();
}
}
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F),
};
}
static inline bf16_t bf16_sub(bf16_t a, bf16_t b)
{
b.bits ^= BF16_SIGN_MASK;
return bf16_add(a, b);
}
static inline bf16_t bf16_mul(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_a == 0xFF) {
if (mant_a)
return a;
if (!exp_b && !mant_b)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_b == 0xFF) {
if (mant_b)
return b;
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if ((!exp_a && !mant_a) || (!exp_b && !mant_b))
return (bf16_t) {.bits = result_sign << 15};
int16_t exp_adjust = 0;
if (!exp_a) {
while (!(mant_a & 0x80)) {
mant_a <<= 1;
exp_adjust--;
}
exp_a = 1;
} else
mant_a |= 0x80;
if (!exp_b) {
while (!(mant_b & 0x80)) {
mant_b <<= 1;
exp_adjust--;
}
exp_b = 1;
} else
mant_b |= 0x80;
uint32_t result_mant = (uint32_t) mant_a * mant_b;
int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust;
if (result_mant & 0x8000) {
result_mant = (result_mant >> 8) & 0x7F;
result_exp++;
} else
result_mant = (result_mant >> 7) & 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0) {
if (result_exp < -6)
return (bf16_t) {.bits = result_sign << 15};
result_mant >>= (1 - result_exp);
result_exp = 0;
}
return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(result_mant & 0x7F)};
}
static inline bf16_t bf16_div(bf16_t a, bf16_t b)
{
uint16_t sign_a = (a.bits >> 15) & 1;
uint16_t sign_b = (b.bits >> 15) & 1;
int16_t exp_a = ((a.bits >> 7) & 0xFF);
int16_t exp_b = ((b.bits >> 7) & 0xFF);
uint16_t mant_a = a.bits & 0x7F;
uint16_t mant_b = b.bits & 0x7F;
uint16_t result_sign = sign_a ^ sign_b;
if (exp_b == 0xFF) {
if (mant_b)
return b;
/* Inf/Inf = NaN */
if (exp_a == 0xFF && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = result_sign << 15};
}
if (!exp_b && !mant_b) {
if (!exp_a && !mant_a)
return BF16_NAN();
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (exp_a == 0xFF) {
if (mant_a)
return a;
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
}
if (!exp_a && !mant_a)
return (bf16_t) {.bits = result_sign << 15};
if (exp_a)
mant_a |= 0x80;
if (exp_b)
mant_b |= 0x80;
uint32_t dividend = (uint32_t) mant_a << 15;
uint32_t divisor = mant_b;
uint32_t quotient = 0;
for (int i = 0; i < 16; i++) {
quotient <<= 1;
if (dividend >= (divisor << (15 - i))) {
dividend -= (divisor << (15 - i));
quotient |= 1;
}
}
int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS;
if (!exp_a)
result_exp--;
if (!exp_b)
result_exp++;
if (quotient & 0x8000)
quotient >>= 8;
else {
while (!(quotient & 0x8000) && result_exp > 1) {
quotient <<= 1;
result_exp--;
}
quotient >>= 8;
}
quotient &= 0x7F;
if (result_exp >= 0xFF)
return (bf16_t) {.bits = (result_sign << 15) | 0x7F80};
if (result_exp <= 0)
return (bf16_t) {.bits = result_sign << 15};
return (bf16_t) {
.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) |
(quotient & 0x7F),
};
}
static inline bf16_t bf16_sqrt(bf16_t a)
{
uint16_t sign = (a.bits >> 15) & 1;
int16_t exp = ((a.bits >> 7) & 0xFF);
uint16_t mant = a.bits & 0x7F;
/* Handle special cases */
if (exp == 0xFF) {
if (mant)
return a; /* NaN propagation */
if (sign)
return BF16_NAN(); /* sqrt(-Inf) = NaN */
return a; /* sqrt(+Inf) = +Inf */
}
/* sqrt(0) = 0 (handle both +0 and -0) */
if (!exp && !mant)
return BF16_ZERO();
/* sqrt of negative number is NaN */
if (sign)
return BF16_NAN();
/* Flush denormals to zero */
if (!exp)
return BF16_ZERO();
/* Direct bit manipulation square root algorithm */
/* For sqrt: new_exp = (old_exp - bias) / 2 + bias */
int32_t e = exp - BF16_EXP_BIAS;
int32_t new_exp;
/* Get full mantissa with implicit 1 */
uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */
/* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */
if (e & 1) {
m <<= 1; /* Double mantissa for odd exponent */
new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS;
} else {
new_exp = (e >> 1) + BF16_EXP_BIAS;
}
/* Now m is in range [128, 256) or [256, 512) if exponent was odd */
/* Binary search for integer square root */
/* We want result where result^2 = m * 128 (since 128 represents 1.0) */
uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */
uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */
uint32_t result = 128; /* Default */
/* Binary search for square root of m */
while (low <= high) {
uint32_t mid = (low + high) >> 1;
uint32_t sq = (mid * mid) / 128; /* Square and scale */
if (sq <= m) {
result = mid; /* This could be our answer */
low = mid + 1;
} else {
high = mid - 1;
}
}
/* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */
/* But we need to adjust the scale */
/* Since m is scaled where 128=1.0, result should also be scaled same way */
/* Normalize to ensure result is in [128, 256) */
if (result >= 256) {
result >>= 1;
new_exp++;
} else if (result < 128) {
while (result < 128 && new_exp > 1) {
result <<= 1;
new_exp--;
}
}
/* Extract 7-bit mantissa (remove implicit 1) */
uint16_t new_mant = result & 0x7F;
/* Check for overflow/underflow */
if (new_exp >= 0xFF)
return (bf16_t) {.bits = 0x7F80}; /* +Inf */
if (new_exp <= 0)
return BF16_ZERO();
return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant};
}
static inline bool bf16_eq(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return true;
return a.bits == b.bits;
}
static inline bool bf16_lt(bf16_t a, bf16_t b)
{
if (bf16_isnan(a) || bf16_isnan(b))
return false;
if (bf16_iszero(a) && bf16_iszero(b))
return false;
bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1;
if (sign_a != sign_b)
return sign_a > sign_b;
return sign_a ? a.bits > b.bits : a.bits < b.bits;
}
static inline bool bf16_gt(bf16_t a, bf16_t b)
{
return bf16_lt(b, a);
}
#include <stdio.h>
#include <time.h>
#define TEST_ASSERT(cond, msg) \
do { \
if (!(cond)) { \
printf("FAIL: %s\n", msg); \
return 1; \
} \
} while (0)
static int test_basic_conversions(void)
{
printf("Testing basic conversions...\n");
float test_values[] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.5f,
-0.5f, 3.14159f, -3.14159f, 1e10f, -1e10f};
for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); i++) {
float orig = test_values[i];
bf16_t bf = f32_to_bf16(orig);
float conv = bf16_to_f32(bf);
if (orig != 0.0f) {
TEST_ASSERT((orig < 0) == (conv < 0), "Sign mismatch");
}
if (orig != 0.0f && !bf16_isinf(f32_to_bf16(orig))) {
float diff = (conv - orig);
float rel_error = (diff < 0) ? -diff / orig : diff / orig;
TEST_ASSERT(rel_error < 0.01f, "Relative error too large");
}
}
printf(" Basic conversions: PASS\n");
return 0;
}
static int test_special_values(void)
{
printf("Testing special values...\n");
bf16_t pos_inf = {.bits = 0x7F80}; /* +Infinity */
TEST_ASSERT(bf16_isinf(pos_inf), "Positive infinity not detected");
TEST_ASSERT(!bf16_isnan(pos_inf), "Infinity detected as NaN");
bf16_t neg_inf = {.bits = 0xFF80}; /* -Infinity */
TEST_ASSERT(bf16_isinf(neg_inf), "Negative infinity not detected");
bf16_t nan_val = BF16_NAN();
TEST_ASSERT(bf16_isnan(nan_val), "NaN not detected");
TEST_ASSERT(!bf16_isinf(nan_val), "NaN detected as infinity");
bf16_t zero = f32_to_bf16(0.0f);
TEST_ASSERT(bf16_iszero(zero), "Zero not detected");
bf16_t neg_zero = f32_to_bf16(-0.0f);
TEST_ASSERT(bf16_iszero(neg_zero), "Negative zero not detected");
printf(" Special values: PASS\n");
return 0;
}
static int test_arithmetic(void)
{
printf("Testing arithmetic operations...\n");
bf16_t a = f32_to_bf16(1.0f);
bf16_t b = f32_to_bf16(2.0f);
bf16_t c = bf16_add(a, b);
float result = bf16_to_f32(c);
float diff = result - 3.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Addition failed");
c = bf16_sub(b, a);
result = bf16_to_f32(c);
diff = result - 1.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Subtraction failed");
a = f32_to_bf16(3.0f);
b = f32_to_bf16(4.0f);
c = bf16_mul(a, b);
result = bf16_to_f32(c);
diff = result - 12.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Multiplication failed");
a = f32_to_bf16(10.0f);
b = f32_to_bf16(2.0f);
c = bf16_div(a, b);
result = bf16_to_f32(c);
diff = result - 5.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Division failed");
/* Test square root */
a = f32_to_bf16(4.0f);
c = bf16_sqrt(a);
result = bf16_to_f32(c);
diff = result - 2.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(4) failed");
a = f32_to_bf16(9.0f);
c = bf16_sqrt(a);
result = bf16_to_f32(c);
diff = result - 3.0f;
TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(9) failed");
printf(" Arithmetic: PASS\n");
return 0;
}
static int test_comparisons(void)
{
printf("Testing comparison operations...\n");
bf16_t a = f32_to_bf16(1.0f);
bf16_t b = f32_to_bf16(2.0f);
bf16_t c = f32_to_bf16(1.0f);
TEST_ASSERT(bf16_eq(a, c), "Equality test failed");
TEST_ASSERT(!bf16_eq(a, b), "Inequality test failed");
TEST_ASSERT(bf16_lt(a, b), "Less than test failed");
TEST_ASSERT(!bf16_lt(b, a), "Not less than test failed");
TEST_ASSERT(!bf16_lt(a, c), "Equal not less than test failed");
TEST_ASSERT(bf16_gt(b, a), "Greater than test failed");
TEST_ASSERT(!bf16_gt(a, b), "Not greater than test failed");
bf16_t nan_val = BF16_NAN();
TEST_ASSERT(!bf16_eq(nan_val, nan_val), "NaN equality test failed");
TEST_ASSERT(!bf16_lt(nan_val, a), "NaN less than test failed");
TEST_ASSERT(!bf16_gt(nan_val, a), "NaN greater than test failed");
printf(" Comparisons: PASS\n");
return 0;
}
static int test_edge_cases(void)
{
printf("Testing edge cases...\n");
float tiny = 1e-45f;
bf16_t bf_tiny = f32_to_bf16(tiny);
float tiny_val = bf16_to_f32(bf_tiny);
TEST_ASSERT(bf16_iszero(bf_tiny) || (tiny_val < 0 ? -tiny_val : tiny_val) < 1e-37f,
"Tiny value handling");
float huge = 1e38f;
bf16_t bf_huge = f32_to_bf16(huge);
bf16_t bf_huge2 = bf16_mul(bf_huge, f32_to_bf16(10.0f));
TEST_ASSERT(bf16_isinf(bf_huge2), "Overflow should produce infinity");
bf16_t small = f32_to_bf16(1e-38f);
bf16_t smaller = bf16_div(small, f32_to_bf16(1e10f));
float smaller_val = bf16_to_f32(smaller);
TEST_ASSERT(bf16_iszero(smaller) || (smaller_val < 0 ? -smaller_val : smaller_val) < 1e-45f,
"Underflow should produce zero or denormal");
printf(" Edge cases: PASS\n");
return 0;
}
static int test_rounding(void)
{
printf("Testing rounding behavior...\n");
float exact = 1.5f;
bf16_t bf_exact = f32_to_bf16(exact);
float back_exact = bf16_to_f32(bf_exact);
TEST_ASSERT(back_exact == exact,
"Exact representation should be preserved");
float val = 1.0001f;
bf16_t bf = f32_to_bf16(val);
float back = bf16_to_f32(bf);
float diff2 = back - val;
TEST_ASSERT((diff2 < 0 ? -diff2 : diff2) < 0.001f, "Rounding error should be small");
printf(" Rounding: PASS\n");
return 0;
}
#ifndef BFLOAT16_NO_MAIN
int main(void)
{
printf("\n=== bfloat16 Test Suite ===\n\n");
int failed = 0;
failed |= test_basic_conversions();
failed |= test_special_values();
failed |= test_arithmetic();
failed |= test_comparisons();
failed |= test_edge_cases();
failed |= test_rounding();
if (failed) {
printf("\n=== TESTS FAILED ===\n");
return 1;
}
printf("\n=== ALL TESTS PASSED ===\n");
return 0;
}
#endif /* BFLOAT16_NO_MAIN */
```
### Assembly
```
.data
.text
setup:
li ra, -1
li sp, 0x7ffffff0
main:
#######################################################
# < Function >
# main procedure
#
# < Parameters >
# NULL
#
# < Return Value >
# NULL
#######################################################
# < Local Variable >
# s0: failed
#######################################################
## Save ra & Callee Saved
addi sp, sp, -8
sw ra, 4(sp)
sw s0, 0(sp)
li s0, 0x01000000
jal ra, BASIC_CONVERSIONS
and s0, s0, a0 # failed = test_basic_conversions()
jal ra, SPECIAL_VALUES
and s0, s0, a0 # failed = test_special_vaules()
jal ra, ARITHMETIC
and s0, s0, a0 # failed = test_arithmetic()
jal ra, COMPARISONS
and s0, s0, a0 # failed = test_comparisons()
jal ra, EDGE_CASED
and s0, s0, a0 # failed = test_edge_cases()
jal ra, ROUNDING
and s0, s0, a0 # failed = test_rounding()
bne a0, x0 ,main_pass
li s1, 88 # if not pass, load 88 to 0x01000000
sw s1, 0(s0)
li a0, 1 # return 1
j main_exit
main_pass:
li s1, 66 # if pass, load 66 to 0x01000000
sw s1, 0(s0)
li a0, 0 # return 0
main_exit:
## Retrieve ra & Callee Saved
lw ra, 4(sp)
lw s0, 0(sp)
addi sp, sp, 8
## return
ret
bf16_isnan:
#######################################################
# < Function >
# bf16_isnan
#
# < Parameters >
# a0 : a
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : return
# t1 : temp
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
andi t0, a0, 0x7f80 # t0 = a & BF16_EXP_MASK
xori t0, t0, 0x7f80 # t0 = (t0 == BF16_EXP_MASK)
sltiu t0, t0, 1 # t0 = (t0 < 1) ? 1 : 0
andi t1, a0, 0x007f # t1 = a & BF16_MANT_MASK
and t0, t0, t1 # t0 = t0 && t1
mv a0, t0 # return t0
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
bf16_isinf:
#######################################################
# < Function >
# bf16_isinf
#
# < Parameters >
# a0 : a
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : return
# t1 : temp
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
andi t0, a0, 0x7f80 # t0 = a & BF16_EXP_MASK
xori t0, t0, 0x7f80 # t0 = (t0 == BF16_EXP_MASK)
sltiu t0, t0, 1 # t0 = (t0 < 1) ? 1 : 0
andi t1, a0, 0x007f # t1 = a & BF16_MANT_MASK
xori t1, t1, 0xffff # t1 = !t1
and t0, t0, t1 # t0 = t0 && t1
mv a0, t0 # return t0
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
bf16_iszero:
#######################################################
# < Function >
# bf16_iszero
#
# < Parameters >
# a0 : a
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : return
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
andi t0, a0, 0x7fff # t0 = a & 0x7FFF
xori t0, t0, 0xffff # t0 = !t0
mv a0, t0 # return t0
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
f32_to_bf16:
#######################################################
# < Function >
# f32_to_bf16
#
# < Parameters >
# a0 : val
#
# < Return Value >
# a0: return
#######################################################
# < Local Variable >
# t0 : f32bits
# t1 : temp1
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
srli t0, a0, 23 # t0 = val >> 23
andi t0, t0, 0xff # t0 = t0 & 0xFF
xori t0, t0, 0xff # t0 = t0 == 0xFF
sltiu t0, t0, 1 # t0 = (t0 < 1) ? 1 : 0
li t1, 0
beq t0, t1, ftobf_else # if (t0 == 0) -> ftobf_else
srli t0, a0, 16 # t0 = val >> 16
andi t0, t0, 0xffff # t0 = t0 & 0xFFFF
mv a0, t0 # return t0
j ftof_end
ftobf_else:
srli t0, a0, 16 # t0 = val >> 16
andi t0, t0, 1 # t0 &= 1
addi t0, t0, 0x7fff # t0 += 0x7FFF
add t0, a0, t0 # t0 = val + ((val >> 16) & 1) + 0x7FFF
srli t0, t0, 16 # t0 = t0 >> 16
mv a0, t0 # return t0
ftobf_end:
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
bf16_to_f32:
#######################################################
# < Function >
# bf16_to_f32
#
# < Parameters >
# a0 : val
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : f32bits
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
slli t0, a0, 16 # t0 = val << 16
mv a0, t0 # return t0
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
bf16_add:
#######################################################
# < Function >
# bf16_add
#
# < Parameters >
# a0 : a
# a1 : b
#
# < Return Value >
# a0
#######################################################
# < Local Variable >
# t0 : f32bits
#######################################################
## Save ra & Callee Saved
addi sp, sp, -4
sw ra, 0(sp)
## funtion start
mv a0, t0 # return t0
## Retrieve ra & Callee Saved
lw ra, 0(sp)
addi sp, sp, 4
## return
ret
```