1
0

Remove invalidation support from BTB

Validating the target PC in the pipeline is cheaper than maintaining
the valid bits and control logic to guarantee the BTB won't ever
mispredict branch targets.
This commit is contained in:
Andrew Waterman 2016-07-02 14:27:29 -07:00
parent 663002ec0c
commit 5aa8ef1855
2 changed files with 46 additions and 62 deletions

View File

@ -36,15 +36,15 @@ class RAS(nras: Int) {
pos := nextPos
}
def peek: UInt = stack(pos)
def pop: Unit = when (!isEmpty) {
def pop(): Unit = when (!isEmpty) {
count := count - 1
pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1))
}
def clear: Unit = count := UInt(0)
def clear(): Unit = count := UInt(0)
def isEmpty: Bool = count === UInt(0)
private val count = Reg(init=UInt(0,log2Up(nras+1)))
private val pos = Reg(init=UInt(0,log2Up(nras)))
private val count = Reg(UInt(width = log2Up(nras+1)))
private val pos = Reg(UInt(width = log2Up(nras)))
private val stack = Reg(Vec(nras, UInt()))
}
@ -140,82 +140,80 @@ class BTB(implicit p: Parameters) extends BtbModule {
val btb_update = Valid(new BTBUpdate).flip
val bht_update = Valid(new BHTUpdate).flip
val ras_update = Valid(new RASUpdate).flip
val invalidate = Bool(INPUT)
}
val idxValid = Reg(init=UInt(0, entries))
val idxs = Mem(entries, UInt(width=matchBits))
val idxPages = Mem(entries, UInt(width=log2Up(nPages)))
val tgts = Mem(entries, UInt(width=matchBits))
val tgtPages = Mem(entries, UInt(width=log2Up(nPages)))
val pages = Mem(nPages, UInt(width=vaddrBits-matchBits))
val pageValid = Reg(init=UInt(0, nPages))
val idxs = Reg(Vec(entries, UInt(width=matchBits)))
val idxPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
val tgts = Reg(Vec(entries, UInt(width=matchBits)))
val tgtPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
val pages = Reg(Vec(nPages, UInt(width=vaddrBits-matchBits)))
val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0))
val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0))
val useRAS = Reg(Vec(entries, Bool()))
val isJump = Reg(Vec(entries, Bool()))
val brIdx = Mem(entries, UInt(width=log2Up(fetchWidth)))
val useRAS = Reg(UInt(width = entries))
val isJump = Reg(UInt(width = entries))
val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth))))
private def page(addr: UInt) = addr >> matchBits
private def pageMatch(addr: UInt) = {
val p = page(addr)
Vec(pages.map(_ === p)).toBits & pageValid
Vec(pages.map(_ === p)).toBits
}
private def tagMatch(addr: UInt, pgMatch: UInt) = {
val idx = addr(matchBits-1,0)
val idxMatch = idxs.map(_ === idx).toBits
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits
idxValid & idxMatch & idxPageMatch
val idxMatch = idxs.map(_ === addr(matchBits-1,0))
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR)
(idxPageMatch zip idxMatch) map { case (p, i) => p && i }
}
val r_btb_update = Pipe(io.btb_update)
val update_target = io.req.bits.addr
val pageHit = pageMatch(io.req.bits.addr)
val hits = tagMatch(io.req.bits.addr, pageHit)
val hitsVec = tagMatch(io.req.bits.addr, pageHit)
val hits = hitsVec.toBits
val updatePageHit = pageMatch(r_btb_update.bits.pc)
val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit)
val updateHit = r_btb_update.bits.prediction.valid
val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1
val nextRepl = Reg(UInt(width = log2Ceil(entries)))
when (r_btb_update.valid && !updateHit) { nextRepl := Mux(nextRepl === entries-1 && Bool(!isPow2(entries)), 0, nextRepl + 1) }
val nextPageRepl = Reg(UInt(width = log2Ceil(nPages)))
val useUpdatePageHit = updatePageHit.orR
val usePageHit = pageHit.orR
val doIdxPageRepl = !useUpdatePageHit
val idxPageRepl = Wire(UInt(width = nPages))
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, idxPageRepl)
val idxPageRepl = UIntToOH(nextPageRepl)
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit,
Mux(usePageHit, Cat(pageHit(nPages-2,0), pageHit(nPages-1)), idxPageRepl))
val idxPageUpdate = OHToUInt(idxPageUpdateOH)
val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0))
val samePage = page(r_btb_update.bits.pc) === page(update_target)
val usePageHit = (pageHit & ~idxPageReplEn).orR
val doTgtPageRepl = !samePage && !usePageHit
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1))
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, Cat(idxPageUpdateOH(nPages-2,0), idxPageUpdateOH(nPages-1)))
val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl))
val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0))
val doPageRepl = doIdxPageRepl || doTgtPageRepl
val pageReplEn = idxPageReplEn | tgtPageReplEn
idxPageRepl := UIntToOH(Counter(r_btb_update.valid && doPageRepl, nPages)._1)
when (r_btb_update.valid && (doIdxPageRepl || doTgtPageRepl)) {
val both = doIdxPageRepl && doTgtPageRepl
val next = nextPageRepl + Mux[UInt](both, 2, 1)
nextPageRepl := Mux(next >= nPages, next(0), next)
}
when (r_btb_update.valid) {
assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target")
val waddr =
if (updatesOutOfOrder) Mux(updateHits.orR, OHToUInt(updateHits), nextRepl)
if (updatesOutOfOrder) Mux(updateHits.reduce(_|_), OHToUInt(updateHits), nextRepl)
else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl)
// invalidate entries if we stomp on pages they depend upon
val invalidateMask = Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits
val validateMask = UIntToOH(waddr)
idxValid := (idxValid & ~invalidateMask) | validateMask
idxs(waddr) := r_btb_update.bits.pc
tgts(waddr) := update_target
idxPages(waddr) := idxPageUpdate
tgtPages(waddr) := tgtPageUpdate
useRAS(waddr) := r_btb_update.bits.isReturn
isJump(waddr) := r_btb_update.bits.isJump
val mask = UIntToOH(waddr)
useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask)
isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask)
if (fetchWidth == 1) {
brIdx(waddr) := UInt(0)
} else {
@ -223,41 +221,29 @@ class BTB(implicit p: Parameters) extends BtbModule {
}
require(nPages % 2 == 0)
val idxWritesEven = (idxPageUpdateOH & Fill(nPages/2, UInt(1,2))).orR
val idxWritesEven = !idxPageUpdate(0)
def writeBank(i: Int, mod: Int, en: Bool, data: UInt) =
def writeBank(i: Int, mod: Int, en: UInt, data: UInt) =
for (i <- i until nPages by mod)
when (en && pageReplEn(i)) { pages(i) := data }
when (en(i)) { pages(i) := data }
writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl),
writeBank(0, 2, Mux(idxWritesEven, idxPageReplEn, tgtPageReplEn),
Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target)))
writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl),
writeBank(1, 2, Mux(idxWritesEven, tgtPageReplEn, idxPageReplEn),
Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc)))
when (doPageRepl) { pageValid := pageValid | pageReplEn }
}
when (io.invalidate) {
idxValid := 0
pageValid := 0
}
io.resp.valid := hits.orR
io.resp.bits.taken := io.resp.valid
io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts))
io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts))
io.resp.bits.entry := OHToUInt(hits)
io.resp.bits.bridx := brIdx(io.resp.bits.entry)
if (fetchWidth == 1) {
io.resp.bits.mask := UInt(1)
} else {
// note: btb_resp is clock gated, so the mask is only relevant for the io.resp.valid case
io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
SInt(-1)).toUInt
}
io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
SInt(-1)).toUInt
if (nBHT > 0) {
val bht = new BHT(nBHT)
val isBranch = !Mux1H(hits, isJump)
val isBranch = !(hits & isJump).orR
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
val update_btb_hit = io.bht_update.bits.prediction.valid
when (io.bht_update.valid && update_btb_hit) {
@ -269,7 +255,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
if (nRAS > 0) {
val ras = new RAS(nRAS)
val doPeek = Mux1H(hits, useRAS)
val doPeek = (hits & useRAS).orR
when (!ras.isEmpty && doPeek) {
io.resp.bits.target := ras.peek
}
@ -280,9 +266,8 @@ class BTB(implicit p: Parameters) extends BtbModule {
io.resp.bits.target := io.ras_update.bits.returnAddr
}
}.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) {
ras.pop
ras.pop()
}
}
when (io.invalidate) { ras.clear }
}
}

View File

@ -81,7 +81,6 @@ class Frontend(implicit p: Parameters) extends CoreModule()(p) with HasL1CachePa
btb.io.btb_update := io.cpu.btb_update
btb.io.bht_update := io.cpu.bht_update
btb.io.ras_update := io.cpu.ras_update
btb.io.invalidate := false
when (!stall && !icmiss) {
btb.io.req.valid := true
s2_btb_resp_valid := btb.io.resp.valid