Clean up multiplier/divider stuff
This commit is contained in:
		| @@ -1,199 +0,0 @@ | ||||
| package rocket | ||||
|  | ||||
| import Chisel._ | ||||
| import ALU._ | ||||
| import Util._ | ||||
|  | ||||
| class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { | ||||
|   val io = new MultiplierIO | ||||
|   val w = io.req.bits.in1.getWidth | ||||
|   val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll | ||||
|    | ||||
|   val s_ready :: s_neg_inputs :: s_mul_busy :: s_div_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 7) | ||||
|   val state = Reg(init=s_ready) | ||||
|    | ||||
|   val req = Reg(io.req.bits.clone) | ||||
|   val count = Reg(UInt(width = log2Up(w+1))) | ||||
|   val divby0 = Reg(Bool()) | ||||
|   val neg_out = Reg(Bool()) | ||||
|   val divisor = Reg(Bits(width = w+1)) // div only needs w bits | ||||
|   val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits | ||||
|  | ||||
|   def sext(x: Bits, cmds: Vec[Bits]) = { | ||||
|     val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn) | ||||
|     val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) | ||||
|     (Cat(hi, x(w/2-1,0)), sign) | ||||
|   } | ||||
|   val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM, FN_MULH, FN_MULHSU)) | ||||
|   val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM, FN_MULH)) | ||||
|          | ||||
|   val subtractor = remainder(2*w,w) - divisor(w,0) | ||||
|   val negated_remainder = -remainder(w-1,0) | ||||
|  | ||||
|   when (state === s_neg_inputs) { | ||||
|     val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(req.fn) | ||||
|     state := Mux(isMul, s_mul_busy, s_div_busy) | ||||
|     when (remainder(w-1) || isMul) { | ||||
|       remainder := negated_remainder | ||||
|     } | ||||
|     when (divisor(w-1) || isMul) { | ||||
|       divisor := subtractor | ||||
|     } | ||||
|   } | ||||
|   when (state === s_neg_output) { | ||||
|     remainder := negated_remainder | ||||
|     state := s_done | ||||
|   } | ||||
|   when (state === s_move_rem) { | ||||
|     remainder := remainder(2*w, w+1) | ||||
|     state := Mux(neg_out, s_neg_output, s_done) | ||||
|   } | ||||
|   when (state === s_mul_busy) { | ||||
|     val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) | ||||
|     val mplier = mulReg(mulw-1,0) | ||||
|     val accum = mulReg(2*mulw,mulw).toSInt | ||||
|     val mpcand = divisor.toSInt | ||||
|     val prod = mplier(mulUnroll-1,0) * mpcand + accum | ||||
|     val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll)).toUInt | ||||
|     remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toSInt | ||||
|  | ||||
|     count := count + 1 | ||||
|     when (count === mulw/mulUnroll-1) { | ||||
|       state := s_done | ||||
|       when (AVec(FN_MULH, FN_MULHU, FN_MULHSU) contains req.fn) { | ||||
|         state := s_move_rem | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   when (state === s_div_busy) { | ||||
|     when (count === UInt(w)) { | ||||
|       state := Mux(neg_out && !divby0, s_neg_output, s_done) | ||||
|       when (AVec(FN_REM, FN_REMU) contains req.fn) { | ||||
|         state := s_move_rem | ||||
|       } | ||||
|     } | ||||
|     count := count + UInt(1) | ||||
|  | ||||
|     val msb = subtractor(w) | ||||
|     divby0 := divby0 && !msb | ||||
|     remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb) | ||||
|  | ||||
|     val divisorMSB = Log2(divisor(w-1,0), w) | ||||
|     val dividendMSB = Log2(remainder(w-1,0), w) | ||||
|     val eOutPos = UInt(w-1) + divisorMSB - dividendMSB | ||||
|     val eOutZero = divisorMSB > dividendMSB | ||||
|     val eOut = count === UInt(0) && (eOutPos > 0 || eOutZero) && (divisorMSB != UInt(0) || divisor(0)) | ||||
|     when (Bool(earlyOut) && eOut) { | ||||
|       val shift = Mux(eOutZero, UInt(w-1), eOutPos) | ||||
|       remainder := remainder(w-1,0) << shift | ||||
|       count := shift | ||||
|     } | ||||
|   } | ||||
|   when (io.resp.fire() || io.kill) { | ||||
|     state := s_ready | ||||
|   } | ||||
|   when (io.req.fire()) { | ||||
|     val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(io.req.bits.fn) | ||||
|     val isRem = AVec(FN_REM, FN_REMU).contains(io.req.bits.fn) | ||||
|     val mulState = Mux(lhs_sign, s_neg_inputs, s_mul_busy) | ||||
|     val divState = Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div_busy) | ||||
|     state := Mux(isMul, mulState, divState) | ||||
|     count := UInt(0) | ||||
|     neg_out := !isMul && Mux(isRem, lhs_sign, lhs_sign != rhs_sign) | ||||
|     divby0 := true | ||||
|     divisor := Cat(rhs_sign, rhs_in) | ||||
|     remainder := lhs_in | ||||
|     req := io.req.bits | ||||
|   } | ||||
|  | ||||
|   io.resp.bits := req | ||||
|   io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) | ||||
|   io.resp.valid := state === s_done | ||||
|   io.req.ready := state === s_ready | ||||
| } | ||||
|  | ||||
| class Divider(earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { | ||||
|   val io = new MultiplierIO | ||||
|   val w = io.req.bits.in1.getWidth | ||||
|    | ||||
|   val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6) | ||||
|   val state = Reg(init=s_ready) | ||||
|    | ||||
|   val count = Reg(UInt(width = log2Up(w+1))) | ||||
|   val divby0 = Reg(Bool()) | ||||
|   val neg_out = Reg(Bool()) | ||||
|   val r_req = Reg(io.req.bits) | ||||
|    | ||||
|   val divisor     = Reg(Bits()) | ||||
|   val remainder   = Reg(Bits(width = 2*w+1)) | ||||
|   val subtractor  = remainder(2*w,w) - divisor | ||||
|  | ||||
|   def sext(x: Bits, cmds: Vec[Bits]) = { | ||||
|     val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn) | ||||
|     val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) | ||||
|     (Cat(hi, x(w/2-1,0)), sign) | ||||
|   } | ||||
|   val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM)) | ||||
|   val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM)) | ||||
|  | ||||
|   val r_isRem = isMulFN(r_req.fn, FN_REM) || isMulFN(r_req.fn, FN_REMU) | ||||
|          | ||||
|   when (state === s_neg_inputs) { | ||||
|     state := s_busy | ||||
|     when (remainder(w-1)) { | ||||
|       remainder := -remainder(w-1,0) | ||||
|     } | ||||
|     when (divisor(w-1)) { | ||||
|       divisor := subtractor(w-1,0) | ||||
|     } | ||||
|   } | ||||
|   when (state === s_neg_output) { | ||||
|     remainder := -remainder(w-1,0) | ||||
|     state := s_done | ||||
|   } | ||||
|   when (state === s_move_rem) { | ||||
|     remainder := remainder(2*w, w+1) | ||||
|     state := Mux(neg_out, s_neg_output, s_done) | ||||
|   } | ||||
|   when (state === s_busy) { | ||||
|     when (count === UInt(w)) { | ||||
|       state := Mux(r_isRem, s_move_rem, Mux(neg_out && !divby0, s_neg_output, s_done)) | ||||
|     } | ||||
|     count := count + UInt(1) | ||||
|  | ||||
|     val msb = subtractor(w) | ||||
|     divby0 := divby0 && !msb | ||||
|     remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb) | ||||
|  | ||||
|     val divisorMSB = Log2(divisor, w) | ||||
|     val dividendMSB = Log2(remainder(w-1,0), w) | ||||
|     val eOutPos = UInt(w-1, log2Up(2*w)) + divisorMSB - dividendMSB | ||||
|     val eOut = count === UInt(0) && eOutPos > 0 && (divisorMSB != UInt(0) || divisor(0)) | ||||
|     when (Bool(earlyOut) && eOut) { | ||||
|       val shift = eOutPos(log2Up(w)-1,0) | ||||
|       remainder := remainder(w-1,0) << shift | ||||
|       count := shift | ||||
|       when (eOutPos(log2Up(w))) { | ||||
|         remainder := remainder(w-1,0) << w-1 | ||||
|         count := w-1 | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   when (io.resp.fire() || io.kill) { | ||||
|     state := s_ready | ||||
|   } | ||||
|   when (io.req.fire()) { | ||||
|     state := Mux(lhs_sign || rhs_sign, s_neg_inputs, s_busy) | ||||
|     count := UInt(0) | ||||
|     neg_out := Mux(AVec(FN_REM, FN_REMU).contains(io.req.bits.fn), lhs_sign, lhs_sign != rhs_sign) | ||||
|     divby0 := true | ||||
|     divisor := rhs_in | ||||
|     remainder := lhs_in | ||||
|     r_req := io.req.bits | ||||
|   } | ||||
|  | ||||
|   io.resp.bits := r_req | ||||
|   io.resp.bits.data := Mux(r_req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) | ||||
|   io.resp.valid := state === s_done | ||||
|   io.req.ready := state === s_ready | ||||
| } | ||||
| @@ -2,6 +2,7 @@ package rocket | ||||
|  | ||||
| import Chisel._ | ||||
| import ALU._ | ||||
| import Util._ | ||||
|  | ||||
| class MultiplierReq(implicit conf: RocketConfiguration) extends Bundle { | ||||
|   val fn = Bits(width = SZ_ALU_FN) | ||||
| @@ -26,68 +27,113 @@ class MultiplierIO(implicit conf: RocketConfiguration) extends Bundle { | ||||
|   val resp = Decoupled(new MultiplierResp) | ||||
| } | ||||
|  | ||||
| class Multiplier(unroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { | ||||
| class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { | ||||
|   val io = new MultiplierIO | ||||
|   val w = io.req.bits.in1.getWidth | ||||
|   val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll | ||||
|   | ||||
|   val w0 = io.req.bits.in1.getWidth | ||||
|   val w = (w0+1+unroll-1)/unroll*unroll | ||||
|   val cycles = w/unroll | ||||
|   val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6) | ||||
|   val state = Reg(init=s_ready) | ||||
|   | ||||
|   val r_val = Reg(init=Bool(false)) | ||||
|   val r_prod = Reg(Bits(width = w*2)) | ||||
|   val r_lsb = Reg(Bits()) | ||||
|   val r_cnt = Reg(UInt(width = log2Up(cycles+1))) | ||||
|   val r_req = Reg(new MultiplierReq) | ||||
|   val r_lhs = Reg(Bits(width = w0+1)) | ||||
|   val req = Reg(io.req.bits) | ||||
|   val count = Reg(UInt(width = log2Up(w+1))) | ||||
|   val neg_out = Reg(Bool()) | ||||
|   val isMul = Reg(Bool()) | ||||
|   val isHi = Reg(Bool()) | ||||
|   val divisor = Reg(Bits(width = w+1)) // div only needs w bits | ||||
|   val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits | ||||
|  | ||||
|   val dw = io.req.bits.dw | ||||
|   val fn = io.req.bits.fn | ||||
|   val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil = | ||||
|     DecodeLogic(io.req.bits.fn, List(X, X, X, X), List( | ||||
|                    FN_DIV    -> List(N, N, Y, Y), | ||||
|                    FN_REM    -> List(N, Y, Y, Y), | ||||
|                    FN_DIVU   -> List(N, N, N, N), | ||||
|                    FN_REMU   -> List(N, Y, N, N), | ||||
|                    FN_MUL    -> List(Y, N, X, X), | ||||
|                    FN_MULH   -> List(Y, Y, Y, Y), | ||||
|                    FN_MULHU  -> List(Y, Y, N, N), | ||||
|                    FN_MULHSU -> List(Y, Y, Y, N))) | ||||
|  | ||||
|   val lhs_msb = Mux(dw === DW_64, io.req.bits.in1(w0-1), io.req.bits.in1(w0/2-1)).toBool | ||||
|   val lhs_sign = (isMulFN(fn, FN_MULH) || isMulFN(fn, FN_MULHSU)) && lhs_msb | ||||
|   val lhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, lhs_sign)) | ||||
|   val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in1(w0/2-1,0)) | ||||
|   def sext(x: Bits, signed: Bool) = { | ||||
|     val sign = signed && Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) | ||||
|     val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) | ||||
|     (Cat(hi, x(w/2-1,0)), sign) | ||||
|   } | ||||
|   val (lhs_in, lhs_sign) = sext(io.req.bits.in1, lhsSigned) | ||||
|   val (rhs_in, rhs_sign) = sext(io.req.bits.in2, rhsSigned) | ||||
|    | ||||
|   val rhs_msb = Mux(dw === DW_64, io.req.bits.in2(w0-1), io.req.bits.in2(w0/2-1)).toBool | ||||
|   val rhs_sign = isMulFN(fn, FN_MULH) && rhs_msb | ||||
|   val rhs_hi = Mux(dw === DW_64, io.req.bits.in2(w0-1,w0/2), Fill(w0/2, rhs_sign)) | ||||
|   val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in2(w0/2-1,0)) | ||||
|   val subtractor = remainder(2*w,w) - divisor(w,0) | ||||
|   val less = subtractor(w) | ||||
|   val negated_remainder = -remainder(w-1,0) | ||||
|  | ||||
|   when (state === s_neg_inputs) { | ||||
|     when (remainder(w-1) || isMul) { | ||||
|       remainder := negated_remainder | ||||
|     } | ||||
|     when (divisor(w-1) || isMul) { | ||||
|       divisor := subtractor | ||||
|     } | ||||
|     state := s_busy | ||||
|   } | ||||
|  | ||||
|   when (state === s_neg_output) { | ||||
|     remainder := negated_remainder | ||||
|     state := s_done | ||||
|   } | ||||
|   when (state === s_move_rem) { | ||||
|     remainder := remainder(2*w, w+1) | ||||
|     state := Mux(neg_out, s_neg_output, s_done) | ||||
|   } | ||||
|   when (state === s_busy && isMul) { | ||||
|     val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) | ||||
|     val mplier = mulReg(mulw-1,0) | ||||
|     val accum = mulReg(2*mulw,mulw).toSInt | ||||
|     val mpcand = divisor.toSInt | ||||
|     val prod = mplier(mulUnroll-1,0) * mpcand + accum | ||||
|     val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll)).toUInt | ||||
|     remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toSInt | ||||
|  | ||||
|     count := count + 1 | ||||
|     when (count === mulw/mulUnroll-1) { | ||||
|       state := Mux(isHi, s_move_rem, s_done) | ||||
|     } | ||||
|   } | ||||
|   when (state === s_busy && !isMul) { | ||||
|     when (count === w) { | ||||
|       state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done)) | ||||
|     } | ||||
|     count := count + 1 | ||||
|  | ||||
|     remainder := Cat(Mux(less, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !less) | ||||
|  | ||||
|     val divisorMSB = Log2(divisor(w-1,0), w) | ||||
|     val dividendMSB = Log2(remainder(w-1,0), w) | ||||
|     val eOutPos = UInt(w-1) + divisorMSB - dividendMSB | ||||
|     val eOutZero = divisorMSB > dividendMSB | ||||
|     val eOut = count === 0 && less /* not divby0 */ && (eOutPos > 0 || eOutZero) | ||||
|     when (Bool(earlyOut) && eOut) { | ||||
|       val shift = Mux(eOutZero, UInt(w-1), eOutPos(log2Up(w)-1,0)) | ||||
|       remainder := remainder(w-1,0) << shift | ||||
|       count := shift | ||||
|     } | ||||
|     when (count === 0 && !less /* divby0 */) { neg_out := false } | ||||
|   } | ||||
|   when (io.resp.fire() || io.kill) { | ||||
|     state := s_ready | ||||
|   } | ||||
|   when (io.req.fire()) { | ||||
|     r_val := Bool(true) | ||||
|     r_cnt := UInt(0, log2Up(cycles+1)) | ||||
|     r_req := io.req.bits | ||||
|     r_lhs := lhs_in | ||||
|     r_prod:= rhs_in | ||||
|     r_lsb := Bool(false) | ||||
|   } | ||||
|   .elsewhen (io.resp.fire() || io.kill) { | ||||
|     r_val := Bool(false) | ||||
|     state := Mux(lhs_sign || rhs_sign && !cmdMul, s_neg_inputs, s_busy) | ||||
|     isMul := cmdMul | ||||
|     isHi := cmdHi | ||||
|     count := 0 | ||||
|     neg_out := !cmdMul && Mux(cmdHi, lhs_sign, lhs_sign != rhs_sign) | ||||
|     divisor := Cat(rhs_sign, rhs_in) | ||||
|     remainder := lhs_in | ||||
|     req := io.req.bits | ||||
|   } | ||||
|  | ||||
|   val eOutDist = (UInt(cycles)-r_cnt)*UInt(unroll) | ||||
|   val outShift = Mux(isMulFN(r_req.fn, FN_MUL), UInt(0), Mux(r_req.dw === DW_64, UInt(64), UInt(32))) | ||||
|   val shiftDist = Mux(r_cnt === UInt(cycles), outShift, eOutDist) | ||||
|   val eOutMask = (UInt(1) << eOutDist) - UInt(1) | ||||
|   val eOut = r_cnt != UInt(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toSInt) & eOutMask).orR | ||||
|   val shift = r_prod.toSInt >> shiftDist | ||||
|  | ||||
|   val sum = r_prod(2*w-1,w).toSInt + r_prod(unroll-1,0).toSInt * r_lhs.toSInt + Mux(r_lsb.toBool, r_lhs.toSInt, SInt(0)) | ||||
|   when (r_val && (r_cnt != UInt(cycles))) { | ||||
|     r_lsb  := r_prod(unroll-1) | ||||
|     r_prod := Cat(sum, r_prod(w-1,unroll)).toSInt | ||||
|     r_cnt := r_cnt + UInt(1) | ||||
|     when (eOut) { | ||||
|       r_prod := shift | ||||
|       r_cnt := UInt(cycles) | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0)) | ||||
|   val out64 = shift(w0-1,0) | ||||
|   | ||||
|   io.req.ready := !r_val | ||||
|   io.resp.bits := r_req | ||||
|   io.resp.bits.data := Mux(r_req.dw === DW_64, out64, out32) | ||||
|   io.resp.valid := r_val && (r_cnt === UInt(cycles)) | ||||
|   io.resp.bits := req | ||||
|   io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) | ||||
|   io.resp.valid := state === s_done | ||||
|   io.req.ready := state === s_ready | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user