# Assignment 1: RISC-V Assembly and Instruction Pipeline
contributed by < [`jtl`](https://github.com/just-ting) >
## Problem - BFloat16 floating-point matrix multiply
Reference [Quiz 1 Problem B](https://hackmd.io/@sysprog/arch2024-quiz1-sol). We try to solve the matrix multiplication of bfloat16 without using floating-point operations in C, particularly since bfloat16 operations are not supported. Besides, to simplify the problem, we will primarily focus on the multiplication of 2x2 bfloat16 matrices.
### Introduction to bfloat16
Brain Floating Point Format (Bfloat16) is a custom 16-bit floating point format developed by Google Brain, which is an artificial intelligence research group at Google. It is specifically designed for machine learning and has the same dynamic range as 32-bit IEEE 754 single-precision format (float32) but with fewer mantissa bits, making it well-suited for deep learning applications.
#### `Format`
| type | sign bit | exponent bit | fraction/ mantissa bit |
| ---- |:--------:|:--------:|:----------:|
| FP32 | 1 | 8 | 23 |
| BF16 | 1 | 8 | 7 |
#### `IEEE 754 Encoding`
| Single-Precision | Exponent | Fraction | Value|
| -------- | -------- | -------- |-------- |
| Normalized Number | 1 to 254 | Anything |$\pm(1.F)\times2^{E-127}$ |
Denormilized Number |0|nonzero|$\pm(1.F)\times2^{E-127}$|
Zero|0|0|$\pm0$|
Infinity|255|0|$\pm\infty$|
NaN|255|nonzero|NaN|
#### `Decimal Expression`
$value_{FP32} = (-1)^{sign}\times{2^{(exponent-127)}}\times\frac{1+fraction}{2^{23}}$
$value_{BF16} = (-1)^{sign}\times{2^{(exponent-127)}}\times\frac{1+fraction}{128}$
## C code
Here, we will not consider NaN, Infinity, or denormalized numbers. Our focus will solely be on floating-point operations while paying attention to potential overflow.
### `idea 1`
To perform matrix multiplication with bfloat (bf16) in C, we can first convert the bf16 values to fp32 (32-bit floating point), perform the calculations, and then convert the results back to bf16.
```c
#include <stdint.h>
#include <stdio.h>
typedef union {
uint16_t bits;
float f;
} bf16_t;
static inline bf16_t fp32_to_bf16(float s) {
bf16_t h;
union {
float f;
uint32_t i;
} u = {.f = s};
if ((u.i & 0x7fffffff) > 0x7f800000){ /*NaN*/
h.bits = (u.i >> 16) | 64; /* force to quiet */
return h;
}
h.bits = (u.i + (0x7fff + ((u.i >> 0x10) & 1))) >> 0x10;
return h;
}
static inline float bf16_to_fp32(bf16_t h)
{
union {
float f;
uint32_t i;
}
u = {.i = (uint32_t)h.bits << 16};
return u.f;
}
void matrix_multiply_bfloat(const bf16_t* A, const bf16_t* B, bf16_t* C, int M, int N, int K) {
// Initialize the result matrix C with zero
for (int i = 0; i < M * K; i++) {
C[i].bits = 0;
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
float sum = 0.0f;
for (int k = 0; k < N; k++) {
float a = bf16_to_fp32(A[i * N + k]);
float b = bf16_to_fp32(B[k * K + j]);
sum += a * b;
}
C[i * K + j] = fp32_to_bf16(sum);
}
}
}
```
### `Idea 2`
However, if we go with the first idea, the computation will essentially still be in fp32. Since C does not support bf16 multiplication natively, any floating-point operations in C will be compiled using the fp32 format. To perform calculations directly with bf16 in C, we can use a divide-and-conquer approach, which involves handling the sign, exponent, and mantissa of bf16 values separately, along with preventing overflow or underflow to derive the final result accurately.
#### bfloat16_multiply
`bfloat16_multiply` function extracts the sign, exponent, and mantissa of BFLOAT16, performing the multiplication of the mantissas, adjusting the exponent, normalizing if necessary, and finally constructing the resulting BFLOAT16 value.
```c
static inline uint16_t bfloat16_multiply(uint16_t x, uint16_t y) {
// extract sign(1), exponent(8), mantissa(7)
uint16_t sign1 = (x & 0x8000);
uint16_t sign2 = (y & 0x8000);
uint16_t exp1 = (x & 0x7F80) >> 7;
uint16_t exp2 = (y & 0x7F80) >> 7;
uint16_t mant1 = (x & 0x007F);
uint16_t mant2 = (y & 0x007F);
// multiply processing
// mantissa multiply using uint32_t to avoid overflow
uint32_t mant_product = (uint32_t)mant1 * (uint32_t)mant2;
uint16_t ans_exp = exp1 + exp2 - 127;
// check if mantissa should be normalized
if (mant_product & 0x00020000) { // if more than 1.0
mant_product >>= 1;
ans_exp++;
}
// truncate the mantissa first 7 bits
uint16_t ans_mant = (mant_product >> 8) & 0x007F;
// if exp overflow or less than 0
if (ans_exp >= 255) {
ans_exp = 255; // inf
ans_mant = 0; // mant = 0
} else if (ans_exp <= 0) {
ans_exp = 0; // Denormalized or zero
ans_mant = 0;
}
uint16_t ans = sign1 ^ sign2;
ans |= (ans_exp << 8);
ans |= ans_mant;
return ans;
}
```
#### bloat16_add
`bloat16_add` extracts the sign, exponent, and mantissa of BFLOAT16. Depending on the signs of the inputs, it performs either addition or subtraction. Key operations include aligning exponents and mantissas and checking for overflows.
```c
static inline uint16_t bfloat16_add(uint16_t x, uint16_t y) {
// extract sign(1), exponent(8), mantissa(7)
uint16_t sign1 = (x & 0x8000);
uint16_t sign2 = (y & 0x8000);
uint16_t exp1 = (x & 0x7F80) >> 7;
uint16_t exp2 = (y & 0x7F80) >> 7;
uint16_t mant1 = (x & 0x007F);
uint16_t mant2 = (y & 0x007F);
// if sign is same, then add, otherwise, sub
if (sign1 == sign2) {
// align exponent
if (exp1 > exp2) {
mant2 >>= (exp1 - exp2);
} else {
mant1 >>= (exp2 - exp1);
exp1 = exp2;
}
// man add
uint32_t mant_sum = (uint32_t)mant1 + (uint32_t)mant2;
// check if disnormalized
if (mant_sum & 0x00020000) { // greater than 1.0
mant_sum >>= 1;
exp1++;
}
uint16_t ans_mant = (mant_sum >> 8) & 0x007F;
// check if exponent overflow
if (exp1 >= 255) {
exp1 = 255; // INF
ans_mant = 0;
}
return sign1 | (exp1 << 7) | ans_mant;
} else {
// sign is not same
if (exp1 > exp2) {
mant2 >>= (exp1 - exp2);
} else {
mant1 >>= (exp2 - exp1);
exp1 = exp2;
}
uint32_t mant_diff = (uint32_t)mant1 - (uint32_t)mant2;
// assure mantissa is positive
if (mant_diff == 0) {
return 0; // result is 0
} else if (mant_diff > 0) {
// positive result
uint16_t ans_mant = (mant_diff >> 8) & 0x007F;
return sign1 | (exp1 << 7) | ans_mant;
} else {
// negative result
uint16_t ans_mant = ((-mant_diff) >> 8) & 0x007F;
return sign2 | (exp1 << 7) | ans_mant;
}
}
}
```
#### bfloat16_matrix_multiply
`bfloat16_matrix_multiply` function multiplies two 2x2 matrices of BFLOAT16 values, and iterates over the rows of the first matrix and the columns of the second matrix, computing the dot product using the bfloat16_add and bfloat16_multiply functions.
```c
void bfloat16_matrix_multiply(uint16_t A[2][2], uint16_t B[2][2], uint16_t C[2][2]) {
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
C[i][j] = 0; // Initialize the result matrix
for (int k = 0; k < 2; k++) {
C[i][j] = bfloat16_add(C[i][j], bfloat16_multiply(A[i][k], B[k][j]));
}
}
}
}
```
#### print_matrix
`print_matrix` function displays a 2x2 BFLOAT16 matrix in hexadecimal format.
```c
// Function to print a BFLOAT16 matrix
void print_matrix(uint16_t matrix[2][2]) {
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
printf("%04X ", matrix[i][j]); // Print in hex format
}
printf("\n");
}
}
```
#### main
`main` initializes two example 2x2 BFLOAT16 matrices A and B, calls the multiplication function to compute the product, and stores the result in matrix C, finally, prints the resulting matrix.
```c
int main() {
// Example 2x2 BFLOAT16 matrices
uint16_t A[2][2] = { {0x3C00, 0x3D00}, // 1.0, 1.5 in BFLOAT16
{0x3E00, 0x3F00} // 2.0, 2.5 in BFLOAT16
};
uint16_t B[2][2] = { {0x3C00, 0x3D00}, // 1.0, 1.5 in BFLOAT16
{0x3E00, 0x3F00} // 2.0, 2.5 in BFLOAT16
};
uint16_t C[2][2]; // Result matrix
bfloat16_matrix_multiply(A, B, C);
printf("Result Matrix:\n");
print_matrix(C);
return 0;
}
```
:::danger
Can you improve the above instead of adding comments literally?
:::
### `Test`
#### test 1
```
uint16_t A[2][2] = { {0x3F9D, 0x401D}, // 1.234, 2.456 in BFLOAT16
{0x406B, 0x409C}}; // 3.678, 4.890 in BFLOAT16
uint16_t B[2][2] = { {0x4007, 0x3F9C}, // 2.111, 1.222 in BFLOAT16
{0x3EAA, 0x405C}}; // 0.333, 3.444 in BFLOAT16
```
#### result 1
<s>

</s>
:::danger
Do not use screenshots for plain text content, as this is inaccessible to visually impaired users.
:::
#### test 2
```
uint16_t A[2][2] = { {0x3DFB, 0x3EE9}, // 0.123, 0.123 in BFLOAT16
{0x3F49, 0x3EA4}}; // 0.789, 0.321 in BFLOAT16
uint16_t B[2][2] = {{0x40B2, 0x40DC}, // 5.567 6.890 in BFLOAT16
{0x40E3, 0x4103}}; // 7.123 8.234 in BFLOAT16
```
#### result 2

:::danger
Do not use screenshots for plain text content, as this is inaccessible to visually impaired users.
:::
#### test 3
```
uint16_t A[2][2] = { {0x7FC0, 0x4018}, // −1.125, 2.376 in BFLOAT16
{0x405F, 0x7FC0} }; // 3.487, −4.158 in BFLOAT16
uint16_t B[2][2] = { {0x400F, 0xBFCB}, // 2.237, −1.587 in BFLOAT16
{0x3EE9, 0x4069} }; // 0.456, 3.654 in BFLOAT16
```
#### result 3

:::danger
Do not use screenshots for plain text content, as this is inaccessible to visually impaired users.
:::
### C Compiled to Assembly
```c
.file "bfloat_matmul.c"
.text
.def bfloat16_multiply; .scl 3; .type 32; .endef
.seh_proc bfloat16_multiply
bfloat16_multiply:
pushq %rbp
.seh_pushreg %rbp
movq %rsp, %rbp
.seh_setframe %rbp, 0
subq $32, %rsp
.seh_stackalloc 32
.seh_endprologue
movl %edx, %eax
movl %ecx, %edx
movw %dx, 16(%rbp)
movw %ax, 24(%rbp)
movzwl 16(%rbp), %eax
andw $-32768, %ax
movw %ax, -10(%rbp)
movzwl 24(%rbp), %eax
andw $-32768, %ax
movw %ax, -12(%rbp)
movzwl 16(%rbp), %eax
sarl $7, %eax
andw $255, %ax
movw %ax, -14(%rbp)
movzwl 24(%rbp), %eax
sarl $7, %eax
andw $255, %ax
movw %ax, -16(%rbp)
movzwl 16(%rbp), %eax
andl $127, %eax
movw %ax, -18(%rbp)
movzwl 24(%rbp), %eax
andl $127, %eax
movw %ax, -20(%rbp)
movzwl -18(%rbp), %edx
movzwl -20(%rbp), %eax
imull %edx, %eax
movl %eax, -4(%rbp)
movzwl -14(%rbp), %edx
movzwl -16(%rbp), %eax
addl %edx, %eax
subl $127, %eax
movw %ax, -6(%rbp)
movl -4(%rbp), %eax
andl $131072, %eax
testl %eax, %eax
je .L2
shrl -4(%rbp)
movzwl -6(%rbp), %eax
addl $1, %eax
movw %ax, -6(%rbp)
.L2:
movl -4(%rbp), %eax
shrl $8, %eax
andl $127, %eax
movw %ax, -8(%rbp)
cmpw $254, -6(%rbp)
jbe .L3
movw $255, -6(%rbp)
movw $0, -8(%rbp)
jmp .L4
.L3:
cmpw $0, -6(%rbp)
jne .L4
movw $0, -6(%rbp)
movw $0, -8(%rbp)
.L4:
movzwl -10(%rbp), %eax
xorw -12(%rbp), %ax
movw %ax, -22(%rbp)
movzwl -6(%rbp), %eax
sall $8, %eax
movl %eax, %edx
movzwl -22(%rbp), %eax
orl %edx, %eax
movw %ax, -22(%rbp)
movzwl -8(%rbp), %eax
orw %ax, -22(%rbp)
movzwl -22(%rbp), %eax
addq $32, %rsp
popq %rbp
ret
.seh_endproc
.def bfloat16_add; .scl 3; .type 32; .endef
.seh_proc bfloat16_add
bfloat16_add:
pushq %rbp
.seh_pushreg %rbp
movq %rsp, %rbp
.seh_setframe %rbp, 0
subq $32, %rsp
.seh_stackalloc 32
.seh_endprologue
movl %edx, %eax
movl %ecx, %edx
movw %dx, 16(%rbp)
movw %ax, 24(%rbp)
movzwl 16(%rbp), %eax
andw $-32768, %ax
movw %ax, -16(%rbp)
movzwl 24(%rbp), %eax
andw $-32768, %ax
movw %ax, -18(%rbp)
movzwl 16(%rbp), %eax
sarl $7, %eax
andw $255, %ax
movw %ax, -2(%rbp)
movzwl 24(%rbp), %eax
sarl $7, %eax
andw $255, %ax
movw %ax, -20(%rbp)
movzwl 16(%rbp), %eax
andl $127, %eax
movw %ax, -4(%rbp)
movzwl 24(%rbp), %eax
andl $127, %eax
movw %ax, -6(%rbp)
movzwl -16(%rbp), %eax
cmpw -18(%rbp), %ax
jne .L7
movzwl -2(%rbp), %eax
cmpw %ax, -20(%rbp)
jnb .L8
movzwl -6(%rbp), %r8d
movzwl -2(%rbp), %eax
movzwl -20(%rbp), %edx
subl %edx, %eax
movl %eax, %ecx
sarl %cl, %r8d
movl %r8d, %eax
movw %ax, -6(%rbp)
jmp .L9
.L8:
movzwl -4(%rbp), %r8d
movzwl -20(%rbp), %eax
movzwl -2(%rbp), %edx
subl %edx, %eax
movl %eax, %ecx
sarl %cl, %r8d
movl %r8d, %eax
movw %ax, -4(%rbp)
movzwl -20(%rbp), %eax
movw %ax, -2(%rbp)
.L9:
movzwl -4(%rbp), %edx
movzwl -6(%rbp), %eax
addl %edx, %eax
movl %eax, -12(%rbp)
movl -12(%rbp), %eax
andl $131072, %eax
testl %eax, %eax
je .L10
shrl -12(%rbp)
movzwl -2(%rbp), %eax
addl $1, %eax
movw %ax, -2(%rbp)
.L10:
movl -12(%rbp), %eax
shrl $8, %eax
andl $127, %eax
movw %ax, -14(%rbp)
cmpw $254, -2(%rbp)
jbe .L11
movw $255, -2(%rbp)
movw $0, -14(%rbp)
.L11:
movzwl -2(%rbp), %eax
sall $7, %eax
movl %eax, %edx
movzwl -16(%rbp), %eax
orl %eax, %edx
movzwl -14(%rbp), %eax
orl %edx, %eax
jmp .L12
.L7:
movzwl -2(%rbp), %eax
cmpw %ax, -20(%rbp)
jnb .L13
movzwl -6(%rbp), %r8d
movzwl -2(%rbp), %eax
movzwl -20(%rbp), %edx
subl %edx, %eax
movl %eax, %ecx
sarl %cl, %r8d
movl %r8d, %eax
movw %ax, -6(%rbp)
jmp .L14
.L13:
movzwl -4(%rbp), %r8d
movzwl -20(%rbp), %eax
movzwl -2(%rbp), %edx
subl %edx, %eax
movl %eax, %ecx
sarl %cl, %r8d
movl %r8d, %eax
movw %ax, -4(%rbp)
movzwl -20(%rbp), %eax
movw %ax, -2(%rbp)
.L14:
movzwl -4(%rbp), %edx
movzwl -6(%rbp), %eax
subl %eax, %edx
movl %edx, -24(%rbp)
cmpl $0, -24(%rbp)
jne .L15
movl $0, %eax
jmp .L12
.L15:
cmpl $0, -24(%rbp)
je .L16
movl -24(%rbp), %eax
shrl $8, %eax
andl $127, %eax
movw %ax, -28(%rbp)
movzwl -2(%rbp), %eax
sall $7, %eax
movl %eax, %edx
movzwl -16(%rbp), %eax
orl %eax, %edx
movzwl -28(%rbp), %eax
orl %edx, %eax
jmp .L12
.L16:
movl -24(%rbp), %eax
negl %eax
shrl $8, %eax
andl $127, %eax
movw %ax, -26(%rbp)
movzwl -2(%rbp), %eax
sall $7, %eax
movl %eax, %edx
movzwl -18(%rbp), %eax
orl %eax, %edx
movzwl -26(%rbp), %eax
orl %edx, %eax
.L12:
addq $32, %rsp
popq %rbp
ret
.seh_endproc
.globl bfloat16_matrix_multiply
.def bfloat16_matrix_multiply; .scl 2; .type 32; .endef
.seh_proc bfloat16_matrix_multiply
bfloat16_matrix_multiply:
pushq %rbp
.seh_pushreg %rbp
pushq %rbx
.seh_pushreg %rbx
subq $56, %rsp
.seh_stackalloc 56
leaq 48(%rsp), %rbp
.seh_setframe %rbp, 48
.seh_endprologue
movq %rcx, 32(%rbp)
movq %rdx, 40(%rbp)
movq %r8, 48(%rbp)
movl $0, -4(%rbp)
jmp .L18
.L23:
movl $0, -8(%rbp)
jmp .L19
.L22:
movl -4(%rbp), %eax
cltq
leaq 0(,%rax,4), %rdx
movq 48(%rbp), %rax
addq %rax, %rdx
movl -8(%rbp), %eax
cltq
movw $0, (%rdx,%rax,2)
movl $0, -12(%rbp)
jmp .L20
.L21:
movl -12(%rbp), %eax
cltq
leaq 0(,%rax,4), %rdx
movq 40(%rbp), %rax
addq %rax, %rdx
movl -8(%rbp), %eax
cltq
movzwl (%rdx,%rax,2), %eax
movzwl %ax, %edx
movl -4(%rbp), %eax
cltq
leaq 0(,%rax,4), %rcx
movq 32(%rbp), %rax
addq %rax, %rcx
movl -12(%rbp), %eax
cltq
movzwl (%rcx,%rax,2), %eax
movzwl %ax, %eax
movl %eax, %ecx
call bfloat16_multiply
movzwl %ax, %edx
movl -4(%rbp), %eax
cltq
leaq 0(,%rax,4), %rcx
movq 48(%rbp), %rax
addq %rax, %rcx
movl -8(%rbp), %eax
cltq
movzwl (%rcx,%rax,2), %eax
movzwl %ax, %eax
movl -4(%rbp), %ecx
movslq %ecx, %rcx
leaq 0(,%rcx,4), %r8
movq 48(%rbp), %rcx
leaq (%r8,%rcx), %rbx
movl %eax, %ecx
call bfloat16_add
movl -8(%rbp), %edx
movslq %edx, %rdx
movw %ax, (%rbx,%rdx,2)
addl $1, -12(%rbp)
.L20:
cmpl $1, -12(%rbp)
jle .L21
addl $1, -8(%rbp)
.L19:
cmpl $1, -8(%rbp)
jle .L22
addl $1, -4(%rbp)
.L18:
cmpl $1, -4(%rbp)
jle .L23
nop
nop
addq $56, %rsp
popq %rbx
popq %rbp
ret
.seh_endproc
.section .rdata,"dr"
.LC0:
.ascii "%04X \0"
.text
.globl print_matrix
.def print_matrix; .scl 2; .type 32; .endef
.seh_proc print_matrix
print_matrix:
pushq %rbp
.seh_pushreg %rbp
movq %rsp, %rbp
.seh_setframe %rbp, 0
subq $48, %rsp
.seh_stackalloc 48
.seh_endprologue
movq %rcx, 16(%rbp)
movl $0, -4(%rbp)
jmp .L25
.L28:
movl $0, -8(%rbp)
jmp .L26
.L27:
movl -4(%rbp), %eax
cltq
leaq 0(,%rax,4), %rdx
movq 16(%rbp), %rax
addq %rax, %rdx
movl -8(%rbp), %eax
cltq
movzwl (%rdx,%rax,2), %eax
movzwl %ax, %eax
movl %eax, %edx
leaq .LC0(%rip), %rax
movq %rax, %rcx
call printf
addl $1, -8(%rbp)
.L26:
cmpl $1, -8(%rbp)
jle .L27
movl $10, %ecx
call putchar
addl $1, -4(%rbp)
.L25:
cmpl $1, -4(%rbp)
jle .L28
nop
nop
addq $48, %rsp
popq %rbp
ret
.seh_endproc
.section .rdata,"dr"
.LC1:
.ascii "Result Matrix:\0"
.text
.globl main
.def main; .scl 2; .type 32; .endef
.seh_proc main
main:
pushq %rbp
.seh_pushreg %rbp
movq %rsp, %rbp
.seh_setframe %rbp, 0
subq $64, %rsp
.seh_stackalloc 64
.seh_endprologue
call __main
movw $15360, -8(%rbp)
movw $15616, -6(%rbp)
movw $15872, -4(%rbp)
movw $16128, -2(%rbp)
movw $15360, -16(%rbp)
movw $15616, -14(%rbp)
movw $15872, -12(%rbp)
movw $16128, -10(%rbp)
leaq -24(%rbp), %rcx
leaq -16(%rbp), %rdx
leaq -8(%rbp), %rax
movq %rcx, %r8
movq %rax, %rcx
call bfloat16_matrix_multiply
leaq .LC1(%rip), %rax
movq %rax, %rcx
call puts
leaq -24(%rbp), %rax
movq %rax, %rcx
call print_matrix
movl $0, %eax
addq $64, %rsp
popq %rbp
ret
.seh_endproc
.def __main; .scl 2; .type 32; .endef
.ident "GCC: (x86_64-mcf-seh-rev0, Built by MinGW-Builds project) 14.2.0"
.def printf; .scl 2; .type 32; .endef
.def putchar; .scl 2; .type 32; .endef
.def puts; .scl 2; .type 32; .endef
```
## RISCV32
### matmul
Main Function `matmul` initialize matrices A and B and set matrix C to zero, respectively compute the four elements of matrix C using the `bfloat16_mul` and `bfloat16_add`. Finally, use `printResult` to
output the results, displaying the contents of matrix C.
```riscv32=
.data
.align 2
A: .word 0x3F9DF3B6, 0x401D2F1B, 0x406B645A, 0x409C7AE1 # A[0][0] = 1.234, A[0][1] = 2.456 (float32), A[1][0] = 3.678, A[1][1] = 4.890 (float32)
.align 2
B: .word 0x40071AA0, 0x3F9C6A7F, 0x3EAA7EFA, 0x405C6A7F # B[0][0] = 2.111, B[0][1] = 1.222 (float32), B[1][0] = 0.333, B[1][1] = 3.444 (float32)
.align 2
C: .word 0, 0, 0, 0
.align 2
sign_mask: .word 0x80000000
exp_mask: .word 0x7F800000
man_mask: .word 0x7F0000
bfloat_sign_mask: .word 0x8000
bfloat_exp_mask: .word 0x7F80
bfloat_man_mask: .word 0x007F
str1: .string "C =" # output message 1
str2: .string ", " # output message 2
str3: .string "overflow!" # output message 3
.text
matmul:
la a6, C
# calculate C[0][0] = A[0][0] * B[0][0] + A[0][1] * B[1][0]
lw a0, A # a0 = A[0][0]
lw a1, B # a1 = B[0][0]
jal bfloat16_mul # a0 * a1
mv a4, a3
lw a0, A+4 # a0 = A[0][1]
lw a1, B+8 # a1 = B[0][1]
jal bfloat16_mul # a0 * a1
mv a5, a3
jal bfloat16_add
sw a3, 0(a6)
# calculate C[0][1] = A[0][0] * B[0][1] + A[0][1] * B[1][1]
lw a0, A-4 # a0 = A[0][0]
lw a1, B-4 # a1 = B[0][1]
jal bfloat16_mul # a0 * a1
mv a4, a3
lw a0, A+4 # a0 = A[0][1]
lw a1, B+8 # a1 = B[1][1]
jal bfloat16_mul # a0 * a1
mv a5, a3
jal bfloat16_add
addi a6, a6, 4
sw a3, 0(a6)
# calculate C[1][0] = A[1][0] * B[0][0] + A[1][1] * B[1][0]
lw a0, A+4 # a0 = A[1][0]
lw a1, B-12 # a1 = B[0][0]
jal bfloat16_mul # a0 * a1
mv a4, a3
lw a0, A+4 # a0 = A[1][1]
lw a1, B+8 # a1 = B[1][0]
jal bfloat16_mul # a0 * a1
mv a5, a3
jal bfloat16_add
addi a6, a6, 4
sw a3, 0(a6)
# calculate C[1][1] = A[1][0] * B[0][1] + A[1][1] * B[1][1]
lw a0, A-4 # a0 = A[1][0]
lw a1, B-4 # a1 = B[0][1]
jal bfloat16_mul # a0 * a1
mv a4, a3
lw a0, A+4 # a0 = A[0][1]
lw a1, B+8 # a1 = B[1][1]
jal bfloat16_mul # a0 * a1
mv a5, a3
jal bfloat16_add
addi a6, a6, 4
sw a3, 0(a6)
# Call the function to print the result
jal ra, printResult
# finish
li a7, 10
ecall
```
### bfloat16_mul
`bfloat16_mul` executes multiplication in bfloat16 format. It extracts the sign, exponent, and mantissa from the input numbers and calculates the product, and includes handling for overflow and normalization, printing an error message in case of overflow.
```riscv32=
# bfloat16 multiplier
# input: a0, a1 (ieee 754 float)
# output: a3 (bfloat16)
# t0, t1: A_sign, B_sign
# t2, t3: A_exp, B_exp
# t4, t5: A_man, B_man
# t6: sign_mask, exp_mask, man_mask, overflow, normalize
bfloat16_mul:
lw t6, sign_mask # t6 = sign_mask
and t0, a0, t6 # A_sign
and t1, a1, t6 # B_sign
xor a3, t0, t1 # calculate new sign
srli a3, a3, 16
lw t6, exp_mask # t6 = exp_mask
and t2, a0, t6 # A_exp
srli t2, t2, 23
addi t2, t2, -127
and t3, a1, t6 # B_exp
srli t3, t3, 23
addi t3, t3, -127
add t2, t2, t3 # calculate new exp
addi t2, t2, 127
slli t2, t2, 7
lw t6, man_mask # t6 = man_mask
and t4, a0, t6 # A_mant
srli t4, t4, 16
addi t4, t4, 0x0080 # add 1.
and t5, a1, t6 # B_mant
srli t5, t5, 16
addi t5, t5, 0x0080
mul t4, t4, t5 # calculate new mant
srli t6, t4, 15 # check if overflow
bne t6, zero, overflow
andi t6, t4, 0x80 # check if need to normalize
beq t6, zero, normalize
or a3, a3, t2
or a3, a3, t5
normalize:
srli t3, t3, 1
addi t2, t2, 1
jr ra
overflow:
la a0, str3 # Load the address of the string
li a7, 4 # System call code for printing a string
ecall # Print the string
li a7, 10
ecall
```
### bfloat16_add
`bfloat16_add` performs addition in bfloat16 format. It extracts the sign, exponent, and mantissa and performs addition or subtraction based on the signs, and also contains handling for normalization.
```riscv32=
# bfloat16 adder
# input: a4, a5 (bfloat16)
# output: a3 (bfloat16)
# t0, t1: A_sign, B_sign
# t2, t3: A_exp, B_exp
# t4, t5: A_man, A_man
# t6: bfloat_sign_mask, bfloat_exp_mask, man_mask, overflow, normalize
bfloat16_add:
# sign
lw t6, bfloat_sign_mask
and t0, a4, t6 # A_sign
and t1, a5, t6 # B_sign
# exp
lw t6, bfloat_exp_mask
and t2, a4, t6 # A_exp
and t3, a5, t6 # B_exp
# man
lw t6, bfloat_man_mask
and t4, a4, t6 # A_man
and t5, a5, t6 # B_man
# calculate exact exp
srli t2, t2, 7
srli t3, t3, 7
addi t2, t2, 127
addi t3, t3, 127
# calculate exact man
ori t4, t4, 0x0080 # add 1.
ori t5, t5, 0x0080
# align man
beq t2, t3, add_man # A_exp = B_exp, add
blt t2, t3, shift_a # A_exp < B_exp, shift
sub t6, t2, t3 # A_exp - B_exp
srl t4, t4, t6 # shift A_man
j add_man
shift_a:
sub t6, t3, t2 # B_exp - A_exp
srl t5, t5, t6 # shift B_man
j add_man
add_man:
beq t0, t1, same_sign # 符號相同,直接加
blt t4, t5, sub_man # A_man < B_man
sub t4, t4, t5 # A_man - B_man
j normalize_man
same_sign:
add t4, t4, t5
sub_man:
sub t4, t5, t4 # B_man - A_man
normalize_man:
# check if normalize
srli t6, t4, 8 # man bit 9 != 1
bnez t6, combine
slli t0, t0, 1 # shift man
addi t2, t2, 1 # exp + 1
combine:
srli t0, t0, 1 # man shift back
or a3, t0, t2
or a3, a3, t4
jr ra
```
### printResult
`printResult` is responsible for printing the four elements of the result matrix C, which demonstrates column-major order.
```riscv32=
printResult:
mv t0, a6 # Save output value (C) in temporary register t0
addi t0, t0, -12
la a0, str1 # Load the address of the string "C ="
li a7, 4 # System call code for printing a string
ecall # Print the string
mv a0, t0 # Move the original input value (X) to a0 for printing
li a7, 1 # System call code for printing an integer
ecall # Print the integer (element of matrix C)
la a0, str2 # Load the address of the string", "
li a7, 4 # System call code for printing a string
ecall # Print the string
addi t0, t0, 4
lw a0, 0(t0) # Move the original input value (X) to a0 for printing
li a7, 1 # System call code for printing an integer
ecall # Print the integer (element of matrix C)
la a0, str2 # Load the address of the string", "
li a7, 4 # System call code for printing a string
ecall
addi t0, t0, 4
lw a0, 0(t0) # Move the original input value (X) to a0 for printing
li a7, 1 # System call code for printing an integer
ecall # Print the integer (element of matrix C)
la a0, str2 # Load the address of the string", "
li a7, 4 # System call code for printing a string
ecall
addi t0, t0, 4
lw a0, 0(t0) # Move the original input value (X) to a0 for printing
li a7, 1 # System call code for printing an integer
ecall # Print the integer (element of matrix C)
ret # Return to the caller
```
## Analysis
analyze by using [Ripes](https://github.com/mortbopet/Ripes) Simulator
### Five stage RISC pipeline
* #### Instruction fetch (IF)
Instruction is fetched from the memory
* #### Instruction decode and register fetch (ID)
Decode the instruction
* #### Execute (EX)
Perform the operation specified by the instruction
* #### Memory access (MEM)
Access data from memory or store data into memory
* #### Register write back (WB)
Store the result in the destination location

All the assembly code will go through the stages like the example below.


#### IF

* PC = PC + 4
* Program Counter (PC) gives next instruction address to instruction memory.
* The instruction `sril` is in the memory address `0x130`
* Instruction `sril` can be translated to `0x017e5e13` from the figure above.
#### ID

* Instruction 0x01ef5fe33 is decoded into three parts:
opcode AND(0x1c)
R1 idx(0x0b) & R2 idx(0x1f)
Imm. 0xdeadbeef
* Read the value from R1 idx and R2 idx
#### EX

* ALU adds two operand together which one is r1_out and another is imm_out.
* PC and next PC value just send through the stage without any operation.
#### MEM

* `srli` instruction has no operation in `MEM`
#### WB

* Write `0x3f800000` back to `Wr data`
### Pipeline Data Hazard
Read After Write (RAW): When one instruction depends on the result of a previous instruction that has not yet completed. In this case, several instructions read values from the same registers without ensuring that previous writes have finished.
There is some ways to solve it:
**1. Stall Insertions:** Insert no-operation (nop) instructions between dependent operations, though this reduces performance.
**2. Reordering Instructions:** Consider reordering independent instructions to allow time for previous computations to complete.
**For example:**
If bfloat16_mul takes several cycles to complete, mv a4, a3 may be executed before a3 is ready.
```
lw a0, A # a0 = A[0][0]
lw a1, B # a1 = B[0][0]
jal bfloat16_mul # a0 * a1
mv a4, a3 # Hazard
```
While line 31 is in IF stage, but register a3 has not updated. Thus, we need to add two nop, as shown below (red words).



### Performance

## Reference
[https://marz.utk.edu/my-courses/cosc230/book/example-risc-v-assembly-programs/#call_c](https://marz.utk.edu/my-courses/cosc230/book/example-risc-v-assembly-programs/#call_c)
[https://en.wikipedia.org/wiki/IEEE_754](https://en.wikipedia.org/wiki/IEEE_754)
[https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus)
[https://hackmd.io/@sysprog/H1TpVYMdB](https://hackmd.io/@sysprog/H1TpVYMdB)