1
0

Improve integer mul/div

- Signed integer multiplication latency is now deterministic (before,
it would take an extra cycle if the multiplier was negative).
- High-part multiplication is now one cycle faster.
- RV64 MULW now takes half as many cycles as MUL.
- Positive remainders are now one cycle faster.
This commit is contained in:
Andrew Waterman 2017-06-19 12:07:19 -07:00
parent ff1f0170dc
commit a6d9884cc0

View File

@ -40,15 +40,16 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
val io = new MultiplierIO(width, log2Up(nXpr)) val io = new MultiplierIO(width, log2Up(nXpr))
val w = io.req.bits.in1.getWidth val w = io.req.bits.in1.getWidth
val mulw = (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll val mulw = (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll
val fastMulW = w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0
val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6) val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8)
val state = Reg(init=s_ready) val state = Reg(init=s_ready)
val req = Reg(io.req.bits) val req = Reg(io.req.bits)
val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll)))) val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll))))
val neg_out = Reg(Bool()) val neg_out = Reg(Bool())
val isMul = Reg(Bool())
val isHi = Reg(Bool()) val isHi = Reg(Bool())
val resHi = Reg(Bool())
val divisor = Reg(Bits(width = w+1)) // div only needs w bits val divisor = Reg(Bits(width = w+1)) // div only needs w bits
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
@ -75,47 +76,47 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned) val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned)
val subtractor = remainder(2*w,w) - divisor val subtractor = remainder(2*w,w) - divisor
val negated_remainder = -remainder(w-1,0) val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0))
val negated_remainder = -result
when (state === s_neg_inputs) { when (state === s_neg_inputs) {
when (remainder(w-1) || isMul) { when (remainder(w-1)) {
remainder := negated_remainder remainder := negated_remainder
} }
when (divisor(w-1) || isMul) { when (divisor(w-1)) {
divisor := subtractor divisor := subtractor
} }
state := s_busy state := s_div
} }
when (state === s_neg_output) { when (state === s_neg_output) {
remainder := negated_remainder remainder := negated_remainder
state := s_done state := s_done_div
resHi := false
} }
when (state === s_move_rem) { when (state === s_mul) {
remainder := remainder(2*w, w+1)
state := Mux(neg_out, s_neg_output, s_done)
}
when (state === s_busy && isMul) {
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
val mplierSign = remainder(w)
val mplier = mulReg(mulw-1,0) val mplier = mulReg(mulw-1,0)
val accum = mulReg(2*mulw,mulw).asSInt val accum = mulReg(2*mulw,mulw).asSInt
val mpcand = divisor.asSInt val mpcand = divisor.asSInt
val prod = mplier(cfg.mulUnroll-1, 0) * mpcand + accum val prod = Cat(mplierSign, mplier(cfg.mulUnroll-1, 0)).asSInt * mpcand + accum
val nextMulReg = Cat(prod, mplier(mulw-1, cfg.mulUnroll)) val nextMulReg = Cat(prod, mplier(mulw-1, cfg.mulUnroll))
val nextMplierSign = count === mulw/cfg.mulUnroll-2 && neg_out
val eOutMask = (SInt(BigInt(-1) << mulw) >> (count * cfg.mulUnroll)(log2Up(mulw)-1,0))(mulw-1,0) val eOutMask = (SInt(BigInt(-1) << mulw) >> (count * cfg.mulUnroll)(log2Up(mulw)-1,0))(mulw-1,0)
val eOut = Bool(cfg.mulEarlyOut) && count =/= mulw/cfg.mulUnroll-1 && count =/= 0 && val eOut = Bool(cfg.mulEarlyOut) && count =/= mulw/cfg.mulUnroll-1 && count =/= 0 &&
!isHi && (mplier & ~eOutMask) === UInt(0) !isHi && (mplier & ~eOutMask) === UInt(0)
val eOutRes = (mulReg >> (mulw - count * cfg.mulUnroll)(log2Up(mulw)-1,0)) val eOutRes = (mulReg >> (mulw - count * cfg.mulUnroll)(log2Up(mulw)-1,0))
val nextMulReg1 = Cat(nextMulReg(2*mulw,mulw), Mux(eOut, eOutRes, nextMulReg)(mulw-1,0)) val nextMulReg1 = Cat(nextMulReg(2*mulw,mulw), Mux(eOut, eOutRes, nextMulReg)(mulw-1,0))
remainder := Cat(nextMulReg1 >> w, Bool(false), nextMulReg1(w-1,0)) remainder := Cat(nextMulReg1 >> w, nextMplierSign, nextMulReg1(w-1,0))
count := count + 1 count := count + 1
when (eOut || count === mulw/cfg.mulUnroll-1) { when (eOut || count === mulw/cfg.mulUnroll-1) {
state := Mux(isHi, s_move_rem, s_done) state := s_done_mul
resHi := isHi
} }
} }
when (state === s_busy && !isMul) { when (state === s_div) {
val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) => val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) =>
// the special case for iteration 0 is to save HW, not for correctness // the special case for iteration 0 is to save HW, not for correctness
val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0) val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0)
@ -125,7 +126,8 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
remainder := unrolls.last remainder := unrolls.last
when (count === w/cfg.divUnroll) { when (count === w/cfg.divUnroll) {
state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done)) state := Mux(neg_out, s_neg_output, s_done_div)
resHi := isHi
if (w % cfg.divUnroll < cfg.divUnroll - 1) if (w % cfg.divUnroll < cfg.divUnroll - 1)
remainder := unrolls(w % cfg.divUnroll) remainder := unrolls(w % cfg.divUnroll)
} }
@ -151,18 +153,21 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
state := s_ready state := s_ready
} }
when (io.req.fire()) { when (io.req.fire()) {
state := Mux(lhs_sign || rhs_sign && !cmdMul, s_neg_inputs, s_busy) state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div))
isMul := cmdMul
isHi := cmdHi isHi := cmdHi
count := 0 resHi := false
neg_out := !cmdMul && Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign) count := Mux[UInt](Bool(fastMulW) && cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0)
neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign)
divisor := Cat(rhs_sign, rhs_in) divisor := Cat(rhs_sign, rhs_in)
remainder := lhs_in remainder := lhs_in
req := io.req.bits req := io.req.bits
} }
val outMul = (state & (s_done_mul ^ s_done_div)) === (s_done_mul & ~s_done_div)
val loOut = Mux(Bool(fastMulW) && halfWidth(req) && outMul, result(w-1,w/2), result(w/2-1,0))
val hiOut = Mux(halfWidth(req), Fill(w/2, loOut(w/2-1)), result(w-1,w/2))
io.resp.bits <> req io.resp.bits <> req
io.resp.bits.data := Mux(halfWidth(req), Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) io.resp.bits.data := Cat(hiOut, loOut)
io.resp.valid := state === s_done io.resp.valid := (state === s_done_mul || state === s_done_div)
io.req.ready := state === s_ready io.req.ready := state === s_ready
} }