Implement RV32F

蔡承遠, 郭晏愷

GitHub

Introduction : "F" Extension

F Register State

The F extension adds 32 floating-point registers, f0–f31, each 32 bits wide, and a floating-point control and status register fcsr, which contains the operating mode and exception status of the f loating-point unit.

Regs ABI Name Description
f0-f07 ft0-7 Temporaries
f08-09 fs0-1 saved regs
f10-11 fa0-1 args/ return values
f12-17 fa2-7 args
f18-27 fs2-11 saved regs
f28-31 ft8-11 Temporaries

Instructions

The 2-bit floating-point format field fmt in bit number 26 and 25 is encoded asshown in the table. It is set to S(00) for all instructions in the F extension.

fmt Mnemonic Meaning
00 S 32-bit single-precision
01 D 64-bit double-precision
10 - reserved
11 Q 128-bit quad-precision

Some instructions for example

Inst. Name Description
fmadd.s mul-add rd = rs1 * rs2 + rs3
fadd.s add rd = rs1 + rs2
fmul.s mul rd = rs1 * rs2
fdiv.s div rd = rs1 / rs2
fsqrt.s square root rd = sqrt(rs1)
flt.s less than rd = (rs1 < rs2) ? 1 : 0
fcvt.s
[fmadd.s rd, rs1, rs2, rs3]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| rs3           | 00  | rs2          | rs1          | rm     | rd           | 1000011             |

