JimmyChen
    • Create new note
    • Create a note from template
      • Sharing URL Link copied
      • /edit
      • View mode
        • Edit mode
        • View mode
        • Book mode
        • Slide mode
        Edit mode View mode Book mode Slide mode
      • Customize slides
      • Note Permission
      • Read
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Write
        • Only me
        • Signed-in users
        • Everyone
        Only me Signed-in users Everyone
      • Engagement control Commenting, Suggest edit, Emoji Reply
      • Invitee
    • Publish Note

      Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

      Your note will be visible on your profile and discoverable by anyone.
      Your note is now live.
      This note is visible on your profile and discoverable online.
      Everyone on the web can find and read all notes of this public team.
      See published notes
      Unpublish note
      Please check the box to agree to the Community Guidelines.
      View profile
    • Commenting
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
      • Everyone
    • Suggest edit
      Permission
      Disabled Forbidden Owners Signed-in users Everyone
    • Enable
    • Permission
      • Forbidden
      • Owners
      • Signed-in users
    • Emoji Reply
    • Enable
    • Versions and GitHub Sync
    • Note settings
    • Engagement control
    • Transfer ownership
    • Delete this note
    • Save as template
    • Insert from template
    • Import from
      • Dropbox
      • Google Drive
      • Gist
      • Clipboard
    • Export to
      • Dropbox
      • Google Drive
      • Gist
    • Download
      • Markdown
      • HTML
      • Raw HTML
Menu Note settings Sharing URL Create Help
Create Create new note Create a note from template
Menu
Options
Versions and GitHub Sync Engagement control Transfer ownership Delete this note
Import from
Dropbox Google Drive Gist Clipboard
Export to
Dropbox Google Drive Gist
Download
Markdown HTML Raw HTML
Back
Sharing URL Link copied
/edit
View mode
  • Edit mode
  • View mode
  • Book mode
  • Slide mode
Edit mode View mode Book mode Slide mode
Customize slides
Note Permission
Read
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Write
Only me
  • Only me
  • Signed-in users
  • Everyone
Only me Signed-in users Everyone
Engagement control Commenting, Suggest edit, Emoji Reply
Invitee
Publish Note

Share your work with the world Congratulations! 🎉 Your note is out in the world Publish Note

Your note will be visible on your profile and discoverable by anyone.
Your note is now live.
This note is visible on your profile and discoverable online.
Everyone on the web can find and read all notes of this public team.
See published notes
Unpublish note
Please check the box to agree to the Community Guidelines.
View profile
Engagement control
Commenting
Permission
Disabled Forbidden Owners Signed-in users Everyone
Enable
Permission
  • Forbidden
  • Owners
  • Signed-in users
  • Everyone
Suggest edit
Permission
Disabled Forbidden Owners Signed-in users Everyone
Enable
Permission
  • Forbidden
  • Owners
  • Signed-in users
Emoji Reply
Enable
Import from Dropbox Google Drive Gist Clipboard
   owned this note    owned this note      
Published Linked with GitHub
Subscribed
  • Any changes
    Be notified of any changes
  • Mention me
    Be notified of mention me
  • Unsubscribe
