1
0

Remove invalidation support from BTB

Validating the target PC in the pipeline is cheaper than maintaining
the valid bits and control logic to guarantee the BTB won't ever
mispredict branch targets.
This commit is contained in:
Andrew Waterman 2016-07-02 14:27:29 -07:00
parent 663002ec0c
commit 5aa8ef1855
2 changed files with 46 additions and 62 deletions

View File

@ -36,15 +36,15 @@ class RAS(nras: Int) {
pos := nextPos pos := nextPos
} }
def peek: UInt = stack(pos) def peek: UInt = stack(pos)
def pop: Unit = when (!isEmpty) { def pop(): Unit = when (!isEmpty) {
count := count - 1 count := count - 1
pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1)) pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1))
} }
def clear: Unit = count := UInt(0) def clear(): Unit = count := UInt(0)
def isEmpty: Bool = count === UInt(0) def isEmpty: Bool = count === UInt(0)
private val count = Reg(init=UInt(0,log2Up(nras+1))) private val count = Reg(UInt(width = log2Up(nras+1)))
private val pos = Reg(init=UInt(0,log2Up(nras))) private val pos = Reg(UInt(width = log2Up(nras)))
private val stack = Reg(Vec(nras, UInt())) private val stack = Reg(Vec(nras, UInt()))
} }
@ -140,82 +140,80 @@ class BTB(implicit p: Parameters) extends BtbModule {
val btb_update = Valid(new BTBUpdate).flip val btb_update = Valid(new BTBUpdate).flip
val bht_update = Valid(new BHTUpdate).flip val bht_update = Valid(new BHTUpdate).flip
val ras_update = Valid(new RASUpdate).flip val ras_update = Valid(new RASUpdate).flip
val invalidate = Bool(INPUT)
} }
val idxValid = Reg(init=UInt(0, entries)) val idxs = Reg(Vec(entries, UInt(width=matchBits)))
val idxs = Mem(entries, UInt(width=matchBits)) val idxPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
val idxPages = Mem(entries, UInt(width=log2Up(nPages))) val tgts = Reg(Vec(entries, UInt(width=matchBits)))
val tgts = Mem(entries, UInt(width=matchBits)) val tgtPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
val tgtPages = Mem(entries, UInt(width=log2Up(nPages))) val pages = Reg(Vec(nPages, UInt(width=vaddrBits-matchBits)))
val pages = Mem(nPages, UInt(width=vaddrBits-matchBits))
val pageValid = Reg(init=UInt(0, nPages))
val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0)) val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0))
val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0)) val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0))
val useRAS = Reg(Vec(entries, Bool())) val useRAS = Reg(UInt(width = entries))
val isJump = Reg(Vec(entries, Bool())) val isJump = Reg(UInt(width = entries))
val brIdx = Mem(entries, UInt(width=log2Up(fetchWidth))) val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth))))
private def page(addr: UInt) = addr >> matchBits private def page(addr: UInt) = addr >> matchBits
private def pageMatch(addr: UInt) = { private def pageMatch(addr: UInt) = {
val p = page(addr) val p = page(addr)
Vec(pages.map(_ === p)).toBits & pageValid Vec(pages.map(_ === p)).toBits
} }
private def tagMatch(addr: UInt, pgMatch: UInt) = { private def tagMatch(addr: UInt, pgMatch: UInt) = {
val idx = addr(matchBits-1,0) val idxMatch = idxs.map(_ === addr(matchBits-1,0))
val idxMatch = idxs.map(_ === idx).toBits val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR)
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits (idxPageMatch zip idxMatch) map { case (p, i) => p && i }
idxValid & idxMatch & idxPageMatch
} }
val r_btb_update = Pipe(io.btb_update) val r_btb_update = Pipe(io.btb_update)
val update_target = io.req.bits.addr val update_target = io.req.bits.addr
val pageHit = pageMatch(io.req.bits.addr) val pageHit = pageMatch(io.req.bits.addr)
val hits = tagMatch(io.req.bits.addr, pageHit) val hitsVec = tagMatch(io.req.bits.addr, pageHit)
val hits = hitsVec.toBits
val updatePageHit = pageMatch(r_btb_update.bits.pc) val updatePageHit = pageMatch(r_btb_update.bits.pc)
val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit) val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit)
val updateHit = r_btb_update.bits.prediction.valid val updateHit = r_btb_update.bits.prediction.valid
val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1 val nextRepl = Reg(UInt(width = log2Ceil(entries)))
when (r_btb_update.valid && !updateHit) { nextRepl := Mux(nextRepl === entries-1 && Bool(!isPow2(entries)), 0, nextRepl + 1) }
val nextPageRepl = Reg(UInt(width = log2Ceil(nPages)))
val useUpdatePageHit = updatePageHit.orR val useUpdatePageHit = updatePageHit.orR
val usePageHit = pageHit.orR
val doIdxPageRepl = !useUpdatePageHit val doIdxPageRepl = !useUpdatePageHit
val idxPageRepl = Wire(UInt(width = nPages)) val idxPageRepl = UIntToOH(nextPageRepl)
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, idxPageRepl) val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit,
Mux(usePageHit, Cat(pageHit(nPages-2,0), pageHit(nPages-1)), idxPageRepl))
val idxPageUpdate = OHToUInt(idxPageUpdateOH) val idxPageUpdate = OHToUInt(idxPageUpdateOH)
val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0)) val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0))
val samePage = page(r_btb_update.bits.pc) === page(update_target) val samePage = page(r_btb_update.bits.pc) === page(update_target)
val usePageHit = (pageHit & ~idxPageReplEn).orR
val doTgtPageRepl = !samePage && !usePageHit val doTgtPageRepl = !samePage && !usePageHit
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1)) val tgtPageRepl = Mux(samePage, idxPageUpdateOH, Cat(idxPageUpdateOH(nPages-2,0), idxPageUpdateOH(nPages-1)))
val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl)) val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl))
val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0)) val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0))
val doPageRepl = doIdxPageRepl || doTgtPageRepl
val pageReplEn = idxPageReplEn | tgtPageReplEn when (r_btb_update.valid && (doIdxPageRepl || doTgtPageRepl)) {
idxPageRepl := UIntToOH(Counter(r_btb_update.valid && doPageRepl, nPages)._1) val both = doIdxPageRepl && doTgtPageRepl
val next = nextPageRepl + Mux[UInt](both, 2, 1)
nextPageRepl := Mux(next >= nPages, next(0), next)
}
when (r_btb_update.valid) { when (r_btb_update.valid) {
assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target") assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target")
val waddr = val waddr =
if (updatesOutOfOrder) Mux(updateHits.orR, OHToUInt(updateHits), nextRepl) if (updatesOutOfOrder) Mux(updateHits.reduce(_|_), OHToUInt(updateHits), nextRepl)
else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl) else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl)
// invalidate entries if we stomp on pages they depend upon
val invalidateMask = Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits
val validateMask = UIntToOH(waddr)
idxValid := (idxValid & ~invalidateMask) | validateMask
idxs(waddr) := r_btb_update.bits.pc idxs(waddr) := r_btb_update.bits.pc
tgts(waddr) := update_target tgts(waddr) := update_target
idxPages(waddr) := idxPageUpdate idxPages(waddr) := idxPageUpdate
tgtPages(waddr) := tgtPageUpdate tgtPages(waddr) := tgtPageUpdate
useRAS(waddr) := r_btb_update.bits.isReturn val mask = UIntToOH(waddr)
isJump(waddr) := r_btb_update.bits.isJump useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask)
isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask)
if (fetchWidth == 1) { if (fetchWidth == 1) {
brIdx(waddr) := UInt(0) brIdx(waddr) := UInt(0)
} else { } else {
@ -223,41 +221,29 @@ class BTB(implicit p: Parameters) extends BtbModule {
} }
require(nPages % 2 == 0) require(nPages % 2 == 0)
val idxWritesEven = (idxPageUpdateOH & Fill(nPages/2, UInt(1,2))).orR val idxWritesEven = !idxPageUpdate(0)
def writeBank(i: Int, mod: Int, en: Bool, data: UInt) = def writeBank(i: Int, mod: Int, en: UInt, data: UInt) =
for (i <- i until nPages by mod) for (i <- i until nPages by mod)
when (en && pageReplEn(i)) { pages(i) := data } when (en(i)) { pages(i) := data }
writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl), writeBank(0, 2, Mux(idxWritesEven, idxPageReplEn, tgtPageReplEn),
Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target))) Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target)))
writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl), writeBank(1, 2, Mux(idxWritesEven, tgtPageReplEn, idxPageReplEn),
Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc))) Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc)))
when (doPageRepl) { pageValid := pageValid | pageReplEn }
}
when (io.invalidate) {
idxValid := 0
pageValid := 0
} }
io.resp.valid := hits.orR io.resp.valid := hits.orR
io.resp.bits.taken := io.resp.valid io.resp.bits.taken := io.resp.valid
io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts)) io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts))
io.resp.bits.entry := OHToUInt(hits) io.resp.bits.entry := OHToUInt(hits)
io.resp.bits.bridx := brIdx(io.resp.bits.entry) io.resp.bits.bridx := brIdx(io.resp.bits.entry)
if (fetchWidth == 1) { io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
io.resp.bits.mask := UInt(1) SInt(-1)).toUInt
} else {
// note: btb_resp is clock gated, so the mask is only relevant for the io.resp.valid case
io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
SInt(-1)).toUInt
}
if (nBHT > 0) { if (nBHT > 0) {
val bht = new BHT(nBHT) val bht = new BHT(nBHT)
val isBranch = !Mux1H(hits, isJump) val isBranch = !(hits & isJump).orR
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch) val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
val update_btb_hit = io.bht_update.bits.prediction.valid val update_btb_hit = io.bht_update.bits.prediction.valid
when (io.bht_update.valid && update_btb_hit) { when (io.bht_update.valid && update_btb_hit) {
@ -269,7 +255,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
if (nRAS > 0) { if (nRAS > 0) {
val ras = new RAS(nRAS) val ras = new RAS(nRAS)
val doPeek = Mux1H(hits, useRAS) val doPeek = (hits & useRAS).orR
when (!ras.isEmpty && doPeek) { when (!ras.isEmpty && doPeek) {
io.resp.bits.target := ras.peek io.resp.bits.target := ras.peek
} }
@ -280,9 +266,8 @@ class BTB(implicit p: Parameters) extends BtbModule {
io.resp.bits.target := io.ras_update.bits.returnAddr io.resp.bits.target := io.ras_update.bits.returnAddr
} }
}.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) { }.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) {
ras.pop ras.pop()
} }
} }
when (io.invalidate) { ras.clear }
} }
} }

View File

@ -81,7 +81,6 @@ class Frontend(implicit p: Parameters) extends CoreModule()(p) with HasL1CachePa
btb.io.btb_update := io.cpu.btb_update btb.io.btb_update := io.cpu.btb_update
btb.io.bht_update := io.cpu.bht_update btb.io.bht_update := io.cpu.bht_update
btb.io.ras_update := io.cpu.ras_update btb.io.ras_update := io.cpu.ras_update
btb.io.invalidate := false
when (!stall && !icmiss) { when (!stall && !icmiss) {
btb.io.req.valid := true btb.io.req.valid := true
s2_btb_resp_valid := btb.io.resp.valid s2_btb_resp_valid := btb.io.resp.valid