From 69441930b56f4d79dbc6944935c6f4887b457251 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Mon, 5 Feb 2018 17:50:01 -0800 Subject: [PATCH 1/3] Rationalize ALU function encoding MULHSU and MULHU should match their ISA funct3 encodings to slightly reduce HW cost. --- src/main/scala/rocket/ALU.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/rocket/ALU.scala b/src/main/scala/rocket/ALU.scala index 7791a4ee..537ffa51 100644 --- a/src/main/scala/rocket/ALU.scala +++ b/src/main/scala/rocket/ALU.scala @@ -33,8 +33,8 @@ object ALU def FN_MUL = FN_ADD def FN_MULH = FN_SL - def FN_MULHSU = FN_SLT - def FN_MULHU = FN_SLTU + def FN_MULHSU = FN_SEQ + def FN_MULHU = FN_SNE def isMulFN(fn: UInt, cmp: UInt) = fn(1,0) === cmp(1,0) def isSub(cmd: UInt) = cmd(3) From a59fc3bdaa9bd208fbad18d662842efda29ae3a4 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Mon, 5 Feb 2018 17:49:33 -0800 Subject: [PATCH 2/3] Teach MulDiv to do either mul-only or div-only by setting unroll=0 --- src/main/scala/rocket/Multiplier.scala | 39 +++++++++++++++----------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/main/scala/rocket/Multiplier.scala b/src/main/scala/rocket/Multiplier.scala index 2083ffb4..b1ce951e 100644 --- a/src/main/scala/rocket/Multiplier.scala +++ b/src/main/scala/rocket/Multiplier.scala @@ -39,30 +39,35 @@ case class MulDivParams( class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { val io = new MultiplierIO(width, log2Up(nXpr)) val w = io.req.bits.in1.getWidth - val mulw = (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll - val fastMulW = w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0 + val mulw = if (cfg.mulUnroll == 0) w else (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll + val fastMulW = if (cfg.mulUnroll == 0) false else w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0 val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8) val state = Reg(init=s_ready) val req = Reg(io.req.bits) - val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll)))) + val count = Reg(UInt(width = log2Ceil( + ((cfg.divUnroll != 0).option(w/cfg.divUnroll + 1).toSeq ++ + (cfg.mulUnroll != 0).option(mulw/cfg.mulUnroll)).reduce(_ max _)))) val neg_out = Reg(Bool()) val isHi = Reg(Bool()) val resHi = Reg(Bool()) val divisor = Reg(Bits(width = w+1)) // div only needs w bits val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits + val mulDecode = List( + FN_MUL -> List(Y, N, X, X), + FN_MULH -> List(Y, Y, Y, Y), + FN_MULHU -> List(Y, Y, N, N), + FN_MULHSU -> List(Y, Y, Y, N)) + val divDecode = List( + FN_DIV -> List(N, N, Y, Y), + FN_REM -> List(N, Y, Y, Y), + FN_DIVU -> List(N, N, N, N), + FN_REMU -> List(N, Y, N, N)) val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil = - DecodeLogic(io.req.bits.fn, List(X, X, X, X), List( - FN_DIV -> List(N, N, Y, Y), - FN_REM -> List(N, Y, Y, Y), - FN_DIVU -> List(N, N, N, N), - FN_REMU -> List(N, Y, N, N), - FN_MUL -> List(Y, N, X, X), - FN_MULH -> List(Y, Y, Y, Y), - FN_MULHU -> List(Y, Y, N, N), - FN_MULHSU -> List(Y, Y, Y, N))).map(_ toBool) + DecodeLogic(io.req.bits.fn, List(X, X, X, X), + (if (cfg.divUnroll != 0) divDecode else Nil) ++ (if (cfg.mulUnroll != 0) mulDecode else Nil)).map(_.toBool) require(w == 32 || w == 64) def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32 @@ -79,7 +84,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0)) val negated_remainder = -result - when (state === s_neg_inputs) { + if (cfg.divUnroll != 0) when (state === s_neg_inputs) { when (remainder(w-1)) { remainder := negated_remainder } @@ -88,12 +93,12 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { } state := s_div } - when (state === s_neg_output) { + if (cfg.divUnroll != 0) when (state === s_neg_output) { remainder := negated_remainder state := s_done_div resHi := false } - when (state === s_mul) { + if (cfg.mulUnroll != 0) when (state === s_mul) { val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) val mplierSign = remainder(w) val mplier = mulReg(mulw-1,0) @@ -116,7 +121,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { resHi := isHi } } - when (state === s_div) { + if (cfg.divUnroll != 0) when (state === s_div) { val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) => // the special case for iteration 0 is to save HW, not for correctness val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0) @@ -156,7 +161,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div)) isHi := cmdHi resHi := false - count := Mux[UInt](Bool(fastMulW) && cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) + count := (if (fastMulW) Mux[UInt](cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) else 0) neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign) divisor := Cat(rhs_sign, rhs_in) remainder := lhs_in From efc6c9cbd33c46988461597c37545b1d82efc8df Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Tue, 6 Feb 2018 14:05:03 -0800 Subject: [PATCH 3/3] Let user of CSRFile decide when to set tval I also renamed badaddr to tval (the correct name). --- src/main/scala/rocket/CSR.scala | 14 ++++---------- src/main/scala/rocket/RocketCore.scala | 6 +++++- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/main/scala/rocket/CSR.scala b/src/main/scala/rocket/CSR.scala index 522b1a68..79b56e17 100644 --- a/src/main/scala/rocket/CSR.scala +++ b/src/main/scala/rocket/CSR.scala @@ -196,7 +196,7 @@ class CSRFileIO(implicit p: Parameters) extends CoreBundle val retire = UInt(INPUT, log2Up(1+retireWidth)) val cause = UInt(INPUT, xLen) val pc = UInt(INPUT, vaddrBitsExtended) - val badaddr = UInt(INPUT, vaddrBitsExtended) + val tval = UInt(INPUT, vaddrBitsExtended) val time = UInt(OUTPUT, xLen) val fcsr_rm = Bits(OUTPUT, FPConstants.RM_SZ) val fcsr_flags = Valid(Bits(width = FPConstants.FLAGS_SZ)).flip @@ -527,12 +527,6 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param assert(!reg_singleStepped || io.retire === UInt(0)) val epc = formEPC(io.pc) - val write_badaddr = exception && cause.isOneOf(Causes.illegal_instruction, Causes.breakpoint, - Causes.misaligned_load, Causes.misaligned_store, - Causes.load_access, Causes.store_access, Causes.fetch_access, - Causes.load_page_fault, Causes.store_page_fault, Causes.fetch_page_fault) - val badaddr_value = Mux(write_badaddr, io.badaddr, 0.U) - val noCause :: mCause :: hCause :: sCause :: uCause :: Nil = Enum(5) val xcause_dest = Wire(init = noCause) @@ -549,7 +543,7 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param reg_sepc := epc reg_scause := cause xcause_dest := sCause - reg_sbadaddr := badaddr_value + reg_sbadaddr := io.tval reg_mstatus.spie := reg_mstatus.sie reg_mstatus.spp := reg_mstatus.prv reg_mstatus.sie := false @@ -558,7 +552,7 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param reg_mepc := epc reg_mcause := cause xcause_dest := mCause - reg_mbadaddr := badaddr_value + reg_mbadaddr := io.tval reg_mstatus.mpie := reg_mstatus.mie reg_mstatus.mpp := trimPrivilege(reg_mstatus.prv) reg_mstatus.mie := false @@ -808,7 +802,7 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param t.priv := Cat(reg_debug, reg_mstatus.prv) t.cause := cause t.interrupt := cause(xLen-1) - t.tval := badaddr_value + t.tval := io.tval } def chooseInterrupt(masksIn: Seq[UInt]): (Bool, UInt) = { diff --git a/src/main/scala/rocket/RocketCore.scala b/src/main/scala/rocket/RocketCore.scala index efc86b9c..f3f59e7a 100644 --- a/src/main/scala/rocket/RocketCore.scala +++ b/src/main/scala/rocket/RocketCore.scala @@ -545,7 +545,11 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) csr.io.fcsr_flags := io.fpu.fcsr_flags csr.io.rocc_interrupt := io.rocc.interrupt csr.io.pc := wb_reg_pc - csr.io.badaddr := encodeVirtualAddress(wb_reg_wdata, wb_reg_wdata) + val tval_valid = wb_xcpt && wb_cause.isOneOf(Causes.illegal_instruction, Causes.breakpoint, + Causes.misaligned_load, Causes.misaligned_store, + Causes.load_access, Causes.store_access, Causes.fetch_access, + Causes.load_page_fault, Causes.store_page_fault, Causes.fetch_page_fault) + csr.io.tval := Mux(tval_valid, encodeVirtualAddress(wb_reg_wdata, wb_reg_wdata), 0.U) io.ptw.ptbr := csr.io.ptbr io.ptw.status := csr.io.status io.ptw.pmp := csr.io.pmp