From f2df6147dfb8c22abd20d7b06a272dc5776f0606 Mon Sep 17 00:00:00 2001 From: Rimas Avizienis Date: Mon, 28 Jan 2013 17:17:09 -0800 Subject: [PATCH] shuffled FPU control logic around to make functional unit retiming work better --- rocket/src/main/scala/fpu.scala | 209 +++++++++++++++++++++----------- 1 file changed, 141 insertions(+), 68 deletions(-) diff --git a/rocket/src/main/scala/fpu.scala b/rocket/src/main/scala/fpu.scala index 41fd4cb4..02aa28c0 100644 --- a/rocket/src/main/scala/fpu.scala +++ b/rocket/src/main/scala/fpu.scala @@ -370,8 +370,8 @@ class FPToFP(val latency: Int) extends Component class ioFMA(width: Int) extends Bundle { val valid = Bool(INPUT) - val cmd = Bits(INPUT, FCMD_WIDTH) - val rm = Bits(INPUT, 3) + val cmd = Bits(INPUT, 2) + val rm = Bits(INPUT, 2) val in1 = Bits(INPUT, width) val in2 = Bits(INPUT, width) val in3 = Bits(INPUT, width) @@ -382,82 +382,126 @@ class ioFMA(width: Int) extends Bundle { class FPUSFMAPipe(val latency: Int) extends Component { val io = new ioFMA(33) + + val r_cmd = Reg() { Bits() } + val r_rm = Reg() { Bits() } + val r_in1 = Reg() { Bits() } + val r_in2 = Reg() { Bits() } + val r_in3 = Reg() { Bits() } + + val out_reg = Reg() { Bits() } + val exc_reg = Reg() { Bits() } + + val valid_pipe_regs = Vec(latency) { Reg() { Bool() } } + val dout_pipe_regs = Vec(latency-2) { Reg() { Bits() } } + val exc_pipe_regs = Vec(latency-2) { Reg() { Bits() } } + + valid_pipe_regs(0) := io.valid + for (i <- 1 until latency) { + valid_pipe_regs(i) := valid_pipe_regs(i-1) + } - val cmd = Reg() { Bits() } - val rm = Reg() { Bits() } - val in1 = Reg() { Bits() } - val in2 = Reg() { Bits() } - val in3 = Reg() { Bits() } - - val cmd_fma = io.cmd === FCMD_MADD || io.cmd === FCMD_MSUB || - io.cmd === FCMD_NMADD || io.cmd === FCMD_NMSUB - val cmd_addsub = io.cmd === FCMD_ADD || io.cmd === FCMD_SUB - - val one = Bits("h80000000") - val zero = Cat(io.in1(32) ^ io.in2(32), Bits(0, 32)) - - val valid = Reg(io.valid) when (io.valid) { - cmd := Cat(io.cmd(1) & (cmd_fma || cmd_addsub), io.cmd(0)) - rm := io.rm - in1 := io.in1 - in2 := Mux(cmd_addsub, one, io.in2) - in3 := Mux(cmd_fma, io.in3, Mux(cmd_addsub, io.in2, zero)) + r_cmd := io.cmd + r_rm := io.rm + r_in1 := io.in1 + r_in2 := io.in2 + r_in3 := io.in3 } val fma = new hardfloat.mulAddSubRecodedFloatN(23, 9) - fma.io.op := cmd - fma.io.roundingMode := rm - fma.io.a := in1 - fma.io.b := in2 - fma.io.c := in3 + + fma.io.op := r_cmd + fma.io.roundingMode := r_rm + fma.io.a := r_in1 + fma.io.b := r_in2 + fma.io.c := r_in3 - io.out := Pipe(valid, fma.io.out, latency-1).bits - io.exc := Pipe(valid, fma.io.exceptionFlags, latency-1).bits + when (valid_pipe_regs(0)) { + dout_pipe_regs(0) := fma.io.out + exc_pipe_regs(0) := fma.io.exceptionFlags + } + + for (i <- 1 until latency-2) { + when (valid_pipe_regs(i)) { + dout_pipe_regs(i) := dout_pipe_regs(i-1) + exc_pipe_regs(i) := exc_pipe_regs(i-1) + } + } + + when (valid_pipe_regs(latency-2)) { + out_reg := dout_pipe_regs(latency-3) + exc_reg := exc_pipe_regs(latency-3) + } + + io.out := out_reg + io.exc := exc_reg } class FPUDFMAPipe(val latency: Int) extends Component { val io = new ioFMA(65) - val cmd = Reg() { Bits() } - val rm = Reg() { Bits() } - val in1 = Reg() { Bits() } - val in2 = Reg() { Bits() } - val in3 = Reg() { Bits() } - - val cmd_fma = io.cmd === FCMD_MADD || io.cmd === FCMD_MSUB || - io.cmd === FCMD_NMADD || io.cmd === FCMD_NMSUB - val cmd_addsub = io.cmd === FCMD_ADD || io.cmd === FCMD_SUB - - val one = Bits("h8000000000000000") - val zero = Cat(io.in1(64) ^ io.in2(64), Bits(0, 64)) - - val valid = Reg(io.valid) + val r_cmd = Reg() { Bits() } + val r_rm = Reg() { Bits() } + val r_in1 = Reg() { Bits() } + val r_in2 = Reg() { Bits() } + val r_in3 = Reg() { Bits() } + + val out_reg = Reg() { Bits() } + val exc_reg = Reg() { Bits() } + + val valid_pipe_regs = Vec(latency) { Reg() { Bool() } } + val dout_pipe_regs = Vec(latency-2) { Reg() { Bits() } } + val exc_pipe_regs = Vec(latency-2) { Reg() { Bits() } } + + valid_pipe_regs(0) := io.valid + for (i <- 1 until latency) { + valid_pipe_regs(i) := valid_pipe_regs(i-1) + } + when (io.valid) { - cmd := Cat(io.cmd(1) & (cmd_fma || cmd_addsub), io.cmd(0)) - rm := io.rm - in1 := io.in1 - in2 := Mux(cmd_addsub, one, io.in2) - in3 := Mux(cmd_fma, io.in3, Mux(cmd_addsub, io.in2, zero)) + r_cmd := io.cmd + r_rm := io.rm + r_in1 := io.in1 + r_in2 := io.in2 + r_in3 := io.in3 + } + + val fma = new hardfloat.mulAddSubRecodedFloatN(52, 12) + + fma.io.op := r_cmd + fma.io.roundingMode := r_rm + fma.io.a := r_in1 + fma.io.b := r_in2 + fma.io.c := r_in3 + + when (valid_pipe_regs(0)) { + dout_pipe_regs(0) := fma.io.out + exc_pipe_regs(0) := fma.io.exceptionFlags + } + + for (i <- 1 until latency-2) { + when (valid_pipe_regs(i)) { + dout_pipe_regs(i) := dout_pipe_regs(i-1) + exc_pipe_regs(i) := exc_pipe_regs(i-1) + } } - val fma = new hardfloat.mulAddSubRecodedFloatN(52, 12) - fma.io.op := cmd - fma.io.roundingMode := rm - fma.io.a := in1 - fma.io.b := in2 - fma.io.c := in3 + when (valid_pipe_regs(latency-2)) { + out_reg := dout_pipe_regs(latency-3) + exc_reg := exc_pipe_regs(latency-3) + } - io.out := Pipe(valid, fma.io.out, latency-1).bits - io.exc := Pipe(valid, fma.io.exceptionFlags, latency-1).bits + io.out := out_reg + io.exc := exc_reg } class FPU(sfma_latency: Int, dfma_latency: Int) extends Component { val io = new Bundle { - val ctrl = (new CtrlFPUIO).flip - val dpath = (new DpathFPUIO).flip + val ctrl = new CtrlFPUIO().flip + val dpath = new DpathFPUIO().flip val sfma = new ioFMA(33) val dfma = new ioFMA(65) } @@ -526,23 +570,52 @@ class FPU(sfma_latency: Int, dfma_latency: Int) extends Component val cmd_fma = mem_ctrl.cmd === FCMD_MADD || mem_ctrl.cmd === FCMD_MSUB || mem_ctrl.cmd === FCMD_NMADD || mem_ctrl.cmd === FCMD_NMSUB val cmd_addsub = mem_ctrl.cmd === FCMD_ADD || mem_ctrl.cmd === FCMD_SUB + + // RIMAS: refactoring for retiming + // moved recoding of cmd -> op outside of DFMA/SFMA blocks + // also moved muxing of operands based on command bits out of module + + // Single precision FMA val sfma = new FPUSFMAPipe(sfma_latency) + val sfma_cmd = Mux(io.sfma.valid, io.sfma.cmd, ctrl.cmd) + val sfma_cmd_fma = sfma_cmd === FCMD_MADD || sfma_cmd === FCMD_MSUB || + sfma_cmd === FCMD_NMADD || sfma_cmd === FCMD_NMSUB + val sfma_cmd_addsub = sfma_cmd === FCMD_ADD || sfma_cmd === FCMD_SUB + + val sfma_in1 = Mux(io.sfma.valid, io.sfma.in1, ex_rs1) + val sfma_in2 = Mux(io.sfma.valid, io.sfma.in2, ex_rs2) + val sfma_in3 = Mux(io.sfma.valid, io.sfma.in3, ex_rs3) + val sfma_one = Bits("h80000000") + val sfma_zero = Cat(sfma_in1(32) ^ sfma_in2(32), Bits(0, 32)) + sfma.io.valid := io.sfma.valid || ex_reg_valid && ctrl.fma && ctrl.single - sfma.io.in1 := Mux(io.sfma.valid, io.sfma.in1, ex_rs1) - sfma.io.in2 := Mux(io.sfma.valid, io.sfma.in2, ex_rs2) - sfma.io.in3 := Mux(io.sfma.valid, io.sfma.in3, ex_rs3) - sfma.io.cmd := Mux(io.sfma.valid, io.sfma.cmd, ctrl.cmd) - sfma.io.rm := Mux(io.sfma.valid, io.sfma.rm, ex_rm) + sfma.io.in1 := sfma_in1 + sfma.io.in2 := Mux(sfma_cmd_addsub, sfma_one, sfma_in2) + sfma.io.in3 := Mux(sfma_cmd_fma, sfma_in3, Mux(sfma_cmd_addsub, sfma_in2, sfma_zero)) + sfma.io.cmd := Cat(sfma_cmd(1) & (sfma_cmd_fma || sfma_cmd_addsub), sfma_cmd(0)) + sfma.io.rm := Mux(io.sfma.valid, io.sfma.rm(1,0), ex_rm(1,0)) io.sfma.out := sfma.io.out io.sfma.exc := sfma.io.exc - + + // Double precision FMA val dfma = new FPUDFMAPipe(dfma_latency) + val dfma_cmd = Mux(io.dfma.valid, io.dfma.cmd, ctrl.cmd) + val dfma_cmd_fma = dfma_cmd === FCMD_MADD || dfma_cmd === FCMD_MSUB || + dfma_cmd === FCMD_NMADD || dfma_cmd === FCMD_NMSUB + val dfma_cmd_addsub = dfma_cmd === FCMD_ADD || dfma_cmd === FCMD_SUB + + val dfma_in1 = Mux(io.dfma.valid, io.dfma.in1, ex_rs1) + val dfma_in2 = Mux(io.dfma.valid, io.dfma.in2, ex_rs2) + val dfma_in3 = Mux(io.dfma.valid, io.dfma.in3, ex_rs3) + val dfma_one = Bits("h8000000000000000") + val dfma_zero = Cat(dfma_in1(64) ^ dfma_in2(64), Bits(0, 64)) + dfma.io.valid := io.dfma.valid || ex_reg_valid && ctrl.fma && !ctrl.single - dfma.io.in1 := Mux(io.dfma.valid, io.dfma.in1, ex_rs1) - dfma.io.in2 := Mux(io.dfma.valid, io.dfma.in2, ex_rs2) - dfma.io.in3 := Mux(io.dfma.valid, io.dfma.in3, ex_rs3) - dfma.io.cmd := Mux(io.dfma.valid, io.dfma.cmd, ctrl.cmd) - dfma.io.rm := Mux(io.dfma.valid, io.dfma.rm, ex_rm) + dfma.io.in1 := dfma_in1 + dfma.io.in2 := Mux(dfma_cmd_addsub, dfma_one, dfma_in2) + dfma.io.in3 := Mux(dfma_cmd_fma, dfma_in3, Mux(dfma_cmd_addsub, dfma_in2, dfma_zero)) + dfma.io.cmd := Cat(dfma_cmd(1) & (dfma_cmd_fma || dfma_cmd_addsub), dfma_cmd(0)) + dfma.io.rm := Mux(io.dfma.valid, io.dfma.rm(1,0), ex_rm(1,0)) io.dfma.out := dfma.io.out io.dfma.exc := dfma.io.exc