# Custom RISC-V instructions to Accelerate LLM Inference > 章劉軒瑋 ## LLM ### what is LLM? A large language model (LLM) is a type of machine learning model designed for natural language processing tasks such as language generation. LLMs are language models with many parameters, and are trained with self-supervised learning on a vast amount of text. ### Tokenization As machine learning algorithms process numbers rather than text, the text must be converted to numbers. In the first step, a vocabulary is decided upon, then integer indices are arbitrarily but uniquely assigned to each vocabulary entry, and finally, an embedding is associated to the integer index. Tokenization also compresses the datasets. Because LLMs generally require input to be an array that is not jagged, the shorter texts must be "padded" until they match the length of the longest one. How many tokens are, on average, needed per word depends on the language of the dataset. ### The Steps of Training a LLM: Pre-training, Fine-tuning, and Reinforcement Learning #### Pre-training Teach the model the general structure, grammar, and meaning of language by training it on large-scale text data. The model is trained on massive datasets such as web pages, books, and Wikipedia. The training tasks often involve self-supervised learning, for example: * Autoregressive generation: Predict the next word based on previous words (e.g., used in GPT models). * Masked Language Modeling: Randomly mask certain words in a sentence and train the model to predict them (e.g., used in BERT). Through this training, the model learns patterns, word relationships, and contextual dependencies in the language. #### Fine-tuning Focus the model on a specific task (e.g., translation, question answering, or sentiment analysis) to improve its performance in that domain. The model is further trained using labeled datasets tailored to the target task, such as: * Question-answering datasets (e.g., SQuAD) for answering questions. * Translation datasets for language conversion tasks. The model’s parameters are adjusted based on these task-specific datasets, refining its understanding. This stage is more efficient as it builds on the foundational knowledge acquired during pre-training. #### Reinforcement Learning Further optimize the model to produce outputs aligned with specific requirements, such as user preferences, content quality, or ethical standards. A Reward Model is introduced to evaluate the quality of the model’s outputs. Reinforcement Learning algorithms are used to adjust the model’s behavior: * For instance, in a chatbot application, the reward model might score responses based on coherence, helpfulness, or ethical considerations. The model learns to maximize rewards, improving the quality of its generated content. ### Applications #### Content Generation Article Writing and Creative Content: LLMs can automatically generate articles, news reports, technical documents, or creative writing (e.g., stories, poetry). This is highly useful in content creation or media industries. Ad Copy and Marketing: LLMs can automatically create catchy ad copy or social media content based on specific needs, saving time in content writing. #### Question Answering (QA) Systems LLMs can perform efficient language translation, supporting multilingual tasks such as cross-border business communication, international collaboration, etc. Compared to traditional translation tools, LLMs can better understand complex context and generate more natural translations. As we’ve explored, LLMs have diverse and impactful applications in areas such as content generation, Question answering systems. However, in training large language models, one key operation that consumes a significant amount of computational resources is matrix-vector multiplication, also known as the fully-connected or linear layer in deep learning. This operation plays a crucial role in applying the learned parameters across the model and often accounts for over 70% of the total computation during training. In the next section, we’ll use a model based on the open-source project llama2.c by Andrej Karpathy, an open-source variation of GPT, and LLaMA2 released by Meta, to further explore how such operations are optimized in the context of modern language models. ## llama2.c ### what is llama2.c ? llama2.c is a minimalistic implementation of the Llama 2 architecture, focusing on simplicity and educational value. It provides a full-stack solution for training and inference using a small Llama 2 model in pure C. The repository allows loading models trained on the TinyStories dataset and supports running them interactively with a C-based inference engine. It emphasizes ease of use, with the ability to run models with parameter sizes up to 42M efficiently on personal hardware. ### Matrix-Vector-Multiplication in llama2.c The ```matmul``` function in Llama performs matrix-vector multiplication, a fundamental operation in neural network computations. * x: Input vector of size (n). * w: Weight matrix of size (d, n). * xout: Output vector of size (d). For each row of the weight matrix w, it computes the dot product with the vector x and stores the result in xout. OpenMP parallelization is used to divide row-wise computations across multiple cores, enhancing performance. This function is a computational bottleneck in model inference. ```c void matmul(float* xout, float* x, float* w, int n, int d) { // W (d,n) @ x (n,) -> xout (d,) // by far the most amount of time is spent inside this little function int i; #pragma omp parallel for private(i) for (i = 0; i < d; i++) { float val = 0.0f; for (int j = 0; j < n; j++) { val += w[i * n + j] * x[j]; } xout[i] = val; } } ``` Let’s take an example: * The weight matrix w is a 2x3 matrix. * The input vector x is a 3x1 vector. * The output vector xout is a 2x1 vector. ```c int main() { // Define dimensions int n = 3; // number of columns in W and size of vector x int d = 2; // number of rows in W and size of vector xout // Allocate memory for vectors and matrix float* x = (float*)malloc(n * sizeof(float)); float* w = (float*)malloc(d * n * sizeof(float)); float* xout = (float*)malloc(d * sizeof(float)); // Initialize input vector x x[0] = 1.0f; x[1] = 2.0f; x[2] = 3.0f; // Initialize weight matrix W w[0] = 1.0f; w[1] = 2.0f; w[2] = 3.0f; w[3] = 4.0f; w[4] = 5.0f; w[5] = 6.0f; // Perform matrix multiplication matmul(xout, x, w, n, d); // Print the result printf("Result vector xout:\n"); print_vector(xout, d); // Free allocated memory free(x); free(w); free(xout); return 0; } ``` Here is a simple diagram to illustrate the matrix multiplication: ``` w (2x3) x (3x1) xout (2x1) [ 1 2 3 ] [ 1 ] [ 14 ] [ 4 5 6 ] x [ 2 ] = [ 32 ] [ 3 ] ``` The result is: ``` Result vector xout: 14.000000 32.000000 ``` ```matmul``` in assembly code: ```asb= 101fc: a069 j 10286 <matmul+0xae> 101fe: fe042423 sw zero,-24(s0) 10202: fe042223 sw zero,-28(s0) 10206: a881 j 10256 <matmul+0x7e> 10208: fec42783 lw a5,-20(s0) 1020c: 873e mv a4,a5 1020e: fc442783 lw a5,-60(s0) 10212: 02f707bb mulw a5,a4,a5 10216: 2781 sext.w a5,a5 10218: fe442703 lw a4,-28(s0) 1021c: 9fb9 addw a5,a5,a4 1021e: 2781 sext.w a5,a5 10220: 078a slli a5,a5,0x2 10222: fc843703 ld a4,-56(s0) 10226: 97ba add a5,a5,a4 10228: 0007a707 flw fa4,0(a5) 1022c: fe442783 lw a5,-28(s0) 10230: 078a slli a5,a5,0x2 10232: fd043703 ld a4,-48(s0) 10236: 97ba add a5,a5,a4 10238: 0007a787 flw fa5,0(a5) 1023c: 10f777d3 fmul.s fa5,fa4,fa5 10240: fe842707 flw fa4,-24(s0) 10244: 00f777d3 fadd.s fa5,fa4,fa5 10248: fef42427 fsw fa5,-24(s0) 1024c: fe442783 lw a5,-28(s0) 10250: 2785 addiw a5,a5,1 10252: fef42223 sw a5,-28(s0) 10256: fe442783 lw a5,-28(s0) 1025a: 873e mv a4,a5 1025c: fc442783 lw a5,-60(s0) 10260: 2701 sext.w a4,a4 10262: 2781 sext.w a5,a5 10264: faf742e3 blt a4,a5,10208 <matmul+0x30> 10268: fec42783 lw a5,-20(s0) 1026c: 078a slli a5,a5,0x2 1026e: fd843703 ld a4,-40(s0) 10272: 97ba add a5,a5,a4 10274: fe842787 flw fa5,-24(s0) 10278: 00f7a027 fsw fa5,0(a5) 1027c: fec42783 lw a5,-20(s0) 10280: 2785 addiw a5,a5,1 10282: fef42623 sw a5,-20(s0) 10286: fec42783 lw a5,-20(s0) 1028a: 873e mv a4,a5 1028c: fc042783 lw a5,-64(s0) 10290: 2701 sext.w a4,a4 10292: 2781 sext.w a5,a5 10294: f6f745e3 blt a4,a5,101fe <matmul+0x26> ``` * `101fc-10294`Outer Loop Control: * Uses blt (branch if less than) to compare indices, controlling the outer loop. * Updates results in memory after processing each matrix row. * 10208-10250 Inner Loop Calculation: * Performs the dot product of a matrix row and a vector. * Loads matrix elements and vector components using lw (load word). * Executes floating-point multiplication with fmul.s and accumulates using fadd.s. * Temporarily stores the result in a local variable. Although each row of the matrix multiplication can be computed in parallel (thanks to OpenMP), the total computation still involves `d` rows, and each row involves `𝑛` operations. The parallelization only reduces the time for individual row computations, not the overall complexity of processing `d` rows. Hence, even with parallelism, the overall time complexity remains `O(d×n)`. The reduction in time only happens in the constant factor, not in the big-O complexity. ## Accelerate-LLM-Inference In matrix multiplication, the inner loop performs multiplication and accumulation one pair of data at a time, requiring multiple CPU cycles per operation. This approach is inefficient because each multiplication and addition operation is done sequentially, leading to high overhead from repeated memory access, data loading, and computation. As a result, the CPU spends excessive time on individual operations, reducing the overall performance. To address the inefficiencies of sequential operations in matrix multiplication, I propose defining custom Matrix-Vector Multiplication (MVM) instructions inspired by the RISC-V Vector Extension (RVV). These instructions would focus on parallelizing computation, enabling operations such as simultaneous data loading, vectorized multiplication, and accumulation. Specifically, the design could include: * Load Vector Registers: Efficiently load multiple elements of the matrix row (w) and vector (x) into vector registers in a single operation. * Vector Multiply-Accumulate: Combine multiple `w[i * n + j] * x[j]` operations with accumulation, reducing loop overhead. * Store Results: Write the computed row output back to memory in a vectorized fashion. After Adding instructions, if each vector operation processes l elements in parallel: * Computaions per row:``⌈n/l⌉``. * Total computations: `d×⌈n/l⌉`. * New Time Complexity: `𝑂(d×⌈n/l⌉)`. The vectorized instructions theoretically reduce complexity by a factor of l, enhancing performance by decreasing loop iterations and memory access. ### Matrix Vector Multiplication instructions `vflw`: Its mnemonic representation would resemble: ``` vflw v1 offset(r1) # R[v1][31:0] = Mem[R[r1] + offset] # R[v1][63:32] = Mem[R[r1] + offset + 4] # R[v1][95:64] = Mem[R[r1] + offset + 8] # R[v1][127:96] = Mem[R[r1] + offset + 12] ``` Use the `vflw` instruction to load data directly from memory into a vector register. I would like to store four words in the vector register at once for computation. `vfsw`: Its mnemonic representation would resemble: ``` vfsw v1, offset(r1) # Mem[R[r1] + offset] = R[v1][31:0] # Mem[R[r1] + offset + 4] = R[v1][63:32] # Mem[R[r1] + offset + 8] = R[v1][95:64] # Mem[R[r1] + offset + 12] = R[v1][127:96] ``` `vmul`: Each corresponding element from `v2` and `v3` is multiplied, and the result is stored in the corresponding position in `v1`. ``` vmul v1, v2, v3 # R[v1][31:0] = R[v2][31:0] * R[v3][31:0] # R[v1][63:32] = R[v2][63:32] * R[v3][63:32] # R[v1][95:64] = R[v2][95:64] * R[v3][95:64] # R[v1][127:96] = R[v2][127:96] * R[v3][127:96] ``` `vadd`: The vadd instruction performs element-wise addition between two vector registers and stores the result in a destination vector register. ``` vadd v1, v2, v3 # R[v1][31:0] = R[v2][31:0] + R[v3][31:0] # R[v1][63:32] = R[v2][63:32] + R[v3][63:32] # R[v1][95:64] = R[v2][95:64] + R[v3][95:64] # R[v1][127:96] = R[v2][127:96] + R[v3][127:96] ``` ## Adding custom instruction to RISCV ISA ### Download and install RISC-V default toolchain In this first step, the default RISC-V toolchain is compiled, without modifications in the instructions set. Cloning the Linux kernel and its submodules: ```= $ git clone --recurse-submodules https://github.com/riscv/riscv-gnu-toolchain.git ``` :::warning Around 7GB are needed to download all repositories. ::: The toolchain is built in `/opt/riscv_custom`: ```= $ cd riscv-gnu-toolchain $ ./configure --prefix=/opt/riscv_custom $ make -j$(nproc) ``` GCC cross-compiler version can be checked: ```= $ /opt/riscv_custom/bin/./riscv64-unknown-elf-gcc --version ``` ``` riscv64-unknown-elf-gcc (g04696df096) 14.2.0 ``` ### Adding a custom instruction in the cross-compiler To test the implementation, we first add the non-default `modulo` instruction to RV32I. Its mnemonic representation would resemble: ``` mod r1 r2 r3 #R[r1] = R[r2] % R[r3] ``` The opcode syntax would be: ``` mod rd rs1 rs2 31..25=1 14..12=0 6..2=2 1..0=3 ``` The `rv_i` file is modified as follows: ``` add rd rs1 rs2 31..25=0 14..12=0 6..2=0x0C 1..0=3 + mod rd rs1 rs2 31..25=1 14..12=0 6..2=2 1..0=3 sub rd rs1 rs2 31..25=32 14..12=0 6..2=0x0C 1..0=3 ``` The `rv_i` file is located in the `riscv-opcodes/extensions` directory, which contains opcode definitions for RISC-V instruction extensions. Then, opcode file is processed to get MATCH and MASK values: ``` $ make ``` This command will generate the representation of opcodes in several formats such as SystemVerilog, Chisel and C (in the encoding.out.h file). ``` #define MATCH_MOD 0x200000b #define MASK_MOD 0xfe00707f ``` ### Binutils modification Now, binutils need to be aware of the new instruction. `riscv-gnu-toolchain/binutils/include/opcode/riscv-opc.h` is updated as follows: ``` /* Instruction opcode macros. */ + #define MATCH_MOD 0x200000b + #define MASK_MOD 0xfe00707f #define MATCH_SLLI_RV32 0x1013 ``` ``` #endif /* RISCV_ENCODING_H */ #ifdef DECLARE_INSN + DECLARE_INSN(mod, MATCH_MOD, MASK_MOD) DECLARE_INSN(slli_rv32, MATCH_SLLI_RV32, MASK_SLLI_RV32) ``` The related C file (riscv-gnu-toolchain/binutils/opcodes/riscv-opc.c) has to be modified as well: ``` /* Basic RVI instructions and aliases. */ + {"mod", 0, INSN_CLASS_I, "d,s,t", MATCH_MOD, MASK_MOD, match_opcode, 0 }, {"unimp", 0, INSN_CLASS_C, "", 0, 0xffffU, match_opcode, INSN_ALIAS }, ``` `name`: name of the instruction. `xlen`: width of an integer register in bits. `isa`: ISA extension. `operands`: based on the parsing available in riscv-gnu-toolchain/riscv-binutils/gas/config/tc-riscv.c: ```clike switch (*fmt++) { case 'd': INSERT_OPERAND (RD, insn, va_arg (args, int)); continue; case 's': INSERT_OPERAND (RS1, insn, va_arg (args, int)); continue; case 't': INSERT_OPERAND (RS2, insn, va_arg (args, int)); continue; ``` `match_func` pointer to the function recovering funct7, funct3 and opcode fields of the instruction ```clike= static int match_opcode (const struct riscv_opcode *op, insn_t insn) { return ((insn ^ op->match) & op->mask) == 0; } ``` ### Testing the new instruction The final step is to recompile the custom instruction that has been implemented. ```= $ make clean $ make -j$(nproc) ``` Here is a sample C code using the freshmly implemented `mod` instruction: ```c #include <stdio.h> int main(){ int a,b,c; a = 5; b = 2; asm volatile ( "mod %[z], %[x], %[y]\n\t" : [z] "=r" (c) : [x] "r" (a), [y] "r" (b) ); if ( c != 1 ){ printf("\n[[FAILED]]\n"); return -1; } printf("\n[[PASSED]]\n"); return 0; } ``` Compile the C file and verify the presence of the mod instruction in the objdump output. ``` $ /opt/riscv_custom/bin/riscv64-unknown-elf-gcc main.c -o main ``` ```c david@david-B660M-PG-Riptide:~/CA_final/riscv-gnu-toolchain$ /opt/riscv_custom/bin/riscv64-unknown-elf-objdump -D main | grep -n -A 20 "<main>:" 78:00000000000101d4 <main>: 79- 101d4: 1101 addi sp,sp,-32 80- 101d6: ec06 sd ra,24(sp) 81- 101d8: e822 sd s0,16(sp) 82- 101da: 1000 addi s0,sp,32 83- 101dc: 4795 li a5,5 84- 101de: fef42623 sw a5,-20(s0) 85- 101e2: 4789 li a5,2 86- 101e4: fef42423 sw a5,-24(s0) 87- 101e8: fec42783 lw a5,-20(s0) 88- 101ec: fe842703 lw a4,-24(s0) 89- 101f0: 02e7878b mod a5,a5,a4 90- 101f4: fef42223 sw a5,-28(s0) 91- 101f8: fe442783 lw a5,-28(s0) 92- 101fc: 0007871b sext.w a4,a5 93- 10200: 4785 li a5,1 94- 10202: 00f70963 beq a4,a5,10214 <main+0x40> 95- 10206: 67c9 lui a5,0x12 96- 10208: 65078513 addi a0,a5,1616 # 12650 <__errno+0x8> 97- 1020c: 392000ef jal 1059e <puts> 98- 10210: 57fd li a5,-1 ``` We can observe the mod instruction at line 89 in the objdump output. ## Adding a custom instruction in Spike Two tools needs to be installed: * Spike, the simulator itself * PK, the RISC-V proxy kernel which is a piece of software that can host statically-linked binaries RISCV tools path ```= export RISCV=/opt/riscv_custom export PATH=$RISCV/bin:$PATH ``` Spike install ```= git clone https://github.com/riscv-software-src/riscv-isa-sim cd riscv-isa-sim mkdir build cd build ../configure --prefix=$RISCV make -j$(nproc) ``` PK install ```= git clone https://github.com/riscv-software-src/riscv-pk cd riscv-pk mkdir build cd build ../configure --prefix=$RISCV --host=riscv64-unknown-elf make -j$(nproc) sudo make install export PATH=$RISCV/riscv64-unknown-elf/bin:$PATH ``` ### Adding an instruction in the simulator Describe the behavior of the new instruction by adding a file in `riscv-isa-sim/riscv/insns/mod.h`. The `mod.h` file will be: ``` WRITE_RD(sext_xlen(RS1 % RS2)); ``` In `riscv-isa-sim/riscv/encoding.h`, add `MATCH_MOD` and `MATCH_MOD` as for the compiler: ``` #define MATCH_ADD 0x33 #define MASK_ADD 0xfe00707f + #define MATCH_MOD 0x200000b + #define MASK_MOD 0xfe00707f ``` ``` DECLARE_INSN(add, MATCH_ADD, MASK_ADD) + DECLARE_INSN(mod, MATCH_MOD, MASK_MOD) DECLARE_INSN(add_uw, MATCH_ADD_UW, MASK_ADD_UW) ``` Then, Makefile needs to compile the `mod` instruction. In `riscv-isa-sim/riscv/riscv.mk.in`: ``` riscv_insn_ext_i = \ add \ + mod \ addi \ ``` The last file to be modified is `riscv-isa-sim/disasm/disasm.cc` where instruction types are defined: ``` DEFINE_RTYPE(add); + DEFINE_RTYPE(mod); DEFINE_RTYPE(sub); DEFINE_RTYPE(sll); ``` The last step is to rebuild the simulator and test the program. ``` davi@david-B660M-PG-Riptide:~/CA$ riscv64-unknown-elf-gcc -o main main.c davi@david-B660M-PG-Riptide:~/CA$ spike pk main [[PASSED]] ``` ## Reference * [Large language model](https://en.wikipedia.org/wiki/Large_language_model) * [llama2.c](https://github.com/karpathy/llama2.c) * [Adding custom instructions in the RISC-V ISA](https://pcotret.gitlab.io/riscv-custom/)