蔡承遠, 郭晏愷
The F extension adds 32 floating-point registers, f0–f31, each 32 bits wide, and a floating-point control and status register fcsr, which contains the operating mode and exception status of the f loating-point unit.
Regs | ABI Name | Description |
---|---|---|
f0-f07 | ft0-7 | Temporaries |
f08-09 | fs0-1 | saved regs |
f10-11 | fa0-1 | args/ return values |
f12-17 | fa2-7 | args |
f18-27 | fs2-11 | saved regs |
f28-31 | ft8-11 | Temporaries |
The 2-bit floating-point format field fmt in bit number 26
and 25
is encoded asshown in the table. It is set to S(00) for all instructions in the F extension.
fmt | Mnemonic | Meaning |
---|---|---|
00 | S | 32-bit single-precision |
01 | D | 64-bit double-precision |
10 | - | reserved |
11 | Q | 128-bit quad-precision |
Some instructions for example
Inst. | Name | Description |
---|---|---|
fmadd.s | mul-add | rd = rs1 * rs2 + rs3 |
fadd.s | add | rd = rs1 + rs2 |
fmul.s | mul | rd = rs1 * rs2 |
fdiv.s | div | rd = rs1 / rs2 |
fsqrt.s | square root | rd = sqrt(rs1) |
flt.s | less than | rd = (rs1 < rs2) ? 1 : 0 |
fcvt.s |
[fmadd.s rd, rs1, rs2, rs3]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| rs3 | 00 | rs2 | rs1 | rm | rd | 1000011 |
[fadd.s rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0000000 | rs2 | rs1 | rm | rd | 1010011 |
[fmul.s rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0001000 | rs2 | rs1 | rm | rd | 1010011 |
[fdiv rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0001100 | rs2 | rs1 | rm | rd | 1010011 |
[fsqrt.s rd, rs1]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0101100 | 00000 | rs1 | rm | rd | 1010011 |
[flt rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 1010000 | rs2 | rs1 | 001 | rd | 1010011 |
[fcvt]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
fcvt.W.s
fcvt.WU.s
fcvt.s.W
fcvt.S.WU
rm(12~14, funct3) bit meaning in F instruction is to describe the rounding mode
Rounding Mode | meaning | Mnemonic |
---|---|---|
0d0 | Round to Nearest, ties to Even. | RNE |
0d1 | Round towards Zero | RTZ |
0d2 | Round Down (towards −∞) | RDN |
0d3 | Round Up (towards +∞) | RUP |
0d4 | Round to Nearest, ties to Max Magnitude | RMM |
0d5 | Invalid. | |
0d6 | Invalid. | |
0d7 | In instruction’s rm field, selects dynamic rounding mode; In Rounding Mode register, Invalid. |
rd bit meaning in F-Classify instruction.
bit | meaning |
---|---|
0d0 | rs1 is −∞. |
0d1 | rs1 is a negative normal number. |
0d2 | rs1 is a negative subnormal number. |
0d3 | rs1 is −0. |
0d4 | rs1 is +0. |
0d5 | rs1 is a positive subnormal number. |
0d6 | rs1 is a positive normal number. |
0d7 | rs1 is +∞. |
0d8 | rs1 is a signaling NaN. |
0d9 | rs1 is a quiet NaN. |
For 32-bit single-precision floating point, the value = sign * exponent * fraction.
sign | exponent | fraction | |
---|---|---|---|
bit number | 31 | 23-30 | 22-0 |
length | 1 | 8 | 23 |
special values | |||
0 | 0 | 00000000 | all zero |
-0 | 1 | 00000000 | all zero |
1 | 0 | 01111111 | all zero |
-1 | 1 | 01111111 | all zero |
min. sub-normal number | * | 00000000 | 000 0000 0000 0000 0000 0001 |
max. sub-normal number | * | 00000000 | all one |
min. normal number | * | 00000001 | all zero |
max. normal number | * | 11111110 | all one |
-∞ | 1 | 11111111 | all zero |
∞ | 0 | 11111111 | all zero |
NaN | * | 11111111 | not all zero |
Base on 5-stage-RV32I from kinzafatim. The principle to implement RV32F is make a floating-point unit (FPU) parallel with Arithmetic logic unit (ALU). A FPU control and a FPU register unit are also needed. these three units are compared to ALU, which are ALU, ALU control and ALU register.
The figure below is a quick draw using matlab simulink to know the whole structure. Because matlab simulink doesn't support some feactures like mux selector, so it shows red dotted line in the picture.
src/main/scala/Pipeline/UNits/Control.scala
In order to control and decide the data read/write data in FPU or ALU and simplify the structure that some wires (e.g: data wire between memory and register…) can be common used outside ALU and FPU. We add a FPU enable port fpu_en
to decide this operation using ALU or FPU. Port fpu_operation
decode the F instruction type from opcode. Because we have rs3
for RV32F R4-type instruction, adding operand_C
for operand C source selection is necessary.
class Control extends Module {
val io = IO(new Bundle {
val opcode = Input(UInt(7.W)) // 7-bit opcode
val mem_write = Output(Bool()) // whether a write to memory
val branch = Output(Bool()) // whether a branch instruction
val mem_read = Output(Bool()) // whether a read from memory
val reg_write = Output(Bool()) // whether a register write
val men_to_reg = Output(Bool()) // whether the value written to a register (for load instructions)
val alu_operation = Output(UInt(3.W))
val fpu_en = Output(Bool())
val fpu_operation = Output(UInt(3.W))
val operand_A = Output(UInt(2.W)) // Operand A source selection for the ALU
val operand_B = Output(Bool()) // Operand B source selection for the ALU
val operand_C = Output(Bool()) // Operand C source selection for the FPU
// Indicates the type of extension to be used (e.g., sign-extend, zero-extend)
val extend = Output(UInt(2.W))
val next_pc_sel = Output(UInt(2.W)) // next PC value (e.g., PC+4, branch target, jump target)
})
...
}
src/main/scala/Pipeline/UNits/FPU_Control.scala
While reciving the enable signal fpu_enable
, this module reads the instruction type from fpu_op
port, encodes the operation code in different methods depending on the instruction type. After encoding the operation code, output to FPU via fpu_out
port.
class FPU_Control extends Module {
val io = IO(new Bundle {
val fpu_op = Input(UInt(5.W))
val fpu_funct3 = Input(UInt(3.W))
val fpu_funct7 = Input(UInt(5.W)) //27~31
val fpu_op5 = Input(UInt(5.W)) //6~2
val fpu_out = Output(UInt(5.W))
val fpu_enable = Input(Bool())
})
io.fpu_out := 0.U
when(io.fpu_enable) {
when(io.fpu_op === 0.U) { //R type
io.fpu_out := io.fpu_funct7
}.elsewhen(io.fpu_op === 1.U) {//R4 type
io.fpu_out := io.fpu_op5
}.elsewhen(io.fpu_op === 2.U) {//I type
io.fpu_out := Cat("b010".U(3.W), io.fpu_funct3)
}.elsewhen(io.fpu_op === 3.U) {//S type
io.fpu_out := "b11111".U
}.otherwise {
io.fpu_out := 0.U
}
}
}
fpu_op
reads the instruction type decoded from control module.
fpu_funct3
reads the function 3 in the instruction.
fpu_funct7
reads the function 7 in the instruction.
fpu_op5
reads the opcode(6, 2) in the instruction.
Below is the encoded operation code lookup table, which built in FPU:
operation Code | instruction | sub-instruction identify in FPU |
---|---|---|
00 | fadd | |
01 | fsub | |
02 | fmul | |
03 | fdiv | |
04 | fsgn | rm(0/1) |
05 | fmin/fmax | rm(0/1) |
11 | fsqrt | |
16 | fmadd | |
17 | fmsub | |
18 | fnmsub | |
19 | fnmadd | |
20 | feq/flt/fle | rm(0/1/2) |
24 | fcvt.w.s/fcvt.wu.s | rs2(0/1) |
26 | fcvt.s.w/fcvt.s.wu | rs2(0/1) |
28 | fmv/fclass | rm(0/1) |
30 | fmv.s.x |
src/main/scala/Pipeline/UNits/FPU.scala
fmt
and rm
are considered only here.Include embedded comparator modules FP_COMP
. FP_COMP
file locate in src/main/scala/Pipeline/UNits/FP_COMP.scala
.
The implemented floating-point operations include addition, subtraction, multiplication, division, square root, fused multiply-add, less-than comparison, and conversions between floating-point and integer types.
We used different io.fpu_Op
values to select and execute specific floating-point operations in our FPU module. Each operation corresponds to a distinct block of logic within the switch
statement, which dynamically determines the behavior based on the provided opcode.
class FPU extends Module{
val io = IO(new Bundle{
val A_data_in = Input(UInt(32.W))
val B_data_in = Input(UInt(32.W))
val C_data_in = Input(UInt(32.W))
val fpu_Op = Input(UInt(5.W))
val rm = Input(UInt(3.W))
val rs2 = Input(UInt(5.W))
val fmt = Input(UInt(2.W))
val out = Output(UInt(32.W))
})
val NAN_exp = "b11111111".U(8.W)
val COMP_0 = Module(new FP_COMP)
val COMP_16 = Module(new FP_COMP)
val COMP_20 = Module(new FP_COMP)
val COMP_21 = Module(new FP_COMP)
val result = 0.U(32.W)
val result_sign = 0.U(1.W)
val result_exp = 0.U(8.W)
val result_frac = 0.U(23.W)
val carry = 0.U(1.W)
val exp_diff = 0.U(8.W)
val A_sign = 0.U(1.W)
val B_sign = 0.U(1.W)
val C_sign = 0.U(1.W)
val A_exp = 0.U(1.W)
val B_exp = 0.U(1.W)
val C_exp = 0.U(1.W)
val A_Mantissa = 0.U(24.W)
val B_Mantissa = 0.U(24.W)
val C_Mantissa = 0.U(24.U)
val Temp_Exp = 0.U(24.W)
val Temp_Mantissa = 0.U(24.W)
val B_shift_mantissa = 0.U(24.W)
val COMP = 0.U(1.W)
val COMP_20_ = 0.U(1.W)
val COMP_21_ = 0.U(1.W)
//comparator
COMP_0.io.rs1 := io.A_data_in
COMP_0.io.rs2 := io.B_data_in
COMP := COMP_0.io.COMP_RESULT
val A_swap = COMP ? io.A_data_in : io.B_data_in
val B_swap = COMP ? io.B_data_in : io.A_data_in
switch(io.fpu_Op) {
is(FPU_FADD_S) {
//result := io.in_A + io.in_B
}
is(FPU_FSUB_S) {
//result := io.in_A - io.in_B
}
is(FPU_FMUL_S) {
//result := io.in_A * io.in_B
}
is(FPU_FSQRT_S){
//result := io.in_A^(1/2)
}
is(FPU_FDIV_S){
//result := io.in_A / io.in_B
}
is(FPU_FMADD_S){
//result := (io.in_A * io.in_B) + io.in_C
}
is(FPU_FCVT_W_S) {
//result := io.in_A.asSInt
}
is(FPU_FLT_S) {
//result := io.in_A < io.in_B
}
}
A_data_in
, B_data_in
, C_data_in
are three inputs from the external environment.
COMP
is used for comparison between two values.
carry
is a register indicating whether a carry is generated during mantissa operations.
exp_diff
is an 8-bit register storing the difference between the exponents of two input floating-point numbers, used for mantissa alignment.
A_sign
, B_sign
, C_sign
are three registers representing the sign bit of A_data_in
, B_data_in
, and C_data_in
, respectively.
The adder (FADD) is designed to perform precise addition operations between two single-precision floating-point numbers.
carry
to ensure accuracy.// Addition Code
is(FPU_FADD_S) {
A_Mantissa := Cat(1.U(1.W), A_swap(22, 0))
B_Mantissa := Cat(1.U(1.W), B_swap(22, 0))
// shift mantissa
exp_diff := A_swap(30, 23) - B_swap(30, 23)
// add the two mantissa numbers and store the carry
B_shift_mantissa := B_Mantissa >> exp_diff
Temp_Mantissa := (A_swap(31) ^ B_swap(31))
? A_Mantissa - B_shift_mantissa
: A_Mantissa + B_shift_mantissa
carry := Temp_Mantissa(24)
// normalize the result
when(carry) {
Temp_Mantissa := Temp_Mantissa >> 1
when(result_exp < 255.U) {
result_exp := A_swap(30:23) + 1.U
// result_exp := result_exp + 1.U
} .otherwise {
result_exp := 255.U
}
} .elsewhen(Temp_Mantissa =/= 1.U) {
Temp_Mantissa := 0.U
} .otherwise {
for(i <- 0 until 24) {
when(Temp_Mantissa(23) =/= 1.U && result_exp > 0.U) {
Temp_Mantissa := Temp_Mantissa << 1
result_exp := result_exp - 1.U
}
}
}
result_sign := A_swap(31)
result_frac := Temp_Mantissa(22, 0)
result := Cat(result_sign, result_exp, result_frac)
}
The subtractor (FSUB) operates similarly to the adder (FADD), with the key difference being how the sign bit of input B is handled. The aligned mantissas are either added or subtracted based on the XOR of the sign bits. Therefore we changed the code of Temp_Mantissa
from adder to subtractor.
// Subtraction Code
is(FPU_FSUB_S) {
A_Mantissa := Cat(1.U(1.W), A_swap(22, 0))
B_Mantissa := Cat(1.U(1.W), B_swap(22, 0))
// shift mantissa
exp_diff := A_swap(30, 23) - B_swap(30, 23)
// add the two mantissa numbers and store the carry
B_shift_mantissa := B_Mantissa >> exp_diff
Temp_Mantissa := (A_swap(31) ^ B_swap(31))
? A_Mantissa + B_shift_mantissa
: A_Mantissa - B_shift_mantissa
carry := Temp_Mantissa(24)
when(carry) {
Temp_Mantissa := Temp_Mantissa >> 1
when(result_exp < 255.U) {
result_exp := A_swap(30:23) + 1.U
} .otherwise {
result_exp := 255.U
}
} .elsewhen(Temp_Mantissa =/= 1.U) {
Temp_Mantissa := 0.U
} .otherwise {
for(i <- 0 until 24) {
when(Temp_Mantissa(23) =/= 1.U && result_exp > 0.U) {
Temp_Mantissa := Temp_Mantissa << 1
result_exp := result_exp - 1.U
}
}
}
result_sign := A_swap(31)
result_frac := Temp_Mantissa(22, 0)
result := Cat(result_sign, result_exp, result_frac)
}
If the MSB of the
Mantissa_product = 1
(indicating overflow), the mantissa must be right-shifted, and the exponent must be incremented to maintain normalization.
= _
= If
Mantissa_prodct = 0
, the exponent remains unchanged.
//Multiplication Code
is(FPU_FMUL_S) {
// extract the sign bit, exponent, and mantissa
A_sign = A_swap(31)
B_sign = B_swap(31)
A_exp = A_swap(30, 23)
B_exp = B_swap(30, 23)
A_Mantissa = Cat(1.U(1.W), A_swap(22, 0))
B_Mantissa = Cat(1.U(1.W), B_swap(22, 0))
// multiply the signs
result_sign = A_sign ^ B_sign
// add the exponents
val exp_sum = A_exp + B_exp - 127.U
// multiply the mantissa
val Mantissa_product = A_Mantissa * B_Mantissa
// normalize
carry = Mantissa_product(47)
val normalized_Mantissa = Mux(carry, Mantissa_product(46, 24), Mantissa_product(45, 23))
val normalized_exp = Mux(carry, exp_sum + 1.U, exp_sum)
// exponent overflow or zero mantissa
when(Mantissa_product === 0.U) {
result_exp := 0.U
result_frac := 0.U
} .elsewhen(normalized_exp >= 255.U) {
result_exp := 255.U // infinity
result_frac := 0.U
} .elsewhen(normalized_exp <= 0.U) {
result_exp := 0.U // subnormal number
result_frac := normalized_Mantissa >> (1.U - normalized_exp)
} .otherwise {
result_exp := normalized_exp
result_frac := normalized_Mantissa
}
The target is to solve
, where N is the input number.
Newton's method uses the formula:
For
, the .Substituting these:
Start with an initial guess
, and repeat the iteration formula three times to converge to
// Square Root Code
is(FPU_FSQRT_S) {
// extract the sign bit, exponent, and mantissa
A_sign = A_swap(31)
A_exp = A_swap(30, 23)
A_Mantissa = Cat(1.U(1.W), A_swap(22, 0))
when(A_sign) {
// when input negative, output NaN
result := "b01111111100000000000000000000000".U
} .elsewhen(A_swap(30, 0) === 0.U) {
// when input zero, output zero
result := A_swap
} .otherwise {
// exponent
val exp_adjust = A_exp - 127.U
// odd or even number
val new_exp = Mux(exp_adjust(0),
(exp_adjust >> 1) + 127.U,
((exp_adjust - 1.U) >> 1) + 127.U
)
// Newton Iteration(three times)
val x0 = "b01000000000000000000000000000000".U(32.W) // x0 assumption = 1
val iter1 = (x0 + (A_Mantissa << 23) / x0) >> 1
val iter2 = (iter1 + (A_Mantissa << 23) / iter1) >> 1
val iter3 = (iter2 + (A_Mantissa << 23) / iter2) >> 1
val sqrt_Mantissa = iter2(45, 23)
// final result
result_sign := 0.U
result_exp := new_exp
result_frac := sqrt_Mantissa
result := Cat(result_sign, result_exp, result_frac)
}
}
// Division Code
is(FPU_FDIV_S) {
// extract the sign bit, exponent, and mantissa
A_sign := A_swap(31)
B_sign := B_swap(31)
A_exp := A_swap(30, 23)
B_exp := B_swap(30, 23)
A_Mantissa := Cat(1.U(1.W), A_swap(22, 0))
B_Mantissa := Cat(1.U(1.W), B_swap(22, 0))
// divide the signs
val result_sign = A_sign ^ B_sign
// minus the exponents
val exp_diff = A_exp - B_exp + 127.U
// Newton Iteration(three times) 1 / B_Mantissa
val x0 = "b01000000000000000000000000000000".U(32.W) // x0 assumption = 1
val iter1 = (x0 + (1.U << 23) - ((B_Mantissa << 23) / x0)) >> 1
val iter2 = (iter1 + (1.U << 23) - ((B_Mantissa << 23) / iter1)) >> 1
val iter3 = (iter2 + (1.U << 23) - ((B_Mantissa << 23) / iter2)) >> 1
val reciprocal = iter3(45, 23)
// calculate mantissa A_Mantissa / B_Mantissa = A_Mantissa * reciprocal
val Mantissa_div = A_Mantissa * reciprocal
// normalize
carry = Mantissa_div(47)
val normalized_Mantissa = Mux(carry, Mantissa_div(46, 24), Mantissa_div(45, 23))
val normalized_exp = Mux(carry, exp_diff + 1.U, exp_diff)
when(B_swap(30, 0) === 0.U) {
// B=0:NaN
result := "b01111111100000000000000000000000".U // NaN
} .elsewhen(A_swap(30, 0) === 0.U) {
// A=0:0
result := 0.U
} .elsewhen(normalized_exp >= 255.U) {
// exponent overflow:exp=infinity, mantissa=0
result_exp := 255.U
result_frac := 0.U
result := Cat(result_sign, result_exp, result_frac)
} .elsewhen(normalized_exp <= 0.U) {
// subnormal:exp=0, mantissa right shift
result_exp := 0.U
result_frac := normalized_Mantissa >> (1.U - normalized_exp)
result := Cat(result_sign, result_exp, result_frac)
} .otherwise {
// final result
result_exp := normalized_exp
result_frac := normalized_Mantissa
result := Cat(result_sign, result_exp, result_frac)
}
}
Combines multiplication and addition into a single operation to improve efficiency and minimize precision loss.
// Fused Multiply-Add Code
is(FPU_FMADD_S) {
// extract the sign bit, exponent, and mantissa
A_sign := io.A_data_in(31)
B_sign := io.B_data_in(31)
C_sign := io.C_data_in(31)
A_exp := io.A_data_in(30, 23)
B_exp := io.B_data_in(30, 23)
C_exp := io.C_data_in(30, 23)
A_Mantissa := Cat(1.U(1.W), io.A_data_in(22, 0))
B_Mantissa := Cat(1.U(1.W), io.B_data_in(22, 0))
C_Mantissa := Cat(1.U(1.W), io.C_data_in(22, 0))
// calculate A*B
val mul_sign = A_sign ^ B_sign
val mul_exp = A_exp + B_exp - 127.U
val mul_Mantissa = A_Mantissa * B_Mantissa
// normalize_A*B
val mul_carry = mul_Mantissa(47)
val mul_norm_Mantissa = Mux(mul_carry, mul_Mantissa(46, 24), mul_Mantissa(45, 23))
val mul_norm_exp = Mux(mul_carry, mul_exp + 1.U, mul_exp)
// comparator_A*B vs C
COMP_16.io.rs1 := mul_norm_Mantissa
COMP_16.io.rs2 := C_Mantissa
COMP := COMP_16.io.COMP_RESULT
// alogned number
val exp_diff = (mul_norm_exp - C_exp).asSInt()
val safe_shift = Mux(exp_diff < 0.S, -exp_diff.asUInt, exp_diff.asUInt)
val aligned_Mul_Mantissa = Mux(COMP, mul_norm_Mantissa, mul_norm_Mantissa >> safe_shift)
val aligned_C_Mantissa = Mux(COMP, C_Mantissa >> safe_shift, C_Mantissa)
val aligned_exp = Mux(COMP, mul_norm_exp, C_exp)
// ADD & SUB result
val add_sub_result = Mux(mul_sign === C_sign,
aligned_Mul_Mantissa + aligned_C_Mantissa,
aligned_Mul_Mantissa - aligned_C_Mantissa
)
val add_sub_sign = Mux(aligned_Mul_Mantissa >= aligned_C_Mantissa, mul_sign, C_sign)
// normalize result
val result_carry = add_sub_result(24)
val norm_result_Mantissa = Mux(result_carry, add_sub_result(23, 1), add_sub_result(22, 0))
val norm_result_exp = Mux(result_carry, aligned_exp + 1.U, aligned_exp)
// special cases
when((A_exp === 255.U && A_Mantissa =/= 0.U) ||
(B_exp === 255.U && B_Mantissa =/= 0.U) ||
(C_exp === 255.U && C_Mantissa =/= 0.U)) {
// A or B or C input = NaN
result := "b0111111111000000000000000000000".U(32.W) // Output NaN
}
when(io.A_data_in(30, 0) === 0.U || io.B_data_in(30, 0) === 0.U) {
// A or B = 0
result := io.C_data_in
} .elsewhen(io.C_data_in(30, 0) === 0.U) {
// C = 0
result := Cat(mul_sign, mul_norm_exp, mul_norm_Mantissa(22, 0))
} .elsewhen(norm_result_exp >= 255.U) {
// overflow
result := Cat(add_sub_sign, "b11111111".U(8.W), 0.U(23.W))
} .elsewhen(norm_result_exp <= 0.U) {
// subnormal
val subnormal_Mantissa = norm_result_Mantissa >> (1.U - norm_result_exp)
result := Cat(add_sub_sign, 0.U(8.W), subnormal_Mantissa(22, 0))
} .otherwise {
// normal result
result := Cat(add_sub_sign, norm_result_exp, norm_result_Mantissa(22, 0))
}
}
Utilizes comparator modules (COMP_20 and COMP_21) to perform numerical comparisons.
io.rm
to Select Operation Based on Mode:// Equal & Less Than & Less Than or Equal Code
is(FPU_FEQ_S, FPU_FLT_S, FPU_FLE_S) {
//result := Mux(io.in_A === io.in_B, 1.U, 0.U)
COMP_20.io.rs1 := io.A_data_in
COMP_20.io.rs2 := io.B_data_in
COMP_20_ := COMP_20.io.COMP_RESULT
COMP_21.io.rs1 := io.B_data_in
COMP_21.io.rs2 := io.A_data_in
COMP_21_ := COMP_21.io.COMP_RESULT
switch(io.rm) {
is(0.U) {
result := Mux((COMP_20_ && COMP_21_), 1.U, 0.U) //A>=B && B>=A -> A=B
}
is(1.U) {
result := Mux(COMP_20_, 0.U, 1.U) //A>=B -> A>B
}
is(2.U) {
result := Mux(COMP_21_, 0.U, 1.U) //B>=A -> A<=B
}
}
}
src/main/scala/Pipeline/UNits/F_Reg.scala
A independent register for FPU, similar structure to the original register file in src/main/scala/Pipeline/UNits/RegisterFile.scala
. But RV32F R4-type needs rs3 operation, here we add F_rs3
and F_rdata3
port.
class F_Reg extends Module {
val io = IO(new Bundle {
val F_rs1 = Input(UInt(5.W))
val F_rs2 = Input(UInt(5.W))
val F_rs3 = Input(UInt(5.W))
val F_reg_write = Input(Bool())
val F_w_reg = Input(UInt(5.W))
val F_w_data = Input(SInt(32.W))
val F_rdata1 = Output(SInt(32.W))
val F_rdata2 = Output(SInt(32.W))
val F_rdata3 = Output(SInt(32.W))
})
val F_regfile = RegInit(VecInit(Seq.fill(32)(0.S(32.W))))
io.F_rdata1 := Mux(io.F_rs1 === 0.U, 0.S, regfile(io.F_rs1))
io.F_rdata2 := Mux(io.F_rs2 === 0.U, 0.S, regfile(io.F_rs2))
io.F_rdata3 := Mux(io.F_rs3 === 0.U, 0.S, regfile(io.F_rs3))
when(io.F_reg_write && io.F_w_reg =/= 0.U) {
F_regfile(io.F_w_reg) := io.F_w_data
}
}
src/main/scala/Pipeline/Main.scala
Compare to original project, we set 2 parallel mux and use ID_EX_.io.ctrl_FPU_en_out
decide ALU/FPU read data. ID_EX_.io.ctrl_OpA_out === "b01".U
is for UJ-type and JALR, no need to care.
Forwarding A (rs1):
val d = Wire(SInt(32.W))
when (ID_EX_.io.ctrl_OpA_out === "b01".U) {
ALU.io.in_A := ID_EX_.io.IFID_pc4_out.asSInt
}.elsewhen (ID_EX_.io.ctrl_FPU_en_out === 1.U) { //RV32F rs1 data
FPU.io.A_data_in := MuxLookup(Forwarding.io.forward_a, 0.U, Array(
(0.U) -> ID_EX_.io.rs1_data_out,
(1.U) -> d,
(2.U) -> EX_MEM_M.io.EXMEM_fpu_out,
(3.U) -> ID_EX_.io.rs1_data_out
))
}.otherwise {
// forwarding A
when(Forwarding.io.forward_a === "b00".U) {
ALU.io.in_A := ID_EX_.io.rs1_data_out
}.elsewhen(Forwarding.io.forward_a === "b01".U) {
ALU.io.in_A := d
}.elsewhen(Forwarding.io.forward_a === "b10".U) {
ALU.io.in_A := EX_MEM_M.io.EXMEM_alu_out
}.otherwise {
ALU.io.in_A := ID_EX_.io.rs1_data_out
}
}
rs2 value controlled by ID_EX_.io.ctrl_FPU_en_out
& Forwarding.io.forward_b
to choose which wire to read. it also connect to EX_MEM_M.io.IDEX_rs2
. Because RS2 also hold the immediate in both ALU/FPU. for ALU/FPU rs2 data input, can be controlled by ID_EX_.io.ctrl_OpB_out
to read the immediate or rs2 value.
Forwarding B (rs2):
val RS2_value = Wire(SInt(32.W))
when (Forwarding.io.forward_b === 0.U) {
RS2_value := ID_EX_.io.rs2_data_out
}.elsewhen (Forwarding.io.forward_b === 1.U) {
RS2_value := d
}.elsewhen (Forwarding.io.forward_b === 2.U) {
RS2_value := Mux(ID_EX_.io.ctrl_FPU_en_out, EX_MEM_M.io.EXMEM_fpu_out, EX_MEM_M.io.EXMEM_alu_out)
}.otherwise {
RS2_value := 0.S
}
when(ID_EX_.io.ctrl_FPU_en_out === 1.U){
FPU.io.B_data_in := Mux(ID_EX_.io.ctrl_OpB_out, ID_EX_.io.imm_out, RS2_value)
}.otherwise{
ALU.io.in_B := Mux(ID_EX_.io.ctrl_OpB_out, ID_EX_.io.imm_out, RS2_value)
}
Forwarding C (rs3):
val RS3_value = Wire(SInt(32.W))
when (Forwarding.io.forward_c === 0.U) {
RS3_value := ID_EX_.io.rs3_data_out
}.elsewhen (Forwarding.io.forward_c === 1.U) {
RS3_value := d
}.elsewhen (Forwarding.io.forward_c === 2.U) {
RS3_value := EX_MEM_M.io.EXMEM_fpu_out
}.otherwise {
RS3_value := 0.S
}
FPU.io.C_data_in := Mux(ID_EX_.io.ctrl_OpC_out, RS3_value, 0.S)
Because we have rs3 for R4-type. a series of io for rs3 hazard is needed.
src/main/scala/Pipeline/Hazard Units/Forwarding.scala
EX_HAZARD and MEM_HAZARD operation method for rs1~rs3 are the same.
class Forwarding extends Module {
val io = IO(new Bundle {
val IDEX_rs1 = Input(UInt(5.W))
val IDEX_rs2 = Input(UInt(5.W))
val IDEX_rs3 = Input(UInt(5.W))
val EXMEM_rd = Input(UInt(5.W))
val EXMEM_regWr = Input(UInt(1.W))
val MEMWB_rd = Input(UInt(5.W))
val MEMWB_regWr = Input(UInt(1.W))
val forward_a = Output(UInt(2.W))
val forward_b = Output(UInt(2.W))
//RV32F rs3
val forward_c = Output(UInt(2.W))
})
io.forward_a := "b00".U
io.forward_b := "b00".U
io.forward_c := "b00".U
// EX HAZARD
when(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1.asUInt) && (io.EXMEM_rd === io.IDEX_rs2) && (io.EXMEM_rd === io.IDEX_rs3)) {
io.forward_a := "b10".U
io.forward_b := "b10".U
io.forward_c := "b10".U
}.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs3)) {
io.forward_c := "b10".U
}.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs2)) {
io.forward_b := "b10".U
}.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1)) {
io.forward_a := "b10".U
}
// MEM HAZARD
when((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs1) && (io.MEMWB_rd === io.IDEX_rs2) &&
~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
(io.EXMEM_rd === io.IDEX_rs1) && (io.EXMEM_rd === io.IDEX_rs2) && (io.EX_MEM_RD === io.IDEX_rs3)) ) {
io.forward_a := "b01".U
io.forward_b := "b01".U
io.forward_c := "b01".U
}.elsewhen((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs3) &&
~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && (io.EXMEM_rd === io.IDEX_rs3))){
io.forward_c := "b01".U
}.elsewhen((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs2) &&
~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && (io.EXMEM_rd === io.IDEX_rs2))){
io.forward_b := "b01".U
}.elsewhen((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs1) &&
~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && (io.EXMEM_rd === io.IDEX_rs1))){
io.forward_a := "b01".U
}
}
src/main/scala/Pipeline/Hazard Units/HazardDetection.scala
To avoid rs3 make error message, here we try to pre-processing the rs3, if it is not R4-type. Bypass rs3 hazard detection.
class HazardDetection extends Module {
val io = IO(new Bundle {
val IF_ID_inst = Input(UInt(32.W))
val ID_EX_memRead = Input(Bool())
val ID_EX_rd = Input(UInt(5.W))
val pc_in = Input(SInt(32.W))
val current_pc = Input(SInt(32.W))
val ID_EX_operandC = Input(Bool())
val inst_forward = Output(Bool())
val pc_forward = Output(Bool())
val ctrl_forward = Output(Bool())
val inst_out = Output(UInt(32.W))
val pc_out = Output(SInt(32.W))
val current_pc_out = Output(SInt(32.W))
})
val Rs1 = io.IF_ID_inst(19, 15)
val Rs2 = io.IF_ID_inst(24, 20)
val Rs3 = io.IF_ID_inst(31, 27)
val rs3detect = io.ID_EX_operandC & (io.ID_EX_rd === Rs3)
when(io.ID_EX_memRead === 1.B && ((io.ID_EX_rd === Rs1) || (io.ID_EX_rd === Rs2) || rs3detect)) {
io.inst_forward := true.B
io.pc_forward := true.B
io.ctrl_forward := true.B
}.otherwise {
io.inst_forward := false.B
io.pc_forward := false.B
io.ctrl_forward := false.B
}
io.inst_out := io.IF_ID_inst
io.pc_out := io.pc_in
io.current_pc_out := io.current_pc
}
src/main/scala/Pipeline/Hazard Units/StructuralHazard.scala
Same method for rs1-rs3.
class StructuralHazard extends Module {
val io = IO(new Bundle {
val rs1 = Input(UInt(5.W))
val rs2 = Input(UInt(5.W))
val MEM_WB_regWr = Input(Bool())
val MEM_WB_Rd = Input(UInt(5.W))
val fwd_rs1 = Output(Bool())
val fwd_rs2 = Output(Bool())
//RV32F rs3
val rs3 = Input(UInt(5.W))
val fwd_rs3 = Output(Bool())
})
...
// Determine if forwarding is needed for rs3
when(io.MEM_WB_regWr && io.MEM_WB_Rd === io.rs3) {
io.fwd_rs3 := true.B
}.otherwise {
io.fwd_rs3 := false.B
}
}
src/main/scala/Pipeline/Pipelines/IF_ID.scala
Same as the original structure. No modification is needed.
ID_EX module connects the control module and register file. Add ports for control module is needed. Also, rs3
concerned port added for R4-type (include operand c, OpC). Some ports decode from the instruction like fmt
, op5
and rm
concerned are for FPU and FPU_Control module.
class ID_EX extends Module {
val io = IO(new Bundle {
val rs1_in = Input(UInt(5.W))
val rs2_in = Input(UInt(5.W))
val rs3_in = Input(UInt(5.W))
val rs1_data_in = Input(SInt(32.W))
val rs2_data_in = Input(SInt(32.W))
val rs3_data_in = Input(SInt(32.W))
val imm = Input(SInt(32.W))
val rd_in = Input(UInt(5.W))
val func3_in = Input(UInt(3.W))
val func7_in = Input(Bool())
val ctrl_MemWr_in = Input(Bool())
val ctrl_Branch_in = Input(Bool())
val ctrl_MemRd_in = Input(Bool())
val ctrl_Reg_W_in = Input(Bool())
val ctrl_MemToReg_in = Input(Bool())
val ctrl_AluOp_in = Input(UInt(3.W))
//RV32F
val ctrl_FPU_en_in = Input(Bool())
val ctrl_FPU_Op_in = Input(UInt(3.W))
val FPU_func7_in = Input(UInt(5.W))
val FPU_fmt_in = Input(UInt(2.W))
val FPU_op5_in = Input(UInt(5.W))
val FPU_rm_in = Input(UInt(3.W))
//
val ctrl_OpA_in = Input(UInt(2.W))
val ctrl_OpB_in = Input(Bool())
//RV32F opc
val ctrl_OpC_in = Input(Bool())
//
val ctrl_nextpc_in = Input(UInt(2.W))
val IFID_pc4_in = Input(UInt(32.W))
val rs1_out = Output(UInt(5.W))
val rs2_out = Output(UInt(5.W))
val rs3_out = Output(UInt(5.W))
val rs1_data_out = Output(SInt(32.W))
val rs2_data_out = Output(SInt(32.W))
val rs3_data_out = Output(SInt(32.W))
val rd_out = Output(UInt(5.W))
val imm_out = Output(SInt(32.W))
val func3_out = Output(UInt(3.W))
val func7_out = Output(Bool())
val ctrl_MemWr_out = Output(Bool())
val ctrl_Branch_out = Output(Bool())
val ctrl_MemRd_out = Output(Bool())
val ctrl_Reg_W_out = Output(Bool())
val ctrl_MemToReg_out = Output(Bool())
val ctrl_AluOp_out = Output(UInt(3.W))
//RV32F
val ctrl_FPU_en_out = Output(Bool())
val ctrl_FPU_Op_out = Output(UInt(3.W))
val FPU_func7_out = Output(UInt(5.W))
val FPU_fmt_out = Output(UInt(2.W))
val FPU_op5_out = Output(UInt(5.W))
val FPU_rm_out = Output(UInt(3.W))
//
val ctrl_OpA_out = Output(UInt(2.W))
val ctrl_OpB_out = Output(Bool())
//RV32F opc
val ctrl_OpC_out = Output(Bool())
//
val ctrl_nextpc_out = Output(UInt(2.W))
val IFID_pc4_out = Output(UInt(32.W))
})
...
}
In src/main/scala/Pipeline/Main.scala
, we use fpu enable pin to decide ID-EX register output data goto ALU/FPU, also decide some inputs reads from which wire.
ID_EX_.io.rs1_in := Mux(control_module.io.fpu_en, F_RegFile.io.F_rs1, RegFile.io.rs1)
ID_EX_.io.rs2_in := Mux(control_module.io.fpu_en, F_RegFile.io.F_rs2, RegFile.io.rs2)
ID_EX_.io.rs3_in := Mux(control_module.io.fpu_en, F_RegFile.io.F_rs3, 0.U)
//ID_EX_.io.rs1_in := RegFile.io.rs1
//ID_EX_.io.rs2_in := RegFile.io.rs2
ID_EX_.io.imm := ImmValue
ID_EX_.io.func3_in := IF_ID_.io.SelectedInstr_out(14, 12)
ID_EX_.io.func7_in := IF_ID_.io.SelectedInstr_out(30)
ID_EX_.io.rd_in := IF_ID_.io.SelectedInstr_out(11, 7)
//RV32F, funct7 & fmt
ID_EX_.io.FPU_func7_in := IF_ID_.io.SelectedInstr_out(31, 27)
ID_EX_.io.FPU_fmt_in := IF_ID_.io.SelectedInstr_out(26, 25)
ID_EX_.io.FPU_op5_in := IF_ID_.io.SelectedInstr_out(6, 2)
EX_MEM_M.io.IDEX_rd := ID_EX_.io.rd_out
//FPU
when(ID_EX_.io.ctrl_FPU_en_out === 1.U){
FPU_Control.io.fpu_Op := ID_EX_.io.ctrl_FPU_Op_out
FPU_Control.io.fpu_funct3 := ID_EX_.io.func3_out
FPU_Control.io.fpu_funct7 := ID_EX_.io.FPU_func7_out
FPU.io.fpu_Op := FPU_Control.io.fpu_out
FPU_Control.io.fpu_enable := ID_EX_.io.ctrl_FPU_en_out
}.otherwise{
ALU_Control.io.aluOp := ID_EX_.io.ctrl_AluOp_out // Alu op code
ALU.io.alu_Op := ALU_Control.io.out // Alu op code
ALU_Control.io.func3 := ID_EX_.io.func3_out // function 3
ALU_Control.io.func7 := ID_EX_.io.func7_out // function 7
}
src/main/scala/Pipeline/Pipelines/EX_MEM.scala
Because Some wires are merged like rd, rs2 … etc in front stage and controlled by IDEX_fpu_en
. Here we add fpu_out
port to pass the result from FPU, and add EXMEM_fp_en
pass IDEX_fpu_en
signal.
class EX_MEM extends Module {
val io = IO(new Bundle {
val IDEX_MEMRD = Input(Bool())
val IDEX_MEMWR = Input(Bool())
val IDEX_MEMTOREG = Input(Bool())
val IDEX_REG_W = Input(Bool())
val IDEX_rs2 = Input(SInt(32.W))
val IDEX_rd = Input(UInt(5.W))
val alu_out = Input(SInt(32.W))
//RV32F
val fpu_out = Input(SInt(32.W))
val IDEX_fp_en = Input(bool)
val EXMEM_memRd_out = Output(Bool())
val EXMEM_memWr_out = Output(Bool())
val EXMEM_memToReg_out = Output(Bool())
val EXMEM_reg_w_out = Output(Bool())
val EXMEM_rs2_out = Output(SInt(32.W))
val EXMEM_rd_out = Output(UInt(5.W))
val EXMEM_alu_out = Output(SInt(32.W))
//RV32F
val EXMEM_fpu_out = Output(SInt(32.W))
val EXMEM_fp_en = Output(bool)
})
...
}
src/main/scala/Pipeline/Pipelines/MEM_WB.scala
Same as EX-MEM module. Just add a fpu_out
port pass data, and add MEMWB_fp_en
pass EXMEM_fp_en
signal.
Addition or modification is not needed. But we modify the address input for ALU/FPU memory operation. Because data input from io.EXMEM_rs2_out
was selected in fordwarding B section, modifacation is no needed.
In src/main/scala/Pipeline/Main.scala
:
// Data memory inputs
DataMemory.io.mem_read := EX_MEM_M.io.EXMEM_memRd_out
DataMemory.io.mem_write := EX_MEM_M.io.EXMEM_memWr_out
DataMemory.io.dataIn := EX_MEM_M.io.EXMEM_rs2_out
DataMemory.io.addr := Mux(EX_MEM_M.io.fp_en === 0.U, EX_MEM_M.io.EXMEM_alu_out.asUInt, EX_MEM_M.io.EXMEM_fpu_out.asUInt)
the write-back data pass through wire d to reg and rs1~rs3 data selector.
In src/main/scala/Pipeline/Main.scala
:
// Write back data to registerfile writedata
when (MEM_WB_M.io.MEMWB_memToReg_out === 0.U) {
d := Mux(MEM_WB_M.io.MEMWB_fp_en, MEM_WB_M.io.MEMWB_fpu_out, MEM_WB_M.io.MEMWB_alu_out) // data from Alu Result or FPU result
}.elsewhen (MEM_WB_M.io.MEMWB_memToReg_out === 1.U) {
d := MEM_WB_M.io.MEMWB_dataMem_out // data from Data Memory
}.otherwise {
d := 0.S
}
//RV32F
when(MEM_WB_M.io.MEMWB_fp_en){
F_RegFile.io.w_data := d
}.otherwise{
RegFile.io.w_data := d
}
Register and F register share a wire from MEMWB_reg_w_out
to reg_write
/F_reg_write
, and a wire from MEMWB_rd_out
to w_reg
/F_w_reg
. these 2 connections controlled by MEM_WB_M.io.MEMWB_fp_en
.
In src/main/scala/Pipeline/Main.scala
:
// Register file connections
//RV32F
when(MEM_WB_M.io.MEMWB_fp_en){
F_RegFile.io.F_w_reg := MEM_WB_M.io.MEMWB_rd_out
F_RegFile.io.F_reg_write := MEM_WB_M.io.MEMWB_reg_w_out
//F_RegFile.io.w_reg := ID_EX_.io.rd_out
}.otherwise{
RegFile.io.w_reg := MEM_WB_M.io.MEMWB_rd_out
RegFile.io.reg_write := MEM_WB_M.io.MEMWB_reg_w_out
//RegFile.io.w_reg := ID_EX_.io.rd_out
}
https://people.eecs.berkeley.edu/~krste/papers/riscv-spec-2.0.pdf
https://github.com/nozomioshi/ChiselRiscV
https://github.com/chadyuu/riscv-chisel-book
https://github.com/kinzafatim/5-Stage-RV32I