# Matrix Multiplication Accelerator implemented with Chisel > 黃啟維 [GitHub](https://github.com/rhway666/CA2024_final) ## Introduction Matrix multiplication is a critical operation in deep learning and scientific computing. This project focuses on developing a high-performance systolic array accelerator using Chisel, enabling efficient computation of General Matrix Multiplication (GEMM). The design is modular and scalable, featuring a systolic array composed of 4x4 Processing Elements (PEs) grouped into 2x2 Tiles. ### Advantages of Hardware Accelerator - Higher Throughput - Multiple operations can be computed in parallel, achieving higher operations-per-second than a purely CPU-bound approach - Scalable and Modular - Base 4×4 array can be replicated or tiled to handle arbitrarily large matrices - A single design can address a wide range of matrix sizes through a flexible tiling strategy - Efficient Data Reuse - Leveraging weight-stationary or output-stationary dataflows minimizes redundant memory access - Data Type Flexibility - Support both int4 and int8 data types for diverse workloads ## Environment Set Up Following the instruction of [chisel-bootcamp](https://github.com/freechipsproject/chisel-bootcamp/blob/master/Install.md) on MacOS Chisel version: 3.4.4 Java version: 11.0.21 Scala version (Properties): 2.12.10 ## Systolic Array Hardware System Design in Chisel ### Dataflow Design #### Weight Stationary: Fixed weights in PEs, with input matrix rows flowing horizontally across the array. - Minimizes memory movement for weights, saving energy and bandwidth - ideal for deep learning inference workloads with large, mostly static weights #### Output Stationary: Fixed output accumulations in PEs, with input matrix rows and weight matrix columns flowing through the array. - Keeps partial sums local, reducing memory bandwidth requirements for intermediate data ### Systolic Array Design ![image](https://hackmd.io/_uploads/HyXMEtJv1x.png) For a matrix multiplication $C=A*B+D$ - PE Module - Output stationary PE![PE](https://hackmd.io/_uploads/H129QeKvJe.png) - input - a_input(defualt 8bits) - w_input(defualt 8bits) - control - valid (1bit) (asserting if it's the starting partial sum) - Register - PS_reg (3*defualt) - W_reg (defualt) - A_reg (defualt) - output - partial sum (for outputing tje result) - fwd_a (pass to the right PE) - fwd_w (pass to the bottom PE) - TODO: - Core compute unit for matrix multiplication (MAC operations) - Dynamically supports int4 and int8 data types - Handles weight-stationary and output-stationary dataflows - Tile (4x4 PE Grid) - Top-Level Module (2x2 Tile Grid) - Organizes 2x2 Tiles into a systolic array for larger matrices (e.g., 16x16) - Row and Column Buffers - Horizontal buffers propagate matrix A row - Vertical buffers propagate matrix B columns - Top-Level Module - Coordinates the systolic array's operation, including data movement, initialization, and control logic. ### Supported Data Types Data Types: input int8 (TODO : int4) ### Interface Requirements #### 1. Data Input Interface [//]: #![image](https://hackmd.io/_uploads/Sy9XPCVU1g.png) - Matrix A Rows: Stream horizontally across the array - Matrix W Columns: Stream vertically into the array - Output Stationary - W will be transposed and padded - A will be padded - Weight Stationary - W will be padded - A will be transposed and padded - Delay for input data - Output Stationary - (3N - 2) cycle to finish - Weight Stationary - weight being preloaded - (2N - 1) cycle to finish ##### Software Solution During the data preparation phase the data is rearranged into the required format ->pad zero ```scala= val activations = Seq( Seq(1, 2, 3, 4, 0, 0, 0, 0), Seq(0, 5, 6, 7, 8, 0, 0, 0), Seq(0, 0, 9, 10, 11, 12, 0, 0), Seq(0, 0, 0, 13, 14, 15, 16, 0) ) val weights = Seq( Seq(1, 5, 9, 13, 0, 0, 0, 0), Seq(0, 2, 6, 10, 14, 0, 0, 0), Seq(0, 0, 3, 7, 11, 15, 0, 0), Seq(0, 0, 0, 4, 8, 12, 16, 0) ) ``` ##### Hardware Solution Add shift reg for delay input ![image](https://hackmd.io/_uploads/Bk8TJaawye.png) Add a padding control for the shift registers ![image](https://hackmd.io/_uploads/S1ReuD0PJg.png) #### 2. Control Signals - valid(bool) (for reset) - start(bool) (to start the systolic array) - done(bool) (when can ) ![SA4x4](https://hackmd.io/_uploads/S1T6ftAPye.png) ## GEMM Result Verification ### Test #### PE module test 1. Uint8 input test value - a_fwd - w_fwd ![image](https://hackmd.io/_uploads/rkGOzsCP1e.png) 2. reg value - a_reg - w_reg 3. testing partial sum reg for max input Uint8 ```scala= dut.io.in_a.poke(255.U) // Max value for 8-bit dut.io.in_w.poke(255.U) // Max value for 8-bit dut.clock.step(1) // First operation: 255 * 255 = 65025 assert(dut.io.outPS.peek().litValue == 65025, "Partial sum should be 22 after second operation") assert(dut.io.fwd_a.peek().litValue == 255, "fwd_a should forward 2") assert(dut.io.fwd_w.peek().litValue == 255, "fwd_w should forward 5") println(s"Partial sum for overflow test: ${dut.io.outPS.peek().litValue}") ``` ![image](https://hackmd.io/_uploads/HyFUmoCvyg.png) #### Systolic Array input data software solution Ends in 10 cycles (3N-2) N is thhe size of systolic array ![image](https://hackmd.io/_uploads/S1t5QPCwJg.png) $$ W = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \\ 13 & 14 & 15 & 16 \end{bmatrix}, \quad A = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \\ 13 & 14 & 15 & 16 \end{bmatrix} $$ $$ {W \times A = SP} = \begin{bmatrix} 90 & 100 & 110 & 120 \\ 202 & 228 & 254 & 280 \\ 314 & 356 & 398 & 440 \\ 426 & 484 & 542 & 600 \end{bmatrix} $$ ##### Using Chisel testing ![image](https://hackmd.io/_uploads/SJjVLtAPye.png) ### Constraints - Matrix Size Constraints - (4x4) x (4x4) or smaller - Intput Datatype - Uint8 for W and A - Output partial sum - 24bit reg for partial sum in each PE ### Interface Requirements Software Verification Compare hardware results with software GEMM implementations Use ChiselTest for Verification ### test case Can support n x n GEMM 1. 2x2 Small Matrices - padding module 2. 4x4 Small-Scale Matrices - Uint8 vec input 3. 8x8 Larger Matrices or NxN Matrices - tiling module(TODO) ## Future Directions 1. Floating-Point Support Add FP16/FP32 operations for broader workloads 2. Sparse Matrix Optimization Introduce techniques for handling sparse matrices efficiently 3. Scalability Extend to NxN systolic arrays for even larger matrix operations ## Reference https://github.com/freechipsproject/chisel-bootcamp https://github.com/ucb-bar/gemmini