diff --git a/rocket/src/main/scala/dpath.scala b/rocket/src/main/scala/dpath.scala index 84746631..f052f6f1 100644 --- a/rocket/src/main/scala/dpath.scala +++ b/rocket/src/main/scala/dpath.scala @@ -252,7 +252,7 @@ class rocketDpath extends Component io.ctrl.div_result_val := div.io.resp_val // multiplier - var mul_io = new rocketMultiplier().io + var mul_io = new rocketMultiplier(unroll = 6).io if (HAVE_VEC) { val vu_mul = new rocketVUMultiplier(nwbq = 1) diff --git a/rocket/src/main/scala/multiplier.scala b/rocket/src/main/scala/multiplier.scala index f56f3d56..f5e3445b 100644 --- a/rocket/src/main/scala/multiplier.scala +++ b/rocket/src/main/scala/multiplier.scala @@ -61,17 +61,12 @@ class rocketVUMultiplier(nwbq: Int) extends Component { io.vu.req <> io.cpu.req } -class rocketMultiplier extends Component { +class rocketMultiplier(unroll: Int = 1) extends Component { val io = new ioMultiplier - // w must be even (booth). - // we need an extra bit to handle signed vs. unsigned, - // so we need to add a second to keep w even. - val w = 64 + 2 - val unroll = 3 - require(w % 2 == 0 && (w/2) % unroll == 0) - - val cycles = w/unroll/2 + val w0 = io.req.bits.in0.getWidth + val w = (w0+1+unroll-1)/unroll*unroll + val cycles = w/unroll val r_val = Reg(resetVal = Bool(false)); val r_dw = Reg { Bits() } @@ -85,15 +80,15 @@ class rocketMultiplier extends Component { val dw = io.req.bits.fn(io.req.bits.fn.width-1) val fn = io.req.bits.fn(io.req.bits.fn.width-2,0) - val lhs_msb = Mux(dw === DW_64, io.req.bits.in0(63), io.req.bits.in0(31)).toBool + val lhs_msb = Mux(dw === DW_64, io.req.bits.in0(w0-1), io.req.bits.in0(w0/2-1)).toBool val lhs_sign = ((fn === MUL_H) || (fn === MUL_HSU)) && lhs_msb - val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(63,32), Fill(32, lhs_sign)) - val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in0(31,0)) + val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(w0-1,w0/2), Fill(w0/2, lhs_sign)) + val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in0(w0/2-1,0)) - val rhs_msb = Mux(dw === DW_64, io.req.bits.in1(63), io.req.bits.in1(31)).toBool + val rhs_msb = Mux(dw === DW_64, io.req.bits.in1(w0-1), io.req.bits.in1(w0/2-1)).toBool val rhs_sign = (fn === MUL_H) && rhs_msb - val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(63,32), Fill(32, rhs_sign)) - val rhs_in = Cat(rhs_sign, rhs_sign, rhs_hi, io.req.bits.in1(31,0)) + val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign)) + val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0)) val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle @@ -111,34 +106,16 @@ class rocketMultiplier extends Component { r_val := Bool(false) } - val lhs_sext = Cat(r_lhs(w-2), r_lhs(w-2), r_lhs).toUFix - val lhs_twice = Cat(r_lhs(w-2), r_lhs, Bits(0,1)).toUFix - - var prod = r_prod - var lsb = r_lsb - - for (i <- 0 until unroll) { - val addend = Mux(prod(0) != lsb, lhs_sext, - Mux(prod(0) != prod(1), lhs_twice, - UFix(0))); - val sub = prod(1) - val adder_lhs = Cat(prod(w*2-1), prod(w*2-1,w)).toUFix - val adder_rhs = Mux(sub, ~addend, addend) - val adder_out = (adder_lhs + adder_rhs + sub.toUFix)(w,0) - - lsb = prod(1) - prod = Cat(adder_out(w), adder_out, prod(w-1,2)) - } - + val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0)) when (r_val && (r_cnt != UFix(cycles))) { - r_lsb := lsb - r_prod := prod + r_lsb := r_prod(unroll-1) + r_prod := Cat(sum, r_prod(w-1,unroll)).toFix r_cnt := r_cnt + UFix(1) } - val mul_output64 = Mux(r_fn === MUL_LO, r_prod(63,0), r_prod(127,64)) - val mul_output32 = Mux(r_fn === MUL_LO, r_prod(31,0), r_prod(63,32)) - val mul_output32_ext = Cat(Fill(32, mul_output32(31)), mul_output32) + val mul_output64 = Mux(r_fn === MUL_LO, r_prod(w0-1,0), r_prod(2*w0-1,w0)) + val mul_output32 = Mux(r_fn === MUL_LO, r_prod(w0/2-1,0), r_prod(w0-1,w0/2)) + val mul_output32_ext = Cat(Fill(32, mul_output32(w0/2-1)), mul_output32) val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext)