1
0

Support unrolling the integer divider

This commit is contained in:
Andrew Waterman 2017-03-06 15:03:14 -08:00
parent 74d8d672bf
commit 7668827741

View File

@ -31,6 +31,7 @@ class MultiplierIO(dataBits: Int, tagBits: Int) extends Bundle {
case class MulDivParams( case class MulDivParams(
mulUnroll: Int = 1, mulUnroll: Int = 1,
divUnroll: Int = 1,
mulEarlyOut: Boolean = false, mulEarlyOut: Boolean = false,
divEarlyOut: Boolean = false divEarlyOut: Boolean = false
) )
@ -44,7 +45,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
val state = Reg(init=s_ready) val state = Reg(init=s_ready)
val req = Reg(io.req.bits) val req = Reg(io.req.bits)
val count = Reg(UInt(width = log2Up(w+1))) val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll))))
val neg_out = Reg(Bool()) val neg_out = Reg(Bool())
val isMul = Reg(Bool()) val isMul = Reg(Bool())
val isHi = Reg(Bool()) val isHi = Reg(Bool())
@ -73,8 +74,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
val (lhs_in, lhs_sign) = sext(io.req.bits.in1, halfWidth(io.req.bits), lhsSigned) val (lhs_in, lhs_sign) = sext(io.req.bits.in1, halfWidth(io.req.bits), lhsSigned)
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned) val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned)
val subtractor = remainder(2*w,w) - divisor(w,0) val subtractor = remainder(2*w,w) - divisor
val less = subtractor(w)
val negated_remainder = -remainder(w-1,0) val negated_remainder = -remainder(w-1,0)
when (state === s_neg_inputs) { when (state === s_neg_inputs) {
@ -116,24 +116,36 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
} }
} }
when (state === s_busy && !isMul) { when (state === s_busy && !isMul) {
when (count === w) { val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) =>
// the special case for iteration 0 is to save HW, not for correctness
val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0)
val less = difference(w)
Cat(Mux(less, rem(2*w-1,w), difference(w-1,0)), rem(w-1,0), !less)
} tail
remainder := unrolls.last
when (count === w/cfg.divUnroll) {
state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done)) state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done))
if (w % cfg.divUnroll < cfg.divUnroll - 1)
remainder := unrolls(w % cfg.divUnroll)
} }
count := count + 1 count := count + 1
remainder := Cat(Mux(less, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !less) val divby0 = count === 0 && !subtractor(w)
if (cfg.divEarlyOut) {
val divisorMSB = Log2(divisor(w-1,0), w) val divisorMSB = Log2(divisor(w-1,0), w)
val dividendMSB = Log2(remainder(w-1,0), w) val dividendMSB = Log2(remainder(w-1,0), w)
val eOutPos = UInt(w-1) + divisorMSB - dividendMSB val eOutPos = UInt(w-1) + divisorMSB - dividendMSB
val eOutZero = divisorMSB > dividendMSB val eOutZero = divisorMSB > dividendMSB
val eOut = count === 0 && less /* not divby0 */ && (eOutPos > 0 || eOutZero) val eOut = count === 0 && !divby0 && (eOutPos >= cfg.divUnroll || eOutZero)
when (Bool(cfg.divEarlyOut) && eOut) { when (eOut) {
val shift = Mux(eOutZero, UInt(w-1), eOutPos(log2Up(w)-1,0)) val inc = Mux(eOutZero, UInt(w-1), eOutPos) >> log2Floor(cfg.divUnroll)
remainder := remainder(w-1,0) << shift val shift = inc << log2Floor(cfg.divUnroll)
count := shift remainder := remainder(w-1,0) << shift
count := inc
}
} }
when (count === 0 && !less /* divby0 */ && !isHi) { neg_out := false } when (divby0 && !isHi) { neg_out := false }
} }
when (io.resp.fire() || io.kill) { when (io.resp.fire() || io.kill) {
state := s_ready state := s_ready