From a50a1f7d50a74ad65f4d61f143cdc103f1769878 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Mon, 13 Jan 2014 21:37:16 -0800 Subject: [PATCH] Clean up multiplier/divider stuff --- rocket/src/main/scala/divider.scala | 199 ------------------------- rocket/src/main/scala/multiplier.scala | 154 ++++++++++++------- 2 files changed, 100 insertions(+), 253 deletions(-) delete mode 100644 rocket/src/main/scala/divider.scala diff --git a/rocket/src/main/scala/divider.scala b/rocket/src/main/scala/divider.scala deleted file mode 100644 index 55bd2aa5..00000000 --- a/rocket/src/main/scala/divider.scala +++ /dev/null @@ -1,199 +0,0 @@ -package rocket - -import Chisel._ -import ALU._ -import Util._ - -class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { - val io = new MultiplierIO - val w = io.req.bits.in1.getWidth - val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll - - val s_ready :: s_neg_inputs :: s_mul_busy :: s_div_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 7) - val state = Reg(init=s_ready) - - val req = Reg(io.req.bits.clone) - val count = Reg(UInt(width = log2Up(w+1))) - val divby0 = Reg(Bool()) - val neg_out = Reg(Bool()) - val divisor = Reg(Bits(width = w+1)) // div only needs w bits - val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits - - def sext(x: Bits, cmds: Vec[Bits]) = { - val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn) - val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) - (Cat(hi, x(w/2-1,0)), sign) - } - val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM, FN_MULH, FN_MULHSU)) - val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM, FN_MULH)) - - val subtractor = remainder(2*w,w) - divisor(w,0) - val negated_remainder = -remainder(w-1,0) - - when (state === s_neg_inputs) { - val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(req.fn) - state := Mux(isMul, s_mul_busy, s_div_busy) - when (remainder(w-1) || isMul) { - remainder := negated_remainder - } - when (divisor(w-1) || isMul) { - divisor := subtractor - } - } - when (state === s_neg_output) { - remainder := negated_remainder - state := s_done - } - when (state === s_move_rem) { - remainder := remainder(2*w, w+1) - state := Mux(neg_out, s_neg_output, s_done) - } - when (state === s_mul_busy) { - val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) - val mplier = mulReg(mulw-1,0) - val accum = mulReg(2*mulw,mulw).toSInt - val mpcand = divisor.toSInt - val prod = mplier(mulUnroll-1,0) * mpcand + accum - val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll)).toUInt - remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toSInt - - count := count + 1 - when (count === mulw/mulUnroll-1) { - state := s_done - when (AVec(FN_MULH, FN_MULHU, FN_MULHSU) contains req.fn) { - state := s_move_rem - } - } - } - when (state === s_div_busy) { - when (count === UInt(w)) { - state := Mux(neg_out && !divby0, s_neg_output, s_done) - when (AVec(FN_REM, FN_REMU) contains req.fn) { - state := s_move_rem - } - } - count := count + UInt(1) - - val msb = subtractor(w) - divby0 := divby0 && !msb - remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb) - - val divisorMSB = Log2(divisor(w-1,0), w) - val dividendMSB = Log2(remainder(w-1,0), w) - val eOutPos = UInt(w-1) + divisorMSB - dividendMSB - val eOutZero = divisorMSB > dividendMSB - val eOut = count === UInt(0) && (eOutPos > 0 || eOutZero) && (divisorMSB != UInt(0) || divisor(0)) - when (Bool(earlyOut) && eOut) { - val shift = Mux(eOutZero, UInt(w-1), eOutPos) - remainder := remainder(w-1,0) << shift - count := shift - } - } - when (io.resp.fire() || io.kill) { - state := s_ready - } - when (io.req.fire()) { - val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(io.req.bits.fn) - val isRem = AVec(FN_REM, FN_REMU).contains(io.req.bits.fn) - val mulState = Mux(lhs_sign, s_neg_inputs, s_mul_busy) - val divState = Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div_busy) - state := Mux(isMul, mulState, divState) - count := UInt(0) - neg_out := !isMul && Mux(isRem, lhs_sign, lhs_sign != rhs_sign) - divby0 := true - divisor := Cat(rhs_sign, rhs_in) - remainder := lhs_in - req := io.req.bits - } - - io.resp.bits := req - io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) - io.resp.valid := state === s_done - io.req.ready := state === s_ready -} - -class Divider(earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { - val io = new MultiplierIO - val w = io.req.bits.in1.getWidth - - val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6) - val state = Reg(init=s_ready) - - val count = Reg(UInt(width = log2Up(w+1))) - val divby0 = Reg(Bool()) - val neg_out = Reg(Bool()) - val r_req = Reg(io.req.bits) - - val divisor = Reg(Bits()) - val remainder = Reg(Bits(width = 2*w+1)) - val subtractor = remainder(2*w,w) - divisor - - def sext(x: Bits, cmds: Vec[Bits]) = { - val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn) - val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) - (Cat(hi, x(w/2-1,0)), sign) - } - val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM)) - val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM)) - - val r_isRem = isMulFN(r_req.fn, FN_REM) || isMulFN(r_req.fn, FN_REMU) - - when (state === s_neg_inputs) { - state := s_busy - when (remainder(w-1)) { - remainder := -remainder(w-1,0) - } - when (divisor(w-1)) { - divisor := subtractor(w-1,0) - } - } - when (state === s_neg_output) { - remainder := -remainder(w-1,0) - state := s_done - } - when (state === s_move_rem) { - remainder := remainder(2*w, w+1) - state := Mux(neg_out, s_neg_output, s_done) - } - when (state === s_busy) { - when (count === UInt(w)) { - state := Mux(r_isRem, s_move_rem, Mux(neg_out && !divby0, s_neg_output, s_done)) - } - count := count + UInt(1) - - val msb = subtractor(w) - divby0 := divby0 && !msb - remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb) - - val divisorMSB = Log2(divisor, w) - val dividendMSB = Log2(remainder(w-1,0), w) - val eOutPos = UInt(w-1, log2Up(2*w)) + divisorMSB - dividendMSB - val eOut = count === UInt(0) && eOutPos > 0 && (divisorMSB != UInt(0) || divisor(0)) - when (Bool(earlyOut) && eOut) { - val shift = eOutPos(log2Up(w)-1,0) - remainder := remainder(w-1,0) << shift - count := shift - when (eOutPos(log2Up(w))) { - remainder := remainder(w-1,0) << w-1 - count := w-1 - } - } - } - when (io.resp.fire() || io.kill) { - state := s_ready - } - when (io.req.fire()) { - state := Mux(lhs_sign || rhs_sign, s_neg_inputs, s_busy) - count := UInt(0) - neg_out := Mux(AVec(FN_REM, FN_REMU).contains(io.req.bits.fn), lhs_sign, lhs_sign != rhs_sign) - divby0 := true - divisor := rhs_in - remainder := lhs_in - r_req := io.req.bits - } - - io.resp.bits := r_req - io.resp.bits.data := Mux(r_req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) - io.resp.valid := state === s_done - io.req.ready := state === s_ready -} diff --git a/rocket/src/main/scala/multiplier.scala b/rocket/src/main/scala/multiplier.scala index 95d1218b..ae4ba082 100644 --- a/rocket/src/main/scala/multiplier.scala +++ b/rocket/src/main/scala/multiplier.scala @@ -2,6 +2,7 @@ package rocket import Chisel._ import ALU._ +import Util._ class MultiplierReq(implicit conf: RocketConfiguration) extends Bundle { val fn = Bits(width = SZ_ALU_FN) @@ -26,68 +27,113 @@ class MultiplierIO(implicit conf: RocketConfiguration) extends Bundle { val resp = Decoupled(new MultiplierResp) } -class Multiplier(unroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { +class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module { val io = new MultiplierIO + val w = io.req.bits.in1.getWidth + val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll + + val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6) + val state = Reg(init=s_ready) + + val req = Reg(io.req.bits) + val count = Reg(UInt(width = log2Up(w+1))) + val neg_out = Reg(Bool()) + val isMul = Reg(Bool()) + val isHi = Reg(Bool()) + val divisor = Reg(Bits(width = w+1)) // div only needs w bits + val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits - val w0 = io.req.bits.in1.getWidth - val w = (w0+1+unroll-1)/unroll*unroll - val cycles = w/unroll - - val r_val = Reg(init=Bool(false)) - val r_prod = Reg(Bits(width = w*2)) - val r_lsb = Reg(Bits()) - val r_cnt = Reg(UInt(width = log2Up(cycles+1))) - val r_req = Reg(new MultiplierReq) - val r_lhs = Reg(Bits(width = w0+1)) + val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil = + DecodeLogic(io.req.bits.fn, List(X, X, X, X), List( + FN_DIV -> List(N, N, Y, Y), + FN_REM -> List(N, Y, Y, Y), + FN_DIVU -> List(N, N, N, N), + FN_REMU -> List(N, Y, N, N), + FN_MUL -> List(Y, N, X, X), + FN_MULH -> List(Y, Y, Y, Y), + FN_MULHU -> List(Y, Y, N, N), + FN_MULHSU -> List(Y, Y, Y, N))) - val dw = io.req.bits.dw - val fn = io.req.bits.fn - - val lhs_msb = Mux(dw === DW_64, io.req.bits.in1(w0-1), io.req.bits.in1(w0/2-1)).toBool - val lhs_sign = (isMulFN(fn, FN_MULH) || isMulFN(fn, FN_MULHSU)) && lhs_msb - val lhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, lhs_sign)) - val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in1(w0/2-1,0)) - - val rhs_msb = Mux(dw === DW_64, io.req.bits.in2(w0-1), io.req.bits.in2(w0/2-1)).toBool - val rhs_sign = isMulFN(fn, FN_MULH) && rhs_msb - val rhs_hi = Mux(dw === DW_64, io.req.bits.in2(w0-1,w0/2), Fill(w0/2, rhs_sign)) - val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in2(w0/2-1,0)) - - when (io.req.fire()) { - r_val := Bool(true) - r_cnt := UInt(0, log2Up(cycles+1)) - r_req := io.req.bits - r_lhs := lhs_in - r_prod:= rhs_in - r_lsb := Bool(false) + def sext(x: Bits, signed: Bool) = { + val sign = signed && Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) + val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) + (Cat(hi, x(w/2-1,0)), sign) } - .elsewhen (io.resp.fire() || io.kill) { - r_val := Bool(false) + val (lhs_in, lhs_sign) = sext(io.req.bits.in1, lhsSigned) + val (rhs_in, rhs_sign) = sext(io.req.bits.in2, rhsSigned) + + val subtractor = remainder(2*w,w) - divisor(w,0) + val less = subtractor(w) + val negated_remainder = -remainder(w-1,0) + + when (state === s_neg_inputs) { + when (remainder(w-1) || isMul) { + remainder := negated_remainder + } + when (divisor(w-1) || isMul) { + divisor := subtractor + } + state := s_busy } - val eOutDist = (UInt(cycles)-r_cnt)*UInt(unroll) - val outShift = Mux(isMulFN(r_req.fn, FN_MUL), UInt(0), Mux(r_req.dw === DW_64, UInt(64), UInt(32))) - val shiftDist = Mux(r_cnt === UInt(cycles), outShift, eOutDist) - val eOutMask = (UInt(1) << eOutDist) - UInt(1) - val eOut = r_cnt != UInt(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toSInt) & eOutMask).orR - val shift = r_prod.toSInt >> shiftDist + when (state === s_neg_output) { + remainder := negated_remainder + state := s_done + } + when (state === s_move_rem) { + remainder := remainder(2*w, w+1) + state := Mux(neg_out, s_neg_output, s_done) + } + when (state === s_busy && isMul) { + val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0)) + val mplier = mulReg(mulw-1,0) + val accum = mulReg(2*mulw,mulw).toSInt + val mpcand = divisor.toSInt + val prod = mplier(mulUnroll-1,0) * mpcand + accum + val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll)).toUInt + remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toSInt - val sum = r_prod(2*w-1,w).toSInt + r_prod(unroll-1,0).toSInt * r_lhs.toSInt + Mux(r_lsb.toBool, r_lhs.toSInt, SInt(0)) - when (r_val && (r_cnt != UInt(cycles))) { - r_lsb := r_prod(unroll-1) - r_prod := Cat(sum, r_prod(w-1,unroll)).toSInt - r_cnt := r_cnt + UInt(1) - when (eOut) { - r_prod := shift - r_cnt := UInt(cycles) + count := count + 1 + when (count === mulw/mulUnroll-1) { + state := Mux(isHi, s_move_rem, s_done) } } + when (state === s_busy && !isMul) { + when (count === w) { + state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done)) + } + count := count + 1 - val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0)) - val out64 = shift(w0-1,0) - - io.req.ready := !r_val - io.resp.bits := r_req - io.resp.bits.data := Mux(r_req.dw === DW_64, out64, out32) - io.resp.valid := r_val && (r_cnt === UInt(cycles)) + remainder := Cat(Mux(less, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !less) + + val divisorMSB = Log2(divisor(w-1,0), w) + val dividendMSB = Log2(remainder(w-1,0), w) + val eOutPos = UInt(w-1) + divisorMSB - dividendMSB + val eOutZero = divisorMSB > dividendMSB + val eOut = count === 0 && less /* not divby0 */ && (eOutPos > 0 || eOutZero) + when (Bool(earlyOut) && eOut) { + val shift = Mux(eOutZero, UInt(w-1), eOutPos(log2Up(w)-1,0)) + remainder := remainder(w-1,0) << shift + count := shift + } + when (count === 0 && !less /* divby0 */) { neg_out := false } + } + when (io.resp.fire() || io.kill) { + state := s_ready + } + when (io.req.fire()) { + state := Mux(lhs_sign || rhs_sign && !cmdMul, s_neg_inputs, s_busy) + isMul := cmdMul + isHi := cmdHi + count := 0 + neg_out := !cmdMul && Mux(cmdHi, lhs_sign, lhs_sign != rhs_sign) + divisor := Cat(rhs_sign, rhs_in) + remainder := lhs_in + req := io.req.bits + } + + io.resp.bits := req + io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) + io.resp.valid := state === s_done + io.req.ready := state === s_ready }