# Assignment1: RISC-V Assembly and Instruction Pipeline contributed by [<`Jackiempty`>](https://github.com/Jackiempty) ## Problem B ### C code ```c #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> typedef uint8_t uf8; static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } /* Test encode/decode round-trip */ static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } int main(void) { if (test()) { printf("All tests passed.\n"); return 0; } return 1; } ``` ### Assembly ```asm .data str_all_passed: .asciz "All tests passed.\n" str_fail1: .asciz "%02x: produces value %d but encodes back to %02x\n" str_fail2: .asciz "%02x: value %d <= previous_value %d\n" .text setup: li ra, -1 li sp, 0x7ffffff0 main: ####################################################### # < Function > # main procedure # # < Parameters > # NULL # # < Return Value > # NULL ####################################################### # < Local Variable > # s0: string ####################################################### ## Save ra & Callee Saved addi sp, sp, -12 sw ra, 8(sp) sw s0, 4(sp) sw s1, 0(sp) li s0, 0x01000000 ############### Call Function Procedure ############### # Caller Saved # Pass Arguments # Jump to Callee jal ra, FUNC_TEST ####################################################### ## Retrieve Caller Saved bne a0, x0 ,main_pass li s1, 88 # if not pass, load 88 to 0x01000000 sw s1, 0(s0) li a0, 1 # return 1 j main_exit main_pass: # la s0, str_all_passed # load string jal ra, print_str # go to print string li s1, 66 # if pass, load 66 to 0x01000000 sw s1, 0(s0) li a0, 0 # return 0 main_exit: ## Retrieve ra & Callee Saved lw ra, 8(sp) lw s0, 4(sp) sw s1, 0(sp) addi sp, sp, 12 ## return ret print_str: # print_str: s0=address of string # For simulation, replace with ecall or system call as needed ret FUNC_TEST: ####################################################### # < Function > # test # # < Parameters > # NULL # # < Return Value > # NULL ####################################################### # < Local Variable > # s0 : pass # s1 : previous_value # t0 : fl # t1 : value # t2 : fl2 # t3 : i ####################################################### ## Save ra & Callee Saved addi sp, sp, -12 sw s0, 8(sp) sw s1, 4(sp) sw ra, 0(sp) li s1, -1 # previous_value(s1) = -1 li s0, 1 # passed(s0) = true li t3, 0 # i(t3) = 0 test_loop: li t4, 256 bge t3, t4, test_end mv t0, t3 # fl(t0) = i(t3) ############### Call Function Procedure ############### # Caller Saved addi sp, sp, -16 sw t0, 12(sp) sw t1, 8(sp) sw t2, 4(sp) sw t3, 0(sp) # Pass Arguments mv a0, t0 # Jump to Callee jal ra, uf8_decode ## Retrieve Caller Saved lw t0, 12(sp) lw t1, 8(sp) lw t2, 4(sp) lw t3, 0(sp) addi sp, sp, 16 mv t1, a0 # value(t1) = uf8_decode(fl) ####################################################### ############### Call Function Procedure ############### # Caller Saved addi sp, sp, -16 sw t0, 12(sp) sw t1, 8(sp) sw t2, 4(sp) sw t3, 0(sp) # Pass Arguments mv a0, t1 # a0 = value(t1) # Jump to Callee jal ra, uf8_encode ## Retrieve Caller Saved lw t0, 12(sp) lw t1, 8(sp) lw t2, 4(sp) lw t3, 0(sp) addi sp, sp, 16 mv t2, a0 # fl2(t2) = uf8_decode(value) ####################################################### andi t0, t0, 0xff andi t2, t2, 0xff bne t0, t2, test_fail1 endif1: bge s1, t1, test_fail2 endif2: mv s1, t1 # previous_value(s1) = value(t1) addi t3, t3, 1 # i(t3)++ j test_loop test_fail1: li s0, 0 # print fail1: skip for now j endif1 test_fail2: li s0, 0 # print fail2: skip for now j endif2 test_end: mv a0, s0 # return passed(s0) ## Retrieve ra & Callee Saved lw s0, 8(sp) lw s1, 4(sp) lw ra, 0(sp) addi sp, sp, 12 ## return ret clz: ####################################################### # < Function > # clz # # < Parameters > # a1 : x # # < Return Value > # a1 ####################################################### # < Local Variable > # t0 : n # t1 : c # t2 : y ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## function start li t0, 32 # n = 32 li t1, 16 # c = 16 clz_loop: srl t2, a1, t1 # y = x >> c beq t2, x0, clz_skip sub t0, t0, t1 # n -= c mv a1, t2 # x = y clz_skip: srli t1, t1, 1 # c >>= 1 bne t1, x0, clz_loop sub a1, t0, a1 # return n - x ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret uf8_decode: ####################################################### # < Function > # uf8_decode # # < Parameters > # a0 : fl # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : mantissa # t1 : exponent # t2 : offset ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start andi t0, a0, 0x0f # mantissa = fl & 0x0f srli t1, a0, 4 # exponent = fl >> 4 li t2, 0x7fff li t3, 15 sub t3, t3, t1 # 15 - exponent srl t2, t2, t3 # 0x7fff >> (15-exponent) slli t2, t2, 4 # << 4 sll t0, t0, t1 # mantissa << exponent add a0, t0, t2 # (mantissa << exponent) + offset ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret uf8_encode: ####################################################### # < Function > # uf8_encode # # < Parameters > # a0 : value # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : lz # t1 : msb # t2 : exponent # t3 : overflow ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## function start li t0, 16 bltu a0, t0, uf8_encode_ret # if value < 16, return value ############### Call Function Procedure ############### # Caller Saved addi sp, sp, -16 sw t0, 12(sp) sw t1, 8(sp) sw t2, 4(sp) sw t3, 0(sp) # Pass Arguments mv a1, a0 # Jump to Callee jal ra, clz # ra = Addr(ra = lw t0, 20(sp) ) ## Retrieve Caller Saved lw t0, 12(sp) lw t1, 8(sp) lw t2, 4(sp) lw t3, 0(sp) addi sp, sp, 16 mv t0, a1 # lz = clz(value) ####################################################### li t1, 31 # msb sub t1, t1, t0 # msb(t1) = 31 - lz(t0) li t2, 0 # exponent(t2) = 0 li t3, 0 # overflow(t3) = 0 li t4, 5 bge t1, t4, en_if1 # if(msb >=5) j en_endif1 # else en_if1: addi t2, t1, -4 # exponent = msb - 4 li t4, 15 blt t4, t2, en_endif2 # if(exponent > 15) li t2, 15 # exponent = 15 en_endif2: li t4, 0 # e(t4) = 0 # li t6, 0 if1_for: bge t4, t2, if1_for_end # if e(t4) >= exponent(t2) slli t3, t3, 1 addi t3, t3, 16 addi t4, t4, 1 # e(t4)++ j if1_for if1_for_end: while1: beq t2, x0, en_endif1 # exponent == 0 bltu a0, t3, in_while1 # value < overflow j en_endif1 in_while1: addi t3, t3, -16 srli t3, t3, 1 addi t2, t2, -1 j while1 en_endif1: li t4, 15 in_while2: bge t2, t4, end_while2 # exponent >= 15 slli t5, t3, 1 # next_onerflow(t5) addi t5, t5, 16 bltu a0, t5, end_while2 # if value < next_overflow mv t3, t5 # overflow = next_overflow addi t2, t2, 1 j in_while2 end_while2: sub t5, a0, t3 # mantissa(t5) srl t5, t5, t2 slli t2, t2, 4 or a0, t2, t5 uf8_encode_ret: ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret ``` ## Problem C ### C code ```c #include <stdbool.h> #include <stdint.h> #include <string.h> typedef struct { uint16_t bits; } bf16_t; #define BF16_SIGN_MASK 0x8000U #define BF16_EXP_MASK 0x7F80U #define BF16_MANT_MASK 0x007FU #define BF16_EXP_BIAS 127 #define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) #define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F), }; } static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } static inline bool bf16_eq(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return true; return a.bits == b.bits; } static inline bool bf16_lt(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return false; bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1; if (sign_a != sign_b) return sign_a > sign_b; return sign_a ? a.bits > b.bits : a.bits < b.bits; } static inline bool bf16_gt(bf16_t a, bf16_t b) { return bf16_lt(b, a); } #include <stdio.h> #include <time.h> #define TEST_ASSERT(cond, msg) \ do { \ if (!(cond)) { \ printf("FAIL: %s\n", msg); \ return 1; \ } \ } while (0) static int test_basic_conversions(void) { printf("Testing basic conversions...\n"); float test_values[] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.5f, -0.5f, 3.14159f, -3.14159f, 1e10f, -1e10f}; for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); i++) { float orig = test_values[i]; bf16_t bf = f32_to_bf16(orig); float conv = bf16_to_f32(bf); if (orig != 0.0f) { TEST_ASSERT((orig < 0) == (conv < 0), "Sign mismatch"); } if (orig != 0.0f && !bf16_isinf(f32_to_bf16(orig))) { float diff = (conv - orig); float rel_error = (diff < 0) ? -diff / orig : diff / orig; TEST_ASSERT(rel_error < 0.01f, "Relative error too large"); } } printf(" Basic conversions: PASS\n"); return 0; } static int test_special_values(void) { printf("Testing special values...\n"); bf16_t pos_inf = {.bits = 0x7F80}; /* +Infinity */ TEST_ASSERT(bf16_isinf(pos_inf), "Positive infinity not detected"); TEST_ASSERT(!bf16_isnan(pos_inf), "Infinity detected as NaN"); bf16_t neg_inf = {.bits = 0xFF80}; /* -Infinity */ TEST_ASSERT(bf16_isinf(neg_inf), "Negative infinity not detected"); bf16_t nan_val = BF16_NAN(); TEST_ASSERT(bf16_isnan(nan_val), "NaN not detected"); TEST_ASSERT(!bf16_isinf(nan_val), "NaN detected as infinity"); bf16_t zero = f32_to_bf16(0.0f); TEST_ASSERT(bf16_iszero(zero), "Zero not detected"); bf16_t neg_zero = f32_to_bf16(-0.0f); TEST_ASSERT(bf16_iszero(neg_zero), "Negative zero not detected"); printf(" Special values: PASS\n"); return 0; } static int test_arithmetic(void) { printf("Testing arithmetic operations...\n"); bf16_t a = f32_to_bf16(1.0f); bf16_t b = f32_to_bf16(2.0f); bf16_t c = bf16_add(a, b); float result = bf16_to_f32(c); float diff = result - 3.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Addition failed"); c = bf16_sub(b, a); result = bf16_to_f32(c); diff = result - 1.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Subtraction failed"); a = f32_to_bf16(3.0f); b = f32_to_bf16(4.0f); c = bf16_mul(a, b); result = bf16_to_f32(c); diff = result - 12.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Multiplication failed"); a = f32_to_bf16(10.0f); b = f32_to_bf16(2.0f); c = bf16_div(a, b); result = bf16_to_f32(c); diff = result - 5.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Division failed"); /* Test square root */ a = f32_to_bf16(4.0f); c = bf16_sqrt(a); result = bf16_to_f32(c); diff = result - 2.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(4) failed"); a = f32_to_bf16(9.0f); c = bf16_sqrt(a); result = bf16_to_f32(c); diff = result - 3.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(9) failed"); printf(" Arithmetic: PASS\n"); return 0; } static int test_comparisons(void) { printf("Testing comparison operations...\n"); bf16_t a = f32_to_bf16(1.0f); bf16_t b = f32_to_bf16(2.0f); bf16_t c = f32_to_bf16(1.0f); TEST_ASSERT(bf16_eq(a, c), "Equality test failed"); TEST_ASSERT(!bf16_eq(a, b), "Inequality test failed"); TEST_ASSERT(bf16_lt(a, b), "Less than test failed"); TEST_ASSERT(!bf16_lt(b, a), "Not less than test failed"); TEST_ASSERT(!bf16_lt(a, c), "Equal not less than test failed"); TEST_ASSERT(bf16_gt(b, a), "Greater than test failed"); TEST_ASSERT(!bf16_gt(a, b), "Not greater than test failed"); bf16_t nan_val = BF16_NAN(); TEST_ASSERT(!bf16_eq(nan_val, nan_val), "NaN equality test failed"); TEST_ASSERT(!bf16_lt(nan_val, a), "NaN less than test failed"); TEST_ASSERT(!bf16_gt(nan_val, a), "NaN greater than test failed"); printf(" Comparisons: PASS\n"); return 0; } static int test_edge_cases(void) { printf("Testing edge cases...\n"); float tiny = 1e-45f; bf16_t bf_tiny = f32_to_bf16(tiny); float tiny_val = bf16_to_f32(bf_tiny); TEST_ASSERT(bf16_iszero(bf_tiny) || (tiny_val < 0 ? -tiny_val : tiny_val) < 1e-37f, "Tiny value handling"); float huge = 1e38f; bf16_t bf_huge = f32_to_bf16(huge); bf16_t bf_huge2 = bf16_mul(bf_huge, f32_to_bf16(10.0f)); TEST_ASSERT(bf16_isinf(bf_huge2), "Overflow should produce infinity"); bf16_t small = f32_to_bf16(1e-38f); bf16_t smaller = bf16_div(small, f32_to_bf16(1e10f)); float smaller_val = bf16_to_f32(smaller); TEST_ASSERT(bf16_iszero(smaller) || (smaller_val < 0 ? -smaller_val : smaller_val) < 1e-45f, "Underflow should produce zero or denormal"); printf(" Edge cases: PASS\n"); return 0; } static int test_rounding(void) { printf("Testing rounding behavior...\n"); float exact = 1.5f; bf16_t bf_exact = f32_to_bf16(exact); float back_exact = bf16_to_f32(bf_exact); TEST_ASSERT(back_exact == exact, "Exact representation should be preserved"); float val = 1.0001f; bf16_t bf = f32_to_bf16(val); float back = bf16_to_f32(bf); float diff2 = back - val; TEST_ASSERT((diff2 < 0 ? -diff2 : diff2) < 0.001f, "Rounding error should be small"); printf(" Rounding: PASS\n"); return 0; } #ifndef BFLOAT16_NO_MAIN int main(void) { printf("\n=== bfloat16 Test Suite ===\n\n"); int failed = 0; failed |= test_basic_conversions(); failed |= test_special_values(); failed |= test_arithmetic(); failed |= test_comparisons(); failed |= test_edge_cases(); failed |= test_rounding(); if (failed) { printf("\n=== TESTS FAILED ===\n"); return 1; } printf("\n=== ALL TESTS PASSED ===\n"); return 0; } #endif /* BFLOAT16_NO_MAIN */ ``` ### Assembly ``` .data .text setup: li ra, -1 li sp, 0x7ffffff0 main: ####################################################### # < Function > # main procedure # # < Parameters > # NULL # # < Return Value > # NULL ####################################################### # < Local Variable > # s0: failed ####################################################### ## Save ra & Callee Saved addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) li s0, 0x01000000 jal ra, BASIC_CONVERSIONS and s0, s0, a0 # failed = test_basic_conversions() jal ra, SPECIAL_VALUES and s0, s0, a0 # failed = test_special_vaules() jal ra, ARITHMETIC and s0, s0, a0 # failed = test_arithmetic() jal ra, COMPARISONS and s0, s0, a0 # failed = test_comparisons() jal ra, EDGE_CASED and s0, s0, a0 # failed = test_edge_cases() jal ra, ROUNDING and s0, s0, a0 # failed = test_rounding() bne a0, x0 ,main_pass li s1, 88 # if not pass, load 88 to 0x01000000 sw s1, 0(s0) li a0, 1 # return 1 j main_exit main_pass: li s1, 66 # if pass, load 66 to 0x01000000 sw s1, 0(s0) li a0, 0 # return 0 main_exit: ## Retrieve ra & Callee Saved lw ra, 4(sp) lw s0, 0(sp) addi sp, sp, 8 ## return ret bf16_isnan: ####################################################### # < Function > # bf16_isnan # # < Parameters > # a0 : a # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : return # t1 : temp ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start andi t0, a0, 0x7f80 # t0 = a & BF16_EXP_MASK xori t0, t0, 0x7f80 # t0 = (t0 == BF16_EXP_MASK) sltiu t0, t0, 1 # t0 = (t0 < 1) ? 1 : 0 andi t1, a0, 0x007f # t1 = a & BF16_MANT_MASK and t0, t0, t1 # t0 = t0 && t1 mv a0, t0 # return t0 ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret bf16_isinf: ####################################################### # < Function > # bf16_isinf # # < Parameters > # a0 : a # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : return # t1 : temp ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start andi t0, a0, 0x7f80 # t0 = a & BF16_EXP_MASK xori t0, t0, 0x7f80 # t0 = (t0 == BF16_EXP_MASK) sltiu t0, t0, 1 # t0 = (t0 < 1) ? 1 : 0 andi t1, a0, 0x007f # t1 = a & BF16_MANT_MASK xori t1, t1, 0xffff # t1 = !t1 and t0, t0, t1 # t0 = t0 && t1 mv a0, t0 # return t0 ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret bf16_iszero: ####################################################### # < Function > # bf16_iszero # # < Parameters > # a0 : a # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : return ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start andi t0, a0, 0x7fff # t0 = a & 0x7FFF xori t0, t0, 0xffff # t0 = !t0 mv a0, t0 # return t0 ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret f32_to_bf16: ####################################################### # < Function > # f32_to_bf16 # # < Parameters > # a0 : val # # < Return Value > # a0: return ####################################################### # < Local Variable > # t0 : f32bits # t1 : temp1 ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start srli t0, a0, 23 # t0 = val >> 23 andi t0, t0, 0xff # t0 = t0 & 0xFF xori t0, t0, 0xff # t0 = t0 == 0xFF sltiu t0, t0, 1 # t0 = (t0 < 1) ? 1 : 0 li t1, 0 beq t0, t1, ftobf_else # if (t0 == 0) -> ftobf_else srli t0, a0, 16 # t0 = val >> 16 andi t0, t0, 0xffff # t0 = t0 & 0xFFFF mv a0, t0 # return t0 j ftof_end ftobf_else: srli t0, a0, 16 # t0 = val >> 16 andi t0, t0, 1 # t0 &= 1 addi t0, t0, 0x7fff # t0 += 0x7FFF add t0, a0, t0 # t0 = val + ((val >> 16) & 1) + 0x7FFF srli t0, t0, 16 # t0 = t0 >> 16 mv a0, t0 # return t0 ftobf_end: ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret bf16_to_f32: ####################################################### # < Function > # bf16_to_f32 # # < Parameters > # a0 : val # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : f32bits ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start slli t0, a0, 16 # t0 = val << 16 mv a0, t0 # return t0 ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret bf16_add: ####################################################### # < Function > # bf16_add # # < Parameters > # a0 : a # a1 : b # # < Return Value > # a0 ####################################################### # < Local Variable > # t0 : f32bits ####################################################### ## Save ra & Callee Saved addi sp, sp, -4 sw ra, 0(sp) ## funtion start mv a0, t0 # return t0 ## Retrieve ra & Callee Saved lw ra, 0(sp) addi sp, sp, 4 ## return ret ```