Subscribe
# Assignment1: RISC-V Assembly and Instruction Pipeline contributed by < [JimmyCh1025](https://github.com/JimmyCh1025) > [TOC] ###### tags: `RISC-V` `Computer Architure 2025` ## Problem B ### UF8 UF8 implements a logarithmic 8-bit codec that is suitable for representing level-of-detail (LOD) distances and fog density values, but is not appropriate for financial calculations. It maps 20-bit unsigned integers to 8-bit symbols using logarithmic quantization, delivering 2.5:1 compression and a relative error of ≤6.25%. UF8 format: | exponent | mantissa | |--------------|---------------| | 4bits | 4bits | Decoding \begin{gather*} D(b) = m \cdot 2^e + (2^e - 1) \cdot 16 \end{gather*} * The maximum exponent and mantissa are both 15 (4 bits), so the base value is 15 × 2¹⁵ = 491,520. The offset is used to bring the decoded value closer to the original input(15 × 2¹⁵ + (2¹⁵ - 1) × 16 = 1015792 ~ 1048575 = 2²⁰ - 1). * To avoid overlapping value ranges between exponent groups, an offset is added. * Without offset: * e = 0 → [0 × 2⁰, 15 × 2⁰] = [0, 15] * e = 1 → [0 × 2¹, 15 × 2¹] = [0, 30] → Overlapping * After applying offset: * e = 0 → [0, 15] * e = 1 → [16, 46] Encoding \begin{gather*} E(v) = \begin{cases} v, & \text{if } v < 16 \\ 16e + \lfloor(v - \text{offset}(e))/2^e\rfloor, & \text{otherwise} \end{cases} \end{gather*} * For input values less than 16, the UF8 encoding is lossless: E(v)=v, because the 4-bit mantissa can directly represent these values without exponent shifting. * If the MSB ≥ 5, then exponent is set to msb - 4 to leave 4 bits for the mantissa. UF8 only supports up to ~1 million (due to 4-bit exponent and mantissa), so large values will be clipped to the maximum representable range. * Normal values are encoded by finding a suitable exponent, computing the offset to avoid overlapping ranges, and placing the remaining value into the mantissa field. ### C code ```c= #include <stdbool.h> #include <stdint.h> #include <stdio.h> #include <stdlib.h> typedef uint8_t uf8; static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } /* Decode uf8 to uint32_t */ uint32_t uf8_decode(uf8 fl) { uint32_t mantissa = fl & 0x0f; uint8_t exponent = fl >> 4; uint32_t offset = (0x7FFF >> (15 - exponent)) << 4; return (mantissa << exponent) + offset; } /* Encode uint32_t to uf8 */ uf8 uf8_encode(uint32_t value) { /* Use CLZ for fast exponent calculation */ if (value < 16) return value; /* Find appropriate exponent using CLZ hint */ int lz = clz(value); int msb = 31 - lz; /* Start from a good initial guess */ uint8_t exponent = 0; uint32_t overflow = 0; if (msb >= 5) { /* Estimate exponent - the formula is empirical */ exponent = msb - 4; if (exponent > 15) exponent = 15; /* Calculate overflow for estimated exponent */ for (uint8_t e = 0; e < exponent; e++) overflow = (overflow << 1) + 16; /* Adjust if estimate was off */ while (exponent > 0 && value < overflow) { overflow = (overflow - 16) >> 1; exponent--; } } /* Find exact exponent */ while (exponent < 15) { uint32_t next_overflow = (overflow << 1) + 16; if (value < next_overflow) break; overflow = next_overflow; exponent++; } uint8_t mantissa = (value - overflow) >> exponent; return (exponent << 4) | mantissa; } /* Test encode/decode round-trip */ static bool test(void) { int32_t previous_value = -1; bool passed = true; for (int i = 0; i < 256; i++) { uint8_t fl = i; int32_t value = uf8_decode(fl); uint8_t fl2 = uf8_encode(value); if (fl != fl2) { printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); passed = false; } if (value <= previous_value) { printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); passed = false; } previous_value = value; } return passed; } int main(void) { if (test()) { printf("All tests passed.\n"); return 0; } return 1; } ``` ### Assembly code :::spoiler More detailed information ```assembly= #======================================================================================= # File : uf8.s # Author : Jimmy Chen # Date : 2025-10-08 # Brief: Implements encoding and decoding for the custom 8-bit UF8 format. #======================================================================================= .data # string all_tests_passed_str: .string "All tests passed.\n" mismatch_prod_val_str: .string ": produces value " mismatch_encode_str: .string " but encodes back to " not_incr_val_str: .string ": value " not_incr_prev_val_str: .string " <= previous_value " endline_str: .string "\n" .text .global main # ======================================================= # Function : main() # Parameter : none # Variable : # Description : execute test function and print pass if test return true. # Return : 0 (exit program) # ======================================================= main: jal ra, test # call test beq a0, x0, main_return1 # return 1 # printf "All tests passed.\n" la a0, all_tests_passed_str li a7, 4 # System call number 4 (print string) ecall # return 0 li a7, 10 add a0, a0, x0 ecall main_return1: # return 1 li a7, 10 addi a0, a0, 1 ecall # ======================================================= # Function : test() # Parameter : none # Variable : # Description : Performs round-trip tests on UF8 encoding and decoding to ensure correctness. # Return : 1(true) or 0(false) # ======================================================= test: addi sp, sp, -16 sw ra, 12(sp) sw s0, 8(sp) sw s1, 4(sp) sw s2, 0(sp) addi s0, x0, -1 # previous_value = -1 addi s1, x0, 0 # i = 0 addi s2, x0, 1 # passed = true test_loop: # if i > 255, break li s3, 0xFF blt s3, s1, test_done # uint8_t fl = i andi s3, s1, 0xFF addi a0, s3, 0 # value = uf8_decode(fl) jal ra, uf8_decode addi s4, a0, 0 # fl2 = uf8_encode(value) addi a0, s4, 0 jal ra, uf8_encode andi s5, a0, 0xFF # if (fl != fl2) bne s3, s5, mismatch # if (value <= previous_value) ble s4, s0, not_increasing # previous_value = value mv s0, s4 addi s1, s1, 1 # i++ j test_loop mismatch: # printf("%02x: produces value %d but encodes back to %02x\n", fl, value, fl2); addi a0, s3, 0 li a7, 1 ecall la a0, mismatch_prod_val_str li a7, 4 ecall addi a0, s4, 0 li a7, 1 ecall la a0, mismatch_encode_str li a7, 4 ecall addi a0, s5, 0 li a7, 1 ecall la a0, endline_str li a7, 4 ecall add s2, x0, x0 # passed = false j continue_loop not_increasing: # printf("%02x: value %d <= previous_value %d\n", fl, value, previous_value); addi a0, s3, 0 # fl li a7, 1 ecall la a0, not_incr_val_str li a7, 4 ecall addi a0, s4, 0 # value li a7, 1 ecall la a0, not_incr_prev_val_str li a7, 4 ecall addi a0, s0, 0 # previous_value li a7, 1 ecall la a0, endline_str li a7, 4 ecall add s2, x0, x0 # passed = false continue_loop: # previous_value = value addi s0, t0, 0 # i++ addi s1, s1, 1 j test_loop test_done: # return passed addi a0, s2, 0 lw s2, 0(sp) lw s1, 4(sp) lw s0, 8(sp) lw ra, 12(sp) addi sp, sp, 16 jalr x0, x1, 0 # ======================================================= # Function : uf8_decode() # Parameter : uf8 fl # Variable : # Description : Decodes an 8-bit UF8 value into a 32-bit unsigned integer. # Return : uint32_t uf8_decode value # ======================================================= uf8_decode: # u32 mantissa = fl & 0x0F andi t0, a0, 0x0F # u8 exponent = fl >> 4 srli t1, a0, 4 andi t1, t1, 0xFF # u32 offset = (0x7FFF >> (15-exponent)) << 4 li t2, 15 li t3, 0x7FFF sub t2, t2, t1 # (15-exponent) srl t3, t3, t2 # (0x7FFF >> (15-exponent)) slli t4, t3, 4 # (0x7FFF >> (15-exponent)) << 4 # return (mantissa << exponent) + offset; sll a0, t0, t1 add a0, a0, t4 jalr x0, x1, 0 # ======================================================= # Function : uf8_encode() # Parameter : uint32_t value # Variable : # Description : Encodes a 32-bit unsigned integer into an 8-bit UF8 representation. # Return : uf8 endcode value # ======================================================= uf8_encode: li t0, 16 blt a0, t0, uf8_encode_return_val # assign t0 = value addi t0, a0, 0 # call clz addi sp, sp, -8 sw ra, 4(sp) sw t0, 0(sp) jal ra, clz lw t0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 # lz = clz(value) add t1, a0, x0 # msb = 31 - lz li t2, 31 sub t2, t2, t1 # exponent = 0 li t3, 0 # overflow = 0 li t4, 0 # if msb < 5 li t5, 5 blt t2, t5, uf8_find_extra_exp uf8_encode_bge5: # exponent = msb - 4 addi t3, t2, -4 andi t3, t3, 0xFF li t5, 15 # e = 0 li t6, 0 # if 15 < exp blt t5, t3, uf8_exp_bg15 j uf8_calculate_overflow uf8_exp_bg15: # exp = 15 li t3, 15 uf8_calculate_overflow: # e < exp bge t6, t3, uf8_adjust_if_off # (overflow << 1)+16 slli t4, t4, 1 addi t4, t4, 16 # e++ addi t6, t6, 1 j uf8_calculate_overflow uf8_adjust_if_off: # 0 >= exponent bge x0, t3, uf8_find_extra_exp # value >= overflow bge t0, t4, uf8_find_extra_exp # overflow = (overflow-16) >> 1 addi t4, t4, -16 srli t4, t4, 1 # exp-- addi t3, t3, -1 j uf8_adjust_if_off uf8_find_extra_exp: # while exp < 15 li t5, 15 # if exp >= 15, return value bge t3, t5, uf8_encode_return # next_overflow = (overflow << 1)+16 slli t5, t4, 1 addi t5, t5, 16 # value < next_overflow blt t0, t5, uf8_encode_return # overflow = next_overflow add t4, t5, x0 # exp++ addi t3, t3, 1 j uf8_find_extra_exp uf8_encode_return: # (value - overflow) >> exponent sub t5, t0, t4 srl t5, t5, t3 andi t5, t5, 0xFF # (exponent << 4)|mantissa slli t1, t3, 4 andi a0, t1, 0xFF or a0, a0, t5 jalr x0, x1, 0 uf8_encode_return_val: # return value(a0) jalr x0, x1, 0 # ======================================================= # Function : clz() # Parameter : uint32_t x # Variable : # Description : Counts the number of leading zeros in a 32-bit unsigned integer. # Return : unsigned val # ======================================================= clz: li t0, 32 # t0 = n li t1, 16 # t1 = c clz_whileLoop: srl t2, a0, t1 # y = t2, y = x >> c beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit sub t0, t0, t1 # n = n - c addi a0, t2, 0 # x = y shift_right_1bit: srli t1, t1, 1 # c = c >> 1 bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop sub a0, t0, a0 # x = n - x jalr x0, x1, 0 # return x ``` ::: ### Result ![image](https://hackmd.io/_uploads/r1vnsJ4aex.png) ---- ## Problem C ### Float | sign | exponent | mantissa | | -------- | -------- | -------- | | 1 bit | 8 bits | 23 bits | ### BF16 | sign | exponent | mantissa | | -------- | -------- | -------- | | 1 bit | 8 bits | 7 bits | * Normalization : $±(1.mantissa) × 2^{exponent-127}$ * Denormalization : $±(0.mantissa) × 2^{-126}$ * NAN : The exponent bits are all 1, and the mantissa bits are not all 0. * INF : The exponent bits are all 1, and the mantissa bits are all 0. * ZERO : The exponent bits are all 0, and the mantissa bits are all 0. * F32toB16 : * The line f32bits += ((f32bits >> 16) & 1) + 0x7FFF implements round to nearest, ties to even. It adds 0x7FFF for the rounding offset, and uses bit 16 (the 17th bit counting from 0) to decide whether to round up in tie cases. ### C code ```c= #include <stdbool.h> #include <stdint.h> #include <string.h> typedef struct { uint16_t bits; } bf16_t; #define BF16_SIGN_MASK 0x8000U #define BF16_EXP_MASK 0x7F80U #define BF16_MANT_MASK 0x007FU #define BF16_EXP_BIAS 127 #define BF16_NAN() ((bf16_t) {.bits = 0x7FC0}) #define BF16_ZERO() ((bf16_t) {.bits = 0x0000}) static inline bool bf16_isnan(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && (a.bits & BF16_MANT_MASK); } static inline bool bf16_isinf(bf16_t a) { return ((a.bits & BF16_EXP_MASK) == BF16_EXP_MASK) && !(a.bits & BF16_MANT_MASK); } static inline bool bf16_iszero(bf16_t a) { return !(a.bits & 0x7FFF); } static inline bf16_t f32_to_bf16(float val) { uint32_t f32bits; memcpy(&f32bits, &val, sizeof(float)); if (((f32bits >> 23) & 0xFF) == 0xFF) return (bf16_t) {.bits = (f32bits >> 16) & 0xFFFF}; f32bits += ((f32bits >> 16) & 1) + 0x7FFF; return (bf16_t) {.bits = f32bits >> 16}; } static inline float bf16_to_f32(bf16_t val) { uint32_t f32bits = ((uint32_t) val.bits) << 16; float result; memcpy(&result, &f32bits, sizeof(float)); return result; } static inline bf16_t bf16_add(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; if (exp_a == 0xFF) { if (mant_a) return a; if (exp_b == 0xFF) return (mant_b || sign_a == sign_b) ? b : BF16_NAN(); return a; } if (exp_b == 0xFF) return b; if (!exp_a && !mant_a) return b; if (!exp_b && !mant_b) return a; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; int16_t exp_diff = exp_a - exp_b; uint16_t result_sign; int16_t result_exp; uint32_t result_mant; if (exp_diff > 0) { result_exp = exp_a; if (exp_diff > 8) return a; mant_b >>= exp_diff; } else if (exp_diff < 0) { result_exp = exp_b; if (exp_diff < -8) return b; mant_a >>= -exp_diff; } else { result_exp = exp_a; } if (sign_a == sign_b) { result_sign = sign_a; result_mant = (uint32_t) mant_a + mant_b; if (result_mant & 0x100) { result_mant >>= 1; if (++result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } } else { if (mant_a >= mant_b) { result_sign = sign_a; result_mant = mant_a - mant_b; } else { result_sign = sign_b; result_mant = mant_b - mant_a; } if (!result_mant) return BF16_ZERO(); while (!(result_mant & 0x80)) { result_mant <<= 1; if (--result_exp <= 0) return BF16_ZERO(); } } return (bf16_t) { .bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F), }; } static inline bf16_t bf16_sub(bf16_t a, bf16_t b) { b.bits ^= BF16_SIGN_MASK; return bf16_add(a, b); } static inline bf16_t bf16_mul(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_a == 0xFF) { if (mant_a) return a; if (!exp_b && !mant_b) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_b == 0xFF) { if (mant_b) return b; if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) return (bf16_t) {.bits = result_sign << 15}; int16_t exp_adjust = 0; if (!exp_a) { while (!(mant_a & 0x80)) { mant_a <<= 1; exp_adjust--; } exp_a = 1; } else mant_a |= 0x80; if (!exp_b) { while (!(mant_b & 0x80)) { mant_b <<= 1; exp_adjust--; } exp_b = 1; } else mant_b |= 0x80; uint32_t result_mant = (uint32_t) mant_a * mant_b; int32_t result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust; if (result_mant & 0x8000) { result_mant = (result_mant >> 8) & 0x7F; result_exp++; } else result_mant = (result_mant >> 7) & 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) { if (result_exp < -6) return (bf16_t) {.bits = result_sign << 15}; result_mant >>= (1 - result_exp); result_exp = 0; } return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)}; } static inline bf16_t bf16_div(bf16_t a, bf16_t b) { uint16_t sign_a = (a.bits >> 15) & 1; uint16_t sign_b = (b.bits >> 15) & 1; int16_t exp_a = ((a.bits >> 7) & 0xFF); int16_t exp_b = ((b.bits >> 7) & 0xFF); uint16_t mant_a = a.bits & 0x7F; uint16_t mant_b = b.bits & 0x7F; uint16_t result_sign = sign_a ^ sign_b; if (exp_b == 0xFF) { if (mant_b) return b; /* Inf/Inf = NaN */ if (exp_a == 0xFF && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = result_sign << 15}; } if (!exp_b && !mant_b) { if (!exp_a && !mant_a) return BF16_NAN(); return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (exp_a == 0xFF) { if (mant_a) return a; return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; } if (!exp_a && !mant_a) return (bf16_t) {.bits = result_sign << 15}; if (exp_a) mant_a |= 0x80; if (exp_b) mant_b |= 0x80; uint32_t dividend = (uint32_t) mant_a << 15; uint32_t divisor = mant_b; uint32_t quotient = 0; for (int i = 0; i < 16; i++) { quotient <<= 1; if (dividend >= (divisor << (15 - i))) { dividend -= (divisor << (15 - i)); quotient |= 1; } } int32_t result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS; if (!exp_a) result_exp--; if (!exp_b) result_exp++; if (quotient & 0x8000) quotient >>= 8; else { while (!(quotient & 0x8000) && result_exp > 1) { quotient <<= 1; result_exp--; } quotient >>= 8; } quotient &= 0x7F; if (result_exp >= 0xFF) return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; if (result_exp <= 0) return (bf16_t) {.bits = result_sign << 15}; return (bf16_t) {.bits = (result_sign << 15) | ((result_exp & 0xFF) << 7) | (quotient & 0x7F)}; } static inline bf16_t bf16_sqrt(bf16_t a) { uint16_t sign = (a.bits >> 15) & 1; int16_t exp = ((a.bits >> 7) & 0xFF); uint16_t mant = a.bits & 0x7F; /* Handle special cases */ if (exp == 0xFF) { if (mant) return a; /* NaN propagation */ if (sign) return BF16_NAN(); /* sqrt(-Inf) = NaN */ return a; /* sqrt(+Inf) = +Inf */ } /* sqrt(0) = 0 (handle both +0 and -0) */ if (!exp && !mant) return BF16_ZERO(); /* sqrt of negative number is NaN */ if (sign) return BF16_NAN(); /* Flush denormals to zero */ if (!exp) return BF16_ZERO(); /* Direct bit manipulation square root algorithm */ /* For sqrt: new_exp = (old_exp - bias) / 2 + bias */ int32_t e = exp - BF16_EXP_BIAS; int32_t new_exp; /* Get full mantissa with implicit 1 */ uint32_t m = 0x80 | mant; /* Range [128, 256) representing [1.0, 2.0) */ /* Adjust for odd exponents: sqrt(2^odd * m) = 2^((odd-1)/2) * sqrt(2*m) */ if (e & 1) { m <<= 1; /* Double mantissa for odd exponent */ new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS; } else { new_exp = (e >> 1) + BF16_EXP_BIAS; } /* Now m is in range [128, 256) or [256, 512) if exponent was odd */ /* Binary search for integer square root */ /* We want result where result^2 = m * 128 (since 128 represents 1.0) */ uint32_t low = 90; /* Min sqrt (roughly sqrt(128)) */ uint32_t high = 256; /* Max sqrt (roughly sqrt(512)) */ uint32_t result = 128; /* Default */ /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } /* result now contains sqrt(m) * sqrt(128) / sqrt(128) = sqrt(m) */ /* But we need to adjust the scale */ /* Since m is scaled where 128=1.0, result should also be scaled same way */ /* Normalize to ensure result is in [128, 256) */ if (result >= 256) { result >>= 1; new_exp++; } else if (result < 128) { while (result < 128 && new_exp > 1) { result <<= 1; new_exp--; } } /* Extract 7-bit mantissa (remove implicit 1) */ uint16_t new_mant = result & 0x7F; /* Check for overflow/underflow */ if (new_exp >= 0xFF) return (bf16_t) {.bits = 0x7F80}; /* +Inf */ if (new_exp <= 0) return BF16_ZERO(); return (bf16_t) {.bits = ((new_exp & 0xFF) << 7) | new_mant}; } ``` ### Assembly code :::spoiler More detailed information ```assembly= #======================================================================================= # File : bfloat16.s # Author : Jimmy Chen # Date : 2025-10-08 # Brief : Implementation of bfloat16 arithmetic operations (add, sub, mul, div, sqrt) # including conversion between float and bfloat16, with IEEE 754 support. #======================================================================================= .data # define .equ BF16_SIGN_MASK, 0x8000 .equ BF16_EXP_MASK, 0x7F80 .equ BF16_MANT_MASK, 0x007F .equ BF16_EXP_BIAS, 127 .equ BF16_NAN, 0x7FC0 .equ BF16_ZERO, 0x0000 # input data input_a: .half 0x0000, 0x8000, 0x3f80, 0x4000, 0x4040, 0xbf80, 0x7f80, 0xff80, 0x7fc1, 0x4110, 0xc080, 0x0001, 0x0a4b, 0x40a0, 0x7f00 input_b: .half 0x0000, 0x0000, 0x4000, 0x3f80, 0x4000, 0x3f80, 0x3f80, 0x7f80, 0x3f80, 0x0000, 0x0000, 0x0001, 0x0a4b, 0x0000, 0x7f00 input_float: .word 0x3F800000, 0x3FC00000, 0x3F81AE14, 0x00000000, 0x80000000, 0x477FE000, 0x7F800000, 0xFF800000, 0x7FC00000, 0x00000001, 0xC0200000, 0x40490FDB, 0xC2F6E979, 0x3EAAAAAB, 0x2F06C6D6 # ans isNanAns: .half 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 isInfAns: .half 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0001, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 isZeroAns: .half 0x0001, 0x0001, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000 f32tob16Ans: .half 0x3f80, 0x3fc0, 0x3f82, 0x0000, 0x8000, 0x4780, 0x7f80, 0xff80, 0x7fc0, 0x0000, 0xc020, 0x4049, 0xc2f7, 0x3eab, 0x2f07 b16tof32Ans: .word 0x00000000, 0x80000000, 0x3f800000, 0x40000000, 0x40400000, 0xbf800000, 0x7f800000, 0xff800000, 0x7fc10000, 0x41100000, 0xc0800000, 0x00010000, 0x0a4b0000, 0x40a00000, 0x7f000000 addAns: .half 0x0000, 0x0000, 0x4040, 0x4040, 0x40a0, 0x0000, 0x7f80, 0x7fc0, 0x7fc1, 0x4110, 0xc080, 0x0002, 0x0acb, 0x40a0, 0x7f80 subAns: .half 0x8000, 0x8000, 0xbf80, 0x3f80, 0x3f80, 0xc000, 0x7f80, 0xff80, 0x7fc1, 0x4110, 0xc080, 0x0000, 0x0000, 0x40a0, 0x0000 mulAns: .half 0x0000, 0x8000, 0x4000, 0x4000, 0x40c0, 0xbf80, 0x7f80, 0xff80, 0x7fc1, 0x0000, 0x8000, 0x0000, 0x0000, 0x0000, 0x7f80 divAns: .half 0x7fc0, 0x7fc0, 0x3f00, 0x4000, 0x3fc0, 0xbf80, 0x7f80, 0x7fc0, 0x7fc1, 0x7f80, 0xff80, 0x3f80, 0x3f80, 0x7f80, 0x3f80 sqrtAns: .half 0x0000, 0x0000, 0x3f80, 0x3fb5, 0x3fdd, 0x7fc0, 0x7f80, 0x7fc0, 0x7fc1, 0x4040, 0x7fc0, 0x0000, 0x24e4, 0x400f, 0x5f35 # string msg_test_nan: .string "======Test NAN======\n" msg_test_inf: .string "======Test INF======\n" msg_test_zero: .string "======Test ZERO======\n" msg_test_f32tob16: .string "======Test F32ToB16======\n" msg_test_b16tof32: .string "======Test B16ToF32======\n" msg_test_add: .string "======Test ADD======\n" msg_test_sub: .string "======Test SUB======\n" msg_test_mul: .string "======Test MUL======\n" msg_test_div: .string "======Test DIV======\n" msg_test_sqrt: .string "======Test SQRT======\n" pass_str: .string " => Pass\n" fail_str: .string " => Fail\n" output_str: .string "Output = " answer_str: .string ",Answer = " .text .global main # ======================================================= # Function : main() # Parameter : none # Variable : i = s0, boundary = 15 # Description : execute all function and print the result # Return : 0 (exit program) # ======================================================= main: # i = 0, boundary = 15 add s0, x0, x0 j main_for main_for: addi t0, x0, 15 # if i >= 15, return exit bge s0, t0, main_exit j main_for_run_nan #======================nan===================== main_for_run_nan: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i] to bf16_nan() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) # call bf16_isnan jal ra, bf16_isnan # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_nan main_for_printf_nan: # print test nan la a0, msg_test_nan li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, isNanAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_nan_pass j main_for_nan_fail main_for_nan_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_inf main_for_nan_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_inf #======================inf===================== main_for_run_inf: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i] to bf16_inf() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) # call bf16_isinf jal ra, bf16_isinf # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_inf main_for_printf_inf: # print test inf la a0, msg_test_inf li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, isInfAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_inf_pass j main_for_inf_fail main_for_inf_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_zero main_for_inf_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_zero #======================zero===================== main_for_run_zero: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i] to bf16_iszero() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) # call bf16_iszero jal ra, bf16_iszero # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_zero main_for_printf_zero: # print test zero la a0, msg_test_zero li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, isZeroAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_zero_pass j main_for_zero_fail main_for_zero_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_f32tob16 main_for_zero_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_f32tob16 #======================f32tob16===================== main_for_run_f32tob16: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load test_float[i] to f32_to_bf16() slli t0, s0, 2 la t1, input_float add t1, t1, t0 lw a0, 0(t1) # call f32_to_bf16 jal ra, f32_to_bf16 # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_f32tob16 main_for_printf_f32tob16: # print test f32tob16 la a0, msg_test_f32tob16 li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, f32tob16Ans slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_f32tob16_pass j main_for_f32tob16_fail main_for_f32tob16_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_b16tof32 main_for_f32tob16_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_b16tof32 #======================b16tof32===================== main_for_run_b16tof32: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i] to bf16_to_f32() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) # call bf16_to_f32 jal ra, bf16_to_f32 # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_b16tof32 main_for_printf_b16tof32: # print test b16tof32 la a0, msg_test_b16tof32 li a7, 4 ecall # print output string la a0, output_str ecall # print output value add a0, s3, x0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, b16tof32Ans slli t0, s0, 2 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 ecall beq t5, t6, main_for_b16tof32_pass j main_for_b16tof32_fail main_for_b16tof32_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_add main_for_b16tof32_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_add #======================add===================== main_for_run_add: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i], b[i] to bf16_add() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) la t2, input_b add t2, t2, t0 lw a1, 0(t2) # call bf16_add jal ra, bf16_add # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_add main_for_printf_add: # print test add la a0, msg_test_add li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, addAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_add_pass j main_for_add_fail main_for_add_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_sub main_for_add_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_sub #======================sub===================== main_for_run_sub: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i], b[i] to bf16_sub() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) la t2, input_b add t2, t2, t0 lw a1, 0(t2) # call bf16_sub jal ra, bf16_sub # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_sub main_for_printf_sub: # print test sub la a0, msg_test_sub li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, subAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_sub_pass j main_for_sub_fail main_for_sub_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_mul main_for_sub_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_mul #======================mul===================== main_for_run_mul: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i], b[i] to bf16_mul() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) la t2, input_b add t2, t2, t0 lw a1, 0(t2) # call bf16_mul jal ra, bf16_mul # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_mul main_for_printf_mul: # print test mul la a0, msg_test_mul li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, mulAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_mul_pass j main_for_mul_fail main_for_mul_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_div main_for_mul_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_div #======================div===================== main_for_run_div: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i], b[i] to bf16_div() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) la t2, input_b add t2, t2, t0 lw a1, 0(t2) # call bf16_div jal ra, bf16_div # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_div main_for_printf_div: # print test div la a0, msg_test_div li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, divAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_div_pass j main_for_div_fail main_for_div_pass: # print pass la a0, pass_str li a7, 4 ecall j main_for_run_sqrt main_for_div_fail: # print fail la a0, fail_str li a7, 4 ecall j main_for_run_sqrt #======================sqrt===================== main_for_run_sqrt: addi sp, sp, -8 sw ra, 4(sp) sw s0, 0(sp) # load a[i] to bf16_sqrt() slli t0, s0, 1 la t1, input_a add t1, t1, t0 lw a0, 0(t1) # call bf16_sqrt jal ra, bf16_sqrt # value of return stores in s3 add s3, a0, x0 lw s0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 j main_for_printf_sqrt main_for_printf_sqrt: # print test sqrt la a0, msg_test_sqrt li a7, 4 ecall # print output string la a0, output_str ecall # print output value li t0, 0xFFFF and a0, s3, t0 li a7, 34 # for compare add t5, a0, x0 ecall # print answer string la a0, answer_str li a7, 4 ecall # print ans value la s3, sqrtAns slli t0, s0, 1 add t0, s3, t0 lw a0, 0(t0) li a7, 34 # for compare add t6, a0, x0 li t4, 0xffff and a0, t6, t4 and t6, t6, t4 ecall beq t5, t6, main_for_sqrt_pass j main_for_sqrt_fail main_for_sqrt_pass: # print pass la a0, pass_str li a7, 4 ecall # ++i addi s0, s0, 1 j main_for main_for_sqrt_fail: # print fail la a0, fail_str li a7, 4 ecall # ++i addi s0, s0, 1 j main_for #======================main exit===================== main_exit: li a7, 10 ecall # ======================================================= # Function : bf16_isnan() # Parameter : bf16_t a # Variable : # Description : Returns true if a is NaN; otherwise, returns false # Return : 1(true) or 0(false) # ======================================================= # test ok bf16_isnan: # t1 = (a.bits & BF16_EXP_MASK) li t0, BF16_EXP_MASK and t1, a0, t0 # if (a.bits & BF16_EXP_MASK) == BF16_EXP_MASK bne t1, t0, bf16_isnan_ret0 # if (a.bits & BF16_MANT_MASK) == 0 li t0, BF16_MANT_MASK and t1, a0, t0 beq t1, x0, bf16_isnan_ret0 # return 1 addi a0, x0, 1 jalr x0, ra, 0 bf16_isnan_ret0: # return 0 add a0, x0, x0 jalr x0, ra, 0 # ======================================================= # Function : bf16_isinf() # Parameter : bf16_t a # Variable : # Description : Returns true if the input is +Infinity or -Infinity. # Return : 1(true) or 0(false) # ======================================================= # test ok bf16_isinf: # t1 = (a.bits & BF16_EXP_MASK) li t0, BF16_EXP_MASK and t1, a0, t0 # if (a.bits & BF16_EXP_MASK) == BF16_EXP_MASK bne t1, t0, bf16_isinf_ret0 # if !(a.bits & BF16_MANT_MASK) == 0 li t0, BF16_MANT_MASK and t1, a0, t0 bne t1, x0, bf16_isinf_ret0 # return 1 addi a0, x0, 1 jalr x0, ra, 0 bf16_isinf_ret0: # return 0 add a0, x0, x0 jalr x0, ra, 0 # ======================================================= # Function : bf16_iszero() # Parameter : bf16_t a # Variable : # Description : Returns true if the input is positive or negative zero. # Return : 1(true) or 0(false) # ======================================================= # test ok bf16_iszero: # t1 = (a.bits & 0x7FFF) li t0, 0x7FFF and t1, a0, t0 bne t1, x0, bf16_iszero_ret0 # return 1 addi a0, x0, 1 jalr x0, ra, 0 bf16_iszero_ret0: # return 0 add a0, x0, x0 jalr x0, ra, 0 # ======================================================= # Function : f32_to_bf16() # Parameter : float val # Variable : # Description : Convert a 32-bit float to 16-bit bfloat16 by keeping the upper 16 bits. # Return : bf16_t value(a0) # ======================================================= # test ok f32_to_bf16: # u32 t0 = f32bits # memcpy(&f32bits, &val, sizeof(float)); add t0, a0, x0 # if (((f32bits >> 23) & 0xFF) == 0xFF) srli t1, t0, 23 andi t2, t1, 0xFF li t3, 0xFF bne t2, t3 , f32_to_bf16_ret_exp_not_allOne # return all exp 1 # (f32bits >> 16)& 0xFFFF srli t1, t0, 16 li t2, 0xFFFF and a0, t1, t2 # return all exp 1 jalr x0, ra, 0 f32_to_bf16_ret_exp_not_allOne: # return exp not all 1 # f32bits += ((f32bits >> 16) & 1) + 0x7FFF; srli t1, t0, 16 andi t2, t1, 1 li t3, 0x7FFF add t4, t2, t3 add t0, t0, t4 # return (bf16_t)f32bits >> 16 srli a0, t0, 16 li t1, 0xFFFF and a0, a0, t1 jalr x0, ra, 0 # ======================================================= # Function : bf16_to_f32() # Parameter : bf16_t a # Variable : # Description : Converts a bfloat16 value to 32-bit float by zero-extending the lower bits. # Return : float value(a0) # ======================================================= # test ok bf16_to_f32: # u32 f32bits = ((u32) val.bits) << 16; slli t0, a0, 16 # memcpy(&result, &f32bits, sizeof(float)) add a0, t0, x0 # return result jalr x0, ra, 0 # ======================================================= # Function : bf16_add() # Parameter : bf16_t a, bf16_t b # Variable : # Description : Performs bfloat16 addition with proper handling of special cases (NaN, Inf, zero). # Return : b16_t value(a0) # ======================================================= # test ok bf16_add: addi sp, sp, -24 sw s6, 20(sp) sw s5, 16(sp) sw s4, 12(sp) sw s3, 8(sp) sw s2, 4(sp) sw s1, 0(sp) # sign_a = (a.bits >> 15) & 1 srli s1, a0, 15 andi s1, s1, 1 # sign_b = (b.bits >> 15) & 1; srli s2, a1, 15 andi s2, s2, 1 # exp_a = ((a.bits >> 7) & 0xFF) srli s3, a0, 7 andi s3, s3, 0xFF # exp_b = ((b.bits >> 7) & 0xFF) srli s4, a1, 7 andi s4, s4, 0xFF # mant_a = a.bits & 0x7F andi s5, a0, 0x7F # mant_b = b.bits & 0x7F andi s6, a1, 0x7F # if exp_a == 0xFF li t0, 0xFF beq s3, t0, bf16_add_exp_a_allOne # if exp_b == 0xFF beq s4, t0, bf16_add_ret_b # if (!exp_a && !mant_a) <=> exp and mant = 0 or t6, s3, s5 beq t6, x0, bf16_add_ret_b # if (!exp_b && !mant_b) <=> exp and mant = 0 or t6, s4, s6 beq t6, x0, bf16_add_ret_a # if (exp_a) bne s3, x0, bf16_add_mant_a_or0x80 # if (exp_b) bne s4, x0, bf16_add_mant_b_or0x80 j bf16_add_dif bf16_add_mant_a_or0x80: ori s5, s5, 0x80 bf16_add_exp_b_not0: # if (exp_b) bne s4, x0, bf16_add_mant_b_or0x80 j bf16_add_dif bf16_add_mant_b_or0x80: ori s6, s6, 0x80 bf16_add_dif: # maybe some error # exp_diff = exp_a - exp_b; sub t0, s3, s4 # if (exp_diff > 0) blt x0, t0, bf16_add_exp_dif_bgt0 # if (exp_diff < 0) blt t0, x0, bf16_add_exp_dif_blt0 # if (exp_diff == 0) beq x0, t0, bf16_add_exp_dif_beq0 # impossible j bf16_add_check_sign bf16_add_exp_dif_bgt0: # result_exp = exp_a add t2, s3, x0 # if (exp_diff > 8) li t6, 8 blt t6, t0, bf16_add_ret_a # mant_b >>= exp_diff srl s6, s6, t0 # jump to if (sign_a == sign_b) j bf16_add_check_sign bf16_add_exp_dif_blt0: # result_exp = exp_b add t2, s4, x0 # if (exp_diff < -8) li t6, -8 blt t0, t2, bf16_add_ret_b # mant_a >>= -exp_diff sub t6, x0, t0 srl s6, s6, t6 # jump to if (sign_a == sign_b) j bf16_add_check_sign bf16_add_exp_dif_beq0: # result_exp = exp_a add t2, s3, x0 bf16_add_check_sign: # if (sign_a == sign_b) , eq jump to bf16_add_check_sign_eq beq s1, s2, bf16_add_check_sign_eq # else # if (mant_a >= mant_b), true jump to gn bge s5, s6, bf16_add_check_mant_gn # else < # result_sign = sign_b add t1, s2, x0 # result_mant = mant_b - mant_a sub t3, s6, s5 # jump bf16_add_check_result_mant j bf16_add_check_result_mant bf16_add_check_mant_gn: # result_sign = sign_a add t1, s1, x0 # result_mant = mant_a - mant_b sub t3, s5, s6 bf16_add_check_result_mant: # if (!result_mant) beq t3, x0, bf16_add_ret0 bf16_add_check_result_mant_while: # while (!(result_mant & 0x80)) andi t5, t3, 0x80 bne t5, x0, bf16_add_ret # result_mant <<= 1 slli t3, t3, 1 # if (--result_exp <= 0) addi t2, t2, -1 bge x0, t2, bf16_add_ret0 j bf16_add_check_result_mant_while bf16_add_check_sign_eq: # result_sign = sign_a add t1, s1, x0 # result_mant = (uint32_t) mant_a + mant_b add t3, s5, s6 # if (result_mant & 0x100), eq 0 jump to return andi t6, t3, 0x100 beq t6, x0, bf16_add_ret # result_mant >>= 1 srli t3, t3, 1 # if(++result_exp >= 0xFF) # ++result_exp addi t2, t2, 1 # (++result_exp >= 0xFF), if result_exp < 0xFF, jump return li t6, 0xFF blt t2, t6, bf16_add_ret # else # ((return result_sign << 15) | 0x7F80) li t6, 0x7F80 slli t1, t1, 15 or a0, t1, t6 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_add_exp_a_allOne: # if mant_a != 0 bne s5, x0, bf16_add_ret_a # if (exp_b == 0xFF) beq s4, t0, bf16_add_exp_b_allOne j bf16_add_ret_a bf16_add_exp_b_allOne: # return (mant_b || sign_a == sign_b) ? b : BF16_NAN() # sign_a == sign_b sub t0, s1, s2 # (mant_b || sign_a == sign_b) or t1, s6, t0 # if true , return b, otherwise return bf16 nan bne t1, x0, bf16_add_ret_b # return BF16_NAN li a0, BF16_NAN add a0, a1, x0 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_add_ret_a: # return a lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_add_ret_b: # return b add a0, a1, x0 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_add_ret0: li a0, BF16_ZERO lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_add_ret: # (result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F) # (result_sign << 15) slli a0, t1, 15 # ((result_exp & 0xFF) << 7) andi t5, t2, 0xFF slli t5, t5, 7 # (result_mant & 0x7F) andi t6, t3, 0x7F or a0, a0, t5 or a0, a0, t6 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 # ======================================================= # Function : bf16_sub() # Parameter : bf16_t a, bf16_t b # Variable : # Description : Performs bfloat16 subtraction by flipping the sign of the second operand and adding. # Return : b16_t value(a0) # ======================================================= # test ok bf16_sub: addi sp, sp, -4 sw ra, 0(sp) # b.bits ^= BF16_SIGN_MASK li t0, BF16_SIGN_MASK xor a1, a1, t0 # call bf16_add jal ra, bf16_add lw ra, 0(sp) addi sp, sp, 4 jalr x0, ra, 0 # ======================================================= # Function : bf16_mul() # Parameter : bf16_t a, bf16_t b # Variable : # Description : Performs bfloat16 multiplication with normalization and special case handling. # Return : b16_t value(a0) # ======================================================= # test ok bf16_mul: addi sp, sp, -24 sw s6, 20(sp) sw s5, 16(sp) sw s4, 12(sp) sw s3, 8(sp) sw s2, 4(sp) sw s1, 0(sp) # sign_a = (a.bits >> 15) & 1 srli s1, a0, 15 andi s1, s1, 1 # sign_b = (b.bits >> 15) & 1; srli s2, a1, 15 andi s2, s2, 1 # exp_a = ((a.bits >> 7) & 0xFF) srli s3, a0, 7 andi s3, s3, 0xFF # exp_b = ((b.bits >> 7) & 0xFF) srli s4, a1, 7 andi s4, s4, 0xFF # mant_a = a.bits & 0x7F andi s5, a0, 0x7F # mant_b = b.bits & 0x7F andi s6, a1, 0x7F # result_sign = sign_a ^ sign_b xor t1, s1, s2 # if (exp_a == 0xFF), exp_a all one jump to bf16_mul_a_exp_allOne li t6, 0xFF beq s3, t6, bf16_mul_a_exp_allOne # if (exp_b == 0xFF), exp_b all one jump to bf16_mul_b_exp_allOne beq s4, t6, bf16_mul_b_exp_allOne # if ((!exp_a && !mant_a) || (!exp_b && !mant_b)) or t5, s3, s5 or t6, s4, s6 beq t5, x0, bf16_mul_retSign_slli15 beq t6, x0, bf16_mul_retSign_slli15 # exp_adjust = 0 add t4, x0, x0 # if (!exp_a), a exp is zero, jump bf16_mul_a_exp_zero beq s3, x0, bf16_mul_a_exp_zero # mant_a |= 0x80 ori s5, s5, 0x80 # if (!exp_b), b exp is zero, jump bf16_mul_b_exp_zero beq s4, x0, bf16_mul_b_exp_zero # mant_b |= 0x80 ori s6, s6, 0x80 j bf16_mul_result_exp_mant bf16_mul_a_exp_zero: # (mant_a & 0x80) andi t6, s5, 0x80 # while (!(mant_a & 0x80)) beq t6, x0, bf16_mul_a_exp_zero_while # exp_a = 1 addi s3, x0, 1 # if (!exp_b), b exp is zero, jump bf16_mul_b_exp_zero beq s4, x0, bf16_mul_b_exp_zero # mant_b |= 0x80 ori s6, s6, 0x80 j bf16_mul_result_exp_mant bf16_mul_a_exp_zero_while: # mant_a <<= 1 slli s5, s5, 1 # exp_adjust-- addi t4, t4, -1 # (mant_a & 0x80) andi t6, s5, 0x80 # while (!(mant_a & 0x80)) beq t6, x0, bf16_mul_a_exp_zero_while # exp_a = 1 addi s3, x0, 1 # if (!exp_b), b exp is zero, jump bf16_mul_b_exp_zero beq s4, x0, bf16_mul_b_exp_zero # mant_b |= 0x80 ori s6, s6, 0x80 j bf16_mul_result_exp_mant bf16_mul_b_exp_zero: # (mant_b & 0x80) andi t6, s6, 0x80 # while (!(mant_b & 0x80)) beq t6, x0, bf16_mul_b_exp_zero_while # exp_b = 1 addi s4, x0, 1 j bf16_mul_result_exp_mant bf16_mul_b_exp_zero_while: # mant_b <<= 1 slli s6, s6, 1 # exp_adjust-- addi t4, t4, -1 # (mant_b & 0x80) andi t6, s6, 0x80 # while (!(mant_b & 0x80)) beq t6, x0, bf16_mul_b_exp_zero_while # exp_b = 1 addi s4, x0, 1 bf16_mul_result_exp_mant: # result_mant = (uint32_t) mant_a * mant_b mul t3, s5, s6 # result_exp = (int32_t) exp_a + exp_b - BF16_EXP_BIAS + exp_adjust li t6, BF16_EXP_BIAS # result_exp = exp_a + exp_b add t2, s3, s4 # result_exp = result_exp - BF16_EXP_BIAS sub t2, t2, t6 # result_exp = result_exp + exp_adjust add t2, t2, t4 # if (result_mant & 0x8000) li t6, 0x8000 and t6, t3, t6 beq t6, x0, bf16_mul_set_result_mant_srl7 # result_mant = (result_mant >> 8) & 0x7F srli t3, t3, 8 andi t3, t3, 0x7F # result_exp++ addi t2, t2, 1 j bf16_mul_check_result_exp bf16_mul_set_result_mant_srl7: # result_mant = (result_mant >> 7) & 0x7F srli t3, t3, 7 andi t3, t3, 0x7F bf16_mul_check_result_exp: # if (result_exp >= 0xFF) li t6, 0xFF bge t2, t6, bf16_mul_result_exp_bg0xFF_ret # if (result_exp <= 0), exp > 0, jump ret blt x0, t2, bf16_mul_ret # if (result_exp < -6) li t6, -6 blt t2, t6, bf16_mul_result_exp_smNeg6_ret # result_mant >>= (1 - result_exp) li t6, 1 sub t5, t6, t2 srl t3, t3, t5 # result_exp = 0 add t2, x0, x0 # return j bf16_mul_ret bf16_mul_result_exp_bg0xFF_ret: # return ((result_sign << 15) | 0x7F80) li t6, 0x7F80 slli t1, t1, 15 or a0, t1, t6 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_result_exp_smNeg6_ret: # return (result_sign << 15) slli a0, t1, 15 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_a_exp_allOne: # if (mant_a), mant_a isn't zero, jump return a bne s5, x0, bf16_mul_reta # if (!exp_b && !mant_b) <=> b exp and mant equal 0 or t5, s4, s6 beq t5, x0, bf16_mul_retNan # return ((result_sign << 15) | 0x7F80) li t6, 0x7F80 slli t1, t1, 15 or a0, t1, t6 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_b_exp_allOne: # if (mant_b), mant_b isn't zero, jump return b bne s6, x0, bf16_mul_retb # if (!exp_a && !mant_a) <=> a exp and mant equal 0 or t5, s3, s5 beq t5, x0, bf16_mul_retNan # return ((result_sign << 15) | 0x7F80) li t6, 0x7F80 slli t1, t1, 15 or a0, t1, t6 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_retSign_slli15: # return result_sign << 15 slli a0, t1, 15 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_retNan: # a0 = a1 li a0, BF16_NAN lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_reta: lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_retb: # a0 = a1 add a0, a1, x0 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 bf16_mul_ret: # return ((result_sign << 15) | ((result_exp & 0xFF) << 7) | (result_mant & 0x7F)) slli t1, t1, 15 andi t2, t2, 0xFF slli t2, t2, 7 andi t3, t3, 0x7F or a0, t1, t2 or a0, a0, t3 lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 # ======================================================= # Function : bf16_div() # Parameter : bf16_t a, bf16_t b # Variable : # Description : Performs bfloat16 division using bit-level integer division and handles edge cases. # Return : b16_t value(a0) # ======================================================= # test ok bf16_div: addi sp, sp, -24 sw s6, 20(sp) sw s5, 16(sp) sw s4, 12(sp) sw s3, 8(sp) sw s2, 4(sp) sw s1, 0(sp) # sign_a = (a.bits >> 15) & 1 srli s1, a0, 15 andi s1, s1, 1 # sign_b = (b.bits >> 15) & 1; srli s2, a1, 15 andi s2, s2, 1 # exp_a = ((a.bits >> 7) & 0xFF) srli s3, a0, 7 andi s3, s3, 0xFF # exp_b = ((b.bits >> 7) & 0xFF) srli s4, a1, 7 andi s4, s4, 0xFF # mant_a = a.bits & 0x7F andi s5, a0, 0x7F # mant_b = b.bits & 0x7F andi s6, a1, 0x7F # result_sign = sign_a ^ sign_b xor t1, s1, s2 # if (exp_b == 0xFF) li t6, 0xFF beq s4, t6, bf16_div_b_exp_allOne # if (!exp_b && !mant_b) or t5, s4, s6 beq t5, x0, bf16_div_b_exp_mant_zero # if (exp_a == 0xFF) beq s3, t6, bf16_div_a_exp_allOne # if (!exp_a && !mant_a), a exp and mant all zero or t5, s3, s5 beq t5, x0, bf16_div_a_exp_mant_zero # if (exp_a) bne s3, x0, bf16_div_mant_a_or0x80 # if (exp_b) bne s4, x0, bf16_div_mant_b_or0x80 j bf16_div_run bf16_div_mant_a_or0x80: # mant_a |= 0x80 ori s5, s5, 0x80 # if (exp_b) bne s4, x0, bf16_div_mant_b_or0x80 j bf16_div_run bf16_div_mant_b_or0x80: # mant_b |= 0x80 ori s6, s6, 0x80 j bf16_div_run bf16_div_run: # dividend = (uint32_t) mant_a << 15 slli t3, s5, 15 # divisor = mant_b add t4, s6, x0 # uint32_t quotient = 0 add t5, x0, x0 # set i = 0 add t0, x0, x0 li t6, 16 j bf16_div_for bf16_div_for: # for (int i = 0; i < 16; i++) bge t0, t6, bf16_div_result_exp # quotient <<= 1 slli t5, t5, 1 # (divisor << (15 - i)) li t6, 15 sub t6, t6, t0 sll t6, t4, t6 # if (dividend >= (divisor << (15 - i))) bge t3, t6, bf16_div_for_divdend_divsor li t6, 16 # i++ addi t0, t0, 1 j bf16_div_for bf16_div_for_divdend_divsor: # dividend -= (divisor << (15 - i)) sub t3, t3, t6 # quotient |= 1 ori t5, t5, 1 li t6, 16 # i++ addi t0, t0, 1 j bf16_div_for bf16_div_result_exp: # result_exp = (int32_t) exp_a - exp_b + BF16_EXP_BIAS li t6, BF16_EXP_BIAS sub t2, s3, s4 add t2, t2, t6 # if (!exp_a) beq s3, x0, bf16_div_result_exp_minus1 # if (!exp_b) beq s4, x0, bf16_div_result_exp_plus1 j bf16_div_quotient bf16_div_result_exp_minus1: # result_exp-- addi t2, t2, -1 # if (!exp_b) beq s4, x0, bf16_div_result_exp_plus1 j bf16_div_quotient bf16_div_result_exp_plus1: # result_exp++ addi t2, t2, 1 j bf16_div_quotient bf16_div_quotient: # if (quotient & 0x8000), quot&0x8000 > 0, srli 8 li t6, 0x8000 and t6, t5, t6 bne t6, x0, bf16_div_quot_srl8 # else jump quot while j bf16_div_quot_while bf16_div_quot_while: # while (!(quotient & 0x8000) && result_exp > 1) bne t6, x0, bf16_div_quot_srl8 li t0, 1 bge t0, t2, bf16_div_quot_srl8 # quotient <<= 1 slli t5, t5, 1 # result_exp-- addi t2, t2, -1 # quotient & 0x8000 li t6, 0x8000 and t6, t5, t6 j bf16_div_quot_while bf16_div_quot_srl8: # quotient >>= 8 srli t5, t5, 8 j bf16_div_result_exp_ret bf16_div_result_exp_ret: # quotient &= 0x7F andi t5, t5, 0x7F # if (result_exp >= 0xFF) li t6, 0xFF bge t2, t6, bf16_div_ret_sign_sll15_or7F80 # if (result_exp <= 0) bge x0, t2, bf16_div_ret_sign_sll15 j bf16_div_ret bf16_div_a_exp_allOne: # if (mant_a), mant_a != 0, return a bne s5, x0, bf16_div_reta # return ((result_sign << 15) | 0x7F80) j bf16_div_ret_sign_sll15_or7F80 bf16_div_a_exp_mant_zero: # return (result_sign << 15) j bf16_div_ret_sign_sll15 bf16_div_b_exp_allOne: # if (mant_b), b mant != 0, return b bne s6, x0, bf16_div_retb # if (exp_a == 0xFF && !mant_a) li t6, 0xFF beq s3, t6, bf16_div_b_check_a_NAN # return result_sign << 15 j bf16_div_ret_sign_sll15 bf16_div_b_exp_mant_zero: # if (!exp_a && !mant_a), a exp and mant all zero, return NAN or t5, s3, s5 beq t5, x0, bf16_div_retNAN # return ((result_sign << 15) | 0x7F80) j bf16_div_ret_sign_sll15_or7F80 bf16_div_b_check_a_NAN: # if (exp_a == 0xFF && !mant_a), exp = 0xFF, mant == 0 return nan beq s5, x0, bf16_div_retNAN # else mant = 0 # return result_sign << 15 j bf16_div_ret_sign_sll15 bf16_div_ret_sign_sll15_or7F80: # return ((result_sign << 15) | 0x7F80) slli a0, t1, 15 li t6, 0x7F80 or a0, a0, t6 j bf16_div_return bf16_div_ret_sign_sll15: # return result_sign << 15 slli a0, t1, 15 j bf16_div_return bf16_div_retNAN: # set a0 = BF16_NAN li a0, BF16_NAN j bf16_div_return bf16_div_reta: j bf16_div_return bf16_div_retb: # set a0 = b add a0, a1, x0 j bf16_div_return bf16_div_ret: # ((result_sign << 15) | ((result_exp & 0xFF) << 7) |(quotient & 0x7F)) slli t1, t1, 15 andi t2, t2, 0xFF slli t2, t2, 7 andi t5, t5, 0x7F or a0, t1, t2 or a0, a0, t5 j bf16_div_return bf16_div_return: lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) lw s4, 12(sp) lw s5, 16(sp) lw s6, 20(sp) addi sp, sp, 24 jalr x0, ra, 0 # ======================================================= # Function : bf16_sqrt() # Parameter : bf16_t a # Variable : # Description : Computes the square root of a bfloat16 number using bitwise operations and binary search. # Return : b16_t value(a0) # ======================================================= # test ok bf16_sqrt: addi sp, sp, -12 sw s3, 8(sp) sw s2, 4(sp) sw s1, 0(sp) # sign = (a.bits >> 15) & 1 srli s1, a0, 15 andi s1, s1, 1 # exp = ((a.bits >> 7) & 0xFF) srli s2, a0, 7 andi s2, s2, 0xFF # mant = a.bits & 0x7F andi s3, a0, 0x7F # if (exp == 0xFF), exp all one, jump bf16_sqrt_exp_allOne li t6, 0xFF beq s2, t6, bf16_sqrt_exp_allOne # if (!exp && !mant), exp and mant all zeros, return zero or t0, s2, s3 beq t0, x0, bf16_sqrt_retZero # if (sign), sign = 1 => negative, return nan bne s1, x0, bf16_sqrt_retNAN # if (!exp), exp = 0, return zero beq s2, x0, bf16_sqrt_retZero # e = exp - BF16_EXP_BIAS li t6, BF16_EXP_BIAS sub t1, s2, t6 # m = 0x80 | mant ori t3, s3, 0x80 j bf16_sqrt_adjust_odd_exp bf16_sqrt_adjust_odd_exp: # if (e & 1), odd, jump bf16_sqrt_adjust_odd andi t5, t1, 1 bne t5, x0, bf16_sqrt_adjust_odd # else jump bf16_sqrt_adjust_even j bf16_sqrt_adjust_even bf16_sqrt_adjust_odd: # m <<= 1 slli t3, t3, 1 # new_exp = ((e - 1) >> 1) + BF16_EXP_BIAS addi t2, t1, -1 srai t2, t2, 1 li t6, BF16_EXP_BIAS add t2, t2, t6 j bf16_sqrt_search_square bf16_sqrt_adjust_even: # new_exp = (e >> 1) + BF16_EXP_BIAS srai t2, t1, 1 li t6, BF16_EXP_BIAS add t2, t2, t6 j bf16_sqrt_search_square bf16_sqrt_search_square: # low = 90 li t4, 90 # high = 256 li t5, 256 # result = 128 li t6, 128 j bf16_sqrt_search_square_while bf16_sqrt_search_square_while: # while (low <= high), low > high, jump ensure result blt t5, t4, bf16_sqrt_norm_ensure_result # mid = (low + high) >> 1 add t0, t4, t5 srli t0, t0, 1 j bf16_sqrt_pow_init bf16_sqrt_search_square_pow: # sq = (mid * mid) / 128 srli t1, t1, 7 # if (sq <= m) bge t3, t1, bf16_sqrt_search_square_while_midPlus1 # else j bf16_sqrt_search_square_while_midMinus1 bf16_sqrt_pow_init: # set sq = 0 , s5 = t0 , s4 = t0 add s4, t0, x0 add s5, t0, x0 add t1, x0, x0 j bf16_sqrt_pow_forLoop bf16_sqrt_pow_forLoop: # if s5 == 0, return bf16_sqrt_search_square_pow beq s5, x0, bf16_sqrt_search_square_pow andi s0, s5, 1 # if (s0&1 == 0), j bf16_sqrt_pow_lsbZero beq s0, x0, bf16_sqrt_pow_lsbZero j bf16_sqrt_pow_lsbOne bf16_sqrt_pow_lsbOne: # sq = sq + s4 add t1, t1, s4 # s4 <<= 1 slli s4, s4, 1 # s5 >>= 1 srli s5, s5, 1 j bf16_sqrt_pow_forLoop bf16_sqrt_pow_lsbZero: # s4 <<= 1 slli s4, s4, 1 # s5 >>= 1 srli s5, s5, 1 j bf16_sqrt_pow_forLoop bf16_sqrt_search_square_while_midPlus1: # result = mid; add t6, t0, x0 # low = mid + 1; addi t4, t0, 1 j bf16_sqrt_search_square_while bf16_sqrt_search_square_while_midMinus1: # high = mid - 1; addi t5, t0, -1 j bf16_sqrt_search_square_while bf16_sqrt_norm_ensure_result: # if (result >= 256), bge li t0, 256 bge t6, t0, bf16_sqrt_norm_ensure_result_bge256 # else if (result < 128) li t0, 128 blt t6, t0, bf16_sqrt_norm_ensure_result_blt128 j bf16_sqrt_extract_mantissa bf16_sqrt_norm_ensure_result_bge256: # result >>= 1 srli t6, t6, 1 # new_exp++ addi t2, t2, 1 j bf16_sqrt_extract_mantissa bf16_sqrt_norm_ensure_result_blt128: li t0, 128 li t1, 1 j bf16_sqrt_norm_ensure_result_blt128_while bf16_sqrt_norm_ensure_result_blt128_while: # while (result < 128 && new_exp > 1) # result < 128 bge t6, t0, bf16_sqrt_extract_mantissa # new_exp > 1 bge t1, t2, bf16_sqrt_extract_mantissa # result <<= 1 slli t6, t6, 1 # new_exp-- addi t2, t2, -1 j bf16_sqrt_norm_ensure_result_blt128_while bf16_sqrt_extract_mantissa: # new_mant = result & 0x7F andi t4, t6, 0x7F li t5, 0xFF # if (new_exp >= 0xFF) bge t2, t5, bf16_sqrt_retBge0xFF # if (new_exp <= 0) bge x0, t2, bf16_sqrt_retZero j bf16_sqrt_ret bf16_sqrt_exp_allOne: # if (mant), mant != 0, return a bne s3, x0, bf16_sqrt_reta # if (sign), sign != 0, return nan bne s1, x0, bf16_sqrt_retNAN # return a j bf16_sqrt_reta bf16_sqrt_retBge0xFF: # return 0x7F80 li a0, 0x7F80 j bf16_sqrt_return bf16_sqrt_retZero: # return zero li a0, BF16_ZERO j bf16_sqrt_return bf16_sqrt_retNAN: # return NAN li a0, BF16_NAN j bf16_sqrt_return bf16_sqrt_reta: # return a j bf16_sqrt_return bf16_sqrt_ret: # return ((new_exp & 0xFF) << 7) | new_mant andi a0, t2, 0xFF slli a0, a0, 7 or a0, a0, t4 j bf16_sqrt_return bf16_sqrt_return: lw s1, 0(sp) lw s2, 4(sp) lw s3, 8(sp) addi sp, sp, 12 jalr x0, ra, 0 ``` ::: ### Result ![image](https://hackmd.io/_uploads/SyG7ztQaxg.png) ---- ## Problem B in [quiz1](https://hackmd.io/@sysprog/arch2025-quiz1-sol) #### clz The clz (Count Leading Zeros) function works by examining the upper half of the input bits to determine where the first 1 appears, starting from the most significant bit. It uses a binary search–like approach to reduce the number of comparisons. At each step, it right-shifts the number by c bits and checks whether the result is zero: * If the result is zero, it means the upper half being examined contains only zeros, so the next check moves to the lower half. * If the result is non-zero, it means there's at least one 1 in the upper half. In that case, the current number is updated to the shifted result, and the count of leading zeros n is reduced. The shift amount c is halved after each iteration, gradually narrowing down the range to locate the first 1. ### C code of clz ```c= static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } ``` ### Assembly code of clz ```Assembly= clz: addi sp, sp, -16 sw ra, 12(sp) sw t0, 8(sp) sw t1, 4(sp) sw t2, 0(sp) li t0, 32 # t0 = n li t1, 16 # t1 = c clz_whileLoop: srl t2, a0, t1 # y = t2, y = x >> c beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit sub t0, t0, t1 # n = n - c addi a0, t2, 0 # x = y shift_right_1bit: srli t1, t1, 1 # c = c >> 1 bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop sub a0, t0, a0 # x = n - x lw t2, 0(sp) lw t1, 4(sp) lw t0, 8(sp) lw ra, 12(sp) addi sp, sp, 16 ``` --- ## Problem [Leetcode 3370. Smallest Number With All Set Bits](https://leetcode.com/problems/smallest-number-with-all-set-bits/description/) >You are given a positive number n. > >Return the smallest number x greater than or equal to n, such that the binary representation of x contains only set bits > >Set Bit A set bit refers to a bit in the binary representation of a number that has a value of 1. > > >### Example 1: > > >Input: n = 5 >Output: 7 >Explanation: >The binary representation of 7 is "111". > > >### Constraints: >* 1 <= n <= 1000 > ## Solution ### Idea for problem solving The problem asks us to find the smallest number x such that x >= n and all bits of x are 1 in binary representation. This means we are looking for the smallest number in the form of 2^k - 1 such that 2^k - 1 >= n. To solve this: We can start from x = 1 and keep shifting left (i.e., multiplying by 2), until x >= n. This approach is simple but takes O(log n) time. To optimize, we can use the clz (Count Leading Zeros) instruction to directly find the position of the most significant 1 bit in n. Using clz, we can calculate the required power of 2 in O(1) time, and construct the result as x = (1 << (32 - clz(n))) - 1. This is a typical bit manipulation problem, where understanding binary patterns and bitwise operations helps simplify the logic and improve performance. --- #### Original: ##### C code ```c= int smallestNumber(int n) { int x = 1; while (x<n) x <<=1; if (x == n) return (x<<1) - 1; return x-1; } ``` ##### Assembly code ```Assmebly= smallestNumber: addi sp, sp, -8 sw ra, 4(sp) sw t0, 0(sp) li t0, 1 # x = 1 sml_while_loop: bge t0, a0, sml_cmp_xandn slli t0, t0, 1 # x <<= 1 j sml_while_loop sml_cmp_xandn: bne t0, a0, sml_x_shift slli t0, t0, 1 sml_x_shift: addi a0, t0, -1 lw t0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 ``` ---- ### Using clz: #### C code ```c= static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } int smallestNumber(int n) { int bit_len = (1 << (32-clz(n)))-1; return bit_len; } ``` #### Assembly code ```Assmebly= smallestNumber: addi sp, sp, -12 sw ra, 8(sp) sw t0, 4(sp) sw t1, 0(sp) # call clz addi sp, sp, -8 # store ra, a0(n) sw ra, 4(sp) sw a0, 0(sp) jal ra, clz li t0, 32 # bit_len = 32 li t1, 1 sub t0, t0, a0 # bit_len = 32 - clz(n) lw a0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 sll t1, t1, t0 addi a0, t1, -1 lw t1, 0(sp) lw t0, 4(sp) lw ra, 8(sp) addi sp, sp, 12 j x0, ra, 0 clz: addi sp, sp, -16 sw ra, 12(sp) sw t0, 8(sp) sw t1, 4(sp) sw t2, 0(sp) li t0, 32 # t0 = n li t1, 16 # t1 = c clz_whileLoop: srl t2, a0, t1 # y = t2, y = x >> c beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit sub t0, t0, t1 # n = n - c addi a0, t2, 0 # x = y shift_right_1bit: srli t1, t1, 1 # c = c >> 1 bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop sub a0, t0, a0 # x = n - x lw t2, 0(sp) lw t1, 4(sp) lw t0, 8(sp) lw ra, 12(sp) addi sp, sp, 16 j x0, ra, 0 ``` ---- ### Using clz in Leetcode #### C code ```c= #include <stdio.h> #include <stdint.h> static inline unsigned clz(uint32_t x) { int n = 32, c = 16; do { uint32_t y = x >> c; if (y) { n -= c; x = y; } c >>= 1; } while (c); return n - x; } int smallestNumber(int n) { int bit_len = (1 << (32-clz(n)))-1; return bit_len; } int main() { int input[] = {1, 509, 1000}, output; int ans[] = {1, 511, 1023}; for (int i = 0 ; i < 3 ; ++i) { output = smallestNumber(input[i]); printf("output = %d, answer = %d\n", output, ans[i]); if (output == ans[i]) printf("True\n"); else printf("False\n"); } return 0; } ``` #### Assembly code ```Assmebly= .data input: .word 1, 509, 1000 ans: .word 1, 511, 1023 output_str: .string "output = " answer_str: .string ", answer = " endline_str: .string "\n" true_str: .string "True\n" false_str: .string "False\n" .text .global main main: # load input array la s0, input # load answer array la s1, ans # i = 0 li t0, 0 # for loop boundary 3 li t1, 3 main_for: # for (int i = 0 ; i < 3 ; ++i) bge t0, t1, main_exit # load input[i] to a0 slli t2, t0, 2 add t3, s0, t2 lw a0, 0(t3) # call smallestNumber addi sp, sp, -16 sw ra, 12(sp) sw t2, 8(sp) sw t1, 4(sp) sw t0, 0(sp) jal ra, smallestNumber # output = smallestNumber(input[i]) add s2, a0, x0 lw t0, 0(sp) lw t1, 4(sp) lw t2, 8(sp) lw ra, 12(sp) addi sp, sp, 16 j print_result print_result: # print "output = " la a0, output_str li a7, 4 ecall # print output addi a0, s2, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[i] to t3 # print answer add t2, s1, t2 lw a0, 0(t2) add t2, a0, x0 li a7, 1 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output[i] == ans[i]) beq s2, t2, print_true # else j print_false print_true: # print "True\n" la a0, true_str li a7, 4 ecall # ++i addi t0, t0, 1 j main_for print_false: # print "False\n" la a0, false_str li a7, 4 ecall # ++i addi t0, t0, 1 j main_for smallestNumber: addi sp, sp, -12 sw ra, 8(sp) sw t0, 4(sp) sw t1, 0(sp) # call clz addi sp, sp, -8 # store ra, a0(n) sw ra, 4(sp) sw a0, 0(sp) jal ra, clz li t0, 32 # bit_len = 32 li t1, 1 sub t0, t0, a0 # bit_len = 32 - clz(n) lw a0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 sll t1, t1, t0 addi a0, t1, -1 lw t1, 0(sp) lw t0, 4(sp) lw ra, 8(sp) addi sp, sp, 12 jalr x0, x1, 0 clz: li t0, 32 # t0 = n li t1, 16 # t1 = c j clz_whileLoop clz_whileLoop: srl t2, a0, t1 # y = t2, y = x >> c beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit sub t0, t0, t1 # n = n - c addi a0, t2, 0 # x = y j shift_right_1bit shift_right_1bit: srli t1, t1, 1 # c = c >> 1 bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop sub a0, t0, a0 # x = n - x jalr x0, x1, 0 main_exit: li a7, 10 ecall ``` #### Result ![image](https://hackmd.io/_uploads/S19KQLZTxe.png) ---- ### Optimizing assembly code #### The above method can optimize the assembly code for the three types of hazards. 1. structure hazard : add more hardware 2. data hazard : data forwarding or reordering 3. control hazard : delayed slot(ripes unsupported(auto flush)) or branch prediction(1 bit、2 bits、3 bits) * Loop unrolling can also solve data and control hazards. #### loop unrolling **I use loop unrolling to reduce the number of branches.** :::spoiler More detailed information ```assembly= .data input: .word 1, 509, 1000 ans: .word 1, 511, 1023 output_str: .string "output = " answer_str: .string ", answer = " endline_str: .string "\n" true_str: .string "True\n" false_str: .string "False\n" .text .global main main: # load input array la s0, input # load answer array la s1, ans main_for: # loop unrolling # load input[0] to a0 addi t3, s0, 0 lw a0, 0(t3) # call smallestNumber addi sp, sp, -4 sw ra, 0(sp) jal ra, smallestNumber # output = smallestNumber(input[0]) add s2, a0, x0 # load input[1] to a0 addi t3, s0, 4 lw a0, 0(t3) # call smallestNumber jal ra, smallestNumber # output = smallestNumber(input[1]) add s3, a0, x0 # load input[2] to a0 addi t3, s0, 8 lw a0, 0(t3) # call smallestNumber jal ra, smallestNumber # output = smallestNumber(input[1]) add s4, a0, x0 lw ra, 0(sp) addi sp, sp, 4 j print_result_0 print_result_0: # print "output = " la a0, output_str li a7, 4 ecall # print output addi a0, s2, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[0] to t3 # print answer addi t2, s1, 0 lw a0, 0(t2) add t2, a0, x0 li a7, 1 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output[i] == ans[i]) beq s2, t2, print_true_0 # else j print_false_0 print_true_0: # print "True\n" la a0, true_str li a7, 4 ecall j print_result_1 print_false_0: # print "False\n" la a0, false_str li a7, 4 ecall j print_result_1 print_result_1: # print "output = " la a0, output_str li a7, 4 ecall # print output addi a0, s3, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[0] to t3 # print answer addi t2, s1, 4 lw a0, 0(t2) add t2, a0, x0 li a7, 1 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output == ans[1]) beq s3, t2, print_true_1 # else j print_false_1 print_true_1: # print "True\n" la a0, true_str li a7, 4 ecall j print_result_2 print_false_1: # print "False\n" la a0, false_str li a7, 4 ecall j print_result_2 print_result_2: # print "output = " la a0, output_str li a7, 4 ecall # print output addi a0, s4, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[0] to t3 # print answer addi t2, s1, 8 lw a0, 0(t2) add t2, a0, x0 li a7, 1 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output == ans[2]) beq s4, t2, print_true_2 # else j print_false_2 print_true_2: # print "True\n" la a0, true_str li a7, 4 ecall j main_exit print_false_2: # print "False\n" la a0, false_str li a7, 4 ecall j main_exit smallestNumber: addi sp, sp, -12 sw ra, 8(sp) sw t0, 4(sp) sw t1, 0(sp) # call clz addi sp, sp, -8 # store ra, a0(n) sw ra, 4(sp) sw a0, 0(sp) jal ra, clz li t0, 32 # bit_len = 32 li t1, 1 sub t0, t0, a0 # bit_len = 32 - clz(n) lw a0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 sll t1, t1, t0 addi a0, t1, -1 lw t1, 0(sp) lw t0, 4(sp) lw ra, 8(sp) addi sp, sp, 12 jalr x0, x1, 0 clz: li t0, 32 # t0 = n li t1, 16 # t1 = c j clz_whileLoop clz_whileLoop: srl t2, a0, t1 # y = t2, y = x >> c beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit sub t0, t0, t1 # n = n - c addi a0, t2, 0 # x = y j shift_right_1bit shift_right_1bit: srli t1, t1, 1 # c = c >> 1 bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop sub a0, t0, a0 # x = n - x jalr x0, x1, 0 main_exit: li a7, 10 ecall ``` ::: #### Loop unrolling + Reorder(load-use) :::spoiler More detailed information ```assembly= .data input: .word 1, 509, 1000 ans: .word 1, 511, 1023 output_str: .string "output = " answer_str: .string ", answer = " endline_str: .string "\n" true_str: .string "True\n" false_str: .string "False\n" .text .global main main: # load input array la s0, input # load answer array la s1, ans main_for: # loop unrolling # load input[0] to a0 addi t3, s0, 0 lw a0, 0(t3) # call smallestNumber addi sp, sp, -4 sw ra, 0(sp) jal ra, smallestNumber # output = smallestNumber(input[0]) add s2, a0, x0 # load input[1] to a0 addi t3, s0, 4 lw a0, 0(t3) # call smallestNumber jal ra, smallestNumber # output = smallestNumber(input[1]) add s3, a0, x0 # load input[2] to a0 addi t3, s0, 8 lw a0, 0(t3) # call smallestNumber jal ra, smallestNumber # output = smallestNumber(input[1]) add s4, a0, x0 lw ra, 0(sp) addi sp, sp, 4 j print_result_0 print_result_0: # print "output = " la a0, output_str li a7, 4 ecall # print output addi a0, s2, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[0] to t3 # print answer addi t2, s1, 0 lw a0, 0(t2) # reorder li and add li a7, 1 add t2, a0, x0 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output[i] == ans[i]) beq s2, t2, print_true_0 # else j print_false_0 print_true_0: # print "True\n" la a0, true_str ecall j print_result_1 print_false_0: # print "False\n" la a0, false_str ecall j print_result_1 print_result_1: # print "output = " la a0, output_str ecall # print output addi a0, s3, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[0] to t3 # print answer addi t2, s1, 4 lw a0, 0(t2) # reorder li and add li a7, 1 add t2, a0, x0 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output == ans[1]) beq s3, t2, print_true_1 # else j print_false_1 print_true_1: # print "True\n" la a0, true_str ecall j print_result_2 print_false_1: # print "False\n" la a0, false_str ecall j print_result_2 print_result_2: # print "output = " la a0, output_str ecall # print output addi a0, s4, 0 li a7, 1 ecall # ", answer = " la a0, answer_str li a7, 4 ecall # load ans[0] to t3 # print answer addi t2, s1, 8 lw a0, 0(t2) # reorder li and add li a7, 1 add t2, a0, x0 ecall # print "\n" la a0, endline_str li a7, 4 ecall # if (output == ans[2]) beq s4, t2, print_true_2 # else j print_false_2 print_true_2: # print "True\n" la a0, true_str ecall j main_exit print_false_2: # print "False\n" la a0, false_str ecall j main_exit smallestNumber: addi sp, sp, -12 sw ra, 8(sp) sw t0, 4(sp) sw t1, 0(sp) # call clz addi sp, sp, -8 # store ra, a0(n) sw ra, 4(sp) sw a0, 0(sp) jal ra, clz li t0, 32 # bit_len = 32 li t1, 1 sub t0, t0, a0 # bit_len = 32 - clz(n) lw a0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 sll t1, t1, t0 addi a0, t1, -1 lw t1, 0(sp) lw t0, 4(sp) lw ra, 8(sp) addi sp, sp, 12 jalr x0, x1, 0 clz: li t0, 32 # t0 = n li t1, 16 # t1 = c j clz_whileLoop clz_whileLoop: srl t2, a0, t1 # y = t2, y = x >> c beq t2, x0, shift_right_1bit # if y == 0, jump to shift_right_1bit sub t0, t0, t1 # n = n - c addi a0, t2, 0 # x = y j shift_right_1bit shift_right_1bit: srli t1, t1, 1 # c = c >> 1 bne t1, x0, clz_whileLoop # if c != 0, jump to clz_whileLoop sub a0, t0, a0 # x = n - x jalr x0, x1, 0 main_exit: li a7, 10 ecall ``` ::: #### Performance and CPU Cycle Counts of Leetcode Problems in Assembly: |C |Assembly|Unrolling | |--------|--------|--------| |![image](https://hackmd.io/_uploads/rJND3Nb6xl.png =220x)|![image](https://hackmd.io/_uploads/SJ2D6y-alx.png =200x)|![image](https://hackmd.io/_uploads/HJggpr-6xl.png =220x)| |Unrolling+Reorder| |--------| |![image](https://hackmd.io/_uploads/rywUa6WTex.png)| By implementing the problem in Assembly, **the number of CPU cycles was reduced from 10785 to 437 cycles, resulting in a 95% reduction in execution time.** Further optimization using techniques like **unrolling and reordering brought the cycles down to 383, improving performance by an additional 12%**. ## Analysis ### RISC-V operation process in Ripes Ripes offers multiple execution models. **I selected the 5-stage pipeline with data forwarding and hazard detection** to run my program. ![image](https://hackmd.io/_uploads/SJd-qhgTee.png) The "5-stage" means this processor using five-stage pipeline to parallelize instructions. The stages are: 1. Instruction fetch 2. Instruction decode 3. Execute 4. Memory access 5. Write back --- ### R-type format |31 - 25|24 - 20|19 - 15|14 - 12|11 - 7|6 - 0| |---------|-------|-------|--------|------|---------| | func7 | rs2 | rs1 | funct3 | rd | opcode | 1. opcode : operation code 2. rd : destination register number 3. funct3 : 3-bit function code 4. rs1 : the first source register number 5. rs2 : the second source register number 6. funct7 : 7-bit function code ex: add x28, x8, x7 | func7 |rs2(x7)|rs1(x8)| funct3 |rd(x28)| opcode | |---------|-------|-------|--------|------|---------| | 0000000 | 00111 | 01000 | 000 |11100 | 0110011 | => hex = 0x00740e33 --- #### IF (Instruction fetch) ![image](https://hackmd.io/_uploads/rkCORngTle.png) 1. The program counter points to the current instruction address in memory. 2. Since no branch occurs, the PC is updated to PC + 4. The multiplexer before the PC selects the adder output as the next address. 3. The instruction memory fetches the instruction at the address, and the compressed decoder (if enabled) expands any compressed instructions to 32-bit format. #### ID (Instruction decode) ![image](https://hackmd.io/_uploads/r1o203gpex.png) 1. Program Counter (PC) * 0x00000000 is the previous program counter (PC), pointing to the last instruction. * 0x00000004 is the current PC, pointing to the instruction currently being decoded. In RISC-V, each instruction is 4 bytes, so PC increases by 4 each time. 2. Decoder The decoder takes the binary instruction fetched from memory and extracts the following fields: * rs1: the first source register * rs2: the second source register * rd: the destination register opcode: tells what kind of instruction it is (e.g., arithmetic, load, branch) 3. Registers * The register file uses the decoded rs1 and rs2 to read values from the corresponding registers * These values are output and passed to the next stage (usually the Execute stage). 4. Immediate * In my example, no immediate is used, which means the instruction is likely an R-type instruction like add, sub, etc. * If it were an I-type, S-type, etc., the immediate value would be extracted during this stage. #### EX (Execute) ![image](https://hackmd.io/_uploads/SkpC0nx6gl.png) 1. Multiplexer * A 3-to-1 multiplexer selects input data from either the ID stage, EX stage, or WB stage. Since only one instruction is being used, the MUX will select the data from the ID stage. * There are two 2-to-1 multiplexers in the design: -- The first multiplexer selects between the Program Counter (PC) and rs1. -- The second multiplexer selects between the immediate value and rs2. 2. ALU * The control unit sends the opcode to the ALU, which uses it to determine the operation to perform. * The ALU result is sent back to the IF stage, allowing the PC to choose between PC + 4 or the ALU result (e.g., for branches or jumps). 3. Branch * The input data for the branch decision comes from the values in rs1 and rs2. However, since no branch instructions are used, the branch logic can be ignored in this case. #### MEM (Memory access) ![image](https://hackmd.io/_uploads/Hyreypg6ge.png) 1. Data memory * In a typical design, the data memory uses the ALU result as the address to load or store data. However, since this design does not include lw or sw instructions, the data memory is not utilized. 2. The result value is sent back to the EX (Execute) stage to be used by the next instruction. ![image](https://hackmd.io/_uploads/B12BlyW6ex.png) #### WB (Write back) ![image](https://hackmd.io/_uploads/BJ7zk6gTee.png) * The multiplexer chooses the result from the ALU as the final output, so the output value is 0x00000000. * The value is sent back to the EX stage if a future instruction needs the value of this instruction’s rd register. * Regardless, the value 0x00000000 is written into register x28. After all these stage are done, the register is updated like this: ![image](https://hackmd.io/_uploads/rknpz1Z6ll.png =40%x) ![image](https://hackmd.io/_uploads/SkOjmkZ6ge.png =40%x) ## Reference * [RISC-V Instruction Set Manual](https://riscv.org/specifications/ratified/) * [Quiz1 of Computer Architecture (2025 Fall)](https://hackmd.io/@sysprog/arch2025-quiz1-sol)

