# Assignment 1: RISC-V Assembly and Instruction Pipeline contributed by < [`jtl`](https://github.com/just-ting) > ## Problem - BFloat16 floating-point matrix multiply Reference [Quiz 1 Problem B](https://hackmd.io/@sysprog/arch2024-quiz1-sol). We try to solve the matrix multiplication of bfloat16 without using floating-point operations in C, particularly since bfloat16 operations are not supported. Besides, to simplify the problem, we will primarily focus on the multiplication of 2x2 bfloat16 matrices. ### Introduction to bfloat16 Brain Floating Point Format (Bfloat16) is a custom 16-bit floating point format developed by Google Brain, which is an artificial intelligence research group at Google. It is specifically designed for machine learning and has the same dynamic range as 32-bit IEEE 754 single-precision format (float32) but with fewer mantissa bits, making it well-suited for deep learning applications. #### `Format` | type | sign bit | exponent bit | fraction/ mantissa bit | | ---- |:--------:|:--------:|:----------:| | FP32 | 1 | 8 | 23 | | BF16 | 1 | 8 | 7 | #### `IEEE 754 Encoding` | Single-Precision | Exponent | Fraction | Value| | -------- | -------- | -------- |-------- | | Normalized Number | 1 to 254 | Anything |$\pm(1.F)\times2^{E-127}$ | Denormilized Number |0|nonzero|$\pm(1.F)\times2^{E-127}$| Zero|0|0|$\pm0$| Infinity|255|0|$\pm\infty$| NaN|255|nonzero|NaN| #### `Decimal Expression` $value_{FP32} = (-1)^{sign}\times{2^{(exponent-127)}}\times\frac{1+fraction}{2^{23}}$ $value_{BF16} = (-1)^{sign}\times{2^{(exponent-127)}}\times\frac{1+fraction}{128}$ ## C code Here, we will not consider NaN, Infinity, or denormalized numbers. Our focus will solely be on floating-point operations while paying attention to potential overflow. ### `idea 1` To perform matrix multiplication with bfloat (bf16) in C, we can first convert the bf16 values to fp32 (32-bit floating point), perform the calculations, and then convert the results back to bf16. ```c #include <stdint.h> #include <stdio.h> typedef union { uint16_t bits; float f; } bf16_t; static inline bf16_t fp32_to_bf16(float s) { bf16_t h; union { float f; uint32_t i; } u = {.f = s}; if ((u.i & 0x7fffffff) > 0x7f800000){ /*NaN*/ h.bits = (u.i >> 16) | 64; /* force to quiet */ return h; } h.bits = (u.i + (0x7fff + ((u.i >> 0x10) & 1))) >> 0x10; return h; } static inline float bf16_to_fp32(bf16_t h) { union { float f; uint32_t i; } u = {.i = (uint32_t)h.bits << 16}; return u.f; } void matrix_multiply_bfloat(const bf16_t* A, const bf16_t* B, bf16_t* C, int M, int N, int K) { // Initialize the result matrix C with zero for (int i = 0; i < M * K; i++) { C[i].bits = 0; } for (int i = 0; i < M; i++) { for (int j = 0; j < K; j++) { float sum = 0.0f; for (int k = 0; k < N; k++) { float a = bf16_to_fp32(A[i * N + k]); float b = bf16_to_fp32(B[k * K + j]); sum += a * b; } C[i * K + j] = fp32_to_bf16(sum); } } } ``` ### `Idea 2` However, if we go with the first idea, the computation will essentially still be in fp32. Since C does not support bf16 multiplication natively, any floating-point operations in C will be compiled using the fp32 format. To perform calculations directly with bf16 in C, we can use a divide-and-conquer approach, which involves handling the sign, exponent, and mantissa of bf16 values separately, along with preventing overflow or underflow to derive the final result accurately. #### bfloat16_multiply `bfloat16_multiply` function extracts the sign, exponent, and mantissa of BFLOAT16, performing the multiplication of the mantissas, adjusting the exponent, normalizing if necessary, and finally constructing the resulting BFLOAT16 value. ```c static inline uint16_t bfloat16_multiply(uint16_t x, uint16_t y) { // extract sign(1), exponent(8), mantissa(7) uint16_t sign1 = (x & 0x8000); uint16_t sign2 = (y & 0x8000); uint16_t exp1 = (x & 0x7F80) >> 7; uint16_t exp2 = (y & 0x7F80) >> 7; uint16_t mant1 = (x & 0x007F); uint16_t mant2 = (y & 0x007F); // multiply processing // mantissa multiply using uint32_t to avoid overflow uint32_t mant_product = (uint32_t)mant1 * (uint32_t)mant2; uint16_t ans_exp = exp1 + exp2 - 127; // check if mantissa should be normalized if (mant_product & 0x00020000) { // if more than 1.0 mant_product >>= 1; ans_exp++; } // truncate the mantissa first 7 bits uint16_t ans_mant = (mant_product >> 8) & 0x007F; // if exp overflow or less than 0 if (ans_exp >= 255) { ans_exp = 255; // inf ans_mant = 0; // mant = 0 } else if (ans_exp <= 0) { ans_exp = 0; // Denormalized or zero ans_mant = 0; } uint16_t ans = sign1 ^ sign2; ans |= (ans_exp << 8); ans |= ans_mant; return ans; } ``` #### bloat16_add `bloat16_add` extracts the sign, exponent, and mantissa of BFLOAT16. Depending on the signs of the inputs, it performs either addition or subtraction. Key operations include aligning exponents and mantissas and checking for overflows. ```c static inline uint16_t bfloat16_add(uint16_t x, uint16_t y) { // extract sign(1), exponent(8), mantissa(7) uint16_t sign1 = (x & 0x8000); uint16_t sign2 = (y & 0x8000); uint16_t exp1 = (x & 0x7F80) >> 7; uint16_t exp2 = (y & 0x7F80) >> 7; uint16_t mant1 = (x & 0x007F); uint16_t mant2 = (y & 0x007F); // if sign is same, then add, otherwise, sub if (sign1 == sign2) { // align exponent if (exp1 > exp2) { mant2 >>= (exp1 - exp2); } else { mant1 >>= (exp2 - exp1); exp1 = exp2; } // man add uint32_t mant_sum = (uint32_t)mant1 + (uint32_t)mant2; // check if disnormalized if (mant_sum & 0x00020000) { // greater than 1.0 mant_sum >>= 1; exp1++; } uint16_t ans_mant = (mant_sum >> 8) & 0x007F; // check if exponent overflow if (exp1 >= 255) { exp1 = 255; // INF ans_mant = 0; } return sign1 | (exp1 << 7) | ans_mant; } else { // sign is not same if (exp1 > exp2) { mant2 >>= (exp1 - exp2); } else { mant1 >>= (exp2 - exp1); exp1 = exp2; } uint32_t mant_diff = (uint32_t)mant1 - (uint32_t)mant2; // assure mantissa is positive if (mant_diff == 0) { return 0; // result is 0 } else if (mant_diff > 0) { // positive result uint16_t ans_mant = (mant_diff >> 8) & 0x007F; return sign1 | (exp1 << 7) | ans_mant; } else { // negative result uint16_t ans_mant = ((-mant_diff) >> 8) & 0x007F; return sign2 | (exp1 << 7) | ans_mant; } } } ``` #### bfloat16_matrix_multiply `bfloat16_matrix_multiply` function multiplies two 2x2 matrices of BFLOAT16 values, and iterates over the rows of the first matrix and the columns of the second matrix, computing the dot product using the bfloat16_add and bfloat16_multiply functions. ```c void bfloat16_matrix_multiply(uint16_t A[2][2], uint16_t B[2][2], uint16_t C[2][2]) { for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { C[i][j] = 0; // Initialize the result matrix for (int k = 0; k < 2; k++) { C[i][j] = bfloat16_add(C[i][j], bfloat16_multiply(A[i][k], B[k][j])); } } } } ``` #### print_matrix `print_matrix` function displays a 2x2 BFLOAT16 matrix in hexadecimal format. ```c // Function to print a BFLOAT16 matrix void print_matrix(uint16_t matrix[2][2]) { for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { printf("%04X ", matrix[i][j]); // Print in hex format } printf("\n"); } } ``` #### main `main` initializes two example 2x2 BFLOAT16 matrices A and B, calls the multiplication function to compute the product, and stores the result in matrix C, finally, prints the resulting matrix. ```c int main() { // Example 2x2 BFLOAT16 matrices uint16_t A[2][2] = { {0x3C00, 0x3D00}, // 1.0, 1.5 in BFLOAT16 {0x3E00, 0x3F00} // 2.0, 2.5 in BFLOAT16 }; uint16_t B[2][2] = { {0x3C00, 0x3D00}, // 1.0, 1.5 in BFLOAT16 {0x3E00, 0x3F00} // 2.0, 2.5 in BFLOAT16 }; uint16_t C[2][2]; // Result matrix bfloat16_matrix_multiply(A, B, C); printf("Result Matrix:\n"); print_matrix(C); return 0; } ``` :::danger Can you improve the above instead of adding comments literally? ::: ### `Test` #### test 1 ``` uint16_t A[2][2] = { {0x3F9D, 0x401D}, // 1.234, 2.456 in BFLOAT16 {0x406B, 0x409C}}; // 3.678, 4.890 in BFLOAT16 uint16_t B[2][2] = { {0x4007, 0x3F9C}, // 2.111, 1.222 in BFLOAT16 {0x3EAA, 0x405C}}; // 0.333, 3.444 in BFLOAT16 ``` #### result 1 <s> ![image](https://hackmd.io/_uploads/Sk-n7ksyyx.png =75%x) </s> :::danger Do not use screenshots for plain text content, as this is inaccessible to visually impaired users. ::: #### test 2 ``` uint16_t A[2][2] = { {0x3DFB, 0x3EE9}, // 0.123, 0.123 in BFLOAT16 {0x3F49, 0x3EA4}}; // 0.789, 0.321 in BFLOAT16 uint16_t B[2][2] = {{0x40B2, 0x40DC}, // 5.567 6.890 in BFLOAT16 {0x40E3, 0x4103}}; // 7.123 8.234 in BFLOAT16 ``` #### result 2 ![image](https://hackmd.io/_uploads/BJlWN1oyyx.png =75%x) :::danger Do not use screenshots for plain text content, as this is inaccessible to visually impaired users. ::: #### test 3 ``` uint16_t A[2][2] = { {0x7FC0, 0x4018}, // −1.125, 2.376 in BFLOAT16 {0x405F, 0x7FC0} }; // 3.487, −4.158 in BFLOAT16 uint16_t B[2][2] = { {0x400F, 0xBFCB}, // 2.237, −1.587 in BFLOAT16 {0x3EE9, 0x4069} }; // 0.456, 3.654 in BFLOAT16 ``` #### result 3 ![image](https://hackmd.io/_uploads/rka4VkiJ1l.png =75%x) :::danger Do not use screenshots for plain text content, as this is inaccessible to visually impaired users. ::: ### C Compiled to Assembly ```c .file "bfloat_matmul.c" .text .def bfloat16_multiply; .scl 3; .type 32; .endef .seh_proc bfloat16_multiply bfloat16_multiply: pushq %rbp .seh_pushreg %rbp movq %rsp, %rbp .seh_setframe %rbp, 0 subq $32, %rsp .seh_stackalloc 32 .seh_endprologue movl %edx, %eax movl %ecx, %edx movw %dx, 16(%rbp) movw %ax, 24(%rbp) movzwl 16(%rbp), %eax andw $-32768, %ax movw %ax, -10(%rbp) movzwl 24(%rbp), %eax andw $-32768, %ax movw %ax, -12(%rbp) movzwl 16(%rbp), %eax sarl $7, %eax andw $255, %ax movw %ax, -14(%rbp) movzwl 24(%rbp), %eax sarl $7, %eax andw $255, %ax movw %ax, -16(%rbp) movzwl 16(%rbp), %eax andl $127, %eax movw %ax, -18(%rbp) movzwl 24(%rbp), %eax andl $127, %eax movw %ax, -20(%rbp) movzwl -18(%rbp), %edx movzwl -20(%rbp), %eax imull %edx, %eax movl %eax, -4(%rbp) movzwl -14(%rbp), %edx movzwl -16(%rbp), %eax addl %edx, %eax subl $127, %eax movw %ax, -6(%rbp) movl -4(%rbp), %eax andl $131072, %eax testl %eax, %eax je .L2 shrl -4(%rbp) movzwl -6(%rbp), %eax addl $1, %eax movw %ax, -6(%rbp) .L2: movl -4(%rbp), %eax shrl $8, %eax andl $127, %eax movw %ax, -8(%rbp) cmpw $254, -6(%rbp) jbe .L3 movw $255, -6(%rbp) movw $0, -8(%rbp) jmp .L4 .L3: cmpw $0, -6(%rbp) jne .L4 movw $0, -6(%rbp) movw $0, -8(%rbp) .L4: movzwl -10(%rbp), %eax xorw -12(%rbp), %ax movw %ax, -22(%rbp) movzwl -6(%rbp), %eax sall $8, %eax movl %eax, %edx movzwl -22(%rbp), %eax orl %edx, %eax movw %ax, -22(%rbp) movzwl -8(%rbp), %eax orw %ax, -22(%rbp) movzwl -22(%rbp), %eax addq $32, %rsp popq %rbp ret .seh_endproc .def bfloat16_add; .scl 3; .type 32; .endef .seh_proc bfloat16_add bfloat16_add: pushq %rbp .seh_pushreg %rbp movq %rsp, %rbp .seh_setframe %rbp, 0 subq $32, %rsp .seh_stackalloc 32 .seh_endprologue movl %edx, %eax movl %ecx, %edx movw %dx, 16(%rbp) movw %ax, 24(%rbp) movzwl 16(%rbp), %eax andw $-32768, %ax movw %ax, -16(%rbp) movzwl 24(%rbp), %eax andw $-32768, %ax movw %ax, -18(%rbp) movzwl 16(%rbp), %eax sarl $7, %eax andw $255, %ax movw %ax, -2(%rbp) movzwl 24(%rbp), %eax sarl $7, %eax andw $255, %ax movw %ax, -20(%rbp) movzwl 16(%rbp), %eax andl $127, %eax movw %ax, -4(%rbp) movzwl 24(%rbp), %eax andl $127, %eax movw %ax, -6(%rbp) movzwl -16(%rbp), %eax cmpw -18(%rbp), %ax jne .L7 movzwl -2(%rbp), %eax cmpw %ax, -20(%rbp) jnb .L8 movzwl -6(%rbp), %r8d movzwl -2(%rbp), %eax movzwl -20(%rbp), %edx subl %edx, %eax movl %eax, %ecx sarl %cl, %r8d movl %r8d, %eax movw %ax, -6(%rbp) jmp .L9 .L8: movzwl -4(%rbp), %r8d movzwl -20(%rbp), %eax movzwl -2(%rbp), %edx subl %edx, %eax movl %eax, %ecx sarl %cl, %r8d movl %r8d, %eax movw %ax, -4(%rbp) movzwl -20(%rbp), %eax movw %ax, -2(%rbp) .L9: movzwl -4(%rbp), %edx movzwl -6(%rbp), %eax addl %edx, %eax movl %eax, -12(%rbp) movl -12(%rbp), %eax andl $131072, %eax testl %eax, %eax je .L10 shrl -12(%rbp) movzwl -2(%rbp), %eax addl $1, %eax movw %ax, -2(%rbp) .L10: movl -12(%rbp), %eax shrl $8, %eax andl $127, %eax movw %ax, -14(%rbp) cmpw $254, -2(%rbp) jbe .L11 movw $255, -2(%rbp) movw $0, -14(%rbp) .L11: movzwl -2(%rbp), %eax sall $7, %eax movl %eax, %edx movzwl -16(%rbp), %eax orl %eax, %edx movzwl -14(%rbp), %eax orl %edx, %eax jmp .L12 .L7: movzwl -2(%rbp), %eax cmpw %ax, -20(%rbp) jnb .L13 movzwl -6(%rbp), %r8d movzwl -2(%rbp), %eax movzwl -20(%rbp), %edx subl %edx, %eax movl %eax, %ecx sarl %cl, %r8d movl %r8d, %eax movw %ax, -6(%rbp) jmp .L14 .L13: movzwl -4(%rbp), %r8d movzwl -20(%rbp), %eax movzwl -2(%rbp), %edx subl %edx, %eax movl %eax, %ecx sarl %cl, %r8d movl %r8d, %eax movw %ax, -4(%rbp) movzwl -20(%rbp), %eax movw %ax, -2(%rbp) .L14: movzwl -4(%rbp), %edx movzwl -6(%rbp), %eax subl %eax, %edx movl %edx, -24(%rbp) cmpl $0, -24(%rbp) jne .L15 movl $0, %eax jmp .L12 .L15: cmpl $0, -24(%rbp) je .L16 movl -24(%rbp), %eax shrl $8, %eax andl $127, %eax movw %ax, -28(%rbp) movzwl -2(%rbp), %eax sall $7, %eax movl %eax, %edx movzwl -16(%rbp), %eax orl %eax, %edx movzwl -28(%rbp), %eax orl %edx, %eax jmp .L12 .L16: movl -24(%rbp), %eax negl %eax shrl $8, %eax andl $127, %eax movw %ax, -26(%rbp) movzwl -2(%rbp), %eax sall $7, %eax movl %eax, %edx movzwl -18(%rbp), %eax orl %eax, %edx movzwl -26(%rbp), %eax orl %edx, %eax .L12: addq $32, %rsp popq %rbp ret .seh_endproc .globl bfloat16_matrix_multiply .def bfloat16_matrix_multiply; .scl 2; .type 32; .endef .seh_proc bfloat16_matrix_multiply bfloat16_matrix_multiply: pushq %rbp .seh_pushreg %rbp pushq %rbx .seh_pushreg %rbx subq $56, %rsp .seh_stackalloc 56 leaq 48(%rsp), %rbp .seh_setframe %rbp, 48 .seh_endprologue movq %rcx, 32(%rbp) movq %rdx, 40(%rbp) movq %r8, 48(%rbp) movl $0, -4(%rbp) jmp .L18 .L23: movl $0, -8(%rbp) jmp .L19 .L22: movl -4(%rbp), %eax cltq leaq 0(,%rax,4), %rdx movq 48(%rbp), %rax addq %rax, %rdx movl -8(%rbp), %eax cltq movw $0, (%rdx,%rax,2) movl $0, -12(%rbp) jmp .L20 .L21: movl -12(%rbp), %eax cltq leaq 0(,%rax,4), %rdx movq 40(%rbp), %rax addq %rax, %rdx movl -8(%rbp), %eax cltq movzwl (%rdx,%rax,2), %eax movzwl %ax, %edx movl -4(%rbp), %eax cltq leaq 0(,%rax,4), %rcx movq 32(%rbp), %rax addq %rax, %rcx movl -12(%rbp), %eax cltq movzwl (%rcx,%rax,2), %eax movzwl %ax, %eax movl %eax, %ecx call bfloat16_multiply movzwl %ax, %edx movl -4(%rbp), %eax cltq leaq 0(,%rax,4), %rcx movq 48(%rbp), %rax addq %rax, %rcx movl -8(%rbp), %eax cltq movzwl (%rcx,%rax,2), %eax movzwl %ax, %eax movl -4(%rbp), %ecx movslq %ecx, %rcx leaq 0(,%rcx,4), %r8 movq 48(%rbp), %rcx leaq (%r8,%rcx), %rbx movl %eax, %ecx call bfloat16_add movl -8(%rbp), %edx movslq %edx, %rdx movw %ax, (%rbx,%rdx,2) addl $1, -12(%rbp) .L20: cmpl $1, -12(%rbp) jle .L21 addl $1, -8(%rbp) .L19: cmpl $1, -8(%rbp) jle .L22 addl $1, -4(%rbp) .L18: cmpl $1, -4(%rbp) jle .L23 nop nop addq $56, %rsp popq %rbx popq %rbp ret .seh_endproc .section .rdata,"dr" .LC0: .ascii "%04X \0" .text .globl print_matrix .def print_matrix; .scl 2; .type 32; .endef .seh_proc print_matrix print_matrix: pushq %rbp .seh_pushreg %rbp movq %rsp, %rbp .seh_setframe %rbp, 0 subq $48, %rsp .seh_stackalloc 48 .seh_endprologue movq %rcx, 16(%rbp) movl $0, -4(%rbp) jmp .L25 .L28: movl $0, -8(%rbp) jmp .L26 .L27: movl -4(%rbp), %eax cltq leaq 0(,%rax,4), %rdx movq 16(%rbp), %rax addq %rax, %rdx movl -8(%rbp), %eax cltq movzwl (%rdx,%rax,2), %eax movzwl %ax, %eax movl %eax, %edx leaq .LC0(%rip), %rax movq %rax, %rcx call printf addl $1, -8(%rbp) .L26: cmpl $1, -8(%rbp) jle .L27 movl $10, %ecx call putchar addl $1, -4(%rbp) .L25: cmpl $1, -4(%rbp) jle .L28 nop nop addq $48, %rsp popq %rbp ret .seh_endproc .section .rdata,"dr" .LC1: .ascii "Result Matrix:\0" .text .globl main .def main; .scl 2; .type 32; .endef .seh_proc main main: pushq %rbp .seh_pushreg %rbp movq %rsp, %rbp .seh_setframe %rbp, 0 subq $64, %rsp .seh_stackalloc 64 .seh_endprologue call __main movw $15360, -8(%rbp) movw $15616, -6(%rbp) movw $15872, -4(%rbp) movw $16128, -2(%rbp) movw $15360, -16(%rbp) movw $15616, -14(%rbp) movw $15872, -12(%rbp) movw $16128, -10(%rbp) leaq -24(%rbp), %rcx leaq -16(%rbp), %rdx leaq -8(%rbp), %rax movq %rcx, %r8 movq %rax, %rcx call bfloat16_matrix_multiply leaq .LC1(%rip), %rax movq %rax, %rcx call puts leaq -24(%rbp), %rax movq %rax, %rcx call print_matrix movl $0, %eax addq $64, %rsp popq %rbp ret .seh_endproc .def __main; .scl 2; .type 32; .endef .ident "GCC: (x86_64-mcf-seh-rev0, Built by MinGW-Builds project) 14.2.0" .def printf; .scl 2; .type 32; .endef .def putchar; .scl 2; .type 32; .endef .def puts; .scl 2; .type 32; .endef ``` ## RISCV32 ### matmul Main Function `matmul` initialize matrices A and B and set matrix C to zero, respectively compute the four elements of matrix C using the `bfloat16_mul` and `bfloat16_add`. Finally, use `printResult` to output the results, displaying the contents of matrix C. ```riscv32= .data .align 2 A: .word 0x3F9DF3B6, 0x401D2F1B, 0x406B645A, 0x409C7AE1 # A[0][0] = 1.234, A[0][1] = 2.456 (float32), A[1][0] = 3.678, A[1][1] = 4.890 (float32) .align 2 B: .word 0x40071AA0, 0x3F9C6A7F, 0x3EAA7EFA, 0x405C6A7F # B[0][0] = 2.111, B[0][1] = 1.222 (float32), B[1][0] = 0.333, B[1][1] = 3.444 (float32) .align 2 C: .word 0, 0, 0, 0 .align 2 sign_mask: .word 0x80000000 exp_mask: .word 0x7F800000 man_mask: .word 0x7F0000 bfloat_sign_mask: .word 0x8000 bfloat_exp_mask: .word 0x7F80 bfloat_man_mask: .word 0x007F str1: .string "C =" # output message 1 str2: .string ", " # output message 2 str3: .string "overflow!" # output message 3 .text matmul: la a6, C # calculate C[0][0] = A[0][0] * B[0][0] + A[0][1] * B[1][0] lw a0, A # a0 = A[0][0] lw a1, B # a1 = B[0][0] jal bfloat16_mul # a0 * a1 mv a4, a3 lw a0, A+4 # a0 = A[0][1] lw a1, B+8 # a1 = B[0][1] jal bfloat16_mul # a0 * a1 mv a5, a3 jal bfloat16_add sw a3, 0(a6) # calculate C[0][1] = A[0][0] * B[0][1] + A[0][1] * B[1][1] lw a0, A-4 # a0 = A[0][0] lw a1, B-4 # a1 = B[0][1] jal bfloat16_mul # a0 * a1 mv a4, a3 lw a0, A+4 # a0 = A[0][1] lw a1, B+8 # a1 = B[1][1] jal bfloat16_mul # a0 * a1 mv a5, a3 jal bfloat16_add addi a6, a6, 4 sw a3, 0(a6) # calculate C[1][0] = A[1][0] * B[0][0] + A[1][1] * B[1][0] lw a0, A+4 # a0 = A[1][0] lw a1, B-12 # a1 = B[0][0] jal bfloat16_mul # a0 * a1 mv a4, a3 lw a0, A+4 # a0 = A[1][1] lw a1, B+8 # a1 = B[1][0] jal bfloat16_mul # a0 * a1 mv a5, a3 jal bfloat16_add addi a6, a6, 4 sw a3, 0(a6) # calculate C[1][1] = A[1][0] * B[0][1] + A[1][1] * B[1][1] lw a0, A-4 # a0 = A[1][0] lw a1, B-4 # a1 = B[0][1] jal bfloat16_mul # a0 * a1 mv a4, a3 lw a0, A+4 # a0 = A[0][1] lw a1, B+8 # a1 = B[1][1] jal bfloat16_mul # a0 * a1 mv a5, a3 jal bfloat16_add addi a6, a6, 4 sw a3, 0(a6) # Call the function to print the result jal ra, printResult # finish li a7, 10 ecall ``` ### bfloat16_mul `bfloat16_mul` executes multiplication in bfloat16 format. It extracts the sign, exponent, and mantissa from the input numbers and calculates the product, and includes handling for overflow and normalization, printing an error message in case of overflow. ```riscv32= # bfloat16 multiplier # input: a0, a1 (ieee 754 float) # output: a3 (bfloat16) # t0, t1: A_sign, B_sign # t2, t3: A_exp, B_exp # t4, t5: A_man, B_man # t6: sign_mask, exp_mask, man_mask, overflow, normalize bfloat16_mul: lw t6, sign_mask # t6 = sign_mask and t0, a0, t6 # A_sign and t1, a1, t6 # B_sign xor a3, t0, t1 # calculate new sign srli a3, a3, 16 lw t6, exp_mask # t6 = exp_mask and t2, a0, t6 # A_exp srli t2, t2, 23 addi t2, t2, -127 and t3, a1, t6 # B_exp srli t3, t3, 23 addi t3, t3, -127 add t2, t2, t3 # calculate new exp addi t2, t2, 127 slli t2, t2, 7 lw t6, man_mask # t6 = man_mask and t4, a0, t6 # A_mant srli t4, t4, 16 addi t4, t4, 0x0080 # add 1. and t5, a1, t6 # B_mant srli t5, t5, 16 addi t5, t5, 0x0080 mul t4, t4, t5 # calculate new mant srli t6, t4, 15 # check if overflow bne t6, zero, overflow andi t6, t4, 0x80 # check if need to normalize beq t6, zero, normalize or a3, a3, t2 or a3, a3, t5 normalize: srli t3, t3, 1 addi t2, t2, 1 jr ra overflow: la a0, str3 # Load the address of the string li a7, 4 # System call code for printing a string ecall # Print the string li a7, 10 ecall ``` ### bfloat16_add `bfloat16_add` performs addition in bfloat16 format. It extracts the sign, exponent, and mantissa and performs addition or subtraction based on the signs, and also contains handling for normalization. ```riscv32= # bfloat16 adder # input: a4, a5 (bfloat16) # output: a3 (bfloat16) # t0, t1: A_sign, B_sign # t2, t3: A_exp, B_exp # t4, t5: A_man, A_man # t6: bfloat_sign_mask, bfloat_exp_mask, man_mask, overflow, normalize bfloat16_add: # sign lw t6, bfloat_sign_mask and t0, a4, t6 # A_sign and t1, a5, t6 # B_sign # exp lw t6, bfloat_exp_mask and t2, a4, t6 # A_exp and t3, a5, t6 # B_exp # man lw t6, bfloat_man_mask and t4, a4, t6 # A_man and t5, a5, t6 # B_man # calculate exact exp srli t2, t2, 7 srli t3, t3, 7 addi t2, t2, 127 addi t3, t3, 127 # calculate exact man ori t4, t4, 0x0080 # add 1. ori t5, t5, 0x0080 # align man beq t2, t3, add_man # A_exp = B_exp, add blt t2, t3, shift_a # A_exp < B_exp, shift sub t6, t2, t3 # A_exp - B_exp srl t4, t4, t6 # shift A_man j add_man shift_a: sub t6, t3, t2 # B_exp - A_exp srl t5, t5, t6 # shift B_man j add_man add_man: beq t0, t1, same_sign # 符號相同,直接加 blt t4, t5, sub_man # A_man < B_man sub t4, t4, t5 # A_man - B_man j normalize_man same_sign: add t4, t4, t5 sub_man: sub t4, t5, t4 # B_man - A_man normalize_man: # check if normalize srli t6, t4, 8 # man bit 9 != 1 bnez t6, combine slli t0, t0, 1 # shift man addi t2, t2, 1 # exp + 1 combine: srli t0, t0, 1 # man shift back or a3, t0, t2 or a3, a3, t4 jr ra ``` ### printResult `printResult` is responsible for printing the four elements of the result matrix C, which demonstrates column-major order. ```riscv32= printResult: mv t0, a6 # Save output value (C) in temporary register t0 addi t0, t0, -12 la a0, str1 # Load the address of the string "C =" li a7, 4 # System call code for printing a string ecall # Print the string mv a0, t0 # Move the original input value (X) to a0 for printing li a7, 1 # System call code for printing an integer ecall # Print the integer (element of matrix C) la a0, str2 # Load the address of the string", " li a7, 4 # System call code for printing a string ecall # Print the string addi t0, t0, 4 lw a0, 0(t0) # Move the original input value (X) to a0 for printing li a7, 1 # System call code for printing an integer ecall # Print the integer (element of matrix C) la a0, str2 # Load the address of the string", " li a7, 4 # System call code for printing a string ecall addi t0, t0, 4 lw a0, 0(t0) # Move the original input value (X) to a0 for printing li a7, 1 # System call code for printing an integer ecall # Print the integer (element of matrix C) la a0, str2 # Load the address of the string", " li a7, 4 # System call code for printing a string ecall addi t0, t0, 4 lw a0, 0(t0) # Move the original input value (X) to a0 for printing li a7, 1 # System call code for printing an integer ecall # Print the integer (element of matrix C) ret # Return to the caller ``` ## Analysis analyze by using [Ripes](https://github.com/mortbopet/Ripes) Simulator ### Five stage RISC pipeline * #### Instruction fetch (IF) Instruction is fetched from the memory * #### Instruction decode and register fetch (ID) Decode the instruction * #### Execute (EX) Perform the operation specified by the instruction * #### Memory access (MEM) Access data from memory or store data into memory * #### Register write back (WB) Store the result in the destination location ![image](https://hackmd.io/_uploads/Hksbv9911l.png) All the assembly code will go through the stages like the example below. ![image](https://hackmd.io/_uploads/HJliSca5Jyx.png) ![image](https://hackmd.io/_uploads/S1cI569Jkx.png) #### IF ![image](https://hackmd.io/_uploads/Sy88jp5k1l.png =40%x) * PC = PC + 4 * Program Counter (PC) gives next instruction address to instruction memory. * The instruction `sril` is in the memory address `0x130` * Instruction `sril` can be translated to `0x017e5e13` from the figure above. #### ID ![image](https://hackmd.io/_uploads/B1wHTa5yyg.png =50%x) * Instruction 0x01ef5fe33 is decoded into three parts: opcode AND(0x1c) R1 idx(0x0b) & R2 idx(0x1f) Imm. 0xdeadbeef * Read the value from R1 idx and R2 idx #### EX ![image](https://hackmd.io/_uploads/SyfsZ0cJ1e.png =50%x) * ALU adds two operand together which one is r1_out and another is imm_out. * PC and next PC value just send through the stage without any operation. #### MEM ![image](https://hackmd.io/_uploads/B1GsGC51kx.png =30%x) * `srli` instruction has no operation in `MEM` #### WB ![image](https://hackmd.io/_uploads/BJ4ZER9yJe.png =40%x) * Write `0x3f800000` back to `Wr data` ### Pipeline Data Hazard Read After Write (RAW): When one instruction depends on the result of a previous instruction that has not yet completed. In this case, several instructions read values from the same registers without ensuring that previous writes have finished. There is some ways to solve it: **1. Stall Insertions:** Insert no-operation (nop) instructions between dependent operations, though this reduces performance. **2. Reordering Instructions:** Consider reordering independent instructions to allow time for previous computations to complete. **For example:** If bfloat16_mul takes several cycles to complete, mv a4, a3 may be executed before a3 is ready. ``` lw a0, A # a0 = A[0][0] lw a1, B # a1 = B[0][0] jal bfloat16_mul # a0 * a1 mv a4, a3 # Hazard ``` While line 31 is in IF stage, but register a3 has not updated. Thus, we need to add two nop, as shown below (red words). ![image](https://hackmd.io/_uploads/B1gUxpcy1g.png) ![image](https://hackmd.io/_uploads/H1udeT9kye.png) ![image](https://hackmd.io/_uploads/S1e3lT5yyl.png) ### Performance ![image](https://hackmd.io/_uploads/SJ_Y3C5k1e.png =40%x) ## Reference [https://marz.utk.edu/my-courses/cosc230/book/example-risc-v-assembly-programs/#call_c](https://marz.utk.edu/my-courses/cosc230/book/example-risc-v-assembly-programs/#call_c) [https://en.wikipedia.org/wiki/IEEE_754](https://en.wikipedia.org/wiki/IEEE_754) [https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus) [https://hackmd.io/@sysprog/H1TpVYMdB](https://hackmd.io/@sysprog/H1TpVYMdB)