simplify and improve multiplier
This commit is contained in:
parent
1864e41361
commit
27ddff1adb
@ -252,7 +252,7 @@ class rocketDpath extends Component
|
|||||||
io.ctrl.div_result_val := div.io.resp_val
|
io.ctrl.div_result_val := div.io.resp_val
|
||||||
|
|
||||||
// multiplier
|
// multiplier
|
||||||
var mul_io = new rocketMultiplier().io
|
var mul_io = new rocketMultiplier(unroll = 6).io
|
||||||
if (HAVE_VEC)
|
if (HAVE_VEC)
|
||||||
{
|
{
|
||||||
val vu_mul = new rocketVUMultiplier(nwbq = 1)
|
val vu_mul = new rocketVUMultiplier(nwbq = 1)
|
||||||
|
@ -61,17 +61,12 @@ class rocketVUMultiplier(nwbq: Int) extends Component {
|
|||||||
io.vu.req <> io.cpu.req
|
io.vu.req <> io.cpu.req
|
||||||
}
|
}
|
||||||
|
|
||||||
class rocketMultiplier extends Component {
|
class rocketMultiplier(unroll: Int = 1) extends Component {
|
||||||
val io = new ioMultiplier
|
val io = new ioMultiplier
|
||||||
// w must be even (booth).
|
|
||||||
// we need an extra bit to handle signed vs. unsigned,
|
|
||||||
// so we need to add a second to keep w even.
|
|
||||||
val w = 64 + 2
|
|
||||||
val unroll = 3
|
|
||||||
|
|
||||||
require(w % 2 == 0 && (w/2) % unroll == 0)
|
val w0 = io.req.bits.in0.getWidth
|
||||||
|
val w = (w0+1+unroll-1)/unroll*unroll
|
||||||
val cycles = w/unroll/2
|
val cycles = w/unroll
|
||||||
|
|
||||||
val r_val = Reg(resetVal = Bool(false));
|
val r_val = Reg(resetVal = Bool(false));
|
||||||
val r_dw = Reg { Bits() }
|
val r_dw = Reg { Bits() }
|
||||||
@ -85,15 +80,15 @@ class rocketMultiplier extends Component {
|
|||||||
val dw = io.req.bits.fn(io.req.bits.fn.width-1)
|
val dw = io.req.bits.fn(io.req.bits.fn.width-1)
|
||||||
val fn = io.req.bits.fn(io.req.bits.fn.width-2,0)
|
val fn = io.req.bits.fn(io.req.bits.fn.width-2,0)
|
||||||
|
|
||||||
val lhs_msb = Mux(dw === DW_64, io.req.bits.in0(63), io.req.bits.in0(31)).toBool
|
val lhs_msb = Mux(dw === DW_64, io.req.bits.in0(w0-1), io.req.bits.in0(w0/2-1)).toBool
|
||||||
val lhs_sign = ((fn === MUL_H) || (fn === MUL_HSU)) && lhs_msb
|
val lhs_sign = ((fn === MUL_H) || (fn === MUL_HSU)) && lhs_msb
|
||||||
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(63,32), Fill(32, lhs_sign))
|
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(w0-1,w0/2), Fill(w0/2, lhs_sign))
|
||||||
val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in0(31,0))
|
val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in0(w0/2-1,0))
|
||||||
|
|
||||||
val rhs_msb = Mux(dw === DW_64, io.req.bits.in1(63), io.req.bits.in1(31)).toBool
|
val rhs_msb = Mux(dw === DW_64, io.req.bits.in1(w0-1), io.req.bits.in1(w0/2-1)).toBool
|
||||||
val rhs_sign = (fn === MUL_H) && rhs_msb
|
val rhs_sign = (fn === MUL_H) && rhs_msb
|
||||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(63,32), Fill(32, rhs_sign))
|
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign))
|
||||||
val rhs_in = Cat(rhs_sign, rhs_sign, rhs_hi, io.req.bits.in1(31,0))
|
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0))
|
||||||
|
|
||||||
val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle
|
val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle
|
||||||
|
|
||||||
@ -111,34 +106,16 @@ class rocketMultiplier extends Component {
|
|||||||
r_val := Bool(false)
|
r_val := Bool(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
val lhs_sext = Cat(r_lhs(w-2), r_lhs(w-2), r_lhs).toUFix
|
val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0))
|
||||||
val lhs_twice = Cat(r_lhs(w-2), r_lhs, Bits(0,1)).toUFix
|
|
||||||
|
|
||||||
var prod = r_prod
|
|
||||||
var lsb = r_lsb
|
|
||||||
|
|
||||||
for (i <- 0 until unroll) {
|
|
||||||
val addend = Mux(prod(0) != lsb, lhs_sext,
|
|
||||||
Mux(prod(0) != prod(1), lhs_twice,
|
|
||||||
UFix(0)));
|
|
||||||
val sub = prod(1)
|
|
||||||
val adder_lhs = Cat(prod(w*2-1), prod(w*2-1,w)).toUFix
|
|
||||||
val adder_rhs = Mux(sub, ~addend, addend)
|
|
||||||
val adder_out = (adder_lhs + adder_rhs + sub.toUFix)(w,0)
|
|
||||||
|
|
||||||
lsb = prod(1)
|
|
||||||
prod = Cat(adder_out(w), adder_out, prod(w-1,2))
|
|
||||||
}
|
|
||||||
|
|
||||||
when (r_val && (r_cnt != UFix(cycles))) {
|
when (r_val && (r_cnt != UFix(cycles))) {
|
||||||
r_lsb := lsb
|
r_lsb := r_prod(unroll-1)
|
||||||
r_prod := prod
|
r_prod := Cat(sum, r_prod(w-1,unroll)).toFix
|
||||||
r_cnt := r_cnt + UFix(1)
|
r_cnt := r_cnt + UFix(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
val mul_output64 = Mux(r_fn === MUL_LO, r_prod(63,0), r_prod(127,64))
|
val mul_output64 = Mux(r_fn === MUL_LO, r_prod(w0-1,0), r_prod(2*w0-1,w0))
|
||||||
val mul_output32 = Mux(r_fn === MUL_LO, r_prod(31,0), r_prod(63,32))
|
val mul_output32 = Mux(r_fn === MUL_LO, r_prod(w0/2-1,0), r_prod(w0-1,w0/2))
|
||||||
val mul_output32_ext = Cat(Fill(32, mul_output32(31)), mul_output32)
|
val mul_output32_ext = Cat(Fill(32, mul_output32(w0/2-1)), mul_output32)
|
||||||
|
|
||||||
val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext)
|
val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user