1
0

add optional early-out to mul/div

This commit is contained in:
Andrew Waterman 2012-10-09 18:29:50 -07:00
parent 27ddff1adb
commit fcd69dba98
2 changed files with 64 additions and 54 deletions

View File

@ -4,13 +4,13 @@ import Chisel._
import Node._ import Node._
import Constants._ import Constants._
class rocketDivider(width: Int) extends Component { class rocketDivider(w: Int, earlyOut: Boolean = false) extends Component {
val io = new ioMultiplier val io = new ioMultiplier
val s_ready :: s_neg_inputs :: s_busy :: s_neg_outputs :: s_done :: Nil = Enum(5) { UFix() }; val s_ready :: s_neg_inputs :: s_busy :: s_neg_outputs :: s_done :: Nil = Enum(5) { UFix() };
val state = Reg(resetVal = s_ready); val state = Reg(resetVal = s_ready);
val count = Reg() { UFix() }; val count = Reg() { UFix(width = log2Up(w+1)) }
val divby0 = Reg() { Bool() }; val divby0 = Reg() { Bool() };
val neg_quo = Reg() { Bool() }; val neg_quo = Reg() { Bool() };
val neg_rem = Reg() { Bool() }; val neg_rem = Reg() { Bool() };
@ -18,16 +18,14 @@ class rocketDivider(width: Int) extends Component {
val rem = Reg() { Bool() }; val rem = Reg() { Bool() };
val half = Reg() { Bool() }; val half = Reg() { Bool() };
val divisor = Reg() { UFix() }; val divisor = Reg() { Bits() }
val remainder = Reg() { UFix() }; val remainder = Reg() { Bits(width = 2*w+1) }
val subtractor = remainder(2*width, width).toUFix - divisor; val subtractor = remainder(2*w,w) - divisor
val dw = io.req.bits.fn(io.req.bits.fn.width-1) val dw = io.req.bits.fn(io.req.bits.fn.width-1)
val fn = io.req.bits.fn(io.req.bits.fn.width-2,0) val fn = io.req.bits.fn(io.req.bits.fn.width-2,0)
val tc = (fn === DIV_D) || (fn === DIV_R); val tc = (fn === DIV_D) || (fn === DIV_R);
val do_kill = io.req_kill && Reg(io.req.ready) // kill on 1st cycle only
switch (state) { switch (state) {
is (s_ready) { is (s_ready) {
when (io.req.valid) { when (io.req.valid) {
@ -35,13 +33,13 @@ class rocketDivider(width: Int) extends Component {
} }
} }
is (s_neg_inputs) { is (s_neg_inputs) {
state := Mux(do_kill, s_ready, s_busy) state := Mux(io.req_kill, s_ready, s_busy)
} }
is (s_busy) { is (s_busy) {
when (do_kill) { when (io.req_kill && Reg(io.req.ready)) {
state := s_ready state := s_ready
} }
.elsewhen (count === UFix(width)) { .elsewhen (count === UFix(w)) {
state := Mux(neg_quo || neg_rem, s_neg_outputs, s_done) state := Mux(neg_quo || neg_rem, s_neg_outputs, s_done)
} }
} }
@ -57,63 +55,69 @@ class rocketDivider(width: Int) extends Component {
// state machine // state machine
val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(width-1), io.req.bits.in0(width/2-1)).toBool val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(w-1), io.req.bits.in0(w/2-1))
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(width-1,width/2), Fill(width/2, lhs_sign)) val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(w-1,w/2), Fill(w/2, lhs_sign))
val lhs_in = Cat(lhs_hi, io.req.bits.in0(width/2-1,0)) val lhs_in = Cat(lhs_hi, io.req.bits.in0(w/2-1,0))
val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(width-1), io.req.bits.in1(width/2-1)).toBool val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(w-1), io.req.bits.in1(w/2-1))
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(width-1,width/2), Fill(width/2, rhs_sign)) val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w-1,w/2), Fill(w/2, rhs_sign))
val rhs_in = Cat(rhs_hi, io.req.bits.in1(width/2-1,0)) val rhs_in = Cat(rhs_hi, io.req.bits.in1(w/2-1,0))
when ((state === s_ready) && io.req.valid) { when (io.req.fire()) {
count := UFix(0, log2Up(width+1)); count := UFix(0)
half := (dw === DW_32); half := (dw === DW_32);
neg_quo := Bool(false); neg_quo := Bool(false);
neg_rem := Bool(false); neg_rem := Bool(false);
rem := (fn === DIV_R) || (fn === DIV_RU); rem := (fn === DIV_R) || (fn === DIV_RU);
reg_tag := io.req_tag; reg_tag := io.req_tag;
divby0 := Bool(true); divby0 := Bool(true);
divisor := rhs_in.toUFix; divisor := rhs_in
remainder := Cat(UFix(0,width+1), lhs_in).toUFix; remainder := lhs_in
} }
when (state === s_neg_inputs) { when (state === s_neg_inputs) {
neg_rem := remainder(width-1).toBool; neg_rem := remainder(w-1)
neg_quo := (remainder(width-1) != divisor(width-1)); neg_quo := (remainder(w-1) != divisor(w-1))
when (remainder(width-1).toBool) { when (remainder(w-1)) {
remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix; remainder := Cat(remainder(2*w, w), -remainder(w-1,0))
} }
when (divisor(width-1).toBool) { when (divisor(w-1)) {
divisor := subtractor(width-1,0); divisor := subtractor(w-1,0)
} }
} }
when (state === s_neg_outputs) { when (state === s_neg_outputs) {
when (neg_rem && neg_quo && !divby0) { when (neg_rem && neg_quo && !divby0) {
remainder := Cat(-remainder(2*width, width+1), remainder(width), -remainder(width-1,0)).toUFix; remainder := Cat(-remainder(2*w, w+1), remainder(w), -remainder(w-1,0))
} }
.elsewhen (neg_quo && !divby0) { .elsewhen (neg_quo && !divby0) {
remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix; remainder := Cat(remainder(2*w, w), -remainder(w-1,0))
} }
.elsewhen (neg_rem) { .elsewhen (neg_rem) {
remainder := Cat(-remainder(2*width, width+1), remainder(width,0)).toUFix; remainder := Cat(-remainder(2*w, w+1), remainder(w,0))
}
when (divisor(width-1).toBool) {
divisor := subtractor(width-1,0);
} }
} }
when (state === s_busy) { when (state === s_busy) {
count := count + UFix(1); count := count + UFix(1)
divby0 := divby0 && !subtractor(width).toBool;
remainder := Mux(subtractor(width).toBool, val msb = subtractor(w)
Cat(remainder(2*width-1, width), remainder(width-1,0), ~subtractor(width)), divby0 := divby0 && !msb
Cat(subtractor(width-1, 0), remainder(width-1,0), ~subtractor(width))).toUFix; remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb)
val divisorMSB = Log2(divisor, w)
val dividendMSB = Log2(remainder(w-1,0), w)
val eOutPos = UFix(w-1, log2Up(2*w)) + divisorMSB
val eOut = count === UFix(0) && eOutPos > dividendMSB && (divisorMSB != UFix(0) || divisor(0))
when (Bool(earlyOut) && eOut) {
val eOutDist = eOutPos - dividendMSB
val shift = Mux(eOutDist >= UFix(w-1), UFix(w-1), eOutDist(log2Up(w)-1,0))
remainder := remainder << shift
count := shift
}
} }
val result = Mux(rem, remainder(2*width, width+1), remainder(width-1,0)); val result = Mux(rem, remainder(2*w, w+1), remainder(w-1,0))
io.resp_bits := Mux(half, Cat(Fill(width/2, result(width/2-1)), result(width/2-1,0)), result); io.resp_bits := Mux(half, Cat(Fill(w/2, result(w/2-1)), result(w/2-1,0)), result)
io.resp_tag := reg_tag; io.resp_tag := reg_tag
io.resp_val := (state === s_done); io.resp_val := state === s_done
io.req.ready := state === s_ready
io.req.ready := (state === s_ready);
} }

