Teach MulDiv to do either mul-only or div-only by setting unroll=0
This commit is contained in:
		| @@ -39,30 +39,35 @@ case class MulDivParams( | ||||
| class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { | ||||
|   val io = new MultiplierIO(width, log2Up(nXpr)) | ||||
|   val w = io.req.bits.in1.getWidth | ||||
|   val mulw = (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll | ||||
|   val fastMulW = w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0 | ||||
|   val mulw = if (cfg.mulUnroll == 0) w else (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll | ||||
|   val fastMulW = if (cfg.mulUnroll == 0) false else w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0 | ||||
|   | ||||
|   val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8) | ||||
|   val state = Reg(init=s_ready) | ||||
|   | ||||
|   val req = Reg(io.req.bits) | ||||
|   val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll)))) | ||||
|   val count = Reg(UInt(width = log2Ceil( | ||||
|     ((cfg.divUnroll != 0).option(w/cfg.divUnroll + 1).toSeq ++ | ||||
|      (cfg.mulUnroll != 0).option(mulw/cfg.mulUnroll)).reduce(_ max _)))) | ||||
|   val neg_out = Reg(Bool()) | ||||
|   val isHi = Reg(Bool()) | ||||
|   val resHi = Reg(Bool()) | ||||
|   val divisor = Reg(Bits(width = w+1)) // div only needs w bits | ||||
|   val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits | ||||
|  | ||||
|   val mulDecode = List( | ||||
|     FN_MUL    -> List(Y, N, X, X), | ||||
|     FN_MULH   -> List(Y, Y, Y, Y), | ||||
|     FN_MULHU  -> List(Y, Y, N, N), | ||||
|     FN_MULHSU -> List(Y, Y, Y, N)) | ||||
|   val divDecode = List( | ||||
|     FN_DIV    -> List(N, N, Y, Y), | ||||
|     FN_REM    -> List(N, Y, Y, Y), | ||||
|     FN_DIVU   -> List(N, N, N, N), | ||||
|     FN_REMU   -> List(N, Y, N, N)) | ||||
|   val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil = | ||||
|     DecodeLogic(io.req.bits.fn, List(X, X, X, X), List( | ||||
|                    FN_DIV    -> List(N, N, Y, Y), | ||||
|                    FN_REM    -> List(N, Y, Y, Y), | ||||
|                    FN_DIVU   -> List(N, N, N, N), | ||||
|                    FN_REMU   -> List(N, Y, N, N), | ||||
|                    FN_MUL    -> List(Y, N, X, X), | ||||
|                    FN_MULH   -> List(Y, Y, Y, Y), | ||||
|                    FN_MULHU  -> List(Y, Y, N, N), | ||||
|                    FN_MULHSU -> List(Y, Y, Y, N))).map(_ toBool) | ||||
|     DecodeLogic(io.req.bits.fn, List(X, X, X, X), | ||||
|       (if (cfg.divUnroll != 0) divDecode else Nil) ++ (if (cfg.mulUnroll != 0) mulDecode else Nil)).map(_.toBool) | ||||
|  | ||||
|   require(w == 32 || w == 64) | ||||
|   def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32 | ||||
| @@ -79,7 +84,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { | ||||
|   val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0)) | ||||
|   val negated_remainder = -result | ||||
|  | ||||
|   when (state === s_neg_inputs) { | ||||
|   if (cfg.divUnroll != 0) when (state === s_neg_inputs) { | ||||
|     when (remainder(w-1)) { | ||||
|       remainder := negated_remainder | ||||
|     } | ||||
| @@ -88,12 +93,12 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { | ||||
|     } | ||||
|     state := s_div | ||||
|   } | ||||
|   when (state === s_neg_output) { | ||||
|   if (cfg.divUnroll != 0) when (state === s_neg_output) { | ||||
|     remainder := negated_remainder | ||||
|     state := s_done_div | ||||
|     resHi := false | ||||
|   } | ||||
|   when (state === s_mul) { | ||||
|   if (cfg.mulUnroll != 0) when (state === s_mul) { | ||||
|     val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) | ||||
|     val mplierSign = remainder(w) | ||||
|     val mplier = mulReg(mulw-1,0) | ||||
| @@ -116,7 +121,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { | ||||
|       resHi := isHi | ||||
|     } | ||||
|   } | ||||
|   when (state === s_div) { | ||||
|   if (cfg.divUnroll != 0) when (state === s_div) { | ||||
|     val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) => | ||||
|       // the special case for iteration 0 is to save HW, not for correctness | ||||
|       val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0) | ||||
| @@ -156,7 +161,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module { | ||||
|     state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div)) | ||||
|     isHi := cmdHi | ||||
|     resHi := false | ||||
|     count := Mux[UInt](Bool(fastMulW) && cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) | ||||
|     count := (if (fastMulW) Mux[UInt](cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) else 0) | ||||
|     neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign) | ||||
|     divisor := Cat(rhs_sign, rhs_in) | ||||
|     remainder := lhs_in | ||||
|   | ||||
		Reference in New Issue
	
	Block a user