add optional early-out to mul/div
This commit is contained in:
parent
27ddff1adb
commit
fcd69dba98
@ -4,13 +4,13 @@ import Chisel._
|
|||||||
import Node._
|
import Node._
|
||||||
import Constants._
|
import Constants._
|
||||||
|
|
||||||
class rocketDivider(width: Int) extends Component {
|
class rocketDivider(w: Int, earlyOut: Boolean = false) extends Component {
|
||||||
val io = new ioMultiplier
|
val io = new ioMultiplier
|
||||||
|
|
||||||
val s_ready :: s_neg_inputs :: s_busy :: s_neg_outputs :: s_done :: Nil = Enum(5) { UFix() };
|
val s_ready :: s_neg_inputs :: s_busy :: s_neg_outputs :: s_done :: Nil = Enum(5) { UFix() };
|
||||||
val state = Reg(resetVal = s_ready);
|
val state = Reg(resetVal = s_ready);
|
||||||
|
|
||||||
val count = Reg() { UFix() };
|
val count = Reg() { UFix(width = log2Up(w+1)) }
|
||||||
val divby0 = Reg() { Bool() };
|
val divby0 = Reg() { Bool() };
|
||||||
val neg_quo = Reg() { Bool() };
|
val neg_quo = Reg() { Bool() };
|
||||||
val neg_rem = Reg() { Bool() };
|
val neg_rem = Reg() { Bool() };
|
||||||
@ -18,16 +18,14 @@ class rocketDivider(width: Int) extends Component {
|
|||||||
val rem = Reg() { Bool() };
|
val rem = Reg() { Bool() };
|
||||||
val half = Reg() { Bool() };
|
val half = Reg() { Bool() };
|
||||||
|
|
||||||
val divisor = Reg() { UFix() };
|
val divisor = Reg() { Bits() }
|
||||||
val remainder = Reg() { UFix() };
|
val remainder = Reg() { Bits(width = 2*w+1) }
|
||||||
val subtractor = remainder(2*width, width).toUFix - divisor;
|
val subtractor = remainder(2*w,w) - divisor
|
||||||
|
|
||||||
val dw = io.req.bits.fn(io.req.bits.fn.width-1)
|
val dw = io.req.bits.fn(io.req.bits.fn.width-1)
|
||||||
val fn = io.req.bits.fn(io.req.bits.fn.width-2,0)
|
val fn = io.req.bits.fn(io.req.bits.fn.width-2,0)
|
||||||
val tc = (fn === DIV_D) || (fn === DIV_R);
|
val tc = (fn === DIV_D) || (fn === DIV_R);
|
||||||
|
|
||||||
val do_kill = io.req_kill && Reg(io.req.ready) // kill on 1st cycle only
|
|
||||||
|
|
||||||
switch (state) {
|
switch (state) {
|
||||||
is (s_ready) {
|
is (s_ready) {
|
||||||
when (io.req.valid) {
|
when (io.req.valid) {
|
||||||
@ -35,13 +33,13 @@ class rocketDivider(width: Int) extends Component {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
is (s_neg_inputs) {
|
is (s_neg_inputs) {
|
||||||
state := Mux(do_kill, s_ready, s_busy)
|
state := Mux(io.req_kill, s_ready, s_busy)
|
||||||
}
|
}
|
||||||
is (s_busy) {
|
is (s_busy) {
|
||||||
when (do_kill) {
|
when (io.req_kill && Reg(io.req.ready)) {
|
||||||
state := s_ready
|
state := s_ready
|
||||||
}
|
}
|
||||||
.elsewhen (count === UFix(width)) {
|
.elsewhen (count === UFix(w)) {
|
||||||
state := Mux(neg_quo || neg_rem, s_neg_outputs, s_done)
|
state := Mux(neg_quo || neg_rem, s_neg_outputs, s_done)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -57,63 +55,69 @@ class rocketDivider(width: Int) extends Component {
|
|||||||
|
|
||||||
// state machine
|
// state machine
|
||||||
|
|
||||||
val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(width-1), io.req.bits.in0(width/2-1)).toBool
|
val lhs_sign = tc && Mux(dw === DW_64, io.req.bits.in0(w-1), io.req.bits.in0(w/2-1))
|
||||||
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(width-1,width/2), Fill(width/2, lhs_sign))
|
val lhs_hi = Mux(dw === DW_64, io.req.bits.in0(w-1,w/2), Fill(w/2, lhs_sign))
|
||||||
val lhs_in = Cat(lhs_hi, io.req.bits.in0(width/2-1,0))
|
val lhs_in = Cat(lhs_hi, io.req.bits.in0(w/2-1,0))
|
||||||
|
|
||||||
val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(width-1), io.req.bits.in1(width/2-1)).toBool
|
val rhs_sign = tc && Mux(dw === DW_64, io.req.bits.in1(w-1), io.req.bits.in1(w/2-1))
|
||||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(width-1,width/2), Fill(width/2, rhs_sign))
|
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w-1,w/2), Fill(w/2, rhs_sign))
|
||||||
val rhs_in = Cat(rhs_hi, io.req.bits.in1(width/2-1,0))
|
val rhs_in = Cat(rhs_hi, io.req.bits.in1(w/2-1,0))
|
||||||
|
|
||||||
when ((state === s_ready) && io.req.valid) {
|
when (io.req.fire()) {
|
||||||
count := UFix(0, log2Up(width+1));
|
count := UFix(0)
|
||||||
half := (dw === DW_32);
|
half := (dw === DW_32);
|
||||||
neg_quo := Bool(false);
|
neg_quo := Bool(false);
|
||||||
neg_rem := Bool(false);
|
neg_rem := Bool(false);
|
||||||
rem := (fn === DIV_R) || (fn === DIV_RU);
|
rem := (fn === DIV_R) || (fn === DIV_RU);
|
||||||
reg_tag := io.req_tag;
|
reg_tag := io.req_tag;
|
||||||
divby0 := Bool(true);
|
divby0 := Bool(true);
|
||||||
divisor := rhs_in.toUFix;
|
divisor := rhs_in
|
||||||
remainder := Cat(UFix(0,width+1), lhs_in).toUFix;
|
remainder := lhs_in
|
||||||
}
|
}
|
||||||
when (state === s_neg_inputs) {
|
when (state === s_neg_inputs) {
|
||||||
neg_rem := remainder(width-1).toBool;
|
neg_rem := remainder(w-1)
|
||||||
neg_quo := (remainder(width-1) != divisor(width-1));
|
neg_quo := (remainder(w-1) != divisor(w-1))
|
||||||
when (remainder(width-1).toBool) {
|
when (remainder(w-1)) {
|
||||||
remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix;
|
remainder := Cat(remainder(2*w, w), -remainder(w-1,0))
|
||||||
}
|
}
|
||||||
when (divisor(width-1).toBool) {
|
when (divisor(w-1)) {
|
||||||
divisor := subtractor(width-1,0);
|
divisor := subtractor(w-1,0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (state === s_neg_outputs) {
|
when (state === s_neg_outputs) {
|
||||||
when (neg_rem && neg_quo && !divby0) {
|
when (neg_rem && neg_quo && !divby0) {
|
||||||
remainder := Cat(-remainder(2*width, width+1), remainder(width), -remainder(width-1,0)).toUFix;
|
remainder := Cat(-remainder(2*w, w+1), remainder(w), -remainder(w-1,0))
|
||||||
}
|
}
|
||||||
.elsewhen (neg_quo && !divby0) {
|
.elsewhen (neg_quo && !divby0) {
|
||||||
remainder := Cat(remainder(2*width, width), -remainder(width-1,0)).toUFix;
|
remainder := Cat(remainder(2*w, w), -remainder(w-1,0))
|
||||||
}
|
}
|
||||||
.elsewhen (neg_rem) {
|
.elsewhen (neg_rem) {
|
||||||
remainder := Cat(-remainder(2*width, width+1), remainder(width,0)).toUFix;
|
remainder := Cat(-remainder(2*w, w+1), remainder(w,0))
|
||||||
}
|
|
||||||
|
|
||||||
when (divisor(width-1).toBool) {
|
|
||||||
divisor := subtractor(width-1,0);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (state === s_busy) {
|
when (state === s_busy) {
|
||||||
count := count + UFix(1);
|
count := count + UFix(1)
|
||||||
divby0 := divby0 && !subtractor(width).toBool;
|
|
||||||
remainder := Mux(subtractor(width).toBool,
|
val msb = subtractor(w)
|
||||||
Cat(remainder(2*width-1, width), remainder(width-1,0), ~subtractor(width)),
|
divby0 := divby0 && !msb
|
||||||
Cat(subtractor(width-1, 0), remainder(width-1,0), ~subtractor(width))).toUFix;
|
remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb)
|
||||||
|
|
||||||
|
val divisorMSB = Log2(divisor, w)
|
||||||
|
val dividendMSB = Log2(remainder(w-1,0), w)
|
||||||
|
val eOutPos = UFix(w-1, log2Up(2*w)) + divisorMSB
|
||||||
|
val eOut = count === UFix(0) && eOutPos > dividendMSB && (divisorMSB != UFix(0) || divisor(0))
|
||||||
|
when (Bool(earlyOut) && eOut) {
|
||||||
|
val eOutDist = eOutPos - dividendMSB
|
||||||
|
val shift = Mux(eOutDist >= UFix(w-1), UFix(w-1), eOutDist(log2Up(w)-1,0))
|
||||||
|
remainder := remainder << shift
|
||||||
|
count := shift
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val result = Mux(rem, remainder(2*width, width+1), remainder(width-1,0));
|
val result = Mux(rem, remainder(2*w, w+1), remainder(w-1,0))
|
||||||
|
|
||||||
io.resp_bits := Mux(half, Cat(Fill(width/2, result(width/2-1)), result(width/2-1,0)), result);
|
io.resp_bits := Mux(half, Cat(Fill(w/2, result(w/2-1)), result(w/2-1,0)), result)
|
||||||
io.resp_tag := reg_tag;
|
io.resp_tag := reg_tag
|
||||||
io.resp_val := (state === s_done);
|
io.resp_val := state === s_done
|
||||||
|
io.req.ready := state === s_ready
|
||||||
io.req.ready := (state === s_ready);
|
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ class rocketVUMultiplier(nwbq: Int) extends Component {
|
|||||||
io.vu.req <> io.cpu.req
|
io.vu.req <> io.cpu.req
|
||||||
}
|
}
|
||||||
|
|
||||||
class rocketMultiplier(unroll: Int = 1) extends Component {
|
class rocketMultiplier(unroll: Int = 1, earlyOut: Boolean = false) extends Component {
|
||||||
val io = new ioMultiplier
|
val io = new ioMultiplier
|
||||||
|
|
||||||
val w0 = io.req.bits.in0.getWidth
|
val w0 = io.req.bits.in0.getWidth
|
||||||
@ -90,8 +90,6 @@ class rocketMultiplier(unroll: Int = 1) extends Component {
|
|||||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign))
|
val rhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, rhs_sign))
|
||||||
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0))
|
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in1(w0/2-1,0))
|
||||||
|
|
||||||
val do_kill = io.req_kill && r_cnt === UFix(0) // can only kill on 1st cycle
|
|
||||||
|
|
||||||
when (io.req.valid && io.req.ready) {
|
when (io.req.valid && io.req.ready) {
|
||||||
r_val := Bool(true)
|
r_val := Bool(true)
|
||||||
r_cnt := UFix(0, log2Up(cycles+1))
|
r_cnt := UFix(0, log2Up(cycles+1))
|
||||||
@ -102,25 +100,33 @@ class rocketMultiplier(unroll: Int = 1) extends Component {
|
|||||||
r_prod:= rhs_in
|
r_prod:= rhs_in
|
||||||
r_lsb := Bool(false)
|
r_lsb := Bool(false)
|
||||||
}
|
}
|
||||||
.elsewhen (io.resp_val && io.resp_rdy || do_kill) { // can only kill on first cycle
|
.elsewhen (io.resp_val && io.resp_rdy || io.req_kill && r_cnt === UFix(0)) { // can only kill on first cycle
|
||||||
r_val := Bool(false)
|
r_val := Bool(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val eOutDist = (UFix(cycles)-r_cnt)*UFix(unroll)
|
||||||
|
val outShift = Mux(r_fn === MUL_LO, UFix(0), Mux(r_dw === DW_64, UFix(64), UFix(32)))
|
||||||
|
val shiftDist = Mux(r_cnt === UFix(cycles), outShift, eOutDist)
|
||||||
|
val eOutMask = (UFix(1) << eOutDist) - UFix(1)
|
||||||
|
val eOut = r_cnt != UFix(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toFix) & eOutMask).orR
|
||||||
|
val shift = r_prod.toFix >> shiftDist
|
||||||
|
|
||||||
val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0))
|
val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0))
|
||||||
when (r_val && (r_cnt != UFix(cycles))) {
|
when (r_val && (r_cnt != UFix(cycles))) {
|
||||||
r_lsb := r_prod(unroll-1)
|
r_lsb := r_prod(unroll-1)
|
||||||
r_prod := Cat(sum, r_prod(w-1,unroll)).toFix
|
r_prod := Cat(sum, r_prod(w-1,unroll)).toFix
|
||||||
r_cnt := r_cnt + UFix(1)
|
r_cnt := r_cnt + UFix(1)
|
||||||
|
when (eOut) {
|
||||||
|
r_prod := shift
|
||||||
|
r_cnt := UFix(cycles)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val mul_output64 = Mux(r_fn === MUL_LO, r_prod(w0-1,0), r_prod(2*w0-1,w0))
|
val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0))
|
||||||
val mul_output32 = Mux(r_fn === MUL_LO, r_prod(w0/2-1,0), r_prod(w0-1,w0/2))
|
val out64 = shift(w0-1,0)
|
||||||
val mul_output32_ext = Cat(Fill(32, mul_output32(w0/2-1)), mul_output32)
|
|
||||||
|
|
||||||
val mul_output = Mux(r_dw === DW_64, mul_output64, mul_output32_ext)
|
|
||||||
|
|
||||||
io.req.ready := !r_val
|
io.req.ready := !r_val
|
||||||
io.resp_bits := mul_output;
|
io.resp_bits := Mux(r_dw === DW_64, out64, out32)
|
||||||
io.resp_tag := r_tag;
|
io.resp_tag := r_tag;
|
||||||
io.resp_val := r_val && (r_cnt === UFix(cycles))
|
io.resp_val := r_val && (r_cnt === UFix(cycles))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user