[fadd.s rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0000000             | rs2          | rs1          | rm     | rd           | 1010011             | 

[fmul.s rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0001000             | rs2          | rs1          | rm     | rd           | 1010011             | 

[fdiv rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0001100             | rs2          | rs1          | rm     | rd           | 1010011             | 

[fsqrt.s rd, rs1]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 0101100             | 00000        | rs1          | rm     | rd           | 1010011             | 

[flt rd, rs1, rs2]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
| 1010000             | rs2          | rs1          | 001    | rd           | 1010011             | 

[fcvt]
| 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 |
fcvt.W.s
fcvt.WU.s
fcvt.s.W
fcvt.S.WU

rm(12~14, funct3) bit meaning in F instruction is to describe the rounding mode

Rounding Mode meaning Mnemonic
0d0 Round to Nearest, ties to Even. RNE
0d1 Round towards Zero RTZ
0d2 Round Down (towards −∞) RDN
0d3 Round Up (towards +∞) RUP
0d4 Round to Nearest, ties to Max Magnitude RMM
0d5 Invalid.
0d6 Invalid.
0d7 In instruction’s rm field, selects dynamic rounding mode; In Rounding Mode register, Invalid.

rd bit meaning in F-Classify instruction.

bit meaning
0d0 rs1 is −∞.
0d1 rs1 is a negative normal number.
0d2 rs1 is a negative subnormal number.
0d3 rs1 is −0.
0d4 rs1 is +0.
0d5 rs1 is a positive subnormal number.
0d6 rs1 is a positive normal number.
0d7 rs1 is +∞.
0d8 rs1 is a signaling NaN.
0d9 rs1 is a quiet NaN.

Intro: IEEE 754

For 32-bit single-precision floating point, the value = sign * exponent * fraction.

sign exponent fraction
bit number 31 23-30 22-0
length 1 8 23
special values
0 0 00000000 all zero
-0 1 00000000 all zero
1 0 01111111 all zero
-1 1 01111111 all zero
min. sub-normal number * 00000000 000 0000 0000 0000 0000 0001
max. sub-normal number * 00000000 all one
min. normal number * 00000001 all zero
max. normal number * 11111110 all one
-∞ 1 11111111 all zero
0 11111111 all zero
NaN * 11111111 not all zero

Implementation RV32F in 5-stage pipeline

structure overview

Base on 5-stage-RV32I from kinzafatim. The principle to implement RV32F is make a floating-point unit (FPU) parallel with Arithmetic logic unit (ALU). A FPU control and a FPU register unit are also needed. these three units are compared to ALU, which are ALU, ALU control and ALU register.

The figure below is a quick draw using matlab simulink to know the whole structure. Because matlab simulink doesn't support some feactures like mux selector, so it shows red dotted line in the picture.

Image Not Showing Possible Reasons
  • The image was uploaded to a note which you don't have access to
  • The note which the image was originally uploaded to has been deleted
Learn More →

Control

  • Locate in src/main/scala/Pipeline/UNits/Control.scala
  • In original project, this module is for opcode(type) decoding and decide the operation mode for other module.

In order to control and decide the data read/write data in FPU or ALU and simplify the structure that some wires (e.g: data wire between memory and register) can be common used outside ALU and FPU. We add a FPU enable port fpu_en to decide this operation using ALU or FPU. Port fpu_operation decode the F instruction type from opcode. Because we have rs3 for RV32F R4-type instruction, adding operand_C for operand C source selection is necessary.

class Control extends Module {
  val io = IO(new Bundle {
    val opcode = Input(UInt(7.W))         // 7-bit opcode
    val mem_write = Output(Bool())        // whether a write to memory
    val branch = Output(Bool())           // whether a branch instruction
    val mem_read = Output(Bool())         // whether a read from memory
    val reg_write = Output(Bool())        // whether a register write
    val men_to_reg = Output(Bool())       // whether the value written to a register (for load instructions)
    val alu_operation = Output(UInt(3.W))
    val fpu_en = Output(Bool())
    val fpu_operation = Output(UInt(3.W))
    val operand_A = Output(UInt(2.W))  // Operand A source selection for the ALU
    val operand_B = Output(Bool()) // Operand B source selection for the ALU
    val operand_C = Output(Bool()) // Operand C source selection for the FPU

    // Indicates the type of extension to be used (e.g., sign-extend, zero-extend)
    val extend = Output(UInt(2.W))   
    val next_pc_sel = Output(UInt(2.W)) // next PC value (e.g., PC+4, branch target, jump target)
  })
      ...
}

FPU design

FPU control

  • Locate in src/main/scala/Pipeline/UNits/FPU_Control.scala
  • RV32F instruction decode

While reciving the enable signal fpu_enable, this module reads the instruction type from fpu_op port, encodes the operation code in different methods depending on the instruction type. After encoding the operation code, output to FPU via fpu_out port.

class FPU_Control extends Module {
  val io = IO(new Bundle {
    val fpu_op = Input(UInt(5.W))
    val fpu_funct3 = Input(UInt(3.W))
    val fpu_funct7 = Input(UInt(5.W)) //27~31
    val fpu_op5    = Input(UInt(5.W)) //6~2
    val fpu_out    =  Output(UInt(5.W))
    val fpu_enable = Input(Bool())
  })

  io.fpu_out := 0.U
  when(io.fpu_enable) {
    
    when(io.fpu_op === 0.U) {       //R type 
      io.fpu_out := io.fpu_funct7    
    }.elsewhen(io.fpu_op === 1.U) {//R4 type
      io.fpu_out := io.fpu_op5   
    }.elsewhen(io.fpu_op === 2.U) {//I type
      io.fpu_out := Cat("b010".U(3.W), io.fpu_funct3)    
    }.elsewhen(io.fpu_op === 3.U) {//S type
      io.fpu_out := "b11111".U
    }.otherwise {
      io.fpu_out := 0.U
    }
  }
}  

fpu_op reads the instruction type decoded from control module.
fpu_funct3 reads the function 3 in the instruction.
fpu_funct7 reads the function 7 in the instruction.
fpu_op5 reads the opcode(6, 2) in the instruction.
Below is the encoded operation code lookup table, which built in FPU:

operation Code instruction sub-instruction identify in FPU
00 fadd
01 fsub
02 fmul
03 fdiv
04 fsgn rm(0/1)
05 fmin/fmax rm(0/1)
11 fsqrt
16 fmadd
17 fmsub
18 fnmsub
19 fnmadd
20 feq/flt/fle rm(0/1/2)
24 fcvt.w.s/fcvt.wu.s rs2(0/1)
26 fcvt.s.w/fcvt.s.wu rs2(0/1)
28 fmv/fclass rm(0/1)
30 fmv.s.x

FPU

  • Locate in src/main/scala/Pipeline/UNits/FPU.scala
  • Whole FP operation work here.
  • fmt and rm are considered only here.

Include embedded comparator modules FP_COMP. FP_COMP file locate in src/main/scala/Pipeline/UNits/FP_COMP.scala.

The implemented floating-point operations include addition, subtraction, multiplication, division, square root, fused multiply-add, less-than comparison, and conversions between floating-point and integer types.

We used different io.fpu_Op values to select and execute specific floating-point operations in our FPU module. Each operation corresponds to a distinct block of logic within the switch statement, which dynamically determines the behavior based on the provided opcode.

class FPU extends Module{
    val io = IO(new Bundle{
        val A_data_in = Input(UInt(32.W))
        val B_data_in = Input(UInt(32.W))
        val C_data_in = Input(UInt(32.W))      
        val fpu_Op = Input(UInt(5.W))
        val rm = Input(UInt(3.W))
        val rs2 = Input(UInt(5.W))
        val fmt = Input(UInt(2.W))
        val out = Output(UInt(32.W))
    })

    val NAN_exp = "b11111111".U(8.W)
    val COMP_0  =   Module(new FP_COMP)
    val COMP_16 =   Module(new FP_COMP)
    val COMP_20 =   Module(new FP_COMP)
    val COMP_21 =   Module(new FP_COMP)

    val result = 0.U(32.W)
    val result_sign = 0.U(1.W)
    val result_exp = 0.U(8.W)
    val result_frac = 0.U(23.W)
    val carry = 0.U(1.W)
    val exp_diff = 0.U(8.W)

    val A_sign = 0.U(1.W)
    val B_sign = 0.U(1.W)
    val C_sign = 0.U(1.W)
    val A_exp = 0.U(1.W)
    val B_exp = 0.U(1.W)
    val C_exp = 0.U(1.W)
    val A_Mantissa = 0.U(24.W)
    val B_Mantissa = 0.U(24.W)
    val C_Mantissa = 0.U(24.U)
    val Temp_Exp = 0.U(24.W)
    val Temp_Mantissa = 0.U(24.W)
    val B_shift_mantissa = 0.U(24.W)
    val COMP = 0.U(1.W)
    val COMP_20_ = 0.U(1.W)
    val COMP_21_ = 0.U(1.W)

    //comparator
    COMP_0.io.rs1 := io.A_data_in
    COMP_0.io.rs2 := io.B_data_in
    COMP := COMP_0.io.COMP_RESULT   

    val A_swap = COMP ? io.A_data_in : io.B_data_in
    val B_swap = COMP ? io.B_data_in : io.A_data_in

    switch(io.fpu_Op) {
        is(FPU_FADD_S) {
          //result := io.in_A + io.in_B
        }
        is(FPU_FSUB_S) {
          //result := io.in_A - io.in_B
        }
        is(FPU_FMUL_S) {
          //result := io.in_A * io.in_B
        }
        is(FPU_FSQRT_S){
          //result := io.in_A^(1/2)
        }
        is(FPU_FDIV_S){
          //result := io.in_A / io.in_B    
        }
        is(FPU_FMADD_S){
          //result := (io.in_A * io.in_B) + io.in_C
        }
        is(FPU_FCVT_W_S) {
          //result := io.in_A.asSInt
        }
        is(FPU_FLT_S) {
        //result := io.in_A < io.in_B
        }
    }

A_data_in, B_data_in, C_data_in are three inputs from the external environment.
COMP is used for comparison between two values.
carry is a register indicating whether a carry is generated during mantissa operations.
exp_diff is an 8-bit register storing the difference between the exponents of two input floating-point numbers, used for mantissa alignment.
A_sign, B_sign, C_signare three registers representing the sign bit of A_data_in, B_data_in, and C_data_in, respectively.


Addition

The adder (FADD) is designed to perform precise addition operations between two single-precision floating-point numbers.

  1. Decomposes the input floating-point numbers into their respective sign, exponent, and fraction fields.
  2. Calculates the exponent difference to align the mantissas.
  3. Once aligned, the mantissas are added or subtracted based on the sign bits, and the result is normalized using carry to ensure accuracy.
// Addition Code

is(FPU_FADD_S) {
  A_Mantissa := Cat(1.U(1.W), A_swap(22, 0))
  B_Mantissa := Cat(1.U(1.W), B_swap(22, 0))
  
  // shift mantissa
  exp_diff := A_swap(30, 23) - B_swap(30, 23)
  // add the two mantissa numbers and store the carry
  B_shift_mantissa := B_Mantissa >> exp_diff
  Temp_Mantissa := (A_swap(31) ^ B_swap(31)) 
             ? A_Mantissa - B_shift_mantissa 
             : A_Mantissa + B_shift_mantissa
  carry := Temp_Mantissa(24)
  // normalize the result
  when(carry) {
    Temp_Mantissa := Temp_Mantissa >> 1
    when(result_exp < 255.U) {
      result_exp := A_swap(30:23) + 1.U
      // result_exp := result_exp + 1.U
    } .otherwise {
      result_exp := 255.U
    }
  } .elsewhen(Temp_Mantissa =/= 1.U) {
    Temp_Mantissa := 0.U
  } .otherwise {
    for(i <- 0 until 24) {
      when(Temp_Mantissa(23) =/= 1.U && result_exp > 0.U) {
        Temp_Mantissa := Temp_Mantissa << 1
        result_exp := result_exp - 1.U
      }
    }
  }
  result_sign := A_swap(31)
  result_frac := Temp_Mantissa(22, 0)
  result := Cat(result_sign, result_exp, result_frac)
}

Subtraction

The subtractor (FSUB) operates similarly to the adder (FADD), with the key difference being how the sign bit of input B is handled. The aligned mantissas are either added or subtracted based on the XOR of the sign bits. Therefore we changed the code of Temp_Mantissa from adder to subtractor.

// Subtraction Code

is(FPU_FSUB_S) {
  A_Mantissa := Cat(1.U(1.W), A_swap(22, 0))
  B_Mantissa := Cat(1.U(1.W), B_swap(22, 0))
  
  // shift mantissa
  exp_diff := A_swap(30, 23) - B_swap(30, 23)
  // add the two mantissa numbers and store the carry
  B_shift_mantissa := B_Mantissa >> exp_diff
  Temp_Mantissa := (A_swap(31) ^ B_swap(31)) 
             ? A_Mantissa + B_shift_mantissa 
             : A_Mantissa - B_shift_mantissa
  carry := Temp_Mantissa(24)

  when(carry) {
    Temp_Mantissa := Temp_Mantissa >> 1
    when(result_exp < 255.U) {
      result_exp := A_swap(30:23) + 1.U 
    } .otherwise {
      result_exp := 255.U
    }
  } .elsewhen(Temp_Mantissa =/= 1.U) {
    Temp_Mantissa := 0.U
  } .otherwise {
    for(i <- 0 until 24) {
      when(Temp_Mantissa(23) =/= 1.U && result_exp > 0.U) {
        Temp_Mantissa := Temp_Mantissa << 1
        result_exp := result_exp - 1.U
      }
    }
  }
  result_sign := A_swap(31)
  result_frac := Temp_Mantissa(22, 0)
  result := Cat(result_sign, result_exp, result_frac)
}

Multiplication
  1. Decomposes the input floating-point numbers into their respective sign, exponent, and fraction fields.
  2. Sign computation: Calculate the sign bit of the result.
  3. Exponent computation: Add the exponents and subtract the bias.
  4. Mantissa computation: Multiply the mantissas and perform normalization.

If the MSB of the Mantissa_product = 1 (indicating overflow), the mantissa must be right-shifted, and the exponent must be incremented to maintain normalization.

Normalized
Mantissa
=
Mantissa
_
product[46:24]

Normalized
Exponent
=
Exponent
Sum+1

If Mantissa_prodct = 0, the exponent remains unchanged.

  1. Special cases :
    5-1. If the mantissa is zero, the result is set to zero.
    5-2. If the exponent overflows, set the result to infinity or a subnormal number.
//Multiplication Code

is(FPU_FMUL_S) {
  // extract the sign bit, exponent, and mantissa
  A_sign = A_swap(31)
  B_sign = B_swap(31)
  A_exp = A_swap(30, 23)
  B_exp = B_swap(30, 23)
  A_Mantissa = Cat(1.U(1.W), A_swap(22, 0))
  B_Mantissa = Cat(1.U(1.W), B_swap(22, 0))
  
  // multiply the signs
  result_sign = A_sign ^ B_sign
  // add the exponents
  val exp_sum = A_exp + B_exp - 127.U
  // multiply the mantissa
  val Mantissa_product = A_Mantissa * B_Mantissa
  // normalize
  carry = Mantissa_product(47)
  val normalized_Mantissa = Mux(carry, Mantissa_product(46, 24), Mantissa_product(45, 23))
  val normalized_exp = Mux(carry, exp_sum + 1.U, exp_sum)

  // exponent overflow or zero mantissa
  when(Mantissa_product === 0.U) {
    result_exp := 0.U
    result_frac := 0.U
  } .elsewhen(normalized_exp >= 255.U) {
    result_exp := 255.U // infinity
    result_frac := 0.U
  } .elsewhen(normalized_exp <= 0.U) {
    result_exp := 0.U // subnormal number
    result_frac := normalized_Mantissa >> (1.U - normalized_exp)
  } .otherwise {
    result_exp := normalized_exp
    result_frac := normalized_Mantissa
  }

Square Root
  1. Decomposes the input floating-point numbers into their respective sign, exponent, and fraction fields.
  2. Handle special cases of inputs: Manage cases for negative input, zero, or subnormal numbers.
  3. Adjust exponent: Correct the square root exponent based on its parity (odd or even).
  4. Newton's Iteration: Iteratively approximate the square root of the mantissa.

The target is to solve

f(s)=x2N=0, where N is the input number.
Newton's method uses the formula:

xn+1=xnf(xn)f(xn)

For

f(x)=x2N, the
f(x)=2x
.Substituting these:

xn+1=xn+Nxn2

Start with an initial guess

x0=1, and repeat the iteration formula three times to converge to
N

牛頓迭代法

  1. Special cases:
    5-1. If input is negative, output is set to NaN.
    5-2. If input is zero, output is set to zero.
    5-3. If input is infinite, output is set to infinite
  2. Combine result: Form the final result by integrating the sign, normalized exponent, and mantissa.
// Square Root Code

is(FPU_FSQRT_S) {
  // extract the sign bit, exponent, and mantissa
  A_sign = A_swap(31)
  A_exp = A_swap(30, 23)
  A_Mantissa = Cat(1.U(1.W), A_swap(22, 0))

  when(A_sign) {
    // when input negative, output NaN
    result := "b01111111100000000000000000000000".U
  } .elsewhen(A_swap(30, 0) === 0.U) {
    // when input zero, output zero
    result := A_swap
  } .otherwise {
    // exponent
    val exp_adjust = A_exp - 127.U 
    // odd or even number
    val new_exp = Mux(exp_adjust(0),
      (exp_adjust >> 1) + 127.U,
      ((exp_adjust - 1.U) >> 1) + 127.U
    )
    // Newton Iteration(three times)
    val x0 = "b01000000000000000000000000000000".U(32.W) // x0 assumption = 1
    val iter1 = (x0 + (A_Mantissa << 23) / x0) >> 1
    val iter2 = (iter1 + (A_Mantissa << 23) / iter1) >> 1
    val iter3 = (iter2 + (A_Mantissa << 23) / iter2) >> 1
    val sqrt_Mantissa = iter2(45, 23) 

    // final result
    result_sign := 0.U 
    result_exp := new_exp
    result_frac := sqrt_Mantissa
    result := Cat(result_sign, result_exp, result_frac)
  }
}

Division
  1. Decomposes the input floating-point numbers into their respective sign, exponent, and fraction fields.
  2. Compute the sign and adjust the exponent difference.
  3. Newton's Iteration: Iteratively approximate the reciprocal of the denominator.
  4. Compute the final mantissa: Multiply the numerator mantissa by the reciprocal.
  5. Normalize the result: Adjust the mantissa and exponent to conform to the IEEE 754 format.
  6. Special cases:
    6-1. If input is zero denominator, output is set to NaN.
    6-2. If input is zero numerator, output is set to zero.
    6-3. If output exponent overflow, the result is set to infinity.
    6-4. If output is subnormal number, the result exponent is set to zero, and the result mantissa is right-shifted.
// Division Code

is(FPU_FDIV_S) {
  // extract the sign bit, exponent, and mantissa
  A_sign := A_swap(31)
  B_sign := B_swap(31)
  A_exp := A_swap(30, 23)
  B_exp := B_swap(30, 23)
  A_Mantissa := Cat(1.U(1.W), A_swap(22, 0))
  B_Mantissa := Cat(1.U(1.W), B_swap(22, 0))

  // divide the signs
  val result_sign = A_sign ^ B_sign
  // minus the exponents
  val exp_diff = A_exp - B_exp + 127.U

  // Newton Iteration(three times) 1 / B_Mantissa
  val x0 = "b01000000000000000000000000000000".U(32.W) // x0 assumption = 1
  val iter1 = (x0 + (1.U << 23) - ((B_Mantissa << 23) / x0)) >> 1
  val iter2 = (iter1 + (1.U << 23) - ((B_Mantissa << 23) / iter1)) >> 1
  val iter3 = (iter2 + (1.U << 23) - ((B_Mantissa << 23) / iter2)) >> 1
  val reciprocal = iter3(45, 23) 

  // calculate mantissa A_Mantissa / B_Mantissa = A_Mantissa * reciprocal
  val Mantissa_div = A_Mantissa * reciprocal

  // normalize
  carry = Mantissa_div(47)
  val normalized_Mantissa = Mux(carry, Mantissa_div(46, 24), Mantissa_div(45, 23))
  val normalized_exp = Mux(carry, exp_diff + 1.U, exp_diff)

  when(B_swap(30, 0) === 0.U) {
    // B=0:NaN
    result := "b01111111100000000000000000000000".U // NaN
  } .elsewhen(A_swap(30, 0) === 0.U) {
    // A=0:0
    result := 0.U
  } .elsewhen(normalized_exp >= 255.U) {
    // exponent overflow:exp=infinity, mantissa=0
    result_exp := 255.U
    result_frac := 0.U
    result := Cat(result_sign, result_exp, result_frac)
  } .elsewhen(normalized_exp <= 0.U) {
    // subnormal:exp=0, mantissa right shift
    result_exp := 0.U
    result_frac := normalized_Mantissa >> (1.U - normalized_exp)
    result := Cat(result_sign, result_exp, result_frac)
  } .otherwise {

    // final result
    result_exp := normalized_exp
    result_frac := normalized_Mantissa
    result := Cat(result_sign, result_exp, result_frac)
  }
}

Fused Multiply-Add

Combines multiplication and addition into a single operation to improve efficiency and minimize precision loss.

(AB)+C

  1. Decomposes the input floating-point numbers into their respective sign, exponent, and fraction fields.
  2. Compute the mantissa product, sign, and adjust the exponent.And normalize the product.
  3. Compute the sign and adjust the exponent difference.
  4. Adjust mantissas based on the exponent difference.
  5. Compute the final result based on the signs, addition or subtraction.
  6. Normalize the result: Adjust the mantissa and exponent to conform to the IEEE 754 format.
  7. Special cases:
    7-1. If any input is NaN, output is set to NaN.
    7-2. If one operand is zero, output is set to the operand.
    7-3. If output exponent overflow, the result is set to infinity.
    7-4. If output is overflow or subnormal number, the result is handled appropriately by adjusting the mantissa and exponen.
// Fused Multiply-Add Code

is(FPU_FMADD_S) {
  // extract the sign bit, exponent, and mantissa
  A_sign := io.A_data_in(31)
  B_sign := io.B_data_in(31)
  C_sign := io.C_data_in(31)
  A_exp := io.A_data_in(30, 23)
  B_exp := io.B_data_in(30, 23)
  C_exp := io.C_data_in(30, 23)
  A_Mantissa := Cat(1.U(1.W), io.A_data_in(22, 0))
  B_Mantissa := Cat(1.U(1.W), io.B_data_in(22, 0))
  C_Mantissa := Cat(1.U(1.W), io.C_data_in(22, 0))

  // calculate A*B
  val mul_sign = A_sign ^ B_sign
  val mul_exp = A_exp + B_exp - 127.U
  val mul_Mantissa = A_Mantissa * B_Mantissa

  // normalize_A*B
  val mul_carry = mul_Mantissa(47)
  val mul_norm_Mantissa = Mux(mul_carry, mul_Mantissa(46, 24), mul_Mantissa(45, 23))
  val mul_norm_exp = Mux(mul_carry, mul_exp + 1.U, mul_exp)

  // comparator_A*B vs C
  COMP_16.io.rs1 := mul_norm_Mantissa
  COMP_16.io.rs2 := C_Mantissa
  COMP := COMP_16.io.COMP_RESULT

  // alogned number
  val exp_diff = (mul_norm_exp - C_exp).asSInt()
  val safe_shift = Mux(exp_diff < 0.S, -exp_diff.asUInt, exp_diff.asUInt)
  val aligned_Mul_Mantissa = Mux(COMP, mul_norm_Mantissa, mul_norm_Mantissa >> safe_shift)
  val aligned_C_Mantissa = Mux(COMP, C_Mantissa >> safe_shift, C_Mantissa)
  val aligned_exp = Mux(COMP, mul_norm_exp, C_exp)

  // ADD & SUB result
  val add_sub_result = Mux(mul_sign === C_sign,
    aligned_Mul_Mantissa + aligned_C_Mantissa,
    aligned_Mul_Mantissa - aligned_C_Mantissa
  )
  val add_sub_sign = Mux(aligned_Mul_Mantissa >= aligned_C_Mantissa, mul_sign, C_sign)

  // normalize result
  val result_carry = add_sub_result(24)
  val norm_result_Mantissa = Mux(result_carry, add_sub_result(23, 1), add_sub_result(22, 0))
  val norm_result_exp = Mux(result_carry, aligned_exp + 1.U, aligned_exp)

  // special cases
  when((A_exp === 255.U && A_Mantissa =/= 0.U) ||
      (B_exp === 255.U && B_Mantissa =/= 0.U) ||
      (C_exp === 255.U && C_Mantissa =/= 0.U)) {
    // A or B or C input = NaN
    result := "b0111111111000000000000000000000".U(32.W) // Output NaN
  }
  when(io.A_data_in(30, 0) === 0.U || io.B_data_in(30, 0) === 0.U) {
    // A or B = 0
    result := io.C_data_in
  } .elsewhen(io.C_data_in(30, 0) === 0.U) {
    // C = 0
    result := Cat(mul_sign, mul_norm_exp, mul_norm_Mantissa(22, 0))
  } .elsewhen(norm_result_exp >= 255.U) {
    // overflow
    result := Cat(add_sub_sign, "b11111111".U(8.W), 0.U(23.W))
  } .elsewhen(norm_result_exp <= 0.U) {
    // subnormal
    val subnormal_Mantissa = norm_result_Mantissa >> (1.U - norm_result_exp)
    result := Cat(add_sub_sign, 0.U(8.W), subnormal_Mantissa(22, 0))
  } .otherwise {
    // normal result
    result := Cat(add_sub_sign, norm_result_exp, norm_result_Mantissa(22, 0))
  }
}

Equal & Less Than & Less Than or Equal

Utilizes comparator modules (COMP_20 and COMP_21) to perform numerical comparisons.

  1. Decompose Inputs: Reads two floating-point numbers.
  2. Use comparators COMP_20 and COMP_21 to evaluate all necessary comparisons, supporting any direction.
  3. Use io.rm to Select Operation Based on Mode:
    3-1. Mode 0: Checks for equality.
    3-2. Mode 1: Checks for greater than.
    3-3. Mode 2: Checks for less than or equal.
// Equal & Less Than & Less Than or Equal Code

is(FPU_FEQ_S, FPU_FLT_S, FPU_FLE_S) {
  //result := Mux(io.in_A === io.in_B, 1.U, 0.U)
  COMP_20.io.rs1 := io.A_data_in
  COMP_20.io.rs2 := io.B_data_in
  COMP_20_ := COMP_20.io.COMP_RESULT
  COMP_21.io.rs1 := io.B_data_in
  COMP_21.io.rs2 := io.A_data_in
  COMP_21_ := COMP_21.io.COMP_RESULT
  switch(io.rm) {
    is(0.U) {
      result := Mux((COMP_20_ && COMP_21_), 1.U, 0.U) //A>=B && B>=A -> A=B
    }
    is(1.U) {
      result := Mux(COMP_20_, 0.U, 1.U) //A>=B -> A>B
    }
    is(2.U) {
      result := Mux(COMP_21_, 0.U, 1.U) //B>=A -> A<=B
    }
  }
}

FPU register file

  • Locate in src/main/scala/Pipeline/UNits/F_Reg.scala

A independent register for FPU, similar structure to the original register file in src/main/scala/Pipeline/UNits/RegisterFile.scala. But RV32F R4-type needs rs3 operation, here we add F_rs3 and F_rdata3 port.

class F_Reg extends Module {
  val io = IO(new Bundle {
    val F_rs1       = Input(UInt(5.W))
    val F_rs2       = Input(UInt(5.W))
    val F_rs3       = Input(UInt(5.W))
    val F_reg_write = Input(Bool())
    val F_w_reg     = Input(UInt(5.W))
    val F_w_data    = Input(SInt(32.W))
    val F_rdata1    = Output(SInt(32.W))
    val F_rdata2    = Output(SInt(32.W))
    val F_rdata3    = Output(SInt(32.W))
  })
  val F_regfile = RegInit(VecInit(Seq.fill(32)(0.S(32.W))))

  io.F_rdata1 := Mux(io.F_rs1 === 0.U, 0.S, regfile(io.F_rs1))
  io.F_rdata2 := Mux(io.F_rs2 === 0.U, 0.S, regfile(io.F_rs2))
  io.F_rdata3 := Mux(io.F_rs3 === 0.U, 0.S, regfile(io.F_rs3))

  when(io.F_reg_write && io.F_w_reg =/= 0.U) {
    F_regfile(io.F_w_reg) := io.F_w_data
  }
}

ALU/FPU rs1-rs3 value input

  • Locate in src/main/scala/Pipeline/Main.scala
  • Forwarding A~C use to decide which data wire shoule be read into rs1-rs3 as value input, will be mentioned in forwarding section.
  • wire d carries write-back data, will be mentioned in Write-back section.

Compare to original project, we set 2 parallel mux and use ID_EX_.io.ctrl_FPU_en_out decide ALU/FPU read data. ID_EX_.io.ctrl_OpA_out === "b01".U is for UJ-type and JALR, no need to care.

Forwarding A (rs1):

    val d = Wire(SInt(32.W))

    when (ID_EX_.io.ctrl_OpA_out === "b01".U) {
        ALU.io.in_A := ID_EX_.io.IFID_pc4_out.asSInt
    }.elsewhen (ID_EX_.io.ctrl_FPU_en_out === 1.U) { //RV32F rs1 data
        FPU.io.A_data_in := MuxLookup(Forwarding.io.forward_a, 0.U, Array(
            (0.U) -> ID_EX_.io.rs1_data_out,
            (1.U) -> d,
            (2.U) -> EX_MEM_M.io.EXMEM_fpu_out,
            (3.U) -> ID_EX_.io.rs1_data_out
        ))
    }.otherwise {
        // forwarding A
        when(Forwarding.io.forward_a === "b00".U) {
            ALU.io.in_A := ID_EX_.io.rs1_data_out
        }.elsewhen(Forwarding.io.forward_a === "b01".U) {
            ALU.io.in_A := d
        }.elsewhen(Forwarding.io.forward_a === "b10".U) {
            ALU.io.in_A := EX_MEM_M.io.EXMEM_alu_out
        }.otherwise {
            ALU.io.in_A := ID_EX_.io.rs1_data_out
        }
      }

rs2 value controlled by ID_EX_.io.ctrl_FPU_en_out & Forwarding.io.forward_b to choose which wire to read. it also connect to EX_MEM_M.io.IDEX_rs2. Because RS2 also hold the immediate in both ALU/FPU. for ALU/FPU rs2 data input, can be controlled by ID_EX_.io.ctrl_OpB_out to read the immediate or rs2 value.
Forwarding B (rs2):

    val RS2_value = Wire(SInt(32.W)) 
    when (Forwarding.io.forward_b === 0.U) {
      RS2_value := ID_EX_.io.rs2_data_out
    }.elsewhen (Forwarding.io.forward_b === 1.U) {
      RS2_value := d
    }.elsewhen (Forwarding.io.forward_b === 2.U) {
      RS2_value := Mux(ID_EX_.io.ctrl_FPU_en_out, EX_MEM_M.io.EXMEM_fpu_out, EX_MEM_M.io.EXMEM_alu_out)
    }.otherwise {
      RS2_value := 0.S
    }

    when(ID_EX_.io.ctrl_FPU_en_out === 1.U){
      FPU.io.B_data_in := Mux(ID_EX_.io.ctrl_OpB_out, ID_EX_.io.imm_out, RS2_value)
    }.otherwise{
      ALU.io.in_B := Mux(ID_EX_.io.ctrl_OpB_out, ID_EX_.io.imm_out, RS2_value)
    }

Forwarding C (rs3):

   val RS3_value = Wire(SInt(32.W))
    when (Forwarding.io.forward_c === 0.U) {
      RS3_value := ID_EX_.io.rs3_data_out
    }.elsewhen (Forwarding.io.forward_c === 1.U) {
      RS3_value := d
    }.elsewhen (Forwarding.io.forward_c === 2.U) {
      RS3_value := EX_MEM_M.io.EXMEM_fpu_out
    }.otherwise {
      RS3_value := 0.S
    }
    FPU.io.C_data_in := Mux(ID_EX_.io.ctrl_OpC_out, RS3_value, 0.S)

Hazard units

Because we have rs3 for R4-type. a series of io for rs3 hazard is needed.

forwarding

  • Locate in src/main/scala/Pipeline/Hazard Units/Forwarding.scala

EX_HAZARD and MEM_HAZARD operation method for rs1~rs3 are the same.

class Forwarding extends Module {
    val io = IO(new Bundle {
        val IDEX_rs1 = Input(UInt(5.W))
        val IDEX_rs2 = Input(UInt(5.W))
        val IDEX_rs3 = Input(UInt(5.W))
        val EXMEM_rd = Input(UInt(5.W))
        val EXMEM_regWr = Input(UInt(1.W))
        val MEMWB_rd = Input(UInt(5.W))
        val MEMWB_regWr = Input(UInt(1.W))
        
        val forward_a = Output(UInt(2.W))
        val forward_b = Output(UInt(2.W))
        //RV32F rs3
        val forward_c = Output(UInt(2.W))
    })

    io.forward_a := "b00".U
    io.forward_b := "b00".U
    io.forward_c := "b00".U

    // EX HAZARD
    when(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && 
            (io.EXMEM_rd === io.IDEX_rs1.asUInt) && (io.EXMEM_rd === io.IDEX_rs2) && (io.EXMEM_rd === io.IDEX_rs3)) {
        io.forward_a := "b10".U
        io.forward_b := "b10".U
        io.forward_c := "b10".U

    }.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && 
            (io.EXMEM_rd === io.IDEX_rs3)) {    
        io.forward_c := "b10".U

    }.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && 
            (io.EXMEM_rd === io.IDEX_rs2)) {    
        io.forward_b := "b10".U
    
    }.elsewhen(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && 
            (io.EXMEM_rd === io.IDEX_rs1)) {    
        io.forward_a := "b10".U
    }

    // MEM HAZARD
    when((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs1) && (io.MEMWB_rd === io.IDEX_rs2) && 
        ~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U &&
        (io.EXMEM_rd === io.IDEX_rs1) && (io.EXMEM_rd === io.IDEX_rs2) && (io.EX_MEM_RD === io.IDEX_rs3)) ) {
        io.forward_a := "b01".U
        io.forward_b := "b01".U
        io.forward_c := "b01".U

    }.elsewhen((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs3) && 
            ~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && (io.EXMEM_rd === io.IDEX_rs3))){
        io.forward_c := "b01".U

    }.elsewhen((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs2) && 
            ~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && (io.EXMEM_rd === io.IDEX_rs2))){
        io.forward_b := "b01".U

    }.elsewhen((io.MEMWB_regWr === "b1".U) && (io.MEMWB_rd =/= "b00000".U) && (io.MEMWB_rd === io.IDEX_rs1) && 
            ~(io.EXMEM_regWr === "b1".U && io.EXMEM_rd =/= "b00000".U && (io.EXMEM_rd === io.IDEX_rs1))){
        io.forward_a := "b01".U
    }    
}

