diff --git a/rocket/src/main/scala/divider.scala b/rocket/src/main/scala/divider.scala index b016ee04..41268264 100644 --- a/rocket/src/main/scala/divider.scala +++ b/rocket/src/main/scala/divider.scala @@ -4,13 +4,13 @@ import Chisel._ import Node._ import Constants._ -class rocketDivider(width: Int) extends Component { +class rocketDivider(w: Int, earlyOut: Boolean = false) extends Component { val io = new ioMultiplier val s_ready :: s_neg_inputs :: s_busy :: s_neg_outputs :: s_done :: Nil = Enum(5) { UFix() }; val state = Reg(resetVal = s_ready); - val count = Reg() { UFix() }; + val count = Reg() { UFix(width = log2Up(w+1)) } val divby0 = Reg() { Bool() }; val neg_quo = Reg() { Bool() }; val neg_rem = Reg() { Bool() }; @@ -18,16 +18,14 @@ class rocketDivider(width: Int) extends Component { val rem = Reg() { Bool() }; val half = Reg() { Bool() }; - val divisor = Reg() { UFix() }; - val remainder = Reg() { UFix() }; - val subtractor = remainder(2*width, width).toUFix - divisor; + val divisor = Reg() { Bits() } + val remainder = Reg() { Bits(width = 2*w+1) } + val subtractor = remainder(2*w,w) - divisor val dw = io.req.bits.fn(io.req.bits.fn.width-1) val fn = io.req.bits.fn(io.req.bits.fn.width-2,0) val tc = (fn === DIV_D) || (fn === DIV_R); - val do_kill = io.req_kill && Reg(io.req.ready) // kill on 1st cycle only - switch (state) { is (s_ready) { when (io.req.valid) { @@ -35,13 +33,13 @@ class rocketDivider(width: Int) extends Component { } } is (s_neg_inputs) { - state := Mux(do_kill, s_ready, s_busy) + state := Mux(io.req_kill, s_ready, s_busy) } is (s_busy) { - when (do_kill) { + when (io.req_kill && Reg(io.req.ready)) { state := s_ready } - .elsewhen (count === UFix(width)) { + .elsewhen (count === UFix(w)) { state := Mux(neg_quo || neg_rem, s_neg_outputs, s_done) } } @@ -57,63 +55,69 @@ class rocketDivider(width: Int) extends Component { // state machine - val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(width-1), io.req.bits.in0(width/2-1)).toBool - val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(width-1,width/2), Fill(width/2, lhs_sign)) - val lhs_in = Cat(lhs_hi, io.req.bits.in0(width/2-1,0)) + val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(w-1), io.req.bits.in0(w/2-1)) + val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(w-1,w/2), Fill(w/2, lhs_sign)) + val lhs_in = Cat(lhs_hi, io.req.bits.in0(w/2-1,0)) - val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(width-1), io.req.bits.in1(width/2-1)).toBool - val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(width-1,width/2), Fill(width/2, rhs_sign)) - val rhs_in = Cat(rhs_hi, io.req.bits.in1(width/2-1,0)) + val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(w-1), io.req.bits.in1(w/2-1)) + val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w-1,w/2), Fill(w/2, rhs_sign)) + val rhs_in = Cat(rhs_hi, io.req.bits.in1(w/2-1,0)) - when ((state === s_ready) && io.req.valid) { - count := UFix(0, log2Up(width+1)); + when (io.req.fire()) { + count := UFix(0) half := (dw === DW_32); neg_quo := Bool(false); neg_rem := Bool(false); rem := (fn === DIV_R) || (fn === DIV_RU); reg_tag := io.req_tag; divby0 := Bool(true); - divisor := rhs_in.toUFix; - remainder := Cat(UFix(0,width+1), lhs_in).toUFix; + divisor := rhs_in + remainder := lhs_in } when (state === s_neg_inputs) { - neg_rem := remainder(width-1).toBool; - neg_quo := (remainder(width-1) != divisor(width-1)); - when (remainder(width-1).toBool) { - remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix; + neg_rem := remainder(w-1) + neg_quo := (remainder(w-1) != divisor(w-1)) + when (remainder(w-1)) { + remainder := Cat(remainder(2*w, w), -remainder(w-1,0)) } - when (divisor(width-1).toBool) { - divisor := subtractor(width-1,0); + when (divisor(w-1)) { + divisor := subtractor(w-1,0) } } when (state === s_neg_outputs) { when (neg_rem && neg_quo && !divby0) { - remainder := Cat(-remainder(2*width, width+1), remainder(width), -remainder(width-1,0)).toUFix; + remainder := Cat(-remainder(2*w, w+1), remainder(w), -remainder(w-1,0)) } .elsewhen (neg_quo && !divby0) { - remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix; + remainder := Cat(remainder(2*w, w), -remainder(w-1,0)) } .elsewhen (neg_rem) { - remainder := Cat(-remainder(2*width, width+1), remainder(width,0)).toUFix; - } - - when (divisor(width-1).toBool) { - divisor := subtractor(width-1,0); + remainder := Cat(-remainder(2*w, w+1), remainder(w,0)) } } when (state === s_busy) { - count := count + UFix(1); - divby0 := divby0 && !subtractor(width).toBool; - remainder := Mux(subtractor(width).toBool, - Cat(remainder(2*width-1, width), remainder(width-1,0), ~subtractor(width)), - Cat(subtractor(width-1, 0), remainder(width-1,0), ~subtractor(width))).toUFix; + count := count + UFix(1) + + val msb = subtractor(w) + divby0 := divby0 && !msb + remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb) + + val divisorMSB = Log2(divisor, w) + val dividendMSB = Log2(remainder(w-1,0), w) + val eOutPos = UFix(w-1, log2Up(2*w)) + divisorMSB + val eOut = count === UFix(0) && eOutPos > dividendMSB && (divisorMSB != UFix(0) || divisor(0)) + when (Bool(earlyOut) && eOut) { + val eOutDist = eOutPos - dividendMSB + val shift = Mux(eOutDist >= UFix(w-1), UFix(w-1), eOutDist(log2Up(w)-1,0)) + remainder := remainder << shift + count := shift + } } - val result = Mux(rem, remainder(2*width, width+1), remainder(width-1,0)); + val result = Mux(rem, remainder(2*w, w+1), remainder(w-1,0)) - io.resp_bits := Mux(half, Cat(Fill(width/2, result(width/2-1)), result(width/2-1,0)), result); - io.resp_tag := reg_tag; - io.resp_val := (state === s_done); - - io.req.ready := (state === s_ready); + io.resp_bits := Mux(half, Cat(Fill(w/2, result(w/2-1)), result(w/2-1,0)), result) + io.resp_tag := reg_tag + io.resp_val := state === s_done + io.req.ready := state === s_ready } diff --git a/rocket/src/main/scala/multiplier.scala b/rocket/src/main/scala/multiplier.scala index f5e3445b..c52ef782 100644 --- a/rocket/src/main/scala/multiplier.scala +++ b/rocket/src/main/scala/multiplier.scala @@ -61,7 +61,7 @@ class rocketVUMultiplier(nwbq: Int) extends Component { io.vu.req <> io.cpu.req } -class rocketMultiplier(unroll: Int = 1) extends Component { +class rocketMultiplier(unroll: Int = 1, earlyOut: Boolean = false) extends Component { val io = new ioMultiplier val w0 = io.req.bits.in0.getWidth @@ -89,8 +89,6 @@ class rocketMultiplier(unroll: Int = 1) extends Component { val rhs_sign = (fn === MUL_H) && rhs_msb val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign)) val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0)) - - val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle when (io.req.valid && io.req.ready) { r_val := Bool(true) @@ -102,25 +100,33 @@ class rocketMultiplier(unroll: Int = 1) extends Component { r_prod:= rhs_in r_lsb := Bool(false) } - .elsewhen (io.resp_val && io.resp_rdy || do_kill) { // can only kill on first cycle + .elsewhen (io.resp_val && io.resp_rdy || io.req_kill && r_cnt === UFix(0)) { // can only kill on first cycle r_val := Bool(false) } + val eOutDist = (UFix(cycles)-r_cnt)*UFix(unroll) + val outShift = Mux(r_fn === MUL_LO, UFix(0), Mux(r_dw === DW_64, UFix(64), UFix(32))) + val shiftDist = Mux(r_cnt === UFix(cycles), outShift, eOutDist) + val eOutMask = (UFix(1) << eOutDist) - UFix(1) + val eOut = r_cnt != UFix(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toFix) & eOutMask).orR + val shift = r_prod.toFix >> shiftDist + val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0)) when (r_val && (r_cnt != UFix(cycles))) { r_lsb := r_prod(unroll-1) r_prod := Cat(sum, r_prod(w-1,unroll)).toFix r_cnt := r_cnt + UFix(1) + when (eOut) { + r_prod := shift + r_cnt := UFix(cycles) + } } - val mul_output64 = Mux(r_fn === MUL_LO, r_prod(w0-1,0), r_prod(2*w0-1,w0)) - val mul_output32 = Mux(r_fn === MUL_LO, r_prod(w0/2-1,0), r_prod(w0-1,w0/2)) - val mul_output32_ext = Cat(Fill(32, mul_output32(w0/2-1)), mul_output32) - - val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext) + val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0)) + val out64 = shift(w0-1,0) io.req.ready := !r_val - io.resp_bits := mul_output; + io.resp_bits := Mux(r_dw === DW_64, out64, out32) io.resp_tag := r_tag; io.resp_val := r_val && (r_cnt === UFix(cycles)) }