diff --git a/rocket/src/main/scala/dpath.scala b/rocket/src/main/scala/dpath.scala index 9feaaa16..a907faf6 100644 --- a/rocket/src/main/scala/dpath.scala +++ b/rocket/src/main/scala/dpath.scala @@ -292,6 +292,7 @@ class rocketDpath extends Component io.ctrl.mul_rdy := mul.io.mul_rdy io.ctrl.mul_result_val := mul.io.result_val; + mul.io.result_rdy := io.ctrl.mul_wb io.ctrl.ex_waddr := ex_reg_waddr; // for load/use hazard detection & bypass control diff --git a/rocket/src/main/scala/multiplier.scala b/rocket/src/main/scala/multiplier.scala index e1c17d8c..f3570bbf 100644 --- a/rocket/src/main/scala/multiplier.scala +++ b/rocket/src/main/scala/multiplier.scala @@ -10,63 +10,82 @@ class ioMultiplier(width: Int) extends Bundle { val mul_rdy = Bool('output); val dw = UFix(1, 'input); val mul_fn = UFix(2, 'input); - val mul_tag = UFix(5, 'input); + val mul_tag = UFix(CPU_TAG_BITS, 'input); val in0 = Bits(width, 'input); val in1 = Bits(width, 'input); // responses val result = Bits(width, 'output); - val result_tag = UFix(5, 'output); + val result_tag = UFix(CPU_TAG_BITS, 'output); val result_val = Bool('output); + val result_rdy = Bool('input); } class rocketMultiplier extends Component { val io = new ioMultiplier(64); + val width = 64 + 2 + val cycles = width/2 val r_val = Reg(resetVal = Bool(false)); - val r_dw = Reg(resetVal = UFix(0,1)); - val r_fn = Reg(resetVal = UFix(0,3)); - val r_tag = Reg(resetVal = UFix(0,5)); - val r_lhs = Reg(resetVal = Bits(0,65)); - val r_rhs = Reg(resetVal = Bits(0,65)); + val r_dw = Reg { UFix() } + val r_fn = Reg { UFix() } + val r_tag = Reg { UFix() } + val r_lhs = Reg { Bits() } + val r_prod= Reg { Bits(width = width*2) } + val r_lsb = Reg { Bits() } + val r_cnt = Reg { UFix(width = log2up(cycles+1)) } val lhs_msb = Mux(io.dw === DW_64, io.in0(63), io.in0(31)).toBool val lhs_sign = ((io.mul_fn === MUL_HS) || (io.mul_fn === MUL_HSU)) && lhs_msb val lhs_hi = Mux(io.dw === DW_64, io.in0(63,32), Fill(32, lhs_sign)) - val lhs = Cat(lhs_sign, lhs_hi, io.in0(31,0)) + val lhs_in = Cat(lhs_sign, lhs_hi, io.in0(31,0)) val rhs_msb = Mux(io.dw === DW_64, io.in1(63), io.in1(31)).toBool val rhs_sign = (io.mul_fn === MUL_HS) && rhs_msb val rhs_hi = Mux(io.dw === DW_64, io.in1(63,32), Fill(32, rhs_sign)) - val rhs = Cat(rhs_sign, rhs_hi, io.in1(31,0)) + val rhs_in = Cat(rhs_sign, rhs_sign, rhs_hi, io.in1(31,0)) - r_val <== io.mul_val; - when (io.mul_val) { + when (io.mul_val && io.mul_rdy) { + r_val <== Bool(true) + r_cnt <== UFix(0, log2up(cycles+1)) r_dw <== io.dw - r_fn <== io.mul_fn; - r_tag <== io.mul_tag; - r_lhs <== lhs; - r_rhs <== rhs; + r_fn <== io.mul_fn + r_tag <== io.mul_tag + r_lhs <== lhs_in + r_prod<== rhs_in + r_lsb <== Bool(false) + } + when (io.result_val && io.result_rdy) { + r_val <== Bool(false) } - - val mul_result = r_lhs.toFix * r_rhs.toFix; - val mul_output64 = Mux(r_fn === MUL_LO, mul_result(63,0), mul_result(127,64)) - val mul_output32 = Mux(r_fn === MUL_LO, mul_result(31,0), mul_result(63,31)) + val lhs_sext = Cat(r_lhs(width-2), r_lhs(width-2), r_lhs).toUFix + val lhs_twice = Cat(r_lhs(width-2), r_lhs, Bits(0,1)).toUFix + + val addend = Mux(r_prod(0) != r_lsb, lhs_sext, + Mux(r_prod(0) != r_prod(1), lhs_twice, + UFix(0))); + val sub = r_prod(1) + val adder_lhs = Cat(r_prod(width*2-1), r_prod(width*2-1,width)).toUFix + val adder_rhs = Mux(sub, ~addend, addend) + val adder_out = (adder_lhs + adder_rhs + sub.toUFix)(width,0) + + when (r_val && (r_cnt != UFix(cycles))) { + r_lsb <== r_prod(1) + r_prod <== Cat(adder_out(width), adder_out, r_prod(width-1,2)) + r_cnt <== r_cnt + UFix(1) + } + + val mul_output64 = Mux(r_fn === MUL_LO, r_prod(63,0), r_prod(127,64)) + val mul_output32 = Mux(r_fn === MUL_LO, r_prod(31,0), r_prod(63,31)) val mul_output32_ext = Cat(Fill(32, mul_output32(31)), mul_output32) val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext) - - // just a hack for now, this should be a parameterized number of stages - val r_result = Reg(Reg(Reg(mul_output))); - val r_result_tag = Reg(Reg(Reg(r_tag))); - val r_result_val = Reg(Reg(Reg(r_val))); - io.mul_rdy := Bool(true) - io.result := r_result; - io.result_tag := r_result_tag; - io.result_val := r_result_val; - + io.mul_rdy := !r_val + io.result := mul_output; + io.result_tag := r_tag; + io.result_val := r_val && (r_cnt === UFix(cycles)) } }