Merge pull request #1228 from freechipsproject/no-mul
Teach MulDiv to do div-only
This commit is contained in:
commit
36cba65e60
@ -33,8 +33,8 @@ object ALU
|
|||||||
|
|
||||||
def FN_MUL = FN_ADD
|
def FN_MUL = FN_ADD
|
||||||
def FN_MULH = FN_SL
|
def FN_MULH = FN_SL
|
||||||
def FN_MULHSU = FN_SLT
|
def FN_MULHSU = FN_SEQ
|
||||||
def FN_MULHU = FN_SLTU
|
def FN_MULHU = FN_SNE
|
||||||
|
|
||||||
def isMulFN(fn: UInt, cmp: UInt) = fn(1,0) === cmp(1,0)
|
def isMulFN(fn: UInt, cmp: UInt) = fn(1,0) === cmp(1,0)
|
||||||
def isSub(cmd: UInt) = cmd(3)
|
def isSub(cmd: UInt) = cmd(3)
|
||||||
|
@ -196,7 +196,7 @@ class CSRFileIO(implicit p: Parameters) extends CoreBundle
|
|||||||
val retire = UInt(INPUT, log2Up(1+retireWidth))
|
val retire = UInt(INPUT, log2Up(1+retireWidth))
|
||||||
val cause = UInt(INPUT, xLen)
|
val cause = UInt(INPUT, xLen)
|
||||||
val pc = UInt(INPUT, vaddrBitsExtended)
|
val pc = UInt(INPUT, vaddrBitsExtended)
|
||||||
val badaddr = UInt(INPUT, vaddrBitsExtended)
|
val tval = UInt(INPUT, vaddrBitsExtended)
|
||||||
val time = UInt(OUTPUT, xLen)
|
val time = UInt(OUTPUT, xLen)
|
||||||
val fcsr_rm = Bits(OUTPUT, FPConstants.RM_SZ)
|
val fcsr_rm = Bits(OUTPUT, FPConstants.RM_SZ)
|
||||||
val fcsr_flags = Valid(Bits(width = FPConstants.FLAGS_SZ)).flip
|
val fcsr_flags = Valid(Bits(width = FPConstants.FLAGS_SZ)).flip
|
||||||
@ -527,12 +527,6 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param
|
|||||||
assert(!reg_singleStepped || io.retire === UInt(0))
|
assert(!reg_singleStepped || io.retire === UInt(0))
|
||||||
|
|
||||||
val epc = formEPC(io.pc)
|
val epc = formEPC(io.pc)
|
||||||
val write_badaddr = exception && cause.isOneOf(Causes.illegal_instruction, Causes.breakpoint,
|
|
||||||
Causes.misaligned_load, Causes.misaligned_store,
|
|
||||||
Causes.load_access, Causes.store_access, Causes.fetch_access,
|
|
||||||
Causes.load_page_fault, Causes.store_page_fault, Causes.fetch_page_fault)
|
|
||||||
val badaddr_value = Mux(write_badaddr, io.badaddr, 0.U)
|
|
||||||
|
|
||||||
val noCause :: mCause :: hCause :: sCause :: uCause :: Nil = Enum(5)
|
val noCause :: mCause :: hCause :: sCause :: uCause :: Nil = Enum(5)
|
||||||
val xcause_dest = Wire(init = noCause)
|
val xcause_dest = Wire(init = noCause)
|
||||||
|
|
||||||
@ -549,7 +543,7 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param
|
|||||||
reg_sepc := epc
|
reg_sepc := epc
|
||||||
reg_scause := cause
|
reg_scause := cause
|
||||||
xcause_dest := sCause
|
xcause_dest := sCause
|
||||||
reg_sbadaddr := badaddr_value
|
reg_sbadaddr := io.tval
|
||||||
reg_mstatus.spie := reg_mstatus.sie
|
reg_mstatus.spie := reg_mstatus.sie
|
||||||
reg_mstatus.spp := reg_mstatus.prv
|
reg_mstatus.spp := reg_mstatus.prv
|
||||||
reg_mstatus.sie := false
|
reg_mstatus.sie := false
|
||||||
@ -558,7 +552,7 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param
|
|||||||
reg_mepc := epc
|
reg_mepc := epc
|
||||||
reg_mcause := cause
|
reg_mcause := cause
|
||||||
xcause_dest := mCause
|
xcause_dest := mCause
|
||||||
reg_mbadaddr := badaddr_value
|
reg_mbadaddr := io.tval
|
||||||
reg_mstatus.mpie := reg_mstatus.mie
|
reg_mstatus.mpie := reg_mstatus.mie
|
||||||
reg_mstatus.mpp := trimPrivilege(reg_mstatus.prv)
|
reg_mstatus.mpp := trimPrivilege(reg_mstatus.prv)
|
||||||
reg_mstatus.mie := false
|
reg_mstatus.mie := false
|
||||||
@ -808,7 +802,7 @@ class CSRFile(perfEventSets: EventSets = new EventSets(Seq()))(implicit p: Param
|
|||||||
t.priv := Cat(reg_debug, reg_mstatus.prv)
|
t.priv := Cat(reg_debug, reg_mstatus.prv)
|
||||||
t.cause := cause
|
t.cause := cause
|
||||||
t.interrupt := cause(xLen-1)
|
t.interrupt := cause(xLen-1)
|
||||||
t.tval := badaddr_value
|
t.tval := io.tval
|
||||||
}
|
}
|
||||||
|
|
||||||
def chooseInterrupt(masksIn: Seq[UInt]): (Bool, UInt) = {
|
def chooseInterrupt(masksIn: Seq[UInt]): (Bool, UInt) = {
|
||||||
|
@ -39,30 +39,35 @@ case class MulDivParams(
|
|||||||
class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
|
class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
|
||||||
val io = new MultiplierIO(width, log2Up(nXpr))
|
val io = new MultiplierIO(width, log2Up(nXpr))
|
||||||
val w = io.req.bits.in1.getWidth
|
val w = io.req.bits.in1.getWidth
|
||||||
val mulw = (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll
|
val mulw = if (cfg.mulUnroll == 0) w else (w + cfg.mulUnroll - 1) / cfg.mulUnroll * cfg.mulUnroll
|
||||||
val fastMulW = w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0
|
val fastMulW = if (cfg.mulUnroll == 0) false else w/2 > cfg.mulUnroll && w % (2*cfg.mulUnroll) == 0
|
||||||
|
|
||||||
val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8)
|
val s_ready :: s_neg_inputs :: s_mul :: s_div :: s_dummy :: s_neg_output :: s_done_mul :: s_done_div :: Nil = Enum(UInt(), 8)
|
||||||
val state = Reg(init=s_ready)
|
val state = Reg(init=s_ready)
|
||||||
|
|
||||||
val req = Reg(io.req.bits)
|
val req = Reg(io.req.bits)
|
||||||
val count = Reg(UInt(width = log2Ceil((w/cfg.divUnroll + 1) max (w/cfg.mulUnroll))))
|
val count = Reg(UInt(width = log2Ceil(
|
||||||
|
((cfg.divUnroll != 0).option(w/cfg.divUnroll + 1).toSeq ++
|
||||||
|
(cfg.mulUnroll != 0).option(mulw/cfg.mulUnroll)).reduce(_ max _))))
|
||||||
val neg_out = Reg(Bool())
|
val neg_out = Reg(Bool())
|
||||||
val isHi = Reg(Bool())
|
val isHi = Reg(Bool())
|
||||||
val resHi = Reg(Bool())
|
val resHi = Reg(Bool())
|
||||||
val divisor = Reg(Bits(width = w+1)) // div only needs w bits
|
val divisor = Reg(Bits(width = w+1)) // div only needs w bits
|
||||||
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
|
val remainder = Reg(Bits(width = 2*mulw+2)) // div only needs 2*w+1 bits
|
||||||
|
|
||||||
val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil =
|
val mulDecode = List(
|
||||||
DecodeLogic(io.req.bits.fn, List(X, X, X, X), List(
|
|
||||||
FN_DIV -> List(N, N, Y, Y),
|
|
||||||
FN_REM -> List(N, Y, Y, Y),
|
|
||||||
FN_DIVU -> List(N, N, N, N),
|
|
||||||
FN_REMU -> List(N, Y, N, N),
|
|
||||||
FN_MUL -> List(Y, N, X, X),
|
FN_MUL -> List(Y, N, X, X),
|
||||||
FN_MULH -> List(Y, Y, Y, Y),
|
FN_MULH -> List(Y, Y, Y, Y),
|
||||||
FN_MULHU -> List(Y, Y, N, N),
|
FN_MULHU -> List(Y, Y, N, N),
|
||||||
FN_MULHSU -> List(Y, Y, Y, N))).map(_ toBool)
|
FN_MULHSU -> List(Y, Y, Y, N))
|
||||||
|
val divDecode = List(
|
||||||
|
FN_DIV -> List(N, N, Y, Y),
|
||||||
|
FN_REM -> List(N, Y, Y, Y),
|
||||||
|
FN_DIVU -> List(N, N, N, N),
|
||||||
|
FN_REMU -> List(N, Y, N, N))
|
||||||
|
val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil =
|
||||||
|
DecodeLogic(io.req.bits.fn, List(X, X, X, X),
|
||||||
|
(if (cfg.divUnroll != 0) divDecode else Nil) ++ (if (cfg.mulUnroll != 0) mulDecode else Nil)).map(_.toBool)
|
||||||
|
|
||||||
require(w == 32 || w == 64)
|
require(w == 32 || w == 64)
|
||||||
def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32
|
def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32
|
||||||
@ -79,7 +84,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
|
|||||||
val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0))
|
val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0))
|
||||||
val negated_remainder = -result
|
val negated_remainder = -result
|
||||||
|
|
||||||
when (state === s_neg_inputs) {
|
if (cfg.divUnroll != 0) when (state === s_neg_inputs) {
|
||||||
when (remainder(w-1)) {
|
when (remainder(w-1)) {
|
||||||
remainder := negated_remainder
|
remainder := negated_remainder
|
||||||
}
|
}
|
||||||
@ -88,12 +93,12 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
|
|||||||
}
|
}
|
||||||
state := s_div
|
state := s_div
|
||||||
}
|
}
|
||||||
when (state === s_neg_output) {
|
if (cfg.divUnroll != 0) when (state === s_neg_output) {
|
||||||
remainder := negated_remainder
|
remainder := negated_remainder
|
||||||
state := s_done_div
|
state := s_done_div
|
||||||
resHi := false
|
resHi := false
|
||||||
}
|
}
|
||||||
when (state === s_mul) {
|
if (cfg.mulUnroll != 0) when (state === s_mul) {
|
||||||
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
|
val mulReg = Cat(remainder(2*mulw+1,w+1),remainder(w-1,0))
|
||||||
val mplierSign = remainder(w)
|
val mplierSign = remainder(w)
|
||||||
val mplier = mulReg(mulw-1,0)
|
val mplier = mulReg(mulw-1,0)
|
||||||
@ -116,7 +121,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
|
|||||||
resHi := isHi
|
resHi := isHi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (state === s_div) {
|
if (cfg.divUnroll != 0) when (state === s_div) {
|
||||||
val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) =>
|
val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) =>
|
||||||
// the special case for iteration 0 is to save HW, not for correctness
|
// the special case for iteration 0 is to save HW, not for correctness
|
||||||
val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0)
|
val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0)
|
||||||
@ -156,7 +161,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32) extends Module {
|
|||||||
state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div))
|
state := Mux(cmdMul, s_mul, Mux(lhs_sign || rhs_sign, s_neg_inputs, s_div))
|
||||||
isHi := cmdHi
|
isHi := cmdHi
|
||||||
resHi := false
|
resHi := false
|
||||||
count := Mux[UInt](Bool(fastMulW) && cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0)
|
count := (if (fastMulW) Mux[UInt](cmdMul && halfWidth(io.req.bits), w/cfg.mulUnroll/2, 0) else 0)
|
||||||
neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign)
|
neg_out := Mux(cmdHi, lhs_sign, lhs_sign =/= rhs_sign)
|
||||||
divisor := Cat(rhs_sign, rhs_in)
|
divisor := Cat(rhs_sign, rhs_in)
|
||||||
remainder := lhs_in
|
remainder := lhs_in
|
||||||
|
@ -545,7 +545,11 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p)
|
|||||||
csr.io.fcsr_flags := io.fpu.fcsr_flags
|
csr.io.fcsr_flags := io.fpu.fcsr_flags
|
||||||
csr.io.rocc_interrupt := io.rocc.interrupt
|
csr.io.rocc_interrupt := io.rocc.interrupt
|
||||||
csr.io.pc := wb_reg_pc
|
csr.io.pc := wb_reg_pc
|
||||||
csr.io.badaddr := encodeVirtualAddress(wb_reg_wdata, wb_reg_wdata)
|
val tval_valid = wb_xcpt && wb_cause.isOneOf(Causes.illegal_instruction, Causes.breakpoint,
|
||||||
|
Causes.misaligned_load, Causes.misaligned_store,
|
||||||
|
Causes.load_access, Causes.store_access, Causes.fetch_access,
|
||||||
|
Causes.load_page_fault, Causes.store_page_fault, Causes.fetch_page_fault)
|
||||||
|
csr.io.tval := Mux(tval_valid, encodeVirtualAddress(wb_reg_wdata, wb_reg_wdata), 0.U)
|
||||||
io.ptw.ptbr := csr.io.ptbr
|
io.ptw.ptbr := csr.io.ptbr
|
||||||
io.ptw.status := csr.io.status
|
io.ptw.status := csr.io.status
|
||||||
io.ptw.pmp := csr.io.pmp
|
io.ptw.pmp := csr.io.pmp
|
||||||
|
Loading…
Reference in New Issue
Block a user