Improve frontend branch prediction
- Put correctness responsibility on Frontend, not IBuf, for improved separation of concerns. Frontend must detect case that the BTB predicts a taken branch in the middle of an instruction. - Pass BTB information down pipeline unconditionally, fixing case that screws up the branch history when the BTB misses and the instruction is misaligned. - Remove jumpInFrontend option; it's now unconditional. - Default to one-bit counters in the BHT. For tiny BHTs like these, it's more resource efficient to have a larger index space than to have hysteresis.
This commit is contained in:
@ -10,12 +10,18 @@ import freechips.rocketchip.coreplex.CacheBlockBytes
|
||||
import freechips.rocketchip.tile.HasCoreParameters
|
||||
import freechips.rocketchip.util._
|
||||
|
||||
case class BHTParams(
|
||||
nEntries: Int = 512,
|
||||
counterLength: Int = 1,
|
||||
historyLength: Int = 8,
|
||||
historyBits: Int = 3)
|
||||
|
||||
case class BTBParams(
|
||||
nEntries: Int = 28,
|
||||
nMatchBits: Int = 14,
|
||||
nPages: Int = 6,
|
||||
nRAS: Int = 6,
|
||||
nBHT: Int = 256,
|
||||
bhtParams: Option[BHTParams] = Some(BHTParams()),
|
||||
updatesOutOfOrder: Boolean = false)
|
||||
|
||||
trait HasBtbParameters extends HasCoreParameters {
|
||||
@ -51,9 +57,10 @@ class RAS(nras: Int) {
|
||||
}
|
||||
|
||||
class BHTResp(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
val history = UInt(width = log2Up(btbParams.nBHT).max(1))
|
||||
val value = UInt(width = 2)
|
||||
val taken = Bool()
|
||||
val history = UInt(width = btbParams.bhtParams.map(_.historyLength).getOrElse(1))
|
||||
val value = UInt(width = btbParams.bhtParams.map(_.counterLength).getOrElse(1))
|
||||
def taken = value(0)
|
||||
def strongly_taken = value === 1
|
||||
}
|
||||
|
||||
// BHT contains table of 2-bit counters and a global history register.
|
||||
@ -65,32 +72,43 @@ class BHTResp(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
// - each counter corresponds with the address of the fetch packet ("fetch pc").
|
||||
// - updated when a branch resolves (and BTB was a hit for that branch).
|
||||
// The updating branch must provide its "fetch pc".
|
||||
class BHT(nbht: Int)(implicit val p: Parameters) extends HasCoreParameters {
|
||||
val nbhtbits = log2Up(nbht)
|
||||
class BHT(params: BHTParams)(implicit val p: Parameters) extends HasCoreParameters {
|
||||
def index(addr: UInt, history: UInt) = {
|
||||
def hashHistory(hist: UInt) = if (params.historyLength == params.historyBits) hist else {
|
||||
val k = math.sqrt(3)/2
|
||||
val i = BigDecimal(k * math.pow(2, params.historyLength)).toBigInt
|
||||
(i.U * hist)(params.historyLength-1, params.historyLength-params.historyBits)
|
||||
}
|
||||
def hashAddr(addr: UInt) = {
|
||||
val hi = addr >> log2Ceil(fetchBytes)
|
||||
hi(log2Ceil(params.nEntries)-1, 0) ^ (hi >> log2Ceil(params.nEntries))(1, 0)
|
||||
}
|
||||
hashAddr(addr) ^ (hashHistory(history) << (log2Up(params.nEntries) - params.historyBits))
|
||||
}
|
||||
def get(addr: UInt): BHTResp = {
|
||||
val res = Wire(new BHTResp)
|
||||
val index = addr(nbhtbits+log2Up(coreInstBytes)-1, log2Up(coreInstBytes)) ^ history
|
||||
res.value := table(index)
|
||||
res.value := table(index(addr, history))
|
||||
res.history := history
|
||||
res.taken := res.value(0)
|
||||
res
|
||||
}
|
||||
def updateTable(addr: UInt, d: BHTResp, taken: Bool): Unit = {
|
||||
val index = addr(nbhtbits+log2Up(coreInstBytes)-1, log2Up(coreInstBytes)) ^ d.history
|
||||
table(index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken))
|
||||
table(index(addr, d.history)) := (params.counterLength match {
|
||||
case 1 => taken
|
||||
case 2 => Cat(taken ^ d.value(0), d.value === 1 || d.value(1) && taken)
|
||||
})
|
||||
}
|
||||
def resetHistory(d: BHTResp): Unit = {
|
||||
history := d.history
|
||||
}
|
||||
def updateHistory(addr: UInt, d: BHTResp, taken: Bool): Unit = {
|
||||
history := Cat(taken, d.history(nbhtbits-1,1))
|
||||
history := Cat(taken, d.history >> 1)
|
||||
}
|
||||
def advanceHistory(taken: Bool): Unit = {
|
||||
history := Cat(taken, history(nbhtbits-1,1))
|
||||
history := Cat(taken, history >> 1)
|
||||
}
|
||||
|
||||
private val table = Mem(nbht, UInt(width = 2))
|
||||
val history = Reg(UInt(width = nbhtbits))
|
||||
private val table = Mem(params.nEntries, UInt(width = params.counterLength))
|
||||
val history = Reg(UInt(width = params.historyLength))
|
||||
}
|
||||
|
||||
object CFIType {
|
||||
@ -106,7 +124,7 @@ object CFIType {
|
||||
// - "pc" is what future fetch PCs will tag match against.
|
||||
// - "br_pc" is the PC of the branch instruction.
|
||||
class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
val prediction = Valid(new BTBResp)
|
||||
val prediction = new BTBResp
|
||||
val pc = UInt(width = vaddrBits)
|
||||
val target = UInt(width = vaddrBits)
|
||||
val taken = Bool()
|
||||
@ -118,8 +136,9 @@ class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
// BHT update occurs during branch resolution on all conditional branches.
|
||||
// - "pc" is what future fetch PCs will tag match against.
|
||||
class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
val prediction = Valid(new BTBResp)
|
||||
val prediction = new BHTResp
|
||||
val pc = UInt(width = vaddrBits)
|
||||
val branch = Bool()
|
||||
val taken = Bool()
|
||||
val mispredict = Bool()
|
||||
}
|
||||
@ -127,7 +146,6 @@ class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) {
|
||||
val cfiType = CFIType()
|
||||
val returnAddr = UInt(width = vaddrBits)
|
||||
val prediction = Valid(new BTBResp)
|
||||
}
|
||||
|
||||
// - "bridx" is the low-order PC bits of the predicted branch (after
|
||||
@ -161,6 +179,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
val bht_advance = Valid(new BTBResp).flip
|
||||
val ras_update = Valid(new RASUpdate).flip
|
||||
val ras_head = Valid(UInt(width = vaddrBits))
|
||||
val flush = Bool().asInput
|
||||
}
|
||||
|
||||
val idxs = Reg(Vec(entries, UInt(width=matchBits - log2Up(coreInstBytes))))
|
||||
@ -195,7 +214,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
if (updatesOutOfOrder) {
|
||||
val updateHits = (pageHit << 1)(Mux1H(idxMatch(r_btb_update.bits.pc), idxPages))
|
||||
(updateHits.orR, OHToUInt(updateHits))
|
||||
} else (r_btb_update.bits.prediction.valid && r_btb_update.bits.prediction.bits.entry < entries, r_btb_update.bits.prediction.bits.entry)
|
||||
} else (r_btb_update.bits.prediction.entry < entries, r_btb_update.bits.prediction.entry)
|
||||
|
||||
val useUpdatePageHit = updatePageHit.orR
|
||||
val usePageHit = pageHit.orR
|
||||
@ -220,7 +239,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
|
||||
val repl = new PseudoLRU(entries)
|
||||
val waddr = Mux(updateHit, updateHitAddr, repl.replace)
|
||||
val r_resp = Pipe(io.req.valid && io.resp.valid, io.resp.bits)
|
||||
val r_resp = Pipe(io.resp)
|
||||
when (r_resp.valid && r_resp.bits.taken || r_btb_update.valid) {
|
||||
repl.access(Mux(r_btb_update.valid, waddr, r_resp.bits.entry))
|
||||
}
|
||||
@ -262,24 +281,25 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
when (PopCountAtLeast(idxHit, 2)) {
|
||||
isValid := isValid & ~idxHit
|
||||
}
|
||||
when (io.flush) {
|
||||
isValid := 0
|
||||
}
|
||||
|
||||
if (btbParams.nBHT > 0) {
|
||||
val bht = new BHT(btbParams.nBHT)
|
||||
if (btbParams.bhtParams.nonEmpty) {
|
||||
val bht = new BHT(btbParams.bhtParams.get)
|
||||
val isBranch = (idxHit & cfiType.map(_ === CFIType.branch).asUInt).orR
|
||||
val res = bht.get(io.req.bits.addr)
|
||||
when (io.req.valid && io.resp.valid && isBranch) {
|
||||
bht.advanceHistory(res.taken)
|
||||
}
|
||||
when (io.bht_advance.valid) {
|
||||
bht.advanceHistory(io.bht_advance.bits.bht.taken)
|
||||
}
|
||||
when (io.btb_update.valid) {
|
||||
bht.resetHistory(io.btb_update.bits.prediction.bits.bht)
|
||||
}
|
||||
when (io.bht_update.valid) {
|
||||
bht.updateTable(io.bht_update.bits.pc, io.bht_update.bits.prediction.bits.bht, io.bht_update.bits.taken)
|
||||
when (io.bht_update.bits.mispredict) {
|
||||
bht.updateHistory(io.bht_update.bits.pc, io.bht_update.bits.prediction.bits.bht, io.bht_update.bits.taken)
|
||||
when (io.bht_update.bits.branch) {
|
||||
bht.updateTable(io.bht_update.bits.pc, io.bht_update.bits.prediction, io.bht_update.bits.taken)
|
||||
when (io.bht_update.bits.mispredict) {
|
||||
bht.updateHistory(io.bht_update.bits.pc, io.bht_update.bits.prediction, io.bht_update.bits.taken)
|
||||
}
|
||||
}.elsewhen (io.bht_update.bits.mispredict) {
|
||||
bht.resetHistory(io.bht_update.bits.prediction)
|
||||
}
|
||||
}
|
||||
when (!res.taken && isBranch) { io.resp.bits.taken := false }
|
||||
@ -297,7 +317,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
when (io.ras_update.valid) {
|
||||
when (io.ras_update.bits.cfiType === CFIType.call) {
|
||||
ras.push(io.ras_update.bits.returnAddr)
|
||||
}.elsewhen (io.ras_update.bits.cfiType === CFIType.ret && io.ras_update.bits.prediction.valid) {
|
||||
}.elsewhen (io.ras_update.bits.cfiType === CFIType.ret) {
|
||||
ras.pop()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user