hazard detection

  • Locate in src/main/scala/Pipeline/Hazard Units/HazardDetection.scala

To avoid rs3 make error message, here we try to pre-processing the rs3, if it is not R4-type. Bypass rs3 hazard detection.

class HazardDetection extends Module {
  val io = IO(new Bundle {
    val IF_ID_inst = Input(UInt(32.W))
    val ID_EX_memRead = Input(Bool())
    val ID_EX_rd = Input(UInt(5.W))
    val pc_in = Input(SInt(32.W))
    val current_pc = Input(SInt(32.W))
    val ID_EX_operandC = Input(Bool())

    val inst_forward = Output(Bool())
    val pc_forward = Output(Bool())
    val ctrl_forward = Output(Bool())
    val inst_out = Output(UInt(32.W))
    val pc_out = Output(SInt(32.W))
    val current_pc_out = Output(SInt(32.W))
  })

  val Rs1 = io.IF_ID_inst(19, 15)
  val Rs2 = io.IF_ID_inst(24, 20)
  val Rs3 = io.IF_ID_inst(31, 27)
  val rs3detect = io.ID_EX_operandC & (io.ID_EX_rd === Rs3)
  
  when(io.ID_EX_memRead === 1.B && ((io.ID_EX_rd === Rs1) || (io.ID_EX_rd === Rs2) || rs3detect)) {
    io.inst_forward := true.B
    io.pc_forward := true.B
    io.ctrl_forward := true.B
  }.otherwise {
    io.inst_forward := false.B
    io.pc_forward := false.B
    io.ctrl_forward := false.B
  }
  io.inst_out := io.IF_ID_inst
  io.pc_out := io.pc_in
  io.current_pc_out := io.current_pc
}

