# Pipelined RISC-V core > 戴均原 ## Project description This project involves extending an existing RISC-V processor from its original 3-stage and 5-stage pipelined designs to support the complete RV32I instruction set and the B extension. Additionally, at least two RISC-V programs from the course exercises will be selected and rewritten to utilize the B extension, ensuring proper functionality on the enhanced processor. The final implementation will be published on GitHub. The process explores the transition from a single-cycle design to a 3-stage pipeline and subsequently to a 5-stage pipeline, focusing on critical design considerations such as forwarding and pipeline optimization. ## Stage Comparison ### three stage A simplified processor pipeline with three main stages: Instruction Fetch (IF), Instruction Decode (ID), and Execute (EX). This design focuses on basic instruction flow without incorporating memory access or write-back stages, making it a foundational model for understanding pipeline operations. `Control.scala` ``` class Control extends Module { val io = IO(new Bundle { val JumpFlag = Input(Bool()) val Flush = Output(Bool()) }) io.Flush := io.JumpFlag } ``` ### five stage stall Extends the pipeline to include Memory Access (MEM) and Write Back (WB) stages. This version handles data hazards by stalling the pipeline, pausing the flow of instructions when dependencies are detected, ensuring correctness at the cost of performance. `Control.scala` ``` package riscv.core.fivestage_stall import chisel3._ import riscv.Parameters class Control extends Module { val io = IO(new Bundle { val jump_flag = Input(Bool()) // ex.io.if_jump_flag val rs1_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg1_read_address val rs2_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg2_read_address val rd_ex = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id2ex.io.output_regs_write_address val reg_write_enable_ex = Input(Bool()) // id2ex.io.output_regs_write_enable val rd_mem = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // ex2mem.io.output_regs_write_address val reg_write_enable_mem = Input(Bool()) // ex2mem.io.output_regs_write_enable val if_flush = Output(Bool()) val id_flush = Output(Bool()) val pc_stall = Output(Bool()) val if_stall = Output(Bool()) }) io.if_flush := false.B io.id_flush := false.B io.pc_stall := false.B io.if_stall := false.B when(io.jump_flag) { io.if_flush := true.B io.id_flush := true.B }.elsewhen( (io.reg_write_enable_ex && (io.rd_ex === io.rs1_id || io.rd_ex === io.rs2_id) && io.rd_ex =/= 0.U) || (io.reg_write_enable_mem && (io.rd_mem === io.rs1_id || io.rd_mem === io.rs2_id) && io.rd_mem =/= 0.U) ) { io.id_flush := true.B io.pc_stall := true.B io.if_stall := true.B } } ``` ``` io.if_flush := false.B io.id_flush := false.B io.pc_stall := false.B io.if_stall := false.B ``` * All control signals are initialized to `false.B `to ensure no unnecessary flushing or stalling occurs unless specific conditions are met. ``` when(io.jump_flag) { io.if_flush := true.B io.id_flush := true.B } ``` * If `io.jump_flag` is asserted (`true.B`), it indicates that a jump instruction is detected. In response: * io.if_flush is set to `true.B`, flushing the instruction fetch (IF) stage. * io.id_flush is set to `true.B`, flushing the instruction decode (ID) stage. ``` .elsewhen( (io.reg_write_enable_ex && (io.rd_ex === io.rs1_id || io.rd_ex === io.rs2_id) && io.rd_ex =/= 0.U) || (io.reg_write_enable_mem && (io.rd_mem === io.rs1_id || io.rd_mem === io.rs2_id) && io.rd_mem =/= 0.U) ) { io.id_flush := true.B io.pc_stall := true.B io.if_stall := true.B } ``` * Conditions for Hazard Detection: * The execution stage (EX) intends to write to a register (`io.reg_write_enable_ex` is true), and the target register (`io.rd_ex`) matches one of the source registers (`io.rs1_id` or `io.rs2_id`) of the current instruction, and the target register is not the zero register (`io.rd_ex =/= 0.U`). * The memory stage (MEM) intends to write to a register (`io.reg_write_enable_mem` is true), and the target register (`io.rd_mem`) matches one of the source registers (`io.rs1_id` or `io.rs2_id`), and the target register is not the zero register (`io.rd_mem =/= 0.U`). * Actions Taken: * `io.id_flush := true.B`: Flush the instruction decode (ID) stage to avoid processing incorrect instructions. * `io.pc_stall := true.B`: Stall the program counter (PC) to pause instruction fetching. * `io.if_stall := true.B`: Stall the instruction fetch (IF) stage to prevent fetching new instructions. ### five stage forward Builds on the five-stage model but adds forwarding (bypassing) to reduce stalls. This technique resolves data hazards by forwarding results from later stages directly to earlier stages, improving performance compared to stalling. `Control.scala` ``` package riscv.core.fivestage_forward import chisel3._ import riscv.Parameters class Control extends Module { val io = IO(new Bundle { val jump_flag = Input(Bool()) // ex.io.if_jump_flag val rs1_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg1_read_address val rs2_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg2_read_address val memory_read_enable_ex = Input(Bool()) // id2ex.io.output_memory_read_enable val rd_ex = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id2ex.io.output_regs_write_address val if_flush = Output(Bool()) val id_flush = Output(Bool()) val pc_stall = Output(Bool()) val if_stall = Output(Bool()) }) io.if_flush := false.B io.id_flush := false.B io.pc_stall := false.B io.if_stall := false.B when(io.jump_flag) { io.if_flush := true.B io.id_flush := true.B }.elsewhen(io.memory_read_enable_ex && io.rd_ex =/= 0.U && (io.rd_ex === io.rs1_id || io.rd_ex === io.rs2_id)) { io.id_flush := true.B io.pc_stall := true.B io.if_stall := true.B } } ``` ``` object ForwardingType { val NoForward = 0.U(2.W) val ForwardFromMEM = 1.U(2.W) val ForwardFromWB = 2.U(2.W) } ``` * `NoForward`: No data forwarding is required (default value: 0). * `ForwardFromMEM`: Forward data from the MEM stage to the EX stage (value: 1). * `ForwardFromWB`: Forward data from the WB stage to the EX stage (value: 2). ``` when(io.reg_write_enable_mem && io.rs1_ex === io.rd_mem && io.rd_mem =/= 0.U) { io.reg1_forward_ex := ForwardingType.ForwardFromMEM }.elsewhen(io.reg_write_enable_wb && io.rs1_ex === io.rd_wb && io.rd_wb =/= 0.U) { io.reg1_forward_ex := ForwardingType.ForwardFromWB }.otherwise { io.reg1_forward_ex := ForwardingType.NoForward } ``` * Condition 1: Forward from MEM stage: * If the MEM stage is writing to a register (`reg_write_enable_mem` is true), and the destination register (`rd_mem`) matches `rs1_ex`, and `rd_mem` is not the zero register (`/= 0.U`), data is forwarded from the MEM stage. * Condition 2: Forward from WB stage: * If the WB stage is writing to a register (`reg_write_enable_wb` is true), and the destination register (`rd_wb`) matches `rs1_ex`, and `rd_wb` is not the zero register (`/= 0.U`), data is forwarded from the WB stage. ### five stage final Combines stalling and forwarding for optimal hazard resolution. It incorporates advanced techniques to handle both data and control hazards, representing a more refined and efficient pipeline design. # B extension > https://github.com/riscv/riscv-b It has been recognized that there is need to officially standardize a "B" extension - that represents the collection of the **Zba**, **Zbb**, and **Zbs** extensions - for the sake of consistency and conciseness across toolchains and how they identify support for these bitmanip extensions (which, for example, are mandated in RVA and RVB profiles). In conjunction with this an official definition of the misa.B bit will be established - along the lines of misa.B=1 indicating support for at least these three extensions (and misa.B=0 indicating that one or more may not be supported). >https://www.ece.lsu.edu/ee4720/doc/riscv-bitmanip-1.0.0.pdf >http://riscvbook.com/chinese/RISC-V-Reader-Chinese-v2p1.pdf ### Zba **sh1add:** Performs a shift left by 1 (multiply by 2) and then adds the result to a second operand. **sh2add:** Similar to `sh1add`, but shifts the first operand left by 2 (multiply by 4) before addition. **sh3add:** Extends the logic further by shifting the first operand left by 3 (multiply by 8) before addition. ### Zbb **andn:** Performs a bitwise AND of the first operand with the negation of the second operand. **orn:** Performs a bitwise OR of the first operand with the negation of the second operand. **xnor:** Performs a bitwise XOR followed by a NOT operation (XNOR). **min:** Compares two signed integers and selects the smaller value. **minu:** Similar to `min`, but operates on unsigned integers. **max:** Compares two signed integers and selects the larger value. **maxu:** Similar to `max`, but operates on unsigned integers. ### Zbs **bclr:** Clears a specific bit in the first operand, based on the position specified by the second operand. **bext:** Extracts a specific bit from the first operand, based on the position specified by the second operand. **binv:** Inverts a specific bit in the first operand, based on the position specified by the second operand. **bset:** Sets a specific bit in the first operand, based on the position specified by the second operand. ## Extend the RISC-V Processor ### `ALU.scala` Implements an Arithmetic Logic Unit (ALU) using Chisel, a hardware description language based on Scala. The ALU supports the RISC-V RV32I instruction set and parts of the B extension (Zba, Zbb, Zbs). * ALUFunctions: Enumerates all the operations (instructions) supported by the ALU. * ALU Class: Implements the specific computation logic, including basic RV32I instructions (e.g., addition, subtraction, and shifts) and B extension instructions (e.g., andn, sh1add, and bclr). `ALUFunctions` ``` object ALUFunctions extends ChiselEnum { val zero, add, sub, sll, slt, xor, or, and, srl, sra, sltu = Value val sh1add, sh2add, sh3add = Value // Zba val andn, orn, xnor, min, minu, max, maxu = Value // Zbb val bclr, bext, binv, bset = Value // Zbs } ``` `ALU Class` ``` //Zba is(ALUFunctions.sh1add) { io.result := (io.op1 << 1) + io.op2 } is(ALUFunctions.sh2add) { io.result := (io.op1 << 2) + io.op2 } is(ALUFunctions.sh3add) { io.result := (io.op1 << 3) + io.op2 } ``` ``` //Zbb is(ALUFunctions.andn) { io.result := io.op1 & ~io.op2 } is(ALUFunctions.orn) { io.result := io.op1 | ~io.op2 } is(ALUFunctions.xnor) { io.result := ~(io.op1 ^ io.op2) } is(ALUFunctions.min) { io.result := Mux(io.op1.asSInt < io.op2.asSInt, io.op1, io.op2) } is(ALUFunctions.minu) { io.result := Mux(io.op1 < io.op2, io.op1, io.op2) } is(ALUFunctions.max) { io.result := Mux(io.op1.asSInt > io.op2.asSInt, io.op1, io.op2) } is(ALUFunctions.maxu) { io.result := Mux(io.op1 > io.op2, io.op1, io.op2) } ``` ``` //Zbs is(ALUFunctions.bclr) { io.result := io.op1 & ~(1.U << io.op2(4, 0)) } is(ALUFunctions.bext) { io.result := (io.op1 >> io.op2(4, 0)) & 1.U } is(ALUFunctions.binv) { io.result := io.op1 ^ (1.U << io.op2(4, 0)) } is(ALUFunctions.bset) { io.result := io.op1 | (1.U << io.op2(4, 0)) } ``` ### ALUControl.scala Implements the ALUControl module, which is responsible for decoding the RISC-V instruction's opcode, funct3, and funct7. Based on the instruction type and function code (funct code), it selects the corresponding ALU function (alu_funct). ``` is(InstructionTypes.RM) { io.alu_funct := MuxLookup( Cat(io.funct7, io.funct3), ALUFunctions.zero, IndexedSeq( Cat("b0000000".U, InstructionsTypeR.add_sub) -> ALUFunctions.add, Cat("b0100000".U, InstructionsTypeR.add_sub) -> ALUFunctions.sub, Cat("b0000000".U, InstructionsTypeR.sll) -> ALUFunctions.sll, Cat("b0000000".U, InstructionsTypeR.slt) -> ALUFunctions.slt, Cat("b0000000".U, InstructionsTypeR.sltu) -> ALUFunctions.sltu, Cat("b0000000".U, InstructionsTypeR.xor) -> ALUFunctions.xor, Cat("b0000000".U, InstructionsTypeR.or) -> ALUFunctions.or, Cat("b0000000".U, InstructionsTypeR.and) -> ALUFunctions.and, Cat("b0000000".U, InstructionsTypeR.sr) -> ALUFunctions.srl, Cat("b0100000".U, InstructionsTypeR.sr) -> ALUFunctions.sra, // Zba Cat("b0010000".U, "b010".U) -> ALUFunctions.sh1add, Cat("b0010000".U, "b100".U) -> ALUFunctions.sh2add, Cat("b0010000".U, "b110".U) -> ALUFunctions.sh3add, // Zbb Cat("b0100000".U, "b111".U) -> ALUFunctions.andn, Cat("b0100000".U, "b110".U) -> ALUFunctions.orn, Cat("b0100000".U, "b100".U) -> ALUFunctions.xnor, Cat("b0000101".U, "b100".U) -> ALUFunctions.min, Cat("b0000101".U, "b101".U) -> ALUFunctions.minu, Cat("b0000101".U, "b110".U) -> ALUFunctions.max, Cat("b0000101".U, "b111".U) -> ALUFunctions.maxu, // Zbs Cat("b0100100".U, "b001".U) -> ALUFunctions.bclr, Cat("b0100100".U, "b101".U) -> ALUFunctions.bext, Cat("b0110100".U, "b001".U) -> ALUFunctions.binv, Cat("b0010100".U, "b001".U) -> ALUFunctions.bset ), ) } ``` * combining funct7 and funct3 using Cat (concatenation) to distinguish specific functions of the B extension. # Test ### test tool use riscv32-unknown-elf or riscv64-unknown-elf to complie .s to .asmbin Since I will be using the B extension, I need to download the latest version. * install ``` git clone https://github.com/riscv-collab/riscv-gnu-toolchain.git cd riscv-gnu-toolchain/ ./configure --prefix=/opt/riscv --with-arch=rv32gc_zba_zbb_zbc sudo make export PATH=$HOME/riscv/bin:$PATH ``` * how to use ``` riscv32-unknown-elf-as -o test_B_extension.o test_B_extension.s riscv32-unknown-elf-ld -o test_B_extension.elf -T link.lds test_B_extension.o riscv32-unknown-elf-objcopy -O binary test_B_extension.elf test_B_extension.asmbin ``` ### test 1 This program consists of two functions: `logint` calculates the base-2 logarithm of an input N (a0), returns its log base 2 in a0, and `reverse` reverses the binary representation of N(a0) base on n bits(a1). * original code ```s # Takes input N (a0), returns its log base 2 in a0 logint: addi sp, sp, -4 sw t0, 0(sp) add t0, a0, zero# k = N add a0, zero, zero# i = 0 logloop: beq t0, zero, logloop_end # Exit if k == 0 srai t0, t0, 1 # k >>= 1 addi a0, a0, 1 # i++ j logloop logloop_end: addi a0, a0, -1 # Return i - 1 lw t0, 0(sp) addi sp, sp, 4 jr ra # Takes inputs N(a0) and n(a1), reverses the number in binary reverse: addi sp, sp, -28 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) sw s3, 16(sp) sw s4, 20(sp) sw s5, 24(sp) call logint# Now a0 has log2(N) addi s0, zero, 1 # j = 1 add s1, zero, zero # p = 0 forloop_reverse: bgt s0, a0, forloop_reverse_end sub s2, a0, s0# s2 = a0 - s0 addi s3, zero, 1 sll s3, s3, s2 and s3, a1, s3 beq s3, zero, elses3 # If not, skip ifs3: addi s4, s0, -1 # s4 = j - 1 addi s5, zero, 1 sll s5, s5, s4 or s1, s1, s5 elses3: addi s0, s0, 1 j forloop_reverse forloop_reverse_end: add a0, s1, zero # Return p lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) lw s2, 12(sp) lw s3, 16(sp) lw s4, 20(sp) lw s5, 24(sp) addi sp, sp, 28 jr ra ``` * modify point ```s sll s3, s3, s2 and s3, a1, s3 # combine above step using bext bext s3, a1, s2 ``` ```s sll s5, s5, s4 or s1, s1, s5 # combine above step using bset bset s1, s1, s4 ``` * code after modified ```s # Takes input N (a0), returns its log base 2 in a0 logint: addi sp, sp, -4 sw t0, 0(sp) add t0, a0, zero# k = N add a0, zero, zero# i = 0 logloop: beq t0, zero, logloop_end # Exit if k == 0 srai t0, t0, 1 # k >>= 1 addi a0, a0, 1 # i++ j logloop logloop_end: addi a0, a0, -1 # Return i - 1 lw t0, 0(sp) addi sp, sp, 4 jr ra # Takes inputs N(a0) and n(a1), reverses the number in binary reverse: addi sp, sp, -28 sw ra, 0(sp) sw s0, 4(sp) sw s1, 8(sp) sw s2, 12(sp) sw s3, 16(sp) sw s4, 20(sp) sw s5, 24(sp) call logint# Now a0 has log2(N) addi s0, zero, 1 # j = 1 add s1, zero, zero # p = 0 forloop_reverse: bgt s0, a0, forloop_reverse_end sub s2, a0, s0# s2 = a0 - s0 addi s3, zero, 1 # modified bext s3, a1, s2 beq s3, zero, elses3 # If not, skip ifs3: addi s4, s0, -1 # s4 = j - 1 addi s5, zero, 1 # modified bset s1, s1, s4 elses3: addi s0, s0, 1 j forloop_reverse forloop_reverse_end: add a0, s1, zero # Return p lw ra, 0(sp) lw s0, 4(sp) lw s1, 8(sp) lw s2, 12(sp) lw s3, 16(sp) lw s4, 20(sp) lw s5, 24(sp) addi sp, sp, 28 jr ra ``` ### test 2 `find_binary`, converts a decimal number (provided in the a0 register) into its binary representation and returns the result in decimal form. The function works recursively to build the binary equivalent from least significant to most significant bits. * original code ```s find_binary: addi sp, sp, -8 # a0 will have arg and be where we return sw ra, 4(sp) # saving return address on the stack sw s0, 0(sp) # saving "decimal % 2" on the stack beq a0, x0, postamble andi s0, a0, 1 # set s0 to a0 % 2 srli a0, a0, 1 jal ra, find_binary # recursive call li t0, 10 mul a0, t0, a0 add a0, a0, s0 # accumulating the stored return value "decimal % 2" postamble: lw ra, 4(sp) # Restore ra lw s0, 0(sp) addi sp, sp, 8 end: jr ra ``` * modify point ```s andi s0, a0, 1 # use bext to replace bext s0, a0, 0 ``` * code after modified ```s find_binary: addi sp, sp, -8 # Save stack space sw ra, 4(sp) # Save return address sw s0, 0(sp) # Save "decimal % 2" beq a0, x0, postamble # Base case: if a0 == 0, return bext s0, a0, 0 # Extract the 0th bit (a0 % 2) srli a0, a0, 1 # Divide a0 by 2 (logical shift right) jal ra, find_binary # Recursive call li t0, 10 # Load 10 into t0 mul a0, t0, a0 # Multiply a0 by 10 add a0, a0, s0 # Accumulate the result with "decimal % 2" postamble: lw ra, 4(sp) # Restore return address lw s0, 0(sp) # Restore s0 addi sp, sp, 8 # Restore stack pointer jr ra # Return ``` # Verify the test ### verify test 1 * test code ``` // ThreeStage // test logint // input(a0) is 8 ,test whether the answer(a0) is 3 it should "test logint function" in { test(new TestTopModule("test_1.asmbin", ImplementationType.ThreeStage)).withAnnotations(TestAnnotations.annos) { c => c.clock.step(1000) c.io.regs_debug_read_address.poke(5.U) c.clock.step() c.io.regs_debug_read_data.expect(3.U) val realData = c.io.regs_debug_read_data.peek().litValue println(s"[LogInt Test] Output = $realData") c.clock.step() } } // test reverse // input N(a0) = 5, n(a1) = 3, output = 5 it should "test reverse function" in { test(new TestTopModule("test_1.asmbin", ImplementationType.ThreeStage)).withAnnotations(TestAnnotations.annos) { c => c.clock.step(1000) c.io.regs_debug_read_address.poke(5.U) c.io.regs_debug_read_data.poke(3.U) c.clock.step() c.io.regs_debug_read_data.expect(5.U) val realData = c.io.regs_debug_read_data.peek().litValue println(s"[Reverse Test] Output = $realData") c.clock.step() } } ``` ``` // FiveStage(only test final) // test logint // input(a0) is 8 ,test whether the answer(a0) is 3 it should "test logint function" in { test(new TestTopModule("test_1.asmbin", ImplementationType.FiveStageFinal)).withAnnotations(TestAnnotations.annos) { c => c.clock.step(1000) c.io.regs_debug_read_address.poke(5.U) c.clock.step() c.io.regs_debug_read_data.expect(3.U) val realData = c.io.regs_debug_read_data.peek().litValue println(s"[LogInt Test] Output = $realData") c.clock.step() } } // test reverse // input N(a0) = 5, n(a1) = 3, output = 5 it should "test reverse function" in { test(new TestTopModule("test_1.asmbin", ImplementationType.FiveStageFinal)).withAnnotations(TestAnnotations.annos) { c => c.clock.step(1000) c.io.regs_debug_read_address.poke(5.U) c.io.regs_debug_read_data.poke(3.U) c.clock.step() c.io.regs_debug_read_data.expect(5.U) val realData = c.io.regs_debug_read_data.peek().litValue println(s"[Reverse Test] Output = $realData") c.clock.step() } } ``` ### verify test 2 * test code ``` // ThreeStage it should "test find_binary function" in { test(new TestTopModule("test_find_binary.asmbin", ImplementationType.ThreeStage)).withAnnotations(TestAnnotations.annos) { c => // Simulate clock initialization c.clock.step(1000) // Set input value, assuming a0 is stored in register 5 val inputValue = 5 // Test input value, e.g., 5 c.io.regs_debug_read_address.poke(5.U) // Set register address c.io.regs_debug_read_data.poke(inputValue.U) // Set input value 5 as a0 c.clock.step() // Set expected output, binary of 5 is 101 (in decimal: 101) val expectedOutput = 101 c.io.regs_debug_read_data.expect(expectedOutput.U) // Check if output matches the expected value val realData = c.io.regs_debug_read_data.peek().litValue println(s"[FindBinary Test] Input: $inputValue, Output: $realData") // Advance the clock to complete module execution c.clock.step() } } ``` ``` // FiveStage(only test final) it should "test find_binary function" in { test(new TestTopModule("test_find_binary.asmbin", ImplementationType.FiveStageFinal)).withAnnotations(TestAnnotations.annos) { c => // Simulate clock initialization c.clock.step(1000) // Set input value, assuming a0 is stored in register 5 val inputValue = 5 // Test input value, e.g., 5 c.io.regs_debug_read_address.poke(5.U) // Set register address c.io.regs_debug_read_data.poke(inputValue.U) // Set input value 5 as a0 c.clock.step() // Set expected output, binary of 5 is 101 (in decimal: 101) val expectedOutput = 101 c.io.regs_debug_read_data.expect(expectedOutput.U) // Check if output matches the expected value val realData = c.io.regs_debug_read_data.peek().litValue println(s"[FindBinary Test] Input: $inputValue, Output: $realData") // Advance the clock to complete module execution c.clock.step() } } ```