1
0
rocket-chip/rocket/src/main/scala/multiplier.scala

153 lines
5.2 KiB
Scala
Raw Normal View History

2014-09-13 03:06:41 +02:00
// See LICENSE for license details.
package rocket
import Chisel._
import ALU._
2014-01-14 06:37:16 +01:00
import Util._
2015-10-06 06:48:05 +02:00
class MultiplierReq(dataBits: Int, tagBits: Int) extends Bundle {
val fn = Bits(width = SZ_ALU_FN)
val dw = Bits(width = SZ_DW)
2015-10-06 06:48:05 +02:00
val in1 = Bits(width = dataBits)
val in2 = Bits(width = dataBits)
val tag = UInt(width = tagBits)
override def cloneType = new MultiplierReq(dataBits, tagBits).asInstanceOf[this.type]
}
2015-10-06 06:48:05 +02:00
class MultiplierResp(dataBits: Int, tagBits: Int) extends Bundle {
val data = Bits(width = dataBits)
val tag = UInt(width = tagBits)
override def cloneType = new MultiplierResp(dataBits, tagBits).asInstanceOf[this.type]
}
2015-10-06 06:48:05 +02:00
class MultiplierIO(dataBits: Int, tagBits: Int) extends Bundle {
val req = Decoupled(new MultiplierReq(dataBits, tagBits)).flip
val kill = Bool(INPUT)
2015-10-06 06:48:05 +02:00
val resp = Decoupled(new MultiplierResp(dataBits, tagBits))
}
2015-10-06 06:48:05 +02:00
class MulDiv(
width: Int,
nXpr: Int = 32,
unroll: Int = 1,
earlyOut: Boolean = false) extends Module {
val io = new MultiplierIO(width, log2Up(nXpr))
2014-01-14 06:37:16 +01:00
val w = io.req.bits.in1.getWidth
2015-10-06 06:48:05 +02:00
val mulw = (w+unroll-1)/unroll*unroll
2014-01-14 06:37:16 +01:00
val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6)
val state = Reg(init=s_ready)
val req = Reg(io.req.bits)
val count = Reg(UInt(width = log2Up(w+1)))
val neg_out = Reg(Bool())
val isMul = Reg(Bool())
val isHi = Reg(Bool())
val divisor = Reg(Bits(width = w+1)) // div only needs w bits
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
2011-12-20 13:18:28 +01:00
2014-01-14 06:37:16 +01:00
val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil =
DecodeLogic(io.req.bits.fn, List(X, X, X, X), List(
FN_DIV -> List(N, N, Y, Y),
FN_REM -> List(N, Y, Y, Y),
FN_DIVU -> List(N, N, N, N),
FN_REMU -> List(N, Y, N, N),
FN_MUL -> List(Y, N, X, X),
FN_MULH -> List(Y, Y, Y, Y),
FN_MULHU -> List(Y, Y, N, N),
FN_MULHSU -> List(Y, Y, Y, N))).map(_ toBool)
2011-12-17 16:20:00 +01:00
require(w == 32 || w == 64)
def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32
def sext(x: Bits, halfW: Bool, signed: Bool) = {
val sign = signed && Mux(halfW, x(w/2-1), x(w-1))
val hi = Mux(halfW, Fill(w/2, sign), x(w-1,w/2))
2014-01-14 06:37:16 +01:00
(Cat(hi, x(w/2-1,0)), sign)
}
val (lhs_in, lhs_sign) = sext(io.req.bits.in1, halfWidth(io.req.bits), lhsSigned)
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned)
2014-01-14 06:37:16 +01:00
val subtractor = remainder(2*w,w) - divisor(w,0)
val less = subtractor(w)
val negated_remainder = -remainder(w-1,0)
2014-01-14 06:37:16 +01:00
when (state === s_neg_inputs) {
when (remainder(w-1) || isMul) {
remainder := negated_remainder
}
when (divisor(w-1) || isMul) {
divisor := subtractor
}
state := s_busy
}
2011-12-17 16:20:00 +01:00
2014-01-14 06:37:16 +01:00
when (state === s_neg_output) {
remainder := negated_remainder
state := s_done
}
2014-01-14 06:37:16 +01:00
when (state === s_move_rem) {
remainder := remainder(2*w, w+1)
state := Mux(neg_out, s_neg_output, s_done)
2011-12-20 12:49:07 +01:00
}
2014-01-14 06:37:16 +01:00
when (state === s_busy && isMul) {
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
val mplier = mulReg(mulw-1,0)
val accum = mulReg(2*mulw,mulw).asSInt
val mpcand = divisor.asSInt
2015-10-06 06:48:05 +02:00
val prod = mplier(unroll-1,0) * mpcand + accum
val nextMulReg = Cat(prod, mplier(mulw-1,unroll))
2015-10-06 06:48:05 +02:00
val eOutMask = (SInt(BigInt(-1) << mulw) >> (count * unroll)(log2Up(mulw)-1,0))(mulw-1,0)
2016-01-14 22:57:45 +01:00
val eOut = Bool(earlyOut) && count =/= mulw/unroll-1 && count =/= 0 &&
!isHi && (mplier & ~eOutMask) === UInt(0)
2015-10-06 06:48:05 +02:00
val eOutRes = (mulReg >> (mulw - count * unroll)(log2Up(mulw)-1,0))
val nextMulReg1 = Cat(nextMulReg(2*mulw,mulw), Mux(eOut, eOutRes, nextMulReg)(mulw-1,0))
2015-08-01 00:42:10 +02:00
remainder := Cat(nextMulReg1 >> w, Bool(false), nextMulReg1(w-1,0))
2014-01-14 06:37:16 +01:00
count := count + 1
2015-10-06 06:48:05 +02:00
when (eOut || count === mulw/unroll-1) {
2014-01-14 06:37:16 +01:00
state := Mux(isHi, s_move_rem, s_done)
}
}
when (state === s_busy && !isMul) {
when (count === w) {
state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done))
}
count := count + 1
2011-12-20 12:49:07 +01:00
2014-01-14 06:37:16 +01:00
remainder := Cat(Mux(less, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !less)
2012-10-10 03:29:50 +02:00
2014-01-14 06:37:16 +01:00
val divisorMSB = Log2(divisor(w-1,0), w)
val dividendMSB = Log2(remainder(w-1,0), w)
val eOutPos = UInt(w-1) + divisorMSB - dividendMSB
val eOutZero = divisorMSB > dividendMSB
val eOut = count === 0 && less /* not divby0 */ && (eOutPos > 0 || eOutZero)
when (Bool(earlyOut) && eOut) {
val shift = Mux(eOutZero, UInt(w-1), eOutPos(log2Up(w)-1,0))
remainder := remainder(w-1,0) << shift
count := shift
2012-10-10 03:29:50 +02:00
}
when (count === 0 && !less /* divby0 */ && !isHi) { neg_out := false }
2014-01-14 06:37:16 +01:00
}
when (io.resp.fire() || io.kill) {
state := s_ready
}
when (io.req.fire()) {
state := Mux(lhs_sign || rhs_sign && !cmdMul, s_neg_inputs, s_busy)
isMul := cmdMul
isHi := cmdHi
count := 0
2016-01-14 22:57:45 +01:00
neg_out := !cmdMul && Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign)
2014-01-14 06:37:16 +01:00
divisor := Cat(rhs_sign, rhs_in)
remainder := lhs_in
req := io.req.bits
2011-12-20 12:49:07 +01:00
}
2014-01-14 06:37:16 +01:00
io.resp.bits := req
io.resp.bits.data := Mux(halfWidth(req), Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0))
2014-01-14 06:37:16 +01:00
io.resp.valid := state === s_done
io.req.ready := state === s_ready
}