邱柏穎, 黃詩哲
Cloning the repository from 5-Stage-RV32I
$ git clone https://github.com/kinzafatim/5-stage-RV32I.git
$ cd 5-stage-RV32I
Due to the original author's unintended mistake in setting the path, we need to modify it manually.
Modify the path in "./src/main/scala/Pipeline/Main.scala"
origin:
val InstMemory = Module(new InstMem ("/home/kinzaa/Desktop/5-Stage-RV32I/src/main/scala/Pipeline/test.txt"))
Replace it with:
val InstMemory = Module(new InstMem ("./src/main/scala/Pipeline/test.txt"))
Then run the processor simulation by sbt test
, the output should be
$ sbt test
Elaborating design...
Elaborating design...
Done elaborating.
Done elaborating.
test PIPELINE Success: 0 tests passed in 202 cycles in 0.088404 seconds 2284.97 Hz
test DecoupledGcd Success: 0 tests passed in 841 cycles in 0.469659 seconds 1790.66 Hz
[info] TOPTest:
[info] - 5-Stage test
[info] GCDSpec:
[info] - Gcd should calculate proper greatest common denominator
[info] Run completed in 1 second, 828 milliseconds.
[info] Total number of tests run: 2
[info] Suites: completed 2, aborted 0
[info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
[success] Total time: 7 s, completed Jan 2, 2025, 11:13:43 PM
sbt.version=1.9.1
Refer to sbt Reference Manual - Hello, World we can quickly setup a simple Hello world build
for our own 5-Stage-RV32I
$ sbt new sbt/scala-seed.g8
The tree should be
$ tree
.
├── build.sbt
├── project
│ ├── build.properties
│ └── Dependencies.scala
└── src
├── main
│ └── scala
│ └── example
│ └── Hello.scala
└── test
└── scala
└── example
└── HelloSpec.scala
After that, we need to modify the build.sbt
file to include Chisel as a dependency.
import Dependencies._
ThisBuild / scalaVersion := "2.12.8"
lazy val root = (project in file("."))
.settings(
name := "our_riscv_5_stage",
libraryDependencies ++= Seq(
"edu.berkeley.cs" %% "chisel3" % "3.4.3",
"edu.berkeley.cs" %% "chiseltest" % "0.3.3" % Test,
"org.scalatest" %% "scalatest" % "3.1.4" % Test
),
scalacOptions ++= Seq(
"-Xsource:2.11",
"-language:reflectiveCalls",
"-deprecation",
"-feature",
"-Xcheckinit",
// Enables autoclonetype2 in 3.4.x (on by default in 3.5)
"-P:chiselplugin:useBundlePlugin"
),
addCompilerPlugin("edu.berkeley.cs" % "chisel3-plugin" % "3.4.3" cross CrossVersion.full),
addCompilerPlugin("org.scalamacros" % "paradise" % "2.1.1" cross CrossVersion.full)
)
To test our 5-stage pipelined, we utilize riscv-tests
.
By default, the PC start address in riscv-tests is set to 0x80000000
, but for convenience, we modify it to 0x00000000
.
$vim /opt/riscv/riscv-tests/env/p/link,ld
SECTIONS
{
.=0x80000000; //change it to 0x00000000
...
}
then make the riscv-tests
$cd /opt/riscv/riscv-tests
$autoconf
$./configure --prefix=/src/target
$make$make install
the test file from riscv-tests will be generated at /src/target/share/riscv-tests/isa
, for instance we can check by
$ file /src/target/share/riscv-tests/isa/rv32ui-p-add
/src/target/share/riscv-tests/isa/rv32ui-p-add: ELF 32-bit LSB executable, UCB RISC-V, soft-float ABI, version 1 (SYSV), statically linked, not stripped
we have to transfer the EFE file
into .bin
file.
$ riscv64-unknown-elf-objcopy -O binary /src/target/share/riscv-tests/isa/rv32ui-p-add rv32ui-p-add.bin
then change it into .hex
file
od -An -tx1 -w1 -v rv32ui-p-add.bin >> rv32ui-p-add.hex
the result .hex
file would be like
6f
00
00
05
73
2f
20
34
..
Ref: riscv-chisel-book - Chapter 20
To streamline the process of managing and deploying the project, we use Docker to package everything. Follow the steps below to build and run the Docker container:
Run the following command in your project directory to create a Docker image:
docker build . -t riscv/our_riscv
This will build a Docker image named riscv/our_riscv
.
Once the image is built, use the command below to start the Docker container and mount the current directory to /src inside the container:
docker run -it -v ./:/src riscv/our_riscv
DockerFile:
FROM ubuntu:22.04
ENV RISCV=/opt/riscv
ENV PATH=$RISCV/bin:$PATH
ENV MAKEFLAGS=-j4
WORKDIR $RISCV
# Install dependencies
RUN apt update && \
apt install -y autoconf automake autotools-dev curl libmpc-dev libmpfr-dev libgmp-dev gawk build-essential bison flex texinfo gperf libtool patchutils bc zlib1g-dev libexpat-dev pkg-config git libusb-1.0-0-dev device-tree-compiler default-jdk gnupg vim
# riscv-gnu-toolchain
RUN git clone --recursive --single-branch https://github.com/riscv-collab/riscv-gnu-toolchain
RUN cd riscv-gnu-toolchain && mkdir build && cd build && ../configure --prefix=${RISCV} --enable-multilib && make
# riscv-tests
RUN git clone -b master --single-branch https://github.com/riscv/riscv-tests && \
cd riscv-tests && git submodule update --init --recursive
# sbt
RUN echo "deb https://repo.scala-sbt.org/scalasbt/debian all main" | tee -a /etc/apt/sources.list.d/sbt.list && \
echo "deb https://repo.scala-sbt.org/scalasbt/debian /" | tee /etc/apt/sources.list.d/sbt_old.list && \
curl -sL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x2EE0EA64E40A89B84B2DF73499E82A75642AC823" | apt-key add && \
apt-get update && apt-get install -y sbt
The following image illustrates the five-stage pipelined datapath
where the folder structure is as follows
$ tree
.
├── Hazard Units
│ ├── BranchForward.scala
│ ├── Forwarding.scala
│ ├── HazardDetection.scala
│ └── StructuralHazard.scala
├── Main.scala
├── Memory
│ ├── DataMemory.scala
│ └── InstMem.scala
├── Pipelines
│ ├── EX_MEM.scala
│ ├── ID_EX.scala
│ ├── IF_ID.scala
│ └── MEM_WB.scala
├── test.txt
└── UNits
├── Alu_Control.scala
├── Alu.scala
├── BRANCH.scala
├── Control.scala
├── ImmGenerator.scala
├── JALR.scala
├── PC4.scala
├── PC.scala
└── RegisterFile.scala
The following will introduce the contents of each object.
Memory units which are fetched during execution.
For Data Memory, there are five I/O ports as following,
I/O Port | Variable Name | Description |
---|---|---|
MemAddress | addr |
Specifies the memory address to either read data from or write data to. |
MemWriteData | dataIn |
The data to be written into the memory. |
MemREn | mem_read |
Indicates whether a read operation is enabled. |
MemWEn | mem_write |
Indicates whether a write operation is enabled. |
MemReadData | dataOut |
Outputs the data read from the memory. |
val io = IO(new Bundle {
val addr = Input(UInt(32.W)) // Address input
val dataIn = Input(SInt(32.W)) // Data to be written
val mem_read = Input(Bool()) // Memory read enable
val mem_write = Input(Bool()) // Memory write enable
val dataOut = Output(SInt(32.W)) // Data output
})
After initializing the I/O ports, creating a memory module by using
val Dmemory = Mem(1024, SInt(32.W))
Where 1024
is the number of memory cells or locations, and SInt(32.W)
defines the data type stored as a signed integer that is 32 bits wide.
Therefore, we will have 32 bits * 1024 = 32768 bits = 4096 Bytes(4 KB)
memory.
There are two types of read-write memories that can be implemented in Chisel: SyncReadMem
and Mem
.
SyncReadMem
represents synchronous-read, synchronous-write memories, where the values on the read data port are not guaranteed to be valid until the next clock cycle.Mem
represents asynchronous-read, synchronous-write memories, where the value is output immediately after the address is provided.For the implementation of the 5-stage RISC-V pipeline, we choose Mem
due to its simpler integration, although SyncReadMem
is closer to real-world hardware behavior, such as that of FPGAs and other applications.
For the final step, we initialize the output to 0
and perform read/write operations based on the values of io.mem_write
and io.mem_read
.
when(io.mem_write) {
Dmemory.write(io.addr, io.dataIn)
}
when(io.mem_read) {
io.dataOut := Dmemory.read(io.addr)
}
The complete code of Data Memory.scala
would be
package Pipeline
import chisel3._
import chisel3.util._
class DataMemory extends Module {
val io = IO(new Bundle {
val addr = Input(UInt(32.W)) // Address input
val dataIn = Input(SInt(32.W)) // Data to be written
val mem_read = Input(Bool()) // Memory read enable
val mem_write = Input(Bool()) // Memory write enable
val dataOut = Output(SInt(32.W)) // Data output
})
val Dmemory = Mem(1024, SInt(32.W))
io.dataOut := 0.S
when(io.mem_write) {
Dmemory.write(io.addr, io.dataIn)
}
when(io.mem_read) {
io.dataOut := Dmemory.read(io.addr)
}
}
For Instruction Memory, there are two I/O ports as following,
I/O Port | Variable Name | Description |
---|---|---|
PC address | addr |
PC address for the the instructions. |
inst | data |
The data of instructions. |
As stated before, the instruction memory block also used a Mem
class for declaration, which allocate
val imem = Mem(16384, UInt(8.W))
WE implement this module as read-only memory, therefore, there is only one input and one output.
val io = IO(new Bundle {
val addr = Input(UInt(32.W)) // Address input to fetch instruction
val data = Output(UInt(32.W)) // Output instruction
}
A noticeable thing here is that, we called the loadMemoryFromFile
function.
loadMemoryFromFile(imem, initFile)
This function load binary into memory module. In this senerio, the module load test.txt
in the root directory of the package into the InstMem
module. Finally, we drive the output signal with :
io.data := imem(io.addr/4.U)
We devide io.addr
by 4
for word alignment.
2025/01/09
To enable our module to input .hex
files generated by riscv-tests, we modified the module as follows:
val imem = Mem(4096, UInt(8.W))
loadMemoryFromFile(imem, initFile)
io.data := Cat(
imem(io.addr+3.U(32.W)),
imem(io.addr+2.U(32.W)),
imem(io.addr+1.U(32.W)),
imem(io.addr)
)
The memory (imem
) is defined as 4096 x 8-bit, where each location represents a single byte and the total storage remain the same.
Fetches 4 consecutive bytes and concatenates them to form a 32-bit instruction.
Putting all together we get:
package Pipeline
import chisel3._
import chisel3.util._
import chisel3.util.experimental.loadMemoryFromFile
import scala.io.Source
class InstMem(initFile: String) extends Module {
val io = IO(new Bundle {
val addr = Input(UInt(32.W)) // Address input to fetch instruction
val data = Output(UInt(32.W)) // Output instruction
})
val imem = Mem(1024, UInt(8.W))
loadMemoryFromFile(imem, initFile)
io.data := Cat(
imem(io.addr+3.U(32.W)),
imem(io.addr+2.U(32.W)),
imem(io.addr+1.U(32.W)),
imem(io.addr)
)
// For txt imput
// val imem = Mem(1024, UInt(32.W))
// io.data := imem(io.addr/4.U)
}
Pipeline registers for storing results of the previous stage.
For IF stage to ID stage , there are two I/O ports as following,
I/O Port | Variable Name | Description |
---|---|---|
PC (I/O) | pc_in(out) , pc4_in(out) , SelectedPC(out) |
PC for the instructions, the next instructions, and selected PC. |
inst (I/O) | SelectedInstr(out) |
The instruction that has been selected. |
The PC (I/O)
and PC+4 (I/O)
correspond to the I/O ports of the PC
module, while the SelectedPC (I/O)
and Selected Instruction (I/O)
correspond to the I/O ports of the instruction module.
![]() |
![]() |
---|
This module serve as register that save program counter and the fetched instruction accordingly. In order to speed up the pipeline, we forward PC
and PC+4
simultaneously.
Four registers are declared and initialized using RegInit
, which sets their reset values:
val Pc_In = RegInit (0.S (32.W))
val Pc4_In = RegInit (0.U (32.W))
val S_pc = RegInit (0.S (32.W))
val S_instr = RegInit (0.U (32.W))
These reset values are applied during a system reset, ensuring the hardware starts in a known state.
During normal operation, the registers are updated with the input signals every clock cycle:
Pc_In := io.pc_in
Pc4_In := io.pc4_in
S_pc := io.SelectedPC
S_instr := io.SelectedInstr
io.pc_out := Pc_In
io.pc4_out := Pc4_In
io.SelectedPC_out := S_pc
io.SelectedInstr_out := S_instr
This ensures that the register values are synchronized with the input signals on each clock edge.
Putting all together we get:
package Pipeline
import chisel3._
import chisel3.util._
class IF_ID extends Module {
val io = IO(new Bundle {
val pc_in = Input (SInt(32.W)) // PC in
val pc4_in = Input (UInt(32.W)) // PC4 in
val SelectedPC = Input (SInt(32.W))
val SelectedInstr = Input (UInt(32.W))
val pc_out = Output (SInt(32.W)) // PC out
val pc4_out = Output (UInt(32.W)) // PC + 4 out
val SelectedPC_out = Output (SInt(32.W))
val SelectedInstr_out = Output (UInt(32.W))
})
val Pc_In = RegInit (0.S (32.W))
val Pc4_In = RegInit (0.U (32.W))
val S_pc = RegInit (0.S (32.W))
val S_instr = RegInit (0.U (32.W))
Pc_In := io.pc_in
Pc4_In := io.pc4_in
S_pc := io.SelectedPC
S_instr := io.SelectedInstr
io.pc_out := Pc_In
io.pc4_out := Pc4_In
io.SelectedPC_out := S_pc
io.SelectedInstr_out := S_instr
In the ID_EX
, EX_MEM
, and MEM_WB
stages, we use RegNext
to ensure that the input signals represent the next states of the registers, while the output signals reflect their current states.
For ID stage to EX stage , there are four I/O ports as following,
I/O Port | Variable Name | Description |
---|---|---|
PC (I/O) | IFID_pc4_in(out) |
Program counter passed to the next stage. |
RegRead Data 1 (I/O) | rs1_data_in(out) |
Data stored in the rs1 register. |
RegRead Data 2 (I/O) | rs2_data_in(out) |
Data stored in the rs2 register |
inst (I/O) | rs1_in(out) ,rs2_in(out) ,rd_in(out) , func3_in(out) , func7_in(out) |
Indecices of the selected registers, and the decoded parts of the instruction. |
![]() |
![]() |
![]() |
![]() |
---|
so we have :
package Pipeline
import chisel3._
import chisel3.util._
class ID_EX extends Module {
val io = IO(new Bundle {
val rs1_in = Input(UInt(5.W))
val rs2_in = Input(UInt(5.W))
val rs1_data_in = Input(SInt(32.W))
val rs2_data_in = Input(SInt(32.W))
val imm = Input(SInt(32.W))
val rd_in = Input(UInt(5.W))
val func3_in = Input(UInt(3.W))
val func7_in = Input(Bool())
val ctrl_MemWr_in = Input(Bool())
val ctrl_Branch_in = Input(Bool())
val ctrl_MemRd_in = Input(Bool())
val ctrl_Reg_W_in = Input(Bool())
val ctrl_MemToReg_in = Input(Bool())
val ctrl_AluOp_in = Input(UInt(3.W))
val ctrl_OpA_in = Input(UInt(2.W))
val ctrl_OpB_in = Input(Bool())
val ctrl_nextpc_in = Input(UInt(2.W))
val IFID_pc4_in = Input(UInt(32.W))
val rs1_out = Output(UInt(5.W))
val rs2_out = Output(UInt(5.W))
val rs1_data_out = Output(SInt(32.W))
val rs2_data_out = Output(SInt(32.W))
val rd_out = Output(UInt(5.W))
val imm_out = Output(SInt(32.W))
val func3_out = Output(UInt(3.W))
val func7_out = Output(Bool())
val ctrl_MemWr_out = Output(Bool())
val ctrl_Branch_out = Output(Bool())
val ctrl_MemRd_out = Output(Bool())
val ctrl_Reg_W_out = Output(Bool())
val ctrl_MemToReg_out = Output(Bool())
val ctrl_AluOp_out = Output(UInt(3.W))
val ctrl_OpA_out = Output(UInt(2.W))
val ctrl_OpB_out = Output(Bool())
val ctrl_nextpc_out = Output(UInt(2.W))
val IFID_pc4_out = Output(UInt(32.W))
})
io.rs1_out := RegNext(io.rs1_in)
io.rs2_out := RegNext(io.rs2_in)
io.rs1_data_out := RegNext(io.rs1_data_in)
io.rs2_data_out := RegNext(io.rs2_data_in)
io.imm_out := RegNext(io.imm)
io.rd_out := RegNext(io.rd_in)
io.func3_out := RegNext(io.func3_in)
io.func7_out := RegNext(io.func7_in)
io.ctrl_MemWr_out := RegNext(io.ctrl_MemWr_in)
io.ctrl_Branch_out := RegNext(io.ctrl_Branch_in)
io.ctrl_MemRd_out := RegNext(io.ctrl_MemRd_in)
io.ctrl_Reg_W_out := RegNext(io.ctrl_Reg_W_in)
io.ctrl_MemToReg_out := RegNext(io.ctrl_MemToReg_in)
io.ctrl_AluOp_out := RegNext(io.ctrl_AluOp_in)
io.ctrl_OpA_out := RegNext(io.ctrl_OpA_in)
io.ctrl_OpB_out := RegNext(io.ctrl_OpB_in)
io.ctrl_nextpc_out := RegNext(io.ctrl_nextpc_in)
io.IFID_pc4_out := RegNext(io.IFID_pc4_in)
}
For EX stage to MEM stage , there are four I/O ports as following,
I/O Port | Variable Name | Description |
---|---|---|
PC (I/O) | None |
|
ALU Out (I/O) | alu_out (EXMEM_alu_out) |
Result computed by ALU. |
RegRead Data 2 (I/O) | IDEX_rs2 (EXMEM_rs2_out) |
Data read from the register rs1 . |
inst (I/O) | IDEX_rd (EXMEM_rd_out) |
Index of the register for storing data read from memory. |
![]() |
![]() |
![]() |
![]() |
---|
so we have :
package Pipeline
import chisel3._
import chisel3.util._
class EX_MEM extends Module {
val io = IO(new Bundle {
val IDEX_MEMRD = Input(Bool())
val IDEX_MEMWR = Input(Bool())
val IDEX_MEMTOREG = Input(Bool())
val IDEX_REG_W = Input(Bool())
val IDEX_rs2 = Input(SInt(32.W))
val IDEX_rd = Input(UInt(5.W))
val alu_out = Input(SInt(32.W))
val EXMEM_memRd_out = Output(Bool())
val EXMEM_memWr_out = Output(Bool())
val EXMEM_memToReg_out = Output(Bool())
val EXMEM_reg_w_out = Output(Bool())
val EXMEM_rs2_out = Output(SInt(32.W))
val EXMEM_rd_out = Output(UInt(5.W))
val EXMEM_alu_out = Output(SInt(32.W))
})
io.EXMEM_memRd_out := RegNext(io.IDEX_MEMRD)
io.EXMEM_memWr_out := RegNext(io.IDEX_MEMWR)
io.EXMEM_memToReg_out := RegNext(io.IDEX_MEMTOREG)
io.EXMEM_reg_w_out := RegNext(io.IDEX_REG_W)
io.EXMEM_rs2_out := RegNext(io.IDEX_rs2)
io.EXMEM_rd_out := RegNext(io.IDEX_rd)
io.EXMEM_alu_out := RegNext(io.alu_out)
}
For MEM stage to WB stage , there are two I/O ports as following,
I/O Port | Variable Name | Description |
---|---|---|
Reg Write Data (I/O) | in_dataMem_out (MEMWB_dataMem_out) , in_alu_out (MEMWB_alu_out) |
Data to be written to memory, and address computed by the ALU. |
inst (I/O) | EXMEM_rd (MEMWB_rd_out) |
Index of register that will store the data. |
![]() |
![]() |
---|
so we have :
package Pipeline
import chisel3._
import chisel3.util._
class MEM_WB extends Module {
val io = IO(new Bundle {
val EXMEM_MEMTOREG = Input(Bool())
val EXMEM_REG_W = Input(Bool())
val EXMEM_MEMRD = Input(Bool())
val EXMEM_rd = Input(UInt(5.W))
val in_dataMem_out = Input(SInt(32.W))
val in_alu_out = Input(SInt(32.W))
val MEMWB_memToReg_out = Output(Bool())
val MEMWB_reg_w_out = Output(Bool())
val MEMWB_memRd_out = Output(Bool())
val MEMWB_rd_out = Output(UInt(5.W))
val MEMWB_dataMem_out = Output(SInt(32.W))
val MEMWB_alu_out = Output(SInt(32.W))
})
io.MEMWB_memToReg_out := RegNext(io.EXMEM_MEMTOREG)
io.MEMWB_reg_w_out := RegNext(io.EXMEM_REG_W)
io.MEMWB_memRd_out := RegNext(io.EXMEM_MEMRD)
io.MEMWB_rd_out := RegNext(io.EXMEM_rd)
io.MEMWB_dataMem_out := RegNext(io.in_dataMem_out)
io.MEMWB_alu_out := RegNext(io.in_alu_out)
}
To generate the control signal for the ALU, we design the ALUDecode
unit. This unit takes funct3
, funct7
, and aluOp
as inputs, which represent the instruction's functional fields and operation type.
val io = IO(new Bundle {
val Type = Input(UInt(3.W))
val funct7 = Input(UInt(1.W)) // Only funct7[5] matters
val funct3 = Input(UInt(3.W))
val ALUSel = Output(UInt(4.W))
})
To find out what kind of operation the ALU
module need to perform, we first identify what instruction's type is, then generate ALUSel
signal from funct3
(or funct3 + funct7 if the instruction is R-type
). The following section is how we generate the ALUSel
signal:
package Pipeline
import chisel3._
import chisel3.util._
class ALUDecode extends Module {
val io = IO(new Bundle {
val Type = Input(UInt(3.W))
val funct7 = Input(UInt(1.W)) // Only funct7[5] matters
val funct3 = Input(UInt(3.W))
val ALUSel = Output(UInt(4.W))
})
// ALUSel table
val ADD = 0.U(4.W)
val SUB = 1.U(4.W)
val AND = 2.U(4.W)
val OR = 3.U(4.W)
val XOR = 4.U(4.W)
val SLL = 5.U(4.W)
val SRL = 6.U(4.W)
val SRA = 7.U(4.W)
val X = 8.U(4.W) // X means don't care
// Instruction types
val Rtype = 0.U(3.W)
val Itype = 1.U(3.W)
val Ltype = 4.U(3.W)
val Stype = 5.U(3.W)
val ConditionalBranch = 2.U(3.W)
val UnconditionalBranch = 3.U(3.W)
val AUIPC = 7.U(3.W)
val LUI = 6.U(3.W)
val temp = Cat(io.funct7, io.funct3) // Combine funct7 and funct3
// Default ALUSel assignment
io.ALUSel := MuxLookup(io.Type, X, Array(
Stype -> ADD,
Ltype -> ADD,
AUIPC -> ADD,
Rtype -> MuxLookup(temp, X, Array(
"b0111".U -> AND,
"b0110".U -> OR,
"b1101".U -> SRA,
"b0101".U -> SRL,
"b0100".U -> XOR,
"b0001".U -> SLL,
"b1000".U -> SUB,
"b0000".U -> ADD
)),
Itype -> MuxLookup(temp, X, Array(
"b1101".U -> SRA,
"b0101".U -> SRL,
"b0001".U -> SLL,
"b0111".U -> AND,
"b0110".U -> OR,
"b0100".U -> XOR,
"b0000".U -> ADD
)),
ConditionalBranch -> X,
UnconditionalBranch -> X,
LUI -> X
))
}
With this module, designing ALU
module would simple.
I/O Port | Variable Name | Description |
---|---|---|
A | in_A |
Input data A |
B | in_B |
Input data B |
ALU Select | ALUSel |
ALU select decide which operation to compute |
ALU Output | out |
The result of ALU. |
ALU is used for performing arithmetic operations (such as addition, subtraction, multiplication, and division) and logical operations (such as AND, OR, NOT, XOR). How the ALU perform operation is defined as :
// ALUSel table
val ADD = 0.U(4.W)
val SUB = 1.U(4.W)
val AND = 2.U(4.W)
val OR = 3.U(4.W)
val XOR = 4.U(4.W)
val SLL = 5.U(4.W)
val SRL = 6.U(4.W)
val SRA = 7.U(4.W)
val X = 8.U(4.W) // X means don't care
And the ALU will execute according to the table:
switch(io.ALUSel)
is(ADD) { io.out := io.in_A + io.in_B }
is(SUB) { io.out := io.in_A - io.in_B }
is(AND) { io.out := io.in_A & io.in_B }
is(OR) { io.out := io.in_A
is(XOR) { io.out := io.in_A ^ io.in_B }
is(SLL) { io.out := io.in_A << io.in_B(4, 0) } // Limit shift amount to 5 bits
is(SRL) { io.out := io.in_A >> io.in_B(4, 0) } // Limit shift amount to 5 bits
is(SRA) { io.out := (io.in_A.asSInt >> io.in_B(4, 0)) } // Signed right shift
is(X) { io.out := 0.S }
}
Putting everything together:
package Pipeline
import chisel3._
import chisel3.util._
class ALU extends Module {
val io = IO(new Bundle {
val in_A = Input(SInt(32.W))
val in_B = Input(SInt(32.W))
val ALUSel = Input(UInt(4.W))
val out = Output(SInt(32.W))
})
// ALUSel table
val ADD = 0.U(4.W)
val SUB = 1.U(4.W)
val AND = 2.U(4.W)
val OR = 3.U(4.W)
val XOR = 4.U(4.W)
val SLL = 5.U(4.W)
val SRL = 6.U(4.W)
val SRA = 7.U(4.W)
val X = 8.U(4.W) // X means don't care
// Default value for out
io.out := 0.S
switch(io.ALUSel) {
is(ADD) { io.out := io.in_A + io.in_B }
is(SUB) { io.out := io.in_A - io.in_B }
is(AND) { io.out := io.in_A & io.in_B }
is(OR) { io.out := io.in_A | io.in_B }
is(XOR) { io.out := io.in_A ^ io.in_B }
is(SLL) { io.out := io.in_A << io.in_B(4, 0) } // Limit shift amount to 5 bits
is(SRL) { io.out := io.in_A >> io.in_B(4, 0) } // Limit shift amount to 5 bits
is(SRA) { io.out := (io.in_A.asSInt >> io.in_B(4, 0)) } // Signed right shift
is(X) { io.out := 0.S }
}
}
There are opcode
, and func3
indicate which instruction it is.
instruction | funct3 (bianry) | funct3 (decimal) |
---|---|---|
beq |
000 | 0 |
bge |
101 | 5 |
bgeu |
111 | 7 |
blt |
100 | 4 |
bltu |
110 | 6 |
bne |
001 | 1 |
By examing IO ports, it will be clear how this module work:
class Branch extends Module {
val io = IO(new Bundle {
val fnct3 = Input(UInt(3.W))
val branch = Input(Bool())
val arg_x = Input(SInt(32.W))
val arg_y = Input(SInt(32.W))
val br_taken = Output(Bool())
})
branch
decide whether this is a B-type
instruction or not, while arg_x
, arg_y
are value coming from rs1
and rs2
, func3
indicates which B-type instruction it is, and br_taken
is the output signal set according to input.
Control signal setting:
when(io.branch) {
// beq
when(io.fnct3 === 0.U) {
io.br_taken := io.arg_x === io.arg_y
}
// bne
.elsewhen(io.fnct3 === 1.U) {
io.br_taken := io.arg_x =/= io.arg_y
}
// blt
.elsewhen(io.fnct3 === 4.U) {
io.br_taken := io.arg_x < io.arg_y
}
// bge
.elsewhen(io.fnct3 === 5.U) {
io.br_taken := io.arg_x >= io.arg_y
}
// bltu (unsigned less than)
.elsewhen(io.fnct3 === 6.U) {
io.br_taken := io.arg_x.asUInt < io.arg_y.asUInt
}
// bgeu (unsigned greater than or equal)
.elsewhen(io.fnct3 === 7.U) {
io.br_taken := io.arg_x.asUInt >= io.arg_y.asUInt
}
}
Putting them all together:
package Pipeline
import chisel3._
import chisel3.util._
class Branch extends Module {
val io = IO(new Bundle {
val fnct3 = Input(UInt(3.W))
val branch = Input(Bool())
val arg_x = Input(SInt(32.W))
val arg_y = Input(SInt(32.W))
val br_taken = Output(Bool())
})
io.br_taken := false.B
when(io.branch) {
// beq
when(io.fnct3 === 0.U) {
io.br_taken := io.arg_x === io.arg_y
}
// bne
.elsewhen(io.fnct3 === 1.U) {
io.br_taken := io.arg_x =/= io.arg_y
}
// blt
.elsewhen(io.fnct3 === 4.U) {
io.br_taken := io.arg_x < io.arg_y
}
// bge
.elsewhen(io.fnct3 === 5.U) {
io.br_taken := io.arg_x >= io.arg_y
}
// bltu (unsigned less than)
.elsewhen(io.fnct3 === 6.U) {
io.br_taken := io.arg_x.asUInt < io.arg_y.asUInt
}
// bgeu (unsigned greater than or equal)
.elsewhen(io.fnct3 === 7.U) {
io.br_taken := io.arg_x.asUInt >= io.arg_y.asUInt
}
}
}
Control signal are directly mapped from opcode
, the sheet will demonstrate the mapping relationship:
opcode | Instruction Type |
---|---|
011 0011 | R-type |
110 0011 | B-type |
001 0011 | I-type |
010 0011 | S-type |
000 0011 | L-type |
001 0111 | AUIPC |
011 0111 | LUI |
110 1111 | JAL |
110 0111 | JALR |
And the control signals can be set accordingly. We set the control signals with chisel's swtich
syntax:
package Pipeline
import chisel3._
import chisel3.util._
class Control extends Module {
val io = IO(new Bundle {
val opcode = Input(UInt(7.W)) // 7-bit opcode
val mem_write = Output(Bool()) // whether a write to memory
val branch = Output(Bool()) // whether a branch instruction
val mem_read = Output(Bool()) // whether a read from memory
val reg_write = Output(Bool()) // whether a register write
val men_to_reg = Output(Bool()) // whether the value written to a register (for load instructions)
val alu_operation = Output(UInt(3.W))
val operand_A = Output(UInt(2.W)) // Operand A source selection for the ALU
val operand_B = Output(Bool()) // Operand B source selection for the ALU
// Indicates the type of extension to be used (e.g., sign-extend, zero-extend)
val extend = Output(UInt(2.W))
val next_pc_sel = Output(UInt(2.W)) // next PC value (e.g., PC+4, branch target, jump target)
})
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 0.B
io.men_to_reg := 0.B
io.alu_operation := 0.U
io.operand_A := 0.U
io.operand_B := 0.B
io.extend := 0.U
io.next_pc_sel := 0.U
switch(io.opcode) {
// R type instructions (e.g., add, sub)
is(51.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 1.B
io.men_to_reg := 0.B
io.alu_operation := 0.U
io.operand_A := 0.U
io.operand_B := 0.B
io.extend := 0.U
io.next_pc_sel := 0.U
}
// I type instructions (e.g., immediate operations)
is(19.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 1.B
io.men_to_reg := 0.B
io.alu_operation := 1.U
io.operand_A := 0.U
io.operand_B := 1.B
io.extend := 0.U
io.next_pc_sel := 0.U
}
// S type instructions (e.g., store operations)
is(35.U) {
io.mem_write := 1.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 0.B
io.men_to_reg := 0.B
io.alu_operation := 5.U
io.operand_A := 0.U
io.operand_B := 1.B
io.extend := 1.U
io.next_pc_sel := 0.U
}
// Load instructions (e.g., load data from memory)
is(3.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 1.B
io.reg_write := 1.B
io.men_to_reg := 1.B
io.alu_operation := 4.U
io.operand_A := 0.U
io.operand_B := 1.B
io.extend := 0.U
io.next_pc_sel := 0.U
}
// SB type instructions (e.g., conditional branch)
is(99.U) {
io.mem_write := 0.B
io.branch := 1.B
io.mem_read := 0.B
io.reg_write := 0.B
io.men_to_reg := 0.B
io.alu_operation := 2.U
io.operand_A := 0.U
io.operand_B := 0.B
io.extend := 0.U
io.next_pc_sel := 1.U
}
// UJ type instructions (e.g., jump and link)
is(111.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 1.B
io.men_to_reg := 0.B
io.alu_operation := 3.U
io.operand_A := 1.U
io.operand_B := 0.B
io.extend := 0.U
io.next_pc_sel := 2.U
}
// Jalr instruction (e.g., jump and link register)
is(103.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 1.B
io.men_to_reg := 0.B
io.alu_operation := 3.U
io.operand_A := 1.U
io.operand_B := 0.B
io.extend := 0.U
io.next_pc_sel := 3.U
}
// U type (LUI) instructions (e.g., load upper immediate)
is(55.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 1.B
io.men_to_reg := 0.B
io.alu_operation := 6.U
io.operand_A := 3.U
io.operand_B := 1.B
io.extend := 2.U
io.next_pc_sel := 0.U
}
// U type (AUIPC) instructions (e.g., add immediate to PC)
is(23.U) {
io.mem_write := 0.B
io.branch := 0.B
io.mem_read := 0.B
io.reg_write := 1.B
io.men_to_reg := 0.B
io.alu_operation := 7.U
io.operand_A := 2.U
io.operand_B := 1.B
io.extend := 2.U
io.next_pc_sel := 0.U
}
}
}
Different types of instructions need different ways to concatenate the immediate number. The module outputs the immediate number based on the input instruction.
bits position | 31-25 | 24-20 | 19-15 | 14-12 | 11-7 | 6-0 |
---|---|---|---|---|---|---|
I | imm[11:0] | rs1 | funct3 | rd | opcode |
For I-type, we extract the inst[31-20]
and sign extend the MSB.
io.I_type := Cat(Fill(20, io.instr(31)), io.instr(31, 20)).asSInt
bits position | 31-25 | 24-20 | 19-15 | 14-12 | 11-7 | 6-0 |
---|---|---|---|---|---|---|
S | imm[11:5] | rs2 | rs1 | funct3 | imm[4:0] | opcode |
For S-type, we concatenate inst[31-25]
with inst[11-7]
and sign extend the MSB.
io.S_type := Cat(Fill(20, io.instr(31)), io.instr(31, 25), io.instr(11, 7)).asSInt
bits position | 31-25 | 24-20 | 19-15 | 14-12 | 11-7 | 6-0 |
---|---|---|---|---|---|---|
B | imm[12|10:5] | rs2 | rs1 | funct3 | imm[4:1|11] | opcode |
For Branch-type, we concatenate inst[31]
,inst[7]
,inst[20-25]
,inst[11-8]
, 0 and sign extend the MSB.Furthermore, we also add the program counter for jump
val sbImm = Cat(Fill(19, io.instr(31)), io.instr(31), io.instr(7), io.instr(30, 25), io.instr(11, 8), 0.U(1.W)).asSInt
io.SB_type := sbImm + io.pc.asSInt
bits position | 31-25 | 24-20 | 19-15 | 14-12 | 11-7 | 6-0 |
---|---|---|---|---|---|---|
U | imm[31:12] | rd | opcode |
For U-type, we eatract inst[31-12]
and fill the rest of the bits with 0.
io.U_type := Cat(io.instr(31, 12), Fill(12, 0.U)).asSInt
bits position | 31-25 | 24-20 | 19-15 | 14-12 | 11-7 | 6-0 |
---|---|---|---|---|---|---|
UJ | imm[20|10:1|11|19:12] | rd | opcode |
For UJ-type, we concatenate inst[31]
,inst[19-12]
,inst[20]
,inst[30-21]
, 0 and sign extend the MSB. Furthermore, we also add the program counter for jump
val ujImm = Cat(Fill(11, io.instr(31)), io.instr(31), io.instr(19, 12), io.instr(20), io.instr(30, 21), 0.U(1.W)).asSInt
io.UJ_type := ujImm + io.pc.asSInt
The complete code is shown below.
package Pipeline
import chisel3._
import chisel3.util._
class ImmGenerator extends Module {
val io = IO(new Bundle {
val instr = Input(UInt(32.W))
val pc = Input(UInt(32.W))
val I_type = Output(SInt(32.W))
val S_type = Output(SInt(32.W))
val SB_type = Output(SInt(32.W))
val U_type = Output(SInt(32.W))
val UJ_type = Output(SInt(32.W))
})
// I-Type Immediate: [31:20] sign-extended to 32 bits
io.I_type := Cat(Fill(20, io.instr(31)), io.instr(31, 20)).asSInt
// S-Type Immediate: [31:25][11:7] sign-extended to 32 bits
io.S_type := Cat(Fill(20, io.instr(31)), io.instr(31, 25), io.instr(11, 7)).asSInt
// Branch-Type Immediate: [31][7][30:25][11:8] sign-extended to 32 bits
val sbImm = Cat(Fill(19, io.instr(31)), io.instr(31), io.instr(7), io.instr(30, 25), io.instr(11, 8), 0.U(1.W)).asSInt
io.SB_type := sbImm + io.pc.asSInt
// U-Type Immediate: [31:12] shifted left by 12 bits
io.U_type := Cat(io.instr(31, 12), Fill(12, 0.U)).asSInt
// UJ-Type Immediate: [31][19:12][20][30:21] sign-extended to 32 bits, shifted left by 1 bit
val ujImm = Cat(Fill(11, io.instr(31)), io.instr(31), io.instr(19, 12), io.instr(20), io.instr(30, 21), 0.U(1.W)).asSInt
io.UJ_type := ujImm + io.pc.asSInt
}
This part is not directly related to the the five-stage pipelined datapath image
, but by writing this component independently, the circuit can be more understandable.
The module implements the JALR
instruction directly by adding generated imm
and value from register to compute destination address
val computedAddr = io.imme + io.rdata1
then we align the address by bitwise operation.
io.out := computedAddr & "hFFFFFFFE".U
The complete code is shown below.
package Pipeline
import chisel3._
import chisel3.util._
class Jalr extends Module {
val io = IO(new Bundle {
val imme = Input(UInt(32.W))
val rdata1 = Input(UInt(32.W))
val out = Output(UInt(32.W))
})
val computedAddr = io.imme + io.rdata1
// Align the address by masking the least significant bit (LSB) to 0
io.out := computedAddr & "hFFFFFFFE".U
}
The PC
and PC4
modules share a similar structure in terms of handling input and output, but their functionality differs slightly based on whether the input is incremented by 4.
PC:
package Pipeline
import chisel3._
import chisel3.util._
class PC extends Module {
val io = IO (new Bundle {
val in = Input(SInt(32.W))
val out = Output(SInt(32.W))
})
val PC = RegInit(0.S(32.W))
io.out := PC
PC := io.in
}
PC4:
package Pipeline
import chisel3._
import chisel3.util._
class PC4 extends Module {
val io = IO (new Bundle {
val pc = Input(UInt(32.W))
val out = Output(UInt(32.W))
})
io.out := 0.U
io.out := io.pc + 4.U(32.W)
}
For the register file, we use RegInit
together with VecInit
to create 32 registers, each initialized to 0.
val regfile = RegInit(VecInit(Seq.fill(32)(0.S(32.W))))
RegisterFile
module accepts two register addresses for reading and one address for writing.
When reading a register, if the address is 0, the output is always 0. Otherwise, it outputs the data stored at the specified address.
io.rdata1 := Mux(io.rs1 === 0.U, 0.S, regfile(io.rs1))
io.rdata2 := Mux(io.rs2 === 0.U, 0.S, regfile(io.rs2))
When writing to a register, the module first checks the write enable signal (reg_write
) and ensures the target address is not 0. If both conditions are met, the data is written to the specified register.
when(io.reg_write && io.w_reg =/= 0.U) {
regfile(io.w_reg) := io.w_data
}
The complete code is shown below.
package Pipeline
import chisel3._
import chisel3.util._
class RegisterFile extends Module {
val io = IO(new Bundle {
val rs1 = Input(UInt(5.W))
val rs2 = Input(UInt(5.W))
val reg_write = Input(Bool())
val w_reg = Input(UInt(5.W))
val w_data = Input(SInt(32.W))
val rdata1 = Output(SInt(32.W))
val rdata2 = Output(SInt(32.W))
})
val regfile = RegInit(VecInit(Seq.fill(32)(0.S(32.W))))
io.rdata1 := Mux(io.rs1 === 0.U, 0.S, regfile(io.rs1))
io.rdata2 := Mux(io.rs2 === 0.U, 0.S, regfile(io.rs2))
when(io.reg_write && io.w_reg =/= 0.U) {
regfile(io.w_reg) := io.w_data
}
}
This module is a combinational logic that decide whther forwarding is needed.IF_ID_inst
, IF_ID_inst
, and ID_EX_rd
are inputs from pipeline registers, pc_in
and current_pc
are not in charge of forwarding decisions.
val io = IO(new Bundle {
val IF_ID_inst = Input(UInt(32.W))
val ID_EX_memRead = Input(Bool())
val ID_EX_rd = Input(UInt(5.W))
val pc_in = Input(SInt(32.W))
val current_pc = Input(SInt(32.W))
val inst_forward = Output(Bool())
val pc_forward = Output(Bool())
val ctrl_forward = Output(Bool())
val inst_out = Output(UInt(32.W))
val pc_out = Output(SInt(32.W))
val current_pc_out = Output(SInt(32.W))
}
If a L-type instruction (i.e., io.ID_EX_memRead === 1.B
) is followed by R-type, S-type or B-type instruction, and registers overlap (i.e., ((io.ID_EX_rd === Rs1) || (io.ID_EX_rd === Rs2))
), we forward instruction, PC and certain control signals. For example:
lw t0, 0(t1) // instruction 1
add s0, t0, s1 // instruction 2
in this senerio, we set signals inst_forward
, pc_forward
and ctrl_forward
to true
, otherwise false
.
when(io.ID_EX_memRead === 1.B && ((io.ID_EX_rd === Rs1) || (io.ID_EX_rd === Rs2))) {
io.inst_forward := true.B
io.pc_forward := true.B
io.ctrl_forward := true.B
}.otherwise {
io.inst_forward := false.B
io.pc_forward := false.B
io.ctrl_forward := false.B
}
then the data are path through the module
io.inst_out := io.IF_ID_inst
io.pc_out := io.pc_in
io.current_pc_out := io.current_pc
The implementation is shown below:
package Pipeline
import chisel3._
import chisel3.util._
class HazardDetection extends Module {
val io = IO(new Bundle {
val IF_ID_inst = Input(UInt(32.W))
val ID_EX_memRead = Input(Bool())
val ID_EX_rd = Input(UInt(5.W))
val pc_in = Input(SInt(32.W))
val current_pc = Input(SInt(32.W))
val inst_forward = Output(Bool())
val pc_forward = Output(Bool())
val ctrl_forward = Output(Bool())
val inst_out = Output(UInt(32.W))
val pc_out = Output(SInt(32.W))
val current_pc_out = Output(SInt(32.W))
})
val Rs1 = io.IF_ID_inst(19, 15)
val Rs2 = io.IF_ID_inst(24, 20)
when(io.ID_EX_memRead === 1.B && ((io.ID_EX_rd === Rs1) || (io.ID_EX_rd === Rs2))) {
io.inst_forward := true.B
io.pc_forward := true.B
io.ctrl_forward := true.B
}.otherwise {
io.inst_forward := false.B
io.pc_forward := false.B
io.ctrl_forward := false.B
}
io.inst_out := io.IF_ID_inst
io.pc_out := io.pc_in
io.current_pc_out := io.current_pc
}
This module handles hazard involving B-type instructions, we will first explain the meaning of the IO ports, then explain the logic of the code.
Name | I/O | Meaning |
---|---|---|
ID_EX_RD |
Input | Index of the destination register for instruction in the ID/EX stage. |
EX_MEM_RD |
Input | Index of the destination register for instruction in the EX/MEM stage. |
MEM_WB_RD |
Input | Index of the destination register for instruction in the MEM/WB stage. |
ID_EX_memRd |
Input | whether the instruction in the ID/EX stage is L-type |
EX_MEM_memRd |
Input | Index of the destination register for L-type instruction in the EX/MEM stage. |
MEM_WB_memRd |
Input | Index of the destination register for L-type instruction in the MEM/WB stage. |
rs1 |
Input | rs1 register of B-type instruction |
rs2 |
Input | rs2 register of B-type instruction |
ctrl_branch |
Input | whether this instruction is B-type |
forward_rs1 |
Output | where rs1 came from |
forward_rs2 |
Output | where rs2 came from |
This module handles
B-type
, ALU hazard
add x3, x1, x2
beq x3, x4, label
and the circuit handling this situation:
when(io.ID_EX_RD =/= 0.U && io.ID_EX_memRd =/= 1.U) {
when(io.ID_EX_RD === io.rs1 && io.ID_EX_RD === io.rs2) {
io.forward_rs1 := "b0001".U
io.forward_rs2 := "b0001".U
}.elsewhen(io.ID_EX_RD === io.rs1) {
io.forward_rs1 := "b0001".U
}.elsewhen(io.ID_EX_RD === io.rs2) {
io.forward_rs2 := "b0001".U
}
}
Here io.forward_rs1 := "b0001".U
indicates that data will be forwarded from EXE/MEM
register.
B-type
, EX/MEM Hazard
lw x3, 0(x1)
beq x3, x2, label
and the circuit handling this situation:
// EX/MEM Hazard
when(io.EX_MEM_RD =/= 0.U && io.EX_MEM_memRd =/= 1.U) {
when(io.EX_MEM_RD === io.rs1 && io.EX_MEM_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1 && io.ID_EX_RD === io.rs2)) {
io.forward_rs1 := "b0010".U
io.forward_rs2 := "b0010".U
}.elsewhen(io.EX_MEM_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1)) {
io.forward_rs1 := "b0010".U
}.elsewhen(io.EX_MEM_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs2)) {
io.forward_rs2 := "b0010".U
}
}
Here, io.forward_rs1 := "b0010".U
means data are forwarded from MEM/WB
stage.
B-type
, MEM/WB HazardB-type
, EXE/MEM Hazard part:
addi x1, x2, 42
beq x1, x4, label
and the code accordingly:
// MEM/WB Hazard
when(io.MEM_WB_RD =/= 0.U && io.MEM_WB_memRd =/= 1.U) {
when(io.MEM_WB_RD === io.rs1 && io.MEM_WB_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1 && io.ID_EX_RD === io.rs2) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs1 && io.EX_MEM_RD === io.rs2)) {
io.forward_rs1 := "b0011".U
io.forward_rs2 := "b0011".U
}.elsewhen(io.MEM_WB_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs1)) {
io.forward_rs1 := "b0011".U
}.elsewhen(io.MEM_WB_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs2) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs2)) {
io.forward_rs2 := "b0011".U
}
}
here, io.forward_rs1 := "b0011".U
means data will be forwarded from the WB
stage.
lw x1, 0(x2)
jalr x2, x1, 42 // target address depends on last instruction
Different types of instruction can all cause hazard, therefore, forward_rs1
should be set correspondingly.
The following sheet demonstrate how each value of forward_rs1
means:
Value of forward_rs1 |
Type | Where data are forwarde from |
---|---|---|
0001 | ALU Hazard | ID/EX |
0010 | EX/MEM Hazard | EX/MEM |
0011 | MEM/WB Hazard | MEM/WB |
0110 | JALR | ID/EX |
0111 | JALR | EX/MEM |
1001 | JALR | EX/MEM |
1000 | JALR | MEM/WB |
1010 | JALR | MEM/WB |
The full implementation:
package Pipeline
import chisel3._
import chisel3.util._
class BranchForward extends Module {
val io = IO(new Bundle {
val ID_EX_RD = Input(UInt(5.W))
val EX_MEM_RD = Input(UInt(5.W))
val MEM_WB_RD = Input(UInt(5.W))
val ID_EX_memRd = Input(UInt(1.W))
val EX_MEM_memRd = Input(UInt(1.W))
val MEM_WB_memRd = Input(UInt(1.W))
val rs1 = Input(UInt(5.W))
val rs2 = Input(UInt(5.W))
val ctrl_branch = Input(UInt(1.W))
val forward_rs1 = Output(UInt(4.W))
val forward_rs2 = Output(UInt(4.W))
})
io.forward_rs1 := "b0000".U
io.forward_rs2 := "b0000".U
// Branch forwarding logic
when(io.ctrl_branch === 1.U) {
// ALU Hazard
when(io.ID_EX_RD =/= 0.U && io.ID_EX_memRd =/= 1.U) {
when(io.ID_EX_RD === io.rs1 && io.ID_EX_RD === io.rs2) {
io.forward_rs1 := "b0001".U
io.forward_rs2 := "b0001".U
}.elsewhen(io.ID_EX_RD === io.rs1) {
io.forward_rs1 := "b0001".U
}.elsewhen(io.ID_EX_RD === io.rs2) {
io.forward_rs2 := "b0001".U
}
}
// EX/MEM Hazard
when(io.EX_MEM_RD =/= 0.U && io.EX_MEM_memRd =/= 1.U) {
when(io.EX_MEM_RD === io.rs1 && io.EX_MEM_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1 && io.ID_EX_RD === io.rs2)) {
io.forward_rs1 := "b0010".U
io.forward_rs2 := "b0010".U
}.elsewhen(io.EX_MEM_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1)) {
io.forward_rs1 := "b0010".U
}.elsewhen(io.EX_MEM_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs2)) {
io.forward_rs2 := "b0010".U
}
}
// MEM/WB Hazard
when(io.MEM_WB_RD =/= 0.U && io.MEM_WB_memRd =/= 1.U) {
when(io.MEM_WB_RD === io.rs1 && io.MEM_WB_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1 && io.ID_EX_RD === io.rs2) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs1 && io.EX_MEM_RD === io.rs2)) {
io.forward_rs1 := "b0011".U
io.forward_rs2 := "b0011".U
}.elsewhen(io.MEM_WB_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs1)) {
io.forward_rs1 := "b0011".U
}.elsewhen(io.MEM_WB_RD === io.rs2 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs2) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs2)) {
io.forward_rs2 := "b0011".U
}
}
// Jalr forwarding logic
}.elsewhen(io.ctrl_branch === 0.U) {
when(io.ID_EX_RD =/= 0.U && io.ID_EX_memRd =/= 1.U && io.ID_EX_RD === io.rs1) {
io.forward_rs1 := "b0110".U
}.elsewhen(io.EX_MEM_RD =/= 0.U && io.EX_MEM_memRd =/= 1.U && io.EX_MEM_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1)) {
io.forward_rs1 := "b0111".U
}.elsewhen(io.EX_MEM_RD =/= 0.U && io.EX_MEM_memRd === 1.U && io.EX_MEM_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1)) {
io.forward_rs1 := "b1001".U
}.elsewhen(io.MEM_WB_RD =/= 0.U && io.MEM_WB_memRd =/= 1.U && io.MEM_WB_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs1)) {
io.forward_rs1 := "b1000".U
}.elsewhen(io.MEM_WB_RD =/= 0.U && io.MEM_WB_memRd === 1.U && io.MEM_WB_RD === io.rs1 && !(io.ID_EX_RD =/= 0.U && io.ID_EX_RD === io.rs1) && !(io.EX_MEM_RD =/= 0.U && io.EX_MEM_RD === io.rs1)) {
io.forward_rs1 := "b1010".U
}
}
}
The module Forwarding.scala
are handling for data hazard
.
This module handles 2 types of hazard, and we will explain with example:
The EX Hazard show as below:
add x3, x1, x2
sub x4, x3, x5
For the first instruction add
, the result needs to be stored in x3
. However, x3
is also required as an input for the sub instruction in the next line. The situation creates an EX Hazard.
The situation is detecting by
io.EXMEM_regWr === "b1".U &&
io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1.asUInt)&&
(io.EXMEM_rd === io.IDEX_rs2)
io.EXMEM_regWr === "b1".U
is used to detect whether the write-back signal is active, which in this case means that the result of add will be written back to x3.io.EXMEM_rd =/= "b00000".U
ensures that the target register is valid (not x0
, which always holds the value 0 in RISC-V architecture).x3
(in add) with x3
and x5
(in sub).This logic ensures that the hazard is detected, sending the signal b10
allowing the processor to forward the result of x3
directly from the EX/MEM stage
to the next instruction before it is written back to the register file.
The complete code for handling EX hazard
would be:
// EX HAZARD
when(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1.asUInt) && (io.EXMEM_rd === io.IDEX_rs2)) {
io.forward_a := "b10".U
io.forward_b := "b10".U
}.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs2)) {
io.forward_b := "b10".U
}.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1)) {
io.forward_a := "b10".U
}
The MEM Hazard show as below:
lw x3, 0(x1)
sub x4, x3, x5
For the first instruction lw
, the result needs to be stored in x3
. However, the next instruction sub
also requires the data in x3
. The situation creates a MEM Hazard.
The situation is detecting by
(io.MEMWB_regWr === "b1".U) &&
(io.MEMWB_rd =/= "b00000".U) &&
(io.MEMWB_rd === io.IDEX_rs1) &&
(io.MEMWB_rd === io.IDEX_rs2) &&
~(io.EXMEM_regWr === "b1".U &&
io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1) &&
(io.EXMEM_rd === io.IDEX_rs2))
(io.MEMWB_regWr === "b1".U)
is used to detect whether the write-back signal is active, which in this case means that the result of lw
will be written back to x3
.io.MEMWB_rd =/= "b00000".U
ensures that the target register is valid (not x0
, which always holds the value 0 in RISC-V architecture).(io.MEMWB_rd === io.IDEX_rs1) && (io.MEMWB_rd === io.IDEX_rs2)
compare the target register to write with target register to calculate which in our case compare x3(in lw) with x3 and x5(in sub).EX Hazard
~(io.EXMEM_regWr === "b1".U &&
io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1) &&
(io.EXMEM_rd === io.IDEX_rs2))
This logic ensures that the hazard is detected and sends the signal b01, allowing the processor to forward the result of x3 directly from the MEM/WB stage to the next instruction, even before it is written back to the register file.
Most of the hazards are handled before, this module may be redundent.
Once we have completed all the modules, we can integrate them into a system.
First, we instantiate every module.
// Pipes of stages
val IF_ID_ = Module(new IF_ID)
val ID_EX_ = Module(new ID_EX)
val EX_MEM_M = Module(new EX_MEM)
val MEM_WB_M = Module(new MEM_WB)
// PC / PC+4
val PC = Module(new PC)
val PC4 = Module(new PC4)
// Memory
val InstMemory = Module(new InstMem ("./src/riscv/test.hex"))
val DataMemory = Module(new DataMemory)
// Helping Units
val control_module = Module(new Control)
val ImmGen = Module(new ImmGenerator)
val RegFile = Module(new RegisterFile)
val ALU_Control = Module(new AluControl)
dontTouch(ALU_Control.io)
val ALU = Module(new ALU)
dontTouch(ALU.io)
val Branch_M = Module(new Branch)
val JALR = Module(new Jalr)
// hazard units
val Forwarding = Module(new Forwarding)
val HazardDetect = Module(new HazardDetection)
val Branch_Forward = Module(new BranchForward)
val Structural = Module(new StructuralHazard)
Remarkably, we use dontTouch
to prevent Chisel from removing the signal during optimization. Ref : DontTouch
To fetch the instruction, we select the correct program counter (PC) based on the signal from the HazardDetect module.
val PC_F = MuxLookup (HazardDetect.io.pc_forward, 0.S, Array (
(0.U) -> PC4.io.out.asSInt,
(1.U) -> HazardDetect.io.pc_out))
PC+4 value
from the PC4 module is selected.Then, update the current program counter (PC) by PC.io.in := PC_F
, update the next PC by PC4.io.pc := PC.io.out.asUInt
and fetch the instruction by InstMemory.io.addr := PC.io.out.asUInt
.
PC.io.in := PC_F // PC_in input
PC4.io.pc := PC.io.out.asUInt // PC4_in input <- PC_out
InstMemory.io.addr := PC.io.out.asUInt // Address to fetch instruction
For PC and instruction forwarding, we choose the instruction and PC using the HazardDetect module.
val PC_for = MuxLookup (HazardDetect.io.inst_forward, 0.S, Array (
(0.U) -> PC.io.out,
(1.U) -> HazardDetect.io.current_pc_out))
val Instruction_F = MuxLookup (HazardDetect.io.inst_forward, 0.U, Array (
(0.U) -> InstMemory.io.data,
(1.U) -> HazardDetect.io.inst_out))
Finally, we pass the pc and instruction to the register (i.e., IF_ID module) where
PC.io_out
represents the current PC,PC4.io.out
represents the next PC (i.e., PC + 4),PC_for
represents the correct PC selected by the HazardDetect module.
IF_ID_.io.pc_in := PC.io.out // PC out from pc
IF_ID_.io.pc4_in := PC4.io.out // PC4 out from pc4
IF_ID_.io.SelectedPC := PC_for // Selected PC
IF_ID_.io.SelectedInstr := Instruction_F // Selected Instruction
First, we pass the Selected instruction and PC into the ImmGenerator module and forward the opcode (i.e., inst[6,0]) into control module.
ImmGen.io.instr := IF_ID_.io.SelectedInstr_out // Instrcution to generate Immidiate Value 32
ImmGen.io.pc := IF_ID_.io.SelectedPC_out.asUInt // PC to add
// Decode connections (Control unit RegFile)
control_module.io.opcode := IF_ID_.io.SelectedInstr_out(6, 0) // OPcode to check Instrcution TYpe
If the type of the instruction is R-type
(opcode = 51), I-type
(opcode = 19), S-type
(opcode = 35), I-type (load instructions)
(opcode = 3), SB-type (branch)
(opcode = 99), or JALR
(opcode = 103), we will decode the rs1
(i.e., inst[19,15]).
RegFile.io.rs1 := Mux(
control_module.io.opcode === 51.U || // R-type
control_module.io.opcode === 19.U || // I-type
control_module.io.opcode === 35.U || // S-type
control_module.io.opcode === 3.U || // I-type (load instructions)
control_module.io.opcode === 99.U || // SB-type (branch)
control_module.io.opcode === 103.U, // JALR instruction
IF_ID_.io.SelectedInstr_out(19, 15), 0.U )
Same, we decode the rs2
(i.e., inst[24,20]) if the type of the instruction is R-type
(opcode = 51), S-type
(opcode = 35), SB-type (branch)
(opcode = 99).
RegFile.io.rs2 := Mux(control_module.io.opcode === 51.U || // R-type
control_module.io.opcode === 35.U || // S-type
control_module.io.opcode === 99.U, // SB-type (branch)
IF_ID_.io.SelectedInstr_out(24, 20), 0.U)
then, control the write signal by RegFile.io.reg_write := control_module.io.reg_write
.
Finally, we generate the immediate value by ImmGenerator module.
val ImmValue = MuxLookup (control_module.io.extend, 0.S, Array (
(0.U) -> ImmGen.io.I_type,
(1.U) -> ImmGen.io.S_type,
(2.U) -> ImmGen.io.U_type))
For handling Structural Hazards, we extract the rs1
and rs2
from the IF_ID instruction
and pass them to the Structural module. This allows the module to detect whether the stage requires the data.
Structural.io.rs1 := IF_ID_.io.SelectedInstr_out(19, 15)
Structural.io.rs2 := IF_ID_.io.SelectedInstr_out(24, 20)
and receive the data for forwarding from MEM_WB register
.
Structural.io.MEM_WB_regWr := MEM_WB_M.io.EXMEM_REG_W
Structural.io.MEM_WB_Rd := MEM_WB_M.io.MEMWB_rd_out
then decide the whether the data need from forwarding by Structural module signal.
// rs1_data
when (Structural.io.fwd_rs1 === 0.U) {
S_rs1DataIn := RegFile.io.rdata1
}.elsewhen (Structural.io.fwd_rs1 === 1.U) {
S_rs1DataIn := RegFile.io.w_data
}.otherwise {
S_rs1DataIn := 0.S
}
// rs2_data
when (Structural.io.fwd_rs2 === 0.U) {
S_rs2DataIn := RegFile.io.rdata2
}.elsewhen (Structural.io.fwd_rs2 === 1.U) {
S_rs2DataIn := RegFile.io.w_data
}.otherwise {
S_rs2DataIn := 0.S
}
//ID_EX_ inputs
ID_EX_.io.rs1_data_in := S_rs1DataIn
ID_EX_.io.rs2_data_in := S_rs2DataIn
For detecting Hazard, we pass the data from IF_ID register
and ID_EX register
into the HazardDetect module.
// Hazard detection Unit inputs
HazardDetect.io.IF_ID_inst := IF_ID_.io.SelectedInstr_out
HazardDetect.io.ID_EX_memRead := ID_EX_.io.ctrl_MemRd_out
HazardDetect.io.ID_EX_rd := ID_EX_.io.rd_out
HazardDetect.io.pc_in := IF_ID_.io.pc4_out.asSInt
HazardDetect.io.current_pc := IF_ID_.io.SelectedPC_out
then if the stall is needed (detected by HazardDetect module), we make a bubble by setting all the control signal to 0.
// Stall when forward
when(HazardDetect.io.ctrl_forward === "b1".U) {
ID_EX_.io.ctrl_MemWr_in := 0.U
ID_EX_.io.ctrl_MemRd_in := 0.U
ID_EX_.io.ctrl_MemToReg_in := 0.U
ID_EX_.io.ctrl_Reg_W_in := 0.U
ID_EX_.io.ctrl_AluOp_in := 0.U
ID_EX_.io.ctrl_OpB_in := 0.U
ID_EX_.io.ctrl_Branch_in := 0.U
ID_EX_.io.ctrl_nextpc_in := 0.U
}.otherwise {
ID_EX_.io.ctrl_MemWr_in := control_module.io.mem_write
ID_EX_.io.ctrl_MemRd_in := control_module.io.mem_read
ID_EX_.io.ctrl_MemToReg_in := control_module.io.men_to_reg
ID_EX_.io.ctrl_Reg_W_in := control_module.io.reg_write
ID_EX_.io.ctrl_AluOp_in := control_module.io.alu_operation
ID_EX_.io.ctrl_OpB_in := control_module.io.operand_B
ID_EX_.io.ctrl_Branch_in := control_module.io.branch
ID_EX_.io.ctrl_nextpc_in := control_module.io.next_pc_sel
}
Before passing the data into the ID_EX register
, we will handle the branch
and jal
operations separately, and we will explain them in separate sections.
then we pass the data into ID_EX register
.
// ID_EX PIPELINE
ID_EX_.io.rs1_in := RegFile.io.rs1
ID_EX_.io.rs2_in := RegFile.io.rs2
ID_EX_.io.imm := ImmValue
ID_EX_.io.func3_in := IF_ID_.io.SelectedInstr_out(14, 12)
ID_EX_.io.func7_in := IF_ID_.io.SelectedInstr_out(30)
ID_EX_.io.rd_in := IF_ID_.io.SelectedInstr_out(11, 7)
ID_EX_.io.ctrl_OpA_in := control_module.io.operand_A // Operand A selection
ID_EX_.io.IFID_pc4_in := IF_ID_.io.pc4_out // pc+4 from Decode to execute
For Branch and JALR we pass the data from IF_ID
, ID_EX
, EX_MEM
, MEM_WB
into the BranchForward module to detect whether the rs1
and rs2
come from.
Branch_Forward.io.ID_EX_RD := ID_EX_.io.rd_out
Branch_Forward.io.EX_MEM_RD := EX_MEM_M.io.EXMEM_rd_out
Branch_Forward.io.MEM_WB_RD := MEM_WB_M.io.MEMWB_rd_out
Branch_Forward.io.ID_EX_memRd := ID_EX_.io.ctrl_MemRd_out
Branch_Forward.io.EX_MEM_memRd := EX_MEM_M.io.EXMEM_memRd_out
Branch_Forward.io.MEM_WB_memRd := MEM_WB_M.io.MEMWB_memRd_out
Branch_Forward.io.rs1 := IF_ID_.io.SelectedInstr_out(19, 15)
Branch_Forward.io.rs2 := IF_ID_.io.SelectedInstr_out(24, 20)
Branch_Forward.io.ctrl_branch := control_module.io.branch
We utilize the BranchForward module to choose the forward rs1
and rs2
and pass the data into the Branch module to detect whether it should branch.
// Branch X
Branch_M.io.arg_x := MuxLookup (Branch_Forward.io.forward_rs1, 0.S, Array (
(0.U) -> RegFile.io.rdata1,
(1.U) -> ALU.io.out,
(2.U) -> EX_MEM_M.io.EXMEM_alu_out,
(3.U) -> RegFile.io.w_data,
(4.U) -> DataMemory.io.dataOut,
(5.U) -> RegFile.io.w_data,
(6.U) -> RegFile.io.rdata1,
(7.U) -> RegFile.io.rdata1,
(8.U) -> RegFile.io.rdata1,
(9.U) -> RegFile.io.rdata1,
(10.U) -> RegFile.io.rdata1))
// Branch Y
Branch_M.io.arg_y := MuxLookup (Branch_Forward.io.forward_rs2, 0.S, Array (
(0.U) -> RegFile.io.rdata2,
(1.U) -> ALU.io.out,
(2.U) -> EX_MEM_M.io.EXMEM_alu_out,
(3.U) -> RegFile.io.w_data,
(4.U) -> DataMemory.io.dataOut,
(5.U) -> RegFile.io.w_data))
Branch_M.io.fnct3 := IF_ID_.io.SelectedInstr_out(14, 12) // Fun3 for(beq,bne....)
Branch_M.io.branch := control_module.io.branch // Branch instr yes
also, we utilize the JAL module to detect how to jump the instruction.
// for JALR
JALR.io.rdata1 := MuxLookup (Branch_Forward.io.forward_rs1, 0.U, Array (
(0.U) -> RegFile.io.rdata1.asUInt,
(1.U) -> RegFile.io.rdata1.asUInt,
(2.U) -> RegFile.io.rdata1.asUInt,
(3.U) -> RegFile.io.rdata1.asUInt,
(4.U) -> RegFile.io.rdata1.asUInt,
(5.U) -> RegFile.io.rdata1.asUInt,
(6.U) -> ALU.io.out.asUInt,
(7.U) -> EX_MEM_M.io.EXMEM_alu_out.asUInt,
(8.U) -> RegFile.io.w_data.asUInt,
(9.U) -> DataMemory.io.dataOut.asUInt,
(10.U) -> RegFile.io.w_data.asUInt))
JALR.io.imme := ImmValue.asUInt
The JALR module should output the target address with correct alignment.
Finally, updating the PC by detecting the Hazard and Control module.
when(HazardDetect.io.pc_forward === 1.B) {
PC.io.in := HazardDetect.io.pc_out
}.otherwise {
when(control_module.io.next_pc_sel === "b01".U) {
when(Branch_M.io.br_taken === 1.B && control_module.io.branch === 1.B) {
PC.io.in := ImmGen.io.SB_type
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.SelectedPC:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.otherwise {
PC.io.in := PC4.io.out.asSInt
}
}.elsewhen(control_module.io.next_pc_sel === "b10".U) {
PC.io.in := ImmGen.io.UJ_type
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.SelectedPC:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.elsewhen(control_module.io.next_pc_sel === "b11".U) {
PC.io.in := JALR.io.out.asSInt
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.SelectedPC:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.otherwise {
PC.io.in := PC4.io.out.asSInt
}
}
So far, we have completed the functions for IF
, ID
, and branch & jalr
. The correct register is ready for calculation, and the PC is accurate.
First, pass the correct register of rs1
, rs2
and rd
into the ALU module and also the ALU control signal.
ALU_Control.io.aluOp := ID_EX_.io.ctrl_AluOp_out // Alu op code
ALU.io.alu_Op := ALU_Control.io.out // Alu op code
ALU_Control.io.func3 := ID_EX_.io.func3_out // function 3
ALU_Control.io.func7 := ID_EX_.io.func7_out // function 7
then we have to forward the data in EX stage therefore pass the data into the Forwarding module.
Forwarding.io.IDEX_rs1 := ID_EX_.io.rs1_out
Forwarding.io.IDEX_rs2 := ID_EX_.io.rs2_out
Forwarding.io.EXMEM_rd := EX_MEM_M.io.EXMEM_rd_out
Forwarding.io.EXMEM_regWr := EX_MEM_M.io.EXMEM_reg_w_out
Forwarding.io.MEMWB_rd := MEM_WB_M.io.MEMWB_rd_out
Forwarding.io.MEMWB_regWr := MEM_WB_M.io.MEMWB_reg_w_out
and decide the value passing into ALU by the Forwarding module.
when (ID_EX_.io.ctrl_OpA_out === "b01".U) {
ALU.io.in_A := ID_EX_.io.IFID_pc4_out.asSInt
}.otherwise {
// forwarding A
when(Forwarding.io.forward_a === "b00".U) {
ALU.io.in_A := ID_EX_.io.rs1_data_out
}.elsewhen(Forwarding.io.forward_a === "b01".U) {
ALU.io.in_A := d
}.elsewhen(Forwarding.io.forward_a === "b10".U) {
ALU.io.in_A := EX_MEM_M.io.EXMEM_alu_out
}.otherwise {
ALU.io.in_A := ID_EX_.io.rs1_data_out
}
}
// forwarding B
val RS2_value = Wire(SInt(32.W))
when (Forwarding.io.forward_b === 0.U) {
RS2_value := ID_EX_.io.rs2_data_out
}.elsewhen (Forwarding.io.forward_b === 1.U) {
RS2_value := d
}.elsewhen (Forwarding.io.forward_b === 2.U) {
RS2_value := EX_MEM_M.io.EXMEM_alu_out
}.otherwise {
RS2_value := 0.S
}
when (ID_EX_.io.ctrl_OpB_out === 0.U) {
ALU.io.in_B := RS2_value
}.otherwise {
ALU.io.in_B := ID_EX_.io.imm_out
}
then pass the data into the EX_WB register.
// Execute
EX_MEM_M.io.IDEX_rd := ID_EX_.io.rd_out
EX_MEM_M.io.IDEX_MEMRD := ID_EX_.io.ctrl_MemRd_out
EX_MEM_M.io.IDEX_MEMWR := ID_EX_.io.ctrl_MemWr_out
EX_MEM_M.io.IDEX_MEMTOREG := ID_EX_.io.ctrl_MemToReg_out
EX_MEM_M.io.IDEX_REG_W := ID_EX_.io.ctrl_Reg_W_out
EX_MEM_M.io.IDEX_rs2 := RS2_value
EX_MEM_M.io.alu_out := ALU.io.out
The only thing we have to do in MEM stage is to pass the data into the DataMemory module.
DataMemory.io.mem_read := EX_MEM_M.io.EXMEM_memRd_out
DataMemory.io.mem_write := EX_MEM_M.io.EXMEM_memWr_out
DataMemory.io.dataIn := EX_MEM_M.io.EXMEM_rs2_out
DataMemory.io.addr := EX_MEM_M.io.EXMEM_alu_out.asUInt
then update the MEM_WB register
MEM_WB_M.io.EXMEM_MEMRD := EX_MEM_M.io.EXMEM_memRd_out // 0/ 1: data read from memory
MEM_WB_M.io.EXMEM_MEMTOREG := EX_MEM_M.io.EXMEM_memToReg_out
MEM_WB_M.io.EXMEM_REG_W := EX_MEM_M.io.EXMEM_reg_w_out
MEM_WB_M.io.EXMEM_rd := EX_MEM_M.io.EXMEM_rd_out
MEM_WB_M.io.in_dataMem_out := DataMemory.io.dataOut // data from Data Memory
MEM_WB_M.io.in_alu_out := EX_MEM_M.io.EXMEM_alu_out // data from Alu Result
For the write back stage, we pass the write and read signal from MEM_WB register.
RegFile.io.w_reg := MEM_WB_M.io.MEMWB_rd_out
RegFile.io.reg_write := MEM_WB_M.io.MEMWB_reg_w_out
Finally, we pass the data and determine which value should be written back based on the control signal.
when (MEM_WB_M.io.MEMWB_memToReg_out === 0.U) {
d := MEM_WB_M.io.MEMWB_alu_out // data from Alu Result
}.elsewhen (MEM_WB_M.io.MEMWB_memToReg_out === 1.U) {
d := MEM_WB_M.io.MEMWB_dataMem_out // data from Data Memory
}.otherwise {
d := 0.S
}
RegFile.io.w_data := d // Write back data
Until now, we have completed the 5-stage RISC-V pipeline. The complete code is shown below.
package Pipeline
import chisel3._
import chisel3.util._
class PIPELINE extends Module {
val io = IO(new Bundle {
val out = Output (SInt(4.W))
})
// Pipes of stages
val IF_ID_ = Module(new IF_ID)
val ID_EX_ = Module(new ID_EX)
val EX_MEM_M = Module(new EX_MEM)
val MEM_WB_M = Module(new MEM_WB)
// PC / PC+4
val PC = Module(new PC)
val PC4 = Module(new PC4)
// Memory
val InstMemory = Module(new InstMem ("./src/riscv/rv32ui-p-add.hex"))
val DataMemory = Module(new DataMemory)
// Helping Units
val control_module = Module(new Control)
val ImmGen = Module(new ImmGenerator)
val RegFile = Module(new RegisterFile)
val ALU_Control = Module(new AluControl)
dontTouch(ALU_Control.io)
val ALU = Module(new ALU)
dontTouch(ALU.io)
val Branch_M = Module(new Branch)
val JALR = Module(new Jalr)
// hazard units
val Forwarding = Module(new Forwarding)
val HazardDetect = Module(new HazardDetection)
val Branch_Forward = Module(new BranchForward)
val Structural = Module(new StructuralHazard)
val PC_F = MuxLookup (HazardDetect.io.pc_forward, 0.S, Array (
(0.U) -> PC4.io.out.asSInt,
(1.U) -> HazardDetect.io.pc_out))
PC.io.in := PC_F // PC_in input
PC4.io.pc := PC.io.out.asUInt // PC4_in input <- PC_out
InstMemory.io.addr := PC.io.out.asUInt // Address to fetch instruction
val PC_for = MuxLookup (HazardDetect.io.inst_forward, 0.S, Array (
(0.U) -> PC.io.out,
(1.U) -> HazardDetect.io.current_pc_out))
val Instruction_F = MuxLookup (HazardDetect.io.inst_forward, 0.U, Array (
(0.U) -> InstMemory.io.data,
(1.U) -> HazardDetect.io.inst_out))
// Fetch decode pipe connections
IF_ID_.io.pc_in := PC.io.out // PC out from pc
IF_ID_.io.pc4_in := PC4.io.out // PC4 out from pc4
IF_ID_.io.SelectedPC := PC_for // Selected PC
IF_ID_.io.SelectedInstr := Instruction_F // Selected Instruction
//ImmGenerator Inputs
ImmGen.io.instr := IF_ID_.io.SelectedInstr_out // Instrcution to generate Immidiate Value 32
ImmGen.io.pc := IF_ID_.io.SelectedPC_out.asUInt // PC to add
// Decode connections (Control unit RegFile)
control_module.io.opcode := IF_ID_.io.SelectedInstr_out(6, 0) // OPcode to check Instrcution TYpe
// Registerfile inputs
RegFile.io.rs1 := Mux(
control_module.io.opcode === 51.U || // R-type
control_module.io.opcode === 19.U || // I-type
control_module.io.opcode === 35.U || // S-type
control_module.io.opcode === 3.U || // I-type (load instructions)
control_module.io.opcode === 99.U || // SB-type (branch)
control_module.io.opcode === 103.U, // JALR instruction
IF_ID_.io.SelectedInstr_out(19, 15), 0.U )
RegFile.io.rs2 := Mux(
control_module.io.opcode === 51.U || // R-type
control_module.io.opcode === 35.U || // S-type
control_module.io.opcode === 99.U, // SB-type (branch)
IF_ID_.io.SelectedInstr_out(24, 20), 0.U)
RegFile.io.reg_write := control_module.io.reg_write
val ImmValue = MuxLookup (control_module.io.extend, 0.S, Array (
(0.U) -> ImmGen.io.I_type,
(1.U) -> ImmGen.io.S_type,
(2.U) -> ImmGen.io.U_type))
// Structural hazard inputs
Structural.io.rs1 := IF_ID_.io.SelectedInstr_out(19, 15)
Structural.io.rs2 := IF_ID_.io.SelectedInstr_out(24, 20)
Structural.io.MEM_WB_regWr := MEM_WB_M.io.EXMEM_REG_W
Structural.io.MEM_WB_Rd := MEM_WB_M.io.MEMWB_rd_out
val S_rs1DataIn = Wire(SInt(32.W))
val S_rs2DataIn = Wire(SInt(32.W))
// rs1_data
when (Structural.io.fwd_rs1 === 0.U) {
S_rs1DataIn := RegFile.io.rdata1
}.elsewhen (Structural.io.fwd_rs1 === 1.U) {
S_rs1DataIn := RegFile.io.w_data
}.otherwise {
S_rs1DataIn := 0.S
}
// rs2_data
when (Structural.io.fwd_rs2 === 0.U) {
S_rs2DataIn := RegFile.io.rdata2
}.elsewhen (Structural.io.fwd_rs2 === 1.U) {
S_rs2DataIn := RegFile.io.w_data
}.otherwise {
S_rs2DataIn := 0.S
}
//ID_EX_ inputs
ID_EX_.io.rs1_data_in := S_rs1DataIn
ID_EX_.io.rs2_data_in := S_rs2DataIn
// Stall when forward
when(HazardDetect.io.ctrl_forward === "b1".U) {
ID_EX_.io.ctrl_MemWr_in := 0.U
ID_EX_.io.ctrl_MemRd_in := 0.U
ID_EX_.io.ctrl_MemToReg_in := 0.U
ID_EX_.io.ctrl_Reg_W_in := 0.U
ID_EX_.io.ctrl_AluOp_in := 0.U
ID_EX_.io.ctrl_OpB_in := 0.U
ID_EX_.io.ctrl_Branch_in := 0.U
ID_EX_.io.ctrl_nextpc_in := 0.U
}.otherwise {
ID_EX_.io.ctrl_MemWr_in := control_module.io.mem_write
ID_EX_.io.ctrl_MemRd_in := control_module.io.mem_read
ID_EX_.io.ctrl_MemToReg_in := control_module.io.men_to_reg
ID_EX_.io.ctrl_Reg_W_in := control_module.io.reg_write
ID_EX_.io.ctrl_AluOp_in := control_module.io.alu_operation
ID_EX_.io.ctrl_OpB_in := control_module.io.operand_B
ID_EX_.io.ctrl_Branch_in := control_module.io.branch
ID_EX_.io.ctrl_nextpc_in := control_module.io.next_pc_sel
}
// Hazard detection Unit inputs
HazardDetect.io.IF_ID_inst := IF_ID_.io.SelectedInstr_out
HazardDetect.io.ID_EX_memRead := ID_EX_.io.ctrl_MemRd_out
HazardDetect.io.ID_EX_rd := ID_EX_.io.rd_out
HazardDetect.io.pc_in := IF_ID_.io.pc4_out.asSInt
HazardDetect.io.current_pc := IF_ID_.io.SelectedPC_out
MEM_WB_M.io.EXMEM_MEMRD := EX_MEM_M.io.EXMEM_memRd_out // 0/ 1: data read from memory
// Branch forward Unit inputs
Branch_Forward.io.ID_EX_RD := ID_EX_.io.rd_out
Branch_Forward.io.EX_MEM_RD := EX_MEM_M.io.EXMEM_rd_out
Branch_Forward.io.MEM_WB_RD := MEM_WB_M.io.MEMWB_rd_out
Branch_Forward.io.ID_EX_memRd := ID_EX_.io.ctrl_MemRd_out
Branch_Forward.io.EX_MEM_memRd := EX_MEM_M.io.EXMEM_memRd_out
Branch_Forward.io.MEM_WB_memRd := MEM_WB_M.io.MEMWB_memRd_out
Branch_Forward.io.rs1 := IF_ID_.io.SelectedInstr_out(19, 15)
Branch_Forward.io.rs2 := IF_ID_.io.SelectedInstr_out(24, 20)
Branch_Forward.io.ctrl_branch := control_module.io.branch
// Branch X
Branch_M.io.arg_x := MuxLookup (Branch_Forward.io.forward_rs1, 0.S, Array (
(0.U) -> RegFile.io.rdata1,
(1.U) -> ALU.io.out,
(2.U) -> EX_MEM_M.io.EXMEM_alu_out,
(3.U) -> RegFile.io.w_data,
(4.U) -> DataMemory.io.dataOut,
(5.U) -> RegFile.io.w_data,
(6.U) -> RegFile.io.rdata1,
(7.U) -> RegFile.io.rdata1,
(8.U) -> RegFile.io.rdata1,
(9.U) -> RegFile.io.rdata1,
(10.U) -> RegFile.io.rdata1))
// for JALR
JALR.io.rdata1 := MuxLookup (Branch_Forward.io.forward_rs1, 0.U, Array (
(0.U) -> RegFile.io.rdata1.asUInt,
(1.U) -> RegFile.io.rdata1.asUInt,
(2.U) -> RegFile.io.rdata1.asUInt,
(3.U) -> RegFile.io.rdata1.asUInt,
(4.U) -> RegFile.io.rdata1.asUInt,
(5.U) -> RegFile.io.rdata1.asUInt,
(6.U) -> ALU.io.out.asUInt,
(7.U) -> EX_MEM_M.io.EXMEM_alu_out.asUInt,
(8.U) -> RegFile.io.w_data.asUInt,
(9.U) -> DataMemory.io.dataOut.asUInt,
(10.U) -> RegFile.io.w_data.asUInt))
JALR.io.imme := ImmValue.asUInt
// Branch Y
Branch_M.io.arg_y := MuxLookup (Branch_Forward.io.forward_rs2, 0.S, Array (
(0.U) -> RegFile.io.rdata2,
(1.U) -> ALU.io.out,
(2.U) -> EX_MEM_M.io.EXMEM_alu_out,
(3.U) -> RegFile.io.w_data,
(4.U) -> DataMemory.io.dataOut,
(5.U) -> RegFile.io.w_data))
Branch_M.io.fnct3 := IF_ID_.io.SelectedInstr_out(14, 12) // Fun3 for(beq,bne....)
Branch_M.io.branch := control_module.io.branch // Branch instr yes
when(HazardDetect.io.pc_forward === 1.B) {
PC.io.in := HazardDetect.io.pc_out
}.otherwise {
when(control_module.io.next_pc_sel === "b01".U) {
when(Branch_M.io.br_taken === 1.B && control_module.io.branch === 1.B) {
PC.io.in := ImmGen.io.SB_type
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.SelectedPC:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.otherwise {
PC.io.in := PC4.io.out.asSInt
}
}.elsewhen(control_module.io.next_pc_sel === "b10".U) {
PC.io.in := ImmGen.io.UJ_type
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.SelectedPC:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.elsewhen(control_module.io.next_pc_sel === "b11".U) {
PC.io.in := JALR.io.out.asSInt
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.SelectedPC:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.otherwise {
PC.io.in := PC4.io.out.asSInt
}
}
// ID_EX PIPELINE
ID_EX_.io.rs1_in := RegFile.io.rs1
ID_EX_.io.rs2_in := RegFile.io.rs2
ID_EX_.io.imm := ImmValue
ID_EX_.io.func3_in := IF_ID_.io.SelectedInstr_out(14, 12)
ID_EX_.io.func7_in := IF_ID_.io.SelectedInstr_out(30)
ID_EX_.io.rd_in := IF_ID_.io.SelectedInstr_out(11, 7)
ALU_Control.io.aluOp := ID_EX_.io.ctrl_AluOp_out // Alu op code
ALU.io.alu_Op := ALU_Control.io.out // Alu op code
ALU_Control.io.func3 := ID_EX_.io.func3_out // function 3
ALU_Control.io.func7 := ID_EX_.io.func7_out // function 7
EX_MEM_M.io.IDEX_rd := ID_EX_.io.rd_out
// Forwarding Inputs
Forwarding.io.IDEX_rs1 := ID_EX_.io.rs1_out
Forwarding.io.IDEX_rs2 := ID_EX_.io.rs2_out
Forwarding.io.EXMEM_rd := EX_MEM_M.io.EXMEM_rd_out
Forwarding.io.EXMEM_regWr := EX_MEM_M.io.EXMEM_reg_w_out
Forwarding.io.MEMWB_rd := MEM_WB_M.io.MEMWB_rd_out
Forwarding.io.MEMWB_regWr := MEM_WB_M.io.MEMWB_reg_w_out
ID_EX_.io.ctrl_OpA_in := control_module.io.operand_A // Operand A selection
ID_EX_.io.IFID_pc4_in := IF_ID_.io.pc4_out // pc+4 from Decode to execute
val d = Wire(SInt(32.W))
when (ID_EX_.io.ctrl_OpA_out === "b01".U) {
ALU.io.in_A := ID_EX_.io.IFID_pc4_out.asSInt
}.otherwise {
// forwarding A
when(Forwarding.io.forward_a === "b00".U) {
ALU.io.in_A := ID_EX_.io.rs1_data_out
}.elsewhen(Forwarding.io.forward_a === "b01".U) {
ALU.io.in_A := d
}.elsewhen(Forwarding.io.forward_a === "b10".U) {
ALU.io.in_A := EX_MEM_M.io.EXMEM_alu_out
}.otherwise {
ALU.io.in_A := ID_EX_.io.rs1_data_out
}
}
// forwarding B
val RS2_value = Wire(SInt(32.W))
when (Forwarding.io.forward_b === 0.U) {
RS2_value := ID_EX_.io.rs2_data_out
}.elsewhen (Forwarding.io.forward_b === 1.U) {
RS2_value := d
}.elsewhen (Forwarding.io.forward_b === 2.U) {
RS2_value := EX_MEM_M.io.EXMEM_alu_out
}.otherwise {
RS2_value := 0.S
}
when (ID_EX_.io.ctrl_OpB_out === 0.U) {
ALU.io.in_B := RS2_value
}.otherwise {
ALU.io.in_B := ID_EX_.io.imm_out
}
// Execute
EX_MEM_M.io.IDEX_MEMRD := ID_EX_.io.ctrl_MemRd_out
EX_MEM_M.io.IDEX_MEMWR := ID_EX_.io.ctrl_MemWr_out
EX_MEM_M.io.IDEX_MEMTOREG := ID_EX_.io.ctrl_MemToReg_out
EX_MEM_M.io.IDEX_REG_W := ID_EX_.io.ctrl_Reg_W_out
EX_MEM_M.io.IDEX_rs2 := RS2_value
EX_MEM_M.io.alu_out := ALU.io.out
// Data memory inputs
DataMemory.io.mem_read := EX_MEM_M.io.EXMEM_memRd_out
DataMemory.io.mem_write := EX_MEM_M.io.EXMEM_memWr_out
DataMemory.io.dataIn := EX_MEM_M.io.EXMEM_rs2_out
DataMemory.io.addr := EX_MEM_M.io.EXMEM_alu_out.asUInt
MEM_WB_M.io.EXMEM_MEMTOREG := EX_MEM_M.io.EXMEM_memToReg_out
MEM_WB_M.io.EXMEM_REG_W := EX_MEM_M.io.EXMEM_reg_w_out
MEM_WB_M.io.EXMEM_rd := EX_MEM_M.io.EXMEM_rd_out
MEM_WB_M.io.in_dataMem_out := DataMemory.io.dataOut // data from Data Memory
MEM_WB_M.io.in_alu_out := EX_MEM_M.io.EXMEM_alu_out // data from Alu Result
// Register file connections
RegFile.io.w_reg := MEM_WB_M.io.MEMWB_rd_out
RegFile.io.reg_write := MEM_WB_M.io.MEMWB_reg_w_out
// Write back data to registerfile writedata
when (MEM_WB_M.io.MEMWB_memToReg_out === 0.U) {
d := MEM_WB_M.io.MEMWB_alu_out // data from Alu Result
}.elsewhen (MEM_WB_M.io.MEMWB_memToReg_out === 1.U) {
d := MEM_WB_M.io.MEMWB_dataMem_out // data from Data Memory
}.otherwise {
d := 0.S
}
RegFile.io.w_data := d // Write back data
io.out := 0.S
// printf(p"pc : 0x${Hexadecimal(IF_ID_.io.SelectedPC)}\n")
// printf(p"inst : 0x${Hexadecimal(InstMemory.io.data)}\n")
}
In this section, we test each function by using ChiselTest
which is base on ScalaTest
and by using Verilator
which is a powerful tool for simulating verilog module.
Ref : Chisel Cookbook (Migrating from ChiselTest to ChiselSim)
Chisel provide tools to convert scala code into verilog. For example:
scala-cli generate.scala
Once the Verilog module is generated, it can be converted into a C++ object using verilator. Ex :
verilator --cc DataMemory.v --exe DataMemory_tb.cpp --trace
A testbench can then be written in C++ (Ex : DataMemory_tb.cpp) to test the module.
For the DataMemory module, we test the functionality of Write Data
and Read Data
separately.
First, initialize the Data Memory and set the address 0
to 42
.
c.io.addr.poke(0.U) // Address 0
c.io.dataIn.poke(42.S) // Write data = 42
c.io.mem_write.poke(false.B) // Disable write initially
c.io.mem_read.poke(false.B) // Disable read initially
c.clock.step(1) // Step clock
To test the functionality of Write Data
, we set the mem_write
to true and step the clock.
c.io.mem_write.poke(true.B) // Enable write
c.clock.step(1) // Step clock
c.io.mem_write.poke(false.B) // Disable write after one cycle
To test the functionality of Read Data
, we set the mem_read
to true, step the clock and expect the output to be 42
.
c.io.mem_read.poke(true.B) // Enable read
c.clock.step(1) // Step clock
c.io.dataOut.expect(42.S) // Expect dataOut = 42
c.io.mem_read.poke(false.B) // Disable read
Finally, we test Write Data
and Read data
together. Write the Address 1
to -15
and read it from the memory.
c.io.addr.poke(1.U) // Address 1
c.io.dataIn.poke(-15.S) // Write data = -15
c.io.mem_write.poke(true.B) // Enable write
c.clock.step(1) // Step clock
c.io.mem_write.poke(false.B) // Disable write
c.io.mem_read.poke(true.B) // Enable read
c.clock.step(1) // Step clock
c.io.dataOut.expect(-15.S) // Expect dataOut = -15
The whole code of DataMemoryTest.scala
would be
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class DataMemoryTester extends AnyFlatSpec with ChiselScalatestTester {
behavior of "DataMemory"
it should "Correct" in {
test(new DataMemory) { c =>
// Initialize
c.io.addr.poke(0.U) // Address 0
c.io.dataIn.poke(42.S) // Write data = 42
c.io.mem_write.poke(false.B) // Disable write initially
c.io.mem_read.poke(false.B) // Disable read initially
c.clock.step(1) // Advance clock
// Write data
c.io.mem_write.poke(true.B) // Enable write
c.clock.step(1) // Advance clock
c.io.mem_write.poke(false.B) // Disable write after one cycle
// Read data
c.io.mem_read.poke(true.B) // Enable read
c.clock.step(1) // Advance clock
c.io.dataOut.expect(42.S) // Expect dataOut = 42
c.io.mem_read.poke(false.B) // Disable read
// Write and read from another address
c.io.addr.poke(1.U) // Address 1
c.io.dataIn.poke(-15.S) // Write data = -15
c.io.mem_write.poke(true.B) // Enable write
c.clock.step(1) // Advance clock
c.io.mem_write.poke(false.B) // Disable write
c.io.mem_read.poke(true.B) // Enable read
c.clock.step(1) // Advance clock
c.io.dataOut.expect(-15.S) // Expect dataOut = -15
}
}
}
The code for generating verilog:
object Main extends App {
println(
ChiselStage.emitSystemVerilog(
gen = new DataMemory,
firtoolOpts = Array("-disable-all-randomization", "-strip-debug-info")
)
)
}
and the generated verilog code can be derived:
module Dmemory_1024x32(
input [9:0] R0_addr,
input R0_en,
R0_clk,
output [31:0] R0_data,
input [9:0] W0_addr,
input W0_en,
W0_clk,
input [31:0] W0_data
);
reg [31:0] Memory[0:1023];
always @(posedge W0_clk) begin
if (W0_en & 1'h1)
Memory[W0_addr] <= W0_data;
end // always @(posedge)
assign R0_data = R0_en ? Memory[R0_addr] : 32'bx;
endmodule
module DataMemory(
input clock,
reset,
input [31:0] io_addr,
io_dataIn,
input io_mem_read,
io_mem_write,
output [31:0] io_dataOut
);
wire [31:0] _Dmemory_ext_R0_data;
Dmemory_1024x32 Dmemory_ext (
.R0_addr (io_addr[9:0]),
.R0_en (io_mem_read),
.R0_clk (clock),
.R0_data (_Dmemory_ext_R0_data),
.W0_addr (io_addr[9:0]),
.W0_en (io_mem_write),
.W0_clk (clock),
.W0_data (io_dataIn)
);
assign io_dataOut = io_mem_read ? _Dmemory_ext_R0_data : 32'h0;
endmodule
We wrote a testbench to test the module's functionality:
#include "VDataMemory.h"
#include "verilated.h"
#include "verilated_vcd_c.h"
#include <stdint.h>
#include <iostream>
// clock
void tick(VDataMemory* dut, VerilatedVcdC* tfp, int tickcount) {
dut->clock = 0; // negedge
dut->eval();
if (tfp) tfp->dump(tickcount * 10 - 5); // dump negedge result
dut->clock = 1; // posedge
dut->eval();
if (tfp) tfp->dump(tickcount * 10); // dump posedge result
if (tfp) tfp->flush();
}
int main(int argc, char** argv) {
Verilated::commandArgs(argc, argv);
Verilated::traceEverOn(true);
VDataMemory* dut = new VDataMemory;
VerilatedVcdC* tfp = new VerilatedVcdC;
dut->trace(tfp, 99); // 99 for tracing all signals
tfp->open("DataMemory.vcd");
// init
dut->reset = 1;
int tickcount = 0;
tick(dut, tfp, ++tickcount);
dut->reset = 0;
for(int i = 0; i < 30; i++)
{
// write into memory
uint32_t addr = rand() % 1024;
int data = rand();
dut->io_mem_write = 1;
dut->io_mem_read = 0;
dut->io_addr = addr;
dut->io_dataIn = data;
tick(dut, tfp, ++tickcount);
// read from memory
dut->io_mem_write = 0;
dut->io_mem_read = 1;
dut->io_addr = addr;
tick(dut, tfp, ++tickcount);
// verify output
if (dut->io_dataOut == data) {
printf("Test passed: Read value is %d\n", dut->io_dataOut);
} else {
printf("Test failed: Expected %d, ", data);
printf("but got %d\n", dut->io_dataOut);
}
}
dut->final();
if (tfp) tfp->close();
delete dut;
delete tfp;
return 0;
}
and the verification result:
Test passed: Read value is 846930886
Test passed: Read value is 1714636915
Test passed: Read value is 424238335
Test passed: Read value is 1649760492
Test passed: Read value is 1189641421
Test passed: Read value is 1350490027
Test passed: Read value is 1102520059
Test passed: Read value is 1967513926
Test passed: Read value is 1540383426
Test passed: Read value is 1303455736
Test passed: Read value is 521595368
Test passed: Read value is 1726956429
Test passed: Read value is 861021530
Test passed: Read value is 233665123
Test passed: Read value is 468703135
Test passed: Read value is 1801979802
Test passed: Read value is 635723058
Test passed: Read value is 1125898167
Test passed: Read value is 2089018456
Test passed: Read value is 1656478042
Test passed: Read value is 1653377373
Test passed: Read value is 1914544919
Test passed: Read value is 756898537
Test passed: Read value is 1973594324
Test passed: Read value is 2038664370
Test passed: Read value is 184803526
Test passed: Read value is 1424268980
Test passed: Read value is 749241873
Test passed: Read value is 42999170
Test passed: Read value is 13549728
For the InstMem module, we utilize the .hex
file from rv32ui-p-add.hex
6f
00
00
05
73
2f
20
34
93
0f
80
00
63
08
ff
03
we extract the first four instructions which is 0x0500006f
,0x34202f73
,0x00800f93
, and 0x03ff0863
.
By using Seq
, we can generate the test cases for the module.
val testCases = Seq(
(0.U, "h0500006f".U), // Address 0, Expect 0x0500006f
(4.U, "h34202f73".U), // Address 4, Expect 0x34202f73
(8.U, "h00800f93".U), // Address 8, Expect 0x00800f93
(12.U, "h03ff0863".U) // Address 12, Expect 0x03ff0863
)
then test the module by poke
and expect
for ((addr, expectedData) <- testCases) {
c.io.addr.poke(addr)
c.clock.step(1)
println(s"Address: $addr, Expected: $expectedData, Actual: 0x ${c.io.data.peek().litValue}")
c.io.data.expect(expectedData)
}
Chisel uses h
as the prefix for hexadecimal values, instead of 0x
as used in C/C++.
Ref : Chisel Cookbook (Chisel Data Types)
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class InstMem_Tester extends AnyFlatSpec with ChiselScalatestTester {
behavior of "InstMem"
it should "Read data from file" in {
test(new InstMem("./src/riscv/rv32ui-p-add.hex")) { c =>
val testCases = Seq(
(0.U, "h0500006f".U), // Address 0, Expect 0x0500006f
(4.U, "h34202f73".U), // Address 4, Expect 0x34202f73
(8.U, "h00800f93".U), // Address 8, Expect 0x00800f93
(12.U, "h03ff0863".U) // Address 12, Expect 0x03ff0863
)
for ((addr, expectedData) <- testCases) {
c.io.addr.poke(addr)
c.clock.step(1)
c.io.data.expect(expectedData)
}
}
}
}
For the testing of the piplnes we test by initialize the registers, step the clock and validate the outputs.
The algorithm of ChiselTest
remains the same; therefore, to reduce the length of the article, the complete code will be provided without additional explanation.
For the IF_ID module, we also test the initialize of the registers.
c.io.pc_out.expect(0.S)
c.io.pc4_out.expect(0.U)
c.io.SelectedPC_out.expect(0.S)
c.io.SelectedInstr_out.expect(0.U)
However, Chisel seems to initialize the registers to zero >automatically, so we are not sure whether the test works.
The complete code :
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class IF_ID_Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "IF_ID"
it should "Initialize registers with default values" in {
test(new IF_ID) { c =>
c.io.pc_out.expect(0.S)
c.io.pc4_out.expect(0.U)
c.io.SelectedPC_out.expect(0.S)
c.io.SelectedInstr_out.expect(0.U)
}
}
it should "Pass inputs to outputs after one clock cycle" in {
test(new IF_ID) { c =>
val testPcIn = 42.S
val testPc4In = 46.U
val testSelectedPC = 100.S
val testSelectedInstr = "h12345678".U
c.io.pc_in.poke(testPcIn)
c.io.pc4_in.poke(testPc4In)
c.io.SelectedPC.poke(testSelectedPC)
c.io.SelectedInstr.poke(testSelectedInstr)
c.clock.step(1)
c.io.pc_out.expect(testPcIn)
c.io.pc4_out.expect(testPc4In)
c.io.SelectedPC_out.expect(testSelectedPC)
c.io.SelectedInstr_out.expect(testSelectedInstr)
}
}
}
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class ID_EX_Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "ID_EX"
it should "Pass inputs to outputs after one clock cycle" in {
test(new ID_EX) { c =>
// Initialize inputs
c.io.rs1_in.poke(1.U)
c.io.rs2_in.poke(2.U)
c.io.rs1_data_in.poke(10.S)
c.io.rs2_data_in.poke(20.S)
c.io.imm.poke(100.S)
c.io.rd_in.poke(5.U)
c.io.func3_in.poke(3.U)
c.io.func7_in.poke(true.B)
c.io.ctrl_MemWr_in.poke(true.B)
c.io.ctrl_Branch_in.poke(false.B)
c.io.ctrl_MemRd_in.poke(true.B)
c.io.ctrl_Reg_W_in.poke(true.B)
c.io.ctrl_MemToReg_in.poke(false.B)
c.io.ctrl_AluOp_in.poke(2.U)
c.io.ctrl_OpA_in.poke(1.U)
c.io.ctrl_OpB_in.poke(true.B)
c.io.ctrl_nextpc_in.poke(1.U)
c.io.IFID_pc4_in.poke(4.U)
// Clock step to register inputs
c.clock.step(1)
// Validate outputs
c.io.rs1_out.expect(1.U)
c.io.rs2_out.expect(2.U)
c.io.rs1_data_out.expect(10.S)
c.io.rs2_data_out.expect(20.S)
c.io.imm_out.expect(100.S)
c.io.rd_out.expect(5.U)
c.io.func3_out.expect(3.U)
c.io.func7_out.expect(true.B)
c.io.ctrl_MemWr_out.expect(true.B)
c.io.ctrl_Branch_out.expect(false.B)
c.io.ctrl_MemRd_out.expect(true.B)
c.io.ctrl_Reg_W_out.expect(true.B)
c.io.ctrl_MemToReg_out.expect(false.B)
c.io.ctrl_AluOp_out.expect(2.U)
c.io.ctrl_OpA_out.expect(1.U)
c.io.ctrl_OpB_out.expect(true.B)
c.io.ctrl_nextpc_out.expect(1.U)
c.io.IFID_pc4_out.expect(4.U)
}
}
}
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class EX_MEM_Test extends AnyFlatSpec with ChiselScalatestTester {
behavior of "EX_MEM"
it should "Pass inputs to outputs after one clock cycle" in {
test(new EX_MEM) { c =>
c.io.IDEX_MEMRD.poke(false.B)
c.io.IDEX_MEMWR.poke(false.B)
c.io.IDEX_MEMTOREG.poke(false.B)
c.io.IDEX_REG_W.poke(false.B)
c.io.IDEX_rs2.poke(0.S)
c.io.IDEX_rd.poke(0.U)
c.io.alu_out.poke(0.S)
c.clock.step(1)
c.io.EXMEM_memRd_out.expect(false.B)
c.io.EXMEM_memWr_out.expect(false.B)
c.io.EXMEM_memToReg_out.expect(false.B)
c.io.EXMEM_reg_w_out.expect(false.B)
c.io.EXMEM_rs2_out.expect(0.S)
c.io.EXMEM_rd_out.expect(0.U)
c.io.EXMEM_alu_out.expect(0.S)
c.io.IDEX_MEMRD.poke(true.B)
c.io.IDEX_MEMWR.poke(true.B)
c.io.IDEX_MEMTOREG.poke(true.B)
c.io.IDEX_REG_W.poke(true.B)
c.io.IDEX_rs2.poke(42.S)
c.io.IDEX_rd.poke(5.U)
c.io.alu_out.poke(123.S)
c.clock.step(1)
c.io.EXMEM_memRd_out.expect(true.B)
c.io.EXMEM_memWr_out.expect(true.B)
c.io.EXMEM_memToReg_out.expect(true.B)
c.io.EXMEM_reg_w_out.expect(true.B)
c.io.EXMEM_rs2_out.expect(42.S)
c.io.EXMEM_rd_out.expect(5.U)
c.io.EXMEM_alu_out.expect(123.S)
c.io.IDEX_MEMRD.poke(false.B)
c.io.IDEX_MEMWR.poke(false.B)
c.io.IDEX_MEMTOREG.poke(false.B)
c.io.IDEX_REG_W.poke(false.B)
c.io.IDEX_rs2.poke(-100.S)
c.io.IDEX_rd.poke(10.U)
c.io.alu_out.poke(-50.S)
c.clock.step(1)
c.io.EXMEM_memRd_out.expect(false.B)
c.io.EXMEM_memWr_out.expect(false.B)
c.io.EXMEM_memToReg_out.expect(false.B)
c.io.EXMEM_reg_w_out.expect(false.B)
c.io.EXMEM_rs2_out.expect(-100.S)
c.io.EXMEM_rd_out.expect(10.U)
c.io.EXMEM_alu_out.expect(-50.S)
}
}
}
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class MEM_WBTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "MEM_WB"
it should "Pass inputs to outputs after one clock cycle" in {
test(new MEM_WB) { c =>
c.io.EXMEM_MEMTOREG.poke(true.B)
c.io.EXMEM_REG_W.poke(true.B)
c.io.EXMEM_MEMRD.poke(false.B)
c.io.EXMEM_rd.poke(5.U)
c.io.in_dataMem_out.poke(123.S)
c.io.in_alu_out.poke(456.S)
c.clock.step(1)
c.io.MEMWB_memToReg_out.expect(true.B)
c.io.MEMWB_reg_w_out.expect(true.B)
c.io.MEMWB_memRd_out.expect(false.B)
c.io.MEMWB_rd_out.expect(5.U)
c.io.MEMWB_dataMem_out.expect(123.S)
c.io.MEMWB_alu_out.expect(456.S)
c.io.EXMEM_MEMTOREG.poke(false.B)
c.io.EXMEM_REG_W.poke(false.B)
c.io.EXMEM_MEMRD.poke(true.B)
c.io.EXMEM_rd.poke(10.U)
c.io.in_dataMem_out.poke(789.S)
c.io.in_alu_out.poke(1011.S)
c.clock.step(1)
c.io.MEMWB_memToReg_out.expect(false.B)
c.io.MEMWB_reg_w_out.expect(false.B)
c.io.MEMWB_memRd_out.expect(true.B)
c.io.MEMWB_rd_out.expect(10.U)
c.io.MEMWB_dataMem_out.expect(789.S)
c.io.MEMWB_alu_out.expect(1011.S)
}
}
}
For the Alu Control module, we test all of the alu op
to ensure the result concate correctly.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class AluControlTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "AluControl"
it should "generate correct control signals" in {
test(new AluControl) { c =>
// Test R type
c.io.func3.poke("b010".U)
c.io.func7.poke(true.B)
c.io.aluOp.poke(0.U)
c.io.out.expect("b001010".U)
c.io.func3.poke("b000".U)
c.io.func7.poke(false.B)
c.io.aluOp.poke(0.U)
c.io.out.expect("b000000".U)
// Test I type
c.io.func3.poke("b101".U)
c.io.func7.poke(false.B)
c.io.aluOp.poke(1.U)
c.io.out.expect("b00101".U)
c.io.func3.poke("b011".U)
c.io.aluOp.poke(1.U)
c.io.out.expect("b00011".U)
// Test SB type
c.io.func3.poke("b110".U)
c.io.aluOp.poke(2.U)
c.io.out.expect("b010110".U)
c.io.func3.poke("b001".U)
c.io.aluOp.poke(2.U)
c.io.out.expect("b010001".U)
// Test Branch type
c.io.aluOp.poke(3.U)
c.io.out.expect("b11111".U)
// Test Loads, S type, U type (lui), U type (auipc)
c.io.aluOp.poke(4.U)
c.io.out.expect("b00000".U)
c.io.aluOp.poke(5.U)
c.io.out.expect("b00000".U)
c.io.aluOp.poke(6.U)
c.io.out.expect("b00000".U)
c.io.aluOp.poke(7.U)
c.io.out.expect("b00000".U)
}
}
}
For the ALU module, we test all operations to ensure they produce correct results.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
import AluOpCode._
class ALUTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "ALU"
it should "perform correct operations" in {
test(new ALU) { c =>
// Test ALU_ADD and ALU_ADDI
c.io.in_A.poke(10.S)
c.io.in_B.poke(15.S)
c.io.alu_Op.poke(ALU_ADD)
c.io.out.expect(25.S)
c.io.alu_Op.poke(ALU_ADDI)
c.io.out.expect(25.S)
// Test ALU_SUB
c.io.alu_Op.poke(ALU_SUB)
c.io.out.expect(-5.S)
// Test ALU_SLL and ALU_SLLI
c.io.in_A.poke(3.S) // 3 = 0b11
c.io.in_B.poke(2.S) // Shift left by 2
c.io.alu_Op.poke(ALU_SLL)
c.io.out.expect(12.S) // 12 = 0b1100
c.io.alu_Op.poke(ALU_SLLI)
c.io.out.expect(12.S)
// Test ALU_SLT and ALU_SLTI
c.io.in_A.poke(5.S)
c.io.in_B.poke(10.S)
c.io.alu_Op.poke(ALU_SLT)
c.io.out.expect(1.S)
c.io.alu_Op.poke(ALU_SLTI)
c.io.out.expect(1.S)
// Test ALU_SLTU and ALU_SLTUI
c.io.in_A.poke(-1.S) // Treated as unsigned: max value
c.io.in_B.poke(0.S)
c.io.alu_Op.poke(ALU_SLTU)
c.io.out.expect(0.S)
c.io.alu_Op.poke(ALU_SLTUI)
c.io.out.expect(0.S)
// Test ALU_XOR and ALU_XORI
c.io.in_A.poke(6.S) // 6 = 0b110
c.io.in_B.poke(3.S) // 3 = 0b011
c.io.alu_Op.poke(ALU_XOR)
c.io.out.expect(5.S) // 5 = 0b101
c.io.alu_Op.poke(ALU_XORI)
c.io.out.expect(5.S)
// Test ALU_OR and ALU_ORI
c.io.alu_Op.poke(ALU_OR)
c.io.out.expect(7.S) // 7 = 0b111
c.io.alu_Op.poke(ALU_ORI)
c.io.out.expect(7.S)
// Test ALU_AND and ALU_ANDI
c.io.alu_Op.poke(ALU_AND)
c.io.out.expect(2.S) // 2 = 0b010
c.io.alu_Op.poke(ALU_ANDI)
c.io.out.expect(2.S)
// Test ALU_SRL and ALU_SRLI
c.io.in_A.poke(16.S) // 16 = 0b10000
c.io.in_B.poke(2.S)
c.io.alu_Op.poke(ALU_SRL)
c.io.out.expect(4.S) // 4 = 0b100
c.io.alu_Op.poke(ALU_SRLI)
c.io.out.expect(4.S)
// Test ALU_SRA and ALU_SRAI
c.io.in_A.poke(-16.S) // -16 = 0b1111111111110000 (sign-extended)
c.io.alu_Op.poke(ALU_SRA)
c.io.out.expect(-4.S) // -4 = 0b1111111111111100
c.io.alu_Op.poke(ALU_SRAI)
c.io.out.expect(-4.S)
// Test ALU_JAL and ALU_JALR
c.io.in_A.poke(42.S)
c.io.alu_Op.poke(ALU_JAL)
c.io.out.expect(42.S)
c.io.alu_Op.poke(ALU_JALR)
c.io.out.expect(42.S)
}
}
}
For the Branch module, we set an function to be able to input the test case.
def testBranch(fnct3: Int, branch: Boolean, x: Int, y: Int, expected: Boolean): Unit = {
c.io.fnct3.poke(fnct3.U)
c.io.branch.poke(branch.B)
c.io.arg_x.poke(x.S)
c.io.arg_y.poke(y.S)
c.clock.step()
c.io.br_taken.expect(expected.B)
}
then test all the possible condition of branch
// beq (fnct3 = 0)
testBranch(0, branch = true, x = 10, y = 10, expected = true) // Equal
testBranch(0, branch = true, x = 10, y = 5, expected = false) // Not equal
// bne (fnct3 = 1)
testBranch(1, branch = true, x = 10, y = 10, expected = false) // Equal
testBranch(1, branch = true, x = 10, y = 5, expected = true) // Not equal
// blt (fnct3 = 4)
testBranch(4, branch = true, x = 5, y = 10, expected = true) // Less than
testBranch(4, branch = true, x = 10, y = 5, expected = false) // Not less than
// bge (fnct3 = 5)
testBranch(5, branch = true, x = 10, y = 5, expected = true) // Greater than or equal
testBranch(5, branch = true, x = 5, y = 10, expected = false) // Not greater than or equal
// bltu (fnct3 = 6)
testBranch(6, branch = true, x = -1, y = 10, expected = false) // Unsigned: -1 is large
testBranch(6, branch = true, x = 5, y = 10, expected = true) // Unsigned less than
// bgeu (fnct3 = 7)
testBranch(7, branch = true, x = -1, y = 10, expected = true) // Unsigned: -1 is large
testBranch(7, branch = true, x = 5, y = 10, expected = false) // Unsigned not greater than or equal
// branch = false (should always be false)
testBranch(0, branch = false, x = 10, y = 10, expected = false)
testBranch(1, branch = false, x = 10, y = 5, expected = false)
testBranch(4, branch = false, x = 5, y = 10, expected = false)
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class BranchTester extends AnyFlatSpec with ChiselScalatestTester {
behavior of "Branch"
it should "correctly evaluate branch conditions" in {
test(new Branch) { c =>
// Function to test a single case
def testBranch(fnct3: Int, branch: Boolean, x: Int, y: Int, expected: Boolean): Unit = {
c.io.fnct3.poke(fnct3.U)
c.io.branch.poke(branch.B)
c.io.arg_x.poke(x.S)
c.io.arg_y.poke(y.S)
c.clock.step()
c.io.br_taken.expect(expected.B)
}
// beq (fnct3 = 0)
testBranch(0, branch = true, x = 10, y = 10, expected = true) // Equal
testBranch(0, branch = true, x = 10, y = 5, expected = false) // Not equal
// bne (fnct3 = 1)
testBranch(1, branch = true, x = 10, y = 10, expected = false) // Equal
testBranch(1, branch = true, x = 10, y = 5, expected = true) // Not equal
// blt (fnct3 = 4)
testBranch(4, branch = true, x = 5, y = 10, expected = true) // Less than
testBranch(4, branch = true, x = 10, y = 5, expected = false) // Not less than
// bge (fnct3 = 5)
testBranch(5, branch = true, x = 10, y = 5, expected = true) // Greater than or equal
testBranch(5, branch = true, x = 5, y = 10, expected = false) // Not greater than or equal
// bltu (fnct3 = 6)
testBranch(6, branch = true, x = -1, y = 10, expected = false) // Unsigned: -1 is large
testBranch(6, branch = true, x = 5, y = 10, expected = true) // Unsigned less than
// bgeu (fnct3 = 7)
testBranch(7, branch = true, x = -1, y = 10, expected = true) // Unsigned: -1 is large
testBranch(7, branch = true, x = 5, y = 10, expected = false) // Unsigned not greater than or equal
// branch = false (should always be false)
testBranch(0, branch = false, x = 10, y = 10, expected = false)
testBranch(1, branch = false, x = 10, y = 5, expected = false)
testBranch(4, branch = false, x = 5, y = 10, expected = false)
}
}
}
For the Control module, we test for each opcode
testOpcode(Opcode,(memWrite, branch, memRead, regWrite, menToReg, aluOp, opA, opB, ext, nextPcSel))
// R-type instruction (opcode 51)
testOpcode(51, (false.B, false.B, false.B, true.B, false.B, 0.U, 0.U, false.B, 0.U, 0.U))
// I-type instruction (opcode 19)
testOpcode(19, (false.B, false.B, false.B, true.B, false.B, 1.U, 0.U, true.B, 0.U, 0.U))
// S-type instruction (opcode 35)
testOpcode(35, (true.B, false.B, false.B, false.B, false.B, 5.U, 0.U, true.B, 1.U, 0.U))
// Load instruction (opcode 3)
testOpcode(3, (false.B, false.B, true.B, true.B, true.B, 4.U, 0.U, true.B, 0.U, 0.U))
// SB-type instruction (opcode 99)
testOpcode(99, (false.B, true.B, false.B, false.B, false.B, 2.U, 0.U, false.B, 0.U, 1.U))
// UJ-type instruction (opcode 111)
testOpcode(111, (false.B, false.B, false.B, true.B, false.B, 3.U, 1.U, false.B, 0.U, 2.U))
// Jalr instruction (opcode 103)
testOpcode(103, (false.B, false.B, false.B, true.B, false.B, 3.U, 1.U, false.B, 0.U, 3.U))
// U-type (LUI) instruction (opcode 55)
testOpcode(55, (false.B, false.B, false.B, true.B, false.B, 6.U, 3.U, true.B, 2.U, 0.U))
// U-type (AUIPC) instruction (opcode 23)
testOpcode(23, (false.B, false.B, false.B, true.B, false.B, 7.U, 2.U, true.B, 2.U, 0.U))
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class ControlTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "Control"
it should "generate correct signals for each opcode" in {
test(new Control) { c =>
// Function to test opcodes
def testOpcode(opcode: Int, expectedSignals: (Bool, Bool, Bool, Bool, Bool, UInt, UInt, Bool, UInt, UInt)) = {
c.io.opcode.poke(opcode.U)
c.clock.step(1)
val (memWrite, branch, memRead, regWrite, menToReg, aluOp, opA, opB, ext, nextPcSel) = expectedSignals
c.io.mem_write.expect(memWrite)
c.io.branch.expect(branch)
c.io.mem_read.expect(memRead)
c.io.reg_write.expect(regWrite)
c.io.men_to_reg.expect(menToReg)
c.io.alu_operation.expect(aluOp)
c.io.operand_A.expect(opA)
c.io.operand_B.expect(opB)
c.io.extend.expect(ext)
c.io.next_pc_sel.expect(nextPcSel)
}
// R-type instruction (opcode 51)
testOpcode(51, (false.B, false.B, false.B, true.B, false.B, 0.U, 0.U, false.B, 0.U, 0.U))
// I-type instruction (opcode 19)
testOpcode(19, (false.B, false.B, false.B, true.B, false.B, 1.U, 0.U, true.B, 0.U, 0.U))
// S-type instruction (opcode 35)
testOpcode(35, (true.B, false.B, false.B, false.B, false.B, 5.U, 0.U, true.B, 1.U, 0.U))
// Load instruction (opcode 3)
testOpcode(3, (false.B, false.B, true.B, true.B, true.B, 4.U, 0.U, true.B, 0.U, 0.U))
// SB-type instruction (opcode 99)
testOpcode(99, (false.B, true.B, false.B, false.B, false.B, 2.U, 0.U, false.B, 0.U, 1.U))
// UJ-type instruction (opcode 111)
testOpcode(111, (false.B, false.B, false.B, true.B, false.B, 3.U, 1.U, false.B, 0.U, 2.U))
// Jalr instruction (opcode 103)
testOpcode(103, (false.B, false.B, false.B, true.B, false.B, 3.U, 1.U, false.B, 0.U, 3.U))
// U-type (LUI) instruction (opcode 55)
testOpcode(55, (false.B, false.B, false.B, true.B, false.B, 6.U, 3.U, true.B, 2.U, 0.U))
// U-type (AUIPC) instruction (opcode 23)
testOpcode(23, (false.B, false.B, false.B, true.B, false.B, 7.U, 2.U, true.B, 2.U, 0.U))
}
}
}
For the ImmGenerator module, we test the possible input instruction and exam the output.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class ImmGeneratorTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "ImmGenerator"
it should "generate correct immediate values for all types" in {
test(new ImmGenerator) { c =>
// Helper function to sign-extend a value
def signExtend(value: BigInt, bits: Int): BigInt = {
val shift = 32 - bits
(value << shift) >> shift
}
// Test Case 1: I-type immediate
c.io.instr.poke("h00000093".U) // ADDI x1, x0, 0 -> Immediate: 0x000
c.clock.step(1)
c.io.I_type.expect(0.S)
c.io.instr.poke("hFFF00093".U) // ADDI x1, x0, -1 -> Immediate: 0xFFF
c.clock.step(1)
c.io.I_type.expect(-1.S)
// Test Case 2: S-type immediate
c.io.instr.poke("h00F02023".U) // SW x15, 0(x0) -> Immediate: 0x00F
c.clock.step(1)
c.io.S_type.expect(0.S)
c.io.instr.poke("hF8002023".U) // SW x15, -128(x0) -> Immediate: 0xF80
c.clock.step(1)
c.io.S_type.expect(-128.S)
// Test Case 3: SB-type immediate
c.io.instr.poke("h00008063".U) // BEQ x0, x0, 0 -> Immediate: 0x000
c.io.pc.poke(0.U)
c.clock.step(1)
c.io.SB_type.expect(0.S)
c.io.instr.poke("hFE008EE3".U) // BEQ x15, x0, -4 -> Immediate: 0xFFC (negative offset)
c.io.pc.poke(8.U)
c.clock.step(1)
c.io.SB_type.expect(4.S)
// Test Case 4: U-type immediate
c.io.instr.poke("h000000B7".U) // LUI x1, 0 -> Immediate: 0x00000000
c.clock.step(1)
c.io.U_type.expect(0.S)
// Test Case 5: UJ-type immediate
c.io.instr.poke("h0000006F".U) // JAL x0, 0 -> Immediate: 0x00000000
c.io.pc.poke(0.U)
c.clock.step(1)
c.io.UJ_type.expect(0.S)
c.io.instr.poke("hFF00006F".U) // JAL x0, -16 -> Immediate: 0xFFFFFFF0
c.io.pc.poke(16.U)
c.clock.step(1)
c.io.UJ_type.expect(-1046528.S)
}
}
}
For the JALR module, we test different type of input and check the result of output.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class JalrTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "Jalr"
it should "compute the correct jump address with alignment" in {
test(new Jalr) { c =>
// Test case 1: imme = 0, rdata1 = 0
c.io.imme.poke(0.U)
c.io.rdata1.poke(0.U)
c.clock.step(1)
c.io.out.expect(0.U)
// Test case 2: imme = 4, rdata1 = 8
c.io.imme.poke(4.U)
c.io.rdata1.poke(8.U)
c.clock.step(1)
c.io.out.expect(12.U) // 4 + 8 = 12 (no masking required)
// Test case 3: imme = 5, rdata1 = 10 (unaligned address)
c.io.imme.poke(5.U)
c.io.rdata1.poke(10.U)
c.clock.step(1)
c.io.out.expect(14.U) // (5 + 10 = 15, aligned to 14)
// Test case 4: imme = 0xFFFFFFFE, rdata1 = 1 (boundary test)
c.io.imme.poke("hFFFFFFFE".U)
c.io.rdata1.poke(1.U)
c.clock.step(1)
c.io.out.expect("hFFFFFFFE".U) // 0xFFFFFFFE + 1 = 0xFFFFFFFF, aligned to 0xFFFFFFFE
// Test case 5: imme = 0x1234, rdata1 = 0x5678
c.io.imme.poke("h1234".U)
c.io.rdata1.poke("h5678".U)
c.clock.step(1)
c.io.out.expect("h68AC".U) // 0x1234 + 0x5678 = 0x68AC (already aligned)
}
}
}
For the PC4 module, we update the program counter for 0
, 4
, 100
and the output results should be pc+4
.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class PC4Test extends AnyFlatSpec with ChiselScalatestTester {
"PC4" should "correctly compute PC + 4" in {
test(new PC4) { c =>
// Test case 1: Input PC = 0
c.io.pc.poke(0.U)
c.clock.step(1)
c.io.out.expect(4.U)
// Test case 2: Input PC = 4
c.io.pc.poke(4.U)
c.clock.step(1)
c.io.out.expect(8.U)
// Test case 3: Input PC = 100
c.io.pc.poke(100.U)
c.clock.step(1)
c.io.out.expect(104.U)
}
}
}
For the PC module, we update the program counter for 4
, 100
, -8
to check the behavior and also check the result for program counter remain the same.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class PCTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "PC"
it should "update and hold the correct value" in {
test(new PC) { c =>
// Initial state: PC register should initialize to 0
c.io.out.expect(0.S)
// Test case 1: Update PC to 4
c.io.in.poke(4.S)
c.clock.step(1) // Advance one clock cycle
c.io.out.expect(4.S)
// Test case 2: Update PC to 100
c.io.in.poke(100.S)
c.clock.step(1)
c.io.out.expect(100.S)
// Test case 3: Update PC to -8
c.io.in.poke(-8.S)
c.clock.step(1)
c.io.out.expect(-8.S)
// Test case 4: Hold PC value (no change to input)
c.io.in.poke(-8.S)
c.clock.step(1)
c.io.out.expect(-8.S)
}
}
}
For the RegisterFile module
, we test the write and read functions, and most importantly, ensure that x0
is always zero.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class RegisterFileTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "RegisterFile"
it should "initialize and perform read/write correctly" in {
test(new RegisterFile) { c =>
// initialize
for (i <- 0 until 32) {
c.io.rs1.poke(i.U)
c.io.rs2.poke(i.U)
c.io.reg_write.poke(false.B)
c.io.w_reg.poke(0.U)
c.io.w_data.poke(0.S)
c.clock.step(1)
// check the initialize data
c.io.rdata1.expect(0.S)
c.io.rdata2.expect(0.S)
}
// write and read
c.io.reg_write.poke(true.B)
c.io.w_reg.poke(5.U)
c.io.w_data.poke(42.S)
c.clock.step(1)
// check
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.reg_write.poke(false.B)
c.clock.step(1)
c.io.rdata1.expect(42.S)
c.io.rdata2.expect(42.S)
// check write x0
c.io.reg_write.poke(true.B)
c.io.w_reg.poke(0.U)
c.io.w_data.poke(123.S)
c.clock.step(1)
// check x0
c.io.rs1.poke(0.U)
c.io.rs2.poke(0.U)
c.io.reg_write.poke(false.B)
c.clock.step(1)
c.io.rdata1.expect(0.S)
c.io.rdata2.expect(0.S)
}
}
}
For the BranchForward module, we test for ALU hazards
, EX/MEM hazards
, MEM/WB hazards
, and Jalr forwarding
.
add x5, x1, x2
beq x5, x5, label
Here, we set the rd
at ID_EX register to x5
(dut.io.ID_EX_RD.poke(5.U)
) and set both rs1
and rs2
to x5
.
it should "handle ALU hazards correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(5.U)
c.io.EX_MEM_RD.poke(0.U)
c.io.MEM_WB_RD.poke(0.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.ctrl_branch.poke(1.U)
// ALU hazard forwarding
c.clock.step(1)
c.io.forward_rs1.expect("b0001".U)
c.io.forward_rs2.expect("b0001".U)
}
}
Here, we set the rd
at EX_MEM register to x5
(dut.io.ID_EX_RD.poke(5.U)
) and set both rs1
and rs2
to x5
.
it should "handle EX/MEM hazards correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(0.U)
c.io.EX_MEM_RD.poke(5.U)
c.io.MEM_WB_RD.poke(0.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.ctrl_branch.poke(1.U)
c.clock.step(1)
c.io.forward_rs1.expect("b0010".U)
c.io.forward_rs2.expect("b0010".U)
}
}
Here, we set the rd
at MEM_WB register to x5
(dut.io.ID_EX_RD.poke(5.U)
) and set both rs1
and rs2
to x5
.
it should "handle MEM/WB hazards correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(0.U)
c.io.EX_MEM_RD.poke(0.U)
c.io.MEM_WB_RD.poke(5.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.ctrl_branch.poke(1.U)
c.clock.step(1)
c.io.forward_rs1.expect("b0011".U)
c.io.forward_rs2.expect("b0011".U)
}
}
add x5, x1, x2
jalr x0, x5, 0
Here, we set the ctrl_branch
to 0 which means the instruction is JALR.
it should "handle Jalr forwarding logic correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(5.U)
c.io.EX_MEM_RD.poke(0.U)
c.io.MEM_WB_RD.poke(0.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.ctrl_branch.poke(0.U)
c.clock.step(1)
c.io.forward_rs1.expect("b0110".U)
}
}
The whole code would be :
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class BranchForwardTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "BranchForward"
it should "handle ALU hazards correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(5.U)
c.io.EX_MEM_RD.poke(0.U)
c.io.MEM_WB_RD.poke(0.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.ctrl_branch.poke(1.U)
// ALU hazard forwarding
c.clock.step(1)
c.io.forward_rs1.expect("b0001".U)
c.io.forward_rs2.expect("b0001".U)
}
}
it should "handle EX/MEM hazards correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(0.U)
c.io.EX_MEM_RD.poke(5.U)
c.io.MEM_WB_RD.poke(0.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.ctrl_branch.poke(1.U)
c.clock.step(1)
c.io.forward_rs1.expect("b0010".U)
c.io.forward_rs2.expect("b0010".U)
}
}
it should "handle MEM/WB hazards correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(0.U)
c.io.EX_MEM_RD.poke(0.U)
c.io.MEM_WB_RD.poke(5.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.rs2.poke(5.U)
c.io.ctrl_branch.poke(1.U)
c.clock.step(1)
c.io.forward_rs1.expect("b0011".U)
c.io.forward_rs2.expect("b0011".U)
}
}
it should "handle Jalr forwarding logic correctly" in {
test(new BranchForward) { c =>
c.io.ID_EX_RD.poke(5.U)
c.io.EX_MEM_RD.poke(0.U)
c.io.MEM_WB_RD.poke(0.U)
c.io.ID_EX_memRd.poke(0.U)
c.io.EX_MEM_memRd.poke(0.U)
c.io.MEM_WB_memRd.poke(0.U)
c.io.rs1.poke(5.U)
c.io.ctrl_branch.poke(0.U)
c.clock.step(1)
c.io.forward_rs1.expect("b0110".U)
}
}
}
For the remaining Hazard Units, we applied the same testing methodology. Therefore, only the code demonstration is provided here.
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class ForwardingTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "Forwarding"
it should "handle EX Hazard correctly" in {
test(new Forwarding) { c =>
c.io.IDEX_rs1.poke(5.U)
c.io.IDEX_rs2.poke(5.U)
c.io.EXMEM_rd.poke(5.U)
c.io.EXMEM_regWr.poke(1.U)
c.io.MEMWB_rd.poke(0.U)
c.io.MEMWB_regWr.poke(0.U)
c.clock.step(1)
c.io.forward_a.expect("b10".U)
c.io.forward_b.expect("b10".U)
c.io.IDEX_rs1.poke(5.U)
c.io.IDEX_rs2.poke(6.U)
c.io.EXMEM_rd.poke(5.U)
c.io.EXMEM_regWr.poke(1.U)
c.io.MEMWB_rd.poke(0.U)
c.io.MEMWB_regWr.poke(0.U)
c.clock.step(1)
c.io.forward_a.expect("b10".U)
c.io.forward_b.expect("b00".U)
c.io.IDEX_rs1.poke(6.U)
c.io.IDEX_rs2.poke(5.U)
c.io.EXMEM_rd.poke(5.U)
c.io.EXMEM_regWr.poke(1.U)
c.io.MEMWB_rd.poke(0.U)
c.io.MEMWB_regWr.poke(0.U)
c.clock.step(1)
c.io.forward_a.expect("b00".U)
c.io.forward_b.expect("b10".U)
}
}
it should "handle MEM Hazard correctly" in {
test(new Forwarding) { c =>
c.io.IDEX_rs1.poke(5.U)
c.io.IDEX_rs2.poke(5.U)
c.io.EXMEM_rd.poke(0.U)
c.io.EXMEM_regWr.poke(0.U)
c.io.MEMWB_rd.poke(5.U)
c.io.MEMWB_regWr.poke(1.U)
c.clock.step(1)
c.io.forward_a.expect("b01".U)
c.io.forward_b.expect("b01".U)
// Case 2: MEM Hazard for rs1 only
c.io.IDEX_rs1.poke(5.U)
c.io.IDEX_rs2.poke(6.U)
c.io.EXMEM_rd.poke(0.U)
c.io.EXMEM_regWr.poke(0.U)
c.io.MEMWB_rd.poke(5.U)
c.io.MEMWB_regWr.poke(1.U)
c.clock.step(1)
c.io.forward_a.expect("b01".U)
c.io.forward_b.expect("b00".U)
// Case 3: MEM Hazard for rs2 only
c.io.IDEX_rs1.poke(6.U)
c.io.IDEX_rs2.poke(5.U)
c.io.EXMEM_rd.poke(0.U)
c.io.EXMEM_regWr.poke(0.U)
c.io.MEMWB_rd.poke(5.U)
c.io.MEMWB_regWr.poke(1.U)
c.clock.step(1)
c.io.forward_a.expect("b00".U)
c.io.forward_b.expect("b01".U)
}
}
it should "handle no hazards correctly" in {
test(new Forwarding) { c =>
c.io.IDEX_rs1.poke(5.U)
c.io.IDEX_rs2.poke(6.U)
c.io.EXMEM_rd.poke(0.U)
c.io.EXMEM_regWr.poke(0.U)
c.io.MEMWB_rd.poke(0.U)
c.io.MEMWB_regWr.poke(0.U)
c.clock.step(1)
c.io.forward_a.expect("b00".U)
c.io.forward_b.expect("b00".U)
}
}
}
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
class HazardDetectionTest extends AnyFlatSpec with ChiselScalatestTester {
"HazardDetection" should "detect hazards correctly" in {
test(new HazardDetection) { c =>
// Case 1: Hazard detected (Rs1 matches ID_EX_rd)
c.io.IF_ID_inst.poke("h00028033".U) // Rs1 = 0x02, Rs2 = 0x00
c.io.ID_EX_memRead.poke(true.B)
c.io.ID_EX_rd.poke(5.U)
c.io.pc_in.poke(100.S)
c.io.current_pc.poke(96.S)
c.clock.step(1)
c.io.inst_forward.expect(true.B)
c.io.pc_forward.expect(true.B)
c.io.ctrl_forward.expect(true.B)
c.io.inst_out.expect("h00028033".U)
c.io.pc_out.expect(100.S)
c.io.current_pc_out.expect(96.S)
// Case 2: Hazard detected (Rs2 matches ID_EX_rd)
c.io.IF_ID_inst.poke("h0002A033".U) // Rs1 = 0x00, Rs2 = 0x02
c.io.ID_EX_memRead.poke(true.B)
c.io.ID_EX_rd.poke(5.U)
c.io.pc_in.poke(104.S)
c.io.current_pc.poke(100.S)
c.clock.step(1)
c.io.inst_forward.expect(true.B)
c.io.pc_forward.expect(true.B)
c.io.ctrl_forward.expect(true.B)
c.io.inst_out.expect("h0002A033".U)
c.io.pc_out.expect(104.S)
c.io.current_pc_out.expect(100.S)
}
}
}
For the test of Main, we utilize riscv-test as our test bench. There are many type of riscv-test, we choose rv32ui-p-*
to test our 5 stage pipelined RISC-V cpu where rv32ui stands for 32-bit RISC-V user mode with the integer base instruction set and p stands for program.
For every RISC-V test, if the CPU passes all tests, the global pointer (gp) will be set to 1. Otherwise, it will be set to a value greater than 1.
Take rv32ui-p-add
as example
0000018c <test_2>:
18c: 00200193 li gp,2
190: 00000593 li a1,0
194: 00000613 li a2,0
198: 00c58733 add a4,a1,a2
19c: 00000393 li t2,0
1a0: 4c771663 bne a4,t2,66c <fail>
000001a4 <test_3>:
1a4: 00300193 li gp,3
1a8: 00100593 li a1,1
1ac: 00100613 li a2,1
...
00000650 <test_38>:
650: 02600193 li gp,38
654: 01000093 li ra,16
658: 01e00113 li sp,30
65c: 00208033 add zero,ra,sp
660: 00000393 li t2,0
664: 00701463 bne zero,t2,66c <fail>
668: 02301063 bne zero,gp,688 <pass>
0000066c <fail>:
66c: 0ff0000f fence
670: 00018063 beqz gp,670 <fail+0x4>
674: 00119193 slli gp,gp,0x1
678: 0011e193 ori gp,gp,1
67c: 05d00893 li a7,93
680: 00018513 mv a0,gp
684: 00000073 ecall
00000688 <pass>:
688: 0ff0000f fence
68c: 00100193 li gp,1
690: 05d00893 li a7,93
694: 00000513 li a0,0
698: 00000073 ecall
69c: c0001073 unimp
6a0: 0000 .insn 2, 0x
If the test fails, the PC will jump to fail; otherwise, the program will complete all the tests and pass.
In order to check the gp
we set a debug output in the RegisterFile module ( i.e., io.reg_debug1 := regfile(3)
where 3 is the gp register id ) and print it out.
printf(p"============================= \n")
printf(p"pc : 0x${Hexadecimal(IF_ID_.io.SelectedPC)}\n")
printf(p"inst : 0x${Hexadecimal(IF_ID_.io.SelectedInstr)}\n")
printf(p"gp : 0x${Hexadecimal(RegFile.io.reg_debug)}\n")
For the test program, we step the clock in order to finish all the instructions.
package Pipeline
import chisel3._
import org.scalatest.flatspec.AnyFlatSpec
import chiseltest._
class MainTest extends AnyFlatSpec with ChiselScalatestTester{
behavior of "5-Stage test"
it should "Go through" in {
test(new PIPELINE){c =>
c.clock.step(600)
}
}
}
We tested all the instructions, and the gp register consistently shows gp: 0x00000001
after running the program, indicating that the testbench has passed.
We did not pass the fence and jalr tests. The reasons will be explained below.
The fence instruction serves as a memory barrier to enforce a specific ordering of loads and stores.Since we did not implement this functionality, we do not need to test rv32ui-p-fence
.
For jalr
, we checked the .dump
file and noticed that the RISC-V test set the t1 register to 0x00000010
, but the target of test 2 is at 0x000001a4
. Therefore, we believe that the reason we did not pass the test is because of the testbench mistakenly set the destination of jalr.
0000018c <test_2>:
18c: 00200193 li gp,2
190: 00000293 li t0,0
194: 00000317 auipc t1,0x0
198: 01030313 addi t1,t1,16 # 1a4 <target_2>
19c: 000302e7 jalr t0,t1
000001a0 <linkaddr_2>:
1a0: 0e00006f j 280 <fail>
000001a4 <target_2>:
1a4: 00000317 auipc t1,0x0
1a8: ffc30313 addi t1,t1,-4 # 1a0 <linkaddr_2>
1ac: 0c629a63 bne t0,t1,280 <fail>
We test the PIPLINE
module with the floowing cpp
code. As mention above gp
(register index 3) will be set to 1 if all testcases were passed. The following is the cpp
code:
#include "verilated.h"
#include "VPIPELINE.h"
#include "verilated_vcd_c.h"
void tick(VPIPELINE* dut)
{
dut->clock = 0;
dut->eval();
dut->clock = 1;
dut->eval();
}
int main(int argc, char** argv, char** env)
{
Verilated::commandArgs(argc, argv);
Verilated::traceEverOn(true);
VPIPELINE* dut = new VPIPELINE;
dut->reset = 1;
for (int i = 0; i < 5; i++) {
tick(dut);
}
dut->reset = 0;
int cycle_count = 0;
for (size_t i = 0; i < 600; i++)
{
tick(dut);
cycle_count++;
if(dut->PIPELINE__DOT__RegFile__DOT__regfile_3 == 1)
{
printf("passed, cycle count : %d\n", cycle_count);
return 0;
}
}
printf("failed\n");
return 0;
}
and the execution result:
~/5-Stage-RV32I/generated$ ./obj_dir/VPIPELINE
passed, cycle count : 521
We wrote a FSM to decide whether to take the branch or not. This module will be placed in the instruction fetch stage, and will be updated when the actual result is computed.
package Pipeline
import chisel3._
import chisel3.util._
import Branch_predict_state._
object Branch_predict_state{
val STRONG_TAKEN = 0.U(2.W)
val WEAK_TAKEN = 1.U(2.W)
val STRONG_NOT_TAKEN = 2.U(2.W)
val WEAK_NOT_TAKEN = 3.U(2.W)
def stateToString(state: UInt): String = {
state.litValue.toInt match {
case 0 => "STRONG_TAKEN"
case 1 => "WEAK_TAKEN"
case 2 => "STRONG_NOT_TAKEN"
case 3 => "WEAK_NOT_TAKEN"
case _ => "UNKNOWN"
}
}
}
class branch_predict extends Module{
val io = IO(new Bundle{
val taken = Input(Bool())
val branch_predict = Output(Bool())
})
val current_state = RegInit(STRONG_NOT_TAKEN)
val next_state = Wire(UInt(2.W))
next_state := current_state
when(current_state === STRONG_TAKEN){
next_state := Mux(io.taken, STRONG_TAKEN, WEAK_TAKEN)
}.elsewhen(current_state === WEAK_TAKEN){
next_state := Mux(io.taken, STRONG_TAKEN, WEAK_NOT_TAKEN)
}.elsewhen(current_state === STRONG_NOT_TAKEN){
next_state := Mux(io.taken, WEAK_NOT_TAKEN, STRONG_NOT_TAKEN)
}.elsewhen(current_state === WEAK_NOT_TAKEN){
next_state := Mux(io.taken, WEAK_TAKEN, STRONG_NOT_TAKEN)
}.otherwise{
next_state := current_state
}
current_state := next_state
printf(p"taken: ${io.taken}, current_state: ${current_state}, next_state: ${next_state}, predict: ${(next_state === STRONG_TAKEN) || (next_state === WEAK_TAKEN)}\n")
io.branch_predict := (next_state === STRONG_TAKEN) || (next_state === WEAK_TAKEN)
}
then test it by
package Pipeline
import chisel3._
import chiseltest._
import org.scalatest.flatspec.AnyFlatSpec
import Branch_predict_state._
class BranchPredictTest extends AnyFlatSpec with ChiselScalatestTester {
behavior of "branch_predict"
it should "correctly predict branch behavior" in {
test(new branch_predict) { c =>
val testCases = Seq(
// (initial state, taken, expected state, expected prediction)
(true.B, false.B),
(true.B, true.B),
(false.B, false.B),
(true.B, true.B),
(true.B, true.B),
(true.B, true.B),
(false.B, true.B)
)
for ((taken, expectedPrediction) <- testCases) {
c.io.taken.poke(taken)
c.io.branch_predict.expect(expectedPrediction)
c.clock.step()
}
}
}
}
Prediction should be compared with actual result to determined whether instruction should be flushed or not.
To achieve the functionality in 5-stage pipelined RISC-V processor, we rewrite the BRANCH
module as well.
val io = IO(new Bundle {
val fnct3 = Input(UInt(3.W))
val branch = Input(Bool())
val arg_x = Input(SInt(32.W))
val arg_y = Input(SInt(32.W))
val pred = Input(Bool()) // predicted result of whether branch is taken or not
val actual = Output(Bool()) // actual result of whether branch is taken or not
val flush = Output(Bool())
})
The whole code would be as following,
package Pipeline
import chisel3._
import chisel3.util._
class Branch extends Module {
val io = IO(new Bundle {
val fnct3 = Input(UInt(3.W))
val branch = Input(Bool())
val arg_x = Input(SInt(32.W))
val arg_y = Input(SInt(32.W))
val pred = Input(Bool()) // predicted result of whether branch is taken or not
val actual = Output(Bool()) // actual result of whether branch is taken or not
val flush = Output(Bool())
})
io.actual := false.B
io.flush := false.B
val temp = WireDefault(false.B)
when(io.branch) {
// beq
when(io.fnct3 === 0.U) {
io.actual := io.arg_x === io.arg_y
temp := io.arg_x === io.arg_y
}
// bne
.elsewhen(io.fnct3 === 1.U) {
io.actual := io.arg_x =/= io.arg_y
temp := io.arg_x =/= io.arg_y
}
// blt
.elsewhen(io.fnct3 === 4.U) {
io.actual := io.arg_x < io.arg_y
temp := io.arg_x < io.arg_y
}
// bge
.elsewhen(io.fnct3 === 5.U) {
io.actual := io.arg_x >= io.arg_y
temp := io.arg_x >= io.arg_y
}
// bltu (unsigned less than)
.elsewhen(io.fnct3 === 6.U) {
io.actual := io.arg_x.asUInt < io.arg_y.asUInt
temp := io.arg_x < io.arg_y
}
// bgeu (unsigned greater than or equal)
.elsewhen(io.fnct3 === 7.U) {
io.actual := io.arg_x.asUInt >= io.arg_y.asUInt
temp := io.arg_x >= io.arg_y
}
io.flush := io.pred ^ temp
}
}
Other than the predictor, we added two more modules to compute the desired program counter. The BTB
module is for computing the predicted program counter, and the PCselector
is for deciding which program counter to take. The following will be explaination of these 2 module.
BTB
The BTB
module takes in current PC
and current instruction, then calculate target address and decide whether the instruction is B-type
.
package Pipeline
import chisel3._
import chisel3.util._
class BTB extends Module {
val io = IO(new Bundle {
val inst = Input(UInt(32.W))
val PC = Input(UInt(32.W))
val isBtype = Output(Bool())
val target = Output(UInt(32.W))
})
// Compute immediate value
val imm = Cat(Fill(23, io.inst(7)), io.inst(30, 26), io.inst(11, 8))
// Compute target
io.target := Mux(io.inst(6, 0) === "b1100011".U, io.PC + imm.asUInt, io.PC + 4.U)
// Determine if instruction is B-type
io.isBtype := io.inst(6, 0) === "b1100011".U
}
PC selection
When predictor is introduced, how to set the new program counter become complicated. Therefore, we redesign the logic of selecting program counter.
when(HazardDetect.io.pc_forward === 1.B) {
// If load type instruction happens, stall for one cycle
PC.io.in := HazardDetect.io.pc_out
}.otherwise {
when(control_module.io.next_pc_sel === "b01".U) {
when(Branch_M.io.flush === 1.B && control_module.io.branch === 1.B) {
//conditional jump, check if flush is needed
PC.io.in := Mux(Branch_M.io.actual, IF_ID_.io.target_old, IF_ID_.io.pc4_out.asSInt)
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.target:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.otherwise {
// decide next PC using predictor
PC.io.in := Mux(btb.io.isBtype, Mux(predictor.io.prediction, btb.io.target, PC4.io.out.asSInt), PC4.io.out.asSInt)
}
}.elsewhen(control_module.io.next_pc_sel === "b10".U) {
// unconditional jump, flush unconditionally
PC.io.in := ImmGen.io.UJ_type
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.target:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.elsewhen(control_module.io.next_pc_sel === "b11".U) {
// unconditional jump, flush unconditionally
PC.io.in := JALR.io.out.asSInt
IF_ID_.io.pc_in := 0.S
IF_ID_.io.pc4_in := 0.U
IF_ID_.io.target:= 0.S
IF_ID_.io.SelectedInstr := 0.U
}.otherwise {
// decide next PC using predictor
PC.io.in := Mux(btb.io.isBtype, Mux(predictor.io.prediction, btb.io.target, PC4.io.out.asSInt), PC4.io.out.asSInt)
}
}
We test our processor with branch prediction using riscv test mentioned above, which set the register gp
to
#include "verilated.h"
#include "VCPU.h"
#include "verilated_vcd_c.h"
void tick(VCPU* dut)
{
dut->clock = 0;
dut->eval();
dut->clock = 1;
dut->eval();
}
int main(int argc, char** argv, char** env)
{
Verilated::commandArgs(argc, argv);
Verilated::traceEverOn(true);
VCPU* dut = new VCPU;
dut->reset = 1;
for (int i = 0; i < 5; i++) {
tick(dut);
}
dut->reset = 0;
int cycle_count = 0;
for (size_t i = 0; i < 600; i++)
{
tick(dut);
cycle_count++;
if(i % 50 == 0)printf("gp register = 0x%08X\n", dut->CPU__DOT__RegFile__DOT__regfile_3);
if(dut->CPU__DOT__RegFile__DOT__regfile_3 == 1)
{
printf("gp register = 0x%08X\n", dut->CPU__DOT__RegFile__DOT__regfile_3);
printf("passed, cycle count : %d\n", cycle_count);
return 0;
}
}
printf("failed\n");
return 0;
}
and the execution result:
gp register = 0x00000001
passed, cycle count : 521
To verify the branch prediction module, we compare the number of cycles required to execute the same program on the original processor and on the one with the prediction mechanism. The testing program is shown as follow:
.section .text
.global _start
_start:
# initialize
li t0, 0
li t1, 10
# ==========================
# Test 1: Always Taken
# ==========================
always_taken:
addi t0, t0, 1 # t0 += 1
beq t0, t1, branch_end
beq zero, zero, always_taken
branch_end:
# ==========================
# Test 2: Not Taken
# ==========================
not_taken:
li t0, 0
li t1, 10
not_taken_loop:
addi t0, t0, 1
bne t0, t1, not_taken_loop
# ==========================
# Test 3: Alternating
# ==========================
alternating:
li t0, 0
li t1, 10
li t2, 0
alternating_loop:
addi t0, t0, 1
beq t2, zero, alt_taken
j alt_not_taken
alt_taken:
li t2, 1
j alternating_loop
alt_not_taken:
li t2, 0
bne t0, t1, alternating_loop
done:
li gp, 1
and we used verilator to monitor how many cycles is required:
#include "verilated.h"
#include "VCPU.h"
#include "verilated_vcd_c.h"
void tick(VCPU* dut)
{
dut->clock = 0;
dut->eval();
dut->clock = 1;
dut->eval();
}
int main(int argc, char** argv, char** env)
{
Verilated::commandArgs(argc, argv);
Verilated::traceEverOn(true);
VCPU* dut = new VCPU;
dut->reset = 1;
for (int i = 0; i < 5; i++) {
tick(dut);
}
dut->reset = 0;
int cycle_count = 0;
int branch_count = 0;
int flush_count = 0;
for (size_t i = 0; i < 1500; i++)
{
tick(dut);
cycle_count++;
if(dut->CPU__DOT__control_module_io_branch)
{
branch_count++;
flush_count = dut->CPU__DOT__Branch_M_io_flush ? flush_count + 1: flush_count;
}
if(dut->CPU__DOT__RegFile__DOT__regfile_3 == 1)
{
printf("gp register = 0x%08X\n", dut->CPU__DOT__RegFile__DOT__regfile_3);
printf("passed, cycle count : %d\n", cycle_count);
printf("number of flushes : %d\n", flush_count);
return 0;
}
}
printf("failed\n");
return 0;
}
and the executing result:
without branch prediction
gp register = 0x00000001
passed, cycle count : 137
number of flushes : 22
with branch prediction
gp register = 0x00000001
passed, cycle count : 171
number of flushes : 37
We test our processor with a program that does multiplication:
.section .text
.global _start
_start:
addi a0,zero, -79 # multiplier
addi a1,zero, 2 # multiplicand
addi t2,zero, 0 # result
loop:
addi t0,zero, 1 # check if lsb = 0
beq a1, zero, done
and t0, a1, t0
beq t0, zero, next
add t2, a0, t2
next:
srli a1, a1, 1
slli a0, a0, 1
j loop
done:
li gp, 1
Result :
with branch prediction
gp register = 0x00000001
passed, cycle count : 83
number of flushes : 10
without branch prediction
gp register = 0x00000001
passed, cycle count : 92
number of flushes : 5
We tried running bubble sort with our processor, the following is the testing program:
.section .text
.global _start
.data
arr:
.word 456
.word 78
.word -796
.word 456785
.word 3
.word -12345
.word 98765
.word 4321
.word -654
.word 0
_start:
li sp, 0x7ff
addi sp,sp,-16
sw ra,12(sp)
sw s0,8(sp)
addi s0,sp,16
li a1,10
lui a5,%hi(arr)
addi a0,a5,%lo(arr)
call bubble_sort
li a5,0
mv a0,a5
lw ra,12(sp)
lw s0,8(sp)
addi sp,sp,16
li gp, 1
jr ra
bubble_sort:
addi sp,sp,-48
sw ra,44(sp)
sw s0,40(sp)
addi s0,sp,48
sw a0,-36(s0)
sw a1,-40(s0)
sw zero,-20(s0)
j .L2
.L8:
sw zero,-28(s0)
sw zero,-24(s0)
j .L3
.L5:
lw a5,-24(s0)
slli a5,a5,2
lw a4,-36(s0)
add a5,a4,a5
lw a4,0(a5)
lw a5,-24(s0)
addi a5,a5,1
slli a5,a5,2
lw a3,-36(s0)
add a5,a3,a5
lw a5,0(a5)
ble a4,a5,.L4
lw a5,-24(s0)
addi a5,a5,1
slli a5,a5,2
lw a4,-36(s0)
add a5,a4,a5
lw a5,0(a5)
sw a5,-32(s0)
lw a5,-24(s0)
slli a5,a5,2
lw a4,-36(s0)
add a4,a4,a5
lw a5,-24(s0)
addi a5,a5,1
slli a5,a5,2
lw a3,-36(s0)
add a5,a3,a5
lw a4,0(a4)
sw a4,0(a5)
lw a5,-24(s0)
slli a5,a5,2
lw a4,-36(s0)
add a5,a4,a5
lw a4,-32(s0)
sw a4,0(a5)
.L4:
lw a5,-24(s0)
addi a5,a5,1
sw a5,-24(s0)
.L3:
lw a4,-40(s0)
lw a5,-20(s0)
sub a5,a4,a5
addi a5,a5,-1
lw a4,-24(s0)
blt a4,a5,.L5
lw a5,-28(s0)
bne a5,zero,.L9
lw a5,-20(s0)
addi a5,a5,1
sw a5,-20(s0)
.L2:
lw a4,-20(s0)
lw a5,-40(s0)
blt a4,a5,.L8
j .L10
.L9:
nop
.L10:
nop
lw ra,44(sp)
lw s0,40(sp)
addi sp,sp,48
jr ra
We could not pass this one, and our guess is that the memory map is problematic. Our first assumption is stack overflow, because we did not allocate enough space when designing the data memory module, the lw ra,44(sp)
instruction did not work as expected, causing problem for returning. The second assumption is segmentaion fault, because we did not specify the memory map during linking, return address is gone during sorting. We will try to fix this problem as soon as possible.
Brute Force solution
We assign the data of array with lw
instruction and set head of the array's head to 0
# a0 start at the bottom of memory
# a0 = {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
li a0, 0
mv t0, a0
li t2, 9
sw t2, 0(t0)
addi t0, t0, 4
li t2, 8
sw t2, 0(t0)
addi t0, t0, 4
li t2, 7
sw t2, 0(t0)
addi t0, t0, 4
li t2, 6
sw t2, 0(t0)
addi t0, t0, 4
li t2, 5
sw t2, 0(t0)
addi t0, t0, 4
li t2, 4
sw t2, 0(t0)
addi t0, t0, 4
li t2, 3
sw t2, 0(t0)
addi t0, t0, 4
li t2, 2
sw t2, 0(t0)
addi t0, t0, 4
li t2, 1
and the program became runnable.
with predictor
execution result
gp register = 0x00000001
passed, cycle count : 66
without predictor
execution result
gp register = 0x00000001
passed, cycle count : 70
number of branch instructions : 2, predictor hit : 2
Although this method solve the problem, it was too inefficient. We are trying to solve it with linker script.
In general, the branch predictor successfully reduces cycles.