# Matrix multiplication using bfloat16 contributed by [Terry7Wei7](https://github.com/Terry7Wei7) Common data types used in Machine Learning We start with the basic understanding of different floating point data types, which are also referred to as "precision" in the context of Machine Learning. The size of a model is determined by the number of its parameters, and their precision, typically one of float32, float16 or bfloat16 (image below from: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/). ![](https://hackmd.io/_uploads/Hyg2Ituea.png) Float32 (FP32) stands for the standardized IEEE 32-bit floating point representation. With this data type it is possible to represent a wide range of floating numbers. In FP32, 8 bits are reserved for the "exponent", 23 bits for the "mantissa" and 1 bit for the sign of the number. In addition to that, most of the hardware supports FP32 operations and instructions. In the float16 (FP16) data type, 5 bits are reserved for the exponent and 10 bits are reserved for the mantissa. This makes the representable range of FP16 numbers much lower than FP32. This exposes FP16 numbers to the risk of overflowing (trying to represent a number that is very large) and underflowing (representing a number that is very small). For example, if you do 10k * 10k you end up with 100M which is not possible to represent in FP16, as the largest number possible is 64k. And thus you'd end up with NaN (Not a Number) result and if you have sequential computation like in neural networks, all the prior work is destroyed. Usually, loss scaling is used to overcome this issue, but it doesn't always work well. A new format, bfloat16 (BF16), was created to avoid these constraints. In BF16, 8 bits are reserved for the exponent (which is the same as in FP32) and 7 bits are reserved for the fraction. This means that in BF16 we can retain the same dynamic range as FP32. But we lose 3 bits of precision with respect to FP16. Now there is absolutely no problem with huge numbers, but the precision is worse than FP16 here. # c code ``` #include <stdio.h> #include <stdint.h> 定义一个联合体,用于浮点数和整数之间的转换 typedef union { float f; # 浮点数部分 unsigned int i; # 整数部分 } float_to_int; 定义一个函数,将单精度浮点数转换为bfloat16格式 float fp32_to_bf16(float x) { float y = x; # 复制输入值到y int *p = (int *) &y; # 创建整数指针,以便访问浮点数的二进制表示 unsigned int exp = *p & 0x7F800000; # 提取浮点数指数部分 unsigned int man = *p & 0x007FFFFF; # 提取浮点数尾数部分 # 如果是零或无穷大/NaN,直接返回输入值x if (exp == 0 && man == 0) return x; if (exp == 0x7F800000) return x; # 对规范化数进行四舍五入操作 # 先创建一个新的浮点数r,用于四舍五入 float r = x; int *pr = (int *) &r; *pr &= 0xFF800000; # r的指数部分与x相同 r /= 0x100; # 右移8位,相当于除以256,实现四舍五入 y = x + r; # 将四舍五入后的值加回x中 *p &= 0xFFFF0000; # 清除y的小数部分,保留整数部分 return y; # 返回bfloat16值 } int main() { # 定义两个2x2矩阵A和B,以及一个结果矩阵C,都初始化为零 float A[2][2] = { {1.2345, 2.3456}, {3.4567, 4.5678} }; float B[2][2] = { {0.1234, 0.2345}, {0.3456, 0.4567} }; float C[2][2] = { {0.0, 0.0}, {0.0, 0.0} }; # 嵌套循环执行矩阵乘法 for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { for (int k = 0; k < 2; k++) { # 使用fp32_to_bf16函数将bfloat16进行转换 # 执行矩阵乘法并累加结果到矩阵C C[i][j] += fp32_to_bf16(A[i][k]) * fp32_to_bf16(B[k][j]); } } } # 打印结果矩阵C printf("Result Matrix C:\n"); for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { printf("%f ", C[i][j]); } printf("\n"); } return 0; } ``` # Assembly code **RISC-V** ``` .data .align 2 A: .word 0x3F9A2098, 0x3FA563B4, 0x3FCA3D70, 0x3FDB7866 # 示例的bfloat16值(1.2345, 2.3456, 3.4567, 4.5678) .align 2 B: .word 0x3E9BEC4F, 0x3EA1670A, 0x3EB851EB, 0x3EBB0192 # 示例的bfloat16值(0.1234, 0.2345, 0.3456, 0.4567) .align 2 C: .word 0, 0, 0, 0 # 初始化结果矩阵C为零 .text .globl main main: # 矩阵维度:A是2x2,B是2x2,C是2x2 # 循环变量 li s1, 2 # 行数 li s2, 2 # 列数 li s3, 2 # 共享的内部维度 # 循环计数器 li s4, 0 # i(行索引) loop_i: li s5, 0 # j(列索引) loop_j: # 初始化C[i][j]结果 li s6, 0 # C[i][j]的累加器 # 执行点积的循环 li s7, 0 # k(内部循环计数器) loop_k: # Calculate the memory addresses of A[i][k] and B[k][j] # 计算A[i][k]和B[k][j]的内存地址 slli t4, s4, 3 # t4 = s4 * 8,计算A[i][k]在内存中的偏移量 slli t5, s7, 2 # t5 = s7 * 4,计算B[k][j]在内存中的偏移量 add t6, t4, t5 # t6 = t4 + t5,计算内存地址的和 # Load A[i][k] and B[k][j] using calculated memory addresses # 使用计算的内存地址加载A[i][k]和B[k][j] lw t0, A # 加载A[i][k](t6) lw t1, B # 加载B[k][j](t6) # Perform bfloat16 multiplication (simplified, no rounding) # 执行bfloat16乘法(简化,不进行四舍五入) # 替代mul指令,使用循环累加t0和t1来计算t2 = A[i][k] * B[k][j] li t2, 0 # 初始化t2为零 addi t3, s7, 0 # 将t3初始化为内部循环计数器s7的值 # 执行bfloat16乘法(简化,不进行四舍五入) # 替代mul指令,使用循环累加t0和t1来计算t2 = A[i][k] * B[k][j] li t2, 0 # 初始化t2为零 addi t3, s7, 0 # 将t3初始化为内部循环计数器s7的值 multiply_loop: beqz t3, multiply_done # 如果t3为零,跳转到multiply_done add t2, t2, t0 # t2 = t2 + t0 addi t3, t3, -1 # 减少t3的值,继续循环 j multiply_loop # 跳转回multiply_loop multiply_done: # 将结果累加到累加器s5 add s5, s5, t2 # 增加内部循环计数器 addi s7, s7, 1 blt s7, s3, loop_k # # Store the result in C[i][j] # 存储结果到C[i][j] slli s8, s4, 3 # s8 = s4 * 8,计算C[i][j]在内存中的偏移量 slli s9, s5, 2 # s9 = s5 * 4,计算C[i][j]在内存中的偏移量 add s10, s8, s9 # s10 = s8 + s9,计算内存地址的和 sw s6, C(s10) # 存储s6的值到内存地址C[i][j] # 增加列索引 addi s5, s5, 1 blt s5, s2, loop_j # 增加行索引 addi s4, s4, 1 blt s4, s1, loop_i # 打印结果 li t0, 0 # 将t0初始化为零,用于打印 li t1, 0 # 将t1初始化为零,用于打印 li t2, 0x100 # 用于执行打印操作的系统调用号码 li t3, 1 # 文件描述符,标准输出 # 触发系统调用来打印矩阵C的值 #ecall #exit nop # stop(no tab) Exit: lw a0, 0(t0) li a7, 93 # Exit syscall number #ecall #li a7, 10 ecall ``` # **Analasis** We test our code using **Ripes simulator**. **Pseudo instruction** Put code above into editor and we will see that Ripe doesn’t execute it literally. Instead, it replace pseudo instruction into equivalent one, and change register name from ABI name to sequencial one. The translated code looks like: ``` 00000000 <main>: 0: 00200493 addi x9 x0 2 4: 00200913 addi x18 x0 2 8: 00200993 addi x19 x0 2 c: 00000a13 addi x20 x0 0 00000010 <loop_i>: 10: 00000a93 addi x21 x0 0 00000014 <loop_j>: 14: 00000b13 addi x22 x0 0 18: 00000b93 addi x23 x0 0 0000001c <loop_k>: 1c: 003a1e93 slli x29 x20 3 20: 002b9f13 slli x30 x23 2 24: 01ee8fb3 add x31 x29 x30 28: 10000297 auipc x5 0x10000 2c: fd82a283 lw x5 -40 x5 30: 10000317 auipc x6 0x10000 34: fe032303 lw x6 -32 x6 38: 00000393 addi x7 x0 0 3c: 000b8e13 addi x28 x23 0 40: 00000393 addi x7 x0 0 44: 000b8e13 addi x28 x23 0 00000048 <multiply_loop>: 48: 000e0863 beq x28 x0 16 <multiply_done> 4c: 005383b3 add x7 x7 x5 50: fffe0e13 addi x28 x28 -1 54: ff5ff06f jal x0 -12 <multiply_loop> 00000058 <multiply_done>: 58: 007a8ab3 add x21 x21 x7 5c: 001b8b93 addi x23 x23 1 60: fb3bcee3 blt x23 x19 -68 <loop_k> 64: 003a1c13 slli x24 x20 3 68: 002a9c93 slli x25 x21 2 6c: 019c0d33 add x26 x24 x25 70: 10000d17 auipc x26 0x10000 74: fb6d2823 sw x22 -80 x26 78: 001a8a93 addi x21 x21 1 7c: f92acce3 blt x21 x18 -104 <loop_j> 80: 001a0a13 addi x20 x20 1 84: f89a46e3 blt x20 x9 -116 <loop_i> 88: 00000293 addi x5 x0 0 8c: 00000313 addi x6 x0 0 90: 10000393 addi x7 x0 256 94: 00100e13 addi x28 x0 1 98: 00000013 addi x0 x0 0 0000009c <Exit>: 9c: 0002a503 lw x10 0 x5 a0: 05d00893 addi x17 x0 93 a4: 00000073 ecall ``` In each row it denotes address in instruction memory, instruction’s machine code (in hex) and instruction itself respectively. # **5-stage pipelined processor** Now we can choose a processor to run this code. Ripes provide four kinds of processor for us to choose, including single cycle processor, 5-stage processor, 5-stage with hazard detection and 5-stage with forward and hazard detection. Here we choose the 5 stage processor. Its block diagram look like this: ![](https://hackmd.io/_uploads/HJ7S9Fdea.png) The “5-stage” means this processor using five-stage pipeline to parallelize instructions. The stages are: 1. Instruction fetch (IF) 1. Instruction decode and register fetch (ID) 1. Execute (EX) 1. Memory access (MEM) 1. Register write back (WB) You can see that each stage is separated by pipeline registers (the rectangle block with stage names on its each side) in the block diagram. Instruction in different type of format will go through 5 stages with different signal turned on. Let’s discuss each format in detail with example. **I-type format** The first instruction in this program is addi x21 x0 0. Let’s see how it go through each stage. **1.Instruction fetch (IF)** ![](https://hackmd.io/_uploads/By2i9cOxa.png) The instruction’s address is at 0x00000010, and the next instruction’s address is at 0x00000014(PC+4). We obtained the instruction 0x00000a93 from the instruction memory. **2.Instruction decode and register fetch (ID)** ![](https://hackmd.io/_uploads/HJBBko_x6.png) Instruction 0x10000a93 is decoded to three part: opcode = addi Wr idx = 0x14 imm. = 0x00000000 (0x00000 in upper 12 bits) Reg 1 and Reg 2 read value from 0x00 register, so their value are both 0x00000000 too. Current PC value (0x00000010) and next PC value (0x0000000C) are just send through this stage, we don’t use them. **3.Execute (EX)** ![](https://hackmd.io/_uploads/rJ03eiugp.png) First level multiplexers choose value come from Reg 1 and Reg 2, but this is an I-type format instruction, we don’t use them. So they are filtered by second level multiplexer. Reg 1 and Reg 2 are also send to branch block, but no branch is taken. Second level multiplexer choose value come from current PC value (upper one) and immediate (lower one) as Op1 and Op2 of ALU. ALU add two operand togeher, so the Res is equal to 0x00000002. Next PC value (0x0000000c) and Wr idx (0x13) are just send through this stage, we don’t use them. **4.Memory access (MEM)** ![](https://hackmd.io/_uploads/S1UgQo_ea.png) Res from ALU is send to 3 ways: Pass through this stage and go to WB stage Send back to EXE stage for next instruction to use Use as data memory address. Memory read data at address 0x00000004, so Read out is equal to 0x00200913. The table below denotes the data section of memory. ![](https://hackmd.io/_uploads/rJByEsuxa.png) **5.Register write back (WB)** ![](https://hackmd.io/_uploads/S191SjOxT.png) The multiplexer choose Res from ALU as final output. So the output value is 0x00000002. The output value and Wr idx are send back to registers block. Finally, the value 0x00000002 will be write into x9 register, whose ABI name is s1. **After all these stage are done, the register is updated like this:** ![](https://hackmd.io/_uploads/BkKaBidla.png) ![](https://hackmd.io/_uploads/rklJ8sdxp.png) # **Base RISC-V Instruction Formats** ![](https://hackmd.io/_uploads/rJUIFK_ea.png) # **src** Floating point representation https://huggingface.co/blog/hf-bitsandbytes-integration https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Zeros_and_infinities https://zhuanlan.zhihu.com/p/657886517 https://blog.csdn.net/qq_36533552/article/details/105885714 https://hackmd.io/@RinHizakura/SJUTL6E4w/%2F%40RinHizakura%2FBkE2mlOww Instructions https://marz.utk.edu/my-courses/cosc230/book/example-risc-v-assembly-programs/ https://blog.csdn.net/weixin_43914889/article/details/106108377 https://blog.csdn.net/q876507447/article/details/105477887 https://tclin914.github.io/16df19b4/ https://cloud.tencent.com/developer/article/1855372 https://zhuanlan.zhihu.com/p/430113452 https://fmrt.gitbooks.io/riscv-spec-v2-cn/content/rv32i/int-comp-instrs.html RISC-V Assembly Programmer's Manual https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md https://hackmd.io/@w4K9apQGS8-NFtsnFXutfg/B1Re5uGa5 https://blog.csdn.net/qq_41019681/article/details/118544526 https://graphics.stanford.edu/~seander/bithacks.html#OperationCounting RISC-V "F" Expand https://zhuanlan.zhihu.com/p/339497543 Type about unsigned int/int https://cloud.tencent.com/developer/article/2152437