# Assignment 1: RISC-V Assembly and Instruction Pipeline contributed by < [kaihsiang092](https://github.com/kaihsiang092/ca2025-quizzes.git) > ## Problem B : uf8 * [Assembly program for Problem B.](https://github.com/kaihsiang092/ca2025-quizzes/blob/main/q1-uf8.s) Problem B mainly focuses on compressing a 20-bit unsigned integer into an 8-bit code using **logarithmic quantization** with a **2.5:1 compression ratio**. This method reduces storage requirements while introducing only minor errors, making it suitable for applications where **range is more important than precision**, such as **temperature sensing, distance sensing, and computer graphics**. The following sections describe the **decoding** and **encoding** methods. ### C code : uf8_decode ```c= * Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } ``` 1. uint32_t mantissa = fl & 0x0f;—takes the lower 4 bits of fl as the mantissa. 2. uint8_t exponent = fl >> 4;—takes the upper 4 bits of fl as the exponent. 3. (0x7FFF >> (15 - exponent)) << 4 calculates the starting value of each interval, which corresponds to the formula ${offset}(e) = (2^e - 1) \cdot 16$ Finally, the function returns the decoded value, which is the sum of the interval’s starting value and the offset. ### C code : uf8_encode 1. If the value is less than 16, no compression is needed — return the original value directly. 2. Use CLZ (Count Leading Zeros) to find the position of the most significant bit (MSB), which determines which power-of-two range the value belongs to. 3. Subtracting 4 is an empirical adjustment to make the value distribution smoother. If the result exceeds 15 (the maximum representable by 4 bits), set it to 15. 4. Use a loop to calculate the starting value (offset) of the corresponding exponent range. This represents the minimum value that can be expressed in that range. 5. If the current overflow is already greater than the input value, it means the exponent was overestimated — decrease the exponent and adjust overflow accordingly. 6. Starting from the current overflow, gradually test the next range to see if it’s still less than the value, until the range that just contains the value is found. 7. Within that range, extract the fine-grained offset and right-shift it according to the exponent. 8. The upper 4 bits represent the exponent, and the lower 4 bits represent the mantissa — together forming the final 8-bit uf8 encoded value. ```c= /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } ``` ### assembly code ```c= .data str1: .string ": produces value " str2: .string " but encodes back to " str3: .string ": value " str4: .string " <= previous_value " str5: .string "All tests passed.\n" str6: .string " tests failed.\n" .text .globl main main: jal ra, test beq a0, x0, fail la a0, str5 # print str5 li a7, 4 ecall li a7, 93 # ecall: exit li a0, 0 # exit code is 0, ecall fail: la a0, str6 # print str6 li a7, 4 ecall li a7, 93 # ecall: exit li a0, 1 # exit code is 1 ecall # clz(a0 = x) -> a0 = result clz: li t0, 32 # n = 32 li t1, 16 # c = 16 clz_loop: srl t2, a0, t1 # y = x >> c beq t2, x0, clz_skip # if (y == 0) skip sub t0, t0, t1 # n -= c add a0, t2, x0 # x = y clz_skip: srli t1, t1, 1 # c >>= 1 bne t1, x0, clz_loop # while(c) sub a0, t0, a0 # return n - x jr ra # a0 = fl (uint8) # return a0 = value (uint32) uf8_decode: andi t0, a0, 0x0F # t0 = mantissa,mantissa = fl & 0x0f; srli t1, a0, 4 # t1 = exponent,exponent = fl >> 4; li t2, 15 sub t2, t2, t1 # 15 - exponent li t3, 0x7FFF srl t3, t3, t2 # 0x7FFF >> (15 - exponent) slli t3, t3, 4 # t3 = offset sll t2, t0, t1 # mantissa << exponent add a0, t2, t3 # a0 = (mantissa << exponent) + offset jr ra # jump to ra uf8_encode: # if (value < 16) return value; addi sp, sp, -4 sw ra, 0(sp) add t6, a0, x0 # t6 = value addi t0, x0, 16 blt a0, t0, return1 jal ra, clz # msb = 31 - lz addi t1, x0, 31 sub t1, t1, a0 # t1 = msb = 31 - lz; addi t2, x0, 0 # t2 = exponent = 0; addi t3, x0, 0 # t3 = overflow = 0; addi t4, t4, 5 blt t1, t4, exact_exponent addi t2, t1, -4 # exponent = msb - 4; addi t4, x0, 15 bge t4, t2, Calculate_overflow # if (exponent <= 15) Calculate overflow for estimated exponent li t2, 15 Calculate_overflow: addi t4, x0, 0 # e = 0 for1: slli t5, t3, 1 addi t3, t5, 16 addi t4, t4, 1 blt t4, t2, for1 while1: blez t2, exact_exponent # if (exponent <= 0) break; bge t6, t3, exact_exponent # if (value >= overflow) break; addi t5, t3, -16 # (overflow - 16) srli t3, t5, 1 # t3 = overflow = (overflow - 16) >> 1; addi t2, t2, -1 j while1 exact_exponent: addi t5, x0, 15 while2: bge t2, t5, return0 slli t4, t3, 1 # (overflow << 1) addi t4, t4, 16 # t4 = next_overflow = (overflow << 1) + 16; blt t6, t4, return0 # if (value < next_overflow) break; add t3, t4, x0 addi t2, t2, 1 j while2 return0: sub t1, t6, t3 # (value - overflow) srl t1, t1, t2 # t1 = mantissa = (value - overflow) >> exponent; slli t0, t2, 4 # (exponent << 4) or a0, t0, t1 # return (exponent << 4) | mantissa; return1: lw ra, 0(sp) addi sp, sp, 4 jr ra # jump to ra test: addi sp, sp, -4 sw ra, 0(sp) addi s0, x0, -1 # s0 = previous_value addi s1, x0, 1 # s1 = passed addi s2, x0, 0 # fl = 0 addi s3, x0,256 # fl = 256 for2: add a0, s2, x0 jal ra uf8_decode add s4, a0, x0 # s4 = value add a0, s4, x0 jal ra uf8_encode add s5, a0, x0 # s5 = fl2 if1: beq s2, s5, if2 mv a0, s2 li a7, 34 ecall la a0, str1 # print str1 li a7, 4 ecall mv a0, s4 # print value li a7, 1 ecall la a0, str2 # print str2 li a7, 4 ecall mv a0, s5 li a7, 34 ecall li s1, 0 # passed = false if2: blt s0, s4, afterif mv a0, s2 # print f1 li a7, 34 ecall la a0, str3 # print str3 li a7, 4 ecall mv a0, s4 # print value li a7, 1 ecall la a0, str4 # print str4 li a7, 4 ecall mv a0, s0 # prepare to print previous_value li a7, 34 ecall li s1, 0 # passed = false afterif: mv s0, s4 # previous_value = value; addi s2, s2, 1 blt s2, s3, for2 mv a0, s1 # return passed lw ra, 0(sp) addi sp, sp, 4 jr ra # jump to ra ``` ### Result ![image](https://hackmd.io/_uploads/ByfaY3KTxx.png) ### Performance Cycles : 44397 ![image](https://hackmd.io/_uploads/S1PXKhFage.png) ### 5-stage pipelined processor 1. **Single-cycle processor** Every instruction completes in a single clock cycle—the fetch, decode, execute, memory access, and write-back all happen within the same tick. 2. **5-stage pipelined processor** The execution of each instruction is split into five stages (IF, ID, EX, MEM, WB), allowing multiple instructions to be in different stages simultaneously. Next, we’ll visualize and explain each of these five stages. #### Instruction Fetch (IF) : ![image](https://hackmd.io/_uploads/HkrXDAt6xe.png =40%x) * **PC:** Holds the address of the current instruction; shown as `PC = 0x00000000`. After each instruction executes, the PC updates to the address of the “next” instruction. * **Adder (+4):** Adds 4 to the current PC and outputs `0x00000004`, which is the sequential next instruction address. * **MUX:** The lower-left MUX selects the next PC. Normally it chooses `PC + 4` (sequential execution). For jumps/branches taken, it selects the branch/jump target. In this figure, because the instruction is `jal x1, 316`, the next PC becomes `PC + 316`. * **Instruction Memory:** Uses the current PC to fetch the instruction. With `addr = 0x00000000`, the fetched word is `0x13C000EF`, which encodes `jal x1, 316`. * **Compressed Decoder:** Ripes checks whether the instruction is a 16-bit RISC-V compressed instruction. If so, it expands it to an equivalent 32-bit instruction here. `jal` is a standard 32-bit instruction, so input and output are the same: `0x13C000EF`. * **IF/ID Pipeline Register:** Latches the fetched instruction and its PC for the next clock cycle, passing them to the ID stage. With **enable** lit green and **clear** red (not asserted), the instruction is flowing normally into the ID stage. #### Instruction Decode / Register Fetch (ID) ![image](https://hackmd.io/_uploads/HJXniCtTxg.png =40%x) * **IF/ID Pipeline Register:** Supplies the PC (address of the current instruction) and the instruction `0x13C000EF` (`jal x1, 316`). This instruction word is forwarded to the **Decode** unit and the **Immediate Generator**. * **Decode:** Extracts the opcode (lowest 7 bits) to identify the instruction type (here: JAL), as well as `rd`, `rs1`, `rs2` (register indices), and `funct3`/`funct7` (which select ALU behaviors for other instruction types). * **Registers:** `R1 idx` and `R2 idx` would read source register values, but `jal` doesn’t use source registers, so the inputs are `0x00000000`. The outputs `Reg1` and `Reg2` are also `0x00000000` since this instruction has no operand reads. * **Immediate Generator:** Extracts the immediate field from the instruction bits; here it outputs `0x0000013C`, the jump offset (316). * **ID/EX Pipeline Register:** Latches these decoded fields for the next cycle and passes them to the EX stage. The **enable** light being green indicates the register is currently writing/propagating values normally (ID stage is proceeding). #### Execute(EX) ![image](https://hackmd.io/_uploads/SJQT-J56ll.png =40%x) * **ALU:** Executes arithmetic (ADD/SUB), bitwise ops (AND/OR/XOR), comparisons (SLT/SLTU), and address calculations (branch/jump targets). Inputs here are the PC (`0x00000000`) and the immediate (`0x0000013C`). Output `Res = 0x00000000 + 0x0000013C = 0x0000013C`, which is the **jump target address**. This value is sent back to the IF stage on the next cycle to update the PC. * **ALU input selectors:** The upper MUX selects the PC (`0x00000000`); the lower MUX selects the immediate (`0x0000013C`). Together they steer the ALU to compute `PC + immediate`. * **Branch Unit:** Determines whether a branch condition is taken (for `beq`, `bne`, `blt`, `bge`, etc.). For `jal`, no comparison is needed—the jump always happens—so **branch taken** is asserted (true). Once the jump/branch is confirmed in EX, the target address (`0x0000013C`) is forwarded to the IF stage, which uses it on the next cycle to update the PC. #### Memory access (MEM) ![image](https://hackmd.io/_uploads/ByhIKJ5all.png =40%x) * **Data Memory:** The `jal` instruction does not access data memory. In this cycle, the module is simply passed through with no actual operation performed. #### Register write back (WB) ![image](https://hackmd.io/_uploads/ryDMoycall.png =25%x) * **MUX:** The three-input MUX selects the data source for write-back. The green light indicates the currently selected path is active. * The rightmost output value `0x00000004` is sent back to the Register File and written to the destination register **x1 (ra)**. This is one of `jal`’s key functions: it stores the address of the next instruction (**PC + 4**) into **x1** as the return address. ![image](https://hackmd.io/_uploads/H1hj3JcTge.png) After execution, **x1 = 0x00000004** — because `jal` writes **PC + 4** (the return address) into **x1 (ra)**. ## Problem C : bfloat16Square Root operation * [Assembly program for Problem C.](https://github.com/kaihsiang092/ca2025-quizzes/blob/main/q1-bfloat16.s) ### bfloat16 data ### bf16_isnan、bf16_isinf、bf16_iszero #### c code ```c= static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } ``` #### assembly code ```c= bf16_isnan: addi t5, a0, 0 # save a.bits # t0 = EXP_MASK = 0x7F80 (build via 0xFF << 7) addi t0, x0, 255 slli t0, t0, 7 and t1, t5, t0 # t1 = a & EXP_MASK xor t2, t1, t0 # t2 = (a & EXP_MASK) ^ EXP_MASK sltiu a0, t2, 1 # a0 = (t2 == 0) ? 1 : 0 (exp_all_ones) # t3 = a & MANT_MASK (0x007F) andi t3, t5, 127 sltu t4, x0, t3 # t4 = (t3 != 0) ? 1 : 0 (mant_nonzero) and a0, a0, t4 # a0 = exp_all_ones && mant_nonzero jalr x0, ra, 0 bf16_isinf: addi t5, a0, 0 # save a.bits # t0 = EXP_MASK = 0x7F80 addi t0, x0, 255 slli t0, t0, 7 and t1, t5, t0 # t1 = a & EXP_MASK xor t2, t1, t0 # t2 = (a & EXP_MASK) ^ EXP_MASK sltiu a0, t2, 1 # a0 = (t2 == 0) ? 1 : 0 (exp_all_ones) # t3 = a & MANT_MASK andi t3, t5, 127 sltiu t4, t3, 1 # t4 = (t3 == 0) ? 1 : 0 (mant_zero) and a0, a0, t4 # a0 = exp_all_ones && mant_zero jalr x0, ra, 0 bf16_iszero: addi t5, a0, 0 # save a.bits # t0 = 0x7FFF = (0x7F << 8) | 0xFF addi t0, x0, 127 slli t0, t0, 8 ori t0, t0, 255 and t1, t5, t0 # t1 = a & 0x7FFF sltiu a0, t1, 1 # a0 = (t1 == 0) ? 1 : 0 jalr x0, ra, 0 ``` ### f32_to_bf16 #### C code ```c= static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } ``` #### Assembly code ```c= f32_to_bf16: # === 取出 exponent (8 bits) === srli t0, a0, 23 # t0 = f32bits >> 23 andi t0, t0, 0xFF # t0 = exponent (8 bits) li t1, 0xFF beq t0, t1, nan_inf # if exponent == 都是1 → NaN/Inf 處理 # === 一般數值的 round-to-nearest-even === srli t1, a0, 16 # t1 = f32bits >> 16 andi t1, t1, 1 # t1 = (f32bits >> 16) & 1 (LSB for tie-to-even) li t2, 0x7FFF add t1, t1, t2 # t1 = ((f32bits >>16)&1) + 0x7FFF add a0, a0, t1 # f32bits += t1 # 取高16位作為 bf16.bits srli a0, a0, 16 # a0 = f32bits >> 16 li t3, 0xFFFF # t3 = 0xFFFF and a0, a0, t3 # 保留低16位(確保乾淨) jr ra nan_inf: srli a0, a0, 16 # a0 = f32bits >> 16 li t3, 0xFFFF # t3 = 0xFFFF and a0, a0, t3 # (f32bits >> 16) & 0xFFFF jr ra ``` ### bf16_to_f32 #### C code ```c= static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } ``` #### Assembly code ```c= bf16_to_f32: slli a0, a0, 16 #f32bits = ((uint32_t) val.bits) << 16 jr ra ``` ### bf16_add #### C code ```c= static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } ``` #### Assembly code ```c= bf16_add: # a0 = a.bits (low 16) # a1 = b.bits (low 16) # 使用暫存器: # t0=sign_a, t1=sign_b, t2=exp_a, t3=exp_b, t4=mant_a, t5=mant_b, t6=temp # a2=result_sign, a3=result_exp, a4=result_mant, a5=exp_diff srli t0, a0, 15 # a.bits >> 15 andi t0, t0, 1 # t0=sign_a srli t1, a1, 15 # b.bits >> 15 andi t1, t1, 1 # t1=sign_b srli t2, a0, 7 # a.bits >> 7 andi t2, t2, 0xFF # t2=exp_a srli t3, a1, 7 # b.bits >> 7 andi t3, t3, 0xFF # t3=exp_b andi t4, a0, 0x7F # t4=mant_a andi t5, a1, 0x7F # t5=mant_b addi t6, x0, 0xFF bne t2, t6, if_1 # if exp_a != 0xFF -> 跳到if_1 if_1: bne t3, t6, if_2 addi a0, a1, 0 # if (exp_b == 0xFF) return b; jr ra if_2: # if (!exp_a && !mant_a) return b; bne t2, x0, if_3 bne t4, x0, if_3 addi a0, a1, 0 jr ra if_3: # if (!exp_b && !mant_b) return a; bne t3, x0, if_4 bne t5, x0, if_4 jr ra return_b: addi a0, a1, 0 jr ra return_a: jr ra if_4: # if (exp_a) mant_a |= 0x80; beq t2, x0, if_5 ori t4, t4, 0x80 if_5: # if (exp_b) mant_b |= 0x80; beq t3, x0, if_6 ori t5, t5, 0x80 if_6: sub a5, t2, t3 # exp_diff = exp_a - exp_b # if exp_diff > 0 slt t6, x0, a5 # t6 = (0 < exp_diff) beq t6, x0, else_if # 若不是 >0,去檢查 <0 addi a3, t2, 0 # result_exp = exp_a addi t6, x0, 8 blt t6, a5, return_a # if 8 < exp_diff -> 直接回 a srl t5, t5, a5 # mant_b >>= exp_diff jal x0, if_8 else_if: # else if (exp_diff < 0) slt t6, a5, x0 # t6 = (exp_diff < 0) beq t6, x0, else # 若不是 <0 -> 相等 addi a3, t3, 0 # result_exp = exp_b sub t6, x0, a5 # t6 = -exp_diff = exp_b - exp_a addi a4, x0, 8 blt a4, t6, return_b # if (-exp_diff) > 8 -> 回 b srl t4, t4, t6 # mant_a >>= (-exp_diff) jal x0, if_8 else: # exp_diff == 0 add a3, t2, x0 # result_exp = exp_a if_8: xor t6, t0, t1 bne t6, x0, DIFF_SIGN # 若異號 add a2, t0, x0 # result_sign = sign_a add a4, t4, t5 # result_mant = mant_a + mant_b andi t6, a4, 0x100 # 檢查進位 (bit8) beq t6, x0, PACK srli a4, a4, 1 addi a3, a3, 1 # exp++ addi t6, x0, 0xFF bge a3, t6, RET_INF # 溢出 -> 回 ±Inf jal x0, PACK RET_INF: slli a0, a2, 15 # a0 = (result_sign << 15) addi t6, x0, 0xFF slli t6, t6, 7 # 0x7F80 = 0xFF<<7 or a0, a0, t6 jr ra DIFF_SIGN: sltu t6, t4, t5 # t6=1 if mant_a < mant_b bne t6, x0, USE_B addi a2, t0, 0 # result_sign = sign_a sub a4, t4, t5 # mant = a - b jal x0, AFTER_SUB USE_B: addi a2, t1, 0 # result_sign = sign_b sub a4, t5, t4 # mant = b - a AFTER_SUB: beq a4, x0, RET_ZERO NORM_LOOP: andi t6, a4, 0x80 bne t6, x0, PACK # 若已具 hidden bit -> 組裝 slli a4, a4, 1 # 左規格化 addi a3, a3, -1 # exp-- # 若 exp <= 0 -> 下溢為 0 slt t6, x0, a3 # t6 = (0 < exp) bne t6, x0, NORM_LOOP jal x0, RET_ZERO PACK: andi t6, a3, 255 slli t6, t6, 7 # (exp&0xFF)<<7 slli a0, a2, 15 # sign<<15 or a0, a0, t6 andi t6, a4, 0x7F or a0, a0, t6 jr ra RET_ZERO: addi a0, x0, 0 jr ra ``` ### bf16_sub #### C code ```c= static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } ``` #### Assembly code ```c= bf16_sub: lui t0, 0x8 # t0 = 0x00008000 xor a1, a1, t0 # return bf16_add(a, b) jal ra, bf16_add # 呼叫 bf16_add(a0, a1) jalr x0, ra, 0 # return ``` ### bf16_mul #### C code ```c= static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } ``` #### Assembly code ```c= bf16_mul: srli t0, a0, 15 # sign_a andi t0, t0, 1 srli t1, a1, 15 # sign_b andi t1, t1, 1 srli t2, a0, 7 # exp_a andi t2, t2, 0xFF srli t3, a1, 7 # exp_b andi t3, t3, 0xFF andi t4, a0, 0x7F # mant_a (7 bits) andi t5, a1, 0x7F # mant_b (7 bits) xor a2, t0, t1 # result_sign = sign_a ^ sign_b addi t6, x0, 255 bne t2, t6, 1 bne t4, x0, RET_A # a 是 NaN → 回 a # a 是 ±Inf # 若 b 是 0(exp_b=0 且 mant_b=0)→ NaN bne t3, x0, RET_INF2 # b 不是次正規/0 → ±Inf beq t5, x0, RET_NAN # b 的 mant==0 且 exp_b==0 → NaN # b 是 subnormal(exp_b=0 但 mant_b!=0),仍視為有限 → ±Inf jal x0, RET_INF2 1: # --- b 的 exponent == 0xFF --- bne t3, t6, 2f bne t5, x0, RET_B # b 是 NaN → 回 b # b 是 ±Inf bne t2, x0, RET_INF2 # a 非 0 → ±Inf beq t4, x0, RET_NAN # a==0 → NaN jal x0, RET_INF2 2: # --- 任一為 0 → 回 ±0(符號=result_sign) --- # if a == 0 bne t2, x0, 3f beq t4, x0, RET_SIGNED_ZERO 3: # if b == 0 bne t3, x0, 4f beq t5, x0, RET_SIGNED_ZERO 4: # --- exp_adjust = 0 --- addi a5, x0, 0 # --- 正規化/補 hidden bit:處理 a --- beq t2, x0, 5f ori t4, t4, 0x80 # mant_a |= 0x80 jal x0, 6f 5: # a 是次正規:左移直到有 hidden bit,exp_adjust-- # do { if (mant_a & 0x80) break; mant_a<<=1; exp_adjust--; } while(1) andi t6, t4, 0x80 bne t6, x0, 555 55: slli t4, t4, 1 addi a5, a5, -1 andi t6, t4, 0x80 beq t6, x0, 55 555: addi t2, x0, 1 # exp_a = 1 6: # --- 處理 b --- beq t3, x0, 7f ori t5, t5, 0x80 # mant_b |= 0x80 jal x0, 8f 7: # b 是次正規 andi t6, t5, 0x80 bne t6, x0, 777 77: slli t5, t5, 1 addi a5, a5, -1 andi t6, t5, 0x80 beq t6, x0, 77 777: addi t3, x0, 1 # exp_b = 1 8: # --- 8x8 乘法(位移加法),result_mant=a4 --- addi a4, x0, 0 # product = 0 addi t6, x0, 8 # 8 次迭代 9: # 迴圈: 若 (mant_b & 1) product += mant_a andi t0, t5, 1 beq t0, x0, 10f add a4, a4, t4 10: slli t4, t4, 1 # mant_a <<= 1 srli t5, t5, 1 # mant_b >>= 1 addi t6, t6, -1 bne t6, x0, 9b add a3, t2, t3 addi a3, a3, -127 add a3, a3, a5 lui t6, 0x8 # t6 = 0x00008000 and t0, a4, t6 beq t0, x0, 11f srli a4, a4, 8 andi a4, a4, 0x7F addi a3, a3, 1 jal x0, 12f 11: # 無 bit15 srli a4, a4, 7 andi a4, a4, 0x7F 12: # --- 溢出: result_exp >= 0xFF → ±Inf --- addi t6, x0, 255 bge a3, t6, RET_INF2 # --- 下溢: result_exp <= 0 --- beq a3, x0, 13f slt t0, a3, x0 # t0=1 if a3<0 beq t0, x0, 14f # a3>0 → 正常數 13: # result_exp <= 0 addi t6, x0, -6 blt a3, t6, RET_SIGNED_ZERO # 太小 → ±0 # result_mant >>= (1 - result_exp); result_exp=0 addi t0, x0, 1 sub t0, t0, a3 # t0 = 1 - result_exp srl a4, a4, t0 addi a3, x0, 0 14: # --- 組裝正常/次正規結果 --- # bits = (sign<<15) | ((exp&0xFF)<<7) | (mant&0x7F) andi t6, a3, 255 slli t6, t6, 7 slli a0, a2, 15 or a0, a0, t6 andi t6, a4, 0x7F or a0, a0, t6 jalr x0, ra, 0 RET_A: # 回 a addi a0, a0, 0 jalr x0, ra, 0 RET_B: # 回 b addi a0, a1, 0 jalr x0, ra, 0 RET_INF2: # 回 ±Inf: (sign<<15)|(0xFF<<7) slli a0, a2, 15 addi t6, x0, 255 slli t6, t6, 7 # t6 = 0x7F80 or a0, a0, t6 jalr x0, ra, 0 RET_NAN: # 回 qNaN: 0x7FC0 addi t6, x0, 255 slli t6, t6, 7 # 0x7F80 ori a0, t6, 0x40 # 0x7FC0 jalr x0, ra, 0 RET_SIGNED_ZERO: # 回 ±0: (result_sign<<15) slli a0, a2, 15 jalr x0, ra, 0 ``` ### bf16_div #### c code ```c= static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F)}; } ``` #### assembly code ```c= bf16_div: srli t0, a0, 15 # sign_a andi t0, t0, 1 srli t1, a1, 15 # sign_b andi t1, t1, 1 srli t2, a0, 7 # exp_a andi t2, t2, 0xFF srli t3, a1, 7 # exp_b andi t3, t3, 0xFF andi t4, a0, 0x7F # mant_a andi t5, a1, 0x7F # mant_b # result_sign = sign_a ^ sign_b xor a2, t0, t1 addi t6, x0, 255 bne t3, t6, 1f # if (exp_b != 0xFF) 跳過 bne t5, x0, RET_B3 # mant_b!=0 → NaN/NaN payload,回 b # b 是 ±Inf;若 a 是 +Inf/-Inf(且非 NaN)→ NaN;否則 a/Inf = ±0 bne t2, t6, RET_SIGNED_ZERO3 beq t4, x0, RET_NAN3 # a 是 ±Inf(mant_a==0)→ NaN jal x0, RET_SIGNED_ZERO3 1: bne t3, x0, 2f bne t5, x0, 2f # exp_b==0 && mant_b==0 → b==0 # 若 a==0 → NaN;否則 → ±Inf bne t2, x0, RET_INF3 beq t4, x0, RET_NAN3 jal x0, RET_INF3 2: bne t2, t6, 3f # if exp_a != 0xFF 跳過 bne t4, x0, RET_A3 # a 是 NaN → 回 a # a 是 ±Inf → 回 ±Inf jal x0, RET_INF3 3: bne t2, x0, 4f beq t4, x0, RET_SIGNED_ZERO 4: beq t2, x0, 5f ori t4, t4, 0x80 # mant_a |= 0x80 5: beq t3, x0, 6f ori t5, t5, 0x80 # mant_b |= 0x80 6: slli a4, t4, 15 # dividend = (uint32)mant_a << 15 addi a5, x0, 0 # quotient = 0 # divisor_shift = mant_b << 15 (用 t6 當 cur = divisor << 15) slli t6, t5, 15 addi t0, x0, 16 # t0 = 16 次 L_div_loop: slli a5, a5, 1 # quotient <<= 1 # if (dividend >= cur) { dividend -= cur; quotient |= 1; } sltu t1, a4, t6 # t1=1 if dividend < cur(無號比較) bne t1, x0, L_no_sub sub a4, a4, t6 ori a5, a5, 1 L_no_sub: srli t6, t6, 1 # cur >>= 1 addi t0, t0, -1 bne t0, x0, L_div_loop sub a3, t2, t3 addi a3, a3, 127 # if (!exp_a) result_exp--; beq t2, x0, L_dec_exp_a jal x0, L_chk_exp_b L_dec_exp_a: addi a3, a3, -1 L_chk_exp_b: # if (!exp_b) result_exp++; beq t3, x0, L_inc_exp_b jal x0, L_norm_q L_inc_exp_b: addi a3, a3, 1 L_norm_q: # if (quotient & 0x8000) quotient >>= 8; lui t0, 0x8 # t0 = 0x00008000 and t1, a5, t0 beq t1, x0, L_shift_left_phase srli a5, a5, 8 jal x0, L_pack_mant L_shift_left_phase: # while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } L_norm_loop: and t1, a5, t0 bne t1, x0, L_after_left_norm addi t1, x0, 1 slt t1, t1, a3 # t1=1 if (1 < result_exp) 即 result_exp > 1 beq t1, x0, L_after_left_norm slli a5, a5, 1 addi a3, a3, -1 jal x0, L_norm_loop L_after_left_norm: srli a5, a5, 8 L_pack_mant: andi a5, a5, 0x7F # mantissa = quotient & 0x7F addi t0, x0, 255 bge a3, t0, RET_INF3 # result_exp >= 0xFF → ±Inf # if (result_exp <= 0) → ±0 beq a3, x0, RET_SIGNED_ZERO3 slt t1, a3, x0 # a3 < 0 ? bne t1, x0, RET_SIGNED_ZERO3 # bits = (sign<<15) | ((exp&0xFF)<<7) | (mant&0x7F) andi t0, a3, 255 slli t0, t0, 7 slli a0, a2, 15 or a0, a0, t0 or a0, a0, a5 jalr x0, ra 0 RET_A3: # 回 a addi a0, a0, 0 jalr x0, ra 0 RET_B3: # 回 b addi a0, a1, 0 jalr x0, ra 0 RET_INF3: # 回 ±Inf: (sign<<15)|(0xFF<<7) slli a0, a2, 15 addi t0, x0, 255 slli t0, t0, 7 # t0 = 0x7F80 or a0, a0, t0 jalr x0, ra 0 RET_NAN3: # 回 qNaN: 0x7FC0 addi t0, x0, 255 slli t0, t0, 7 # 0x7F80 ori a0, t0, 0x40 # 0x7FC0 jalr x0, ra 0 RET_SIGNED_ZERO3: # 回 ±0: (sign<<15) slli a0, a2, 15 jalr x0, ra 0 ``` ### bf16_sqrt #### c code ```c= static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` #### assembly code ```c= bf16_sqrt: # 取 sign/exp/mant srli t0, a0, 15 # sign andi t0, t0, 1 srli t1, a0, 7 # exp andi t1, t1, 0xFF andi t2, a0, 0x7F # mant addi t3, x0, 255 bne t1, t3, SKIP_INF_NAN bne t2, x0, RET_A4 bne t0, x0, RET_NAN4 # sign!=0 → NaN jalr x0, ra, 0 # +Inf 直接回 a SKIP_INF_NAN: # sqrt(0)=0 :(exp==0 && mant==0) bne t1, x0, SKIP_ZERO bne t2, x0, SKIP_ZERO addi a0, x0, 0 jalr x0, ra, 0 SKIP_ZERO: # 陰數(有限)→ NaN beq t0, x0, SKIP_NEG jal x0, RET_NAN4 SKIP_NEG: # 次正規 flush-to-zero: exp==0 → 0 bne t1, x0, CONTINUE1 addi a0, x0, 0 jalr x0, ra, 0 CONTINUE1: # e = exp - 127 (暫存在 a3) addi a3, t1, -127 # m = 0x80 | mant (a2) ori a2, t2, 0x80 # 若 e 為奇數:m <<= 1; new_exp = ((e-1)>>1) + 127 andi t3, a3, 1 beq t3, x0, EVEN_EXP slli a2, a2, 1 # m <<= 1 addi t3, a3, -1 srai t3, t3, 1 addi a3, t3, 127 # new_exp jal x0, INIT_BSEARCH EVEN_EXP: srai t3, a3, 1 addi a3, t3, 127 # new_exp INIT_BSEARCH: # Binary search: low=90, high=256, result=128 addi t4, x0, 90 addi t5, x0, 256 addi t6, x0, 128 BSEARCH_LOOP: # while (low <= high) { ... } slt t3, t5, t4 # t3=1 if high<low bne t3, x0, BSEARCH_DONE # mid = (low + high) >> 1 (用 t6 當 mid) add t6, t4, t5 srli t6, t6, 1 # prod = mid * mid (a4=prod, a5=multiplier,t3=multiplicand) addi a4, x0, 0 # prod=0 addi a5, t6, 0 # multiplier = mid addi t3, t6, 0 # multiplicand = mid addi t0, x0, 0 # (scratch) MUL_LOOP: andi t0, a5, 1 beq t0, x0, MUL_SKIP_ADD add a4, a4, t3 # prod += multiplicand MUL_SKIP_ADD: slli t3, t3, 1 # multiplicand <<= 1 srli a5, a5, 1 # multiplier >>= 1 bne a5, x0, MUL_LOOP # sq = (prod) >> 7 (a4 作為 sq) srli a4, a4, 7 # if (sq <= m) { result = mid; low = mid + 1; } else { high = mid - 1; } # 判斷:m < sq ? (無號比較) sltu t0, a2, a4 # t0=1 if m < sq bne t0, x0, TAKE_RIGHT # m<sq → high=mid-1 # sq <= m addi t6, t6, 0 # result = mid(維持在 t6) addi t4, t6, 1 # low = mid + 1 jal x0, BSEARCH_LOOP TAKE_RIGHT: addi t5, t6, -1 # high = mid - 1 jal x0, BSEARCH_LOOP BSEARCH_DONE: # result 現在在 t6 # 規格化:若 result >= 256 → result>>=1; new_exp++ addi t0, x0, 256 slt t1, t6, t0 # t1=1 if result < 256 bne t1, x0, CHECK_LOW srli t6, t6, 1 addi a3, a3, 1 jal x0, AFTER_NORM CHECK_LOW: # 若 result < 128,且 new_exp > 1:左移直到 >=128 或 new_exp==1 addi t0, x0, 128 NORM_LOOP_LOW: slt t1, t6, t0 # t1=1 if result < 128 beq t1, x0, AFTER_NORM addi t1, x0, 1 slt t1, t1, a3 # t1=1 if 1 < new_exp beq t1, x0, AFTER_NORM slli t6, t6, 1 addi a3, a3, -1 jal x0, NORM_LOOP_LOW AFTER_NORM: # new_mant = result & 0x7F andi t6, t6, 0x7F # 溢出 / 下溢 檢查 addi t0, x0, 255 bge a3, t0, RET_POS_INF # new_exp >= 0xFF → +Inf beq a3, x0, RET_ZERO4 # new_exp <= 0 → 0 slt t1, a3, x0 bne t1, x0, RET_ZERO4 # 組裝(正號):bits = ((new_exp&0xFF)<<7) | new_mant andi t0, a3, 255 slli t0, t0, 7 or a0, t0, t6 jalr x0, ra, 0 # ========= 多個返回路徑 ========= RET_A4: # 回 a(NaN 傳播等) addi a0, a0, 0 jalr x0, ra, 0 RET_ZERO4: # 回 0 addi a0, x0, 0 jalr x0, ra, 0 RET_POS_INF: # 回 +Inf : 0x7F80 addi t0, x0, 255 slli t0, t0, 7 # 0x7F80 addi a0, t0, 0 jalr x0, ra, 0 RET_NAN4: # 回 qNaN : 0x7FC0 addi t0, x0, 255 slli t0, t0, 7 # 0x7F80 ori a0, t0, 0x40 # 0x7FC0 jalr x0, ra, 0 ``` ### bf16_eq #### c code ```c= static inline bool bf16_eq(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return true; return a.bits == b.bits; } ``` #### assembly code ```c= bf16_eq: srli t0, a0, 7 # t0 = exp_a andi t0, t0, 255 addi t1, x0, 255 bne t0, t1, CheckIsNanB # if exp_a != 0xFF -> 不是 NaN andi t2, a0, 0x7F # mant_a beq t2, x0, CheckIsNanB # mant_a == 0 -> Inf,不是 NaN addi a0, x0, 0 # a 是 NaN -> false jalr x0, ra, 0 CheckIsNanB: srli t0, a1, 7 # t0 = exp_b andi t0, t0, 255 addi t1, x0, 255 bne t0, t1, CheckBothZero andi t2, a1, 0x7F # mant_b beq t2, x0, CheckBothZero addi a0, x0, 0 # b 是 NaN -> false jalr x0, ra, 0 CheckBothZero: lui t3, 0x8 # t3 = 0x00008000 addi t3, t3, -1 # t3 = 0x00007FFF # (a.bits & 0x7FFF) == 0 ? and t0, a0, t3 bne t0, x0, CompareBits # (b.bits & 0x7FFF) == 0 ? and t1, a1, t3 bne t1, x0, CompareBits # 兩個都是 ±0 -> true addi a0, x0, 1 jalr x0, ra, 0 CompareBits: xor t0, a0, a1 bne t0, x0, ReturnFalse # 若不同 -> false addi a0, x0, 1 # 相同 -> true jalr x0, ra, 0 ReturnFalse: addi a0, x0, 0 jalr x0, ra, 0 ``` ### bf16_lt #### c code ```c= static inline bool bf16_lt(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return false; bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1; if (sign_a != sign_b) return sign_a > sign_b; return sign_a ? a.bits > b.bits : a.bits < b.bits; } ``` #### assembly code ```c= bf16_lt: srli t0, a0, 7 # t0 = exp_a andi t0, t0, 255 addi t1, x0, 255 bne t0, t1, CheckIsNanB2 andi t2, a0, 0x7F # mant_a beq t2, x0, CheckIsNanB2 # mant==0 → Inf,不是 NaN addi a0, x0, 0 # a 是 NaN → false jalr x0, ra, 0 CheckIsNanB2: srli t0, a1, 7 # t0 = exp_b andi t0, t0, 255 addi t1, x0, 255 bne t0, t1, CheckBothZero andi t2, a1, 0x7F # mant_b beq t2, x0, CheckBothZero addi a0, x0, 0 # b 是 NaN → false jalr x0, ra, 0 CheckBothZero2: # 構造 0x7FFF 掩碼 lui t3, 0x8 # t3 = 0x00008000 addi t3, t3, -1 # t3 = 0x00007FFF and t0, a0, t3 # t0 = a.bits & 0x7FFF bne t0, x0, TakeSigns and t1, a1, t3 # t1 = b.bits & 0x7FFF bne t1, x0, TakeSigns # 兩者皆為 ±0 → a<b 為假 addi a0, x0, 0 jalr x0, ra, 0 TakeSigns: srli t0, a0, 15 # sign_a andi t0, t0, 1 srli t1, a1, 15 # sign_b andi t1, t1, 1 # 若符號不同:回傳 (sign_a > sign_b) xor t2, t0, t1 # t2 = sign_a ^ sign_b beq t2, x0, SameSign # 符號不同時:a< b 當且僅當 sign_a=1 且 sign_b=0 # 亦即直接回傳 sign_a addi a0, t0, 0 jalr x0, ra, 0 SameSign: # 若為負號:用位元序反向(a.bits > b.bits → a<b) bne t0, x0, BothNegative # 兩者皆為正:a.bits < b.bits → a<b sltu t2, a0, a1 # t2 = (a0 < a1) ? 1:0 addi a0, t2, 0 jalr x0, ra, 0 BothNegative: # 兩者皆為負:a.bits > b.bits → a<b sltu t2, a1, a0 # t2 = (a1 < a0) ? 1:0 <=> a0 > a1 addi a0, t2, 0 jalr x0, ra, 0 ``` ### bf16_gt #### c code ```c= static inline bool bf16_gt(bf16_t a, bf16_t b) { return bf16_lt(b, a); } ``` #### assembly code ```c= bf16_gt: addi t0, a0, 0 # t0 = a addi a0, a1, 0 # a0 = b addi a1, t0, 0 # a1 = a jal ra, bf16_lt # call bf16_lt(b, a) jalr x0, ra, 0 # return ```