Structural Hazard

  • Locate in src/main/scala/Pipeline/Hazard Units/StructuralHazard.scala
  • Determine if forwarding is needed for rs1-rs3

Same method for rs1-rs3.

class StructuralHazard extends Module {
  val io = IO(new Bundle {
    val rs1 = Input(UInt(5.W))
    val rs2 = Input(UInt(5.W))
    val MEM_WB_regWr = Input(Bool())
    val MEM_WB_Rd = Input(UInt(5.W))
    val fwd_rs1 = Output(Bool())
    val fwd_rs2 = Output(Bool())
    //RV32F rs3
    val rs3 = Input(UInt(5.W))
    val fwd_rs3 = Output(Bool())
  }) 
      ...
      
  // Determine if forwarding is needed for rs3
  when(io.MEM_WB_regWr && io.MEM_WB_Rd === io.rs3) {
    io.fwd_rs3 := true.B
  }.otherwise {
    io.fwd_rs3 := false.B
  }
}

Pipeline registers

IF-ID

  • These registers hold the instruction and program counter values between the IF and ID stages.
  • Locate in src/main/scala/Pipeline/Pipelines/IF_ID.scala

Same as the original structure. No modification is needed.

ID-EX

  • These registers pass decoded instruction signals and operands from the ID stage to the EX stage.
  • src/main/scala/Pipeline/Pipelines/ID_EX.scala

