戴均原
This project involves extending an existing RISC-V processor from its original 3-stage and 5-stage pipelined designs to support the complete RV32I instruction set and the B extension. Additionally, at least two RISC-V programs from the course exercises will be selected and rewritten to utilize the B extension, ensuring proper functionality on the enhanced processor. The final implementation will be published on GitHub.
The process explores the transition from a single-cycle design to a 3-stage pipeline and subsequently to a 5-stage pipeline, focusing on critical design considerations such as forwarding and pipeline optimization.
A simplified processor pipeline with three main stages: Instruction Fetch (IF), Instruction Decode (ID), and Execute (EX). This design focuses on basic instruction flow without incorporating memory access or write-back stages, making it a foundational model for understanding pipeline operations.
Control.scala
class Control extends Module {
val io = IO(new Bundle {
val JumpFlag = Input(Bool())
val Flush = Output(Bool())
})
io.Flush := io.JumpFlag
}
Extends the pipeline to include Memory Access (MEM) and Write Back (WB) stages. This version handles data hazards by stalling the pipeline, pausing the flow of instructions when dependencies are detected, ensuring correctness at the cost of performance.
Control.scala
package riscv.core.fivestage_stall
import chisel3._
import riscv.Parameters
class Control extends Module {
val io = IO(new Bundle {
val jump_flag = Input(Bool()) // ex.io.if_jump_flag
val rs1_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg1_read_address
val rs2_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg2_read_address
val rd_ex = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id2ex.io.output_regs_write_address
val reg_write_enable_ex = Input(Bool()) // id2ex.io.output_regs_write_enable
val rd_mem = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // ex2mem.io.output_regs_write_address
val reg_write_enable_mem = Input(Bool()) // ex2mem.io.output_regs_write_enable
val if_flush = Output(Bool())
val id_flush = Output(Bool())
val pc_stall = Output(Bool())
val if_stall = Output(Bool())
})
io.if_flush := false.B
io.id_flush := false.B
io.pc_stall := false.B
io.if_stall := false.B
when(io.jump_flag) {
io.if_flush := true.B
io.id_flush := true.B
}.elsewhen(
(io.reg_write_enable_ex && (io.rd_ex === io.rs1_id || io.rd_ex === io.rs2_id) && io.rd_ex =/= 0.U)
|| (io.reg_write_enable_mem && (io.rd_mem === io.rs1_id || io.rd_mem === io.rs2_id) && io.rd_mem =/= 0.U)
) {
io.id_flush := true.B
io.pc_stall := true.B
io.if_stall := true.B
}
}
io.if_flush := false.B
io.id_flush := false.B
io.pc_stall := false.B
io.if_stall := false.B
false.B
to ensure no unnecessary flushing or stalling occurs unless specific conditions are met.when(io.jump_flag) {
io.if_flush := true.B
io.id_flush := true.B
}
io.jump_flag
is asserted (true.B
), it indicates that a jump instruction is detected. In response:
true.B
, flushing the instruction fetch (IF) stage.true.B
, flushing the instruction decode (ID) stage..elsewhen(
(io.reg_write_enable_ex && (io.rd_ex === io.rs1_id || io.rd_ex === io.rs2_id) && io.rd_ex =/= 0.U)
|| (io.reg_write_enable_mem && (io.rd_mem === io.rs1_id || io.rd_mem === io.rs2_id) && io.rd_mem =/= 0.U)
) {
io.id_flush := true.B
io.pc_stall := true.B
io.if_stall := true.B
}
io.reg_write_enable_ex
is true), and the target register (io.rd_ex
) matches one of the source registers (io.rs1_id
or io.rs2_id
) of the current instruction, and the target register is not the zero register (io.rd_ex =/= 0.U
).io.reg_write_enable_mem
is true), and the target register (io.rd_mem
) matches one of the source registers (io.rs1_id
or io.rs2_id
), and the target register is not the zero register (io.rd_mem =/= 0.U
).io.id_flush := true.B
: Flush the instruction decode (ID) stage to avoid processing incorrect instructions.io.pc_stall := true.B
: Stall the program counter (PC) to pause instruction fetching.io.if_stall := true.B
: Stall the instruction fetch (IF) stage to prevent fetching new instructions.Builds on the five-stage model but adds forwarding (bypassing) to reduce stalls. This technique resolves data hazards by forwarding results from later stages directly to earlier stages, improving performance compared to stalling.
Control.scala
package riscv.core.fivestage_forward
import chisel3._
import riscv.Parameters
class Control extends Module {
val io = IO(new Bundle {
val jump_flag = Input(Bool()) // ex.io.if_jump_flag
val rs1_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg1_read_address
val rs2_id = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id.io.regs_reg2_read_address
val memory_read_enable_ex = Input(Bool()) // id2ex.io.output_memory_read_enable
val rd_ex = Input(UInt(Parameters.PhysicalRegisterAddrWidth)) // id2ex.io.output_regs_write_address
val if_flush = Output(Bool())
val id_flush = Output(Bool())
val pc_stall = Output(Bool())
val if_stall = Output(Bool())
})
io.if_flush := false.B
io.id_flush := false.B
io.pc_stall := false.B
io.if_stall := false.B
when(io.jump_flag) {
io.if_flush := true.B
io.id_flush := true.B
}.elsewhen(io.memory_read_enable_ex && io.rd_ex =/= 0.U && (io.rd_ex === io.rs1_id || io.rd_ex === io.rs2_id)) {
io.id_flush := true.B
io.pc_stall := true.B
io.if_stall := true.B
}
}
object ForwardingType {
val NoForward = 0.U(2.W)
val ForwardFromMEM = 1.U(2.W)
val ForwardFromWB = 2.U(2.W)
}
NoForward
: No data forwarding is required (default value: 0).ForwardFromMEM
: Forward data from the MEM stage to the EX stage (value: 1).ForwardFromWB
: Forward data from the WB stage to the EX stage (value: 2).when(io.reg_write_enable_mem && io.rs1_ex === io.rd_mem && io.rd_mem =/= 0.U) {
io.reg1_forward_ex := ForwardingType.ForwardFromMEM
}.elsewhen(io.reg_write_enable_wb && io.rs1_ex === io.rd_wb && io.rd_wb =/= 0.U) {
io.reg1_forward_ex := ForwardingType.ForwardFromWB
}.otherwise {
io.reg1_forward_ex := ForwardingType.NoForward
}
reg_write_enable_mem
is true), and the destination register (rd_mem
) matches rs1_ex
, and rd_mem
is not the zero register (/= 0.U
), data is forwarded from the MEM stage.reg_write_enable_wb
is true), and the destination register (rd_wb
) matches rs1_ex
, and rd_wb
is not the zero register (/= 0.U
), data is forwarded from the WB stage.Combines stalling and forwarding for optimal hazard resolution. It incorporates advanced techniques to handle both data and control hazards, representing a more refined and efficient pipeline design.
It has been recognized that there is need to officially standardize a "B" extension - that represents the collection of the Zba, Zbb, and Zbs extensions - for the sake of consistency and conciseness across toolchains and how they identify support for these bitmanip extensions (which, for example, are mandated in RVA and RVB profiles). In conjunction with this an official definition of the misa.B bit will be established - along the lines of misa.B=1 indicating support for at least these three extensions (and misa.B=0 indicating that one or more may not be supported).
https://www.ece.lsu.edu/ee4720/doc/riscv-bitmanip-1.0.0.pdf
http://riscvbook.com/chinese/RISC-V-Reader-Chinese-v2p1.pdf
sh1add:
Performs a shift left by 1 (multiply by 2) and then adds the result to a second operand.
sh2add:
Similar to sh1add
, but shifts the first operand left by 2 (multiply by 4) before addition.
sh3add:
Extends the logic further by shifting the first operand left by 3 (multiply by 8) before addition.
andn:
Performs a bitwise AND of the first operand with the negation of the second operand.
orn:
Performs a bitwise OR of the first operand with the negation of the second operand.
xnor:
Performs a bitwise XOR followed by a NOT operation (XNOR).
min:
Compares two signed integers and selects the smaller value.
minu:
Similar to min
, but operates on unsigned integers.
max:
Compares two signed integers and selects the larger value.
maxu:
Similar to max
, but operates on unsigned integers.
bclr:
Clears a specific bit in the first operand, based on the position specified by the second operand.
bext:
Extracts a specific bit from the first operand, based on the position specified by the second operand.
binv:
Inverts a specific bit in the first operand, based on the position specified by the second operand.
bset:
Sets a specific bit in the first operand, based on the position specified by the second operand.
ALU.scala
Implements an Arithmetic Logic Unit (ALU) using Chisel, a hardware description language based on Scala. The ALU supports the RISC-V RV32I instruction set and parts of the B extension (Zba, Zbb, Zbs).
ALUFunctions
object ALUFunctions extends ChiselEnum {
val zero, add, sub, sll, slt, xor, or, and, srl, sra, sltu = Value
val sh1add, sh2add, sh3add = Value // Zba
val andn, orn, xnor, min, minu, max, maxu = Value // Zbb
val bclr, bext, binv, bset = Value // Zbs
}
ALU Class
//Zba
is(ALUFunctions.sh1add) {
io.result := (io.op1 << 1) + io.op2
}
is(ALUFunctions.sh2add) {
io.result := (io.op1 << 2) + io.op2
}
is(ALUFunctions.sh3add) {
io.result := (io.op1 << 3) + io.op2
}
//Zbb
is(ALUFunctions.andn) {
io.result := io.op1 & ~io.op2
}
is(ALUFunctions.orn) {
io.result := io.op1 | ~io.op2
}
is(ALUFunctions.xnor) {
io.result := ~(io.op1 ^ io.op2)
}
is(ALUFunctions.min) {
io.result := Mux(io.op1.asSInt < io.op2.asSInt, io.op1, io.op2)
}
is(ALUFunctions.minu) {
io.result := Mux(io.op1 < io.op2, io.op1, io.op2)
}
is(ALUFunctions.max) {
io.result := Mux(io.op1.asSInt > io.op2.asSInt, io.op1, io.op2)
}
is(ALUFunctions.maxu) {
io.result := Mux(io.op1 > io.op2, io.op1, io.op2)
}
//Zbs
is(ALUFunctions.bclr) {
io.result := io.op1 & ~(1.U << io.op2(4, 0))
}
is(ALUFunctions.bext) {
io.result := (io.op1 >> io.op2(4, 0)) & 1.U
}
is(ALUFunctions.binv) {
io.result := io.op1 ^ (1.U << io.op2(4, 0))
}
is(ALUFunctions.bset) {
io.result := io.op1 | (1.U << io.op2(4, 0))
}
Implements the ALUControl module, which is responsible for decoding the RISC-V instruction's opcode, funct3, and funct7. Based on the instruction type and function code (funct code), it selects the corresponding ALU function (alu_funct).
is(InstructionTypes.RM) {
io.alu_funct := MuxLookup(
Cat(io.funct7, io.funct3),
ALUFunctions.zero,
IndexedSeq(
Cat("b0000000".U, InstructionsTypeR.add_sub) -> ALUFunctions.add,
Cat("b0100000".U, InstructionsTypeR.add_sub) -> ALUFunctions.sub,
Cat("b0000000".U, InstructionsTypeR.sll) -> ALUFunctions.sll,
Cat("b0000000".U, InstructionsTypeR.slt) -> ALUFunctions.slt,
Cat("b0000000".U, InstructionsTypeR.sltu) -> ALUFunctions.sltu,
Cat("b0000000".U, InstructionsTypeR.xor) -> ALUFunctions.xor,
Cat("b0000000".U, InstructionsTypeR.or) -> ALUFunctions.or,
Cat("b0000000".U, InstructionsTypeR.and) -> ALUFunctions.and,
Cat("b0000000".U, InstructionsTypeR.sr) -> ALUFunctions.srl,
Cat("b0100000".U, InstructionsTypeR.sr) -> ALUFunctions.sra,
// Zba
Cat("b0010000".U, "b010".U) -> ALUFunctions.sh1add,
Cat("b0010000".U, "b100".U) -> ALUFunctions.sh2add,
Cat("b0010000".U, "b110".U) -> ALUFunctions.sh3add,
// Zbb
Cat("b0100000".U, "b111".U) -> ALUFunctions.andn,
Cat("b0100000".U, "b110".U) -> ALUFunctions.orn,
Cat("b0100000".U, "b100".U) -> ALUFunctions.xnor,
Cat("b0000101".U, "b100".U) -> ALUFunctions.min,
Cat("b0000101".U, "b101".U) -> ALUFunctions.minu,
Cat("b0000101".U, "b110".U) -> ALUFunctions.max,
Cat("b0000101".U, "b111".U) -> ALUFunctions.maxu,
// Zbs
Cat("b0100100".U, "b001".U) -> ALUFunctions.bclr,
Cat("b0100100".U, "b101".U) -> ALUFunctions.bext,
Cat("b0110100".U, "b001".U) -> ALUFunctions.binv,
Cat("b0010100".U, "b001".U) -> ALUFunctions.bset
),
)
}
use riscv32-unknown-elf or riscv64-unknown-elf to complie .s to .asmbin
Since I will be using the B extension, I need to download the latest version.
git clone https://github.com/riscv-collab/riscv-gnu-toolchain.git
cd riscv-gnu-toolchain/
./configure --prefix=/opt/riscv --with-arch=rv32gc_zba_zbb_zbc
sudo make
export PATH=$HOME/riscv/bin:$PATH
riscv32-unknown-elf-as -o test_B_extension.o test_B_extension.s
riscv32-unknown-elf-ld -o test_B_extension.elf -T link.lds test_B_extension.o
riscv32-unknown-elf-objcopy -O binary test_B_extension.elf test_B_extension.asmbin
This program consists of two functions: logint
calculates the base-2 logarithm of an input N (a0), returns its log base 2 in a0, and reverse
reverses the binary representation of N(a0) base on n bits(a1).
# Takes input N (a0), returns its log base 2 in a0
logint:
addi sp, sp, -4
sw t0, 0(sp)
add t0, a0, zero# k = N
add a0, zero, zero# i = 0
logloop:
beq t0, zero, logloop_end # Exit if k == 0
srai t0, t0, 1 # k >>= 1
addi a0, a0, 1 # i++
j logloop
logloop_end:
addi a0, a0, -1 # Return i - 1
lw t0, 0(sp)
addi sp, sp, 4
jr ra
# Takes inputs N(a0) and n(a1), reverses the number in binary
reverse:
addi sp, sp, -28
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
sw s3, 16(sp)
sw s4, 20(sp)
sw s5, 24(sp)
call logint# Now a0 has log2(N)
addi s0, zero, 1 # j = 1
add s1, zero, zero # p = 0
forloop_reverse:
bgt s0, a0, forloop_reverse_end
sub s2, a0, s0# s2 = a0 - s0
addi s3, zero, 1
sll s3, s3, s2
and s3, a1, s3
beq s3, zero, elses3 # If not, skip
ifs3:
addi s4, s0, -1 # s4 = j - 1
addi s5, zero, 1
sll s5, s5, s4
or s1, s1, s5
elses3:
addi s0, s0, 1
j forloop_reverse
forloop_reverse_end:
add a0, s1, zero # Return p
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
lw s2, 12(sp)
lw s3, 16(sp)
lw s4, 20(sp)
lw s5, 24(sp)
addi sp, sp, 28
jr ra
sll s3, s3, s2
and s3, a1, s3
# combine above step using bext
bext s3, a1, s2
sll s5, s5, s4
or s1, s1, s5
# combine above step using bset
bset s1, s1, s4
# Takes input N (a0), returns its log base 2 in a0
logint:
addi sp, sp, -4
sw t0, 0(sp)
add t0, a0, zero# k = N
add a0, zero, zero# i = 0
logloop:
beq t0, zero, logloop_end # Exit if k == 0
srai t0, t0, 1 # k >>= 1
addi a0, a0, 1 # i++
j logloop
logloop_end:
addi a0, a0, -1 # Return i - 1
lw t0, 0(sp)
addi sp, sp, 4
jr ra
# Takes inputs N(a0) and n(a1), reverses the number in binary
reverse:
addi sp, sp, -28
sw ra, 0(sp)
sw s0, 4(sp)
sw s1, 8(sp)
sw s2, 12(sp)
sw s3, 16(sp)
sw s4, 20(sp)
sw s5, 24(sp)
call logint# Now a0 has log2(N)
addi s0, zero, 1 # j = 1
add s1, zero, zero # p = 0
forloop_reverse:
bgt s0, a0, forloop_reverse_end
sub s2, a0, s0# s2 = a0 - s0
addi s3, zero, 1
# modified
bext s3, a1, s2
beq s3, zero, elses3 # If not, skip
ifs3:
addi s4, s0, -1 # s4 = j - 1
addi s5, zero, 1
# modified
bset s1, s1, s4
elses3:
addi s0, s0, 1
j forloop_reverse
forloop_reverse_end:
add a0, s1, zero # Return p
lw ra, 0(sp)
lw s0, 4(sp)
lw s1, 8(sp)
lw s2, 12(sp)
lw s3, 16(sp)
lw s4, 20(sp)
lw s5, 24(sp)
addi sp, sp, 28
jr ra
find_binary
, converts a decimal number (provided in the a0 register) into its binary representation and returns the result in decimal form. The function works recursively to build the binary equivalent from least significant to most significant bits.
find_binary:
addi sp, sp, -8 # a0 will have arg and be where we return
sw ra, 4(sp) # saving return address on the stack
sw s0, 0(sp) # saving "decimal % 2" on the stack
beq a0, x0, postamble
andi s0, a0, 1 # set s0 to a0 % 2
srli a0, a0, 1
jal ra, find_binary # recursive call
li t0, 10
mul a0, t0, a0
add a0, a0, s0 # accumulating the stored return value "decimal % 2"
postamble:
lw ra, 4(sp) # Restore ra
lw s0, 0(sp)
addi sp, sp, 8
end:
jr ra
andi s0, a0, 1
# use bext to replace
bext s0, a0, 0
find_binary:
addi sp, sp, -8 # Save stack space
sw ra, 4(sp) # Save return address
sw s0, 0(sp) # Save "decimal % 2"
beq a0, x0, postamble # Base case: if a0 == 0, return
bext s0, a0, 0 # Extract the 0th bit (a0 % 2)
srli a0, a0, 1 # Divide a0 by 2 (logical shift right)
jal ra, find_binary # Recursive call
li t0, 10 # Load 10 into t0
mul a0, t0, a0 # Multiply a0 by 10
add a0, a0, s0 # Accumulate the result with "decimal % 2"
postamble:
lw ra, 4(sp) # Restore return address
lw s0, 0(sp) # Restore s0
addi sp, sp, 8 # Restore stack pointer
jr ra # Return
// ThreeStage
// test logint
// input(a0) is 8 ,test whether the answer(a0) is 3
it should "test logint function" in {
test(new TestTopModule("test_1.asmbin", ImplementationType.ThreeStage)).withAnnotations(TestAnnotations.annos) { c =>
c.clock.step(1000)
c.io.regs_debug_read_address.poke(5.U)
c.clock.step()
c.io.regs_debug_read_data.expect(3.U)
val realData = c.io.regs_debug_read_data.peek().litValue
println(s"[LogInt Test] Output = $realData")
c.clock.step()
}
}
// test reverse
// input N(a0) = 5, n(a1) = 3, output = 5
it should "test reverse function" in {
test(new TestTopModule("test_1.asmbin", ImplementationType.ThreeStage)).withAnnotations(TestAnnotations.annos) { c =>
c.clock.step(1000)
c.io.regs_debug_read_address.poke(5.U)
c.io.regs_debug_read_data.poke(3.U)
c.clock.step()
c.io.regs_debug_read_data.expect(5.U)
val realData = c.io.regs_debug_read_data.peek().litValue
println(s"[Reverse Test] Output = $realData")
c.clock.step()
}
}
// FiveStage(only test final)
// test logint
// input(a0) is 8 ,test whether the answer(a0) is 3
it should "test logint function" in {
test(new TestTopModule("test_1.asmbin", ImplementationType.FiveStageFinal)).withAnnotations(TestAnnotations.annos) { c =>
c.clock.step(1000)
c.io.regs_debug_read_address.poke(5.U)
c.clock.step()
c.io.regs_debug_read_data.expect(3.U)
val realData = c.io.regs_debug_read_data.peek().litValue
println(s"[LogInt Test] Output = $realData")
c.clock.step()
}
}
// test reverse
// input N(a0) = 5, n(a1) = 3, output = 5
it should "test reverse function" in {
test(new TestTopModule("test_1.asmbin", ImplementationType.FiveStageFinal)).withAnnotations(TestAnnotations.annos) { c =>
c.clock.step(1000)
c.io.regs_debug_read_address.poke(5.U)
c.io.regs_debug_read_data.poke(3.U)
c.clock.step()
c.io.regs_debug_read_data.expect(5.U)
val realData = c.io.regs_debug_read_data.peek().litValue
println(s"[Reverse Test] Output = $realData")
c.clock.step()
}
}
// ThreeStage
it should "test find_binary function" in {
test(new TestTopModule("test_find_binary.asmbin", ImplementationType.ThreeStage)).withAnnotations(TestAnnotations.annos) { c =>
// Simulate clock initialization
c.clock.step(1000)
// Set input value, assuming a0 is stored in register 5
val inputValue = 5 // Test input value, e.g., 5
c.io.regs_debug_read_address.poke(5.U) // Set register address
c.io.regs_debug_read_data.poke(inputValue.U) // Set input value 5 as a0
c.clock.step()
// Set expected output, binary of 5 is 101 (in decimal: 101)
val expectedOutput = 101
c.io.regs_debug_read_data.expect(expectedOutput.U) // Check if output matches the expected value
val realData = c.io.regs_debug_read_data.peek().litValue
println(s"[FindBinary Test] Input: $inputValue, Output: $realData")
// Advance the clock to complete module execution
c.clock.step()
}
}
// FiveStage(only test final)
it should "test find_binary function" in {
test(new TestTopModule("test_find_binary.asmbin", ImplementationType.FiveStageFinal)).withAnnotations(TestAnnotations.annos) { c =>
// Simulate clock initialization
c.clock.step(1000)
// Set input value, assuming a0 is stored in register 5
val inputValue = 5 // Test input value, e.g., 5
c.io.regs_debug_read_address.poke(5.U) // Set register address
c.io.regs_debug_read_data.poke(inputValue.U) // Set input value 5 as a0
c.clock.step()
// Set expected output, binary of 5 is 101 (in decimal: 101)
val expectedOutput = 101
c.io.regs_debug_read_data.expect(expectedOutput.U) // Check if output matches the expected value
val realData = c.io.regs_debug_read_data.peek().litValue
println(s"[FindBinary Test] Input: $inputValue, Output: $realData")
// Advance the clock to complete module execution
c.clock.step()
}
}