1
0

Teach MulDiv to do either mul-only or div-only by setting unroll=0

This commit is contained in:
Andrew Waterman 2018-02-05 17:49:33 -08:00
parent 69441930b5
commit a59fc3bdaa

View File

@ -39,30 +39,35 @@ case class MulDivParams(
class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
val io = new MultiplierIO(width, log2Up(nXpr)) val io = new MultiplierIO(width, log2Up(nXpr))
val w = io.req.bits.in1.getWidth val w = io.req.bits.in1.getWidth
val mulw = (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll val mulw = if (cfg.mulUnroll == 0) w else (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll
val fastMulW = w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0 val fastMulW = if (cfg.mulUnroll == 0) false else w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0
val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8) val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8)
val state = Reg(init=s_ready) val state = Reg(init=s_ready)
val req = Reg(io.req.bits) val req = Reg(io.req.bits)
val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll)))) val count = Reg(UInt(width = log2Ceil(
((cfg.divUnroll != 0).option(w/cfg.divUnroll + 1).toSeq ++
(cfg.mulUnroll != 0).option(mulw/cfg.mulUnroll)).reduce(_ max _))))
val neg_out = Reg(Bool()) val neg_out = Reg(Bool())
val isHi = Reg(Bool()) val isHi = Reg(Bool())
val resHi = Reg(Bool()) val resHi = Reg(Bool())
val divisor = Reg(Bits(width = w+1)) // div only needs w bits val divisor = Reg(Bits(width = w+1)) // div only needs w bits
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
val mulDecode = List(
FN_MUL -> List(Y, N, X, X),
FN_MULH -> List(Y, Y, Y, Y),
FN_MULHU -> List(Y, Y, N, N),
FN_MULHSU -> List(Y, Y, Y, N))
val divDecode = List(
FN_DIV -> List(N, N, Y, Y),
FN_REM -> List(N, Y, Y, Y),
FN_DIVU -> List(N, N, N, N),
FN_REMU -> List(N, Y, N, N))
val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil = val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil =
DecodeLogic(io.req.bits.fn, List(X, X, X, X), List( DecodeLogic(io.req.bits.fn, List(X, X, X, X),
FN_DIV -> List(N, N, Y, Y), (if (cfg.divUnroll != 0) divDecode else Nil) ++ (if (cfg.mulUnroll != 0) mulDecode else Nil)).map(_.toBool)
FN_REM -> List(N, Y, Y, Y),
FN_DIVU -> List(N, N, N, N),
FN_REMU -> List(N, Y, N, N),
FN_MUL -> List(Y, N, X, X),
FN_MULH -> List(Y, Y, Y, Y),
FN_MULHU -> List(Y, Y, N, N),
FN_MULHSU -> List(Y, Y, Y, N))).map(_ toBool)
require(w == 32 || w == 64) require(w == 32 || w == 64)
def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32 def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32
@ -79,7 +84,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0)) val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0))
val negated_remainder = -result val negated_remainder = -result
when (state === s_neg_inputs) { if (cfg.divUnroll != 0) when (state === s_neg_inputs) {
when (remainder(w-1)) { when (remainder(w-1)) {
remainder := negated_remainder remainder := negated_remainder
} }
@ -88,12 +93,12 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
} }
state := s_div state := s_div
} }
when (state === s_neg_output) { if (cfg.divUnroll != 0) when (state === s_neg_output) {
remainder := negated_remainder remainder := negated_remainder
state := s_done_div state := s_done_div
resHi := false resHi := false
} }
when (state === s_mul) { if (cfg.mulUnroll != 0) when (state === s_mul) {
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
val mplierSign = remainder(w) val mplierSign = remainder(w)
val mplier = mulReg(mulw-1,0) val mplier = mulReg(mulw-1,0)
@ -116,7 +121,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
resHi := isHi resHi := isHi
} }
} }
when (state === s_div) { if (cfg.divUnroll != 0) when (state === s_div) {
val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) => val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) =>
// the special case for iteration 0 is to save HW, not for correctness // the special case for iteration 0 is to save HW, not for correctness
val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0) val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0)
@ -156,7 +161,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div)) state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div))
isHi := cmdHi isHi := cmdHi
resHi := false resHi := false
count := Mux[UInt](Bool(fastMulW) && cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) count := (if (fastMulW) Mux[UInt](cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) else 0)
neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign) neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign)
divisor := Cat(rhs_sign, rhs_in) divisor := Cat(rhs_sign, rhs_in)
remainder := lhs_in remainder := lhs_in