Update RAS speculatively from fetch stage
This commit is contained in:
parent
3b2c15b648
commit
f2d4cb8152
@ -86,6 +86,15 @@ class BHT(nbht: Int)(implicit val p: Parameters) extends HasCoreParameters {
|
|||||||
val history = Reg(UInt(width = nbhtbits))
|
val history = Reg(UInt(width = nbhtbits))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object CFIType {
|
||||||
|
def SZ = 2
|
||||||
|
def apply() = UInt(width = SZ)
|
||||||
|
def branch = 0.U
|
||||||
|
def jump = 1.U
|
||||||
|
def call = 2.U
|
||||||
|
def ret = 3.U
|
||||||
|
}
|
||||||
|
|
||||||
// BTB update occurs during branch resolution (and only on a mispredict).
|
// BTB update occurs during branch resolution (and only on a mispredict).
|
||||||
// - "pc" is what future fetch PCs will tag match against.
|
// - "pc" is what future fetch PCs will tag match against.
|
||||||
// - "br_pc" is the PC of the branch instruction.
|
// - "br_pc" is the PC of the branch instruction.
|
||||||
@ -95,9 +104,8 @@ class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
|||||||
val target = UInt(width = vaddrBits)
|
val target = UInt(width = vaddrBits)
|
||||||
val taken = Bool()
|
val taken = Bool()
|
||||||
val isValid = Bool()
|
val isValid = Bool()
|
||||||
val isJump = Bool()
|
|
||||||
val isReturn = Bool()
|
|
||||||
val br_pc = UInt(width = vaddrBits)
|
val br_pc = UInt(width = vaddrBits)
|
||||||
|
val cfiType = CFIType()
|
||||||
}
|
}
|
||||||
|
|
||||||
// BHT update occurs during branch resolution on all conditional branches.
|
// BHT update occurs during branch resolution on all conditional branches.
|
||||||
@ -110,8 +118,7 @@ class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
||||||
val isCall = Bool()
|
val cfiType = CFIType()
|
||||||
val isReturn = Bool()
|
|
||||||
val returnAddr = UInt(width = vaddrBits)
|
val returnAddr = UInt(width = vaddrBits)
|
||||||
val prediction = Valid(new BTBResp)
|
val prediction = Valid(new BTBResp)
|
||||||
}
|
}
|
||||||
@ -121,6 +128,7 @@ class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
|||||||
// - "mask" provides a mask of valid instructions (instructions are
|
// - "mask" provides a mask of valid instructions (instructions are
|
||||||
// masked off by the predicted taken branch from the BTB).
|
// masked off by the predicted taken branch from the BTB).
|
||||||
class BTBResp(implicit p: Parameters) extends BtbBundle()(p) {
|
class BTBResp(implicit p: Parameters) extends BtbBundle()(p) {
|
||||||
|
val cfiType = CFIType()
|
||||||
val taken = Bool()
|
val taken = Bool()
|
||||||
val mask = Bits(width = fetchWidth)
|
val mask = Bits(width = fetchWidth)
|
||||||
val bridx = Bits(width = log2Up(fetchWidth))
|
val bridx = Bits(width = log2Up(fetchWidth))
|
||||||
@ -154,8 +162,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
val pageValid = Reg(init = UInt(0, nPages))
|
val pageValid = Reg(init = UInt(0, nPages))
|
||||||
|
|
||||||
val isValid = Reg(init = UInt(0, entries))
|
val isValid = Reg(init = UInt(0, entries))
|
||||||
val isReturn = Reg(UInt(width = entries))
|
val cfiType = Reg(Vec(entries, CFIType()))
|
||||||
val isJump = Reg(UInt(width = entries))
|
|
||||||
val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth))))
|
val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth))))
|
||||||
|
|
||||||
private def page(addr: UInt) = addr >> matchBits
|
private def page(addr: UInt) = addr >> matchBits
|
||||||
@ -210,9 +217,8 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
tgts(waddr) := update_target(matchBits-1, log2Up(coreInstBytes))
|
tgts(waddr) := update_target(matchBits-1, log2Up(coreInstBytes))
|
||||||
idxPages(waddr) := idxPageUpdate +& 1 // the +1 corresponds to the <<1 on io.resp.valid
|
idxPages(waddr) := idxPageUpdate +& 1 // the +1 corresponds to the <<1 on io.resp.valid
|
||||||
tgtPages(waddr) := tgtPageUpdate
|
tgtPages(waddr) := tgtPageUpdate
|
||||||
|
cfiType(waddr) := r_btb_update.bits.cfiType
|
||||||
isValid := Mux(r_btb_update.bits.isValid, isValid | mask, isValid & ~mask)
|
isValid := Mux(r_btb_update.bits.isValid, isValid | mask, isValid & ~mask)
|
||||||
isReturn := Mux(r_btb_update.bits.isReturn, isReturn | mask, isReturn & ~mask)
|
|
||||||
isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask)
|
|
||||||
if (fetchWidth > 1)
|
if (fetchWidth > 1)
|
||||||
brIdx(waddr) := r_btb_update.bits.br_pc >> log2Up(coreInstBytes)
|
brIdx(waddr) := r_btb_update.bits.br_pc >> log2Up(coreInstBytes)
|
||||||
|
|
||||||
@ -236,6 +242,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
io.resp.bits.entry := OHToUInt(idxHit)
|
io.resp.bits.entry := OHToUInt(idxHit)
|
||||||
io.resp.bits.bridx := (if (fetchWidth > 1) Mux1H(idxHit, brIdx) else UInt(0))
|
io.resp.bits.bridx := (if (fetchWidth > 1) Mux1H(idxHit, brIdx) else UInt(0))
|
||||||
io.resp.bits.mask := Cat((UInt(1) << ~Mux(io.resp.bits.taken, ~io.resp.bits.bridx, UInt(0)))-1, UInt(1))
|
io.resp.bits.mask := Cat((UInt(1) << ~Mux(io.resp.bits.taken, ~io.resp.bits.bridx, UInt(0)))-1, UInt(1))
|
||||||
|
io.resp.bits.cfiType := Mux1H(idxHit, cfiType)
|
||||||
|
|
||||||
// if multiple entries for same PC land in BTB, zap them
|
// if multiple entries for same PC land in BTB, zap them
|
||||||
when (PopCountAtLeast(idxHit, 2)) {
|
when (PopCountAtLeast(idxHit, 2)) {
|
||||||
@ -244,7 +251,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
|
|
||||||
if (nBHT > 0) {
|
if (nBHT > 0) {
|
||||||
val bht = new BHT(nBHT)
|
val bht = new BHT(nBHT)
|
||||||
val isBranch = !(idxHit & isJump).orR
|
val isBranch = (idxHit & cfiType.map(_ === CFIType.branch).asUInt).orR
|
||||||
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
|
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
|
||||||
val update_btb_hit = io.bht_update.bits.prediction.valid
|
val update_btb_hit = io.bht_update.bits.prediction.valid
|
||||||
when (io.bht_update.valid && update_btb_hit) {
|
when (io.bht_update.valid && update_btb_hit) {
|
||||||
@ -256,17 +263,14 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
|
|
||||||
if (nRAS > 0) {
|
if (nRAS > 0) {
|
||||||
val ras = new RAS(nRAS)
|
val ras = new RAS(nRAS)
|
||||||
val doPeek = (idxHit & isReturn).orR
|
val doPeek = (idxHit & cfiType.map(_ === CFIType.ret).asUInt).orR
|
||||||
when (!ras.isEmpty && doPeek) {
|
when (!ras.isEmpty && doPeek) {
|
||||||
io.resp.bits.target := ras.peek
|
io.resp.bits.target := ras.peek
|
||||||
}
|
}
|
||||||
when (io.ras_update.valid) {
|
when (io.ras_update.valid) {
|
||||||
when (io.ras_update.bits.isCall) {
|
when (io.ras_update.bits.cfiType === CFIType.call) {
|
||||||
ras.push(io.ras_update.bits.returnAddr)
|
ras.push(io.ras_update.bits.returnAddr)
|
||||||
when (doPeek) {
|
}.elsewhen (io.ras_update.bits.cfiType === CFIType.ret && io.ras_update.bits.prediction.valid) {
|
||||||
io.resp.bits.target := io.ras_update.bits.returnAddr
|
|
||||||
}
|
|
||||||
}.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) {
|
|
||||||
ras.pop()
|
ras.pop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,7 +84,9 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer)
|
|||||||
val s2_speculative = Reg(init=Bool(false))
|
val s2_speculative = Reg(init=Bool(false))
|
||||||
val s2_cacheable = Reg(init=Bool(false))
|
val s2_cacheable = Reg(init=Bool(false))
|
||||||
|
|
||||||
val ntpc = ~(~s1_pc | (coreInstBytes*fetchWidth-1)) + UInt(coreInstBytes*fetchWidth)
|
val fetchBytes = coreInstBytes * fetchWidth
|
||||||
|
val s1_base_pc = ~(~s1_pc | (fetchBytes - 1))
|
||||||
|
val ntpc = s1_base_pc + fetchBytes.U
|
||||||
val predicted_npc = Wire(init = ntpc)
|
val predicted_npc = Wire(init = ntpc)
|
||||||
val predicted_taken = Wire(init = Bool(false))
|
val predicted_taken = Wire(init = Bool(false))
|
||||||
|
|
||||||
@ -129,6 +131,14 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer)
|
|||||||
predicted_npc := btb.io.resp.bits.target.sextTo(vaddrBitsExtended)
|
predicted_npc := btb.io.resp.bits.target.sextTo(vaddrBitsExtended)
|
||||||
predicted_taken := Bool(true)
|
predicted_taken := Bool(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// push RAS speculatively
|
||||||
|
btb.io.ras_update.valid := btb.io.req.valid && btb.io.resp.valid && btb.io.resp.bits.cfiType.isOneOf(CFIType.call, CFIType.ret)
|
||||||
|
val returnAddrLSBs = btb.io.resp.bits.bridx +& 1
|
||||||
|
btb.io.ras_update.bits.returnAddr :=
|
||||||
|
Mux(returnAddrLSBs(log2Ceil(fetchWidth)), ntpc, s1_base_pc | ((returnAddrLSBs << log2Ceil(coreInstBytes)) & (fetchBytes - 1)))
|
||||||
|
btb.io.ras_update.bits.cfiType := btb.io.resp.bits.cfiType
|
||||||
|
btb.io.ras_update.bits.prediction.valid := true
|
||||||
}
|
}
|
||||||
|
|
||||||
io.ptw <> tlb.io.ptw
|
io.ptw <> tlb.io.ptw
|
||||||
|
@ -587,8 +587,10 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p)
|
|||||||
|
|
||||||
io.imem.btb_update.valid := (mem_reg_replay && mem_reg_btb_hit) || (mem_reg_valid && !take_pc_wb && (((mem_cfi_taken || !mem_cfi) && mem_wrong_npc) || (Bool(fastJAL) && mem_ctrl.jal && !mem_reg_btb_hit)))
|
io.imem.btb_update.valid := (mem_reg_replay && mem_reg_btb_hit) || (mem_reg_valid && !take_pc_wb && (((mem_cfi_taken || !mem_cfi) && mem_wrong_npc) || (Bool(fastJAL) && mem_ctrl.jal && !mem_reg_btb_hit)))
|
||||||
io.imem.btb_update.bits.isValid := !mem_reg_replay && mem_cfi
|
io.imem.btb_update.bits.isValid := !mem_reg_replay && mem_cfi
|
||||||
io.imem.btb_update.bits.isJump := mem_ctrl.jal || mem_ctrl.jalr
|
io.imem.btb_update.bits.cfiType :=
|
||||||
io.imem.btb_update.bits.isReturn := mem_ctrl.jalr && mem_reg_inst(19,15) === BitPat("b00?01")
|
Mux(mem_ctrl.jalr && mem_reg_inst(19,15) === BitPat("b00?01"), CFIType.ret,
|
||||||
|
Mux(mem_ctrl.jal || mem_ctrl.jalr, Mux(mem_waddr(0), CFIType.call, CFIType.jump),
|
||||||
|
CFIType.branch))
|
||||||
io.imem.btb_update.bits.target := io.imem.req.bits.pc
|
io.imem.btb_update.bits.target := io.imem.req.bits.pc
|
||||||
io.imem.btb_update.bits.br_pc := (if (usingCompressed) mem_reg_pc + Mux(mem_reg_rvc, UInt(0), UInt(2)) else mem_reg_pc)
|
io.imem.btb_update.bits.br_pc := (if (usingCompressed) mem_reg_pc + Mux(mem_reg_rvc, UInt(0), UInt(2)) else mem_reg_pc)
|
||||||
io.imem.btb_update.bits.pc := ~(~io.imem.btb_update.bits.br_pc | (coreInstBytes*fetchWidth-1))
|
io.imem.btb_update.bits.pc := ~(~io.imem.btb_update.bits.br_pc | (coreInstBytes*fetchWidth-1))
|
||||||
@ -601,12 +603,6 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p)
|
|||||||
io.imem.bht_update.bits.mispredict := mem_wrong_npc
|
io.imem.bht_update.bits.mispredict := mem_wrong_npc
|
||||||
io.imem.bht_update.bits.prediction := io.imem.btb_update.bits.prediction
|
io.imem.bht_update.bits.prediction := io.imem.btb_update.bits.prediction
|
||||||
|
|
||||||
io.imem.ras_update.valid := mem_reg_valid && !take_pc_wb
|
|
||||||
io.imem.ras_update.bits.returnAddr := mem_int_wdata
|
|
||||||
io.imem.ras_update.bits.isCall := io.imem.btb_update.bits.isJump && mem_waddr(0)
|
|
||||||
io.imem.ras_update.bits.isReturn := io.imem.btb_update.bits.isReturn
|
|
||||||
io.imem.ras_update.bits.prediction := io.imem.btb_update.bits.prediction
|
|
||||||
|
|
||||||
io.fpu.valid := !ctrl_killd && id_ctrl.fp
|
io.fpu.valid := !ctrl_killd && id_ctrl.fp
|
||||||
io.fpu.killx := ctrl_killx
|
io.fpu.killx := ctrl_killx
|
||||||
io.fpu.killm := killm_common
|
io.fpu.killm := killm_common
|
||||||
|
Loading…
Reference in New Issue
Block a user