# Assignment1: RISC-V Assembly and Instruction Pipeline contributed by <[`hsuhsuhs`](https://github.com/hsuhsuhs/2025_NCKU_Computer_Architecture/tree/main/Assignment1)> --- [TOC] ## Problem `B` in [Quiz1](https://hackmd.io/@sysprog/arch2025-quiz1-sol) - [ ] Decoding $$ D(b) = m \cdot 2^e + (2^e - 1) \cdot 16 $$ Where $e = \lfloor b/16 \rfloor$ (upper 4 bits)and $m = b \bmod 16$(lower 4 bits) - [ ] Encoding $$ E(v) = \begin{cases} v, & \text{if } v < 16 \\ 16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise} \end{cases} $$ where $\text{offset}(e) = (2^e - 1) \cdot 16$ ### Original C program <details> <summary><b>Open to see the complete C program </b></summary> ```c= #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> typedef uint8_t uf8; // count leading zeros) static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; // If the result is still non-zero after right-shifting by c bits, it means the highest bit is still in y if (y) { n -= c; // Reduce the number of leading x = y; // Narrow down the search range } c >>= 1; // 16 -> 8 -> 4 -> 2 -> 1 -> 0 } while (c); return n - x; } /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; // (0x7FFF >> (15 - exponent)) means (2^e)-1 uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { // 1) Roughly estimate e /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) // [eeee | mmmm] e <= 15 exponent = 15; // 2) Compute offset(e) via recurrence /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; // 3) If the estimate was too large, adjust downward (safety correction) /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; // Invert recurrence: go back to offset(e-1) exponent--; } } // 4) Adjust upward to find the exact e: keep advancing while we can cross into the next bucket /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; // = offset(e+1) if (value < next_overflow) // Can't cross further → current e is correct break; overflow = next_overflow; // Move forward to the next bucket exponent++; } uint8_t mantissa = (value - overflow) >> exponent; // = floor((value-offset(e))/2^e) return (exponent << 4) | mantissa; } /* Test encode/decode round-trip */ static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } int main(void) { if (test()) { printf("All tests passed.\n"); return 0; } return 1; } ``` </details> ### Assembly code <h4> <span style="color:darkblue">version <code>1</code></span> </h4> ```s .data # Test data test1: .word 15 # small value test2: .word 108 # medium value test3: .word 1000000 # large value # Test messages test_start_msg: .string "=== UF8 Automated Test ===\n\n" test1_msg: .string "Test 1 (small): " test2_msg: .string "Test 2 (medium): " test3_msg: .string "Test 3 (large): " arrow_msg: .string " -> Encoded: " decode_result_msg: .string ", Decoded: " pass_msg: .string " PASS\n" fail_msg: .string " FAIL\n" separator: .string "\n====================\n" all_pass_msg: .string "All tests passed.\n" some_fail_msg: .string "Tests failed.\n" newline: .string "\n" .text .globl _start _start: jal main li a7, 10 ecall print_string: li a7, 4 ecall jr ra print_int: li a7, 1 ecall jr ra print_char: li a7, 11 ecall jr ra # clz function (count leading zeros) clz: li t0, 32 # n = 32 li t1, 16 # c = 16 clz_loop: srl t2, a0, t1 # y = x >> c beqz t2, clz_skip # if (y == 0) skip sub t0, t0, t1 # n -= c mv a0, t2 # x = y clz_skip: srli t1, t1, 1 # c >>= 1 bnez t1, clz_loop # while (c != 0) sub a0, t0, a0 # return n - x jr ra # UF8 decoding function uf8_decode: andi t0, a0, 0x0F # mantissa = fl & 0x0F srli t1, a0, 4 # exponent = fl >> 4 # offset = (0x7FFF >> (15 - exponent)) << 4 li t2, 0x7FFF li t3, 15 sub t3, t3, t1 # 15 - exponent srl t2, t2, t3 # 0x7FFF >> (15 - exponent) slli t2, t2, 4 # << 4 sll t0, t0, t1 # mantissa << exponent add a0, t0, t2 # return value jr ra # UF8 encoding function uf8_encode: # Special case: value < 16 li t0, 16 blt a0, t0, encode_small addi sp, sp, -20 sw ra, 16(sp) sw s0, 12(sp) # value sw s1, 8(sp) # exponent sw s2, 4(sp) # overflow sw s3, 0(sp) # temporary variable mv s0, a0 # s0 = value li s1, 0 # exponent = 0 li s2, 0 # overflow = 0 # Estimate using CLZ mv a0, s0 jal clz li t0, 31 sub t0, t0, a0 # msb = 31 - clz(value) # If msb >= 5, estimate exponent li t1, 5 blt t0, t1, upscan_init # Estimate exponent = msb - 4 addi s1, t0, -4 # exponent = msb - 4 li t1, 15 ble s1, t1, calc_overflow li s1, 15 # clamp to max = 15 calc_overflow: # overflow = 16 * (2^exponent - 1) beqz s1, upscan_init # if exponent == 0 → overflow = 0 li t2, 0 # loop counter li s2, 0 # overflow = 0 build_loop: slli s2, s2, 1 # overflow << 1 addi s2, s2, 16 # overflow + 16 addi t2, t2, 1 blt t2, s1, build_loop j check_adjust check_adjust: beqz s1, upscan_init # if exponent == 0 → skip # If value < overflow, adjust downward bgeu s0, s2, upscan_init addi s2, s2, -16 # overflow - 16 srli s2, s2, 1 # (overflow - 16) >> 1 addi s1, s1, -1 # exponent-- j upscan_init # continue (only adjust once) upscan_init: # Start upward adjustment li s3, 15 # max_exponent = 15 upscan_loop: bge s1, s3, upscan_done # next_overflow = (overflow << 1) + 16 slli t0, s2, 1 addi t0, t0, 16 # Stop if value < next_overflow blt s0, t0, upscan_done # Move to next range mv s2, t0 # overflow = next_overflow addi s1, s1, 1 # exponent++ j upscan_loop upscan_done: # mantissa = (value - overflow) >> exponent sub t0, s0, s2 srl t0, t0, s1 # Limit mantissa to 15 li t1, 15 ble t0, t1, encode_pack li t0, 15 encode_pack: # Combine: (exponent << 4) | mantissa slli s1, s1, 4 or a0, s1, t0 lw ra, 16(sp) lw s0, 12(sp) lw s1, 8(sp) lw s2, 4(sp) lw s3, 0(sp) addi sp, sp, 20 jr ra encode_small: # For value < 16, return directly jr ra # Automated test function – test one value (returns 0=pass, 1=fail) test_single_value: # a0 = test value addi sp, sp, -16 sw ra, 12(sp) sw s0, 8(sp) # original value sw s1, 4(sp) # encoded sw s2, 0(sp) # decoded mv s0, a0 # save original # Show test info mv a0, s0 jal print_int la a0, arrow_msg jal print_string # Encode mv a0, s0 jal uf8_encode mv s1, a0 # save encoded result # Print encoded value mv a0, s1 jal print_int # Decode mv a0, s1 jal uf8_decode mv s2, a0 # save decoded result # Print decoded result la a0, decode_result_msg jal print_string mv a0, s2 jal print_int # Check decoded == original bne s2, s0, test_fail # Pass la a0, pass_msg jal print_string li a0, 0 j test_done test_fail: # Fail la a0, fail_msg jal print_string li a0, 1 test_done: lw ra, 12(sp) lw s0, 8(sp) lw s1, 4(sp) lw s2, 0(sp) addi sp, sp, 16 jr ra # Main – automated tests for three values main: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # test summary (0 = all passed) li s0, 0 # assume all pass la a0, test_start_msg jal print_string # Test 1: small value la a0, test1_msg jal print_string lw a0, test1 jal test_single_value or s0, s0, a0 # merge result # Test 2: medium value la a0, test2_msg jal print_string lw a0, test2 jal test_single_value or s0, s0, a0 # Test 3: large value la a0, test3_msg jal print_string lw a0, test3 jal test_single_value or s0, s0, a0 # Print summary la a0, separator jal print_string bnez s0, tests_failed # All passed la a0, all_pass_msg jal print_string j main_done tests_failed: # Some tests failed la a0, some_fail_msg jal print_string main_done: lw ra, 4(sp) lw s0, 0(sp) addi sp, sp, 8 jr ra ``` **1. Execution information** - <span style="color:darkblue">**$755 \;\text{cycles}$**</span> ![image](https://hackmd.io/_uploads/S1_lvvA2ee.png =350x) --- <h4> <span style="color:darkblue">version <code>2</code></span> </h4> **1. Improvement** * **Revised** the `uf8_decode` function in <span style="color:darkblue">**version**</span> **`1`**: use `((1 << e) - 1) << 4` instead of `(0x7FFF >> (15 - e)) << 4` ```s # UF8 decoding function --> value = (m<<e) + (((1<<e)-1)<<4) uf8_decode: andi t0, a0, 0x0F # m = b & 0x0F srli t1, a0, 4 # e = b >> 4 li t2, 1 sll t2, t2, t1 # 1<<e addi t2, t2, -1 # (1<<e)-1 slli t2, t2, 4 # offset = ((1<<e)-1)<<4 sll t0, t0, t1 # m<<e add a0, t0, t2 # value jr ra ``` **2. Execution information** - <span style="color:darkblue">**$749 \;\text{cycles}$**</span> ![截圖 2025-10-04 下午5.51.06](https://hackmd.io/_uploads/HkVO0vR3le.png =350x) **3. why faster ?** The difference lies in the number of instructions and the overhead of loading constants. | | `decode_v1` | `decode_v2` | |:--------------------:|:-----------:|:-----------:| | immediate instrs.num | 2 | 1 | | instrs.num | 10 | 9 | --- <h4> <span style="color:darkblue">version <code>3</code></span> </h4> **1. Improvement** * **Unloop** `CLZ` function in <span style="color:darkblue">**version**</span> **`2`**: ```s # clz function (count leading zeros) --> unloop clz: beqz a0, 9f # x==0 → 32 li t0, 0 # n = 0 # Check upper 16 bits srli t1, a0, 16 # t1 = x >> 16 bnez t1, 1f # if (t1 != 0) skip addi t0, t0, 16 # n += 16 slli a0, a0, 16 # x <<= 16 1: # Check upper 8 bits srli t1, a0, 24 # t1 = x >> 24 bnez t1, 2f # if (t1 != 0) skip addi t0, t0, 8 # n += 8 slli a0, a0, 8 # x <<= 8 2: # Check upper 4 bits srli t1, a0, 28 # t1 = x >> 28 bnez t1, 3f # if (t1 != 0) skip addi t0, t0, 4 # n += 4 slli a0, a0, 4 # x <<= 4 3: # Check upper 2 bits srli t1, a0, 30 # t1 = x >> 30 bnez t1, 4f # if (t1 != 0) skip addi t0, t0, 2 # n += 2 slli a0, a0, 2 # x <<= 2 4: # Check the most significant bit srli t1, a0, 31 # t1 = x >> 31 bnez t1, 5f # if (t1 != 0) skip addi t0, t0, 1 # n += 1 5: mv a0, t0 # return n (number of leading zeros) jr ra # return # Case when x == 0 9: li a0, 32 # return 32 (all bits are zero) jr ra ``` **2. Execution information** - <span style="color:darkblue">**$712 \;\text{cycles}$**</span> ![截圖 2025-10-04 晚上8.23.26](https://hackmd.io/_uploads/Hk3bfq0hgx.png =350x) **3. why faster ?** * **fewer dynamic instructions**: The unrolled version replaces the `16/8/4/2/1` loop iterations with straight-line conditional checks.Each stage only performs a simple `“test + (if needed) add n / shift x”` operation, * **fewer loop-carried dependencies.** The unrolled version still has dependencies, but since there is **no backward branch**,the entire sequence becomes a **fixed-depth linear dependency chain** that can be forwarded smoothly through a simple pipeline. * **without accumulating per-iteration branch decision latency** --- <h4> <span style="color:darkblue">complete version</span> </h4> ```s .data # Test data test1: .word 15 # small value test2: .word 108 # medium value test3: .word 1000000 # large value # Test messages test_start_msg: .string "=== UF8 Automated Test ===\n\n" test1_msg: .string "Test 1 (small): " test2_msg: .string "Test 2 (medium): " test3_msg: .string "Test 3 (large): " arrow_msg: .string " -> Encoded: " decode_result_msg: .string ", Decoded: " pass_msg: .string " PASS\n" fail_msg: .string " FAIL\n" separator: .string "\n====================\n" all_pass_msg: .string "All tests passed.\n" some_fail_msg: .string "Tests failed.\n" newline: .string "\n" .text .globl _start _start: jal main li a7, 10 ecall print_string: li a7, 4 ecall jr ra print_int: li a7, 1 ecall jr ra print_char: li a7, 11 ecall jr ra # clz function (count leading zeros) --> unloop clz: beqz a0, 9f # x==0 → 32 li t0, 0 # n = 0 # Check upper 16 bits srli t1, a0, 16 # t1 = x >> 16 bnez t1, 1f # if (t1 != 0) skip addi t0, t0, 16 # n += 16 slli a0, a0, 16 # x <<= 16 1: # Check upper 8 bits srli t1, a0, 24 # t1 = x >> 24 bnez t1, 2f # if (t1 != 0) skip addi t0, t0, 8 # n += 8 slli a0, a0, 8 # x <<= 8 2: # Check upper 4 bits srli t1, a0, 28 # t1 = x >> 28 bnez t1, 3f # if (t1 != 0) skip addi t0, t0, 4 # n += 4 slli a0, a0, 4 # x <<= 4 3: # Check upper 2 bits srli t1, a0, 30 # t1 = x >> 30 bnez t1, 4f # if (t1 != 0) skip addi t0, t0, 2 # n += 2 slli a0, a0, 2 # x <<= 2 4: # Check the most significant bit srli t1, a0, 31 # t1 = x >> 31 bnez t1, 5f # if (t1 != 0) skip addi t0, t0, 1 # n += 1 5: mv a0, t0 # return n (number of leading zeros) jr ra # return # Case when x == 0 9: li a0, 32 # return 32 (all bits are zero) jr ra # UF8 decoding function --> value = (m<<e) + (((1<<e)-1)<<4) uf8_decode: andi t0, a0, 0x0F # m = b & 0x0F srli t1, a0, 4 # e = b >> 4 li t2, 1 sll t2, t2, t1 # 1<<e addi t2, t2, -1 # (1<<e)-1 slli t2, t2, 4 # offset = ((1<<e)-1)<<4 sll t0, t0, t1 # m<<e add a0, t0, t2 # value jr ra # UF8 encoding function uf8_encode: # Special case: value < 16 li t0, 16 blt a0, t0, encode_small addi sp, sp, -20 sw ra, 16(sp) sw s0, 12(sp) # value sw s1, 8(sp) # exponent sw s2, 4(sp) # overflow sw s3, 0(sp) # temporary variable mv s0, a0 # s0 = value li s1, 0 # exponent = 0 li s2, 0 # overflow = 0 # Estimate using CLZ mv a0, s0 jal clz li t0, 31 sub t0, t0, a0 # msb = 31 - clz(value) # If msb >= 5, estimate exponent li t1, 5 blt t0, t1, upscan_init # Estimate exponent = msb - 4 addi s1, t0, -4 # exponent = msb - 4 li t1, 15 ble s1, t1, calc_overflow li s1, 15 # clamp to max = 15 calc_overflow: # overflow = 16 * (2^exponent - 1) beqz s1, upscan_init # if exponent == 0 → overflow = 0 li t2, 0 # loop counter li s2, 0 # overflow = 0 build_loop: slli s2, s2, 1 # overflow << 1 addi s2, s2, 16 # overflow + 16 addi t2, t2, 1 blt t2, s1, build_loop j check_adjust check_adjust: beqz s1, upscan_init # if exponent == 0 → skip # If value < overflow, adjust downward bgeu s0, s2, upscan_init addi s2, s2, -16 # overflow - 16 srli s2, s2, 1 # (overflow - 16) >> 1 addi s1, s1, -1 # exponent-- j upscan_init # continue (only adjust once) upscan_init: # Start upward adjustment li s3, 15 # max_exponent = 15 upscan_loop: bge s1, s3, upscan_done # next_overflow = (overflow << 1) + 16 slli t0, s2, 1 addi t0, t0, 16 # Stop if value < next_overflow blt s0, t0, upscan_done # Move to next range mv s2, t0 # overflow = next_overflow addi s1, s1, 1 # exponent++ j upscan_loop upscan_done: # mantissa = (value - overflow) >> exponent sub t0, s0, s2 srl t0, t0, s1 # Limit mantissa to 15 li t1, 15 ble t0, t1, encode_pack li t0, 15 encode_pack: # Combine: (exponent << 4) | mantissa slli s1, s1, 4 or a0, s1, t0 lw ra, 16(sp) lw s0, 12(sp) lw s1, 8(sp) lw s2, 4(sp) lw s3, 0(sp) addi sp, sp, 20 jr ra encode_small: # For value < 16, return directly jr ra # Automated test function – test one value (returns 0=pass, 1=fail) test_single_value: # a0 = test value addi sp, sp, -16 sw ra, 12(sp) sw s0, 8(sp) # original value sw s1, 4(sp) # encoded sw s2, 0(sp) # decoded mv s0, a0 # save original # Show test info mv a0, s0 jal print_int la a0, arrow_msg jal print_string # Encode mv a0, s0 jal uf8_encode mv s1, a0 # save encoded result # Print encoded value mv a0, s1 jal print_int # Decode mv a0, s1 jal uf8_decode mv s2, a0 # save decoded result # Print decoded result la a0, decode_result_msg jal print_string mv a0, s2 jal print_int # Check decoded == original bne s2, s0, test_fail # Pass la a0, pass_msg jal print_string li a0, 0 j test_done test_fail: # Fail la a0, fail_msg jal print_string li a0, 1 test_done: lw ra, 12(sp) lw s0, 8(sp) lw s1, 4(sp) lw s2, 0(sp) addi sp, sp, 16 jr ra # Main – automated tests for three values main: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # test summary (0 = all passed) li s0, 0 # assume all pass la a0, test_start_msg jal print_string # Test 1: small value la a0, test1_msg jal print_string lw a0, test1 jal test_single_value or s0, s0, a0 # merge result # Test 2: medium value la a0, test2_msg jal print_string lw a0, test2 jal test_single_value or s0, s0, a0 # Test 3: large value la a0, test3_msg jal print_string lw a0, test3 jal test_single_value or s0, s0, a0 # Print summary la a0, separator jal print_string bnez s0, tests_failed # All passed la a0, all_pass_msg jal print_string j main_done tests_failed: # Some tests failed la a0, some_fail_msg jal print_string main_done: lw ra, 4(sp) lw s0, 0(sp) addi sp, sp, 8 jr ra ``` --- ## Problem `C` in [Quiz1](https://hackmd.io/@sysprog/arch2025-quiz1-sol) ### Original C program <details> <summary><b>Open to see the complete C program </b></summary> ```c= #include <stdbool.h> #include <stdint.h> #include <string.h> typedef struct { uint16_t bits; } bf16_t; #define BF16_SIGN_MASK 0x8000U #define BF16_EXP_MASK 0x7F80U #define BF16_MANT_MASK 0x007FU #define BF16_EXP_BIAS 127 #define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) #define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F), }; } static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } static inline bool bf16_eq(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return true; return a.bits == b.bits; } static inline bool bf16_lt(bf16_t a, bf16_t b) { if (bf16_isnan(a) || bf16_isnan(b)) return false; if (bf16_iszero(a) && bf16_iszero(b)) return false; bool sign_a = (a.bits >> 15) & 1, sign_b = (b.bits >> 15) & 1; if (sign_a != sign_b) return sign_a > sign_b; return sign_a ? a.bits > b.bits : a.bits < b.bits; } static inline bool bf16_gt(bf16_t a, bf16_t b) { return bf16_lt(b, a); } #include <stdio.h> #include <time.h> #define TEST_ASSERT(cond, msg) \ do { \ if (!(cond)) { \ printf("FAIL: %s\n", msg); \ return 1; \ } \ } while (0) static int test_basic_conversions(void) { printf("Testing basic conversions...\n"); float test_values[] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f, 0.5f, -0.5f, 3.14159f, -3.14159f, 1e10f, -1e10f}; for (size_t i = 0; i < sizeof(test_values) / sizeof(test_values[0]); i++) { float orig = test_values[i]; bf16_t bf = f32_to_bf16(orig); float conv = bf16_to_f32(bf); if (orig != 0.0f) { TEST_ASSERT((orig < 0) == (conv < 0), "Sign mismatch"); } if (orig != 0.0f && !bf16_isinf(f32_to_bf16(orig))) { float diff = (conv - orig); float rel_error = (diff < 0) ? -diff / orig : diff / orig; TEST_ASSERT(rel_error < 0.01f, "Relative error too large"); } } printf(" Basic conversions: PASS\n"); return 0; } static int test_special_values(void) { printf("Testing special values...\n"); bf16_t pos_inf = {.bits = 0x7F80}; /* +Infinity */ TEST_ASSERT(bf16_isinf(pos_inf), "Positive infinity not detected"); TEST_ASSERT(!bf16_isnan(pos_inf), "Infinity detected as NaN"); bf16_t neg_inf = {.bits = 0xFF80}; /* -Infinity */ TEST_ASSERT(bf16_isinf(neg_inf), "Negative infinity not detected"); bf16_t nan_val = BF16_NAN(); TEST_ASSERT(bf16_isnan(nan_val), "NaN not detected"); TEST_ASSERT(!bf16_isinf(nan_val), "NaN detected as infinity"); bf16_t zero = f32_to_bf16(0.0f); TEST_ASSERT(bf16_iszero(zero), "Zero not detected"); bf16_t neg_zero = f32_to_bf16(-0.0f); TEST_ASSERT(bf16_iszero(neg_zero), "Negative zero not detected"); printf(" Special values: PASS\n"); return 0; } static int test_arithmetic(void) { printf("Testing arithmetic operations...\n"); bf16_t a = f32_to_bf16(1.0f); bf16_t b = f32_to_bf16(2.0f); bf16_t c = bf16_add(a, b); float result = bf16_to_f32(c); float diff = result - 3.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Addition failed"); c = bf16_sub(b, a); result = bf16_to_f32(c); diff = result - 1.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "Subtraction failed"); a = f32_to_bf16(3.0f); b = f32_to_bf16(4.0f); c = bf16_mul(a, b); result = bf16_to_f32(c); diff = result - 12.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Multiplication failed"); a = f32_to_bf16(10.0f); b = f32_to_bf16(2.0f); c = bf16_div(a, b); result = bf16_to_f32(c); diff = result - 5.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.1f, "Division failed"); /* Test square root */ a = f32_to_bf16(4.0f); c = bf16_sqrt(a); result = bf16_to_f32(c); diff = result - 2.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(4) failed"); a = f32_to_bf16(9.0f); c = bf16_sqrt(a); result = bf16_to_f32(c); diff = result - 3.0f; TEST_ASSERT((diff < 0 ? -diff : diff) < 0.01f, "sqrt(9) failed"); printf(" Arithmetic: PASS\n"); return 0; } static int test_comparisons(void) { printf("Testing comparison operations...\n"); bf16_t a = f32_to_bf16(1.0f); bf16_t b = f32_to_bf16(2.0f); bf16_t c = f32_to_bf16(1.0f); TEST_ASSERT(bf16_eq(a, c), "Equality test failed"); TEST_ASSERT(!bf16_eq(a, b), "Inequality test failed"); TEST_ASSERT(bf16_lt(a, b), "Less than test failed"); TEST_ASSERT(!bf16_lt(b, a), "Not less than test failed"); TEST_ASSERT(!bf16_lt(a, c), "Equal not less than test failed"); TEST_ASSERT(bf16_gt(b, a), "Greater than test failed"); TEST_ASSERT(!bf16_gt(a, b), "Not greater than test failed"); bf16_t nan_val = BF16_NAN(); TEST_ASSERT(!bf16_eq(nan_val, nan_val), "NaN equality test failed"); TEST_ASSERT(!bf16_lt(nan_val, a), "NaN less than test failed"); TEST_ASSERT(!bf16_gt(nan_val, a), "NaN greater than test failed"); printf(" Comparisons: PASS\n"); return 0; } static int test_edge_cases(void) { printf("Testing edge cases...\n"); float tiny = 1e-45f; bf16_t bf_tiny = f32_to_bf16(tiny); float tiny_val = bf16_to_f32(bf_tiny); TEST_ASSERT(bf16_iszero(bf_tiny) || (tiny_val < 0 ? -tiny_val : tiny_val) < 1e-37f, "Tiny value handling"); float huge = 1e38f; bf16_t bf_huge = f32_to_bf16(huge); bf16_t bf_huge2 = bf16_mul(bf_huge, f32_to_bf16(10.0f)); TEST_ASSERT(bf16_isinf(bf_huge2), "Overflow should produce infinity"); bf16_t small = f32_to_bf16(1e-38f); bf16_t smaller = bf16_div(small, f32_to_bf16(1e10f)); float smaller_val = bf16_to_f32(smaller); TEST_ASSERT(bf16_iszero(smaller) || (smaller_val < 0 ? -smaller_val : smaller_val) < 1e-45f, "Underflow should produce zero or denormal"); printf(" Edge cases: PASS\n"); return 0; } static int test_rounding(void) { printf("Testing rounding behavior...\n"); float exact = 1.5f; bf16_t bf_exact = f32_to_bf16(exact); float back_exact = bf16_to_f32(bf_exact); TEST_ASSERT(back_exact == exact, "Exact representation should be preserved"); float val = 1.0001f; bf16_t bf = f32_to_bf16(val); float back = bf16_to_f32(bf); float diff2 = back - val; TEST_ASSERT((diff2 < 0 ? -diff2 : diff2) < 0.001f, "Rounding error should be small"); printf(" Rounding: PASS\n"); return 0; } #ifndef BFLOAT16_NO_MAIN int main(void) { printf("\n=== bfloat16 Test Suite ===\n\n"); int failed = 0; failed |= test_basic_conversions(); failed |= test_special_values(); failed |= test_arithmetic(); failed |= test_comparisons(); failed |= test_edge_cases(); failed |= test_rounding(); if (failed) { printf("\n=== TESTS FAILED ===\n"); return 1; } printf("\n=== ALL TESTS PASSED ===\n"); return 0; } #endif /* BFLOAT16_NO_MAIN */ ``` </details> ### Assembly code <h4> <span style="color:darkblue">Part<code>1</code> : Bfloat16 <-----> float32</span> </h4> Contains special value test. ```s .data # Test Case Description Strings --------------------------------------------- str_t1: .asciz " 1.0f" str_t_neg_simple: .asciz " -4.0f" str_t3: .asciz " 3.14159f" str_t4: .asciz " 0.1f" str_t_neg_round: .asciz " -2.7f" str_t5: .asciz " +0.0f" str_t6: .asciz " -0.0f" str_t7: .asciz " +Infinity" str_t8: .asciz " -Infinity" str_t9: .asciz " NaN" # Normal Value Tests (with corrected golden values) ------------------------ normal_test_strings: .word str_t1, str_t_neg_simple, str_t3, str_t4, str_t_neg_round normal_test_inputs: .word 0x3f800000, 0xc0800000, 0x40490fdb, 0x3dcccccd, 0xc02ccccd normal_test_golden: .word 0x3f80, 0xc080, 0x4049, 0x3dcd, 0xc02d # Special Value Tests (with corrected golden values) --------------------------- special_test_strings: .word str_t5, str_t6, str_t7, str_t8, str_t9 special_test_inputs: .word 0x00000000, 0x80000000, 0x7f800000, 0xff800000, 0x7fc00000 special_test_golden: .word 0x0000, 0x8000, 0x7f80, 0xff80, 0x7fc0 # tests message --------------------------------------------- str_header_normal: .asciz "\n--- Running Normal Value Test Cases ---\n" str_header_special: .asciz "\n--- Running Special Value Test Cases (Zero, Inf, NaN) ---\n" str_testing: .asciz "Testing" str_orig_label: .asciz "\n Original f32: 0x" str_bf16_label: .asciz " -> bf16: 0x" str_restored_label: .asciz " -> Restored f32: 0x" str_success: .asciz " [PASS]" str_fail: .asciz " [FAIL]" str_actual: .asciz " (Actual: 0x" str_expected: .asciz ", Expected: 0x" str_close_paren: .asciz ")\n" str_summary: .asciz "\n\n[Summary: " str_summary_middle: .asciz " / " str_summary_end: .asciz " Tests Passed]\n" newline: .asciz "\n" str_all_pass: .asciz "\n--- All tests passed! ---\n" str_some_fail: .asciz "\n--- Some tests failed! ---\n" .text .globl main # ==================== main tests ==================== main: # Function Prologue addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # Reserve space on stack for s0 (total failure counter) # Initialize overall failure counter mv s0, zero # s0 = total_failures = 0 # ======== Run Normal Tests ======== la a0, str_header_normal jal ra, print_string la a0, normal_test_strings la a1, normal_test_inputs la a2, normal_test_golden addi a3, zero, 5 jal ra, run_test_suite add s0, s0, a0 # Accumulate failures from this test suite # ======== Run Special Tests ======== la a0, str_header_special jal ra, print_string la a0, special_test_strings la a1, special_test_inputs la a2, special_test_golden addi a3, zero, 5 jal ra, run_test_suite add s0, s0, a0 # Accumulate failures again # ======== Print Final Overall Summary ======== bnez s0, some_tests_failed all_tests_passed: la a0, str_all_pass jal ra, print_string j exit_program some_tests_failed: la a0, str_some_fail jal ra, print_string exit_program: # Function Epilogue lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 addi a7, zero, 10 # ecall 10: Exit ecall # ============================================================================= # run_test_suite: Loops through a set of tests, runs them, # and returns the number of failures. # Return Value: a0 = Number of failures in this suite. #============================================================================== run_test_suite: # Function Prologue: Allocate 28 bytes (1 ra + 6 s-regs) addi sp, sp, -28 sw ra, 24(sp) sw s0, 20(sp) sw s1, 16(sp) sw s2, 12(sp) sw s3, 8(sp) sw s4, 4(sp) sw s5, 0(sp) mv s0, a0 # s0 = string array mv s1, a1 # s1 = input array mv s2, a2 # s2 = golden value array mv s3, a3 # s3 = Loop counter mv s4, zero # s4 = Success counter mv s5, zero # s5 = Failure counter mv t5, a3 # t5 = Backup of total test count for summary test_loop: # Pass pointers to the single test runner mv a0, s0 mv a1, s1 mv a2, s2 jal ra, run_single_test # run_single_test returns 1 in a0 if it failed, 0 if success add s5, s5, a0 # Accumulate failure count # Advance all array pointers to the next test case addi s0, s0, 4 addi s1, s1, 4 addi s2, s2, 4 addi s3, s3, -1 bnez s3, test_loop # Continue if tests remain # ========= Print Suite Summary =========== sub s4, t5, s5 # Success count is total - failures la a0, str_summary jal ra, print_string mv a0, s4 jal ra, print_int la a0, str_summary_middle jal ra, print_string mv a0, t5 jal ra, print_int la a0, str_summary_end jal ra, print_string # Function Epilogue mv a0, s5 # Set return value to the failure count lw s5, 0(sp) lw s4, 4(sp) lw s3, 8(sp) lw s2, 12(sp) lw s1, 16(sp) lw s0, 20(sp) lw ra, 24(sp) addi sp, sp, 28 ret # ============================================================================= # run_single_test: Performs a round-trip conversion and checks the result. # Arguments: a0=str_ptr, a1=input_ptr, a2=golden_ptr # Return Value: a0 = 1 if fail, 0 if success. #============================================================================== run_single_test: # Function Prologue addi sp, sp, -20 sw ra, 16(sp) sw s0, 12(sp) sw s1, 8(sp) sw s2, 4(sp) sw s3, 0(sp) # Load test data into saved registers lw s0, 0(a0) # s0 = address of description string lw s1, 0(a1) # s1 = original f32 input value lw s2, 0(a2) # s2 = expected bf16 golden value # Perform Conversions mv a0, s1 # Set argument for f32_to_bf16 jal ra, f32_to_bf16 mv s3, a0 # s3 = actual_bf16_result jal ra, bf16_to_f32 mv t0, a0 # t0 = restored_f32_result # Calculate Golden Restored Value slli t1, s2, 16 # t1 = golden_restored_f32 # ===== Print Results on a Single Line ========== la a0, str_testing jal ra, print_string # Print "Testing" mv a0, s0 jal ra, print_string # Print " 1.0f" la a0, str_orig_label jal ra, print_string # Print "\n Original f32: 0x" mv a0, s1 jal ra, print_hex32 # Print original value la a0, str_bf16_label jal ra, print_string # Print " -> bf16: 0x" mv a0, s3 jal ra, print_hex32 # Print bf16 value la a0, str_restored_label jal ra, print_string # Print " -> Restored f32: 0x" mv a0, t0 jal ra, print_hex32 # Print restored value # Compare restored_f32 (t0) with golden_restored_f32 (t1) beq t0, t1, test_success test_fail: la a0, str_fail jal ra, print_string la a0, str_actual jal ra, print_string mv a0, t0 jal ra, print_hex32 la a0, str_expected jal ra, print_string mv a0, t1 jal ra, print_hex32 la a0, str_close_paren jal ra, print_string addi a0, zero, 1 # Return 1 for failure j end_single_test test_success: la a0, str_success jal ra, print_string la a0, newline jal ra, print_string addi a0, zero, 0 # Return 0 for success end_single_test: # Function Epilogue lw s3, 0(sp) lw s2, 4(sp) lw s1, 8(sp) lw s0, 12(sp) lw ra, 16(sp) addi sp, sp, 20 ret # ==================== bf16_t f32_to_bf16(float val) ======================== f32_to_bf16: # Check if the number is NaN or Infinity by inspecting the exponent. srli t0, a0, 23 # Isolate exponent and sign. andi t0, t0, 0xFF # Mask to get only the 8 exponent bits. addi t1, zero, 0xFF # Load 0xFF for comparison. beq t0, t1, is_nan_or_inf # If exponent is all 1s, jump to special handling. # ====== Normal Number Rounding Path ====== lui t0, 0x80000 # Load upper bits of the sign mask (0x80000000). and t1, a0, t0 # t1 = sign bit (0x80000000 or 0). not t0, t0 # sign mask = 0x7FFFFFFF. and t2, a0, t0 # t2 = magnitude. srli t3, t2, 16 # Shift magnitude to get the tie-breaking bit's position. andi t3, t3, 1 # Isolate the tie-breaking bit (0 or 1). lui t4, 0x8 # Load upper bits of 0x8000. addi t4, t4, -1 # Create the main rounding constant 0x7FFF. add t3, t3, t4 # t3 = final addend (0x7FFF or 0x8000). add t2, t2, t3 srli t2, t2, 16 # Truncate the rounded magnitude to 16 bits. srli t1, t1, 16 # Shift the original sign bit to its bf16 position (bit 15). or a0, t2, t1 ret is_nan_or_inf: # For NaN, Infinity, and Zero, the correct behavior is simple truncation. srli a0, a0, 16 # Shift right by 16 bits. ret # ================ float bf16_to_f32(bf16_t val) ====================== bf16_to_f32: slli a0, a0, 16 # Shift left by 16, padding lower bits with zeros. ret # =========== Helper Print Functions ========== print_string: addi a7, zero, 4 # Set ecall code for Print String. ecall ret print_hex32: addi a7, zero, 34 # Set ecall code for Print Hex. ecall ret print_int: addi a7, zero, 1 # Set ecall code for Print Integer. ecall ret ``` **1. Execution information** - <span style="color:darkblue">**$2075 \;\text{cycles}$**</span> ![截圖 2025-10-12 下午6.54.21](https://hackmd.io/_uploads/S1uNF-tpgx.png =350x) <h4> <span style="color:darkblue">Part<code>2</code> : Bfloat16 arithmetic_operations</span> </h4> Contains special value test. ```s .data # Addition test cases add_a1: .word 0x00003F80 # 1.0 in bfloat16 add_b1: .word 0x00004000 # 2.0 in bfloat16 add_exp1: .word 0x00004040 # 3.0 expected (1.0 + 2.0) add_a2: .word 0x00003FC0 # 1.5 in bfloat16 add_b2: .word 0x00003FC0 # 1.5 in bfloat16 add_exp2: .word 0x00004040 # 3.0 expected (1.5 + 1.5) add_a3: .word 0x00003F00 # 0.5 in bfloat16 add_b3: .word 0x00003F00 # 0.5 in bfloat16 add_exp3: .word 0x00003F80 # 1.0 expected (0.5 + 0.5) # Subtraction test cases sub_a1: .word 0x00004040 # 3.0 in bfloat16 sub_b1: .word 0x00004000 # 2.0 in bfloat16 sub_exp1: .word 0x00003F80 # 1.0 expected (3.0 - 2.0) sub_a2: .word 0x00004000 # 2.0 in bfloat16 sub_b2: .word 0x00003F80 # 1.0 in bfloat16 sub_exp2: .word 0x00003F80 # 1.0 expected (2.0 - 1.0) sub_a3: .word 0x00004040 # 3.0 in bfloat16 sub_b3: .word 0x00003FC0 # 1.5 in bfloat16 sub_exp3: .word 0x00003FC0 # 1.5 expected (3.0 - 1.5) # Multiplication test cases mul_a1: .word 0x00004040 # 3.0 in bfloat16 mul_b1: .word 0x00004080 # 4.0 in bfloat16 mul_exp1: .word 0x00004180 # 12.0 expected (3.0 * 4.0) mul_a2: .word 0x00004000 # 2.0 in bfloat16 mul_b2: .word 0x00004020 # 2.5 in bfloat16 mul_exp2: .word 0x00004140 # 5.0 expected (2.0 * 2.5) mul_a3: .word 0x00003FC0 # 1.5 in bfloat16 mul_b3: .word 0x00004000 # 2.0 in bfloat16 mul_exp3: .word 0x00004040 # 3.0 expected (1.5 * 2.0) # Division test cases div_a1: .word 0x00004140 # 5.0 in bfloat16 div_b1: .word 0x00004000 # 2.0 in bfloat16 div_exp1: .word 0x000040A0 # 2.5 expected (5.0 / 2.0) div_a2: .word 0x00004180 # 6.0 in bfloat16 div_b2: .word 0x00004000 # 2.0 in bfloat16 div_exp2: .word 0x00004040 # 3.0 expected (6.0 / 2.0) div_a3: .word 0x00004040 # 3.0 in bfloat16 div_b3: .word 0x00004000 # 2.0 in bfloat16 div_exp3: .word 0x00003FC0 # 1.5 expected (3.0 / 2.0) # Special values for testing edge cases test_nan: .word 0x00007FC0 # NaN value test_inf: .word 0x00007F80 # Positive infinity test_zero: .word 0x00000000 # Zero value # Result storage for test verification result: .word 0x00000000 # tests message --------------------------------------------- test_start: .string "bfloat16 Complete Test Suite\n\n" add_header: .string "=== Addition Tests ===\n" sub_header: .string "\n=== Subtraction Tests ===\n" mul_header: .string "\n=== Multiplication Tests ===\n" div_header: .string "\n=== Division Tests ===\n" special_header:.string "\n=== Special Values Tests ===\n" # Addition test descriptions add_test1: .string "1.0 + 2.0 = 3.0: " add_test2: .string "1.5 + 1.5 = 3.0: " add_test3: .string "0.5 + 0.5 = 1.0: " # Subtraction test descriptions sub_test1: .string "3.0 - 2.0 = 1.0: " sub_test2: .string "2.0 - 1.0 = 1.0: " sub_test3: .string "3.0 - 1.5 = 1.5: " # Multiplication test descriptions mul_test1: .string "3.0 * 4.0 = 12.0: " mul_test2: .string "2.0 * 2.5 = 5.0: " mul_test3: .string "1.5 * 2.0 = 3.0: " # Division test descriptions div_test1: .string "5.0 / 2.0 = 2.5: " div_test2: .string "6.0 / 2.0 = 3.0: " div_test3: .string "3.0 / 2.0 = 1.5: " # Special values test description special_test: .string "Special values (NaN, Inf, Zero): " # Test result messages pass_msg: .string "PASS\n" fail_msg: .string "FAIL\n" # Summary and statistics messages summary: .string "\n=== Test Summary ===\n" total_tests: .string "Total tests: " passed_tests: .string "Passed: " failed_tests: .string "Failed: " all_pass: .string "\nAll tests passed!\n" some_fail: .string "\nSome tests failed.\n" newline: .string "\n" .text .globl _start # ==================== main tests ==================== _start: la a0, test_start # Print test suite header addi a7, zero, 4 ecall addi s0, zero, 0 # s0 = pass counter addi s1, zero, 0 # s1 = total test counter # ==================== addition tests ==================== la a0, add_header addi a7, zero, 4 ecall # Test 1.1: 1.0 + 2.0 = 3.0 la a0, add_test1 addi a7, zero, 4 ecall la t0, add_a1 # Load test values and perform addition lw a0, 0(t0) la t0, add_b1 lw a1, 0(t0) jal bf16_add la t0, result # Store and verify result sw a0, 0(t0) la t0, add_exp1 lw a1, 0(t0) jal bf16_eq jal check_test_result # Check and display result # Test 1.2: 1.5 + 1.5 = 3.0 la a0, add_test2 addi a7, zero, 4 ecall la t0, add_a2 lw a0, 0(t0) la t0, add_b2 lw a1, 0(t0) jal bf16_add la t0, result sw a0, 0(t0) la t0, add_exp2 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 1.3: 0.5 + 0.5 = 1.0 la a0, add_test3 addi a7, zero, 4 ecall la t0, add_a3 lw a0, 0(t0) la t0, add_b3 lw a1, 0(t0) jal bf16_add la t0, result sw a0, 0(t0) la t0, add_exp3 lw a1, 0(t0) jal bf16_eq jal check_test_result # ==================== subtraction tests ==================== la a0, sub_header addi a7, zero, 4 ecall # Test 2.1: 3.0 - 2.0 = 1.0 la a0, sub_test1 addi a7, zero, 4 ecall la t0, sub_a1 lw a0, 0(t0) la t0, sub_b1 lw a1, 0(t0) jal bf16_sub la t0, result sw a0, 0(t0) la t0, sub_exp1 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 2.2: 2.0 - 1.0 = 1.0 la a0, sub_test2 addi a7, zero, 4 ecall la t0, sub_a2 lw a0, 0(t0) la t0, sub_b2 lw a1, 0(t0) jal bf16_sub la t0, result sw a0, 0(t0) la t0, sub_exp2 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 2.3: 3.0 - 1.5 = 1.5 la a0, sub_test3 addi a7, zero, 4 ecall la t0, sub_a3 lw a0, 0(t0) la t0, sub_b3 lw a1, 0(t0) jal bf16_sub la t0, result sw a0, 0(t0) la t0, sub_exp3 lw a1, 0(t0) jal bf16_eq jal check_test_result # ==================== multiplication tests ==================== la a0, mul_header addi a7, zero, 4 ecall # Test 3.1: 3.0 * 4.0 = 12.0 la a0, mul_test1 addi a7, zero, 4 ecall la t0, mul_a1 lw a0, 0(t0) la t0, mul_b1 lw a1, 0(t0) jal bf16_mul la t0, result sw a0, 0(t0) la t0, mul_exp1 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 3.2: 2.0 * 2.5 = 5.0 la a0, mul_test2 addi a7, zero, 4 ecall la t0, mul_a2 lw a0, 0(t0) la t0, mul_b2 lw a1, 0(t0) jal bf16_mul la t0, result sw a0, 0(t0) la t0, mul_exp2 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 3.3: 1.5 * 2.0 = 3.0 la a0, mul_test3 addi a7, zero, 4 ecall la t0, mul_a3 lw a0, 0(t0) la t0, mul_b3 lw a1, 0(t0) jal bf16_mul la t0, result sw a0, 0(t0) la t0, mul_exp3 lw a1, 0(t0) jal bf16_eq jal check_test_result # ==================== division tests ==================== la a0, div_header addi a7, zero, 4 ecall # Test 4.1: 5.0 / 2.0 = 2.5 la a0, div_test1 addi a7, zero, 4 ecall la t0, div_a1 lw a0, 0(t0) la t0, div_b1 lw a1, 0(t0) jal bf16_div la t0, result sw a0, 0(t0) la t0, div_exp1 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 4.2: 6.0 / 2.0 = 3.0 la a0, div_test2 addi a7, zero, 4 ecall la t0, div_a2 lw a0, 0(t0) la t0, div_b2 lw a1, 0(t0) jal bf16_div la t0, result sw a0, 0(t0) la t0, div_exp2 lw a1, 0(t0) jal bf16_eq jal check_test_result # Test 4.3: 3.0 / 2.0 = 1.5 la a0, div_test3 addi a7, zero, 4 ecall la t0, div_a3 lw a0, 0(t0) la t0, div_b3 lw a1, 0(t0) jal bf16_div la t0, result sw a0, 0(t0) la t0, div_exp3 lw a1, 0(t0) jal bf16_eq jal check_test_result # ==================== special value test ==================== la a0, special_header addi a7, zero, 4 ecall la a0, special_test addi a7, zero, 4 ecall # Test NaN detection la t0, test_nan lw a0, 0(t0) jal bf16_isnan beq a0, zero, special_fail # Test Infinity detection la t0, test_inf lw a0, 0(t0) jal bf16_isinf beq a0, zero, special_fail # Test zero detection la t0, test_zero lw a0, 0(t0) jal bf16_iszero beq a0, zero, special_fail # All special value tests passed addi s0, s0, 1 addi s1, s1, 1 la a0, pass_msg addi a7, zero, 4 ecall j print_summary special_fail: # Special value test failed addi s1, s1, 1 la a0, fail_msg addi a7, zero, 4 ecall # ==================== test summary ==================== print_summary: la a0, summary # Print test summary header addi a7, zero, 4 ecall sub s2, s1, s0 # s2 = failed tests = total - passed # Print total tests count la a0, total_tests addi a7, zero, 4 ecall add a0, zero, s1 addi a7, zero, 1 ecall la a0, newline addi a7, zero, 4 ecall # Print passed tests count la a0, passed_tests addi a7, zero, 4 ecall add a0, zero, s0 addi a7, zero, 1 ecall la a0, newline addi a7, zero, 4 ecall # Print failed tests count la a0, failed_tests addi a7, zero, 4 ecall add a0, zero, s2 addi a7, zero, 1 ecall la a0, newline addi a7, zero, 4 ecall # Print final result message beq s2, zero, all_passed_msg la a0, some_fail j exit all_passed_msg: la a0, all_pass exit: addi a7, zero, 4 ecall addi a7, zero, 10 # Exit system call ecall # ==================== TEST RESULT CHECKER SUBROUTINE ==================== check_test_result: addi sp, sp, -4 sw ra, 0(sp) addi s1, s1, 1 # Increment total test counter beq a0, zero, test_failed addi s0, s0, 1 # Increment pass counter la a0, pass_msg j test_end test_failed: la a0, fail_msg test_end: addi a7, zero, 4 ecall lw ra, 0(sp) addi sp, sp, 4 ret # ==================== BF16 Subtraction Function ==================== # Implements: a - b = a + (-b) # Input: a0 = bf16 a, a1 = bf16 b # Output: a0 = a - b in bf16 format bf16_sub: addi sp, sp, -8 sw ra, 0(sp) sw s0, 4(sp) mv s0, a1 # Save b # Flip the sign bit of b (b ^= BF16_SIGN_MASK) li t0, 0x8000 # BF16_SIGN_MASK xor a1, s0, t0 # b.bits ^= BF16_SIGN_MASK # Call bf16_add(a, -b) jal bf16_add lw ra, 0(sp) lw s0, 4(sp) addi sp, sp, 8 ret # ==================== BF16 Addition Function ==================== bf16_add: addi sp, sp, -16 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) add s0, zero, a0 # Save input values add s1, zero, a1 # Extract sign, exponent, mantissa from a srli t0, s0, 15 # Extract sign bit (bit 15) andi t0, t0, 1 # Keep only the sign bit srli t1, s0, 7 # Extract exponent (bits 7-14) andi t1, t1, 255 # Keep only 8 exponent bits andi t2, s0, 127 # Extract mantissa (bits 0-6) # Extract sign, exponent, mantissa from b srli t3, s1, 15 andi t3, t3, 1 srli t4, s1, 7 andi t4, t4, 255 andi t5, s1, 127 # Add implicit 1 to mantissas for normalized numbers beq t1, zero, skip_impl_a ori t2, t2, 128 # Add implicit 1 (bit 7) skip_impl_a: beq t4, zero, skip_impl_b ori t5, t5, 128 skip_impl_b: # Align exponents by shifting smaller exponent's mantissa sub t6, t1, t4 # Calculate exponent difference bge t6, zero, a_greater_exp # b has larger exponent, shift a's mantissa sub t6, t4, t1 # Get shift amount srl t2, t2, t6 # Shift a's mantissa right add t1, zero, t4 # Use b's exponent j exponents_aligned a_greater_exp: # a has larger exponent, shift b's mantissa srl t5, t5, t6 exponents_aligned: # Perform addition or subtraction based on signs bne t0, t3, subtract # Same sign - add mantissas add t2, t2, t5 add s2, zero, t0 # Save result sign j normalize subtract: # Different signs - subtract smaller from larger bgeu t2, t5, a_greater_mant sub t2, t5, t2 # b - a add s2, zero, t3 # Use b's sign j normalize a_greater_mant: sub t2, t2, t5 # a - b add s2, zero, t0 # Use a's sign normalize: beq t2, zero, zero_result # Result is zero # Normalize mantissa to range [128, 255] (1.0 to 1.99) addi t3, zero, 1 slli t3, t3, 8 # t3 = 256 (threshold) normalize_loop: bgeu t2, t3, shift_right slli t2, t2, 1 # Shift left until normalized addi t1, t1, -1 # Decrement exponent bne t1, zero, normalize_loop j underflow # Exponent underflow shift_right: srli t2, t2, 1 # Shift right if mantissa too large addi t1, t1, 1 # Increment exponent bgeu t2, t3, shift_right addi t3, zero, 255 # Check for exponent overflow/underflow bge t1, t3, overflow bne t1, zero, pack_result underflow: add a0, zero, zero # Exponent underflow - return zero j done overflow: # Exponent overflow - return infinity addi a0, zero, 127 # Construct 0x7F80 (+inf) slli a0, a0, 8 ori a0, a0, 128 beq s2, zero, done addi a0, zero, 255 # Construct 0xFF80 (-inf) slli a0, a0, 8 ori a0, a0, 128 j done zero_result: add a0, zero, zero j done pack_result: andi t2, t2, 127 # Remove implicit bit (keep bits 0-6) slli a0, s2, 15 # Set sign bit (bit 15) slli t1, t1, 7 # Shift exponent to bits 7-14 or a0, a0, t1 # Combine sign and exponent or a0, a0, t2 # Combine with mantissa done: lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) lw s2, 12(sp) addi sp, sp, 16 ret # ==================== BF16 Multiplication Function (Fixed) ==================== bf16_mul: addi sp, sp, -24 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) sw s3, 16(sp) sw s4, 20(sp) mv s0, a0 # a mv s1, a1 # b # Extract sign bits srli s2, s0, 15 andi s2, s2, 1 # sign_a srli t0, s1, 15 andi t0, t0, 1 # sign_b xor s2, s2, t0 # result_sign = sign_a ^ sign_b # Extract exponents srli s3, s0, 7 andi s3, s3, 0xFF # exp_a srli t0, s1, 7 andi t0, t0, 0xFF # exp_b # Extract mantissas and add implicit bit andi s4, s0, 0x7F # mant_a andi t1, s1, 0x7F # mant_b # Check for special cases beq s3, zero, mul_zero_a li t2, 0xFF beq s3, t2, mul_inf_nan_a mul_check_b: beq t0, zero, mul_zero_b li t2, 0xFF beq t0, t2, mul_inf_nan_b # Add implicit 1 to mantissas for normalized numbers ori s4, s4, 0x80 ori t1, t1, 0x80 # Multiply mantissas (16-bit result) mul t2, s4, t1 # mantissa product (16 bits) # Calculate result exponent add s3, s3, t0 # exp_a + exp_b addi s3, s3, -127 # subtract bias # Normalize mantissa li t3, 0x4000 # 0x4000 = 1<<14 (check if product >= 0x4000) bgeu t2, t3, mul_normalize_shift_right # Need to shift left slli t2, t2, 1 addi s3, s3, -1 j mul_check_exp mul_normalize_shift_right: # Product is too large, shift right srli t2, t2, 7 # shift to get 7-bit mantissa andi t2, t2, 0x7F # keep only 7 bits j mul_pack_result mul_check_exp: # Check if we need to normalize more li t3, 0x4000 bgeu t2, t3, mul_normalize_shift_right # Get final 7-bit mantissa srli t2, t2, 7 andi t2, t2, 0x7F mul_pack_result: # Check exponent bounds ble s3, zero, mul_underflow li t3, 0xFF bge s3, t3, mul_overflow # Pack result slli a0, s2, 15 # sign slli t3, s3, 7 # exponent or a0, a0, t3 or a0, a0, t2 # mantissa j mul_done mul_zero_a: # a is zero beq t0, zero, mul_zero_result # 0 * 0 = 0 li t2, 0xFF beq t0, t2, mul_nan_result # 0 * inf = NaN j mul_zero_result mul_zero_b: # b is zero li t2, 0xFF beq s3, t2, mul_nan_result # inf * 0 = NaN j mul_zero_result mul_inf_nan_a: # a is inf or NaN beq s4, zero, mul_a_inf # a is infinity j mul_nan_result # a is NaN mul_inf_nan_b: # b is inf or NaN beq t1, zero, mul_b_inf # b is infinity j mul_nan_result # b is NaN mul_a_inf: beq t0, zero, mul_nan_result # inf * 0 = NaN j mul_inf_result mul_b_inf: beq s3, zero, mul_nan_result # 0 * inf = NaN j mul_inf_result mul_zero_result: li a0, 0 j mul_done mul_inf_result: li a0, 0x7F80 # +inf beq s2, zero, mul_done li a0, 0xFF80 # -inf j mul_done mul_nan_result: li a0, 0x7FC0 # NaN j mul_done mul_underflow: li a0, 0 # flush to zero j mul_done mul_overflow: li a0, 0x7F80 # +inf beq s2, zero, mul_done li a0, 0xFF80 # -inf mul_done: lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) lw s2, 12(sp) lw s3, 16(sp) lw s4, 20(sp) addi sp, sp, 24 ret # ==================== BF16 Division Function (Fixed) ==================== bf16_div: addi sp, sp, -24 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) sw s3, 16(sp) sw s4, 20(sp) mv s0, a0 # a (dividend) mv s1, a1 # b (divisor) # Extract sign bits srli s2, s0, 15 andi s2, s2, 1 # sign_a srli t0, s1, 15 andi t0, t0, 1 # sign_b xor s2, s2, t0 # result_sign = sign_a ^ sign_b # Extract exponents srli s3, s0, 7 andi s3, s3, 0xFF # exp_a srli t0, s1, 7 andi t0, t0, 0xFF # exp_b # Extract mantissas andi s4, s0, 0x7F # mant_a andi t1, s1, 0x7F # mant_b # Check for special cases # Division by zero beq t0, zero, div_by_zero_check # Check for NaN/infinity li t2, 0xFF beq s3, t2, div_inf_nan_a beq t0, t2, div_inf_nan_b # Check for zero dividend beq s3, zero, div_zero_dividend # Add implicit 1 to mantissas ori s4, s4, 0x80 ori t1, t1, 0x80 # Perform division using restoring division algorithm li t2, 0 # quotient li t3, 8 # counter mv t4, s4 # remainder (start with dividend mantissa) div_loop: slli t2, t2, 1 # shift quotient left slli t4, t4, 1 # shift remainder left # Compare remainder with divisor bltu t4, t1, div_skip_sub sub t4, t4, t1 # subtract divisor ori t2, t2, 1 # set quotient bit div_skip_sub: addi t3, t3, -1 bnez t3, div_loop # t2 now contains the quotient mantissa (8 bits) # Calculate result exponent sub s3, s3, t0 # exp_a - exp_b addi s3, s3, 127 # add bias # Normalize quotient if needed andi t3, t2, 0x80 # check if implicit bit is set bnez t3, div_normalized # Need to normalize - shift left and adjust exponent slli t2, t2, 1 andi t2, t2, 0xFF # keep 8 bits addi s3, s3, -1 div_normalized: # Get final 7-bit mantissa (remove implicit bit) andi t2, t2, 0x7F # Check exponent bounds ble s3, zero, div_underflow li t3, 0xFF bge s3, t3, div_overflow # Pack result slli a0, s2, 15 # sign slli t3, s3, 7 # exponent or a0, a0, t3 or a0, a0, t2 # mantissa j div_done div_by_zero_check: # Division by zero beq s3, zero, div_zero_by_zero # 0/0 = NaN li a0, 0x7F80 # +inf beq s2, zero, div_done li a0, 0xFF80 # -inf j div_done div_zero_by_zero: li a0, 0x7FC0 # NaN j div_done div_zero_dividend: # 0 / non-zero = 0 li a0, 0 j div_done div_inf_nan_a: # a is inf or NaN beq s4, zero, div_a_inf # a is infinity j div_nan_result # a is NaN div_inf_nan_b: # b is inf or NaN beq t1, zero, div_b_inf # b is infinity j div_nan_result # b is NaN div_a_inf: beq t0, zero, div_nan_result # inf / 0 = NaN li t2, 0xFF beq t0, t2, div_nan_result # inf / inf = NaN j div_inf_result div_b_inf: beq s3, zero, div_zero_result # 0 / inf = 0 j div_zero_result # finite / inf = 0 div_inf_result: li a0, 0x7F80 # +inf beq s2, zero, div_done li a0, 0xFF80 # -inf j div_done div_zero_result: li a0, 0 j div_done div_nan_result: li a0, 0x7FC0 # NaN j div_done div_underflow: li a0, 0 # flush to zero j div_done div_overflow: li a0, 0x7F80 # +inf beq s2, zero, div_done li a0, 0xFF80 # -inf div_done: lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) lw s2, 12(sp) lw s3, 16(sp) lw s4, 20(sp) addi sp, sp, 24 ret # ==================== utility functions ==================== # Check if bfloat16 value is NaN bf16_isnan: srli t0, a0, 7 andi t0, t0, 255 addi t1, zero, 255 bne t0, t1, not_nan_ret andi t0, a0, 127 # Exponent is all 1s, check if mantissa is non-zero beq t0, zero, not_nan_ret # Infinity, not NaN addi a0, zero, 1 ret not_nan_ret: add a0, zero, zero ret # Check if bfloat16 value is Infinity bf16_isinf: srli t0, a0, 7 andi t0, t0, 255 addi t1, zero, 255 bne t0, t1, not_inf_ret andi t0, a0, 127 # Exponent is all 1s, check if mantissa is zero bne t0, zero, not_inf_ret # NaN, not Infinity addi a0, zero, 1 ret not_inf_ret: add a0, zero, zero ret # Check if bfloat16 value is zero bf16_iszero: andi t0, a0, 255 # Check lower 8 bits srli t1, a0, 8 # Check next 7 bits andi t1, t1, 127 or t0, t0, t1 # Combine results bne t0, zero, not_zero_ret addi a0, zero, 1 ret not_zero_ret: add a0, zero, zero ret # Check if two bfloat16 values are equal bf16_eq: beq a0, a1, equal_ret add a0, zero, zero ret equal_ret: addi a0, zero, 1 ret ``` **1. Execution information** - <span style="color:darkblue">**$1823 \;\text{cycles}$**</span> ![截圖 2025-10-12 下午6.56.55](https://hackmd.io/_uploads/HJ8jYWKpel.png =350x) --- ## [Leetcode`66`](https://leetcode.com/problems/plus-one/) ### Description You are given a large integer represented as an integer array `digits`, where each `digits[i]` is the ith digit of the integer. The digits are ordered from **most significant to least significant** in left-to-right order. The large integer does **not** contain any leading `0'`s. Increment the large integer by one and **return the resulting array of digits**. #### Example `1` ```vb Input: digits = [1,2,3] Output: [1,2,4] Explanation: The array represents the integer 123. Incrementing by one gives 123 + 1 = 124, so the result is [1,2,4]. ``` #### Example `2` ```vb Input: digits = [9,9,9] Output: [1,0,0,0] Explanation: The array represents the integer 999. Incrementing by one gives 1000, so the result is [1,0,0,0]. ``` ### Original C program ```c int* plusOne(int* digits, int digitsSize, int* returnSize) { // Allocate result array (max possible size is digitsSize + 1) int* result = (int*)malloc(sizeof(int) * (digitsSize + 1)); // Initialize result with zeros for (int i = 0; i <= digitsSize; i++) { result[i] = 0; } // Copy input to result for (int i = 0; i < digitsSize; i++) { result[i] = digits[i]; } int carry = 1; // Start with carry = 1 (adding 1) // Process from least significant digit for (int i = digitsSize - 1; i >= 0; i--) { if (carry) { result[i] += 1; if (result[i] >= 10) { result[i] = 0; carry = 1; } else { carry = 0; } } } // Check if we need to expand the array if (carry) { // Shift all elements right for (int i = digitsSize; i > 0; i--) { result[i] = result[i - 1]; } result[0] = 1; *returnSize = digitsSize + 1; } else { *returnSize = digitsSize; } return result; } ``` ### Assembly code <h4> <span style="color:darkblue">version <code>1</code></span> </h4> ```s # Plus One - RISC-V Assembly Implementation # Implements: digits = digits + 1 for large integers represented as arrays .data # Test data arrays test1: .word 1, 2, 3 # [1,2,3] -> [1,2,4] test1_len: .word 3 test1_exp: .word 1, 2, 4 # Expected result test2: .word 4, 3, 2, 1 # [4,3,2,1] -> [4,3,2,2] test2_len: .word 4 test2_exp: .word 4, 3, 2, 2 # Expected result test3: .word 9, 9, 9 # [9,9,9] -> [1,0,0,0] test3_len: .word 3 test3_exp: .word 1, 0, 0, 0 # Expected result (length changes!) # Output buffer for results result: .word 0, 0, 0, 0, 0 # 5 words buffer # Test messages test_start_msg: .string "=== Plus One Automated Test ===\n\n" test1_msg: .string "Test 1: [1,2,3] -> [1,2,4] " test2_msg: .string "Test 2: [4,3,2,1] -> [4,3,2,2] " test3_msg: .string "Test 3: [9,9,9] -> [1,0,0,0] " pass_msg: .string "PASS\n" fail_msg: .string "FAIL\n" separator: .string "\n====================\n" all_pass_msg: .string "All tests passed.\n" some_fail_msg: .string "Some tests failed.\n" newline: .string "\n" .text .globl main main: # Print test start message la a0, test_start_msg li a7, 4 ecall # Initialize test counters li s0, 0 # s0 = pass count li s1, 3 # s1 = total test count # Test 1: [1,2,3] -> [1,2,4] la a0, test1_msg li a7, 4 ecall la a0, test1 # input array lw a1, test1_len # array length jal ra, plus_one la a1, test1_exp # expected result lw a2, test1_len # expected length jal ra, compare_arrays beqz a0, test1_fail addi s0, s0, 1 # increment pass count la a0, pass_msg j test1_end test1_fail: la a0, fail_msg test1_end: li a7, 4 ecall # Test 2: [4,3,2,1] -> [4,3,2,2] la a0, test2_msg li a7, 4 ecall la a0, test2 # input array lw a1, test2_len # array length jal ra, plus_one la a1, test2_exp # expected result lw a2, test2_len # expected length jal ra, compare_arrays beqz a0, test2_fail addi s0, s0, 1 # increment pass count la a0, pass_msg j test2_end test2_fail: la a0, fail_msg test2_end: li a7, 4 ecall # Test 3: [9,9,9] -> [1,0,0,0] la a0, test3_msg li a7, 4 ecall la a0, test3 # input array lw a1, test3_len # array length jal ra, plus_one la a1, test3_exp # expected result li a2, 4 # expected length is 4 (changed due to carry) jal ra, compare_arrays beqz a0, test3_fail addi s0, s0, 1 # increment pass count la a0, pass_msg j test3_end test3_fail: la a0, fail_msg test3_end: li a7, 4 ecall # Print separator la a0, separator li a7, 4 ecall # Print final results beq s0, s1, all_passed la a0, some_fail_msg j print_final all_passed: la a0, all_pass_msg print_final: li a7, 4 ecall # Exit program li a7, 10 ecall # Plus One function # Corresponds to: int* plusOne(int* digits, int digitsSize, int* returnSize) plus_one: addi sp, sp, -20 sw ra, 0(sp) sw s0, 4(sp) # input array address sw s1, 8(sp) # original length sw s2, 12(sp) # result buffer address sw s3, 16(sp) # carry flag mv s0, a0 # s0 = input array address mv s1, a1 # s1 = original length la s2, result # s2 = result buffer address # Initialize result buffer to zeros la t0, result li t1, 5 # max possible length li t2, 0 init_loop: beqz t1, init_done sw t2, 0(t0) addi t0, t0, 4 addi t1, t1, -1 j init_loop init_done: # Copy input to result mv t0, s0 # source = input array mv t1, s2 # destination = result buffer mv t2, s1 # counter = length copy_loop: beqz t2, copy_done lw t3, 0(t0) sw t3, 0(t1) addi t0, t0, 4 addi t1, t1, 4 addi t2, t2, -1 j copy_loop copy_done: # Start from least significant digit with carry = 1 li s3, 1 # carry flag = true (initially adding 1) # Calculate pointer to last element la t0, result # result array addi t1, s1, -1 # index of last element slli t1, t1, 2 # convert to byte offset add t0, t0, t1 # point to last element mv t2, s1 # counter = length process_digits: beqz t2, check_final_carry # Check if we need to add (last digit or carry is set) beqz s3, no_addition # Add 1 to current digit lw t3, 0(t0) # load current digit addi t3, t3, 1 # add 1 # Check for carry li t4, 10 blt t3, t4, no_overflow # Handle overflow: set digit to 0, keep carry flag li t3, 0 li s3, 1 # carry remains true j store_digit no_overflow: # No overflow, clear carry flag li s3, 0 store_digit: sw t3, 0(t0) no_addition: # Move to next digit (more significant) addi t0, t0, -4 addi t2, t2, -1 j process_digits check_final_carry: # If we still have carry after processing all digits beqz s3, plus_one_done # Need to expand array: shift right and add 1 at front # Calculate new length addi s1, s1, 1 # Shift all elements right by one position la t0, result addi t1, s1, -2 # index of second last element in new array slli t1, t1, 2 add t2, t0, t1 # source pointer (starts at original last element) addi t3, t2, 4 # destination pointer (one position right) mv t4, s1 # counter = new length - 1 addi t4, t4, -1 shift_loop: beqz t4, shift_done lw t5, 0(t2) sw t5, 0(t3) addi t2, t2, -4 addi t3, t3, -4 addi t4, t4, -1 j shift_loop shift_done: # Add 1 as the new most significant digit li t0, 1 la t1, result sw t0, 0(t1) plus_one_done: mv a0, s1 # return new length lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) lw s2, 12(sp) lw s3, 16(sp) addi sp, sp, 20 ret # Compare arrays function # Input: a0 = actual length, a1 = expected array address, a2 = expected length # Output: a0 = 1 if arrays match, 0 otherwise compare_arrays: addi sp, sp, -12 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) mv s0, a1 # s0 = expected array mv s1, a2 # s1 = expected length # First check if lengths match bne a0, s1, compare_fail # Now compare each element la t0, result # actual result array mv t1, s0 # expected array mv t2, s1 # counter compare_loop: beqz t2, compare_success lw t3, 0(t0) # actual value lw t4, 0(t1) # expected value bne t3, t4, compare_fail addi t0, t0, 4 addi t1, t1, 4 addi t2, t2, -1 j compare_loop compare_success: li a0, 1 j compare_done compare_fail: li a0, 0 compare_done: lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) addi sp, sp, 12 ret ``` **1. Execution information** - <span style="color:darkblue">**$843 \;\text{cycles}$**</span> ![截圖 2025-10-05 晚上11.03.19](https://hackmd.io/_uploads/HyVlKWgTge.png =360x) --- <h4> <span style="color:darkblue">version <code>2</code></span> </h4> **1. Improvement** * **simplify carry handling logic in** <span style="color:darkblue">**version**</span> **`1`**: ```s add_loop: # Load current digit lw t2, 0(t0) # Add carry add t2, t2, t1 li t1, 0 # reset carry # Check if digit >= 10 li t3, 10 blt t2, t3, no_carry # Handle carry: set digit to 0, set carry to 1 li t2, 0 li t1, 1 ``` * **add early exit mechanism in** <span style="color:darkblue">**version**</span> **`1`**: ```s beqz t1, add_done # Exit early if no carry ``` * **eliminated unnecessary initialization** Directly copies the input data without separate initialization. **2. Execution information** - <span style="color:darkblue">**$610 \;\text{cycles}$**</span> ![截圖 2025-10-06 凌晨4.32.03](https://hackmd.io/_uploads/r17JL8g6xg.png =350x) **3. why faster ?** * **reduced memory access** Uses only 2 saved registers (`s0-s1`) + temporary registers. * **branch prediction improvements** More straightforward logic with fewer conditional branches and consistent iteration behavior improves branch prediction accuracy. * **instruction count reduction** * Removed `O(n)` initialization loop that zeroed result buffer. * Reduced from `2 ~ 3` conditional branches to direct arithmetic * Added immediate termination when carry propagation stops. <h4> <span style="color:darkblue">complete version</span> </h4> ```s # Plus One - RISC-V Assembly Implementation # Implements: digits = digits + 1 for large integers represented as arrays .data # Test data arrays test1: .word 1, 2, 3 # [1,2,3] -> [1,2,4] test1_len: .word 3 test1_exp: .word 1, 2, 4 # Expected result test2: .word 4, 3, 2, 1 # [4,3,2,1] -> [4,3,2,2] test2_len: .word 4 test2_exp: .word 4, 3, 2, 2 # Expected result test3: .word 9, 9, 9 # [9,9,9] -> [1,0,0,0] test3_len: .word 3 test3_exp: .word 1, 0, 0, 0 # Expected result (length changes!) # Output buffer for results result: .word 0, 0, 0, 0, 0 # 5 words buffer # Test messages test_start_msg: .string "=== Plus One Automated Test ===\n\n" test1_msg: .string "Test 1: [1,2,3] -> [1,2,4] " test2_msg: .string "Test 2: [4,3,2,1] -> [4,3,2,2] " test3_msg: .string "Test 3: [9,9,9] -> [1,0,0,0] " pass_msg: .string "PASS\n" fail_msg: .string "FAIL\n" separator: .string "\n====================\n" all_pass_msg: .string "All tests passed.\n" some_fail_msg: .string "Some tests failed.\n" newline: .string "\n" .text .globl main main: # Print test start message la a0, test_start_msg li a7, 4 ecall # Initialize test counters li s0, 0 # s0 = pass count li s1, 3 # s1 = total test count # Test 1: [1,2,3] -> [1,2,4] la a0, test1_msg li a7, 4 ecall la a0, test1 # input array lw a1, test1_len # array length jal ra, plus_one la a1, test1_exp # expected result lw a2, test1_len # expected length jal ra, compare_arrays beqz a0, test1_fail addi s0, s0, 1 # increment pass count la a0, pass_msg j test1_end test1_fail: la a0, fail_msg test1_end: li a7, 4 ecall # Test 2: [4,3,2,1] -> [4,3,2,2] la a0, test2_msg li a7, 4 ecall la a0, test2 # input array lw a1, test2_len # array length jal ra, plus_one la a1, test2_exp # expected result lw a2, test2_len # expected length jal ra, compare_arrays beqz a0, test2_fail addi s0, s0, 1 # increment pass count la a0, pass_msg j test2_end test2_fail: la a0, fail_msg test2_end: li a7, 4 ecall # Test 3: [9,9,9] -> [1,0,0,0] la a0, test3_msg li a7, 4 ecall la a0, test3 # input array lw a1, test3_len # array length jal ra, plus_one la a1, test3_exp # expected result li a2, 4 # expected length is 4 (changed due to carry) jal ra, compare_arrays beqz a0, test3_fail addi s0, s0, 1 # increment pass count la a0, pass_msg j test3_end test3_fail: la a0, fail_msg test3_end: li a7, 4 ecall # Print separator la a0, separator li a7, 4 ecall # Print final results beq s0, s1, all_passed la a0, some_fail_msg j print_final all_passed: la a0, all_pass_msg print_final: li a7, 4 ecall # Exit program li a7, 10 ecall # Optimized plus_one function plus_one: addi sp, sp, -12 sw ra, 0(sp) sw s0, 4(sp) # result buffer address sw s1, 8(sp) # original length la s0, result # s0 = result buffer address mv s1, a1 # s1 = original length # Copy input to result mv t0, a0 # source mv t1, s0 # destination mv t2, a1 # counter copy_loop: beqz t2, copy_done lw t3, 0(t0) sw t3, 0(t1) addi t0, t0, 4 addi t1, t1, 4 addi t2, t2, -1 j copy_loop copy_done: # Add 1 starting from the last digit addi t0, s1, -1 # index of last element slli t0, t0, 2 # convert to byte offset add t0, s0, t0 # pointer to last element li t1, 1 # carry = 1 (we're adding 1) add_loop: # Load current digit lw t2, 0(t0) # Add carry add t2, t2, t1 li t1, 0 # reset carry # Check if digit >= 10 li t3, 10 blt t2, t3, no_carry # Handle carry: set digit to 0, set carry to 1 li t2, 0 li t1, 1 no_carry: # Store updated digit sw t2, 0(t0) # Check if we need to continue beqz t1, add_done # no carry, we're done # Move to previous digit addi t0, t0, -4 # Check if we've reached the beginning blt t0, s0, expand_array j add_loop add_done: mv a0, s1 # return original length j plus_one_exit expand_array: # All digits were 9, need to expand array # Shift all digits right by one position addi t0, s1, -1 # start from last index shift_loop: bltz t0, shift_done slli t1, t0, 2 # byte offset add t1, s0, t1 # source address lw t2, 0(t1) # load value addi t3, t1, 4 # destination address (one position right) sw t2, 0(t3) # store value addi t0, t0, -1 # move to previous index j shift_loop shift_done: # Set first digit to 1 li t0, 1 sw t0, 0(s0) addi a0, s1, 1 # return new length plus_one_exit: lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) addi sp, sp, 12 ret # Compare arrays function # Input: a0 = actual length, a1 = expected array address, a2 = expected length # Output: a0 = 1 if arrays match, 0 otherwise compare_arrays: addi sp, sp, -12 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) mv s0, a1 # s0 = expected array mv s1, a2 # s1 = expected length # First check if lengths match bne a0, s1, compare_fail # Now compare each element la t0, result # actual result array mv t1, s0 # expected array mv t2, s1 # counter compare_loop: beqz t2, compare_success lw t3, 0(t0) # actual value lw t4, 0(t1) # expected value bne t3, t4, compare_fail addi t0, t0, 4 addi t1, t1, 4 addi t2, t2, -1 j compare_loop compare_success: li a0, 1 j compare_done compare_fail: li a0, 0 compare_done: lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) addi sp, sp, 12 ret ``` ## Ripes Simulator Use <span style="color:blue">**Ripes**</span>, which is a graphical processor simulator and assembly editor for the **RISC-V** architecture ### 5-stage pipelined processor ![截圖 2025-10-06 清晨5.29.45](https://hackmd.io/_uploads/HJVumwgTlg.png) A pipelined processor divides instruction execution into five stages: * **Instruction Fetch (IF)** The Instruction Fetch stage is the first step in the pipeline process. Here, the CPU retrieves an instruction from the program memory. The primary tasks in this stage include: * `Program Counter (PC) Management` * `Instruction Memory Access` * `Buffering` Efficiency at this stage is paramount as delays here can **stall** the entire pipeline, leading to performance bottlenecks. **Advanced CPUs** often use techniques like prefetching to mitigate potential delays. * **Decode (ID)** Once the instruction is fetched, it enters the Instruction Decode stage. This stage involves interpreting the fetched instruction and preparing the necessary operands for execution. Key activities include: * `Opcode Decoding` * `Register Read` * `Instruction Classification` This stage is also responsible for **hazard detection and forwarding** to resolve data dependencies and prevent pipeline stalls. * **Execute (EX)** The Execute stage is where the actual computation or operation specified by the instruction takes place. This stage includes: * `ALU Operations` * `Address Calculation` * `Branch Evaluation` The execution stage is critical for the **overall performance**, as complex instructions can take multiple cycles, potentially causing pipeline stalls. * **Memory Access (MEM)** After execution, some instructions require access to memory to read or write data. The Memory Access stage handles these operations. Key functions include: * `Load Operations` * `Store Operations` * `Address Translation` Efficiency in this stage is achieved through techniques like caching, which **reduces latency** by storing frequently accessed data closer to the CPU. * **Write Back (WB)** The final stage in the pipeline is Write Back. Here, the results of the executed instructions are written back to the CPU's register file. This stage involves: * `Register Write` * `Completion Logging` Write Back is crucial for ensuring that subsequent instructions have access to the correct and updated data, **maintaining the integrity and consistency** of the processor’s state. --- **Reference** : * [leetcode 66](https://leetcode.com/problems/plus-one/) * [5 Stages of Pipeline in Computer Architecture](https://medium.com/@aylia.zulfiqar29/5-stages-of-pipeline-in-computer-architecture-dc9fca11784e) * [arch2025-quiz1-sol](https://hackmd.io/@sysprog/arch2025-quiz1-sol) * [@sysprog21/ca2025-quizzes](https://github.com/sysprog21/ca2025-quizzes/tree/main)