# Assignment1: RISC-V Assembly and Instruction Pipeline contributed by <ukp66482> [TOC] ## Leetcode No.1342 Number of Steps to Reduce a Number to Zero ### Problem Goal The goal of this problem is to calculate the number of steps required to reduce a given integer to zero. In each step: - If the number is even, divide it by 2. - If the number is odd, subtract 1. The process continues until the number becomes zero, and the program returns the total number of steps taken. ### Solution 1 Loop based A simple while-loop repeatedly checks if the number is even or odd and updates it accordingly until it reaches zero. #### C Code ```c int numberOfSteps_simple(uint32_t num) { int steps = 0; while (num != 0u) { if ((num & 1u) == 0u) num >>= 1; // even → divide by 2 else num -= 1u; // odd → subtract 1 steps++; } return steps; } ``` #### Assembly Code ``` asm NUMBEROFSTEPS_SIMPLE: addi sp, sp, -4 sw ra, 0(sp) li t0, 0 # steps = 0 WHILE_LOOP: beqz a0, DONE # while(num != 0) andi t1, a0, 1 # check LSB (num % 2) beqz t1, EVEN addi a0, a0, -1 # num -= 1 addi t0, t0, 1 # steps++ jal x0, WHILE_LOOP EVEN: srli a0, a0, 1 # num /= 2 addi t0, t0, 1 # steps++ jal x0, WHILE_LOOP DONE: mv a0, t0 # return steps lw ra, 0(sp) addi sp, sp, 4 ret ``` ### Solution 2 CLZ + POPCOUNT The total number of steps can be expressed as: **(bit length − 1) + number of 1 bits** #### Why? 1. Each 1 bit causes one “subtract 1” operation (odd → even). 2. Each division by 2 corresponds to a right shift — that is, the number of bit positions. 3. However, the most significant 1 is only subtracted once at the end and is not divided again, so we subtract 1 from the total. #### C code ```c int numberOfSteps(int num) { if (num == 0) return 0; int bits = 32 - __builtin_clz(num); int ones = __builtin_popcount(num); return (bits - 1) + ones; } ``` `numberOfSteps` function calculates how many steps are needed to reduce a number to zero. It utilizes two GCC/Clang built-in functions: * `__builtin_clz()` counts the number of leading zeros in the 32-bit binary form of input, helping determine the position of the most significant bit. * `__builtin_popcount()` counts how many bits are set to 1 in num, representing how many subtraction steps occur. Both are compiler intrinsics provided by GCC and Clang, which are often translated directly into efficient CPU instructions #### Assembly Code #### CLZ ``` asm #//////////////////////////////////////////// # # CLZ Function # # input a0 = x, return a0 = orig x, a1 = clz_result # #//////////////////////////////////////////// CLZ: addi sp, sp, -8 sw ra, 0(sp) sw a0, 4(sp) addi t0, x0, 32 # n = 32 addi t1, x0, 16 # c = 16 CLZ_LOOP: srl t2, a0, t1 # t2 = x >> c beq t2, x0, CLZ_LOWER # if(t2 == 0) goto lower sub t0, t0, t1 # n -= c mv a0, t2 # x = y CLZ_LOWER: srli t1, t1, 1 bne t1, x0, CLZ_LOOP sub a1, t0, a0 # return = n - x lw ra, 0(sp) lw a0, 4(sp) addi sp, sp, 8 ret ``` ##### Key Idea This is a **binary-search–style** implementation of CLZ — it determines how many zeros precede the most significant 1 bit in the number efficiently **without looping through all bits**. #### POPCOUNT ``` asm #//////////////////////////////////////////// # # POPCOUNT Function # # input a0 = x, return a0 = popcount result # #//////////////////////////////////////////// POPCOUNT: li t1, 0x55555555 # t1 = 0x55555555 and t2, a0, t1 # t2 = u & 0x55555555 srli t3, a0, 1 # t3 = u >> 1 and t3, t3, t1 # t3 = (u >> 1) & 0x55555555 add a0, t2, t3 # u = (u & 0x55555555) + ((u >> 1) & 0x55555555) li t1, 0x33333333 # t1 = 0x33333333 and t2, a0, t1 # t2 = u & 0x33333333 srli t3, a0, 2 # t3 = u >> 2 and t3, t3, t1 # t3 = (u >> 2) & 0x33333333 add a0, t2, t3 # u = (u & 0x33333333) + ((u >> 2) & 0x33333333) li t1, 0x0F0F0F0F # t1 = 0x0F0F0F0F and t2, a0, t1 # t2 = u & 0x0F0F0F0F srli t3, a0, 4 # t3 = u >> 4 and t3, t3, t1 # t3 = (u >> 4) & 0x0F0F0F add a0, t2, t3 # u = (u & 0x0F0F0F0F) + ((u >> 4) & 0x0F0F0F0F) li t1, 0x00FF00FF # t1 = 0x00FF00FF and t2, a0, t1 # t2 = u & 0x00FF00FF srli t3, a0, 8 # t3 = u >> 8 and t3, t3, t1 # t3 = (u >> 8) & 0x00FF00FF add a0, t2, t3 # u = (u & 0x00FF00FF) + ((u >> 8) & 0x00FF00FF) li t1, 0x0000FFFF # t1 = 0x0000FFFF and t2, a0, t1 # t2 = u & 0x0000FFFF srli t3, a0, 16 # t3 = u >> 16 and t3, t3, t1 # t3 = (u >> 16) & 0x0000FFFF add a0, t2, t3 # u = (u & 0x0000FFFF) + ((u >> 16) & 0x0000FFFF) ret ``` ##### POPCOUNT Algorithm Explanation It implements the parallel bit count (Hamming weight) algorithm, which sums bits in groups **using bitwise operations — without using loops**. 1. Pairwise summation (every 2 bits): u = (u & 0x55555555) + ((u >> 1) & 0x55555555) → counts 1s in every 2-bit group. 2. Group 4 bits: u = (u & 0x33333333) + ((u >> 2) & 0x33333333) → accumulates results over 4-bit chunks. 3. Group 8 bits: u = (u & 0x0F0F0F0F) + ((u >> 4) & 0x0F0F0F0F) 4. Group 16 bits: u = (u & 0x00FF00FF) + ((u >> 8) & 0x00FF00FF) 5. Group 32 bits: u = (u & 0x0000FFFF) + ((u >> 16) & 0x0000FFFF) → popcount final result. >Example: u = (u & 0x55555555) + ((u >> 1) & 0x55555555) > >1: (u & 0x55555555) > >Keeps only the bits in even positions (0, 2, 4, 6, ...). The mask 0x55555555 in binary is 0b_0101_0101_0101..., so this filters out every other bit (the ones in odd positions are cleared to 0). > >2: (u >> 1) > >Shifts all bits in u right by 1 position. This moves each odd-position bit into the place of its neighboring even bit, allowing them to be added together. > >3: ((u >> 1) & 0x55555555) > >After shifting, applies the same mask to keep only the shifted bits that were originally in odd positions. > >4: Add the two results together: > >Sums the counts of bits in each pair (2-bit group). Each 2-bit group now contains the number of 1s in those two bits (possible results per group: 00, 01, 10). #### NumberOfStep ``` asm NUMBEROFSTEPS: addi sp, sp, -4 sw ra, 0(sp) # save return address beqz a0, RETURN_ZERO jal ra, CLZ # a0 = original num, a1 = clz_result addi a2, x0, 32 sub a1, a2, a1 # a1 = 32 - clz(num) jal ra, POPCOUNT # a0 = popcount(num) add a0, a0, a1 # steps = popcount(num) + (32 - clz(num)) addi a0, a0, -1 # steps -= 1 lw ra, 0(sp) # restore return address addi sp, sp, 4 ret RETURN_ZERO: lw ra, 0(sp) # restore return address addi sp, sp, 4 addi a0, x0, 0 # steps = 0 ret ``` ### Test Function ``` asm #//////////////////////////////////////////// # # TEST Function # #//////////////////////////////////////////// TEST: addi sp, sp, 4 sw ra, 0(sp) # save return address la s0, TEST_DATA la s1, TEST_RESULT li s2, 20 # number of test cases li s3, 0 # current test case index TEST_LOOP: beq s3, s2, TEST_ALL_PASSED # if (i == number of test cases) all passed lw a0, 0(s0) # load test data jal ra, NUMBEROFSTEPS # call numberOfSteps lw t0, 0(s1) # load expected result bne a0, t0, TEST_FAIL # if (result != expected) fail addi s0, s0, 4 # move to next test data addi s1, s1, 4 # move to next expected result addi s3, s3, 1 # i++ jal x0, TEST_LOOP TEST_ALL_PASSED: la a0, TEST_OK_MSG li a7, 4 ecall lw ra, 0(sp) # restore return address addi sp, sp, 4 ret TEST_FAIL: la a0, TEST_FAIL_MSG li a7, 4 ecall lw ra, 0(sp) # restore return address addi sp, sp, 4 ret ``` ### Test Case | Test Input | Expected Output | |------------|-----------------| | 0 | 0 | | 1 | 1 | | 2 | 2 | | 3 | 3 | | 4 | 3 | | 7 | 5 | | 8 | 4 | | 14 | 6 | | 15 | 7 | | 16 | 5 | | 31 | 9 | | 32 | 6 | | 63 | 11 | | 64 | 7 | | 123 | 12 | | 255 | 15 | | 256 | 9 | | 1023 | 19 | | 1024 | 11 | | 0xFFFFFFFF | 63 | ### Performance Comparison Between Two Solutions Both Solution 1 (Loop-Based) and Solution 2 (CLZ + POPCOUNT) were tested under the same test function and identical test cases. This ensures a fair comparison, as both implementations were evaluated using the same input data, expected outputs, and verification process. Therefore, the performance difference observed — in terms of cycles, CPI, and IPC — reflects only the efficiency of each algorithm’s logic, not variations in testing conditions. #### Solution 1: Loop Based ![image](https://hackmd.io/_uploads/S1Zr9exalg.png) #### Solution 2: CLZ + POPCOUNT ![image](https://hackmd.io/_uploads/Sy0KbleTgg.png) #### Conclusion Solution 2 (CLZ + POPCOUNT) executes fewer cycles and has a lower CPI, showing higher efficiency. However, Solution 1 (loop-based) is simpler and easier to understand, though slower in performance. ## 5-Stage Pipeline CPU In this section, we analyze how a RISC-V instruction is executed in a 5-stage pipelined CPU. The five classic pipeline stages are: 1. IF (Instruction Fetch) – Retrieve instruction from instruction memory. 2. ID (Instruction Decode / Register Fetch) – Decode the instruction and read source registers. 3. EX (Execute) – Perform arithmetic or logical operations in the ALU. 4. MEM (Memory Access) – Access data memory if needed (load/store). 5. WB (Write Back) – Write the computation result back to the destination register. To demonstrate how each stage operates, I use the CLZ function as an example.The function utilizes several R-type and I-type instructions, including logical shifts, branches, and arithmetic operations. By examining these instructions, we can clearly observe how data flows through each pipeline stage, and how hazards may occur and be resolved. ### Assembly Code Below is the assembly implementation of the CLZ function : ``` asm main: addi a0, x0, 0xff jal ra, CLZ mv a0, a1 li a7, 1 ecall li a7, 10 ecall #//////////////////////////////////////////// # # CLZ Function # # input a0 = x, output a0 = orig x, a1 = clz_result # #//////////////////////////////////////////// CLZ: addi sp, sp, -8 sw ra, 0(sp) sw a0, 4(sp) addi t0, x0, 32 # n = 32 addi t1, x0, 16 # c = 16 CLZ_LOOP: srl t2, a0, t1 # t2 = x >> c beq t2, x0, CLZ_LOWER # if(t2 == 0) goto lower sub t0, t0, t1 # n -= c mv a0, t2 # x = y CLZ_LOWER: srli t1, t1, 1 bne t1, x0, CLZ_LOOP sub a1, t0, a0 # return = n - x lw ra, 0(sp) lw a0, 4(sp) addi sp, sp, 8 ret ``` ### Ripes Executable Code The following machine-level disassembly was generated by Ripes after assembling the CLZ program. Each line shows the instruction address, encoded 32-bit machine code, and the corresponding assembly instruction recognized by the simulator. ``` asm 00000000 <main>: 0: 0ff00513 addi x10 x0 255 4: 018000ef jal x1 24 <CLZ> 8: 00058513 addi x10 x11 0 c: 00100893 addi x17 x0 1 10: 00000073 ecall 14: 00a00893 addi x17 x0 10 18: 00000073 ecall 0000001c <CLZ>: 1c: ff810113 addi x2 x2 -8 20: 00112023 sw x1 0 x2 24: 00a12223 sw x10 4 x2 28: 02000293 addi x5 x0 32 2c: 01000313 addi x6 x0 16 00000030 <CLZ_LOOP>: 30: 006553b3 srl x7 x10 x6 34: 00038663 beq x7 x0 12 <CLZ_LOWER> 38: 406282b3 sub x5 x5 x6 3c: 00038513 addi x10 x7 0 00000040 <CLZ_LOWER>: 40: 00135313 srli x6 x6 1 44: fe0316e3 bne x6 x0 -20 <CLZ_LOOP> 48: 40a285b3 sub x11 x5 x10 4c: 00012083 lw x1 0 x2 50: 00412503 lw x10 4 x2 54: 00810113 addi x2 x2 8 58: 00008067 jalr x0 x1 0 ``` ### Instruction Analysis Example – sw ra, 0(sp) In the CLZ function, the instruction ``` asm sw ra, 0(sp) ``` This instruction stores the return address (ra) into the memory location pointed to by the stack pointer (sp). This ensures that the function can later restore the return address before executing ret. It is an S-type (Store Word) instruction, which writes a 32-bit word from a source register into memory using a base-address-plus-offset addressing mode. #### Instruction Type Format: `sw rs2, imm(rs1)` Type: `S-type (Store)` Main purpose: Save data from a register into memory. Field Bits Value Meaning: ``` 31 25 24 20 19 15 14 12 11 7 6 0 +-----------+-------+-------+------+----------+---------+ | imm[11:5] | rs2 | rs1 |funct3| imm[4:0] | opcode | +-----------+-------+-------+------+----------+---------+ | 0000000 | x1 | x2 | 010 | 00000 | 0100011 | ra sp SW imm=0 store ``` #### IF ![image](https://hackmd.io/_uploads/SJM4AP5axx.png) In this stage, the CPU fetches the instruction from memory using the current program counter (PC). From the diagram, the PC = `0x00000020`, and the instruction read from Instr. memory is `0x00112023`, which corresponds to `sw x1, 0(x2)` in the disassembly. The Compressed Decoder checks whether the instruction is a 16-bit compressed type. Since the input flag is `0x1`, it indicates a normal 32-bit instruction, so no decompression is required. At the same time, the adder calculates the next PC as 0x00000024, preparing for the next fetch. Finally, the IF/ID pipeline register stores the current PC `0x00000020`, the next PC `0x00000024`, and the fetched instruction `0x00112023`, which will be used for decoding in the next stage. #### ID ![image](https://hackmd.io/_uploads/rkdYetcTgl.png) In the Instruction Decode (ID) stage, the instruction `0x00112023` is decoded as `sw x1, 0(x2)`. The Decode unit extracts the opcode `0x23`, indicating a `store-type instruction`. It also checks the `funct3` field, which has the value `0x2`, confirming that this specific instruction is `sw (store word)`. The Control unit then generates the appropriate control signals for the store operation — enabling `MemWrite` The lower control signal labeled `IMM` corresponds to ALUSrc, which tells the ALU to use the immediate value (offset) as its second operand instead of reading a second register. In addition, the ALU operation control is set to `ADD`, instructing the ALU to perform an addition to calculate the effective memory address for the store instruction. The Register File reads the values of the two registers: - rs1: x2 (sp) = `0x7FFFFFF0` → base address for memory access - rs2: x1 (ra) = `0x00000008` → data to be stored The Immediate Generator produces the offset value `0x00000000` from the instruction’s immediate field. All of this information — including the register data, control signals, and immediate — is then passed into the ID/EX pipeline register, ready for the address computation in the EX stage. #### EXE ![image](https://hackmd.io/_uploads/rk8-ktc6ee.png) In the Execute (EX) stage, the CPU performs the address calculation required for the store instruction. From the diagram, the ALU receives two inputs: - Op1 = x2 = `0x7FFFFFF0` (base address) - Op2 = Immediate = `0x00000000` (offset) In last stage, The control signal, set to `IMM`, ensures that the ALU uses the immediate value instead of a second register. The ALU control signal specifies an add operation, producing the effective memory address ``` 0x7FFFFFF0 + 0x00000000 = 0x7FFFFFF0. ``` The computed result `0x7FFFFFF0` will be used as the memory address in the next stage `MEM`,while the data to be stored (x1 = 0x00000008) is also forwarded through the pipeline to the EX/MEM register. #### MEM ![image](https://hackmd.io/_uploads/rkawMtcaee.png) In the MEM stage, the CPU performs the actual data memory operation. From the diagram, the Data Memory module receives two inputs: - Address = 0x7FFFFFF0, which comes from the ALU result calculated in the previous EX stage. - Data in = 0x00000008, the value stored in register x1 (ra). The MemWrite control signal is enabled `WrEn` = 1, allowing the processor to write the data into memory at the specified address. Since this instruction is a store operation, there is no data read from memory `Read out` = 0x00000000. The updated memory content is then forwarded through the MEM/WB pipeline register, although for sw, the write-back path remains unused. #### WB ![image](https://hackmd.io/_uploads/HJ8qrF56ee.png) In the WB stage, the processor determines whether data should be written back to the register file. For the `sw (store word)` instruction, there is no register destination, since the operation only writes data to memory and does not produce a value that needs to be stored in a register. From the diagram, the multiplexer selects the ALU result `0x7FFFFFF0` as the output path, but the RegWrite control signal is disabled (0x0). As a result, no data is written back to the register file. ## UF8 ### CLZ ``` asm #//////////////////////////////////////////// # # CLZ Function # # input a0 = x, return a0 = clz_result, a1 = original x # #//////////////////////////////////////////// CLZ: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) # input x addi t0, x0, 32 # t0 = n addi t1, x0, 16 # t1 = c CLZ_LOOP: srl t2, a0, t1 # t2 = x >> c = y beq t2, x0, CLZ_MSB_IN_LOWER_HALF # if y == 0 sub t0, t0, t1 # n = n - c add a0, t2, x0 # x = y CLZ_MSB_IN_LOWER_HALF: srli t1, t1, 1 bne t1, x0, CLZ_LOOP sub a0, t0, a0 # return lw ra, 4(sp) lw a1, 0(sp) # input x addi sp, sp, 8 ret # a0 clz_result, a1 input ``` ### UF8_ENCODE ``` asm #//////////////////////////////////////////// # # UF8_ENCODE Function # # input a0 = x, output a0 = encode_result # #//////////////////////////////////////////// UF8_ENCODE: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) addi t0, x0, 16 blt a0, t0, ENCODE_RET # if x < 16 return a0 jal ra, CLZ addi t0, x0, 31 sub t0, t0, a0 # t0 = msb = 31 - clz_result lw a0, 0(sp) # reload a0 = value addi t1, x0, 0 # t1 = exponent = 0 addi t4, x0, 0 # t4 = overflow = 0 addi t3, x0, 5 blt t0, t3, ENCODE_FIND_E # if msb < 5 goto ENCODE_FIND addi t1, t0, -4 # exponent = msb - 4 addi t3, x0, 15 blt t1, t3, SKIP # if exponent < 15 goto SKIP addi t1, x0, 15 # exponent = 15 SKIP: addi t5, x0, 0 # t5 = e = 0 lw a0, 0(sp) # reload a0 = value ENCODE_CAL_OVERFLOW: bge t5, t1, ENCODE_ADJUST # if e >= exponent slli t4, t4, 1 # overflow = overflow << 1 addi t4, t4, 16 # overflow = overflow + 16 addi t5, t5, 1 # e++ jal x0, ENCODE_CAL_OVERFLOW ENCODE_ADJUST: bge x0, t1, ENCODE_FIND_E # if 0 >= exponent bge a0, t4, ENCODE_FIND_E # if value >= overflow addi t4, t4, -16 # overflow = overflow - 16 srli t4, t4, 1 # overflow = overflow >> 1 addi t1, t1, -1 # exponent-- jal x0, ENCODE_ADJUST ENCODE_FIND_E: addi t3, x0, 15 bge t1, t3, ENCODE_COMBINE # if exponent >= 15 break slli t6, t4, 1 # t6 = next_overflow = overflow << 1 addi t6, t6, 16 # next_overflow += 16 blt a0, t6, ENCODE_COMBINE # if value < next_overflow break addi t1, t1, 1 # exponent++ addi t4, t6, 0 # overflow = next_overflow jal x0, ENCODE_FIND_E ENCODE_COMBINE: sub t2, a0, t4 # t2 = value - overflow srl t2, t2, t1 # t2 = mantissa = (value - overflow) >> exponent slli t1, t1, 4 # t1 = exponent << 4 or a0, t1, t2 ENCODE_RET: lw ra, 4(sp) lw a1, 0(sp) addi sp, sp, 8 ret ``` ### UF8_DECODE ``` asm #//////////////////////////////////////////// # # UF8_DECODE Function # # input a0 = x, return a0 = decode_result # #//////////////////////////////////////////// UF8_DECODE: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) andi t0, a0, 0x0f # t0 = mantissa srli t1, a0, 4 # t1 = exponent addi t2, x0, 15 sub t2, t2, t1 # t2 = 15 - exponent lui t3, 0x8 # 0x8 << 12 = 0x8000 upper 20 bits addi t3, t3, -1 # 0x8000 - 1 = 0x7FFF srl t3, t3, t2 # t3 = 0x7FFF >> (15 - exponent) slli t3, t3, 4 # t3 = t3 << 4 (offset) sll t0, t0, t1 add t0, t0, t3 addi a0, t0, 0 lw ra, 4(sp) lw a1, 0(sp) addi sp, sp, 8 ret # a0 decode_result, a1 input ``` ### Test Function ``` asm #//////////////////////////////////////////// # # Test for UF8_ENCODE and UF8_DECODE # (exhaustive for all 256 possible inputs) # #//////////////////////////////////////////// main: addi a3, x0, 0 # i = 0 addi a4, x0, 256 # limit = 256 addi s3, x0, 1 # passed = true (1) LOOP: bge a3, a4, END # if i >= 256 stop mv a0, a3 # fl = i jal ra, UF8_DECODE mv s1, a0 # value mv a0, s1 jal ra, UF8_ENCODE mv s2, a0 # fl2 bne s2, a3, SET_FAIL # if fl2 != i fail NEXT_ITER: addi a3, a3, 1 jal x0, LOOP SET_FAIL: addi s3, x0, 0 # passed = false addi a3, a4, 0 # force break (i = limit) jal x0, LOOP END: beq s3, x0, EXIT # if !passed skip message la a0, okmsg li a7, 4 ecall EXIT: li a7, 10 ecall ``` #### How the Test Works The goal of this program is to exhaustively test `UF8_ENCODE` and `UF8_DECODE` for all possible 8-bit inputs (0–255). The program ensures that: For every 8-bit input i, UF8_ENCODE(UF8_DECODE(i)) == i. This verifies the correctness of the encode/decode pair across the full input space. #### Test Result ![image](https://hackmd.io/_uploads/rJoRvHAhle.png) ![image](https://hackmd.io/_uploads/rkQXWll6ex.png) ## BFloat16 ### Conversion Function #### F32_TO_BF16 ``` asm #//////////////////////////////////////////// # # F32_TO_BF16 function # # input: a0 = f32 bits, return bf16 bits in a0 # #//////////////////////////////////////////// F32_TO_BF16: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) srli t0, a0, 23 # t0 = exponent + mantissa andi t0, t0, 0xFF # t0 = exponent + 0x7F addi t1, x0, 0xFF # t1 = 0xFF beq t0, t1, F32_TO_BF16_INF_NAN # if exponent == 0xFF, go to INF_NAN srli t0, a0, 16 # t0 = f32bits >> 16 andi t0, t0, 1 # t0 = (f32bits >> 16) & 1 = tmp add a0, a0, t0 # a0 = f32bits + tmp li t1, BF16_ALL_MASK # t1 = 0x7FFF add a0, a0, t1 # a0 = f32bits + tmp + 0x7FFF srli a0, a0, 16 # a0 = f32bits >> 16 lw ra, 4(sp) addi sp, sp, 8 ret F32_TO_BF16_INF_NAN: srli a0, a0, 16 # a0 = f32bits >> 16 lw ra, 4(sp) addi sp, sp, 8 ret ``` #### BF16_TO_F32 ``` asm #//////////////////////////////////////////// # # BF16_TO_F32 function # # input: a0 = bf16 bits, return f32 bits in a0 # #//////////////////////////////////////////// BF16_TO_F32: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) slli a0, a0, 16 # a0 = bf16 << 16 lw ra, 4(sp) addi sp, sp, 8 ret ``` ### Special Value Check Function #### BF16_EXP_ALL1 ``` asm #//////////////////////////////////////////// # # BF16_EXP_ALL1 function # # input: a0 = x, return 1 if exponent is all 1 else 0 # #//////////////////////////////////////////// BF16_EXP_ALL1: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) li t0, BF16_EXP_MASK and t1, a0, t0 # t1 = a0 & BF16_EXP_MASK bne t1, t0, BF16_EXP_NOTALL1 # if (a0 & BF16_EXP_MASK) != BF16_EXP_MASK, go to NOTALL1 addi a0, x0, 1 # exponent is all 1 return 1 lw ra, 4(sp) addi sp, sp, 8 ret BF16_EXP_NOTALL1: addi a0, x0, 0 # exponent is not all 1 return 0 lw ra, 4(sp) addi sp, sp, 8 ret ``` #### BF16_MANT_NOT0 ``` asm #//////////////////////////////////////////// # # BF16_MANT_NOT0 function # # input: a0 = x, return 1 if mantissa is not 0 else 0 # #//////////////////////////////////////////// BF16_MANT_NOT0: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) li t0, BF16_MANT_MASK # t0 = BF16_MANT_MASK and t1, a0, t0 # t1 = a0 & BF16_MANT_MASK beq t1, x0, BF16_MANT0 addi a0, x0, 1 # mantissa is not 0 return 1 lw ra, 4(sp) addi sp, sp, 8 ret BF16_MANT0: lw ra, 4(sp) addi sp, sp, 8 addi a0, x0, 0 # mantissa is 0 return 0 ret ``` #### BF16_ISNAN ``` asm #//////////////////////////////////////////// # # BF16_ISNAN function # # input: a0 = x, return 1 if x is NaN else 0 # #//////////////////////////////////////////// BF16_ISNAN: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) jal ra, BF16_EXP_ALL1 beq a0, x0, BF16_NOTNAN # if exponent is not all 1, go to NOTNAN lw a0, 0(sp) # load original a0 jal ra, BF16_MANT_NOT0 beq a0, x0, BF16_NOTNAN # if mantissa is 0, go to NOTNAN addi a0, x0, 1 # is NaN return 1 lw ra, 4(sp) addi sp, sp, 8 ret BF16_NOTNAN: addi a0, x0, 0 # not NaN return 0 lw ra, 4(sp) addi sp, sp, 8 ret ``` #### BF16_INF ``` asm #//////////////////////////////////////////// # # BF16_ISINF function # # input: a0 = x, return 1 if x is Inf else 0 # #//////////////////////////////////////////// BF16_ISINF: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) jal ra, BF16_EXP_ALL1 beq a0, x0, BF16_NOTINF # if exponent is not all 1, go to NOTINF lw a0, 0(sp) # load original a0 jal ra, BF16_MANT_NOT0 # a0 = 1 if mantissa is not 0 else 0 bne a0, x0, BF16_NOTINF # if mantissa is not 0, go to NOTINF addi a0, x0, 1 # is Inf return 1 lw ra, 4(sp) addi sp, sp, 8 ret BF16_NOTINF: addi a0, x0, 0 # not Inf return 0 lw ra, 4(sp) addi sp, sp, 8 ret ``` #### BF16_ISZERO ``` asm #//////////////////////////////////////////// # # BF16_ISZERO function # # input: a0 = x, return 1 if x is zero else 0 # #//////////////////////////////////////////// BF16_ISZERO: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) li t0, BF16_ALL_MASK and t1, a0, t0 # t1 = a0 & BF16_ALL_MASK bne t1, x0, BF16_NOTZERO # if (a0 & BF16_ALL_MASK) != 0, go to NOTZERO addi a0, x0, 1 # is zero return 1 lw ra, 4(sp) addi sp, sp, 8 ret BF16_NOTZERO: addi a0, x0, 0 # not zero return 0 lw ra, 4(sp) addi sp, sp, 8 ret ``` ### Arithmetic Function #### BF16_ADD ``` asm #//////////////////////////////////////////// # # BF16_ADD function # # input: a0 = a, a1 = b # return: a0 = a + b # #//////////////////////////////////////////// BF16_ADD: addi sp, sp, -12 sw ra, 8(sp) sw a1, 4(sp) # 4(sp) = b sw a0, 0(sp) # 0(sp) = a srli t1, a0, 15 # t1 = sign_a srli t2, a1, 15 # t2 = sign_b srli t3, a0, 7 srli t4, a1, 7 andi t3, t3, 0xFF # t3 = exp_a andi t4, t4, 0xFF # t4 = exp_b andi t5, a0, 0x7F # t5 = mant_a andi t6, a1, 0x7F # t6 = mant_b addi t0, x0, 0xFF beq t3, t0, 1f # if exp_a == 0xFF, go to 1 beq t4, t0, RETURN_B # if exp_b == 0xFF, go to RETURN_B beq t3, x0, 3f # if exp_a == 0, go to 3 ADD_A_NOT_ZERO: beq t4, x0, 4f # if exp_b == 0, go to 4 ADD_B_NOT_ZERO: bne t3, x0, 5f # if exp_a != 0, go to 5 ADD_EXP_A_NOT_ZERO: bne t4, x0, 6f # if exp_b != 0, go to 6 jal x0, ADD_EXP_B_NOT_ZERO 1: bne t5, x0, RETURN_A # if mant_a != 0, go to RETURN_A beq t6, t0, 2f # if mant_b == 0xFF, go to 2 jal x0, RETURN_A # else return a 2: bne t6, x0, RETURN_B # if mant_b != 0, go to RETURN_B beq t1, t2, RETURN_B # if sign_a == sign_b, go to RETURN_B jal x0, RETURN_NAN # else return NAN 3: beq t5, x0, RETURN_B # if mant_a == 0, go to RETURN_B jal x0, ADD_A_NOT_ZERO 4: beq t6, x0, RETURN_A # if mant_b == 0, go to RETURN_A jal x0, ADD_B_NOT_ZERO 5: ori t5, t5, 0x80 # mant_a = mant_a | 0x80 jal x0, ADD_EXP_A_NOT_ZERO 6: ori t6, t6, 0x80 # mant_b = mant_b | 0x80 jal x0, ADD_EXP_B_NOT_ZERO ADD_EXP_B_NOT_ZERO: sub a0, t3, t4 # a0 = exp_a - exp_b bgt a0, x0, 1f # if exp_a > exp_b, go to 1 blt a0, x0, 2f # if exp_a < exp_b, go to 2 addi a3, t3, 0 # a3 = result_exp = exp_a jal x0, ADD_EXP_DONE 1: addi a3, t3, 0 # a3 = result_exp = exp_a addi t0, x0, 8 bgt a0, t0, RETURN_A # if exp_a - exp_b > 8, go to RETURN_A srl t6, t6, a0 # mant_b = mant_b >> (exp_a - exp_b) jal x0, ADD_EXP_DONE 2: addi a3, t4, 0 addi t0, x0, -8 blt a0, t0, RETURN_B # if exp_a - exp_b < -8, go to RETURN_B sub a0, x0, a0 # a0 = -a0 srl t5, t5, a0 # mant_a = mant_a >> (exp_b - exp_a) jal x0, ADD_EXP_DONE ADD_EXP_DONE: bne t1, t2, ADD_DIFF_SIGN # if sign_a != sign_b, go to ADD_DIFF_SIGN ADD_SAME_SIGN: addi a2, t1, 0 # a2 = result_sign = sign_a add a4, t5, t6 # a4 = mant_a + mant_b = result_mant andi t0, a4, 0x100 # t0 = (mant_a + mant_b) & 0x100 bne t0, x0, 1f jal x0, RETURN_ADD 1: srli a4, a4, 1 # result_mant = result_mant >> 1 addi a3, a3, 1 # result_exp = result_exp + 1 addi t0, x0,0xFF beq a3, t0, 2f # if result_exp == 0xFF, go to 2 bgt a3, t0, 2f # if result_exp >= 0xFF, go to 2 jal x0, RETURN_ADD 2: slli a0, a2, 15 # return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; li t0, BF16_EXP_MASK or a0, a0, t0 lw ra, 8(sp) addi sp, sp, 12 ret ADD_DIFF_SIGN: beq t5, t6, 2f # if mant_a == mant_b, go to 2 bgt t5, t6, 2f # if mant_a > mant_b, go to 2 addi a2, t2, 0 # a2 = result_sign = sign_b sub a4, t6, t5 # a4 = mant_b - mant_a = result_mant 1: beq a4, x0, RETURN_ZERO # if result_mant == 0, go to RETURN_ZERO 3: andi t0, a4, 0x80 # t0 = result_mant & 0x80 beq t0, x0, 4f jal x0, RETURN_ADD 4: slli a4, a4, 1 # result_mant = result_mant << 1 addi a3, a3, -1 # result_exp = result_exp - 1 bgt a3, x0, 3b # if result_exp > 0, go to 3 jal x0, RETURN_ZERO # else go to RETURN_ZERO 2: addi a2, t1, 0 # a2 = result_sign = sign_a sub a4, t5, t6 # a4 = mant_a - mant_b = result_mant jal x0, 1b lw ra, 8(sp) addi sp, sp, 12 ret RETURN_A: lw a0, 0(sp) lw ra, 8(sp) addi sp, sp, 12 ret RETURN_B: lw a0, 4(sp) lw ra, 8(sp) addi sp, sp, 12 ret RETURN_NAN: li a0, 0x7FC0 lw ra, 8(sp) addi sp, sp, 12 ret RETURN_ZERO: addi a0, x0, 0 lw ra, 8(sp) addi sp, sp, 12 ret RETURN_ADD: slli a0, a2, 15 # a0 = result_sign << 15 andi t0, a3, 0xFF slli t0, t0, 7 or a0, a0, t0 # a0 = (result_sign << 15) | result_exp andi t0, a4, 0x7F or a0, a0, t0 # a0 = (result_sign << 15) | (result_exp << 7) | result_mant lw ra, 8(sp) addi sp, sp, 12 ret ``` #### BF16_SUB ``` asm #//////////////////////////////////////////// # # BF16_SUB function # # input: a0 = a, a1 = b # return: a0 = a - b # #//////////////////////////////////////////// BF16_SUB: li t0, BF16_SIGN_MASK xor a1, a1, t0 # b = b ^ 0x8000 jal x0, BF16_ADD ``` #### BF16_MUL ``` asm #//////////////////////////////////////////// # # BF16_MUL function # # input: a0 = a, a1 = b # return: a0 = a * b # #//////////////////////////////////////////// BF16_MUL: addi sp, sp, -12 sw ra, 8(sp) sw a1, 4(sp) # 4(sp) = b sw a0, 0(sp) # 0(sp) = a srli t1, a0, 15 # t1 = sign_a srli t2, a1, 15 # t2 = sign_b srli t3, a0, 7 srli t4, a1, 7 andi t3, t3, 0xFF # t3 = exp_a andi t4, t4, 0xFF # t4 = exp_b andi t5, a0, 0x7F # t5 = mant_a andi t6, a1, 0x7F # t6 = mant_b xor a2, t1, t2 # a2 = result_sign = sign_a ^ sign_b addi t0, x0, 0xFF beq t3, t0, 1f # if exp_a == 0xFF, go to 1 jal x0, CHECK_B 1: # exp_a == 0xFF bne t5, x0, RETURN_A # a is NaN beq t4, x0, 2f # exp_b == 0 → check mant_b beq t4, t0, 3f # exp_b == 0xFF → check mant_b jal x0, RETURN_INF 2: # exp_b == 0 beq t6, x0, RETURN_NAN # Inf * 0 -> NaN jal x0, RETURN_INF # Inf * subnormal(nonzero) -> Inf 3: # exp_b == 0xFF bne t6, x0, RETURN_B # b == NaN jal x0, RETURN_INF # b == INF CHECK_B: beq t4, t0, 1f # if exp_b == 0xFF, go to 1 jal x0, CHECK_ZERO 1: bne t6, x0, RETURN_B # if mant_b != 0, go to RETURN_B beq t3, t0, 2f # if exp_a == 0, go to 2 jal x0, RETURN_INF 2: beq t5, x0, RETURN_NAN # if mant_a == 0, go to RETURN_NAN jal x0, RETURN_INF CHECK_ZERO: beq t3, x0, CHECK_A_ZERO # if exp_a == 0, check mant_a beq t4, x0, CHECK_B_ZERO # if exp_b == 0, check mant_b jal x0, CONT_MUL CHECK_A_ZERO: beq t5, x0, RETURN_ZERO_SIGN # if mant_a == 0 -> return 0 jal x0, CONT_MUL CHECK_B_ZERO: beq t6, x0, RETURN_ZERO_SIGN # if mant_b == 0 -> return 0 jal x0, CONT_MUL RETURN_ZERO_SIGN: slli a0, a2, 15 # a0 = result_sign << 15 lw ra, 8(sp) addi sp, sp, 12 ret CONT_MUL: addi a3, x0, 0 # a3 = exp_adjust = 0 beq t3, x0, 1f # if exp_a == 0, go to 1 ori t5, t5, 0x80 # mant_a = mant_a | 0x80 jal x0, 3f 1: andi t0, t5, 0x80 # t0 = mant_a & 0x80 beq t0, x0, 2f # if mant_a & 0x80 == 0, go to 2 addi t3, x0, 1 # exp_a = 1 jal x0, 3f 2: slli t5, t5, 1 # mant_a = mant_a << 1 addi a3, a3, -1 # exp_adjust = exp_adjust - 1 jal x0, 1b # go to 1 3: beq t4, x0, 4f # if exp_b == 0, go to 4 ori t6, t6, 0x80 # mant_b = mant_b | 0x80 jal x0, SHIFT_ADD 4: andi t0, t6, 0x80 # t0 = mant_b & 0x80 beq t0, x0, 5f # if mant_b & 0x80 == 0, go to 5 addi t4, x0, 1 # exp_b = 1 jal x0, SHIFT_ADD 5: slli t6, t6, 1 # mant_b = mant_b << 1 addi a3, a3, -1 # exp_adjust = exp_adjust - 1 jal x0, 4b # go to 4 SHIFT_ADD: addi a4, x0, 0 # a4 = result_mant = 0 addi t0, x0, 8 # counter = 8 = mant_b 8-bit SHIFT_LOOP: andi t1, t6, 1 # t1 = mant_b & 1 beq t1, x0, NO_ADD # if (mant_b & 1) == 0, skip add add a4, a4, t5 # result_mant += mant_a NO_ADD: slli t5, t5, 1 # mant_a <<= 1 srli t6, t6, 1 # mant_b >>= 1 addi t0, t0, -1 # counter-- bne t0, x0, SHIFT_LOOP # loop until counter = 0 add a3, a3, t3 add a3, a3, t4 # result_exp = exp_a + exp_b + exp_adjust addi a3, a3, -BF16_EXP_BIAS # result_exp -= 127 li t0, BF16_SIGN_MASK and t0, t0, a4 # t0 = result_mant & 0x8000 bne t0, x0, 1f srli a4, a4, 7 # result_mant >>= 7 andi a4, a4, 0x7F # result_mant &= 0x7F jal x0, 2f 1: srli a4, a4, 8 # result_mant >>= 8 andi a4, a4, 0x7F # result_mant &= 0x7F addi a3, a3, 1 # result_exp++ 2: addi t0, x0, 0xFF beq a3, t0, RETURN_INF # if result_exp == 0xFF, go to RETURN_INF bgt a3, t0, RETURN_INF # if result_exp > 0xFF, go to RETURN_INF 3: beq a3, x0, 4f # if result_exp == 0, go to 4 blt a3, x0, 4f # if result_exp < 0, go to 4 jal x0, RETURN_MUL 4: addi a0, x0, -6 # a0 = -6 blt a3, a0, RETURN_ZERO_SIGN # if result_exp < -6, go to RETURN_ZERO_SIGN addi a3, x0, 0 # result_exp = 0 addi t0, x0, 1 sub t0, t0, a3 # t0 = 1 - result_exp srl a4, a4, t0 # result_mant >>= (1 - result_exp) RETURN_INF: #return (bf16_t) {.bits = (result_sign << 15) | 0x7F80}; slli a0, a2, 15 li t0, BF16_EXP_MASK or a0, a0, t0 lw ra, 8(sp) addi sp, sp, 12 ret RETURN_MUL: slli a0, a2, 15 # a0 = result_sign << 15 andi t0, a3, 0xFF slli t0, t0, 7 or a0, a0, t0 # a0 = (result_sign << 15) | (result_exp << 7) andi t0, a4, 0x7F or a0, a0, t0 # a0 = (result_sign << 15) | (result_exp << 7) | result_mant lw ra, 8(sp) addi sp, sp, 12 ret ``` Since RV32I has no hardware mul instruction, multiplication is implemented using a shift-add algorithm. In this method, the multiplier’s least significant bit (LSB) is checked each cycle: * If it is 1, the multiplicand is added to the partial result. * Then the multiplicand is shifted left (<< 1) and the multiplier is shifted right (>> 1). * This repeats until all bits of the multiplier are processed. #### BF16_DIV ``` asm #//////////////////////////////////////////// # # BF16_DIV function # # input: a0 = a (dividend), a1 = b (divisor) # return: a0 = a / b # #//////////////////////////////////////////// BF16_DIV: addi sp, sp, -12 sw ra, 8(sp) sw a1, 4(sp) # 4(sp) = b sw a0, 0(sp) # 0(sp) = a srli t1, a0, 15 # t1 = sign_a srli t2, a1, 15 # t2 = sign_b srli t3, a0, 7 srli t4, a1, 7 andi t3, t3, 0xFF # t3 = exp_a andi t4, t4, 0xFF # t4 = exp_b andi t5, a0, 0x7F # t5 = mant_a andi t6, a1, 0x7F # t6 = mant_b xor a2, t1, t2 # a2 = result_sign = sign_a ^ sign_b addi t0, x0, 0xFF beq t4, t0, 1f # if exp_b == 0xFF, go to 1 beq t4, x0, 3f # if exp_b == 0, go to 3 8: beq t3, t0, 6f # if exp_a == 0xFF, go to 6 beq t3, x0, 7f # if exp_a == 0, go to 7 jal x0, CONT_DIV 1: bne t6, x0, RETURN_B # if mant_b != 0, go to RETURN_B beq t3, t0, 2f # if exp_a == 0xFF, go to 2 jal x0, RETURN_ZERO_SIGN # return 0 2: bne t5, x0, RETURN_NAN # if mant_a != 0, go to RETURN_NAN jal x0, RETURN_NAN 3: beq t6, x0, 4f # if mant_b == 0, go to 4 jal x0, 8b # go to 8 4: beq t3, x0, 5f # if exp_a == 0, go to 5 jal x0, RETURN_INF # return inf 5: beq t5, x0, RETURN_NAN # if mant_a == 0, go to return_nan jal x0, RETURN_INF # return inf 6: bne t5, x0, RETURN_A # if mant_a != 0, go to RETURN_A jal x0, RETURN_INF # return inf 7: beq t5, x0, RETURN_ZERO_SIGN # if mant_a == 0, go to RETURN_ZERO_SIGN CONT_DIV: bne t3, x0, 1f # if exp_a != 0, go to 1 3: bne t4, x0, 2f # if exp_b != 0, go to 2 jal x0, DIV_SHIFT 1: ori t5, t5, 0x80 # mant_a = mant_a | 0x80 jal x0, 3b 2: ori t6, t6, 0x80 # mant_b = mant_b | 0x80 DIV_SHIFT: slli t1, t5, 15 # t1 = dividend = mant_a << 15 addi t5, t6, 0 # t5 = divisor = mant_b addi t6, x0, 0 # t6 = quotient = 0 addi t0, x0, 0 # t0 = i = 0 DIV_LOOP: addi t2, x0, 16 beq t0, t2, DIV_END slli t6, t6, 1 # quotient <<= 1 addi t2, x0, 15 sub t2, t2, t0 # t2 = 15 - i sll t3, t5, t2 # t3 = divisor << (15 - i) blt t1, t3, DIV_SKIP sub t1, t1, t3 # dividend -= (divisor << (15 - i)) ori t6, t6, 1 # quotient |= 1 DIV_SKIP: addi t0, t0, 1 jal x0, DIV_LOOP DIV_END: srli t3, a0, 7 andi t3, t3, 0xFF # t3 = exp_a sub a3, t3, t4 # a3 = exp_a - exp_b addi a3, a3, BF16_EXP_BIAS # a3 = result_exp = exp_a - exp_b + 127 beq t3, x0, 1f # if exp_a == 0, go to 1 2: beq t4, x0, 3f # if exp_b == 0, go to 3 jal x0, DIV_EXP_DONE 1: addi a3, a3, -1 # result_exp-- jal x0, 2b # go to 2 3: addi a3, a3, 1 # result_exp++ DIV_EXP_DONE: li t0, BF16_SIGN_MASK and t1, t6, t0 # quotient & 0x8000 bne t1, x0, DIV_SHIFT_DONE 3: beq t6, x0, 1f jal x0, DIV_SHIFT_DONE 1: addi t1, x0, 1 # t1 = 1 bgt a3, t1, 2f # if result_exp > 1, go to 2 jal x0, DIV_SHIFT_DONE 2: slli t6, t6, 1 # quotient <<= 1 addi a3, a3, -1 # result_exp-- jal x0, 3b DIV_SHIFT_DONE: srli t6, t6, 8 # quotient >>= 8 andi t6, t6, 0x7F # t6 = result_mant = quotient & 0x7F addi t0, x0, 0xFF beq a3, t0, RETURN_INF # if result_exp == 0xFF, go to RETURN_INF bgt a3, t0, RETURN_INF # if result_exp > 0xFF, go to RETURN_INF beq a3, x0, RETURN_ZERO_SIGN # if result_exp == 0, go to RETURN_ZERO_SIGN blt a3, x0, RETURN_ZERO_SIGN # if result_exp < 0, go to RETURN_ZERO_SIGN RETURN_DIV: slli a0, a2, 15 # a0 = result_sign << 15 andi t0, a3, 0xFF slli t0, t0, 7 or a0, a0, t0 # a0 = (result_sign << 15) | (result_exp << 7) andi t0, t6, 0x7F # t0 = result_mant = quotient & 0x7F or a0, a0, t0 # a0 = (result_sign << 15) | (result_exp << 7) | result_mant lw ra, 8(sp) addi sp, sp, 12 ret ``` #### BF16_SQRT ``` asm #//////////////////////////////////////////// # # BF16_SQRT function # # input: a0(old), return: a0(result) # #//////////////////////////////////////////// BF16_SQRT: addi sp, sp, -8 sw ra, 4(sp) sw a0, 0(sp) # 0(sp) = a srli t1, a0, 15 # t1 = sign srli t2, a0, 7 # t2 = exp andi t2, t2, 0xFF # t2 = exp andi t3, a0, 0x7F # t3 = mant addi t0, x0, 0xFF beq t2, t0, 1f # if exp == 0xFF, go to 1 jal x0, 2f 1: bne t3, x0, RETURN_A_SQRT # if mant != 0, go to RETURN_A bne t1, x0, RETURN_NAN_SQRT # if sign != 0, go to RETURN_NAN jal x0, RETURN_A_SQRT # else return a 2: beq t2, x0, 3f # if exp == 0, go to 3 jal x0, 4f 3: beq t1, x0, RETURN_ZERO_SQRT # if sign == 0, go to RETURN_ZERO_SQRT 4: bne t1, x0, RETURN_NAN_SQRT # if sign != 0, go to RETURN_NAN_SQRT beq t2, x0, RETURN_ZERO_SQRT # if exp == 0, go to RETURN_ZERO_SQRT addi t0, x0, BF16_EXP_BIAS sub t2, t2, t0 # t2 = e = exp - 127 ori t3, t3, 0x80 # t3 = m = mant | 0x80 andi t4, t2, 1 # t4 = e & 1 bne t4, x0, 1f # if (e & 1) != 0, go to 1 jal x0, 2f # else go to 2 1: slli t3, t3, 1 # t3 = m = mant << 1 addi t2, t2, -1 # e = e - 1 srli t2, t2, 1 # e = e >> 1 addi t2, t2, BF16_EXP_BIAS # t2 = new_exp = e + 127 jal x0, ADJUST_DONE 2: srli t2, t2, 1 # e = e >> 1 addi t2, t2, BF16_EXP_BIAS # t2 = new_exp = e + 127 ADJUST_DONE: # t3 = m, t2 = new_exp, t1 = sign addi t0, x0, 90 # low addi t1, x0, 256 # high addi a2, x0, 128 # result BS_LOOP: bgt t0, t1, BS_END add t4, t0, t1 srli t4, t4, 1 # t4 = mid mv a3, t4 # a3 = mid addi a0, x0, 0 # acc = 0 mv t6, a3 # multiplicand = mid mv t5, a3 # multiplier = mid addi a1, x0, 9 # counter = 9 BS_MUL_LOOP: andi t4, t5, 1 # t4 = (multiplier & 1) beq t4, x0, BS_NO_ADD add a0, a0, t6 BS_NO_ADD: slli t6, t6, 1 srli t5, t5, 1 addi a1, a1, -1 bne a1, x0, BS_MUL_LOOP srli a0, a0, 7 # sq BS_UPDATE: bgeu t3, a0, BS_IF j BS_ELSE BS_IF: mv a2, a3 # result = mid addi t0, a3, 1 # low = mid + 1 j BS_LOOP BS_ELSE: addi t1, a3, -1 # high = mid - 1 j BS_LOOP BS_END: addi a0, x0, 256 bgeu a2, a0, 1f # if (result >= 256), go to 1 addi a0, x0, 128 addi a1, x0, 1 blt a2, a0, 2f # if (result < 128), go to 2 jal x0, 5f 1: srli a2, a2, 1 # result >>= 1 addi t2, t2, 1 # new_exp++ jal x0, 5f 2: blt a2, a0, 3f # if (result < 0), go to 3 jal x0, 5f 3: bgt t2, a1, 4f jal x0, 5f 4: slli a2, a2, 1 # result <<= 1 addi t2, t2, -1 # new_exp-- jal x0, 2b 5: andi t3, a2, 0x7F # result_mant = result & 0x7F addi a0, x0, 0xFF bge t2, a0, RETURN_INF_SQRT # if new_exp >= 0xFF, go to RETURN_INF blt t2, x0, RETURN_ZERO_SQRT # if new_exp < 0, go to RETURN_ZERO beq t2, x0, RETURN_ZERO_SQRT # if new_exp == 0, go to RETURN_ZERO andi t2, t2, 0xFF # new_exp = new_exp & 0xFF slli t2, t2, 7 or a0, t2, t3 # a0 = (new_exp << 7) | result_mant lw ra, 4(sp) addi sp, sp, 8 ret RETURN_INF_SQRT: #return (bf16_t) {.bits = (sign << 15) | 0x7F80}; li a0, BF16_EXP_MASK lw ra, 4(sp) addi sp, sp, 8 ret RETURN_A_SQRT: lw a0, 0(sp) lw ra, 4(sp) addi sp, sp, 8 ret RETURN_NAN_SQRT: li a0, 0x7FC0 lw ra, 4(sp) addi sp, sp, 8 ret RETURN_ZERO_SQRT: addi a0, x0, 0 lw ra, 4(sp) addi sp, sp, 8 ret ``` The `bf16_sqrt()` function computes the square root using a **binary search algorithm** on integer values. During each iteration, it needs to calculate `mid * mid` to compare with the mantissa value. Instead of using hardware multiplication, it performs this operation using the **shift-and-add method**, which adds and shifts bits repeatedly to emulate integer multiplication. ```c /* Binary search for square root of m */ while (low <= high) { uint32_t mid = (low + high) >> 1; uint32_t sq = (mid * mid) / 128; /* Square and scale */ if (sq <= m) { result = mid; /* This could be our answer */ low = mid + 1; } else { high = mid - 1; } } ``` ### Comparison Function #### BF16_EQ ``` asm #//////////////////////////////////////////// # # BF16_Equal function # # input: a0 = a, a1 = b, return 1 if a == b else 0 # #//////////////////////////////////////////// BF16_EQ: addi sp, sp, -12 sw ra, 8(sp) sw a1, 4(sp) # 4(sp) = b sw a0, 0(sp) # 0(sp) = a jal ra, BF16_ISNAN bne a0, x0, 1f # if a is NaN, go to Not equal lw a0, 4(sp) # load b jal ra, BF16_ISNAN bne a0, x0, 1f # if b is NaN, go to Not equal lw a0, 0(sp) # load a jal ra, BF16_ISZERO beq a0, x0, 3f # if a is not zero, go to compare a and b lw a0, 4(sp) # load b jal ra, BF16_ISZERO beq a0, x0, 3f # if b is not zero, go to compare a and b jal x0, 2f # both a and b are zero, go to Equal 1: addi a0, x0, 0 # Not Equal return 0 lw ra, 8(sp) addi sp, sp, 12 ret 2: addi a0, x0, 1 # Equal return 1 lw ra, 8(sp) addi sp, sp, 12 ret 3: lw a0, 0(sp) # load a lw a1, 4(sp) # load b beq a0, a1, 2b # if a == b, go to Equal jal x0, 1b # else go to Not Equal ``` #### BF16_LT ``` asm #//////////////////////////////////////////// # # BF16_Less_Than function # # input: a0 = a, a1 = b, return 1 if a < b else 0 # #//////////////////////////////////////////// BF16_LT: addi sp, sp, -12 sw ra, 8(sp) sw a1, 4(sp) # 4(sp) = b sw a0, 0(sp) # 0(sp) = a jal ra, BF16_ISNAN bne a0, x0, 1f # if a is NaN, go to Not Less lw a0, 4(sp) # load b jal ra, BF16_ISNAN bne a0, x0, 1f # if b is NaN, go to Not Less lw a0, 0(sp) # load a jal ra, BF16_ISZERO beq a0, x0, 3f # if a is not zero, go to check a and b signed bit lw a0, 4(sp) # load b jal ra, BF16_ISZERO beq a0, x0, 3f # if b is not zero, go to check a and b signed bit jal x0, 1f # both a and b are zero, go to Not Less 1: addi a0, x0, 0 # Not Less return 0 lw ra, 8(sp) addi sp, sp, 12 ret 2: addi a0, x0, 1 # Less return 1 lw ra, 8(sp) addi sp, sp, 12 ret 3: lw a0, 0(sp) # load a lw a1, 4(sp) # load b srli t0, a0, 15 # t0 = sign_a srli t1, a1, 15 # t1 = sign_b bne t0, t1, 4f # if sign_a != sign_b, go to sign bit compare jal x0, 5f # if sign_a == sign_b 4: bgt t0, t1, 2b # if sign_a > sign_b, go to Less jal x0, 1b # else go to Not Less 5: beq t0, x0, 7f # if a_sign == 0, go to both pos jal x0, 6f # if a_sign == 1, go to both neg 6: bgt a0, a1, 2b # if a > b, go to Less jal x0, 1b # else go to Not Less 7: blt a0, a1, 2b # if a < b, go to Less jal x0, 1b # else go to Not Less ``` #### BF16_GT ``` asm #//////////////////////////////////////////// # # BF16_Greater_Than function # # input: a0 = a, a1 = b, return 1 if a > b else 0 # #//////////////////////////////////////////// BF16_GT: mv t0, a0 mv a0, a1 mv a1, t0 jal x0, BF16_LT ``` ### Test Case #### FP32 / BF16 Conversion | # | FP32 (Original) | Value | BF16 (Original) | Value | Converted FP32 | |:-:|:-:|:-:|:-:|:-:|:-:| | 1 | `0x00000000` | 0.0 | `0x0000` | 0.0 | `0x00000000` | | 2 | `0x3F800000` | 1.0 | `0x3F80` | 1.0 | `0x3F800000` | | 3 | `0xBF800000` | -1.0 | `0xBF80` | -1.0 | `0xBF800000` | | 4 | `0x40000000` | 2.0 | `0x4000` | 2.0 | `0x40000000` | | 5 | `0xC0000000` | -2.0 | `0xC000` | -2.0 | `0xC0000000` | | 6 | `0x3F000000` | 0.5 | `0x3F00` | 0.5 | `0x3F000000` | | 7 | `0xBF000000` | -0.5 | `0xBF00` | -0.5 | `0xBF000000` | | 8 | `0x40490FD0` | 3.14159 | `0x4049` | 3.14159 | `0x40490000` | | 9 | `0xC0490FD0` | -3.14159 | `0xC049` | -3.14159 | `0xC0490000` | | 10 | `0x501502F9` | 1e10 | `0x5015` | 1e10 | `0x50150000` | | 11 | `0xD01502F9` | -1e10 | `0xD015` | -1e10 | `0xD0150000` | #### Special Values | # | BF16 (Hex) | Description | isNaN | isInf | isZero | |:-:|:-:|:-:|:-:|:-:|:-:| | 1 | `0x7F80` | +Inf | 0 | 1 | 0 | | 2 | `0xFF80` | -Inf | 0 | 1 | 0 | | 3 | `0x7FC0` | NaN | 1 | 0 | 0 | | 4 | `0x0000` | +0.0 | 0 | 0 | 1 | | 5 | `0x8000` | -0.0 | 0 | 0 | 1 | --- #### BF16 ADD | # | Operand A (Hex) | Operand B (Hex) | Expression | Expected Result (BF16 Hex) | Result (Value) | |:-:|:-:|:-:|:-:|:-:|:-:| | 1 | `0x3F80` | `0x4000` | 1.0 + 2.0 | `0x4040` | 3.0 | | 2 | `0x4049` | `0x402E` | 3.140625 + 2.71875 | `0x40BB` | 5.84375 | | 3 | `0x3F80` | `0xC000` | 1.0 + (-2.0) | `0xBF80` | -1.0 | | 4 | `0xC000` | `0x3F80` | -2.0 + 1.0 | `0xBF80` | -1.0 | | 5 | `0x0000` | `0x3F80` | 0.0 + 1.0 | `0x3F80` | 1.0 | | 6 | `0x3F80` | `0x0000` | 1.0 + 0.0 | `0x3F80` | 1.0 | | 7 | `0x7F80` | `0x3F80` | +Inf + 1.0 | `0x7F80` | +Inf | | 8 | `0x3F80` | `0x7F80` | 1.0 + +Inf | `0x7F80` | +Inf | | 9 | `0xFF80` | `0x3F80` | -Inf + 1.0 | `0xFF80` | -Inf | | 10 | `0x3F80` | `0xFF80` | 1.0 + -Inf | `0xFF80` | -Inf | | 11 | `0x7F62` | `0x7F62` | 3e38 + 3e38 (f32→bf16 overflow) | `0x7F80` | +Inf | --- #### BF16 SUB | # | Operand A (Hex) | Operand B (Hex) | Expression | Expected Result (BF16 Hex) | Result (Value) | |:-:|:-:|:-:|:-:|:-:|:-:| | 1 | `0x4000` | `0x3F80` | 2.0 − 1.0 | `0x3F80` | 1.0 | | 2 | `0x4049` | `0x402E` | 3.140625 − 2.71875 | `0x3ED8` | 0.421875 | | 3 | `0x3F80` | `0xC000` | 1.0 − (−2.0) | `0x4040` | 3.0 | --- #### BF16 MUL | # | Operand A (Hex) | Operand B (Hex) | Expression | Expected Result (BF16 Hex) | Result (Value) | |:-:|:-:|:-:|:-:|:-:|:-:| | 1 | `0x4040` | `0x4080` | 3.0 × 4.0 | `0x4140` | 12.0 | | 2 | `0x4049` | `0x402E` | 3.140625 × 2.71875 | `0x4108` | 8.5 | | 3 | `0x7F80` | `0x3F80` | +Inf × 1.0 | `0x7F80` | +Inf | | 4 | `0x7F80` | `0x0000` | +Inf × 0.0 | `0x7FC0` | NaN | --- #### BF16 DIV | # | Operand A (Hex) | Operand B (Hex) | Expression | Expected Result (BF16 Hex) | Result (Value) | |:-:|:-:|:-:|:-:|:-:|:-:| | 1 | `0x4120` | `0x4000` | 10.0 ÷ 2.0 | `0x40A0` | 5.0 | | 2 | `0xC120` | `0x4000` | −10.0 ÷ 2.0 | `0xC0A0` | −5.0 | | 3 | `0x4128` | `0x3F00` | 10.5 ÷ 0.5 | `0x41A8` | 21.0 | | 4 | `0x7F80` | `0x7F80` | +Inf ÷ +Inf | `0x7FC0` | NaN | | 5 | `0x3F80` | `0x7F80` | 1.0 ÷ +Inf | `0x0000` | 0.0 | | 6 | `0x7F80` | `0x3F80` | +Inf ÷ 1.0 | `0x7F80` | +Inf | | 7 | `0x3F80` | `0x0000` | 1.0 ÷ 0.0 | `0x7F80` | +Inf | --- #### BF16 SQRT | # | Operand (Hex) | Expression | Expected Result (BF16 Hex) | Result (Value) | |:-:|:-:|:-:|:-:|:-:| | 1 | `0x4080` | √(4.0) | `0x4000` | 2.0 | | 2 | `0x4110` | √(9.0) | `0x4040` | 3.0 | | 3 | `0x42A2` | √(81.0) | `0x4110` | 9.0 | --- #### BF16 COMPARISON ##### Operands | Symbol | FP32 (Hex) | Value | |:--:|:--:|:--:| | a | `0x3F800000` | 1.0 | | b | `0x40000000` | 2.0 | | c | `0x3F800000` | 1.0 | | nan | `0x7FC0` | NaN | ##### Equality Tests (`BF16_EQ`) | Case | Expression | Expected Result | |:--:|:--:|:--:| | 1 | a == c → 1.0 == 1.0 | **True (1)** | | 2 | a == b → 1.0 == 2.0 | **False (0)** | | 3 | nan == nan | **False (0)** | ##### Less-Than Tests (`BF16_LT`) | Case | Expression | Expected Result | |:--:|:--:|:--:| | 4 | a < b → 1.0 < 2.0 | **True (1)** | | 5 | b < a → 2.0 < 1.0 | **False (0)** | | 6 | a < c → 1.0 < 1.0 | **False (0)** | | 7 | nan < a | **False (0)** | ##### Greater-Than Tests (`BF16_GT`) | Case | Expression | Expected Result | |:--:|:--:|:--:| | 8 | b > a → 2.0 > 1.0 | **True (1)** | | 9 | a > b → 1.0 > 2.0 | **False (0)** | | 10 | nan > a | **False (0)** | --- #### BF16 Edge Case Conversion ##### Input Operands | Symbol | FP32 (Hex) | Description | |:--:|:--:|:--:| | v₁ | `0x00000001` | 1 × 10⁻⁴⁵ | | v₂ | `0x7E967699` | 1 × 10³⁸ | | v₃ | `0x006CE3EE` | 1 × 10⁻³⁸ | ##### Expected Results | Case | Operation | Expected Result (FP32 Hex) | Meaning | |:--:|:--:|:--:|:--:| | 1 | `F32_TO_BF16(v₁)` | `0x00000000` | Underflow → 0 | | 2 | `F32_TO_BF16(v₂)` → `× 10.0` → `BF16_TO_F32` | `0x7F800000` | Overflow → +Inf | | 3 | `F32_TO_BF16(1e10)` ÷ `F32_TO_BF16(v₃)` → `BF16_TO_F32` | `0x00000000` | Underflow → 0 | --- #### BF16 Rounding ##### Input Operands | Symbol | FP32 (Hex) | Description | |:--:|:--:|:--:| | r₁ | `0x3FC00000` | 1.5 (exactly halfway, should round to nearest even) | | r₂ | `0x3F800347` | 1.0001 (tiny fraction beyond 1.0) | ##### Expected Results | Case | Operation | Expected Result (FP32 Hex) | Meaning | |:--:|:--:|:--:|:--:| | 1 | `F32_TO_BF16(r₁)` → `BF16_TO_F32` | `0x3FC00000` | Stays **1.5**, correctly rounded | | 2 | `F32_TO_BF16(r₂)` → `BF16_TO_F32` | `0x3F800000` | Rounds down to **1.0** (nearest representable) | --- ### Test Result ![image](https://hackmd.io/_uploads/Sy4lR1e6lg.png) ![image](https://hackmd.io/_uploads/BkULblx6ex.png)