Import from clipboard

Paste your markdown or webpage here...

Advanced permission required

Your current role can only read. Ask the system administrator to acquire write and comment permission.

This team is disabled

Sorry, this team is disabled. You can't edit this note.

This note is locked

Sorry, only owner can edit this note.

Reach the limit

Sorry, you've reached the max length this note can be.
Please reduce the content or divide it to more notes, thank you!

Import from Gist

Import from Snippet

or

Export to Snippet

Are you sure?

Do you really want to delete this note?
All users will lose their connection.

Create a note from template

Create a note from template

Oops...
This template has been removed or transferred.
Upgrade
All
  • All
  • Team
No template.

Create a template

Upgrade

Delete template

Do you really want to delete this template?
Turn this template into a regular note and keep its content, versions, and comments.

This page need refresh

You have an incompatible client version.
Refresh to update.
New version available!
See releases notes here
Refresh to enjoy new features.
Your user state has changed.
Refresh to load new user state.

Sign in

Forgot password

or

By clicking below, you agree to our terms of service.

Sign in via Facebook Sign in via Twitter Sign in via GitHub Sign in via Dropbox Sign in with Wallet
Wallet ( )
Connect another wallet

New to HackMD? Sign up

Help

  • English
  • 中文
  • Français
  • Deutsch
  • 日本語
  • Español
  • Català
  • Ελληνικά
  • Português
  • italiano
  • Türkçe
  • Русский
  • Nederlands
  • hrvatski jezik
  • język polski
  • Українська
  • हिन्दी
  • svenska
  • Esperanto
  • dansk

