# Exploiting Data-Level Parallelism in BFloat16 Multiplication with SWAR contributed by <[`yutingshih`](https://github.com/yutingshih)> > [Source code](https://github.com/yutingshih/ca2023-hw1). Because the BFloat16 only 16 bits, we can place two BF16 numbers into one single 32-bit scalar register to perform the multiplication on the two numbers at the same time to exploit the data-level parallelism (DLP) even though the RV32I does not have vector registers. Combine the [Problem B](https://hackmd.io/@sysprog/arch2023-quiz1-sol#Problem-B) and [Problem C](https://hackmd.io/@sysprog/arch2023-quiz1-sol#Problem-C) of Quiz 1 ### BFloat16 Format The [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) format $$ BF16 = (-1)^S \cdot 2^{E-127} (1 + 2^{-7}M) $$ ``` 16 8 0 +---+------------------------+---------------------+ | S | E | M | +---+------------------------+---------------------+ | | | | | +-- mantissa (7 bits) | +-- exponent (8 bits) +-- sign (1 bit) ``` ### SIMD Within A Register (SWAR) [SWAR](https://en.wikipedia.org/wiki/SWAR) In this assginment, we will pack two BF16 floating pointer numbers into one single 32-bit scalar register to reduce the memory usage and improve the performance. ``` 32 16 15 0 +--------------------------------+--------------------------------+ | BF16 | BF16 | +--------------------------------+--------------------------------+ ``` ### BFloat16 Multiplication with SWAR ## Conversion from FP32 to BF16 ### C Code ```c= float fp32_to_bf16(float x) { float y = x; int *p = (int *) &y; unsigned int exp = *p & 0x7F800000; unsigned int man = *p & 0x007FFFFF; if (exp == 0 && man == 0) /* zero */ return x; if (exp == 0x7F800000) /* infinity or NaN */ return x; /* Normalized number */ /* round to nearest */ float r = x; int *pr = (int *) &r; *pr &= 0xFF800000; /* r has the same exp as x */ r /= 0x100; y = x + r; *p &= 0xFFFF0000; return y; } ``` Since FP16 has only 7 bits in its mantissa, converting FP32 with a 23-bit mantissa to BF16 will result in the truncation of bits 0 to 15 of the mantissa. In order to minimize numerical errors, we aim to approximate the FP32 floating-point number to the closest FP16 floating-point number. If the 15th bit is 0, it is simply discarded; otherwise, it is rounded up to the 16th bit. The original approach involved extracting the exponent and dividing it by 256 (equivalent to shifting the exponent right by 8 bits). The integer part of the original floating-point number 1 would be moved to the 15th bit of the mantissa. If the original mantissa's 15th bit was 1, the floating-point addition in line 18 would automatically handle the rounding up to the 16th bit. Conversely, it would be discarded. However, this method required additional handling of floating-point addition in the RV32I ISA and could not be achieved with just a few simple instructions. Therefore, I implement the assembly code that directly adds `0x8000` onto the mantissa (setting the 15th bit and clearing others) as the line 19 and 20 of the following assembly code. ### Assembly Code ```riscv= fp32_to_bf16: addi sp, sp, -4 sw s0, 0(sp) mv s0, a0 srli t0, s0, 23 # t0: exponent andi t0, t0, 0xFF slli t1, s0, 9 # t1: mantissa fp32_to_bf16_CHECK_ZERO: bnez t0, fp32_to_bf16_CHECK_INF_NAN beqz t1, fp32_to_bf16_EXIT fp32_to_bf16_CHECK_INF_NAN: li t2, 0xFF # t2: mask 0xFF beq t0, t2, fp32_to_bf16_EXIT fp32_to_bf16_ROUNDING: li t2, 0x8000 add a0, s0, t2 # round to nearest bf16 li t2, 0xFFFF0000 and a0, a0, t2 fp32_to_bf16_EXIT: lw s0, 0(sp) addi sp, sp, 4 ret ``` ## Packed BF16 Encoding and Decoding ### C Code For the sake of convenience, we define the type `bf16_t` as an alias of `float` and reuse the bit encoding (sign, exponent, and mantissa) of the higher 16 bits of `float`. This allows us to perform floating-point arithmetic and bitwise operations directly on `bf16_t` using built-in operators in C language. Additionally, we define `pbf16_t`, which stands for **packed bfloat16**, where two `bf16_t` values are packed into a 32-bit unsigned integer variable. ```c= typedef float bf16_t; // bfloat16 typedef uint32_t pbf16_t; // packed bfloat16 ``` ```c= pbf16_t pbf16_encode(bf16_t a, bf16_t b) { return (*(pbf16_t *)&a & 0xFFFF0000) | (*(pbf16_t *)&b >> 16); } void pbf16_decode(pbf16_t ab, bf16_t* a, bf16_t* b) { *(pbf16_t*)a = ab & 0xFFFF0000; *(pbf16_t*)b = ab << 16; } ``` ### Assembly Code ```riscv= # a0 = (a0 & 0xFFFF0000) | (a1 >> 16) pbf16_encode: srli a1, a1, 16 li t0, 0xFFFF0000 and a0, a0, t0 or a0, a0, a1 ret # a0 = a0 & 0xFFFF0000 # a1 = a0 << 16 pbf16_decode: slli a1, a0, 16 li t0, 0xFFFF0000 and a0, a0, t0 ret ``` ## Packed BF16 Multiplication To implement packed BF16 multiplication, we can reuse and modify the code of [Problem C of Quiz 1](https://hackmd.io/@sysprog/arch2023-quiz1-sol#Problem-C) Assume that there are no Infinity and NaN. Here is my implementation: ### C Code ```c= uint32_t mask_lowest_zero(uint32_t x) { uint32_t mask = x; mask &= (mask << 1) | 0x1; mask &= (mask << 2) | 0x3; mask &= (mask << 4) | 0xF; mask &= (mask << 8) | 0xFF; mask &= (mask << 16) | 0xFFFF; return mask; } uint32_t inc(uint32_t x) { if (~x == 0) return 0; /* TODO: Carry flag */ uint32_t mask = mask_lowest_zero(x); uint32_t z1 = mask ^ ((mask << 1) | 1); return (x & ~mask) | z1; } uint32_t imul16(uint32_t a, uint32_t b) { uint32_t r = 0; for (int i = 0; i < 8; i++) if ((b >> i) & 1) r += a << i; r &= 0xFFFF; b >>= 16; a &= 0xFFFF0000; for (int i = 0; i < 8; i++) if ((b >> i) & 1) r += a << i; return r; } pbf16_t pbf16_mul(pbf16_t a, pbf16_t b) { uint32_t sr = (a ^ b) & 0x80008000; uint32_t ma = (a & 0x007F007F) | 0x00800080; uint32_t mb = (b & 0x007F007F) | 0x00800080; uint32_t mr = (imul16(ma, mb) >> 7) & 0x007F007F; uint32_t msh = (mr >> 8) & 1; mr >>= msh; uint32_t ea = (a >> 7) & 0x00FF00FF; uint32_t eb = (b >> 7) & 0x00FF00FF; uint32_t er = ea + eb - 0x007F007F; // 127 = 0b1111111 = 0x7F er = msh ? inc(er) : er; pbf16_t r = sr | ((er & 0x00FF00FF) << 7) | (mr & 0x007F007F); return r; } ``` ### Assembly Code ```riscv= pbf16_mul: addi sp, sp, -44 sw s0, 0(sp) # sign sw s1, 4(sp) # mantissa sw s2, 8(sp) # exponent sw s3, 12(sp) # input a sw s4, 16(sp) # input b sw s5, 20(sp) # mask 0x80008000 sw s6, 24(sp) # mask 0x00800080 sw s7, 28(sp) # mask 0x007F007F sw s8, 32(sp) # mask 0x00FF00FF sw s9, 36(sp) # shift sw ra, 40(sp) mv s3, a0 mv s4, a1 li s5, 0x80008000 li s6, 0x00800080 li s7, 0x007F007F li s8, 0x00FF00FF # sign s0 xor s0, s3, s4 and s0, s0, s5 # mantissa s1 and a0, s3, s7 and a1, s4, s7 or a0, a0, s6 or a1, a1, s6 jal ra, imul16 srli s1, a0, 7 and s1, s1, s7 # shift s9 srli s9, s1, 8 andi s9, s9, 1 srl s1, s1, s9 # exponent s2 srli a0, s3, 7 srli a1, s4, 7 and a0, a0, s8 and a1, a1, s8 add s2, a0, a1 sub s2, s2, s7 beqz s9, pbf16_mul_IF_END mv a0, s2 jal ra, inc mv s2, a0 pbf16_mul_IF_END: # result a0 mv a0, s0 # sign and s1, s1, s7 or a0, a0, s1 # mantissa and s2, s2, s8 slli s2, s2, 7 or a0, a0, s2 # exponent lw s0, 0(sp) # sign lw s1, 4(sp) # mantissa lw s2, 8(sp) # exponent lw s3, 12(sp) # input a lw s4, 16(sp) # input b lw s5, 20(sp) # mask 0x80008000 lw s6, 24(sp) # mask 0x00800080 lw s7, 28(sp) # mask 0x007F007F lw s8, 32(sp) # mask 0x00FF00FF lw s9, 36(sp) # shift lw ra, 40(sp) addi sp, sp, 44 ret ####################### mask_lowest_zero: addi sp, sp, -4 sw s0, 0(sp) mv s0, a0 slli s0, a0, 1 ori s0, s0, 0x1 and a0, a0, s0 slli s0, a0, 2 ori s0, s0, 0x3 and a0, a0, s0 slli s0, a0, 4 ori s0, s0, 0xF and a0, a0, s0 slli s0, a0, 8 ori s0, s0, 0xFF and a0, a0, s0 slli s0, a0, 16 li t0, 0xFFFF or s0, s0, t0 and a0, a0, s0 lw s0, 0(sp) addi sp, sp, 4 ret ####################### inc: addi sp, sp, -12 sw s0, 0(sp) sw s1, 0(sp) not s0, a0 beqz s0, inc_EXIT_FAIL mv s0, a0 # s0: x jal ra, mask_lowest_zero mv s1, a0 # s1: mask slli a0, a0, 1 ori a0, a0, 1 xor a0, a0, a0 mv s0, a0 # s0: x jal ra, mask_lowest_zero # a0: mask slli s1, a0, 1 # s1: z1 ori s1, s1, 1 xor s1, s1, a0 not a0, a0 and a0, a0, s0 or a0, a0, s1 inc_EXIT_OK: lw s0, 0(sp) lw s1, 4(sp) lw ra, 8(sp) addi sp, sp, 12 ret inc_EXIT_FAIL: li a0, 0 lw s0, 0(sp) lw s1, 4(sp) lw ra, 8(sp) addi sp, sp, 12 ret ####################### # a0 = a0 * a1 imul16: addi sp, sp, -12 sw s0, 0(sp) sw s1, 4(sp) sw s2, 8(sp) li s0, 0 # s0: current result li t0, 0 # t0: loop index li t1, 8 # t1: loop bound imul16_LOOP1_BEGIN: bge t0, t1, imul16_LOOP1_END srl s1, a1, t0 andi s1, s1, 1 beqz s1, imul16_IF1_END sll s2, a0, t0 add s0, s0, s2 imul16_IF1_END: addi t0, t0, 1 j imul16_LOOP1_BEGIN imul16_LOOP1_END: slli s0, s0, 16 srli s0, s0, 16 srli a1, a1, 16 srli a0, a0, 16 slli a0, a0, 16 li t0, 0 # t0: loop index li t1, 8 # t1: loop bound imul16_LOOP2_BEGIN: bge t0, t1, imul16_LOOP2_END srl s1, a1, t0 andi s1, s1, 1 beqz s1, imul16_IF2_END sll s2, a0, t0 add s0, s0, s2 imul16_IF2_END: addi t0, t0, 1 j imul16_LOOP2_BEGIN imul16_LOOP2_END: lw s0, 0(sp) lw s1, 4(sp) lw s2, 8(sp) addi sp, sp, 12 ret ``` ## Testing Results ### Test Cases For each case, there are 4 inputs `float` numbers `a`, `b`, `c`, and `d`. The program will convert each of them into `bf16_t`, pack `a` and `b` into a `pbf16_t` number `p`, pack `c` and `d` into a `pbf16_t` number `q`, and finally perform `a * c` and `b * d` within a register by applying `pbf16_mul` on `p` and `q`. | | input `a` | input `b` | input `c` | input `d` | answer <br> `{ac, bd}` | | --- | ----------------------- | ----------------------- | ----------------------- | ----------------------- | --------------------------------- | | 1 | 1.200000 (`0x3F99999A`) | 2.312500 (`0x40140000`) | 2.310000 (`0x4013D70A`) | 1.203125 (`0x3F9A0000`) | 2.781250, 2.781250 (`0x40324032`) | | 2 | 0.000000 (`0x00000000`) | 2.312500 (`0x40140000`) | 2.310000 (`0x4013D70A`) | 1.203125 (`0x3F9A0000`) | 0.000000, 2.781250 (`0x00004032`) | | 2 | 0.000000 (`0x00000000`) | 2.312500 (`0x40140000`) | -2.310000 (`0xC013D70A`) | -1.203125 (`0xBF9A0000`) | -0.000000, -2.781250 (`0x8000C032`) | ### C Output - Case 1 ![](https://hackmd.io/_uploads/SJujn0fba.png) - Case 2 - Case 3 ### Assemly Output - Case 1 - Case 2 - Case 3 ## Analysis ![](https://hackmd.io/_uploads/S1Pur6zWa.png) ### Pipeline Stage Explanation [Ripes](https://github.com/mortbopet/Ripes/tree/master) provides different processors to run the code. In this assignment, we choose **5-stage processor** to execute the program. ![](https://hackmd.io/_uploads/S1o1Opz-T.png) Take the instruction `jal x1 112 <fp32_to_bf16>` for example. The jump and link (JAL) instruction uses the J-type format, where the J-immediate encodes a signed offset in multiples of 2 bytes. The offset is sign-extended and added to the `pc` to form the jump target address. Jumps can therefore target a ±1 MiB range. JAL stores the address of the instruction following the jump (`pc+4`) into register `rd`. The standard software calling convention uses `x1` (`ra`) as the return address register and `x5` (`t0`) as an alternate link register. > reference: [RISC-V Spec v2.2 (p.15-16)](https://riscv.org/wp-content/uploads/2017/05/riscv-spec-v2.2.pdf) <!-- ![p12](https://hackmd.io/_uploads/rkRgCyX-6.png) --> <!-- ![p12](https://hackmd.io/_uploads/HypHCkmZT.png) --> ![p16](https://hackmd.io/_uploads/Sy9cAkXbp.png) ``` rd = pc + 4 pc += offset ``` #### Instruction Fetch (IF) Stage ![JAL IF](https://hackmd.io/_uploads/BJJJOlmZp.png =50%x) 1. The PC of `jal x1 112 <fp32_to_bf16>` in my program is `0x1c`. 2. The instruction fetched from memory is `0x070000ef`. 3. Both PC (`0x0000001c`) and PC+4 (`0x00000020`) are passed to the next stage. #### Instruction Decoce (ID) Stage ![](https://hackmd.io/_uploads/SyAAOlX-a.png =50%x) 1. From the instruction `jal x1 112` (`0x070000ef`), the opcode is `0b1101111`, destination register is `ra` (`0x01`), and immediate is 112 (`0x00000070`). 2. PC (`0x0000001c`), PC+4 (`0x00000020`), and rd (`0x01`) are passed to the next stage. #### Execute (EX) Stage ![](https://hackmd.io/_uploads/SJRVilmb6.png =50%x) 1. The ALU performs addition on PC (`0x0000001c`) and immediate (`0x00000070`) passed from the previous stage, and the result (`0x0000008c`) will be the next PC. 2. PC+4 (`0x00000020`) and rd (`0x01`) are passed to the next stage ![](https://hackmd.io/_uploads/H169axQZa.png) #### Memory Access (MEM) Stage ![](https://hackmd.io/_uploads/HydnClXWp.png =50%x) 1. There is no memory access in JAL instruction, so the `Wr en` signal of data memory is cleared. 2. PC+4 (`0x00000020`) and rd (`0x01`) are passed to the next stage 3. The PC in IF stage is now `0x0000008c`, which is the first instruction in the `fp32_to_bf16` function. 4. Because the target address of JAL instruction is determined in the EX stage, and by that time, the next two instructions have already entered the pipeline. Thus, it is necessary to flush both the ID and EX stages ensure the correct execution of the program. ![](https://hackmd.io/_uploads/rJ4MGZ7Wp.png) #### Write Back (WB) Stage ![](https://hackmd.io/_uploads/r1G2N-Xbp.png) 1. In the WB stage, PC+4 (`0x00000020`) and rd (`0x01`) passed from previous stages are used to write the register file. PC+4 (`0x00000020`) would be written into the `ra` (`0x01`) register. <!-- ## Optimization ### Use temperary registers ### Avoid redundant branch ### Loop unrolling ### Obviate data hazards by reordering instructions ### Eliminate redundant branch instructions by strip mining like technique --> ## Future Work - Handling Infinity and NaN <!-- ## References -->