diff --git a/rocket/src/main/scala/fpu.scala b/rocket/src/main/scala/fpu.scala index 721c9898..2bf12d96 100644 --- a/rocket/src/main/scala/fpu.scala +++ b/rocket/src/main/scala/fpu.scala @@ -182,6 +182,31 @@ class FPInput extends FPUCtrlSigs { val in3 = Bits(width = 65) } +object ClassifyRecFN { + def apply(expWidth: Int, sigWidth: Int, in: UInt) = { + val sign = in(sigWidth + expWidth) + val exp = in(sigWidth + expWidth - 1, sigWidth - 1) + val sig = in(sigWidth - 2, 0) + + val code = exp(expWidth,expWidth-2) + val codeHi = code(2, 1) + val isSpecial = codeHi === UInt(3) + + val isHighSubnormalIn = exp(expWidth-2, 0) < UInt(2) + val isSubnormal = code === UInt(1) || codeHi === UInt(1) && isHighSubnormalIn + val isNormal = codeHi === UInt(1) && !isHighSubnormalIn || codeHi === UInt(2) + val isZero = code === UInt(0) + val isInf = isSpecial && !exp(expWidth-2) + val isNaN = code.andR + val isSNaN = isNaN && !sig(sigWidth-2) + val isQNaN = isNaN && sig(sigWidth-2) + + Cat(isQNaN, isSNaN, isInf && !sign, isNormal && !sign, + isSubnormal && !sign, isZero && !sign, isZero && sign, + isSubnormal && sign, isNormal && sign, isInf && sign) + } +} + class FPToInt extends Module { val io = new Bundle { @@ -197,30 +222,59 @@ class FPToInt extends Module val in = Reg(new FPInput) val valid = Reg(next=io.in.valid) + + def upconvert(x: UInt) = { + val s2d = Module(new hardfloat.RecFNToRecFN(8, 24, 11, 53)) + s2d.io.in := x + s2d.io.roundingMode := UInt(0) + s2d.io.out + } + + val in1_upconvert = upconvert(io.in.bits.in1) + val in2_upconvert = upconvert(io.in.bits.in2) + when (io.in.valid) { - def upconvert(x: UInt) = hardfloat.recodedFloatNToRecodedFloatM(x, Bits(0), 23, 9, 52, 12)._1 in := io.in.bits - when (io.in.bits.single && !io.in.bits.ldst && io.in.bits.cmd != FCMD_MV_XF) { - in.in1 := upconvert(io.in.bits.in1) - in.in2 := upconvert(io.in.bits.in2) + when (io.in.bits.single && !io.in.bits.ldst && io.in.bits.cmd != FCMD_MV_XF && + // need to also check toint because CVT_IF and SQRT overlap + !(io.in.bits.cmd === FCMD_CVT_IF && io.in.bits.toint)) { + in.in1 := in1_upconvert + in.in2 := in2_upconvert } } - val unrec_s = hardfloat.recodedFloatNToFloatN(in.in1, 23, 9) - val unrec_d = hardfloat.recodedFloatNToFloatN(in.in1, 52, 12) + val unrec_s = hardfloat.fNFromRecFN(8, 24, in.in1) + val unrec_d = hardfloat.fNFromRecFN(11, 53, in.in1) val unrec_out = Mux(in.single, Cat(Fill(32, unrec_s(31)), unrec_s), unrec_d) - val classify_s = hardfloat.recodedFloatNClassify(in.in1, 23, 9) - val classify_d = hardfloat.recodedFloatNClassify(in.in1, 52, 12) + val classify_s = ClassifyRecFN(8, 24, in.in1) + val classify_d = ClassifyRecFN(11, 53, in.in1) val classify_out = Mux(in.single, classify_s, classify_d) - val dcmp = Module(new hardfloat.recodedFloatNCompare(52, 12)) + val dcmp = Module(new hardfloat.CompareRecFN(11, 53)) dcmp.io.a := in.in1 dcmp.io.b := in.in2 - val dcmp_out = (~in.rm & Cat(dcmp.io.a_lt_b, dcmp.io.a_eq_b)).orR - val dcmp_exc = (~in.rm & Cat(dcmp.io.a_lt_b_invalid, dcmp.io.a_eq_b_invalid)).orR << 4 + dcmp.io.signaling := Bool(true) + val dcmp_out = (~in.rm & Cat(dcmp.io.lt, dcmp.io.eq)).orR + val dcmp_exc = dcmp.io.exceptionFlags - val d2i = hardfloat.recodedFloatNToAny(in.in1, in.rm, in.typ ^ 1, 52, 12, 64) + val s2l = Module(new hardfloat.RecFNToIN(8, 24, 64)) + val s2w = Module(new hardfloat.RecFNToIN(8, 24, 32)) + s2l.io.in := in.in1 + s2l.io.roundingMode := in.rm + s2l.io.signedOut := in.typ(0) ^ 1 + s2w.io.in := in.in1 + s2w.io.roundingMode := in.rm + s2w.io.signedOut := in.typ(0) ^ 1 + + val d2l = Module(new hardfloat.RecFNToIN(11, 53, 64)) + val d2w = Module(new hardfloat.RecFNToIN(11, 53, 32)) + d2l.io.in := in.in1 + d2l.io.roundingMode := in.rm + d2l.io.signedOut := in.typ(0) ^ 1 + d2w.io.in := in.in1 + d2w.io.roundingMode := in.rm + d2w.io.signedOut := in.typ(0) ^ 1 io.out.bits.toint := Mux(in.rm(0), classify_out, unrec_out) io.out.bits.store := unrec_out @@ -231,12 +285,19 @@ class FPToInt extends Module io.out.bits.exc := dcmp_exc } when (in.cmd === FCMD_CVT_IF) { - io.out.bits.toint := Mux(in.typ(1), d2i._1, d2i._1(31,0).toSInt).toUInt - io.out.bits.exc := d2i._2 + when (in.single) { + io.out.bits.toint := Mux(in.typ(1), s2l.io.out, s2w.io.out.toSInt).toUInt + val sflags = Mux(in.typ(1), s2l.io.intExceptionFlags, s2w.io.intExceptionFlags) + io.out.bits.exc := Cat(sflags(2, 1).orR, UInt(0, 3), sflags(0)) + } .otherwise { + io.out.bits.toint := Mux(in.typ(1), d2l.io.out, d2w.io.out.toSInt).toUInt + val dflags = Mux(in.typ(1), d2l.io.intExceptionFlags, d2w.io.intExceptionFlags) + io.out.bits.exc := Cat(dflags(2, 1).orR, UInt(0, 3), dflags(0)) + } } io.out.valid := valid - io.out.bits.lt := dcmp.io.a_lt_b + io.out.bits.lt := dcmp.io.lt io.as_double := in } @@ -251,20 +312,36 @@ class IntToFP(val latency: Int) extends Module val mux = Wire(new FPResult) mux.exc := Bits(0) - mux.data := hardfloat.floatNToRecodedFloatN(in.bits.in1, 52, 12) + mux.data := hardfloat.recFNFromFN(11, 53, in.bits.in1) when (in.bits.single) { - mux.data := Cat(SInt(-1, 32), hardfloat.floatNToRecodedFloatN(in.bits.in1, 23, 9)) + mux.data := Cat(SInt(-1, 32), hardfloat.recFNFromFN(8, 24, in.bits.in1)) } + val l2s = Module(new hardfloat.INToRecFN(64, 8, 24)) + val w2s = Module(new hardfloat.INToRecFN(32, 8, 24)) + l2s.io.signedIn := in.bits.typ(0) ^ 1 + l2s.io.in := in.bits.in1 + l2s.io.roundingMode := in.bits.rm + w2s.io.signedIn := in.bits.typ(0) ^ 1 + w2s.io.in := in.bits.in1 + w2s.io.roundingMode := in.bits.rm + + val l2d = Module(new hardfloat.INToRecFN(64, 11, 53)) + val w2d = Module(new hardfloat.INToRecFN(32, 11, 53)) + l2d.io.signedIn := in.bits.typ(0) ^ 1 + l2d.io.in := in.bits.in1 + l2d.io.roundingMode := in.bits.rm + w2d.io.signedIn := in.bits.typ(0) ^ 1 + w2d.io.in := in.bits.in1 + w2d.io.roundingMode := in.bits.rm + when (in.bits.cmd === FCMD_CVT_FI) { when (in.bits.single) { - val u = hardfloat.anyToRecodedFloatN(in.bits.in1(63,0), in.bits.rm, in.bits.typ ^ 1, 23, 9, 64) - mux.data := Cat(SInt(-1, 32), u._1) - mux.exc := u._2 + mux.data := Cat(SInt(-1, 32), Mux(in.bits.typ(1), l2s.io.out, w2s.io.out)) + mux.exc := Mux(in.bits.typ(1), l2s.io.exceptionFlags, w2s.io.exceptionFlags) }.otherwise { - val u = hardfloat.anyToRecodedFloatN(in.bits.in1(63,0), in.bits.rm, in.bits.typ ^ 1, 52, 12, 64) - mux.data := u._1 - mux.exc := u._2 + mux.data := Mux(in.bits.typ(1), l2d.io.out, w2d.io.out) + mux.exc := Mux(in.bits.typ(1), l2d.io.exceptionFlags, w2d.io.exceptionFlags) } } @@ -289,8 +366,12 @@ class FPToFP(val latency: Int) extends Module val sign_d = fsgnjSign(in.bits.in1, in.bits.in2, 64, !in.bits.single && isSgnj, in.bits.rm) val fsgnj = Cat(sign_d, in.bits.in1(63,33), sign_s, in.bits.in1(31,0)) - val s2d = hardfloat.recodedFloatNToRecodedFloatM(in.bits.in1, in.bits.rm, 23, 9, 52, 12) - val d2s = hardfloat.recodedFloatNToRecodedFloatM(in.bits.in1, in.bits.rm, 52, 12, 23, 9) + val s2d = Module(new hardfloat.RecFNToRecFN(8, 24, 11, 53)) + val d2s = Module(new hardfloat.RecFNToRecFN(11, 53, 8, 24)) + s2d.io.in := in.bits.in1 + s2d.io.roundingMode := in.bits.rm + d2s.io.in := in.bits.in1 + d2s.io.roundingMode := in.bits.rm val isnan1 = Mux(in.bits.single, in.bits.in1(31,29).andR, in.bits.in1(63,61).andR) val isnan2 = Mux(in.bits.single, in.bits.in2(31,29).andR, in.bits.in2(63,61).andR) @@ -308,18 +389,18 @@ class FPToFP(val latency: Int) extends Module when (isSgnj || isLHS) { mux.data := fsgnj } when (in.bits.cmd === FCMD_CVT_FF) { when (in.bits.single) { - mux.data := Cat(SInt(-1, 32), d2s._1) - mux.exc := d2s._2 + mux.data := Cat(SInt(-1, 32), d2s.io.out) + mux.exc := d2s.io.exceptionFlags }.otherwise { - mux.data := s2d._1 - mux.exc := s2d._2 + mux.data := s2d.io.out + mux.exc := s2d.io.exceptionFlags } } io.out <> Pipe(in.valid, mux, latency-1) } -class FPUFMAPipe(val latency: Int, sigWidth: Int, expWidth: Int) extends Module +class FPUFMAPipe(val latency: Int, expWidth: Int, sigWidth: Int) extends Module { val io = new Bundle { val in = Valid(new FPInput).flip @@ -341,7 +422,7 @@ class FPUFMAPipe(val latency: Int, sigWidth: Int, expWidth: Int) extends Module unless (cmd_fma || cmd_addsub) { in.in3 := zero } } - val fma = Module(new hardfloat.mulAddSubRecodedFloatN(sigWidth, expWidth)) + val fma = Module(new hardfloat.MulAddRecFN(expWidth, sigWidth)) fma.io.op := in.cmd fma.io.roundingMode := in.rm fma.io.a := in.in1 @@ -377,8 +458,8 @@ class FPU(implicit p: Parameters) extends CoreModule()(p) { val load_wb_single = RegEnable(io.dmem_resp_type === MT_W || io.dmem_resp_type === MT_WU, io.dmem_resp_val) val load_wb_data = RegEnable(io.dmem_resp_data, io.dmem_resp_val) val load_wb_tag = RegEnable(io.dmem_resp_tag, io.dmem_resp_val) - val rec_s = hardfloat.floatNToRecodedFloatN(load_wb_data, 23, 9) - val rec_d = hardfloat.floatNToRecodedFloatN(load_wb_data, 52, 12) + val rec_s = hardfloat.recFNFromFN(8, 24, load_wb_data) + val rec_d = hardfloat.recFNFromFN(11, 53, load_wb_data) val load_wb_data_recoded = Mux(load_wb_single, Cat(SInt(-1, 32), rec_s), rec_d) // regfile @@ -415,11 +496,11 @@ class FPU(implicit p: Parameters) extends CoreModule()(p) { req.in3 := ex_rs3 req.typ := ex_reg_inst(21,20) - val sfma = Module(new FPUFMAPipe(p(SFMALatency), 23, 9)) + val sfma = Module(new FPUFMAPipe(p(SFMALatency), 8, 24)) sfma.io.in.valid := ex_reg_valid && ex_ctrl.fma && ex_ctrl.single sfma.io.in.bits := req - val dfma = Module(new FPUFMAPipe(p(DFMALatency), 52, 12)) + val dfma = Module(new FPUFMAPipe(p(DFMALatency), 11, 53)) dfma.io.in.valid := ex_reg_valid && ex_ctrl.fma && !ex_ctrl.single dfma.io.in.bits := req @@ -490,8 +571,8 @@ class FPU(implicit p: Parameters) extends CoreModule()(p) { when (wen(0) || divSqrt_wen) { regfile(waddr) := wdata if (enableCommitLog) { - val wdata_unrec_s = hardfloat.recodedFloatNToFloatN(wdata(64,0), 23, 9) - val wdata_unrec_d = hardfloat.recodedFloatNToFloatN(wdata(64,0), 52, 12) + val wdata_unrec_s = hardfloat.fNFromRecFN(8, 24, wdata(64,0)) + val wdata_unrec_d = hardfloat.fNFromRecFN(11, 53, wdata(64,0)) val wb_single = (winfo(0) >> 5)(0) printf ("f%d p%d 0x%x\n", waddr, waddr+ UInt(32), Mux(wb_single, Cat(UInt(0,32), wdata_unrec_s), wdata_unrec_d)) @@ -525,7 +606,7 @@ class FPU(implicit p: Parameters) extends CoreModule()(p) { val divSqrt_flags_double = Reg(Bits()) val divSqrt_wdata_double = Reg(Bits()) - val divSqrt = Module(new hardfloat.divSqrtRecodedFloat64) + val divSqrt = Module(new hardfloat.DivSqrtRecF64) divSqrt_inReady := Mux(divSqrt.io.sqrtOp, divSqrt.io.inReady_sqrt, divSqrt.io.inReady_div) val divSqrt_outValid = divSqrt.io.outValid_div || divSqrt.io.outValid_sqrt divSqrt.io.inValid := mem_reg_valid && (mem_ctrl.div || mem_ctrl.sqrt) @@ -549,8 +630,10 @@ class FPU(implicit p: Parameters) extends CoreModule()(p) { divSqrt_flags_double := divSqrt.io.exceptionFlags } - val divSqrt_toSingle = hardfloat.recodedFloatNToRecodedFloatM(divSqrt_wdata_double, ex_rm, 52, 12, 23, 9) - divSqrt_wdata := Mux(divSqrt_single, divSqrt_toSingle._1, divSqrt_wdata_double) - divSqrt_flags := divSqrt_flags_double | Mux(divSqrt_single, divSqrt_toSingle._2, Bits(0)) + val divSqrt_toSingle = Module(new hardfloat.RecFNToRecFN(11, 53, 8, 24)) + divSqrt_toSingle.io.in := divSqrt_wdata_double + divSqrt_toSingle.io.roundingMode := ex_rm + divSqrt_wdata := Mux(divSqrt_single, divSqrt_toSingle.io.out, divSqrt_wdata_double) + divSqrt_flags := divSqrt_flags_double | Mux(divSqrt_single, divSqrt_toSingle.io.exceptionFlags, Bits(0)) } }