add optional early-out to mul/div
This commit is contained in:
parent
27ddff1adb
commit
fcd69dba98
@ -4,13 +4,13 @@ import Chisel._
|
||||
import Node._
|
||||
import Constants._
|
||||
|
||||
class rocketDivider(width: Int) extends Component {
|
||||
class rocketDivider(w: Int, earlyOut: Boolean = false) extends Component {
|
||||
val io = new ioMultiplier
|
||||
|
||||
val s_ready :: s_neg_inputs :: s_busy :: s_neg_outputs :: s_done :: Nil = Enum(5) { UFix() };
|
||||
val state = Reg(resetVal = s_ready);
|
||||
|
||||
val count = Reg() { UFix() };
|
||||
val count = Reg() { UFix(width = log2Up(w+1)) }
|
||||
val divby0 = Reg() { Bool() };
|
||||
val neg_quo = Reg() { Bool() };
|
||||
val neg_rem = Reg() { Bool() };
|
||||
@ -18,16 +18,14 @@ class rocketDivider(width: Int) extends Component {
|
||||
val rem = Reg() { Bool() };
|
||||
val half = Reg() { Bool() };
|
||||
|
||||
val divisor = Reg() { UFix() };
|
||||
val remainder = Reg() { UFix() };
|
||||
val subtractor = remainder(2*width, width).toUFix - divisor;
|
||||
val divisor = Reg() { Bits() }
|
||||
val remainder = Reg() { Bits(width = 2*w+1) }
|
||||
val subtractor = remainder(2*w,w) - divisor
|
||||
|
||||
val dw = io.req.bits.fn(io.req.bits.fn.width-1)
|
||||
val fn = io.req.bits.fn(io.req.bits.fn.width-2,0)
|
||||
val tc = (fn === DIV_D) || (fn === DIV_R);
|
||||
|
||||
val do_kill = io.req_kill && Reg(io.req.ready) // kill on 1st cycle only
|
||||
|
||||
switch (state) {
|
||||
is (s_ready) {
|
||||
when (io.req.valid) {
|
||||
@ -35,13 +33,13 @@ class rocketDivider(width: Int) extends Component {
|
||||
}
|
||||
}
|
||||
is (s_neg_inputs) {
|
||||
state := Mux(do_kill, s_ready, s_busy)
|
||||
state := Mux(io.req_kill, s_ready, s_busy)
|
||||
}
|
||||
is (s_busy) {
|
||||
when (do_kill) {
|
||||
when (io.req_kill && Reg(io.req.ready)) {
|
||||
state := s_ready
|
||||
}
|
||||
.elsewhen (count === UFix(width)) {
|
||||
.elsewhen (count === UFix(w)) {
|
||||
state := Mux(neg_quo || neg_rem, s_neg_outputs, s_done)
|
||||
}
|
||||
}
|
||||
@ -57,63 +55,69 @@ class rocketDivider(width: Int) extends Component {
|
||||
|
||||
// state machine
|
||||
|
||||
val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(width-1), io.req.bits.in0(width/2-1)).toBool
|
||||
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(width-1,width/2), Fill(width/2, lhs_sign))
|
||||
val lhs_in = Cat(lhs_hi, io.req.bits.in0(width/2-1,0))
|
||||
val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(w-1), io.req.bits.in0(w/2-1))
|
||||
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(w-1,w/2), Fill(w/2, lhs_sign))
|
||||
val lhs_in = Cat(lhs_hi, io.req.bits.in0(w/2-1,0))
|
||||
|
||||
val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(width-1), io.req.bits.in1(width/2-1)).toBool
|
||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(width-1,width/2), Fill(width/2, rhs_sign))
|
||||
val rhs_in = Cat(rhs_hi, io.req.bits.in1(width/2-1,0))
|
||||
val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(w-1), io.req.bits.in1(w/2-1))
|
||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w-1,w/2), Fill(w/2, rhs_sign))
|
||||
val rhs_in = Cat(rhs_hi, io.req.bits.in1(w/2-1,0))
|
||||
|
||||
when ((state === s_ready) && io.req.valid) {
|
||||
count := UFix(0, log2Up(width+1));
|
||||
when (io.req.fire()) {
|
||||
count := UFix(0)
|
||||
half := (dw === DW_32);
|
||||
neg_quo := Bool(false);
|
||||
neg_rem := Bool(false);
|
||||
rem := (fn === DIV_R) || (fn === DIV_RU);
|
||||
reg_tag := io.req_tag;
|
||||
divby0 := Bool(true);
|
||||
divisor := rhs_in.toUFix;
|
||||
remainder := Cat(UFix(0,width+1), lhs_in).toUFix;
|
||||
divisor := rhs_in
|
||||
remainder := lhs_in
|
||||
}
|
||||
when (state === s_neg_inputs) {
|
||||
neg_rem := remainder(width-1).toBool;
|
||||
neg_quo := (remainder(width-1) != divisor(width-1));
|
||||
when (remainder(width-1).toBool) {
|
||||
remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix;
|
||||
neg_rem := remainder(w-1)
|
||||
neg_quo := (remainder(w-1) != divisor(w-1))
|
||||
when (remainder(w-1)) {
|
||||
remainder := Cat(remainder(2*w, w), -remainder(w-1,0))
|
||||
}
|
||||
when (divisor(width-1).toBool) {
|
||||
divisor := subtractor(width-1,0);
|
||||
when (divisor(w-1)) {
|
||||
divisor := subtractor(w-1,0)
|
||||
}
|
||||
}
|
||||
when (state === s_neg_outputs) {
|
||||
when (neg_rem && neg_quo && !divby0) {
|
||||
remainder := Cat(-remainder(2*width, width+1), remainder(width), -remainder(width-1,0)).toUFix;
|
||||
remainder := Cat(-remainder(2*w, w+1), remainder(w), -remainder(w-1,0))
|
||||
}
|
||||
.elsewhen (neg_quo && !divby0) {
|
||||
remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix;
|
||||
remainder := Cat(remainder(2*w, w), -remainder(w-1,0))
|
||||
}
|
||||
.elsewhen (neg_rem) {
|
||||
remainder := Cat(-remainder(2*width, width+1), remainder(width,0)).toUFix;
|
||||
}
|
||||
|
||||
when (divisor(width-1).toBool) {
|
||||
divisor := subtractor(width-1,0);
|
||||
remainder := Cat(-remainder(2*w, w+1), remainder(w,0))
|
||||
}
|
||||
}
|
||||
when (state === s_busy) {
|
||||
count := count + UFix(1);
|
||||
divby0 := divby0 && !subtractor(width).toBool;
|
||||
remainder := Mux(subtractor(width).toBool,
|
||||
Cat(remainder(2*width-1, width), remainder(width-1,0), ~subtractor(width)),
|
||||
Cat(subtractor(width-1, 0), remainder(width-1,0), ~subtractor(width))).toUFix;
|
||||
count := count + UFix(1)
|
||||
|
||||
val msb = subtractor(w)
|
||||
divby0 := divby0 && !msb
|
||||
remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb)
|
||||
|
||||
val divisorMSB = Log2(divisor, w)
|
||||
val dividendMSB = Log2(remainder(w-1,0), w)
|
||||
val eOutPos = UFix(w-1, log2Up(2*w)) + divisorMSB
|
||||
val eOut = count === UFix(0) && eOutPos > dividendMSB && (divisorMSB != UFix(0) || divisor(0))
|
||||
when (Bool(earlyOut) && eOut) {
|
||||
val eOutDist = eOutPos - dividendMSB
|
||||
val shift = Mux(eOutDist >= UFix(w-1), UFix(w-1), eOutDist(log2Up(w)-1,0))
|
||||
remainder := remainder << shift
|
||||
count := shift
|
||||
}
|
||||
}
|
||||
|
||||
val result = Mux(rem, remainder(2*width, width+1), remainder(width-1,0));
|
||||
val result = Mux(rem, remainder(2*w, w+1), remainder(w-1,0))
|
||||
|
||||
io.resp_bits := Mux(half, Cat(Fill(width/2, result(width/2-1)), result(width/2-1,0)), result);
|
||||
io.resp_tag := reg_tag;
|
||||
io.resp_val := (state === s_done);
|
||||
|
||||
io.req.ready := (state === s_ready);
|
||||
io.resp_bits := Mux(half, Cat(Fill(w/2, result(w/2-1)), result(w/2-1,0)), result)
|
||||
io.resp_tag := reg_tag
|
||||
io.resp_val := state === s_done
|
||||
io.req.ready := state === s_ready
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ class rocketVUMultiplier(nwbq: Int) extends Component {
|
||||
io.vu.req <> io.cpu.req
|
||||
}
|
||||
|
||||
class rocketMultiplier(unroll: Int = 1) extends Component {
|
||||
class rocketMultiplier(unroll: Int = 1, earlyOut: Boolean = false) extends Component {
|
||||
val io = new ioMultiplier
|
||||
|
||||
val w0 = io.req.bits.in0.getWidth
|
||||
@ -90,8 +90,6 @@ class rocketMultiplier(unroll: Int = 1) extends Component {
|
||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign))
|
||||
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0))
|
||||
|
||||
val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle
|
||||
|
||||
when (io.req.valid && io.req.ready) {
|
||||
r_val := Bool(true)
|
||||
r_cnt := UFix(0, log2Up(cycles+1))
|
||||
@ -102,25 +100,33 @@ class rocketMultiplier(unroll: Int = 1) extends Component {
|
||||
r_prod:= rhs_in
|
||||
r_lsb := Bool(false)
|
||||
}
|
||||
.elsewhen (io.resp_val && io.resp_rdy || do_kill) { // can only kill on first cycle
|
||||
.elsewhen (io.resp_val && io.resp_rdy || io.req_kill && r_cnt === UFix(0)) { // can only kill on first cycle
|
||||
r_val := Bool(false)
|
||||
}
|
||||
|
||||
val eOutDist = (UFix(cycles)-r_cnt)*UFix(unroll)
|
||||
val outShift = Mux(r_fn === MUL_LO, UFix(0), Mux(r_dw === DW_64, UFix(64), UFix(32)))
|
||||
val shiftDist = Mux(r_cnt === UFix(cycles), outShift, eOutDist)
|
||||
val eOutMask = (UFix(1) << eOutDist) - UFix(1)
|
||||
val eOut = r_cnt != UFix(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toFix) & eOutMask).orR
|
||||
val shift = r_prod.toFix >> shiftDist
|
||||
|
||||
val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0))
|
||||
when (r_val && (r_cnt != UFix(cycles))) {
|
||||
r_lsb := r_prod(unroll-1)
|
||||
r_prod := Cat(sum, r_prod(w-1,unroll)).toFix
|
||||
r_cnt := r_cnt + UFix(1)
|
||||
when (eOut) {
|
||||
r_prod := shift
|
||||
r_cnt := UFix(cycles)
|
||||
}
|
||||
}
|
||||
|
||||
val mul_output64 = Mux(r_fn === MUL_LO, r_prod(w0-1,0), r_prod(2*w0-1,w0))
|
||||
val mul_output32 = Mux(r_fn === MUL_LO, r_prod(w0/2-1,0), r_prod(w0-1,w0/2))
|
||||
val mul_output32_ext = Cat(Fill(32, mul_output32(w0/2-1)), mul_output32)
|
||||
|
||||
val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext)
|
||||
val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0))
|
||||
val out64 = shift(w0-1,0)
|
||||
|
||||
io.req.ready := !r_val
|
||||
io.resp_bits := mul_output;
|
||||
io.resp_bits := Mux(r_dw === DW_64, out64, out32)
|
||||
io.resp_tag := r_tag;
|
||||
io.resp_val := r_val && (r_cnt === UFix(cycles))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user