1
0

Improve frontend branch prediction

- Put correctness responsibility on Frontend, not IBuf, for improved
  separation of concerns.  Frontend must detect case that the BTB
  predicts a taken branch in the middle of an instruction.

- Pass BTB information down pipeline unconditionally, fixing case that
  screws up the branch history when the BTB misses and the instruction
  is misaligned.

- Remove jumpInFrontend option; it's now unconditional.

- Default to one-bit counters in the BHT.  For tiny BHTs like these, it's
  more resource efficient to have a larger index space than to have
  hysteresis.
This commit is contained in:
Andrew Waterman
2017-11-08 17:23:25 -08:00
parent 989eeb78f9
commit d0c6cbba6b
5 changed files with 177 additions and 157 deletions

View File

@ -10,12 +10,18 @@ import freechips.rocketchip.coreplex.CacheBlockBytes
import freechips.rocketchip.tile.HasCoreParameters
import freechips.rocketchip.util._
case class BHTParams(
nEntries: Int = 512,
counterLength: Int = 1,
historyLength: Int = 8,
historyBits: Int = 3)
case class BTBParams(
nEntries: Int = 28,
nMatchBits: Int = 14,
nPages: Int = 6,
nRAS: Int = 6,
nBHT: Int = 256,
bhtParams: Option[BHTParams] = Some(BHTParams()),
updatesOutOfOrder: Boolean = false)
trait HasBtbParameters extends HasCoreParameters {
@ -51,9 +57,10 @@ class RAS(nras: Int) {
}
class BHTResp(implicit p: Parameters) extends BtbBundle()(p) {
val history = UInt(width = log2Up(btbParams.nBHT).max(1))
val value = UInt(width = 2)
val taken = Bool()
val history = UInt(width = btbParams.bhtParams.map(_.historyLength).getOrElse(1))
val value = UInt(width = btbParams.bhtParams.map(_.counterLength).getOrElse(1))
def taken = value(0)
def strongly_taken = value === 1
}
// BHT contains table of 2-bit counters and a global history register.
@ -65,32 +72,43 @@ class BHTResp(implicit p: Parameters) extends BtbBundle()(p) {
// - each counter corresponds with the address of the fetch packet ("fetch pc").
// - updated when a branch resolves (and BTB was a hit for that branch).
// The updating branch must provide its "fetch pc".
class BHT(nbht: Int)(implicit val p: Parameters) extends HasCoreParameters {
val nbhtbits = log2Up(nbht)
class BHT(params: BHTParams)(implicit val p: Parameters) extends HasCoreParameters {
def index(addr: UInt, history: UInt) = {
def hashHistory(hist: UInt) = if (params.historyLength == params.historyBits) hist else {
val k = math.sqrt(3)/2
val i = BigDecimal(k * math.pow(2, params.historyLength)).toBigInt
(i.U * hist)(params.historyLength-1, params.historyLength-params.historyBits)
}
def hashAddr(addr: UInt) = {
val hi = addr >> log2Ceil(fetchBytes)
hi(log2Ceil(params.nEntries)-1, 0) ^ (hi >> log2Ceil(params.nEntries))(1, 0)
}
hashAddr(addr) ^ (hashHistory(history) << (log2Up(params.nEntries) - params.historyBits))
}
def get(addr: UInt): BHTResp = {
val res = Wire(new BHTResp)
val index = addr(nbhtbits+log2Up(coreInstBytes)-1, log2Up(coreInstBytes)) ^ history
res.value := table(index)
res.value := table(index(addr, history))
res.history := history
res.taken := res.value(0)
res
}
def updateTable(addr: UInt, d: BHTResp, taken: Bool): Unit = {
val index = addr(nbhtbits+log2Up(coreInstBytes)-1, log2Up(coreInstBytes)) ^ d.history
table(index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken))
table(index(addr, d.history)) := (params.counterLength match {
case 1 => taken
case 2 => Cat(taken ^ d.value(0), d.value === 1 || d.value(1) && taken)
})
}
def resetHistory(d: BHTResp): Unit = {
history := d.history
}
def updateHistory(addr: UInt, d: BHTResp, taken: Bool): Unit = {
history := Cat(taken, d.history(nbhtbits-1,1))
history := Cat(taken, d.history >> 1)
}
def advanceHistory(taken: Bool): Unit = {
history := Cat(taken, history(nbhtbits-1,1))
history := Cat(taken, history >> 1)
}
private val table = Mem(nbht, UInt(width = 2))
val history = Reg(UInt(width = nbhtbits))
private val table = Mem(params.nEntries, UInt(width = params.counterLength))
val history = Reg(UInt(width = params.historyLength))
}
object CFIType {
@ -106,7 +124,7 @@ object CFIType {
// - "pc" is what future fetch PCs will tag match against.
// - "br_pc" is the PC of the branch instruction.
class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) {
val prediction = Valid(new BTBResp)
val prediction = new BTBResp
val pc = UInt(width = vaddrBits)
val target = UInt(width = vaddrBits)
val taken = Bool()
@ -118,8 +136,9 @@ class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) {
// BHT update occurs during branch resolution on all conditional branches.
// - "pc" is what future fetch PCs will tag match against.
class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) {
val prediction = Valid(new BTBResp)
val prediction = new BHTResp
val pc = UInt(width = vaddrBits)
val branch = Bool()
val taken = Bool()
val mispredict = Bool()
}
@ -127,7 +146,6 @@ class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) {
class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) {
val cfiType = CFIType()
val returnAddr = UInt(width = vaddrBits)
val prediction = Valid(new BTBResp)
}
// - "bridx" is the low-order PC bits of the predicted branch (after
@ -161,6 +179,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
val bht_advance = Valid(new BTBResp).flip
val ras_update = Valid(new RASUpdate).flip
val ras_head = Valid(UInt(width = vaddrBits))
val flush = Bool().asInput
}
val idxs = Reg(Vec(entries, UInt(width=matchBits - log2Up(coreInstBytes))))
@ -195,7 +214,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
if (updatesOutOfOrder) {
val updateHits = (pageHit << 1)(Mux1H(idxMatch(r_btb_update.bits.pc), idxPages))
(updateHits.orR, OHToUInt(updateHits))
} else (r_btb_update.bits.prediction.valid && r_btb_update.bits.prediction.bits.entry < entries, r_btb_update.bits.prediction.bits.entry)
} else (r_btb_update.bits.prediction.entry < entries, r_btb_update.bits.prediction.entry)
val useUpdatePageHit = updatePageHit.orR
val usePageHit = pageHit.orR
@ -220,7 +239,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
val repl = new PseudoLRU(entries)
val waddr = Mux(updateHit, updateHitAddr, repl.replace)
val r_resp = Pipe(io.req.valid && io.resp.valid, io.resp.bits)
val r_resp = Pipe(io.resp)
when (r_resp.valid && r_resp.bits.taken || r_btb_update.valid) {
repl.access(Mux(r_btb_update.valid, waddr, r_resp.bits.entry))
}
@ -262,24 +281,25 @@ class BTB(implicit p: Parameters) extends BtbModule {
when (PopCountAtLeast(idxHit, 2)) {
isValid := isValid & ~idxHit
}
when (io.flush) {
isValid := 0
}
if (btbParams.nBHT > 0) {
val bht = new BHT(btbParams.nBHT)
if (btbParams.bhtParams.nonEmpty) {
val bht = new BHT(btbParams.bhtParams.get)
val isBranch = (idxHit & cfiType.map(_ === CFIType.branch).asUInt).orR
val res = bht.get(io.req.bits.addr)
when (io.req.valid && io.resp.valid && isBranch) {
bht.advanceHistory(res.taken)
}
when (io.bht_advance.valid) {
bht.advanceHistory(io.bht_advance.bits.bht.taken)
}
when (io.btb_update.valid) {
bht.resetHistory(io.btb_update.bits.prediction.bits.bht)
}
when (io.bht_update.valid) {
bht.updateTable(io.bht_update.bits.pc, io.bht_update.bits.prediction.bits.bht, io.bht_update.bits.taken)
when (io.bht_update.bits.mispredict) {
bht.updateHistory(io.bht_update.bits.pc, io.bht_update.bits.prediction.bits.bht, io.bht_update.bits.taken)
when (io.bht_update.bits.branch) {
bht.updateTable(io.bht_update.bits.pc, io.bht_update.bits.prediction, io.bht_update.bits.taken)
when (io.bht_update.bits.mispredict) {
bht.updateHistory(io.bht_update.bits.pc, io.bht_update.bits.prediction, io.bht_update.bits.taken)
}
}.elsewhen (io.bht_update.bits.mispredict) {
bht.resetHistory(io.bht_update.bits.prediction)
}
}
when (!res.taken && isBranch) { io.resp.bits.taken := false }
@ -297,7 +317,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
when (io.ras_update.valid) {
when (io.ras_update.bits.cfiType === CFIType.call) {
ras.push(io.ras_update.bits.returnAddr)
}.elsewhen (io.ras_update.bits.cfiType === CFIType.ret && io.ras_update.bits.prediction.valid) {
}.elsewhen (io.ras_update.bits.cfiType === CFIType.ret) {
ras.pop()
}
}