1
0
rocket-chip/rocket/src/main/scala/multiplier.scala

94 lines
3.1 KiB
Scala
Raw Normal View History

package rocket
import Chisel._
import ALU._
class MultiplierReq(implicit conf: RocketConfiguration) extends Bundle {
val fn = Bits(width = SZ_ALU_FN)
val dw = Bits(width = SZ_DW)
val in1 = Bits(width = conf.xprlen)
val in2 = Bits(width = conf.xprlen)
val tag = UFix(width = conf.nxprbits)
override def clone = new MultiplierReq().asInstanceOf[this.type]
}
class MultiplierResp(implicit conf: RocketConfiguration) extends Bundle {
val data = Bits(width = conf.xprlen)
val tag = UFix(width = conf.nxprbits)
override def clone = new MultiplierResp().asInstanceOf[this.type]
}
class MultiplierIO(implicit conf: RocketConfiguration) extends Bundle {
val req = new FIFOIO()(new MultiplierReq).flip
val kill = Bool(INPUT)
val resp = new FIFOIO()(new MultiplierResp)
}
class Multiplier(unroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Component {
val io = new MultiplierIO
2011-12-20 13:18:28 +01:00
val w0 = io.req.bits.in1.getWidth
2012-10-07 02:32:01 +02:00
val w = (w0+1+unroll-1)/unroll*unroll
val cycles = w/unroll
val r_val = Reg(resetVal = Bool(false));
val r_prod= Reg { Bits(width = w*2) }
2011-12-20 12:49:07 +01:00
val r_lsb = Reg { Bits() }
val r_cnt = Reg { UFix(width = log2Up(cycles+1)) }
val r_req = Reg{new MultiplierReq}
val r_lhs = Reg{Bits(width = w0+1)}
2011-12-17 16:20:00 +01:00
val dw = io.req.bits.dw
val fn = io.req.bits.fn
val lhs_msb = Mux(dw === DW_64, io.req.bits.in1(w0-1), io.req.bits.in1(w0/2-1)).toBool
val lhs_sign = (isMulFN(fn, FN_MULH) || isMulFN(fn, FN_MULHSU)) && lhs_msb
val lhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, lhs_sign))
val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in1(w0/2-1,0))
2011-12-17 16:20:00 +01:00
val rhs_msb = Mux(dw === DW_64, io.req.bits.in2(w0-1), io.req.bits.in2(w0/2-1)).toBool
val rhs_sign = isMulFN(fn, FN_MULH) && rhs_msb
val rhs_hi = Mux(dw === DW_64, io.req.bits.in2(w0-1,w0/2), Fill(w0/2, rhs_sign))
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in2(w0/2-1,0))
when (io.req.fire()) {
2012-02-12 02:20:33 +01:00
r_val := Bool(true)
r_cnt := UFix(0, log2Up(cycles+1))
r_req := io.req.bits
2012-02-12 02:20:33 +01:00
r_lhs := lhs_in
r_prod:= rhs_in
r_lsb := Bool(false)
}
.elsewhen (io.resp.fire() || io.kill) {
2012-02-12 02:20:33 +01:00
r_val := Bool(false)
2011-12-20 12:49:07 +01:00
}
2012-10-10 03:29:50 +02:00
val eOutDist = (UFix(cycles)-r_cnt)*UFix(unroll)
val outShift = Mux(isMulFN(r_req.fn, FN_MUL), UFix(0), Mux(r_req.dw === DW_64, UFix(64), UFix(32)))
2012-10-10 03:29:50 +02:00
val shiftDist = Mux(r_cnt === UFix(cycles), outShift, eOutDist)
val eOutMask = (UFix(1) << eOutDist) - UFix(1)
val eOut = r_cnt != UFix(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toFix) & eOutMask).orR
val shift = r_prod.toFix >> shiftDist
2012-10-07 02:32:01 +02:00
val sum = r_prod(2*w-1,w).toFix + r_prod(unroll-1,0).toFix * r_lhs.toFix + Mux(r_lsb, r_lhs.toFix, Fix(0))
2011-12-20 12:49:07 +01:00
when (r_val && (r_cnt != UFix(cycles))) {
2012-10-07 02:32:01 +02:00
r_lsb := r_prod(unroll-1)
r_prod := Cat(sum, r_prod(w-1,unroll)).toFix
2012-02-12 02:20:33 +01:00
r_cnt := r_cnt + UFix(1)
2012-10-10 03:29:50 +02:00
when (eOut) {
r_prod := shift
r_cnt := UFix(cycles)
}
2011-12-20 12:49:07 +01:00
}
2012-10-10 03:29:50 +02:00
val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0))
val out64 = shift(w0-1,0)
2011-12-17 16:30:47 +01:00
io.req.ready := !r_val
io.resp.bits := r_req
io.resp.bits.data := Mux(r_req.dw === DW_64, out64, out32)
io.resp.valid := r_val && (r_cnt === UFix(cycles))
}