ID_EX module connects the control module and register file. Add ports for control module is needed. Also, rs3 concerned port added for R4-type (include operand c, OpC). Some ports decode from the instruction like fmt , op5 and rm concerned are for FPU and FPU_Control module.

class ID_EX extends Module {
  val io = IO(new Bundle {
    val rs1_in              = Input(UInt(5.W))
    val rs2_in              = Input(UInt(5.W))
    val rs3_in              = Input(UInt(5.W))
    val rs1_data_in         = Input(SInt(32.W))
    val rs2_data_in         = Input(SInt(32.W))
    val rs3_data_in         = Input(SInt(32.W))
    val imm                 = Input(SInt(32.W))
    val rd_in               = Input(UInt(5.W))
    val func3_in            = Input(UInt(3.W))
    val func7_in            = Input(Bool())
    val ctrl_MemWr_in       = Input(Bool())
    val ctrl_Branch_in      = Input(Bool())
    val ctrl_MemRd_in       = Input(Bool())
    val ctrl_Reg_W_in       = Input(Bool())
    val ctrl_MemToReg_in    = Input(Bool())
    val ctrl_AluOp_in       = Input(UInt(3.W))
    //RV32F
    val ctrl_FPU_en_in      = Input(Bool())
    val ctrl_FPU_Op_in      = Input(UInt(3.W))
    val FPU_func7_in        = Input(UInt(5.W))
    val FPU_fmt_in          = Input(UInt(2.W))
    val FPU_op5_in          = Input(UInt(5.W))
    val FPU_rm_in           = Input(UInt(3.W))
    //
    val ctrl_OpA_in         = Input(UInt(2.W))
    val ctrl_OpB_in         = Input(Bool())
    //RV32F opc
    val ctrl_OpC_in         = Input(Bool())
    //
    val ctrl_nextpc_in      = Input(UInt(2.W))
    val IFID_pc4_in         = Input(UInt(32.W))

    val rs1_out             = Output(UInt(5.W))
    val rs2_out             = Output(UInt(5.W))
    val rs3_out             = Output(UInt(5.W))
    val rs1_data_out        = Output(SInt(32.W))
    val rs2_data_out        = Output(SInt(32.W))
    val rs3_data_out        = Output(SInt(32.W))
    val rd_out              = Output(UInt(5.W))
    val imm_out             = Output(SInt(32.W))
    val func3_out           = Output(UInt(3.W))
    val func7_out           = Output(Bool())
    val ctrl_MemWr_out      = Output(Bool())
    val ctrl_Branch_out     = Output(Bool())
    val ctrl_MemRd_out      = Output(Bool())
    val ctrl_Reg_W_out      = Output(Bool())
    val ctrl_MemToReg_out   = Output(Bool())
    val ctrl_AluOp_out      = Output(UInt(3.W))
    //RV32F
    val ctrl_FPU_en_out     = Output(Bool())
    val ctrl_FPU_Op_out     = Output(UInt(3.W))
    val FPU_func7_out       = Output(UInt(5.W))
    val FPU_fmt_out         = Output(UInt(2.W))
    val FPU_op5_out         = Output(UInt(5.W))
    val FPU_rm_out          = Output(UInt(3.W))
    //
    val ctrl_OpA_out        = Output(UInt(2.W))
    val ctrl_OpB_out        = Output(Bool())
    //RV32F opc
    val ctrl_OpC_out        = Output(Bool())
    //
    val ctrl_nextpc_out     = Output(UInt(2.W))
    val IFID_pc4_out        = Output(UInt(32.W))
  })
      ...
}

