1
0

Simplify and correct integer multiplier

This commit is contained in:
Andrew Waterman 2013-05-21 19:35:08 -07:00
parent 11133d6d4c
commit fe9adfe71b

View File

@ -9,7 +9,7 @@ import Util._
class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Component { class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Component {
val io = new MultiplierIO val io = new MultiplierIO
val w = io.req.bits.in1.getWidth val w = io.req.bits.in1.getWidth
val mulw = (w+1+mulUnroll-1)/mulUnroll*mulUnroll val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll
val s_ready :: s_neg_inputs :: s_mul_busy :: s_div_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(7) { UFix() }; val s_ready :: s_neg_inputs :: s_mul_busy :: s_div_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(7) { UFix() };
val state = Reg(resetVal = s_ready); val state = Reg(resetVal = s_ready);
@ -19,7 +19,7 @@ class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: Rocke
val divby0 = Reg{Bool()} val divby0 = Reg{Bool()}
val neg_out = Reg{Bool()} val neg_out = Reg{Bool()}
val divisor = Reg{Bits(width = w+1)} // div only needs w bits val divisor = Reg{Bits(width = w+1)} // div only needs w bits
val remainder = Reg{Bits(width = 2*mulw+1)} // div only needs 2*w+1 bits val remainder = Reg{Bits(width = 2*mulw+2)} // div only needs 2*w+1 bits
def sext(x: Bits, cmds: Vec[Bits]) = { def sext(x: Bits, cmds: Vec[Bits]) = {
val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn) val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn)
@ -29,19 +29,21 @@ class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: Rocke
val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM, FN_MULH, FN_MULHSU)) val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM, FN_MULH, FN_MULHSU))
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM, FN_MULH)) val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM, FN_MULH))
val subtractor = remainder(2*w,w) - divisor(w-1,0) val subtractor = remainder(2*w,w) - divisor(w,0)
val negated_remainder = -remainder(w-1,0)
when (state === s_neg_inputs) { when (state === s_neg_inputs) {
state := s_div_busy val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(req.fn)
when (remainder(w-1)) { state := Mux(isMul, s_mul_busy, s_div_busy)
remainder := -remainder(w-1,0) when (remainder(w-1) || isMul) {
remainder := negated_remainder
} }
when (divisor(w-1) && !AVec(FN_MULHU, FN_MULHSU).contains(req.fn)) { when (divisor(w-1) || isMul) {
divisor := subtractor(w-1,0) divisor := subtractor
} }
} }
when (state === s_neg_output) { when (state === s_neg_output) {
remainder := -remainder(w-1,0) remainder := negated_remainder
state := s_done state := s_done
} }
when (state === s_move_rem) { when (state === s_move_rem) {
@ -49,16 +51,13 @@ class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: Rocke
state := Mux(neg_out, s_neg_output, s_done) state := Mux(neg_out, s_neg_output, s_done)
} }
when (state === s_mul_busy) { when (state === s_mul_busy) {
val carryIn = remainder(w) val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
val mplier = Cat(remainder(2*mulw,w+1),remainder(w-1,0)).toFix val mplier = mulReg(mulw-1,0)
val accum = mulReg(2*mulw,mulw).toFix
val mpcand = divisor.toFix val mpcand = divisor.toFix
val prod0 = mplier(2*mulw-1,mulw) + val prod = mplier(mulUnroll-1,0) * mpcand + accum
(if (mulUnroll == 1) Mux(mplier(0), -Cat(mpcand < Fix(0), mpcand).toFix, Mux(carryIn, mpcand, Fix(0))) val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll))
else (mplier(mulUnroll-1,0) + carryIn.toUFix).toFix * mpcand) remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toFix
val prod = Mux(mplier(mulUnroll-1,0).andR && carryIn, mplier(2*mulw-1,mulw), prod0)
val sum = Cat(prod, mplier(mulw-1,mulUnroll))
val carryOut = mplier(mulUnroll-1)
remainder := Cat(sum(sum.getWidth-1,w), carryOut, sum(w-1,0)).toFix
count := count + 1 count := count + 1
when (count === mulw/mulUnroll-1) { when (count === mulw/mulUnroll-1) {
@ -81,7 +80,7 @@ class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: Rocke
divby0 := divby0 && !msb divby0 := divby0 && !msb
remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb) remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb)
val divisorMSB = Log2(divisor, w) val divisorMSB = Log2(divisor(w-1,0), w)
val dividendMSB = Log2(remainder(w-1,0), w) val dividendMSB = Log2(remainder(w-1,0), w)
val eOutPos = UFix(w-1, log2Up(2*w)) + divisorMSB - dividendMSB val eOutPos = UFix(w-1, log2Up(2*w)) + divisorMSB - dividendMSB
val eOut = count === UFix(0) && eOutPos > 0 && (divisorMSB != UFix(0) || divisor(0)) val eOut = count === UFix(0) && eOutPos > 0 && (divisorMSB != UFix(0) || divisor(0))
@ -101,12 +100,14 @@ class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: Rocke
when (io.req.fire()) { when (io.req.fire()) {
val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(io.req.bits.fn) val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(io.req.bits.fn)
val isRem = AVec(FN_REM, FN_REMU).contains(io.req.bits.fn) val isRem = AVec(FN_REM, FN_REMU).contains(io.req.bits.fn)
state := Mux(isMul, s_mul_busy, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div_busy)) val mulState = Mux(lhs_sign, s_neg_inputs, s_mul_busy)
val divState = Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div_busy)
state := Mux(isMul, mulState, divState)
count := UFix(0) count := UFix(0)
neg_out := !isMul && Mux(isRem, lhs_sign, lhs_sign != rhs_sign) neg_out := !isMul && Mux(isRem, lhs_sign, lhs_sign != rhs_sign)
divby0 := true divby0 := true
divisor := Cat(rhs_sign, rhs_in) divisor := Cat(rhs_sign, rhs_in)
remainder := Cat(Fill(mulw-w, isMul && lhs_sign), Bool(false), lhs_in) remainder := lhs_in
req := io.req.bits req := io.req.bits
} }