# Computer Architecture — Fall 2025 Homework 1 ## Problem B ### 1. Introduction #### overview The **uf8** format is a custom 8-bit *logarithmic compression scheme* designed to represent large unsigned integers with compact precision. Instead of storing values linearly, it uses a **logarithmic scale** so that larger numbers occupy proportionally fewer bits. `uf8` behaves like a miniature floating-point format. Its 8 bits are divided as follows: | Bits | Field | Description | |:----:|:------|:-------------| | 7 – 4 | Exponent ($e$) | Controls the logarithmic scale | | 3 – 0 | Mantissa ($m$) | Provides local precision within the scale | The decoded value is given by $$ D(b) = m \cdot 2^{e} + (2^{e} - 1) \cdot 16 $$ Where $$ e = \left\lfloor \frac{b}{16} \right\rfloor, \quad m = b \bmod 16 $$ The encode value is given by $$ E(v) = \begin{cases} v, & \text{if } v < 16 \\[6pt] 16e + \left\lfloor \dfrac{v - \text{offset}(e)}{2^{e}} \right\rfloor, & \text{otherwise} \end{cases} $$ Where $$ \text{offset}(e) = (2^{e} - 1) \cdot 16 $$ --- ### 2.Implementation #### Original C code include test :::spoiler code ```c= #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> typedef uint8_t uf8; static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } /* Test encode/decode round-trip */ static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } int main(void) { if (test()) { printf("All tests passed.\n"); return 0; } return 1; } ``` ::: #### RiscV assembly code include test :::spoiler code ```asm= .data PASS_MSG: .asciz "All tests PASSED!\n" FAIL_MSG: .asciz "Some tests FAILED!\n" .text .globl main main: jal ra, test mv t0, a0 beqz t0, print_fail print_pass: la a0, PASS_MSG li a7, 4 # ecall 4 = print string ecall j done print_fail: la a0, FAIL_MSG li a7, 4 ecall done: li a7, 10 # ecall 10 = exit ecall clz: li t0,32 #t0 = n li t1,16 #t1 = c clz_Loop: srl t2,a0,t1 #t2 = y sgtz t3,t2 beqz t3,shift sub t0,t0,t1 mv a0,t2 shift: srli t1,t1,1 sgtz t3,t1 bnez t3,clz_Loop sub a0,t0,a0 ret uf8_decode: andi t1,a0,0x0f #t1 = mantissa srli t2,a0,4 li t3,15 sub t0,t3,t2 li t3,0x7fff srl t3,t3,t0 slli t0,t3,4 # t0 = offset sll t1,t1,t2 add a0,t1,t0 ret uf8_encode: addi sp, sp, -24 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) mv s7,a0 #s7 = value addi t0,s7,-16 sltz t0,t0 beqz t0,find_exp j restore_regs find_exp: addi sp,sp,-4 sw ra,0(sp) jal ra,clz lw ra, 0(sp) addi sp,sp,4 mv s2,a0 #s2=clz li s3,31 sub s3,s3,s2 #s3 = msb li s4,0 #exp=0 li s5,0 #overflow=0 li t0,5 blt s3,t0,find_extract_exp addi s4,s3,-4 addi t0,s4,-15 sgtz t0,t0 beqz t0,end_check li s4,15 end_check: li t0,0 check_overflow: sub t1,t0,s4 sltz t1,t1 beqz t1,check_overflow_done slli s5,s5,1 addi s5,s5,16 addi t0,t0,1 j check_overflow check_overflow_done: adjust: sgtz t0,s4 sub t1,s7,s5 sltz t1,t1 and t0,t1,t0 beqz t0,adjust_done addi s5,s5,-16 srli s5,s5,1 addi s4,s4,-1 j adjust adjust_done: find_extract_exp: addi t0,s4,-15 sltz t0,t0 beqz t0,extract_done slli t0,s5,1 addi s6,t0,16 sub t0,s7,s6 sltz t0,t0 bnez t0,extract_done mv s5,s6 addi s4,s4,1 j find_extract_exp extract_done: sub t0,s7,s5 srl t0,t0,s4 slli t1,s4,4 or a0,t1,t0 j restore_regs restore_regs: lw s2, 0(sp) lw s3, 4(sp) lw s4, 8(sp) lw s5, 12(sp) lw s6, 16(sp) lw s7, 20(sp) addi sp, sp, 24 ret test: addi sp, sp, -32 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) sw s0, 24(sp) sw ra, 28(sp) # s2 = i # s3 = previous_value # s4 = pass flag # s5 = decoded value # s6 = re-encoded value # s7 = pointer to FAIL_RECORD li s2, 0 # i = 0 li s3, -1 # previous_value = -1 li s4, 1 # passed = true test_loop: li t0, 256 beq s2, t0, test_done mv a0, s2 # a0 = fl = i jal ra, uf8_decode mv s5, a0 # s5 = value mv a0, s5 jal ra, uf8_encode mv s6, a0 # s6 = fl2 # if (fl != fl2) bne s2, s6, record_fail # if (value <= previous_value) sub t1, s5, s3 blez t1, record_fail # update mv s3, s5 addi s2, s2, 1 j test_loop record_fail: sw s2, 0(s7) # fl sw s5, 4(s7) # value sw s6, 8(s7) # fl2 li s4, 0 # passed = false addi s7, s7, 12 # next record slot addi s2, s2, 1 j test_loop test_done: mv a0, s4 # return pass flag (1 = pass, 0 = fail) lw s2, 0(sp) lw s3, 4(sp) lw s4, 8(sp) lw s5, 12(sp) lw s6, 16(sp) lw s7, 20(sp) lw s0, 24(sp) lw ra, 28(sp) addi sp, sp, 32 ret ``` ::: #### testing result all cases pass successfully in ripes! ![截圖 2025-10-10 晚上8.38.59](https://hackmd.io/_uploads/rkzO0dU6xg.png) --- ### 3.encode Optimization In the original C implementation of **`uf8_encode()`**, the encoding process relies on **two loops** to determine the correct exponent range. The first loop estimates the exponent by gradually doubling a threshold value, and the second loop fine-tunes the result by checking for overflow and underflow conditions. While this approach is functionally correct, it is computationally inefficient in RISC-V assembly due to multiple iterations and branches. However, by analyzing the mathematical relationship of the exponent range, we can derive a **closed-form expression** that eliminates all iterative loops. The goal of the exponent search is to find the integer \( e \) such that: $$ (2^e - 1) \times 16 \le v < (2^{e+1} - 1) \times 16 $$ To simplify, we first divide both sides by 16: $$ 2^e - 1 \le \frac{v}{16} < 2^{e+1} - 1 $$ Let $w = \left\lfloor \frac{v}{16} \right\rfloor$. Then we can rewrite the inequality as: $$ 2^e - 1 \le w < 2^{e+1} - 1 $$ Next, we add 1 to each term: $$ 2^e \le w + 1 < 2^{e+1} $$ Now, by taking the base-2 logarithm on both sides, we can directly obtain: $$ e = \lfloor \log_2(w + 1) \rfloor $$ Since \( w = v >> 4 \), the final **closed-form equation** becomes: $$ \boxed{e = \lfloor \log_2((v >> 4) + 1) \rfloor} $$ #### Implementation Insight In integer arithmetic, the logarithmic part can be efficiently implemented using the **CLZ (Count Leading Zeros)** instruction, since: $$ \lfloor \log_2(x) \rfloor = 31 - \text{CLZ}(x) $$ Therefore, the exponent can be calculated as: $$ e = 31 - \text{CLZ}((v >> 4) + 1) $$ This closed-form transformation **completely removes the original overflow and underflow loops**, turning the multi-step iterative process into a single constant-time computation. After obtaining \( e \), the offset and mantissa can be computed directly as: $$ \text{offset} = ((1 << e) - 1) << 4, \qquad m = \frac{v - \text{offset}}{2^e} $$ and the final encoded value is: $$ \text{uf8} = (e << 4) \,|\, (m \& 0x0F) $$ Asmembly code of refinment encode funtion :::spoiler code ```asm= uf8_encode: addi sp, sp, -24 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) mv s2,a0 #s2 = value addi t0,s2,-16 sltz t0,t0 bnez t0,restore_regs #if v<16 return # compute e = floor ( log2 (v/16)+1 ) srli t2,s2,4 addi t2,t2,1 mv a0,t2 addi sp,sp,-4 sw ra,0(sp) jal ra,clz lw ra,0(sp) addi sp,sp,4 mv s3,a0 li t0,31 sub s3,t0,s3 #s3 = e = 31-clz addi t0,s3,-15 sgtz t0,t0 beqz t0,compute_offset li s3,15 # compute offser : offset = ((1<<e)-1) << 4 compute_offset: li t0,1 sll t0,t0,s3 addi t0,t0,-1 slli s4,t0,4 #s4 = offset # compute mantissa = v-offset >> e sub t0,s2,s4 srl s5,t0,s3 #s5 = m # pack result slli t0,s3,4 andi t1,s5,0x0f or a0,t0,t1 j restore_regs ``` ::: testing result: ![截圖 2025-10-11 下午4.02.21](https://hackmd.io/_uploads/Hkpfycvpxx.png) #### Code Size and Cycle Count Comparison To evaluate the optimization impact, we compare the **original loop-based implementation** with the **closed-form CLZ-based implementation** in terms of both *code size* and *average execution cycles*. The closed-form version eliminates all loops and conditional branches in the exponent search, leading to a significantly smaller code footprint and a constant-time execution pattern. | Version | Exponent Computation Method |Code lines in assembler| Cycles | |:--------:|:----------------------------|:----------------:|:----------------------------:| | **Original** | Iterative search + fine-tune (overflow/underflow loops) | 252 | 56531 | | **Optimized (Closed Form)** | Single-step CLZ-based computation | 148 | 34067 | --- <div align="center"> **Original (Loop-based):** ![截圖 2025-10-13 下午2.58.40](https://hackmd.io/_uploads/HkcVmQcpeg.png) **Optimized (Closed-form):** (![截圖 2025-10-11 下午4.19.25](https://hackmd.io/_uploads/Hy0MQ5D6el.png) </div> - The **original implementation** spends most of its time inside two nested loops for exponent adjustment (`overflow` and `underflow` searches). - In the **closed-form version**, the exponent is determined by a single logarithmic estimation: $$ e = 31 - \text{CLZ}((v >> 4) + 1) $$ eliminating all iterative loops and reducing branch overhead. - The improvement cuts the cycle count by **~40%** --- ### 4.CLZ Function Optimization In the original assembly implementation, the **CLZ (Count Leading Zeros)** function was written using a loop structure. However, on a pipelined CPU, each branch instruction may introduce **branch hazards** and **pipeline stalls**, which waste cycles and reduce performance. To address this issue, we can **unroll the loop** and transform the function into a **branchless CLZ implementation**. This approach replaces conditional branching with arithmetic and logical operations, allowing the CPU pipeline to execute the entire routine **sequentially without control flow interruptions**. As a result, the branchless CLZ achieves the same functionality as the original version but with **predictable execution timing** and **significantly better performance**. :::spoiler code ```asm= clz: li t0,32 #t0 = n # chceck high 16 bits # --- k = 16 --- srli t1, a0, 16 # y = x >> 16 sltu t2, x0, t1 # cond = (y != 0) ? 1 : 0 sub t3, x0, t2 # mask = -cond (0x00000000 or 0xFFFFFFFF) slli t4, t2, 4 # cond*16 sub t0, t0, t4 # n -= 16 (if cond) xori t5, t3, -1 # ~mask and a0, a0, t5 # a0 = (a0 & ~mask) | (y & mask) and t1, t1, t3 or a0, a0, t1 # --- k = 8 --- # similar to k = 16 srli t1, a0, 8 sltu t2, x0, t1 sub t3, x0, t2 slli t4, t2, 3 # cond*8 sub t0, t0, t4 xori t5, t3, -1 and a0, a0, t5 and t1, t1, t3 or a0, a0, t1 # --- k = 4 --- srli t1, a0, 4 sltu t2, x0, t1 sub t3, x0, t2 slli t4, t2, 2 # cond*4 sub t0, t0, t4 xori t5, t3, -1 and a0, a0, t5 and t1, t1, t3 or a0, a0, t1 # --- k = 2 --- srli t1, a0, 2 sltu t2, x0, t1 sub t3, x0, t2 slli t4, t2, 1 # cond*2 sub t0, t0, t4 xori t5, t3, -1 and a0, a0, t5 and t1, t1, t3 or a0, a0, t1 # --- k = 1 --- srli t1, a0, 1 sltu t2, x0, t1 sub t3, x0, t2 # mask # cond*1 sub t0, t0, t2 # n -= 1 if cond xori t5, t3, -1 and a0, a0, t5 and t1, t1, t3 or a0, a0, t1 # return n - x sub a0, t0, a0 ret ``` ::: The branchless implementation removes all conditional branches (`beq`, `bnez`) and replaces them with pure arithmetic and bitwise masking operations. This eliminates **branch misprediction penalties** and makes the pipeline execution **fully deterministic**. Although the total number of ALU operations increases slightly, the CPU can now execute the instructions linearly without waiting for branch resolution. As a result, the overall performance improves from 34607 cycles down to 32867 cycles while maintaining identical functional correctness for all test inputs. We can also observe a **lower CPI (Cycles Per Instruction)**, which indicates that the **pipeline is operating more efficiently**, with fewer stalls and better instruction throughput across the execution stages. ![截圖 2025-10-11 下午5.57.25](https://hackmd.io/_uploads/rkZMqjvall.png) --- ### 5.LeetCode 2571 — Minimum Number of Operations to Reduce an Integer to 0 #### Problem Description You are given a positive integer $n$. In one operation, you may replace $n$ with either $n + 2^k$ or $n - 2^k$ for any integer $k \ge 0$. Return the **minimum number of operations** required to make $n = 0$. Example: Input: n = 39 Output: 3 Explanation: 39 → 7 (subtract 32) 7 → 8 (add 1) 8 → 0 (subtract 8) #### Core Idea Every integer $n$ lies between two consecutive powers of two: $$ 2^e \le n < 2^{e+1} $$ If we move $n$ to the nearest power of two ($2^e$ or $2^{e+1}$), we effectively remove the **most significant bit** in one step. This is optimal because: - Smaller power jumps ($\pm 1, \pm 2, \pm 4, \dots$) cannot affect the highest bit. - Removing the largest active bit first minimizes the total number of steps. Hence the **greedy rule**: 1. Find the highest set bit $e = \lfloor \log_2 n \rfloor$ 2. Compute $p = 2^e$ and $q = 2^{e+1}$ 3. Move $n$ toward whichever of $p$ or $q$ is closer: $$ d_1 = n - p, \quad d_2 = q - n $$ Choose the smaller of $d_1, d_2$ 4. Repeat until $n = 0$ #### Using CLZ (Count Leading Zeros) To find $\lfloor \log_2 n \rfloor$ efficiently, we can use **CLZ**: $$ e = 31 - \text{CLZ}(n) $$ This operation counts the number of leading zeros in the 32-bit binary representation of $n$, giving the position of the most significant 1-bit in $O(1)$ time. #### Algorithm Steps 1. Initialize `steps = 0` 2. While $n > 0$: - Find $e = 31 - \text{CLZ}(n)$ - Let $p = 2^e$, $q = 2^{e+1}$ - Compare: $$ \text{if } 2n \le 3p \Rightarrow n = n - p \quad \text{else} \quad n = q - n $$ - Increment `steps` 3. Return `steps` #### C Implementation :::spoiler code ```c #include <stdint.h> static inline int msb_index(uint32_t x) { return 31 - __builtin_clz(x); } int minOperations(int n) { int steps = 0; while (n) { int e = msb_index((uint32_t)n); uint32_t p = 1u << e; uint32_t q = p << 1; uint64_t two_n = ((uint64_t)n) << 1; uint64_t three_p = (uint64_t)p + ((uint64_t)p << 1); if (two_n <= three_p) n -= (int)p; else n = (int)(q - (uint32_t)n); ++steps; } return steps; } ``` ::: assembly code including test data :::spoiler code ```asm= .data PASS_MSG: .asciz "PASS\n" FAIL_MSG: .asciz "FAIL\n" # 10 test inputs test_cases: .word 1, 2, 3, 5, 7, 15, 39, 100, 1234, 1023 # Expected minimal steps for each input expected_steps: .word 1, 1, 2, 2, 2, 2, 3, 3, 5, 2 .text .globl main # ===================================================== # main: runs 10 test cases; prints PASS if all match. # ===================================================== main: la s0, test_cases # s0 = ptr to inputs la s1, expected_steps # s1 = ptr to expected li s2, 10 # s2 = number of test cases li s4, 1 # s4 = pass flag (1 = true) case_loop: beqz s2, all_done # if all cases done → finish # Load n and expected steps lw s5, 0(s0) # s5 = n lw s6, 0(s1) # s6 = expected steps addi s0, s0, 4 addi s1, s1, 4 # ---- Run algorithm for this n ---- mv s2, s2 # keep counter intact mv s7, s5 # s7 = n (working copy) li s3, 0 # s3 = steps = 0 Loop: beqz s7, Loop_end # if (n == 0) break # --- Compute e = 31 - clz(n) --- mv a0, s7 jal ra, clz # a0 = number of leading zeros li t0, 31 sub t0, t0, a0 # t0 = e = 31 - clz(n) # --- Compute p = 2^e --- li t1, 1 sll t1, t1, t0 # t1 = p = 1 << e # --- Compute two_n = 2 * n --- slli t3, s7, 1 # t3 = two_n # --- Compute three_p = 3 * p = p + 2p --- slli t4, t1, 1 # t4 = 2p add t4, t4, t1 # t4 = 3p # --- Unsigned compare: if (2n <= 3p) go to p else go to 2p --- sltu t5, t4, t3 # t5 = (3p < 2n) beqz t5, to_p # if (2n <= 3p) → to_p # Case: move toward 2p → n = 2p - n slli t6, t1, 1 # t6 = 2p sub s7, t6, s7 # n = 2p - n j step_done to_p: # Case: move toward p → n = n - p sub s7, s7, t1 # n = n - p step_done: addi s3, s3, 1 # steps++ j Loop Loop_end: # Compare computed steps (s3) with expected (s6) beq s3, s6, case_ok li s4, 0 # mark FAIL case_ok: addi s2, s2, -1 # next case j case_loop all_done: # Print PASS/FAIL beqz s4, print_fail print_pass: la a0, PASS_MSG li a7, 4 # print string ecall j exit print_fail: la a0, FAIL_MSG li a7, 4 # print string ecall exit: li a7, 10 # exit ecall # ===================================================== # clz: count leading zeros (returns in a0) # Input: a0 = x (x > 0) # Output: a0 = number of leading zeros in 32-bit x # ===================================================== clz: li t0, 32 # t0 = n (leading-zero count base) li t1, 16 # t1 = c (shift amount) clz_Loop: srl t2, a0, t1 # t2 = y = x >> c sgtz t3, t2 # t3 = (y > 0) beqz t3, shift sub t0, t0, t1 # n -= c mv a0, t2 # x = y shift: srli t1, t1, 1 # c >>= 1 sgtz t3, t1 # c > 0 ? bnez t3, clz_Loop sub a0, t0, a0 # a0 = n - x (final adjust in this scheme) ret ``` ::: and the cycle is 1948 ![截圖 2025-10-13 下午4.25.14](https://hackmd.io/_uploads/H1uOPV96lx.png) #### implement with branchless clz Using the branchless CLZ implementation, we achieved 1833 cycles, and the CPI decreased from 1.47 to 1.16. Although the total instruction count increased, both the cycle count and CPI were reduced, showing that the pipeline became more efficient with fewer control hazards and smoother instruction flow. ![截圖 2025-10-13 下午5.31.14](https://hackmd.io/_uploads/B1JePScTxg.png) --- ## Problem C --- ### 1. Introduction #### Goal The purpose of this assignment is to implement the **bfloat16 (BF16)** arithmetic operations using only **RV32I** instructions This project aims to demonstrate understanding of: - IEEE-754 floating-point representation - Bitwise manipulation of sign, exponent, and mantissa - Software emulation of floating-point behavior on integer hardware #### Motivation By building BF16 arithmetic from scratch, we can observe how floating-point units handle normalization, bias adjustment, rounding, and exceptional values (NaN/Inf/Zero). It also helps strengthen low-level reasoning about performance and pipeline hazards. --- ### 2. Design Overview #### Supported Operations - `bf16_add` - `bf16_sub` - `bf16_mul` - `bf16_div` - `bf16_sqrt` - Helper utilities: `bf16_isnan`, `bf16_isinf`, `bf16_iszero`, `bf16_to_f32`, `f32_to_bf16` #### Constraints - Only **RV32I** instructions are allowed - No hardware floating-point unit - Must correctly handle **NaN**, **Inf**, **Zero**, and **Subnormal** values --- ### 3. Implementation #### Data Representation | Field | Bits | Description | |-------|------|-------------| | Sign | 1 | Positive (0) / Negative (1) | | Exponent | 8 | Biased with bias = 127 | | Mantissa | 7 | Fraction bits (with implicit 1 for normalized values) | --- #### C code :::spoiler code ```c= #include <stdbool.h> #include <stdint.h> #include <string.h> typedef struct { uint16_t bits; } bf16_t; #define BF16_SIGN_MASK 0x8000U #define BF16_EXP_MASK 0x7F80U #define BF16_MANT_MASK 0x007FU #define BF16_EXP_BIAS 127 #define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) #define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F)}; } static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 127; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ /*if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } }*/ /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; +Inf if (new_exp <= 0) return BF16_ZERO();*/ return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` ::: #### assembly code #### bf16 isnan isinf iszero :::spoiler code ```asm= bf16_isnan: # check exponent li t1, 0x7F80 and t2, a0, t1 bne t2, t1, isnan_end # check mantissa li t1, 0x007F and t2, a0, t1 beqz t2, isnan_end li a0, 1 # return 1 ret isnan_end: li a0, 0 # return 0 ret #---------------------------------------- bf16_isinf: # check exponent li t1, 0x7F80 and t2, a0, t1 bne t2, t1, isinf_end # check mantissa li t1, 0x007F and t2, a0, t1 bnez t2, isinf_end li a0, 1 # return 1 ret isinf_end: li a0, 0 # return 0 ret #---------------------------------------- bf16_iszero: li t1, 0x7FFF and t2, a0, t1 bnez t2, not_zero li a0, 1 ret not_zero: li a0, 0 ret #---------------------------------------- ``` #### bf32 to bf16 ```asm= bf32_to_bf16: lw t0,0(a0) srli t1,t0,23 #shift right t0 23bits li t2,0xff #load 0xff and t1,t1,t2 beq t1,t2,return_bf16_inf srli t1,t0,16 andi t1,t1,1 li t2,0x7fff add t1,t1,t2 add t0,t0,t1 srli a0,t0,16 ret return_bf16_inf: srli t1,t0,16 li t2,0xffff and a0,t1,t2 ret ``` #### bf16 to bf32 ```asm= bf16_to_bf32: slli a0,a0,16 ret ``` ::: #### bf16 add :::spoiler code ```asm= bf16_add: # s2 = sign_a # s3 = sign_b # s4 = exp_a # s5 = exp_b # s6 = mant_a # s7 = mant_b addi sp, sp, -24 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) #------------ extract a addi t0, a0,0 srli t0, t0, 15 andi s2, t0, 1 # s2 = sign_a addi t0, a0,0 srli t0, t0, 7 andi s4, t0, 0xFF # s4 = exp_a addi t0, a0,0 andi s6, t0, 0x7F # s6 = mant_a #------------ extract b addi t0, a1,0 srli t0, t0, 15 andi s3, t0, 1 # s3 = sign_b addi t0, a1,0 srli t0, t0, 7 andi s5, t0, 0xFF # s5 = exp_b addi t0, a1,0 andi s7, t0, 0x7F # s7 = mant_b #handle inf li t0,0xff bne s4,t0,check_exponent_b bnez s6,return_a bne s5,t0,return_a xor t0,s2,s3 #if signa=signb t6=0 else t6=1 or t0,s7,t0 beqz t0,return_b li a0,0x7FC0 j restore_regs return_a: j restore_regs return_b: addi a0,a1,0 j restore_regs check_exponent_b: li t0,0xff beq s5,t0,return_b seqz t1,s4 #s2=!exp_a seqz t2,s6 #s3=!mantissa_a and t1,t1,t2 bnez t1,return_b seqz t1,s5 #s2=!exp_b seqz t2,s7 #s3=!mantissa_b and t1,t1,t2 bnez t1,return_a beqz s4,skip_a ori s6,s6,0x80 skip_a: beqz s5,skip_b ori s7,s7,0x80 skip_b: sub t0,s4,s5 #t0=exp_diff=exp_a-exp_b bgtz t0,exp_a_bigger bltz t0,exp_b_bigger exp_equal: addi t1,s4,0 #t1=result_exp = exp_a j sign_check exp_a_bigger: addi t1,s4,0 #t1=result_exp = exp_a addi t0,t0,-8 # exp_diff -= 8 bgtz t0,return_a #exp_diff>0 return a addi t0,t0,8 srl s7,s7,t0 #mantissa_b >> exp_diff j sign_check exp_b_bigger: addi t1,s5,0 #t1= result_exp =b addi t0,t0,8 bltz t0,return_b addi t0,t0,-8 neg t0,t0 srl s6,s6,t0 sign_check: bne s2,s3,sign_diff addi t2,s2,0 #t2 = result_sign = sign_a add t3,s6,s7 #t3 = result_mantissa = man_a + man_b andi t0,t3,0x100 beqz t0,return_result srli t3,t3,1 addi t1,t1,1 addi t0,t1,-0xff bltz t0,return_result slli t2,t2,15 li t0,0x7F80 or t2,t2,t0 addi a0,t2,0 j restore_regs sign_diff: sub t0,s6,s7 bltz t0,manb_greater addi t2,s2,0 #t2 = result_sign = sign_a sub t3,s6,s7 #t3 = result_mantissa = mana-manb j check_result manb_greater: addi t2,s3,0 #t2 = result_sign = sign_b sub t3,s7,s6 #t3 = result_man = manb-mana check_result: beqz t3,bf16_zero normalize_loop: andi t0,t3,0x80 bnez t0,return_result slli t3,t3,1 addi t1,t1,-1 blez t1,bf16_zero j normalize_loop bf16_zero: li a0,0x0000 j restore_regs return_result: slli t2,t2,15 andi t1,t1,0xFF slli t1,t1,7 andi t3,t3,0x7F or a0,t1,t2 or a0,a0,t3 j restore_regs ``` ::: #### bf16 sub :::spoiler code ```asm= bf16_sub: li t0,0x8000 xor a1, a1, t0 # b.bits ^= 0x8000 # call bf16_add(a, b) jal ra, bf16_add ``` ::: #### bf16 mul :::spoiler code ```asm= bf16_mul: # s2 = sign_a # s3 = sign_b # s4 = exp_a # s5 = exp_b # s6 = mant_a # s7 = mant_b addi sp, sp, -24 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) #------------ extract a addi t0, a0,0 srli t0, t0, 15 andi s2, t0, 1 # s2 = sign_a addi t0, a0,0 srli t0, t0, 7 andi s4, t0, 0xFF # s4 = exp_a addi t0, a0,0 andi s6, t0, 0x7F # s6 = mant_a #------------ extract b addi t0, a1,0 srli t0, t0, 15 andi s3, t0, 1 # s3 = sign_b addi t0, a1,0 srli t0, t0, 7 andi s5, t0, 0xFF # s5 = exp_b addi t0, a1,0 andi s7, t0, 0x7F # s7 = mant_b xor t1,s2,s3 #result_sign = signa xor signb li t0,0xff beq t0,s4,mul_a_inf beq t0,s5,mul_b_inf seqz t2,s4 seqz t3,s6 seqz t4,s5 seqz t5,s7 and t2,t2,t3 and t3,t4,t5 or t0,t2,t3 beqz t0,mul_process slli t1,t1,15 addi a0,t1,0 j restore_regs mul_a_inf: beqz s6,check_expbmantb j restore_regs check_expbmantb: seqz t2,s5 seqz t3,s7 and t0,t2,t3 bnez t0,mul_return_nan slli t1,t1,15 li t0,0x7f80 or a0,t1,t0 j restore_regs mul_return_nan: li a0,0x7fc0 j restore_regs mul_b_inf: beqz s7,check_expamanta addi a0,a1,0 j restore_regs check_expamanta: seqz t2,s4 seqz t3,s6 and t0,t2,t3 bnez t0,mul_return_nan slli t1,t1,15 li t0,0x7f80 or a0,t1,t0 j restore_regs mul_process: li t2,0 #exp_adjust beqz s4,mul_normalized_a ori s6,s6,0x80 j mul_a_done mul_normalized_a: andi t0,s6,0x80 bnez t0,mul_norm_a_done slli s6,s6,1 addi t2,t2,-1 j mul_normalized_a mul_norm_a_done: li s4,1 mul_a_done: beqz s5,mul_normalized_b ori s7,s7,0x80 j mul_b_done mul_normalized_b: andi t0,s7,0x80 bnez t0,mul_norm_b_done slli s7,s7,1 addi t2,t2,-1 j mul_normalized_b mul_norm_b_done: li s5,1 mul_b_done: add t3, s4, s5 # exp_a + exp_b add t3, t3, t2 # + exp_adjust addi t2, t3, -127 # - bias #t2=result_exp t1=result_sign #mantissa multiplyer li t3,0 addi t4,s6,0 addi t5,s7,0 li t6,8 mul_loop: andi t0,t5,1 # t0 = (t5 & 1) neg t0,t0 # t0 = 0 or -1 (全1) and t0,t0,t4 # t0 = (t5&1)? t4 : 0 add t3,t3,t0 # t3 += t0 slli t4,t4,1 srli t5,t5,1 addi t6,t6,-1 bnez t6,mul_loop #check result mantissa li t0,0x8000 and t0,t0,t3 beqz t0,result_mant_zero srli t0,t3,8 andi t3,t0,0x7F addi t2,t2,1 j check_mantissa_done result_mant_zero: srli t0,t3,7 andi t3,t0,0x7f check_mantissa_done: #check result exp # check overflow li t0, 0xff bge t2, t0, mul_overflow # if result_exp >= 255 → INF # check underflow blez t2, mul_underflow # if result_exp <= 0 → handle underflow j mul_done # else → normal mul_overflow: slli t1, t1, 15 # sign li t0, 0x7f80 # INF exponent or a0, t1, t0 j restore_regs mul_underflow: li t0, -6 blt t2, t0, mul_to_zero # if result_exp < -6 → return 0 # else: shift mantissa >> (1 - exp) li t0, 1 sub t0, t0, t2 # t0 = 1 - result_exp srl t3, t3, t0 li t2, 0 # exp = 0 j mul_done mul_to_zero: slli t1, t1, 15 # just sign addi a0, t1, 0 # return ±0 j restore_regs mul_done: slli t1, t1, 15 # sign andi t2, t2, 0xff # exp slli t2, t2, 7 andi t3, t3, 0x7f # mant or t1, t1, t2 or t1, t1, t3 addi a0, t1, 0 j restore_regs ``` ::: Since the use of the hardware multiply instruction (mul) is not allowed, this program implements mantissa multiplication using the shift-and-add method. This approach mimics how binary multiplication works at the bit level. #### bf16 div :::spoiler code ```asm= bf16_div: addi sp, sp, -24 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) #-------------------------------- # Extract fields #-------------------------------- srli t0, a0, 15 andi s2, t0, 1 # s2 = sign_a srli t0, a1, 15 andi s3, t0, 1 # s3 = sign_b srli t0, a0, 7 andi s4, t0, 0xFF # s4 = exp_a srli t0, a1, 7 andi s5, t0, 0xFF # s5 = exp_b andi s6, a0, 0x7F # mant_a andi s7, a1, 0x7F # mant_b xor t1, s2, s3 # result_sign = sign_a ^ sign_b #-------------------------------- # Handle special cases #-------------------------------- li t0, 0xFF beq s5, t0, div_b_inf_or_nan # b = Inf or NaN beqz s5, div_b_zero_check # b = subnormal or zero? j div_check_a # else continue div_b_inf_or_nan: bnez s7, div_return_nan # mant_b == 0 → b 是 Inf li t0, 0xFF bne s4, t0, div_return_zero bnez s6, div_return_nan # a 是 Inf, b 是 Inf → NaN j div_return_nan div_b_zero_check: beqz s7, div_b_is_zero j div_check_a div_b_is_zero: beqz s4, div_a_zero_check_for_bzero j div_by_zero div_a_zero_check_for_bzero: beqz s6, div_return_nan # 0 / 0 → NaN j div_by_zero div_by_zero: li t0, 0x7F80 # Inf exponent slli t1, t1, 15 or a0, t0, t1 j restore_regs div_return_zero: slli t1, t1, 15 addi a0, t1, 0 j restore_regs div_return_nan: li a0, 0x7FC0 j restore_regs div_check_a: li t0, 0xFF beq s4, t0, div_a_inf_check beqz s4, div_a_zero_check j div_process div_a_inf_check: beqz s6, div_return_inf j div_return_nan div_a_zero_check: beqz s6, div_return_zero j div_process div_return_inf: li t0, 0x7F80 slli t1, t1, 15 or a0, t1, t0 j restore_regs #-------------------------------- # Main process (normal / subnormal) #-------------------------------- div_process: # add implicit 1 beqz s4, div_norm_a_done ori s6, s6, 0x80 div_norm_a_done: beqz s5, div_norm_b_done ori s7, s7, 0x80 div_norm_b_done: # exp = exp_a - exp_b + bias sub t2, s4, s5 addi t2, t2, 127 # t2 = result_exp #divider slli t4,s6,15 #dividend addi t5,s7,0 #divisor li t3,0 # quotient li t6,0 #loop_count div_loop: slli t3,t3,1 li t0,15 sub t0,t0,t6 sll t0,t5,t0 blt t4,t0,div_skip sub t4,t4,t0 ori t3,t3,1 div_skip: addi t6,t6,1 li t0,16 blt t6,t0,div_loop #div_end # exp_a == 0 → exp_adjust-- beqz s4, div_exp_a_zero j div_exp_a_done div_exp_a_zero: addi t2, t2, -1 div_exp_a_done: # exp_b == 0 → exp_adjust++ beqz s5, div_exp_b_zero j div_exp_b_done div_exp_b_zero: addi t2, t2, 1 div_exp_b_done: #mantissa normalization li t0, 0x8000 and t0, t0, t3 bnez t0,div_msb_one div_shift_check: li t0, 0x8000 and t0, t3, t0 # t0 = (t3 & 0x8000) bnez t0, div_norm_done # MSB==1 done addi t0, t2, -1 # t0 = t2 - 1 blez t0, div_norm_done # t0<1 done slli t3, t3, 1 # quotient <<= 1 addi t2, t2, -1 # result_exp-- j div_shift_check div_norm_done: srli t3, t3, 8 # quotient >>= 8 j div_done div_msb_one: srli t3,t3,8 j div_done div_done: andi t3, t3, 0x7F # mantissa & 0x7F # --- check overflow (result_exp >= 0xFF) --- li t0, 0xFF bge t2, t0, div_overflow # if exp >= 255 → INF # --- check underflow (result_exp <= 0) --- blez t2, div_underflow # if exp <= 0 → ZERO # --- normal case --- slli t1, t1, 15 # sign << 15 andi t2, t2, 0xFF slli t2, t2, 7 # exp << 7 or t1, t1, t2 or t1, t1, t3 addi a0, t1, 0 j restore_regs # --- overflow (INF) --- div_overflow: slli t1, t1, 15 li t0, 0x7F80 or a0, t1, t0 j restore_regs # --- underflow (ZERO) --- div_underflow: slli t1, t1, 15 addi a0, t1, 0 j restore_regs ``` ::: #### bf16 sqrt :::spoiler code ```asm= bf16_sqrt: # s2 = sign_a # s3 = exp_a # s4 = mant_a # s5 = result_sign # s6 = result_exp # s7 = result_mant addi sp, sp, -24 sw s2, 0(sp) sw s3, 4(sp) sw s4, 8(sp) sw s5, 12(sp) sw s6, 16(sp) sw s7, 20(sp) #------------ extract a (bf16 in a0) srli t0, a0, 15 andi s2, t0, 1 # s2 = sign_a srli t0, a0, 7 andi s3, t0, 0xFF # s3 = exp_a andi s4, a0, 0x7F # s4 = mant_a # default sign=0 li s5, 0 #===================================== # Special cases #===================================== # exp == 0xFF → Inf or NaN li t0, 0xFF beq s3, t0, sqrt_is_inf_or_nan # exp == 0 → zero or subnormal beqz s3, sqrt_is_zero_or_sub # negative (normal, non-zero) → NaN bnez s2, sqrt_neg_input # others go main process j sqrt_process # ---------- exp==0xFF ---------- sqrt_is_inf_or_nan: # mant != 0 NaN bnez s4, sqrt_return_input_nan # mant == 0 → Inf # −Inf → NaN;+Inf → +Inf bnez s2, sqrt_return_qnan li a0, 0x7F80 # +Inf j restore_regs sqrt_return_input_nan: addi a0, a0, 0 j restore_regs sqrt_return_qnan: li a0, 0x7FC0 # quiet NaN j restore_regs # ---------- exp==0 ---------- sqrt_is_zero_or_sub: # mant == 0 → ±0 beqz s4, sqrt_return_same_zero li a0, 0x0000 j restore_regs sqrt_return_same_zero: li a0, 0x0000 j restore_regs # ---------- negative (normal, non-zero) ---------- sqrt_neg_input: li a0, 0x7FC0 # NaN j restore_regs sqrt_process: addi t2,s3,-127 #t2= exp-bias ori t1,s4,0x80 # get full mantissa andi t0,t2,1 beqz t0,even_exp slli t1,t1,1 addi t0,t2,-1 srai t0,t0,1 addi t0,t0,127 mv s6,t0 #s6 = new_exp j sqrt_adjust_done even_exp: srai t0,t2,1 addi t0,t0,127 mv s6,t0 #s6 = new_exp sqrt_adjust_done: li t2, 128 # low li t3, 256 # high li t4, 128 # result sqrt_binary_search_loop: # while (low <= high) blt t3, t2, sqrt_binary_search_end # mid = (low + high) >> 1 add t0, t2, t3 srli t0, t0, 1 # t0 = mid # sq = (mid * mid) >> 7 li t5, 0 mv t6, t0 # counter = mid sqrt_mul_loop: beqz t6, sqrt_mul_done add t5, t5, t0 # t5 += mid addi t6, t6, -1 j sqrt_mul_loop sqrt_mul_done: srli t5, t5, 7 # sq = (mid*mid)/128 # if (sq <= m) { result=mid; low=mid+1; } else { high=mid-1; } blt t1, t5, sqrt_go_hi # if m < sq → high = mid - 1 mv t4, t0 # result = mid addi t2, t0, 1 # low = mid + 1 j sqrt_binary_search_loop sqrt_go_hi: addi t3, t0, -1 # high = mid - 1 j sqrt_binary_search_loop sqrt_binary_search_end: li t0, 256 blt t4, t0, sqrt_check_low # result >= 256 exp++ srli t4, t4, 1 addi s6, s6, 1 sqrt_check_low: li t0, 128 bge t4, t0, sqrt_norm_done sqrt_norm_loop: slli t4, t4, 1 addi s6, s6, -1 blt t4, t0, sqrt_norm_loop sqrt_norm_done: andi s7, t4, 0x7F j pack_result pack_result: # new_mant = result & 0x7F andi s7, t4, 0x7F # return (sign<<15) | (new_exp<<7) | new_mant; slli t1, s5, 15 # sign andi t2, s6, 0xFF slli t2, t2, 7 # !!! exp << 7( andi t3, s7, 0x7F # mant or t1, t1, t2 or a0, t1, t3 j restore_regs ``` ::: In the sqrt function, we use binary search to approximate the square root of the mantissa. The mantissa part of a bfloat16 number has 7 bits. After adding the implicit leading 1, it becomes an 8-bit value ranging from $128\ (b'10000000)$ to $255\ (b'11111111)$. If the exponent is even, the mantissa range is $[128, 256)$. If the exponent is odd, we subtract $1$ from the exponent and multiply the mantissa by $2$, resulting in a range of $[256, 512)$. Thus, before taking the square root, the mantissa lies within $[128, 512).$ Since the square-root operation effectively halves the exponent, the normalized mantissa result always falls within $[128, 256).$ Therefore, we can safely use $[128, 256)$ as a unified binary-search range for both even and odd exponent cases. In the original implementation, a wider range of $[90, 256]$ was used to guarantee correctness under all inputs. However, this range is unnecessarily large and increases the number of iterations required for convergence. By refining the bounds to $[128, 256)$, the search interval becomes tighter, allowing the algorithm to converge faster while maintaining full accuracy. Moreover, because the result is guaranteed to remain within $[128, 256)$ (exclusive of $256$), it will never overflow or underflow the mantissa field. As a result, additional overflow or underflow checks are no longer required, simplifying the post-processing logic. ### Tese Cases #### ADD Test Cases (bfloat16) | No. | Operand A (hex) | Operand B (hex) | Expected (hex) | Description | |:---:|:----------------:|:----------------:|:----------------:|:-------------| | 1 | `0x3F40` | `0x3EA0` | `0x3F88` | $0.75 + 0.3125 = \mathbf{1.0625}$ | | 2 | `0x3F80` | `0xBFF0` | `0xBF60` | $1.0 + (-1.875) = \mathbf{-0.875}$ | | 3 | `0x3F80` | `0xBF80` | `0x0000` | $1.0 + (-1.0) = \mathbf{0.0}$ | | 4 | `0x7F80` | `0x3F80` | `0x7F80` | $+\infty + 1.0 = \mathbf{+\infty}$ | | 5 | `0x7F80` | `0xFF80` | `0x7FC0` | $+\infty + (-\infty) = \mathbf{\text{NaN}}$ | | 6 | `0x7F00` | `0x7F00` | `0x7F80` | $\ 0x7F00 + 0x7F00 = \mathbf{+\infty}$ | #### SUB Test Cases (bfloat16) | No. | Operand A (hex) | Operand B (hex) | Expected (hex) | Description | |:---:|:----------------:|:----------------:|:----------------:|:-------------| | 1 | `0x3FC0` | `0x3F40` | `0x3F40` | $1.5 - 0.75 = \mathbf{0.75}$ | | 2 | `0x3F80` | `0x4000` | `0xBF80` | $1.0 - 2.0 = \mathbf{-1.0}$ | | 3 | `0x3F80` | `0x3F80` | `0x0000` | $1.0 - 1.0 = \mathbf{0.0}$ | | 4 | `0x7F80` | `0x7F80` | `0x7FC0` | $+\infty - +\infty = \mathbf{\text{NaN}}$ | #### MUL Test Cases (bfloat16) | No. | A (hex) | B (hex) | Expected (hex) | Description | |:--:|:--------:|:--------:|:---------------:|:-------------| | 1 | `0x4040` | `0x40C0` | `0x4190` | $3.0 \times 6.0 = 18.0$ | | 2 | `0x3F80` | `0x3F40` | `0x3F40` | $1.0 \times 0.75 = 0.75$ | | 3 | `0x3FC8` | `0x40D2` | `0x4124` | $1.5625 \times 6.5625 = 10.25390625$ | | 4 | `0x7F80` | `0xBF80` | `0xFF80` | $+\infty \times -1.0 = -\infty$ | | 5 | `0x7F80` | `0x0000` | `0x7FC0` | $+\infty \times +0 = \text{NaN (invalid)}$ | | 6 | `0x0000` | `0x0000` | `0x0000` | $0 \times 0 = 0$ | #### DIV Test Cases (bfloat16) | No. | Operand A (hex) | Operand B (hex) | Expected (hex) | Description | |:---:|:----------------:|:----------------:|:----------------:|:-------------| | 1 | `0x3F80` | `0x4000` | `0x3F00` | **1.0 ÷ 2.0 = 0.5** | | 2 | `0x40B8` | `0x4000` | `0x4038` | **5.75 ÷ 2.0 = 2.875** | | 3 | `0xBF80` | `0x4000` | `0xBF00` | **−1.0 ÷ 2.0 = −0.5** | | 4 | `0x7F80` | `0x4000` | `0x7F80` | **+Inf ÷ 2.0 = +Inf** | | 5 | `0x3F80` | `0x7F80` | `0x0000` | **1.0 ÷ +Inf = +0** | | 6 | `0x0000` | `0x0000` | `0x7FC0` | **0 ÷ 0 = NaN (invalid)** | | 7 | `0x3F80` | `0x0000` | `0x7F80` | **1.0 ÷ 0 = +Inf** | #### SQRT Test Cases (bfloat16) | No. | Operand A (hex) | Expected (hex) | Description | |:---:|:----------------:|:----------------:|:-------------| | 1 | `0x4080` | `0x4000` | **sqrt(4.0) = 2.0** | | 2 | `0x40C8` | `0x4020` | **sqrt(6.25) = 2.5** | | 3 | `0xBF80` | `0x7FC0` | **sqrt(−1.0) = NaN** | | 4 | `0x7F80` | `0x7F80` | **sqrt(+Inf) = +Inf** | | 5 | `0x0000` | `0x0000` | **sqrt(+0) = +0** ### Testing Result ![截圖 2025-10-09 晚上9.01.23](https://hackmd.io/_uploads/r1hwMVSpge.png)