In src/main/scala/Pipeline/Main.scala, we use fpu enable pin to decide ID-EX register output data goto ALU/FPU, also decide some inputs reads from which wire.

    ID_EX_.io.rs1_in := Mux(control_module.io.fpu_en, F_RegFile.io.F_rs1, RegFile.io.rs1)
    ID_EX_.io.rs2_in := Mux(control_module.io.fpu_en, F_RegFile.io.F_rs2, RegFile.io.rs2)
    ID_EX_.io.rs3_in := Mux(control_module.io.fpu_en, F_RegFile.io.F_rs3, 0.U)
    //ID_EX_.io.rs1_in            := RegFile.io.rs1
    //ID_EX_.io.rs2_in            := RegFile.io.rs2
    ID_EX_.io.imm               := ImmValue 
    ID_EX_.io.func3_in          := IF_ID_.io.SelectedInstr_out(14, 12)
    ID_EX_.io.func7_in          := IF_ID_.io.SelectedInstr_out(30)
    ID_EX_.io.rd_in             := IF_ID_.io.SelectedInstr_out(11, 7)
    //RV32F, funct7 & fmt
    ID_EX_.io.FPU_func7_in      := IF_ID_.io.SelectedInstr_out(31, 27)
    ID_EX_.io.FPU_fmt_in        := IF_ID_.io.SelectedInstr_out(26, 25)
    ID_EX_.io.FPU_op5_in        := IF_ID_.io.SelectedInstr_out(6, 2)

    EX_MEM_M.io.IDEX_rd         := ID_EX_.io.rd_out
    //FPU
    when(ID_EX_.io.ctrl_FPU_en_out === 1.U){
      FPU_Control.io.fpu_Op := ID_EX_.io.ctrl_FPU_Op_out
      FPU_Control.io.fpu_funct3 := ID_EX_.io.func3_out
      FPU_Control.io.fpu_funct7 := ID_EX_.io.FPU_func7_out
      FPU.io.fpu_Op := FPU_Control.io.fpu_out
      FPU_Control.io.fpu_enable := ID_EX_.io.ctrl_FPU_en_out
    }.otherwise{
    ALU_Control.io.aluOp            := ID_EX_.io.ctrl_AluOp_out     // Alu op code
    ALU.io.alu_Op                   := ALU_Control.io.out           // Alu op code
    ALU_Control.io.func3            := ID_EX_.io.func3_out          // function 3
    ALU_Control.io.func7            := ID_EX_.io.func7_out          // function 7
    }

