Clean up multiplier/divider stuff
This commit is contained in:
parent
4d236979bd
commit
a50a1f7d50
@ -1,199 +0,0 @@
|
||||
package rocket
|
||||
|
||||
import Chisel._
|
||||
import ALU._
|
||||
import Util._
|
||||
|
||||
class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module {
|
||||
val io = new MultiplierIO
|
||||
val w = io.req.bits.in1.getWidth
|
||||
val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll
|
||||
|
||||
val s_ready :: s_neg_inputs :: s_mul_busy :: s_div_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 7)
|
||||
val state = Reg(init=s_ready)
|
||||
|
||||
val req = Reg(io.req.bits.clone)
|
||||
val count = Reg(UInt(width = log2Up(w+1)))
|
||||
val divby0 = Reg(Bool())
|
||||
val neg_out = Reg(Bool())
|
||||
val divisor = Reg(Bits(width = w+1)) // div only needs w bits
|
||||
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
|
||||
|
||||
def sext(x: Bits, cmds: Vec[Bits]) = {
|
||||
val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn)
|
||||
val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign))
|
||||
(Cat(hi, x(w/2-1,0)), sign)
|
||||
}
|
||||
val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM, FN_MULH, FN_MULHSU))
|
||||
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM, FN_MULH))
|
||||
|
||||
val subtractor = remainder(2*w,w) - divisor(w,0)
|
||||
val negated_remainder = -remainder(w-1,0)
|
||||
|
||||
when (state === s_neg_inputs) {
|
||||
val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(req.fn)
|
||||
state := Mux(isMul, s_mul_busy, s_div_busy)
|
||||
when (remainder(w-1) || isMul) {
|
||||
remainder := negated_remainder
|
||||
}
|
||||
when (divisor(w-1) || isMul) {
|
||||
divisor := subtractor
|
||||
}
|
||||
}
|
||||
when (state === s_neg_output) {
|
||||
remainder := negated_remainder
|
||||
state := s_done
|
||||
}
|
||||
when (state === s_move_rem) {
|
||||
remainder := remainder(2*w, w+1)
|
||||
state := Mux(neg_out, s_neg_output, s_done)
|
||||
}
|
||||
when (state === s_mul_busy) {
|
||||
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
|
||||
val mplier = mulReg(mulw-1,0)
|
||||
val accum = mulReg(2*mulw,mulw).toSInt
|
||||
val mpcand = divisor.toSInt
|
||||
val prod = mplier(mulUnroll-1,0) * mpcand + accum
|
||||
val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll)).toUInt
|
||||
remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toSInt
|
||||
|
||||
count := count + 1
|
||||
when (count === mulw/mulUnroll-1) {
|
||||
state := s_done
|
||||
when (AVec(FN_MULH, FN_MULHU, FN_MULHSU) contains req.fn) {
|
||||
state := s_move_rem
|
||||
}
|
||||
}
|
||||
}
|
||||
when (state === s_div_busy) {
|
||||
when (count === UInt(w)) {
|
||||
state := Mux(neg_out && !divby0, s_neg_output, s_done)
|
||||
when (AVec(FN_REM, FN_REMU) contains req.fn) {
|
||||
state := s_move_rem
|
||||
}
|
||||
}
|
||||
count := count + UInt(1)
|
||||
|
||||
val msb = subtractor(w)
|
||||
divby0 := divby0 && !msb
|
||||
remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb)
|
||||
|
||||
val divisorMSB = Log2(divisor(w-1,0), w)
|
||||
val dividendMSB = Log2(remainder(w-1,0), w)
|
||||
val eOutPos = UInt(w-1) + divisorMSB - dividendMSB
|
||||
val eOutZero = divisorMSB > dividendMSB
|
||||
val eOut = count === UInt(0) && (eOutPos > 0 || eOutZero) && (divisorMSB != UInt(0) || divisor(0))
|
||||
when (Bool(earlyOut) && eOut) {
|
||||
val shift = Mux(eOutZero, UInt(w-1), eOutPos)
|
||||
remainder := remainder(w-1,0) << shift
|
||||
count := shift
|
||||
}
|
||||
}
|
||||
when (io.resp.fire() || io.kill) {
|
||||
state := s_ready
|
||||
}
|
||||
when (io.req.fire()) {
|
||||
val isMul = AVec(FN_MUL, FN_MULH, FN_MULHU, FN_MULHSU).contains(io.req.bits.fn)
|
||||
val isRem = AVec(FN_REM, FN_REMU).contains(io.req.bits.fn)
|
||||
val mulState = Mux(lhs_sign, s_neg_inputs, s_mul_busy)
|
||||
val divState = Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div_busy)
|
||||
state := Mux(isMul, mulState, divState)
|
||||
count := UInt(0)
|
||||
neg_out := !isMul && Mux(isRem, lhs_sign, lhs_sign != rhs_sign)
|
||||
divby0 := true
|
||||
divisor := Cat(rhs_sign, rhs_in)
|
||||
remainder := lhs_in
|
||||
req := io.req.bits
|
||||
}
|
||||
|
||||
io.resp.bits := req
|
||||
io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0))
|
||||
io.resp.valid := state === s_done
|
||||
io.req.ready := state === s_ready
|
||||
}
|
||||
|
||||
class Divider(earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module {
|
||||
val io = new MultiplierIO
|
||||
val w = io.req.bits.in1.getWidth
|
||||
|
||||
val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6)
|
||||
val state = Reg(init=s_ready)
|
||||
|
||||
val count = Reg(UInt(width = log2Up(w+1)))
|
||||
val divby0 = Reg(Bool())
|
||||
val neg_out = Reg(Bool())
|
||||
val r_req = Reg(io.req.bits)
|
||||
|
||||
val divisor = Reg(Bits())
|
||||
val remainder = Reg(Bits(width = 2*w+1))
|
||||
val subtractor = remainder(2*w,w) - divisor
|
||||
|
||||
def sext(x: Bits, cmds: Vec[Bits]) = {
|
||||
val sign = Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) && cmds.contains(io.req.bits.fn)
|
||||
val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign))
|
||||
(Cat(hi, x(w/2-1,0)), sign)
|
||||
}
|
||||
val (lhs_in, lhs_sign) = sext(io.req.bits.in1, AVec(FN_DIV, FN_REM))
|
||||
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, AVec(FN_DIV, FN_REM))
|
||||
|
||||
val r_isRem = isMulFN(r_req.fn, FN_REM) || isMulFN(r_req.fn, FN_REMU)
|
||||
|
||||
when (state === s_neg_inputs) {
|
||||
state := s_busy
|
||||
when (remainder(w-1)) {
|
||||
remainder := -remainder(w-1,0)
|
||||
}
|
||||
when (divisor(w-1)) {
|
||||
divisor := subtractor(w-1,0)
|
||||
}
|
||||
}
|
||||
when (state === s_neg_output) {
|
||||
remainder := -remainder(w-1,0)
|
||||
state := s_done
|
||||
}
|
||||
when (state === s_move_rem) {
|
||||
remainder := remainder(2*w, w+1)
|
||||
state := Mux(neg_out, s_neg_output, s_done)
|
||||
}
|
||||
when (state === s_busy) {
|
||||
when (count === UInt(w)) {
|
||||
state := Mux(r_isRem, s_move_rem, Mux(neg_out && !divby0, s_neg_output, s_done))
|
||||
}
|
||||
count := count + UInt(1)
|
||||
|
||||
val msb = subtractor(w)
|
||||
divby0 := divby0 && !msb
|
||||
remainder := Cat(Mux(msb, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !msb)
|
||||
|
||||
val divisorMSB = Log2(divisor, w)
|
||||
val dividendMSB = Log2(remainder(w-1,0), w)
|
||||
val eOutPos = UInt(w-1, log2Up(2*w)) + divisorMSB - dividendMSB
|
||||
val eOut = count === UInt(0) && eOutPos > 0 && (divisorMSB != UInt(0) || divisor(0))
|
||||
when (Bool(earlyOut) && eOut) {
|
||||
val shift = eOutPos(log2Up(w)-1,0)
|
||||
remainder := remainder(w-1,0) << shift
|
||||
count := shift
|
||||
when (eOutPos(log2Up(w))) {
|
||||
remainder := remainder(w-1,0) << w-1
|
||||
count := w-1
|
||||
}
|
||||
}
|
||||
}
|
||||
when (io.resp.fire() || io.kill) {
|
||||
state := s_ready
|
||||
}
|
||||
when (io.req.fire()) {
|
||||
state := Mux(lhs_sign || rhs_sign, s_neg_inputs, s_busy)
|
||||
count := UInt(0)
|
||||
neg_out := Mux(AVec(FN_REM, FN_REMU).contains(io.req.bits.fn), lhs_sign, lhs_sign != rhs_sign)
|
||||
divby0 := true
|
||||
divisor := rhs_in
|
||||
remainder := lhs_in
|
||||
r_req := io.req.bits
|
||||
}
|
||||
|
||||
io.resp.bits := r_req
|
||||
io.resp.bits.data := Mux(r_req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0))
|
||||
io.resp.valid := state === s_done
|
||||
io.req.ready := state === s_ready
|
||||
}
|
@ -2,6 +2,7 @@ package rocket
|
||||
|
||||
import Chisel._
|
||||
import ALU._
|
||||
import Util._
|
||||
|
||||
class MultiplierReq(implicit conf: RocketConfiguration) extends Bundle {
|
||||
val fn = Bits(width = SZ_ALU_FN)
|
||||
@ -26,68 +27,113 @@ class MultiplierIO(implicit conf: RocketConfiguration) extends Bundle {
|
||||
val resp = Decoupled(new MultiplierResp)
|
||||
}
|
||||
|
||||
class Multiplier(unroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module {
|
||||
class MulDiv(mulUnroll: Int = 1, earlyOut: Boolean = false)(implicit conf: RocketConfiguration) extends Module {
|
||||
val io = new MultiplierIO
|
||||
val w = io.req.bits.in1.getWidth
|
||||
val mulw = (w+mulUnroll-1)/mulUnroll*mulUnroll
|
||||
|
||||
val w0 = io.req.bits.in1.getWidth
|
||||
val w = (w0+1+unroll-1)/unroll*unroll
|
||||
val cycles = w/unroll
|
||||
val s_ready :: s_neg_inputs :: s_busy :: s_move_rem :: s_neg_output :: s_done :: Nil = Enum(UInt(), 6)
|
||||
val state = Reg(init=s_ready)
|
||||
|
||||
val r_val = Reg(init=Bool(false))
|
||||
val r_prod = Reg(Bits(width = w*2))
|
||||
val r_lsb = Reg(Bits())
|
||||
val r_cnt = Reg(UInt(width = log2Up(cycles+1)))
|
||||
val r_req = Reg(new MultiplierReq)
|
||||
val r_lhs = Reg(Bits(width = w0+1))
|
||||
val req = Reg(io.req.bits)
|
||||
val count = Reg(UInt(width = log2Up(w+1)))
|
||||
val neg_out = Reg(Bool())
|
||||
val isMul = Reg(Bool())
|
||||
val isHi = Reg(Bool())
|
||||
val divisor = Reg(Bits(width = w+1)) // div only needs w bits
|
||||
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
|
||||
|
||||
val dw = io.req.bits.dw
|
||||
val fn = io.req.bits.fn
|
||||
val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil =
|
||||
DecodeLogic(io.req.bits.fn, List(X, X, X, X), List(
|
||||
FN_DIV -> List(N, N, Y, Y),
|
||||
FN_REM -> List(N, Y, Y, Y),
|
||||
FN_DIVU -> List(N, N, N, N),
|
||||
FN_REMU -> List(N, Y, N, N),
|
||||
FN_MUL -> List(Y, N, X, X),
|
||||
FN_MULH -> List(Y, Y, Y, Y),
|
||||
FN_MULHU -> List(Y, Y, N, N),
|
||||
FN_MULHSU -> List(Y, Y, Y, N)))
|
||||
|
||||
val lhs_msb = Mux(dw === DW_64, io.req.bits.in1(w0-1), io.req.bits.in1(w0/2-1)).toBool
|
||||
val lhs_sign = (isMulFN(fn, FN_MULH) || isMulFN(fn, FN_MULHSU)) && lhs_msb
|
||||
val lhs_hi = Mux(dw === DW_64, io.req.bits.in1(w0-1,w0/2), Fill(w0/2, lhs_sign))
|
||||
val lhs_in = Cat(lhs_sign, lhs_hi, io.req.bits.in1(w0/2-1,0))
|
||||
def sext(x: Bits, signed: Bool) = {
|
||||
val sign = signed && Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1))
|
||||
val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign))
|
||||
(Cat(hi, x(w/2-1,0)), sign)
|
||||
}
|
||||
val (lhs_in, lhs_sign) = sext(io.req.bits.in1, lhsSigned)
|
||||
val (rhs_in, rhs_sign) = sext(io.req.bits.in2, rhsSigned)
|
||||
|
||||
val rhs_msb = Mux(dw === DW_64, io.req.bits.in2(w0-1), io.req.bits.in2(w0/2-1)).toBool
|
||||
val rhs_sign = isMulFN(fn, FN_MULH) && rhs_msb
|
||||
val rhs_hi = Mux(dw === DW_64, io.req.bits.in2(w0-1,w0/2), Fill(w0/2, rhs_sign))
|
||||
val rhs_in = Cat(Fill(w-w0, rhs_sign), rhs_hi, io.req.bits.in2(w0/2-1,0))
|
||||
val subtractor = remainder(2*w,w) - divisor(w,0)
|
||||
val less = subtractor(w)
|
||||
val negated_remainder = -remainder(w-1,0)
|
||||
|
||||
when (state === s_neg_inputs) {
|
||||
when (remainder(w-1) || isMul) {
|
||||
remainder := negated_remainder
|
||||
}
|
||||
when (divisor(w-1) || isMul) {
|
||||
divisor := subtractor
|
||||
}
|
||||
state := s_busy
|
||||
}
|
||||
|
||||
when (state === s_neg_output) {
|
||||
remainder := negated_remainder
|
||||
state := s_done
|
||||
}
|
||||
when (state === s_move_rem) {
|
||||
remainder := remainder(2*w, w+1)
|
||||
state := Mux(neg_out, s_neg_output, s_done)
|
||||
}
|
||||
when (state === s_busy && isMul) {
|
||||
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
|
||||
val mplier = mulReg(mulw-1,0)
|
||||
val accum = mulReg(2*mulw,mulw).toSInt
|
||||
val mpcand = divisor.toSInt
|
||||
val prod = mplier(mulUnroll-1,0) * mpcand + accum
|
||||
val nextMulReg = Cat(prod, mplier(mulw-1,mulUnroll)).toUInt
|
||||
remainder := Cat(nextMulReg >> w, Bool(false), nextMulReg(w-1,0)).toSInt
|
||||
|
||||
count := count + 1
|
||||
when (count === mulw/mulUnroll-1) {
|
||||
state := Mux(isHi, s_move_rem, s_done)
|
||||
}
|
||||
}
|
||||
when (state === s_busy && !isMul) {
|
||||
when (count === w) {
|
||||
state := Mux(isHi, s_move_rem, Mux(neg_out, s_neg_output, s_done))
|
||||
}
|
||||
count := count + 1
|
||||
|
||||
remainder := Cat(Mux(less, remainder(2*w-1,w), subtractor(w-1,0)), remainder(w-1,0), !less)
|
||||
|
||||
val divisorMSB = Log2(divisor(w-1,0), w)
|
||||
val dividendMSB = Log2(remainder(w-1,0), w)
|
||||
val eOutPos = UInt(w-1) + divisorMSB - dividendMSB
|
||||
val eOutZero = divisorMSB > dividendMSB
|
||||
val eOut = count === 0 && less /* not divby0 */ && (eOutPos > 0 || eOutZero)
|
||||
when (Bool(earlyOut) && eOut) {
|
||||
val shift = Mux(eOutZero, UInt(w-1), eOutPos(log2Up(w)-1,0))
|
||||
remainder := remainder(w-1,0) << shift
|
||||
count := shift
|
||||
}
|
||||
when (count === 0 && !less /* divby0 */) { neg_out := false }
|
||||
}
|
||||
when (io.resp.fire() || io.kill) {
|
||||
state := s_ready
|
||||
}
|
||||
when (io.req.fire()) {
|
||||
r_val := Bool(true)
|
||||
r_cnt := UInt(0, log2Up(cycles+1))
|
||||
r_req := io.req.bits
|
||||
r_lhs := lhs_in
|
||||
r_prod:= rhs_in
|
||||
r_lsb := Bool(false)
|
||||
}
|
||||
.elsewhen (io.resp.fire() || io.kill) {
|
||||
r_val := Bool(false)
|
||||
state := Mux(lhs_sign || rhs_sign && !cmdMul, s_neg_inputs, s_busy)
|
||||
isMul := cmdMul
|
||||
isHi := cmdHi
|
||||
count := 0
|
||||
neg_out := !cmdMul && Mux(cmdHi, lhs_sign, lhs_sign != rhs_sign)
|
||||
divisor := Cat(rhs_sign, rhs_in)
|
||||
remainder := lhs_in
|
||||
req := io.req.bits
|
||||
}
|
||||
|
||||
val eOutDist = (UInt(cycles)-r_cnt)*UInt(unroll)
|
||||
val outShift = Mux(isMulFN(r_req.fn, FN_MUL), UInt(0), Mux(r_req.dw === DW_64, UInt(64), UInt(32)))
|
||||
val shiftDist = Mux(r_cnt === UInt(cycles), outShift, eOutDist)
|
||||
val eOutMask = (UInt(1) << eOutDist) - UInt(1)
|
||||
val eOut = r_cnt != UInt(0) && Bool(earlyOut) && !((r_prod(w-1,0) ^ r_lsb.toSInt) & eOutMask).orR
|
||||
val shift = r_prod.toSInt >> shiftDist
|
||||
|
||||
val sum = r_prod(2*w-1,w).toSInt + r_prod(unroll-1,0).toSInt * r_lhs.toSInt + Mux(r_lsb.toBool, r_lhs.toSInt, SInt(0))
|
||||
when (r_val && (r_cnt != UInt(cycles))) {
|
||||
r_lsb := r_prod(unroll-1)
|
||||
r_prod := Cat(sum, r_prod(w-1,unroll)).toSInt
|
||||
r_cnt := r_cnt + UInt(1)
|
||||
when (eOut) {
|
||||
r_prod := shift
|
||||
r_cnt := UInt(cycles)
|
||||
}
|
||||
}
|
||||
|
||||
val out32 = Cat(Fill(w0/2, shift(w0/2-1)), shift(w0/2-1,0))
|
||||
val out64 = shift(w0-1,0)
|
||||
|
||||
io.req.ready := !r_val
|
||||
io.resp.bits := r_req
|
||||
io.resp.bits.data := Mux(r_req.dw === DW_64, out64, out32)
|
||||
io.resp.valid := r_val && (r_cnt === UInt(cycles))
|
||||
io.resp.bits := req
|
||||
io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0))
|
||||
io.resp.valid := state === s_done
|
||||
io.req.ready := state === s_ready
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user