# Assignment 1: RISC-V Assembly and Instruction Pipeline > contributed by < [`kkevinhu`](https://github.com/kkevinhu) > ## Introduction ### 1009. Complement of Base 10 Integer The complement of an integer is the integer you get when you flip all the 0's to 1's and all the 1's to 0's in its binary representation. For example, The integer 5 is "101" in binary and its complement is "010" which is the integer 2. Given an integer n, return its complement. Example 1: - Input: n = 5 - Output: 2 - Explanation: 5 is "101" in binary, with complement "010" in binary, which is 2 in base-10. Example 2: - Input: n = 7 - Output: 0 - Explanation: 7 is "111" in binary, with complement "000" in binary, which is 0 in base-10. ## Implementation You can find the source code and more completed details [here](https://github.com/kkevinhu/ca2025-quizzes). And the CPU I used in Ripes is `5-stage processor` (A 5-stage in-order processor with hazard detection/elimination and forwarding.) ### Motivation The main task is to find the bitmask covering the significant bits of a number. A simple loop-based approach shifts bits repeatedly, causing variable runtime. By using CLZ, we can determine the bit length directly and build the mask in constant time. This makes the algorithm faster, more predictable, and closer to hardware-efficient execution on architectures that support CLZ. #### For example : `5` -> `0000...0101`, its complement will be `1111...1010`, but only the lower 3 bits is matter, so we can use `CLZ()` to find the position of MSB is 3. Then we can make a bit mask like (1 << 3) - 1 = `0000...0111`. Finally get the result by `1111...1010` & `0000...0111` = `0000...0010` which decimal is `2`. ### C code without CLZ The code repeatedly shifts bits to find the highest set bit ```c= int bitwiseComplement(int n) { if (n == 0) return 1; int mask = 0, temp = n; while (temp > 0) { mask = (mask << 1) | 1; temp >>= 1; } return n ^ mask; } ``` ### C code with loopless CLZ CLZ directly uses efficient bit operations to find the most significant bit (MSB), without needing loops or right-shift masking. ```c= static inline unsigned clz(uint32_t x) { if (x == 0) return 32; // all bits are 0 → 32 leading zeros int n = 0; if ((x >> 16) == 0) { n += 16; x <<= 16; } if ((x >> 24) == 0) { n += 8; x <<= 8; } if ((x >> 28) == 0) { n += 4; x <<= 4; } if ((x >> 30) == 0) { n += 2; x <<= 2; } if ((x >> 31) == 0) n += 1; return n; } int bitwiseComplement(int n) { if (n==0) return 1; unsigned lz = clz(n); unsigned msb = 31 - lz; int mask = (1 << (msb + 1)) - 1; return n ^ mask; } ``` ### C code with branchless CLZ ```c= static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } int bitwiseComplement(int n) { if (n==0) return 1; int mask = (1 << (32 - clz(n))) - 1; return ~n & mask; } ``` ### Assembly for Complement of Base 10 Integer In this assembly, I converted the above C code into assembly. - `main` : Verified the correctness of function bitwiseComplement and CLZ - `bitwiseComplement` : * n is zero : return 1 * Otherwise : First, count leading zero, it can finds the position of MSB in n. Then we have to create bit mask that has 1s in all positions up to the MSB . Finally `XOR` flips all bits of n that fall under the mask. #### Without CLZ :::spoiler See More ```asm= .data tests: .word 0, 5, 7, 10, 121 # Test Input answers: .word 1, 2, 0, 5, 6 # Answers for test input n_tests: .word 5 msg_test: .string "Test " msg_input: .string ": input = " msg_result: .string ", result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, tests # s0 = tests la s1, answers # s1 = answers lw s2, n_tests # s2 = n_tests li s3, 0 # s3 = index loop_tests: beq s3, s2, done # Output "Test <index>" la a0, msg_test li a7, 4 ecall addi a0, s3, 1 li a7, 1 ecall la a0, msg_input li a7, 4 ecall # Get test input from tests slli t0, s3, 2 add t1, s0, t0 lw a0, 0(t1) addi t4, a0, 0 li a7, 1 ecall la a0, msg_result li a7, 4 ecall addi a0, t4, 0 jal ra, bitwiseComplement # a1 = result addi t5, a1, 0 # Output result addi a0, t5, 0 li a7, 1 ecall # Get ans from answers slli t0, s3, 2 add t2, s1, t0 lw t3, 0(t2) # Compare result and correct answers bne t5, t3, fail pass: la a0, msg_pass li a7, 4 ecall addi s3, s3, 1 j loop_tests fail: la a0, msg_fail li a7, 4 ecall addi s3, s3, 1 j loop_tests done: li a7, 10 # exit ecall bitwiseComplement: addi sp, sp, -16 sw ra, 12(sp) # if (n == 0) return 1; beqz a0, ret_one mv t0, a0 # temp = n li t1, 0 # mask = 0 build_mask: blez t0, mask_done # while (temp > 0) slli t1, t1, 1 # mask <<= 1 ori t1, t1, 1 # mask |= 1 srli t0, t0, 1 # temp >>= 1 j build_mask mask_done: xor a1, a0, t1 # return n ^ mask j done_func ret_one: li a1, 1 # return 1 done_func: lw ra, 12(sp) addi sp, sp, 16 jr ra ``` ::: - code line : 114 - cpu cycle : 506 #### With branchless CLZ :::spoiler See More ```asm= .data tests: .word 0, 5, 7, 10, 121 # Test Input answers: .word 1, 2, 0, 5, 6 # Answers for test input n_tests: .word 5 msg_test: .string "Test " msg_input: .string ": input = " msg_result: .string ", result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, tests # s0 = tests la s1, answers # s1 = answers lw s2, n_tests # s2 = n_tests li s3, 0 # s3 = index loop_tests: beq s3, s2, done # Output "Test <index>" la a0, msg_test li a7, 4 ecall addi a0, s3, 1 li a7, 1 ecall la a0, msg_input li a7, 4 ecall # Get test input from tests slli t0, s3, 2 add t1, s0, t0 lw a0, 0(t1) addi t4, a0, 0 li a7, 1 ecall la a0, msg_result li a7, 4 ecall addi a0, t4, 0 jal ra, bitwiseComplement # a1 = result addi t5, a1, 0 # Output result addi a0, t5, 0 li a7, 1 ecall # Get ans from answers slli t0, s3, 2 add t2, s1, t0 lw t3, 0(t2) # Compare result and correct answers bne t5, t3, fail pass: la a0, msg_pass li a7, 4 ecall addi s3, s3, 1 j loop_tests fail: la a0, msg_fail li a7, 4 ecall addi s3, s3, 1 j loop_tests done: li a7, 10 # exit ecall bitwiseComplement: addi sp, sp, -4 sw ra, 0(sp) beqz a0, zero jal ra, clz addi t0, a1, 0 li t1, 32 sub t0, t1, t0 li t1, 1 sll t0, t1, t0 sub t0, t0, t1 xor a1, a0, t0 j return zero: addi a1, a0, 1 j return return: lw ra, 0(sp) addi sp, sp, 4 ret clz: addi sp, sp, -16 sw ra, 12(sp) sw a0, 8(sp) li t0, 32 li t1, 16 loop: srl t2, a0, t1 beqz t2, skip sub t0, t0, t1 addi a0, t2, 0 skip: srli t1, t1, 1 bnez t1, loop sub a1, t0, a0 lw ra, 12(sp) lw a0, 8(sp) addi sp, sp, 16 jr ra ``` ::: - code line : 128 - cpu cycle : 600 #### With loopless CLZ :::spoiler See More ```asm= .data tests: .word 0, 5, 7, 10, 121 # Test Input answers: .word 1, 2, 0, 5, 6 # Answers for test input n_tests: .word 5 msg_test: .string "Test " msg_input: .string ": input = " msg_result: .string ", result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, tests # s0 = tests la s1, answers # s1 = answers lw s2, n_tests # s2 = n_tests li s3, 0 # s3 = index loop_tests: beq s3, s2, done # Output "Test <index>" la a0, msg_test li a7, 4 ecall addi a0, s3, 1 li a7, 1 ecall la a0, msg_input li a7, 4 ecall # Get test input from tests slli t0, s3, 2 add t1, s0, t0 lw a0, 0(t1) addi t4, a0, 0 li a7, 1 ecall la a0, msg_result li a7, 4 ecall addi a0, t4, 0 jal ra, bitwiseComplement # a1 = result addi t5, a1, 0 # Output result addi a0, t5, 0 li a7, 1 ecall # Get ans from answers slli t0, s3, 2 add t2, s1, t0 lw t3, 0(t2) # Compare result and correct answers bne t5, t3, fail pass: la a0, msg_pass li a7, 4 ecall addi s3, s3, 1 j loop_tests fail: la a0, msg_fail li a7, 4 ecall addi s3, s3, 1 j loop_tests done: li a7, 10 # exit ecall bitwiseComplement: addi sp, sp, -4 sw ra, 0(sp) beqz a0, zero jal ra, clz addi t0, a1, 0 li t1, 32 sub t0, t1, t0 li t1, 1 sll t0, t1, t0 sub t0, t0, t1 xor a1, a0, t0 j return zero: addi a1, a0, 1 j return return: lw ra, 0(sp) addi sp, sp, 4 ret clz: addi sp, sp, -8 sw ra, 0(sp) sw a0, 4(sp) beqz a0, clz_zero # if x == 0 -> return 32 li t0, 0 # n = 0 chk_16: srli t1, a0, 16 bnez t1, chk_8 # if (x >> 16) != 0 -> skip addi t0, t0, 16 # n += 16 slli a0, a0, 16 # x <<= 16 chk_8: srli t1, a0, 24 bnez t1, chk_4 addi t0, t0, 8 slli a0, a0, 8 chk_4: srli t1, a0, 28 bnez t1, chk_2 addi t0, t0, 4 slli a0, a0, 4 chk_2: srli t1, a0, 30 bnez t1, chk_31 addi t0, t0, 2 slli a0, a0, 2 chk_31: srli t1, a0, 31 beqz t1, add_one # if bit31 == 0 -> add 1 j clz_done add_one: addi t0, t0, 1 clz_done: mv a1, t0 j clz_return clz_zero: li a1, 32 # return 32 if input = 0 clz_return: lw ra, 0(sp) lw a0, 4(sp) addi sp, sp, 8 ret ``` ::: - code line : 159 - cpu cycle : 550 ### Analysis - Time complexity - Due to the loop in branchless CLZ at most run $log16$ = $4$ times, so it's time complexity will be constant $O(1)$ - Due to loopless CLZ doesn't has any loop, so it's time complexity will be $O(1)$ - So I reduce time complexity from $O(logn)$ to $O(1)$ - Cycle (branchless v.s loopless) - I reduce cycles from 600 to 550, so loopless CLZ has more outstanding performance than branchless CLZ - Fewer branch misprediction penalty - Fixed number of operations - Better pipeline utilization - Summary : - Using CLZ achieves better time complexity compared to the version without CLZ. Although the non-CLZ may have fewer CPU cycles in some cases, its cycle increases significantly when tested with extreme input values. - Moreover, the loopless version requires even fewer cycles than the branchless version to complete the same operation. ## uf8 In this part, I transfer `q1-uf8.c` into Assembly, and I also used some test data to verify each function's correctness. ### C code :::spoiler See More ```c= #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> typedef uint8_t uf8; static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } /* Test encode/decode round-trip */ static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } int main(void) { if (test()) { printf("All tests passed.\n"); return 0; } return 1; } ``` ::: ### Assembly - #### uf8-decode :::spoiler See More ```asm= .data test: .word 0x00 .word 0xf .word 0x68 .word 0xa9 .word 0xFF ans: .word 0x00 .word 0xf .word 0x5f0 .word 0x63f0 .word 0xF7FF0 n_tests: .word 5 msg_input: .string "Input = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, test la s1, ans lw s2, n_tests li s4, 0 # index loop_t: beq s4, s2, done la a0, msg_input li a7, 4 ecall lw s3, 0(s0) addi a0, s3, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, decode addi t0, a1, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s1) addi s0, s0, 4 addi s1, s1, 4 addi s4, s4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall decode: addi sp, sp, -16 sw ra, 12(sp) sw s3, 8(sp) # s3 = f1 andi t0, s3, 0x0f srli t1, s3, 4 li t2, 15 sub t2, t2, t1 li t3, 0x7FFF srl t3, t3, t2 slli t3, t3, 4 sll t0, t0, t1 add a1, t0, t3 lw ra, 12(sp) lw s3, 8(sp) addi sp, sp, 16 jr ra ``` ::: - #### uf8-clz :::spoiler See More ```asm= .data test: .word 0x00 .word 0x24 .word 0x210 .word 0x63ff0 .word 0x10000000 ans: .word 0x20 .word 0x1a .word 0x16 .word 0xd .word 0x3 n_tests: .word 5 msg_input: .string "Input = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, test la s1, ans lw s2, n_tests li s4, 0 # index loop_t: beq s4, s2, done la a0, msg_input li a7, 4 ecall lw s3, 0(s0) addi a0, s3, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, clz addi t0, a1, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s1) addi s0, s0, 4 addi s1, s1, 4 addi s4, s4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall clz: addi sp, sp, -16 sw ra, 12(sp) sw s3, 8(sp) # s3 = x li t0, 32 # t0 = n li t1, 16 # t1 = c loop: srl t2, s3, t1 # t2 = y beqz t2, skip sub t0, t0, t1 addi s3, t2, 0 skip: srli t1, t1, 1 bnez t1, loop sub a1, t0, s3 lw ra, 12(sp) lw s3, 8(sp) addi sp, sp, 16 jr ra ``` ::: - #### uf8-encode :::spoiler See More ```asm= .data test: .word 0x00 .word 0xf .word 0xd0 .word 0x2df0 .word 0x000F7FF0 ans: .word 0x00 .word 0xf .word 0x3c .word 0x97 .word 0xFF n_tests: .word 5 msg_input: .string "Input = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .global main main: la s0, test la s1, ans lw s2, n_tests li s4, 0 # index loop_t: beq s4, s2, done la a0, msg_input li a7, 4 ecall lw s3, 0(s0) addi a0, s3, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, encode addi t0, a1, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s1) addi s0, s0, 4 addi s1, s1, 4 addi s4, s4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall encode: addi sp, sp, -32 sw ra, 28(sp) sw s3, 24(sp) # s3 = x sw s0, 20(sp) sw s1, 16(sp) sw s2, 12(sp) li t0, 16 blt s3, t0, special jal ra, clz addi t1, a1, 0 # t1 = lz li t2, 31 sub t1, t2, t1 # t1 = msb li t2, 5 blt t1, t2, find_exact_exp addi s0, t1, -4 # s0 = exponent li t2, 15 bgt s0, t2, limit_exp j est_loop_init limit_exp: addi s0, t2, 0 est_loop_init: li s1, 0 # s1 = overflow li t1, 0 # t1 = e est_loop: bge t1, s0, adjust_est slli t2, s1, 1 addi s1, t2, 16 addi t1, t1, 1 j est_loop adjust_est: beqz s0, find_exact_exp adjust_loop: ble s3, s1, adjust_inner j find_exact_exp adjust_inner: addi t2, s1, -16 srli s1, t2, 1 addi s0, s0, -1 bnez s0, adjust_loop find_exact_exp: li t2, 15 find_loop: bge s0, t2, find_done slli t0, s1, 1 addi t0, t0, 16 blt s3, t0, find_done addi s1, t0, 0 addi s0, s0, 1 j find_loop find_done: sub t0, s3, s1 srl s2, t0, s0 # s2 = mantissa slli s0, s0, 4 or a1, s0, s2 return: lw ra, 28(sp) lw s3, 24(sp) # s3 = x lw s0, 20(sp) lw s1, 16(sp) lw s2, 12(sp) addi sp, sp, 32 ret special: addi a1, s3, 0 j return clz: addi sp, sp, -16 sw ra, 12(sp) sw s3, 8(sp) # s3 = x li t0, 32 # t0 = n li t1, 16 # t1 = c loop: srl t2, s3, t1 # t2 = y beqz t2, skip sub t0, t0, t1 addi s3, t2, 0 skip: srli t1, t1, 1 bnez t1, loop sub a1, t0, s3 lw ra, 12(sp) lw s3, 8(sp) addi sp, sp, 16 jr ra ``` ::: ## bfloat16 In this part, I transfer `q1-bfloat16.c` into Assembly, and I also used some test data to verify each function's correctness. ### C code :::spoiler See More ```c= #include <stdbool.h> #include <stdint.h> #include <string.h> typedef struct { uint16_t bits; } bf16_t; #define BF16_SIGN_MASK 0x8000U #define BF16_EXP_MASK 0x7F80U #define BF16_MANT_MASK 0x007FU #define BF16_EXP_BIAS 127 #define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) #define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F), }; } static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } static inline bool bf16_eq(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return true; return a.bits == b.bits; } static inline bool bf16_lt(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return false; bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1; if (sign_a != sign_b) return sign_a > sign_b; return sign_a ? a.bits > b.bits : a.bits < b.bits; } static inline bool bf16_gt(bf16_t a, bf16_t b) { return bf16_lt(b, a); } #include <stdio.h> #include <time.h> #define TEST_ASSERT(cond, msg) \ do { \ if (!(cond)) { \ printf("FAIL: %s\n", msg); \ return 1; \ } \ } while (0) static int test_basic_conversions(void) { printf("Testing basic conversions...\n"); float test_values[] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.5f, -0.5f, 3.14159f, -3.14159f, 1e10f, -1e10f}; for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); i++) { float orig = test_values[i]; bf16_t bf = f32_to_bf16(orig); float conv = bf16_to_f32(bf); if (orig != 0.0f) { TEST_ASSERT((orig < 0) == (conv < 0), "Sign mismatch"); } if (orig != 0.0f && !bf16_isinf(f32_to_bf16(orig))) { float diff = (conv - orig); float rel_error = (diff < 0) ? -diff / orig : diff / orig; TEST_ASSERT(rel_error < 0.01f, "Relative error too large"); } } printf(" Basic conversions: PASS\n"); return 0; } static int test_special_values(void) { printf("Testing special values...\n"); bf16_t pos_inf = {.bits = 0x7F80}; /* +Infinity */ TEST_ASSERT(bf16_isinf(pos_inf), "Positive infinity not detected"); TEST_ASSERT(!bf16_isnan(pos_inf), "Infinity detected as NaN"); bf16_t neg_inf = {.bits = 0xFF80}; /* -Infinity */ TEST_ASSERT(bf16_isinf(neg_inf), "Negative infinity not detected"); bf16_t nan_val = BF16_NAN(); TEST_ASSERT(bf16_isnan(nan_val), "NaN not detected"); TEST_ASSERT(!bf16_isinf(nan_val), "NaN detected as infinity"); bf16_t zero = f32_to_bf16(0.0f); TEST_ASSERT(bf16_iszero(zero), "Zero not detected"); bf16_t neg_zero = f32_to_bf16(-0.0f); TEST_ASSERT(bf16_iszero(neg_zero), "Negative zero not detected"); printf(" Special values: PASS\n"); return 0; } static int test_arithmetic(void) { printf("Testing arithmetic operations...\n"); bf16_t a = f32_to_bf16(1.0f); bf16_t b = f32_to_bf16(2.0f); bf16_t c = bf16_add(a, b); float result = bf16_to_f32(c); float diff = result - 3.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Addition failed"); c = bf16_sub(b, a); result = bf16_to_f32(c); diff = result - 1.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Subtraction failed"); a = f32_to_bf16(3.0f); b = f32_to_bf16(4.0f); c = bf16_mul(a, b); result = bf16_to_f32(c); diff = result - 12.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Multiplication failed"); a = f32_to_bf16(10.0f); b = f32_to_bf16(2.0f); c = bf16_div(a, b); result = bf16_to_f32(c); diff = result - 5.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Division failed"); /* Test square root */ a = f32_to_bf16(4.0f); c = bf16_sqrt(a); result = bf16_to_f32(c); diff = result - 2.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(4) failed"); a = f32_to_bf16(9.0f); c = bf16_sqrt(a); result = bf16_to_f32(c); diff = result - 3.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(9) failed"); printf(" Arithmetic: PASS\n"); return 0; } static int test_comparisons(void) { printf("Testing comparison operations...\n"); bf16_t a = f32_to_bf16(1.0f); bf16_t b = f32_to_bf16(2.0f); bf16_t c = f32_to_bf16(1.0f); TEST_ASSERT(bf16_eq(a, c), "Equality test failed"); TEST_ASSERT(!bf16_eq(a, b), "Inequality test failed"); TEST_ASSERT(bf16_lt(a, b), "Less than test failed"); TEST_ASSERT(!bf16_lt(b, a), "Not less than test failed"); TEST_ASSERT(!bf16_lt(a, c), "Equal not less than test failed"); TEST_ASSERT(bf16_gt(b, a), "Greater than test failed"); TEST_ASSERT(!bf16_gt(a, b), "Not greater than test failed"); bf16_t nan_val = BF16_NAN(); TEST_ASSERT(!bf16_eq(nan_val, nan_val), "NaN equality test failed"); TEST_ASSERT(!bf16_lt(nan_val, a), "NaN less than test failed"); TEST_ASSERT(!bf16_gt(nan_val, a), "NaN greater than test failed"); printf(" Comparisons: PASS\n"); return 0; } static int test_edge_cases(void) { printf("Testing edge cases...\n"); float tiny = 1e-45f; bf16_t bf_tiny = f32_to_bf16(tiny); float tiny_val = bf16_to_f32(bf_tiny); TEST_ASSERT(bf16_iszero(bf_tiny) || (tiny_val < 0 ? -tiny_val : tiny_val) < 1e-37f, "Tiny value handling"); float huge = 1e38f; bf16_t bf_huge = f32_to_bf16(huge); bf16_t bf_huge2 = bf16_mul(bf_huge, f32_to_bf16(10.0f)); TEST_ASSERT(bf16_isinf(bf_huge2), "Overflow should produce infinity"); bf16_t small = f32_to_bf16(1e-38f); bf16_t smaller = bf16_div(small, f32_to_bf16(1e10f)); float smaller_val = bf16_to_f32(smaller); TEST_ASSERT(bf16_iszero(smaller) || (smaller_val < 0 ? -smaller_val : smaller_val) < 1e-45f, "Underflow should produce zero or denormal"); printf(" Edge cases: PASS\n"); return 0; } static int test_rounding(void) { printf("Testing rounding behavior...\n"); float exact = 1.5f; bf16_t bf_exact = f32_to_bf16(exact); float back_exact = bf16_to_f32(bf_exact); TEST_ASSERT(back_exact == exact, "Exact representation should be preserved"); float val = 1.0001f; bf16_t bf = f32_to_bf16(val); float back = bf16_to_f32(bf); float diff2 = back - val; TEST_ASSERT((diff2 < 0 ? -diff2 : diff2) < 0.001f, "Rounding error should be small"); printf(" Rounding: PASS\n"); return 0; } #ifndef BFLOAT16_NO_MAIN int main(void) { printf("\n=== bfloat16 Test Suite ===\n\n"); int failed = 0; failed |= test_basic_conversions(); failed |= test_special_values(); failed |= test_arithmetic(); failed |= test_comparisons(); failed |= test_edge_cases(); failed |= test_rounding(); if (failed) { printf("\n=== TESTS FAILED ===\n"); return 1; } printf("\n=== ALL TESTS PASSED ===\n"); return 0; } #endif /* BFLOAT16_NO_MAIN */ ``` ::: ### Assembly - #### fp32_to_bf16 :::spoiler See More ```asm= .data test: .word 0x7FC00000 # NaN .word 0x7F800000 # inf .word 0x3F818000 # round up .word 0x3F808000 # round down ans: .word 0x7FC0 .word 0x7F80 .word 0x3F82 .word 0x3F80 n_tests: .word 4 msg_input: .string "Input = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, test la s1, ans lw s2, n_tests li s4, 0 # index loop_t: beq s4, s2, done la a0, msg_input li a7, 4 ecall lw s3, 0(s0) addi a0, s3, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, f32_to_bf16 addi t0, a1, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s1) addi s0, s0, 4 addi s1, s1, 4 addi s4, s4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall f32_to_bf16: addi sp, sp, -16 sw ra, 12(sp) sw s3, 8(sp) li t0, 0xFF srli t1, s3, 23 and t1, t0, t1 beq t0, t1, isNAN srli t0, s3, 16 andi t0, t0, 1 li t1, 0x7FFF add t0, t0, t1 add a1, s3, t0 srli a1, a1, 16 lw ra, 12(sp) addi sp, sp, 16 jr ra isNAN: srli a1, s3, 16 ret ``` ::: - #### bf16_to_fp32 :::spoiler See More ```asm= .data test: .word 0x0000 .word 0x7F80 .word 0xBF82 ans: .word 0x00000000 .word 0x7F800000 .word 0xBF820000 n_tests: .word 3 msg_input: .string "Input = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s0, test la s1, ans lw s2, n_tests li s4, 0 # index loop_t: beq s4, s2, done la a0, msg_input li a7, 4 ecall lw s3, 0(s0) addi a0, s3, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, bf16_to_f32 addi t0, a1, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s1) addi s0, s0, 4 addi s1, s1, 4 addi s4, s4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall bf16_to_f32: addi sp, sp, -16 sw ra, 12(sp) sw s3, 8(sp) slli a1, s3, 16 lw ra, 12(sp) addi sp, sp, 16 jr ra ``` ::: - #### bf16_add :::spoiler See More ```asm= .data test1: .word 0x0000 # 0 .word 0x7F80 # inf .word 0x7FC0 # NaN .word 0x3F81 # 1.0078125 .word 0xBF82 # -1.015625 test2: .word 0x0000 .word 0x7F80 .word 0x7FC0 .word 0x3F81 .word 0xBF82 ans: .word 0x0000 .word 0x7F80 .word 0x7FC0 .word 0x4001 .word 0xC002 n_tests: .word 5 msg_input_a: .string "A = " msg_input_b: .string ", B = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s6, test1 la s7, test2 la s10, ans lw s11, n_tests li a4, 0 # index loop_t: beq a4, s11, done la a0, msg_input_a li a7, 4 ecall lw s8, 0(s6) addi a0, s8, 0 li a7, 34 ecall la a0, msg_input_b li a7, 4 ecall lw s9, 0(s7) addi a0, s9, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, bf16_add addi t0, a2, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s10) addi s6, s6, 4 addi s10, s10, 4 addi s7, s7, 4 addi a4, a4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall bf16_add: addi sp, sp, -16 sw ra, 12(sp) sw s8, 8(sp) sw s9, 4(sp) # extract sign/exponent/mantissa from a0 (a) srli s0, s8, 15 # sign_a = (a.bits >> 15) srli s2, s8, 7 # exp_a = (a.bits >> 7) & 0xFF andi s2, s2, 0xFF andi s4, s8, 0x7F # mant_a = a.bits & 0x7F # extract sign/exponent/mantissa from a1 (b) srli s1, s9, 15 srli s3, s9, 7 andi s3, s3, 0xFF andi s5, s9, 0x7F # NaN / INF check li t6, 0xFF beq s2, t6, check_b_naninf # if exp_a == 0xFF check_b_naninf: beq s3, t6, handle_b_naninf check_a_naninf: bne s2, t6, check_b_zero # if exp_a != 0xFF skip bnez s4, return_nan # if mant_a != 0 => NaN beq s3, t6, both_inf # if both Inf j return_a # else return Inf (a) handle_b_naninf: bnez s5, return_nan # if mant_b != 0 => NaN j return_b # else Inf(b) both_inf: beq s0, s1, return_a # same sign -> Inf j return_nan # opposite sign -> NaN # Normal zero handling check_b_zero: or t3, s2, s4 beqz t3, return_b or t3, s3, s5 beqz t3, return_a beqz s2, r1 ori s4, s4, 0x80 r1: beqz s3, r2 ori s5, s5, 0x80 r2: # Exp_diff sub t3, s2, s3 li t4, 8 bgt t3, t4, return_a neg t5, t3 bgt t5, t4, return_b # Align bgtz t3, shift_b bltz t3, shift_a j aligned shift_b: srl s5, s5, t3 mv t1, s2 j compute shift_a: neg t3, t3 srl s4, s4, t3 mv t1, s3 j compute aligned: mv t1, s2 # Compute compute: beq s0, s1, same_sign # Different signs slt t4, s4, s5 beqz t4, a_ge_b mv t0, s1 sub t2, s5, s4 j normalize a_ge_b: mv t0, s0 sub t2, s4, s5 j normalize # Same sign same_sign: mv t0, s0 add t2, s4, s5 li t3, 0x100 and t3, t2, t3 beqz t3, pack srli t2, t2, 1 addi t1, t1, 1 j pack # Normalize normalize: beqz t2, return_zero norm_loop: andi t3, t2, 0x80 bnez t3, pack slli t2, t2, 1 addi t1, t1, -1 blez t1, return_zero j norm_loop # Result pack: andi t2, t2, 0x7F andi t1, t1, 0xFF slli t1, t1, 7 slli t0, t0, 15 or a2, t1, t2 or a2, a2, t0 j return return_nan: li a2, 0x7FC0 j return return_a: mv a2, s8 j return return_b: mv a2, s9 j return return_zero: li a2, 0 j return return: lw ra, 12(sp) lw s8, 8(sp) lw s9, 4(sp) addi sp, sp, 16 jr ra ``` ::: - #### bf16_mul :::spoiler See More ```asm= .data test1: .word 0x0000 # 0 .word 0x7F80 # inf .word 0x7FC0 # NaN .word 0x3F81 # 1.0078125 .word 0xBF82 # -1.015625 test2: .word 0x0000 .word 0x7F80 .word 0x7FC0 .word 0x3F81 .word 0xBF82 ans: .word 0x0000 .word 0x7F80 .word 0x7FC0 .word 0x3F82 .word 0x3F84 n_tests: .word 5 msg_input_a: .string "A = " msg_input_b: .string ", B = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s6, test1 la s7, test2 la s10, ans lw s11, n_tests li a4, 0 # index loop_t: beq a4, s11, done la a0, msg_input_a li a7, 4 ecall lw s8, 0(s6) addi a0, s8, 0 li a7, 34 ecall la a0, msg_input_b li a7, 4 ecall lw s9, 0(s7) addi a0, s9, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, bf16_mul addi t0, a2, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s10) addi s6, s6, 4 addi s10, s10, 4 addi s7, s7, 4 addi a4, a4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall bf16_mul: addi sp, sp, -16 sw ra, 12(sp) # extract sign/exponent/mantissa from a (a0) srli s0, s8, 15 # sign_a srli s2, s8, 7 andi s2, s2, 0xFF # exp_a andi s4, s8, 0x7F # mant_a # extract sign/exponent/mantissa from b (a1) srli s1, s9, 15 # sign_b srli s3, s9, 7 andi s3, s3, 0xFF # exp_b andi s5, s9, 0x7F # mant_b # result_sign = sign_a ^ sign_b xor s0, s0, s1 # if exp_a == 0xFF li t1, 0xFF beq s2, t1, check_a_inf check_b_inf: beq s3, t1, check_b_nan_inf check_zero: beqz s2, check_a_zero beqz s3, check_b_zero norm_mant: # normalize a beqz s2, norm_a_sub ori s4, s4, 0x80 j norm_b norm_a_sub: li t2, 0 norm_a_shift: andi t3, s4, 0x80 bnez t3, norm_a_done slli s4, s4, 1 addi t2, t2, -1 j norm_a_shift norm_a_done: li s2, 1 mv t5, t2 norm_b: beqz s3, norm_b_sub ori s5, s5, 0x80 j mul_core norm_b_sub: li t2, 0 norm_b_shift: andi t3, s5, 0x80 bnez t3, norm_b_done slli s5, s5, 1 addi t2, t2, -1 j norm_b_shift norm_b_done: li s3, 1 add t5, t5, t2 # exp_adjust mul_core: # result_mant = mant_a * mant_b mul t6, s4, s5 # result_exp = exp_a + exp_b - 127 + exp_adjust add t1, s2, s3 add t1, t1, t5 addi t1, t1, -127 # normalize mantissa li t0, 0x8000 and t2, t6, t0 beqz t2, mant_shift7 srli t6, t6, 8 andi t6, t6, 0x7F addi t1, t1, 1 j check_exp mant_shift7: srli t6, t6, 7 andi t6, t6, 0x7F check_exp: li t3, 0xFF bge t1, t3, ret_inf blez t1, underflow # normal result slli t0, s0, 15 slli t1, t1, 7 or a2, t0, t1 or a2, a2, t6 j done_mul check_a_inf: andi t2, s4, 0x7F bnez t2, ret_a beq s3, t1, both_inf j ret_inf check_b_nan_inf: andi t2, s5, 0x7F bnez t2, ret_b beq s2, x0, ret_nan j ret_inf check_a_zero: beqz s4, ret_b check_b_zero: beqz s5, ret_a j norm_mant underflow: li t2, -6 blt t1, t2, ret_zero li t3, 1 sub t3, t3, t1 srl t6, t6, t3 li t1, 0 slli t0, t0, 15 slli t1, t1, 7 or a2, t0, t1 or a2, a2, t6 j done_mul ret_inf: li a2, 0x7F80 slli t0, t0, 15 or a2, a2, t0 j done_mul ret_a: mv a2, s8 j done_mul ret_b: mv a2, s9 j done_mul ret_zero: slli a2, t0, 15 j done_mul ret_nan: li a2, 0x7FC0 j done_mul both_inf: li a2, 0x7F80 slli t0, t0, 15 or a2, a2, t0 j done_mul done_mul: lw ra, 12(sp) addi sp, sp, 16 jr ra ``` ::: - #### bf16_div :::spoiler See More ```asm= .data test1: .word 0x7FC0 .word 0x7F80 .word 0x0000 .word 0x0080 .word 0x3E80 .word 0x3F80 test2: .word 0x3F80 .word 0x7F80 .word 0x0000 .word 0x3F80 .word 0x3F81 .word 0x0040 ans: .word 0x7FC0 .word 0x7FC0 .word 0x7FC0 .word 0x0080 .word 0x3E7E .word 0x7F80 n_tests: .word 6 msg_input_a: .string "A = " msg_input_b: .string ", B = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s6, test1 la s7, test2 la s10, ans lw s11, n_tests li a4, 0 # index loop_t: beq a4, s11, done la a0, msg_input_a li a7, 4 ecall lw s8, 0(s6) addi a0, s8, 0 li a7, 34 ecall la a0, msg_input_b li a7, 4 ecall lw s9, 0(s7) addi a0, s9, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall jal ra, bf16_div addi t0, a2, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s10) addi s6, s6, 4 addi s10, s10, 4 addi s7, s7, 4 addi a4, a4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall bf16_div: addi sp, sp, -16 sw ra, 12(sp) # extract fields srli s0, s8, 15 # sign_a srli s2, s8, 7 andi s2, s2, 0xFF # exp_a andi s4, s8, 0x7F # mant_a srli s1, s9, 15 # sign_b srli s3, s9, 7 andi s3, s3, 0xFF # exp_b andi s5, s9, 0x7F # mant_b # result_sign = sign_a ^ sign_b xor t0, s0, s1 li t1, 0xFF # Special cases: b is Inf or NaN beq s3, t1, check_b_inf # Special cases: b exponent == 0 (subnormal or zero) beqz s3, check_b_zero # Special cases: a is Inf or NaN beq s2, t1, check_a_inf # Special cases: a exponent == 0 (subnormal or zero) beqz s2, check_a_zero_needed # Normalize mantissas / handle subnormals j norm_mant # b is Inf or NaN check_b_inf: andi t2, s5, 0x7F # t2 = mant_b bnez t2, ret_b # b is NaN -> return b # b is +Inf/-Inf (mant_b == 0) # if a is also Inf (exp_a==0xFF) and mant_a==0 -> Inf/Inf = NaN beq s2, t1, a_maybe_inf # else finite / Inf => result is signed zero (result_sign << 15) slli a2, t0, 15 j done_div a_maybe_inf: andi t2, s4, 0x7F # t2 = mant_a beqz t2, ret_nan # a is Inf too -> Inf/Inf = NaN # else a is NaN -> return a mv a2, s8 j done_div # b exponent == 0 (subnormal or zero) check_b_zero: andi t2, s5, 0x7F # t2 = mant_b bnez t2, norm_mant # subnormal -> normalize then divide # b == 0 exactly -> division by zero andi t2, s4, 0x7F # t2 = mant_a beqz t2, ret_nan # 0/0 -> NaN # else -> signed Inf slli a2, t0, 15 li t3, 0x7F80 or a2, a2, t3 j done_div # a exponent == 0 (subnormal or zero) check check_a_zero_needed: andi t2, s4, 0x7F beqz t2, ret_zero # a == 0 -> signed zero j norm_mant # a is Inf or NaN check_a_inf: andi t2, s4, 0x7F bnez t2, ret_a # a is NaN -> return a # a is Inf -> result signed Inf (b not Inf here) slli a2, t0, 15 li t3, 0x7F80 or a2, a2, t3 j done_div # normalize mantissas (handle subnormals) norm_mant: # normalize a (if subnormal) beqz s2, norm_a_sub ori s4, s4, 0x80 # implicit 1 li t5, 0 # exp_adjust = 0 j norm_b norm_a_sub: li t5, 0 norm_a_shift: andi t2, s4, 0x80 bnez t2, norm_a_done slli s4, s4, 1 addi t5, t5, -1 j norm_a_shift norm_a_done: li s2, 1 norm_b: beqz s3, norm_b_sub ori s5, s5, 0x80 j div_core norm_b_sub: li t2, 0 norm_b_shift: andi t2, s5, 0x80 bnez t2, norm_b_done slli s5, s5, 1 addi t2, t2, -1 j norm_b_shift norm_b_done: li s3, 1 add t5, t5, t2 # t5 = exp_adjust (a_adjust + b_adjust) # div_core: long division 16 iterations div_core: # dividend = mant_a << 15 slli t1, s4, 15 # t1 = dividend mv t2, s5 # t2 = divisor li t3, 0 # t3 = quotient li t4, 15 # j = 15 down to 0 div_loop: slli t3, t3, 1 # compute shifted = divisor << j (use register shift: sll rd, rs1, rs2) sll t6, t2, t4 # t6 = divisor << j blt t1, t6, no_sub sub t1, t1, t6 ori t3, t3, 1 no_sub: addi t4, t4, -1 bgez t4, div_loop # continue while j >= 0 # compute result_exp = exp_a - exp_b + BF16_EXP_BIAS (127) + exp_adjust sub t1, s2, s3 addi t1, t1, 127 add t1, t1, t5 # normalize quotient: if bit15 set -> right shift 8; else left shift until bit15 set (decrement exp) li t4, 0x8000 and t2, t3, t4 bnez t2, q_shift8 q_norm_loop: and t2, t3, t4 bnez t2, q_norm_done slli t3, t3, 1 addi t1, t1, -1 bgt t1, x0, q_norm_loop q_norm_done: srli t3, t3, 8 j q_after_norm q_shift8: srli t3, t3, 8 q_after_norm: andi t3, t3, 0x7F # mantissa (7 bits) # overflow / underflow li t2, 0xFF bge t1, t2, ret_inf blez t1, ret_zero # pack result slli t0, t0, 15 slli t1, t1, 7 or a2, t0, t1 or a2, a2, t3 j done_div # returns / special cases ret_a: mv a2, s8 j done_div ret_b: mv a2, s9 j done_div ret_inf: slli a2, t0, 15 li t2, 0x7F80 or a2, a2, t2 j done_div ret_zero: slli a2, t0, 15 j done_div ret_nan: li a2, 0x7FC0 j done_div done_div: lw ra, 12(sp) addi sp, sp, 16 jr ra ``` ::: - #### bf16_sqrt :::spoiler See More ```asm= .data test1: .word 0x7FC1 .word 0x7F80 .word 0x0000 .word 0x0040 .word 0x3E80 .word 0x407F ans: .word 0x7FC1 .word 0x7F80 .word 0x0000 .word 0x0000 .word 0x3F00 .word 0x3FFF n_tests: .word 6 msg_input_a: .string "A = " msg_result: .string ", Result = " msg_pass: .string ", PASS!\n" msg_fail: .string ", FAIL!\n" .text .globl main main: la s6, test1 la s10, ans lw s11, n_tests li a4, 0 # index loop_t: beq a4, s11, done la a0, msg_input_a li a7, 4 ecall lw s8, 0(s6) addi a0, s8, 0 li a7, 34 ecall la a0, msg_result li a7, 4 ecall mv a0, s8 jal ra, bf16_sqrt addi t0, a2, 0 addi a0, t0, 0 li a7, 34 ecall lw t1, 0(s10) addi s6, s6, 4 addi s10, s10, 4 addi a4, a4, 1 bne t0, t1, fail la a0, msg_pass li a7, 4 ecall j loop_t fail: la a0, msg_fail li a7, 4 ecall j loop_t done: li a7, 10 # exit ecall bf16_sqrt: addi sp, sp, -32 sw ra, 28(sp) sw a0, 24(sp) mv t0, a0 srli t1, t0, 15 # sign srli t2, t0, 7 andi t2, t2, 0xFF # exp andi t3, t0, 0x7F # mant # if exp == 0xFF li t4, 0xFF beq t2, t4, check_inf_nan # if exp==0 && mant==0 -> return 0 beqz t2, check_zero j check_neg check_zero: beqz t3, ret_zero j check_neg check_neg: bnez t1, ret_nan # negative -> NaN # flush denormals beqz t2, ret_zero # e = exp - 127 addi t5, t2, -127 # get mantissa with implicit 1 ori t6, t3, 0x80 # m = 0x80 | mant # adjust for odd exponent andi t7, t5, 1 beqz t7, even_exp slli t6, t6, 1 addi t5, t5, -1 even_exp: srai t5, t5, 1 addi t5, t5, 127 # new_exp = (e>>1)+127 # binary search for sqrt(m) li s0, 90 li s1, 256 li s2, 128 # result bs_loop: bgt s0, s1, bs_done add s3, s0, s1 srli s3, s3, 1 # mid = (low+high)>>1 mul s4, s3, s3 # mid*mid srli s4, s4, 7 # /128 bleu s4, t6, bs_leq addi s1, s3, -1 j bs_loop bs_leq: mv s2, s3 addi s0, s3, 1 j bs_loop bs_done: mv t6, s2 # normalize result li t7, 256 bge t6, t7, norm_shift li t7, 128 bge t6, t7, mant_ok norm_shift: srli t6, t6, 1 addi t5, t5, 1 mant_ok: andi t6, t6, 0x7F # mantissa only # check overflow/underflow li t7, 0xFF bge t5, t7, ret_inf blez t5, ret_zero slli t5, t5, 7 or a2, t5, t6 j done_sqrt check_inf_nan: bnez t3, ret_a # NaN propagation bnez t1, ret_nan # sqrt(-Inf)=NaN mv a2, t0 # sqrt(+Inf)=+Inf j done_sqrt ret_a: mv a2, t0 j done_sqrt ret_zero: li a2, 0 j done_sqrt ret_nan: li a2, 0x7FC0 j done_sqrt ret_inf: li a2, 0x7F80 j done_sqrt done_sqrt: lw ra, 28(sp) addi sp, sp, 32 jr ra ``` ::: ## Analysis Testing the code using [Ripes](https://ripes.me/) simulator. ### Pseudo instruction ``` 00000000 <main>: 0: 10000417 auipc x8 0x10000 4: 00040413 addi x8 x8 0 8: 10000497 auipc x9 0x10000 c: 00c48493 addi x9 x9 12 10: 10000917 auipc x18 0x10000 14: 01892903 lw x18 24 x18 18: 00000993 addi x19 x0 0 0000001c <loop_tests>: 1c: 0b298863 beq x19 x18 176 <done> 20: 10000517 auipc x10 0x10000 24: 00c50513 addi x10 x10 12 28: 00400893 addi x17 x0 4 2c: 00000073 ecall 30: 00198513 addi x10 x19 1 34: 00100893 addi x17 x0 1 38: 00000073 ecall 3c: 10000517 auipc x10 0x10000 40: ff650513 addi x10 x10 -10 44: 00400893 addi x17 x0 4 48: 00000073 ecall 4c: 00299293 slli x5 x19 2 50: 00540333 add x6 x8 x5 54: 00032503 lw x10 0 x6 58: 00050e93 addi x29 x10 0 5c: 00100893 addi x17 x0 1 60: 00000073 ecall 64: 10000517 auipc x10 0x10000 68: fd950513 addi x10 x10 -39 6c: 00400893 addi x17 x0 4 70: 00000073 ecall 74: 000e8513 addi x10 x29 0 78: 05c000ef jal x1 92 <bitwiseComplement> 7c: 00058f13 addi x30 x11 0 80: 000f0513 addi x10 x30 0 84: 00100893 addi x17 x0 1 88: 00000073 ecall 8c: 00299293 slli x5 x19 2 90: 005483b3 add x7 x9 x5 94: 0003ae03 lw x28 0 x7 98: 01cf1e63 bne x30 x28 28 <fail> 0000009c <pass>: 9c: 10000517 auipc x10 0x10000 a0: fad50513 addi x10 x10 -83 a4: 00400893 addi x17 x0 4 a8: 00000073 ecall ac: 00198993 addi x19 x19 1 b0: f6dff06f jal x0 -148 <loop_tests> 000000b4 <fail>: b4: 10000517 auipc x10 0x10000 b8: f9e50513 addi x10 x10 -98 bc: 00400893 addi x17 x0 4 c0: 00000073 ecall c4: 00198993 addi x19 x19 1 c8: f55ff06f jal x0 -172 <loop_tests> 000000cc <done>: cc: 00a00893 addi x17 x0 10 d0: 00000073 ecall 000000d4 <bitwiseComplement>: d4: ffc10113 addi x2 x2 -4 d8: 00112023 sw x1 0 x2 dc: 02050463 beq x10 x0 40 <zero> e0: 038000ef jal x1 56 <clz> e4: 00058293 addi x5 x11 0 e8: 02000313 addi x6 x0 32 ec: 405302b3 sub x5 x6 x5 f0: 00100313 addi x6 x0 1 f4: 005312b3 sll x5 x6 x5 f8: 406282b3 sub x5 x5 x6 fc: 005545b3 xor x11 x10 x5 100: 00c0006f jal x0 12 <return> 00000104 <zero>: 104: 00150593 addi x11 x10 1 108: 0040006f jal x0 4 <return> 0000010c <return>: 10c: 00012083 lw x1 0 x2 110: 00410113 addi x2 x2 4 114: 00008067 jalr x0 x1 0 00000118 <clz>: 118: ff010113 addi x2 x2 -16 11c: 00112623 sw x1 12 x2 120: 00a12423 sw x10 8 x2 124: 02000293 addi x5 x0 32 128: 01000313 addi x6 x0 16 0000012c <loop>: 12c: 006553b3 srl x7 x10 x6 130: 00038663 beq x7 x0 12 <skip> 134: 406282b3 sub x5 x5 x6 138: 00038513 addi x10 x7 0 0000013c <skip>: 13c: 00135313 srli x6 x6 1 140: fe0316e3 bne x6 x0 -20 <loop> 144: 40a285b3 sub x11 x5 x10 148: 00c12083 lw x1 12 x2 14c: 00812503 lw x10 8 x2 150: 01010113 addi x2 x2 16 154: 00008067 jalr x0 x1 0 ``` ### 5-stage pipelined processor ![5stage_cpu](https://hackmd.io/_uploads/ByZETCP6le.png) Above is a 5-stage in-order processor with hazard detection / elimination and forwarding CPU. | Execution info | with CLZ | | --------------- | -------- | | Cycles | 606 | | Instrs. retired | 400 | | CPI | 1.51 | | IPC | 0.66 | | Clock Rate | 10.31 HZ | #### IF ![IF](https://hackmd.io/_uploads/SkMfDNYage.png) - PC in this stage is `0x0000000C` - After we get PC, we can get instruction `0X00C48493` (addi x9, x9, 12) from Instr. Memory - There is no branch occur, so the next pc will be PC + 4 (`0x00000010`), then the mux before PC will select input from adder #### ID ![ID](https://hackmd.io/_uploads/BkFzw4Yael.png) - Instruction `addi x9, x9, 12` will be decoded into opcode `addi`, R1 idx `0x0C`, Wr idx `0x09`, imm `0x0C` - In addi, R2 is no need - Reg1 read value `0x00` from register file - `0X0C` will be sign extension to 32 bits `0x0000000C` through Imm. #### EX ![EX](https://hackmd.io/_uploads/SJA4FVF6ge.png) - Multiplexers 1 is to check whether data hazard happens, if the required register is not the newest value, we have to forward data form MEM stage or WB stage. In this case, the value in Reg1 is `0x00`, but we detect the newest value is at M stage, so we need to forward data `0x10000008` from MEM stage - Multiplexers 2 is to choose ALU's operands, the upper one has to choose RS1 `0x10000008`, the lower one has to choose immediate `0x0000000C` - ALU add two operand, so the result will be `0x10000014` #### MEM ![MEM](https://hackmd.io/_uploads/H14_sVY6xl.png) - Instruction `addi` is no need to use Data Memory - So just pass through this stage and go to WB in next stage #### WB ![WB](https://hackmd.io/_uploads/Sk5gnNtTlg.png) - The mux select result from ALU - Then write value `0X10000014` back to register `0x09` Before WB : ![before](https://hackmd.io/_uploads/H1GnnNYagl.png) After WB : ![after](https://hackmd.io/_uploads/HkM63VFall.png) ## Reference - [Leetcode:1009. Complement of Base 10 Integer](https://leetcode.com/problems/complement-of-base-10-integer/) - [Lab1: RV32I Simulator](https://hackmd.io/@sysprog/H1TpVYMdB) - [Quiz1 of Computer Architecture (2025 Fall)](https://hackmd.io/@sysprog/arch2025-quiz1)