EX-MEM

  • These registers carry the results of the EX stage to the MA stage.
  • Locate in src/main/scala/Pipeline/Pipelines/EX_MEM.scala

Because Some wires are merged like rd, rs2 etc in front stage and controlled by IDEX_fpu_en. Here we add fpu_out port to pass the result from FPU, and add EXMEM_fp_en pass IDEX_fpu_en signal.

class EX_MEM extends Module {
  val io = IO(new Bundle {
    val IDEX_MEMRD          =   Input(Bool())
    val IDEX_MEMWR          =   Input(Bool())
    val IDEX_MEMTOREG       =   Input(Bool())
    val IDEX_REG_W          =   Input(Bool())
    val IDEX_rs2            =   Input(SInt(32.W))
    val IDEX_rd             =   Input(UInt(5.W))
    val alu_out             =   Input(SInt(32.W))
    //RV32F
    val fpu_out             =   Input(SInt(32.W))
    val IDEX_fp_en          =   Input(bool)

    val EXMEM_memRd_out     = Output(Bool())
    val EXMEM_memWr_out     = Output(Bool())
    val EXMEM_memToReg_out  = Output(Bool())
    val EXMEM_reg_w_out     = Output(Bool())
    val EXMEM_rs2_out       = Output(SInt(32.W))
    val EXMEM_rd_out        = Output(UInt(5.W))
    val EXMEM_alu_out       = Output(SInt(32.W))
    //RV32F
    val EXMEM_fpu_out       = Output(SInt(32.W))
    val EXMEM_fp_en         = Output(bool)
    })
    ...
}