Documents

Help & Tutorial

How to use Book mode

Slide Example

API Docs

Edit in VSCode

Install browser extension

Contacts

Feedback

Discord

Send us email

Resources

Releases

Pricing

Blog

Policy

Terms

Privacy

Cheatsheet

Syntax Example Reference
# Header Header 基本排版
- Unordered List
  • Unordered List
1. Ordered List
  1. Ordered List
- [ ] Todo List
  • Todo List
> Blockquote
Blockquote
**Bold font** Bold font
*Italics font* Italics font
~~Strikethrough~~ Strikethrough
19^th^ 19th
H~2~O H2O
++Inserted text++ Inserted text
==Marked text== Marked text
[link text](https:// "title") Link
![image alt](https:// "title") Image
`Code` Code 在筆記中貼入程式碼
```javascript
var i = 0;
```
var i = 0;
:smile: :smile: Emoji list
{%youtube youtube_id %} Externals
$L^aT_eX$ LaTeX
:::info
This is a alert area.
:::

This is a alert area.

Versions and GitHub Sync
Get Full History Access

  • Edit version name
  • Delete

revision author avatar     named on  

More Less

Note content is identical to the latest version.
Compare
    Choose a version
    No search result
    Version not found
Sign in to link this note to GitHub
Learn more
This note is not linked with GitHub
 

Feedback

Submission failed, please try again

Thanks for your support.

On a scale of 0-10, how likely is it that you would recommend HackMD to your friends, family or business associates?

Please give us some advice and help us improve HackMD.

 

Thanks for your feedback

Remove version name

Do you want to remove this version name and description?

Transfer ownership

Transfer to
    Warning: is a public team. If you transfer note to this team, everyone on the web can find and read this note.

      Link with GitHub

      Please authorize HackMD on GitHub
      • Please sign in to GitHub and install the HackMD app on your GitHub repo.
      • HackMD links with GitHub through a GitHub App. You can choose which repo to install our App.
      Learn more  Sign in to GitHub

      Push the note to GitHub Push to GitHub Pull a file from GitHub

        Authorize again
       

      Choose which file to push to

      Select repo
      Refresh Authorize more repos
      Select branch
      Select file
      Select branch
      Choose version(s) to push
      • Save a new version and push
      • Choose from existing versions
      Include title and tags
      Available push count

      Pull from GitHub

       
      File from GitHub
      File from HackMD

      GitHub Link Settings

      File linked

      Linked by
      File path
      Last synced branch
      Available push count

      Danger Zone

      Unlink
      You will no longer receive notification when GitHub file changes after unlink.

      Syncing

      Push failed

      Push successfully