diff --git a/src/main/scala/rocket/Multiplier.scala b/src/main/scala/rocket/Multiplier.scala index 4bf566b0..400d8314 100644 --- a/src/main/scala/rocket/Multiplier.scala +++ b/src/main/scala/rocket/Multiplier.scala @@ -31,6 +31,7 @@ class MultiplierIO(dataBits: Int, tagBits: Int) extends Bundle { case class MulDivParams( mulUnroll: Int = 1, + divUnroll: Int = 1, mulEarlyOut: Boolean = false, divEarlyOut: Boolean = false ) @@ -44,7 +45,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { val state = Reg(init=s_ready) val req = Reg(io.req.bits) - val count = Reg(UInt(width = log2Up(w+1))) + val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll)))) val neg_out = Reg(Bool()) val isMul = Reg(Bool()) val isHi = Reg(Bool()) @@ -73,8 +74,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { val (lhs_in, lhs_sign) = sext(io.req.bits.in1, halfWidth(io.req.bits), lhsSigned) val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned) - val subtractor = remainder(2*w,w) - divisor(w,0) - val less = subtractor(w) + val subtractor = remainder(2*w,w) - divisor val negated_remainder = -remainder(w-1,0) when (state === s_neg_inputs) { @@ -116,24 +116,36 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { } } when (state === s_busy && !isMul) { - when (count === w) { + val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) => + // the special case for iteration 0 is to save HW, not for correctness + val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0) + val less = difference(w) + Cat(Mux(less, rem(2*w-1,w), difference(w-1,0)), rem(w-1,0), !less) + } tail + + remainder := unrolls.last + when (count === w/cfg.divUnroll) { state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done)) + if (w % cfg.divUnroll < cfg.divUnroll - 1) + remainder := unrolls(w % cfg.divUnroll) } count := count + 1 - remainder := Cat(Mux(less, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !less) - - val divisorMSB = Log2(divisor(w-1,0), w) - val dividendMSB = Log2(remainder(w-1,0), w) - val eOutPos = UInt(w-1) + divisorMSB - dividendMSB - val eOutZero = divisorMSB > dividendMSB - val eOut = count === 0 && less /* not divby0 */ && (eOutPos > 0 || eOutZero) - when (Bool(cfg.divEarlyOut) && eOut) { - val shift = Mux(eOutZero, UInt(w-1), eOutPos(log2Up(w)-1,0)) - remainder := remainder(w-1,0) << shift - count := shift + val divby0 = count === 0 && !subtractor(w) + if (cfg.divEarlyOut) { + val divisorMSB = Log2(divisor(w-1,0), w) + val dividendMSB = Log2(remainder(w-1,0), w) + val eOutPos = UInt(w-1) + divisorMSB - dividendMSB + val eOutZero = divisorMSB > dividendMSB + val eOut = count === 0 && !divby0 && (eOutPos >= cfg.divUnroll || eOutZero) + when (eOut) { + val inc = Mux(eOutZero, UInt(w-1), eOutPos) >> log2Floor(cfg.divUnroll) + val shift = inc << log2Floor(cfg.divUnroll) + remainder := remainder(w-1,0) << shift + count := inc + } } - when (count === 0 && !less /* divby0 */ && !isHi) { neg_out := false } + when (divby0 && !isHi) { neg_out := false } } when (io.resp.fire() || io.kill) { state := s_ready