MEM-WB

  • These registers transfer the data from the MA stage to the WB stage for final write-back to the register file.
  • Locate in src/main/scala/Pipeline/Pipelines/MEM_WB.scala

Same as EX-MEM module. Just add a fpu_out port pass data, and add MEMWB_fp_en pass EXMEM_fp_en signal.


Memory

Addition or modification is not needed. But we modify the address input for ALU/FPU memory operation. Because data input from io.EXMEM_rs2_out was selected in fordwarding B section, modifacation is no needed.

In src/main/scala/Pipeline/Main.scala :

    // Data memory inputs
    DataMemory.io.mem_read          := EX_MEM_M.io.EXMEM_memRd_out 
    DataMemory.io.mem_write         := EX_MEM_M.io.EXMEM_memWr_out
    DataMemory.io.dataIn            := EX_MEM_M.io.EXMEM_rs2_out
    DataMemory.io.addr              := Mux(EX_MEM_M.io.fp_en === 0.U, EX_MEM_M.io.EXMEM_alu_out.asUInt, EX_MEM_M.io.EXMEM_fpu_out.asUInt)

Write-back

the write-back data pass through wire d to reg and rs1~rs3 data selector.
In src/main/scala/Pipeline/Main.scala :

    // Write back data to registerfile writedata
    when (MEM_WB_M.io.MEMWB_memToReg_out === 0.U) {
      d := Mux(MEM_WB_M.io.MEMWB_fp_en, MEM_WB_M.io.MEMWB_fpu_out, MEM_WB_M.io.MEMWB_alu_out)        // data from Alu Result or FPU result
    }.elsewhen (MEM_WB_M.io.MEMWB_memToReg_out === 1.U) {
      d := MEM_WB_M.io.MEMWB_dataMem_out    // data from Data Memory
    }.otherwise {
      d := 0.S
    }
    //RV32F
    when(MEM_WB_M.io.MEMWB_fp_en){
      F_RegFile.io.w_data := d
    }.otherwise{
      RegFile.io.w_data := d
    }

Register data write and read

Register and F register share a wire from MEMWB_reg_w_out to reg_write/F_reg_write, and a wire from MEMWB_rd_out to w_reg/F_w_reg. these 2 connections controlled by MEM_WB_M.io.MEMWB_fp_en.

In src/main/scala/Pipeline/Main.scala :

    // Register file connections
    //RV32F
    when(MEM_WB_M.io.MEMWB_fp_en){
      F_RegFile.io.F_w_reg                := MEM_WB_M.io.MEMWB_rd_out
      F_RegFile.io.F_reg_write            := MEM_WB_M.io.MEMWB_reg_w_out
      //F_RegFile.io.w_reg := ID_EX_.io.rd_out
    }.otherwise{
      RegFile.io.w_reg                := MEM_WB_M.io.MEMWB_rd_out
      RegFile.io.reg_write            := MEM_WB_M.io.MEMWB_reg_w_out
    //RegFile.io.w_reg := ID_EX_.io.rd_out
    }

References

https://people.eecs.berkeley.edu/~krste/papers/riscv-spec-2.0.pdf
https://github.com/nozomioshi/ChiselRiscV
https://github.com/chadyuu/riscv-chisel-book
https://github.com/kinzafatim/5-Stage-RV32I