View File

@ -61,7 +61,7 @@ class rocketVUMultiplier(nwbq: Int) extends Component {
io.vu.req <> io.cpu.req io.vu.req <> io.cpu.req
} }
class rocketMultiplier(unroll: Int = 1) extends Component { class rocketMultiplier(unroll: Int = 1, earlyOut: Boolean = false) extends Component {
val io = new ioMultiplier val io = new ioMultiplier
val w0 = io.req.bits.in0.getWidth val w0 = io.req.bits.in0.getWidth
@ -90,8 +90,6 @@ class rocketMultiplier(unroll: Int = 1) extends Component {
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign)) val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign))
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0)) val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0))
val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle
when (io.req.valid && io.req.ready) { when (io.req.valid && io.req.ready) {
r_val := Bool(true) r_val := Bool(true)
r_cnt := UFix(0, log2Up(cycles+1)) r_cnt := UFix(0, log2Up(cycles+1))
@ -102,25 +100,33 @@ class rocketMultiplier(unroll: Int = 1) extends Component {
r_prod:= rhs_in r_prod:= rhs_in
r_lsb := Bool(false) r_lsb := Bool(false)
} }
.elsewhen (io.resp_val && io.resp_rdy || do_kill) { // can only kill on first cycle .elsewhen (io.resp_val && io.resp_rdy || io.req_kill && r_cnt === UFix(0)) { // can only kill on first cycle
r_val := Bool(false) r_val := Bool(false)
} }
val eOutDist = (UFix(cycles)-r_cnt)*UFix(unroll)
val outShift = Mux(r_fn === MUL_LO, UFix(0), Mux(r_dw === DW_64, UFix(64), UFix(32)))
val shiftDist = Mux(r_cnt === UFix(cycles), outShift, eOutDist)
val eOutMask = (UFix(1) << eOutDist) - UFix(1)
val eOut = r_cnt != UFix(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toFix) & eOutMask).orR
val shift = r_prod.toFix >> shiftDist
val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0)) val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0))
when (r_val && (r_cnt != UFix(cycles))) { when (r_val && (r_cnt != UFix(cycles))) {
r_lsb := r_prod(unroll-1) r_lsb := r_prod(unroll-1)
r_prod := Cat(sum, r_prod(w-1,unroll)).toFix r_prod := Cat(sum, r_prod(w-1,unroll)).toFix
r_cnt := r_cnt + UFix(1) r_cnt := r_cnt + UFix(1)
when (eOut) {
r_prod := shift
r_cnt := UFix(cycles)
}
} }
val mul_output64 = Mux(r_fn === MUL_LO, r_prod(w0-1,0), r_prod(2*w0-1,w0)) val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0))
val mul_output32 = Mux(r_fn === MUL_LO, r_prod(w0/2-1,0), r_prod(w0-1,w0/2)) val out64 = shift(w0-1,0)
val mul_output32_ext = Cat(Fill(32, mul_output32(w0/2-1)), mul_output32)
val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext)
io.req.ready := !r_val io.req.ready := !r_val
io.resp_bits := mul_output; io.resp_bits := Mux(r_dw === DW_64, out64, out32)
io.resp_tag := r_tag; io.resp_tag := r_tag;
io.resp_val := r_val && (r_cnt === UFix(cycles)) io.resp_val := r_val && (r_cnt === UFix(cycles))
} }