Remove invalidation support from BTB
Validating the target PC in the pipeline is cheaper than maintaining the valid bits and control logic to guarantee the BTB won't ever mispredict branch targets.
This commit is contained in:
parent
663002ec0c
commit
5aa8ef1855
@ -36,15 +36,15 @@ class RAS(nras: Int) {
|
||||
pos := nextPos
|
||||
}
|
||||
def peek: UInt = stack(pos)
|
||||
def pop: Unit = when (!isEmpty) {
|
||||
def pop(): Unit = when (!isEmpty) {
|
||||
count := count - 1
|
||||
pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1))
|
||||
}
|
||||
def clear: Unit = count := UInt(0)
|
||||
def clear(): Unit = count := UInt(0)
|
||||
def isEmpty: Bool = count === UInt(0)
|
||||
|
||||
private val count = Reg(init=UInt(0,log2Up(nras+1)))
|
||||
private val pos = Reg(init=UInt(0,log2Up(nras)))
|
||||
private val count = Reg(UInt(width = log2Up(nras+1)))
|
||||
private val pos = Reg(UInt(width = log2Up(nras)))
|
||||
private val stack = Reg(Vec(nras, UInt()))
|
||||
}
|
||||
|
||||
@ -140,82 +140,80 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
val btb_update = Valid(new BTBUpdate).flip
|
||||
val bht_update = Valid(new BHTUpdate).flip
|
||||
val ras_update = Valid(new RASUpdate).flip
|
||||
val invalidate = Bool(INPUT)
|
||||
}
|
||||
|
||||
val idxValid = Reg(init=UInt(0, entries))
|
||||
val idxs = Mem(entries, UInt(width=matchBits))
|
||||
val idxPages = Mem(entries, UInt(width=log2Up(nPages)))
|
||||
val tgts = Mem(entries, UInt(width=matchBits))
|
||||
val tgtPages = Mem(entries, UInt(width=log2Up(nPages)))
|
||||
val pages = Mem(nPages, UInt(width=vaddrBits-matchBits))
|
||||
val pageValid = Reg(init=UInt(0, nPages))
|
||||
val idxs = Reg(Vec(entries, UInt(width=matchBits)))
|
||||
val idxPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
|
||||
val tgts = Reg(Vec(entries, UInt(width=matchBits)))
|
||||
val tgtPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
|
||||
val pages = Reg(Vec(nPages, UInt(width=vaddrBits-matchBits)))
|
||||
val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0))
|
||||
val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0))
|
||||
|
||||
val useRAS = Reg(Vec(entries, Bool()))
|
||||
val isJump = Reg(Vec(entries, Bool()))
|
||||
val brIdx = Mem(entries, UInt(width=log2Up(fetchWidth)))
|
||||
val useRAS = Reg(UInt(width = entries))
|
||||
val isJump = Reg(UInt(width = entries))
|
||||
val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth))))
|
||||
|
||||
private def page(addr: UInt) = addr >> matchBits
|
||||
private def pageMatch(addr: UInt) = {
|
||||
val p = page(addr)
|
||||
Vec(pages.map(_ === p)).toBits & pageValid
|
||||
Vec(pages.map(_ === p)).toBits
|
||||
}
|
||||
private def tagMatch(addr: UInt, pgMatch: UInt) = {
|
||||
val idx = addr(matchBits-1,0)
|
||||
val idxMatch = idxs.map(_ === idx).toBits
|
||||
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits
|
||||
idxValid & idxMatch & idxPageMatch
|
||||
val idxMatch = idxs.map(_ === addr(matchBits-1,0))
|
||||
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR)
|
||||
(idxPageMatch zip idxMatch) map { case (p, i) => p && i }
|
||||
}
|
||||
|
||||
val r_btb_update = Pipe(io.btb_update)
|
||||
val update_target = io.req.bits.addr
|
||||
|
||||
val pageHit = pageMatch(io.req.bits.addr)
|
||||
val hits = tagMatch(io.req.bits.addr, pageHit)
|
||||
val hitsVec = tagMatch(io.req.bits.addr, pageHit)
|
||||
val hits = hitsVec.toBits
|
||||
val updatePageHit = pageMatch(r_btb_update.bits.pc)
|
||||
val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit)
|
||||
|
||||
val updateHit = r_btb_update.bits.prediction.valid
|
||||
val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1
|
||||
val nextRepl = Reg(UInt(width = log2Ceil(entries)))
|
||||
when (r_btb_update.valid && !updateHit) { nextRepl := Mux(nextRepl === entries-1 && Bool(!isPow2(entries)), 0, nextRepl + 1) }
|
||||
val nextPageRepl = Reg(UInt(width = log2Ceil(nPages)))
|
||||
|
||||
val useUpdatePageHit = updatePageHit.orR
|
||||
val usePageHit = pageHit.orR
|
||||
val doIdxPageRepl = !useUpdatePageHit
|
||||
val idxPageRepl = Wire(UInt(width = nPages))
|
||||
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, idxPageRepl)
|
||||
val idxPageRepl = UIntToOH(nextPageRepl)
|
||||
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit,
|
||||
Mux(usePageHit, Cat(pageHit(nPages-2,0), pageHit(nPages-1)), idxPageRepl))
|
||||
val idxPageUpdate = OHToUInt(idxPageUpdateOH)
|
||||
val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0))
|
||||
|
||||
val samePage = page(r_btb_update.bits.pc) === page(update_target)
|
||||
val usePageHit = (pageHit & ~idxPageReplEn).orR
|
||||
val doTgtPageRepl = !samePage && !usePageHit
|
||||
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1))
|
||||
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, Cat(idxPageUpdateOH(nPages-2,0), idxPageUpdateOH(nPages-1)))
|
||||
val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl))
|
||||
val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0))
|
||||
val doPageRepl = doIdxPageRepl || doTgtPageRepl
|
||||
|
||||
val pageReplEn = idxPageReplEn | tgtPageReplEn
|
||||
idxPageRepl := UIntToOH(Counter(r_btb_update.valid && doPageRepl, nPages)._1)
|
||||
when (r_btb_update.valid && (doIdxPageRepl || doTgtPageRepl)) {
|
||||
val both = doIdxPageRepl && doTgtPageRepl
|
||||
val next = nextPageRepl + Mux[UInt](both, 2, 1)
|
||||
nextPageRepl := Mux(next >= nPages, next(0), next)
|
||||
}
|
||||
|
||||
when (r_btb_update.valid) {
|
||||
assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target")
|
||||
|
||||
val waddr =
|
||||
if (updatesOutOfOrder) Mux(updateHits.orR, OHToUInt(updateHits), nextRepl)
|
||||
if (updatesOutOfOrder) Mux(updateHits.reduce(_|_), OHToUInt(updateHits), nextRepl)
|
||||
else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl)
|
||||
|
||||
// invalidate entries if we stomp on pages they depend upon
|
||||
val invalidateMask = Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits
|
||||
val validateMask = UIntToOH(waddr)
|
||||
idxValid := (idxValid & ~invalidateMask) | validateMask
|
||||
|
||||
idxs(waddr) := r_btb_update.bits.pc
|
||||
tgts(waddr) := update_target
|
||||
idxPages(waddr) := idxPageUpdate
|
||||
tgtPages(waddr) := tgtPageUpdate
|
||||
useRAS(waddr) := r_btb_update.bits.isReturn
|
||||
isJump(waddr) := r_btb_update.bits.isJump
|
||||
val mask = UIntToOH(waddr)
|
||||
useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask)
|
||||
isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask)
|
||||
if (fetchWidth == 1) {
|
||||
brIdx(waddr) := UInt(0)
|
||||
} else {
|
||||
@ -223,41 +221,29 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
}
|
||||
|
||||
require(nPages % 2 == 0)
|
||||
val idxWritesEven = (idxPageUpdateOH & Fill(nPages/2, UInt(1,2))).orR
|
||||
val idxWritesEven = !idxPageUpdate(0)
|
||||
|
||||
def writeBank(i: Int, mod: Int, en: Bool, data: UInt) =
|
||||
def writeBank(i: Int, mod: Int, en: UInt, data: UInt) =
|
||||
for (i <- i until nPages by mod)
|
||||
when (en && pageReplEn(i)) { pages(i) := data }
|
||||
when (en(i)) { pages(i) := data }
|
||||
|
||||
writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl),
|
||||
writeBank(0, 2, Mux(idxWritesEven, idxPageReplEn, tgtPageReplEn),
|
||||
Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target)))
|
||||
writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl),
|
||||
writeBank(1, 2, Mux(idxWritesEven, tgtPageReplEn, idxPageReplEn),
|
||||
Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc)))
|
||||
|
||||
when (doPageRepl) { pageValid := pageValid | pageReplEn }
|
||||
}
|
||||
|
||||
when (io.invalidate) {
|
||||
idxValid := 0
|
||||
pageValid := 0
|
||||
}
|
||||
|
||||
io.resp.valid := hits.orR
|
||||
io.resp.bits.taken := io.resp.valid
|
||||
io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts))
|
||||
io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts))
|
||||
io.resp.bits.entry := OHToUInt(hits)
|
||||
io.resp.bits.bridx := brIdx(io.resp.bits.entry)
|
||||
if (fetchWidth == 1) {
|
||||
io.resp.bits.mask := UInt(1)
|
||||
} else {
|
||||
// note: btb_resp is clock gated, so the mask is only relevant for the io.resp.valid case
|
||||
io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
|
||||
SInt(-1)).toUInt
|
||||
}
|
||||
|
||||
if (nBHT > 0) {
|
||||
val bht = new BHT(nBHT)
|
||||
val isBranch = !Mux1H(hits, isJump)
|
||||
val isBranch = !(hits & isJump).orR
|
||||
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
|
||||
val update_btb_hit = io.bht_update.bits.prediction.valid
|
||||
when (io.bht_update.valid && update_btb_hit) {
|
||||
@ -269,7 +255,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
|
||||
if (nRAS > 0) {
|
||||
val ras = new RAS(nRAS)
|
||||
val doPeek = Mux1H(hits, useRAS)
|
||||
val doPeek = (hits & useRAS).orR
|
||||
when (!ras.isEmpty && doPeek) {
|
||||
io.resp.bits.target := ras.peek
|
||||
}
|
||||
@ -280,9 +266,8 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
||||
io.resp.bits.target := io.ras_update.bits.returnAddr
|
||||
}
|
||||
}.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) {
|
||||
ras.pop
|
||||
ras.pop()
|
||||
}
|
||||
}
|
||||
when (io.invalidate) { ras.clear }
|
||||
}
|
||||
}
|
||||
|
@ -81,7 +81,6 @@ class Frontend(implicit p: Parameters) extends CoreModule()(p) with HasL1CachePa
|
||||
btb.io.btb_update := io.cpu.btb_update
|
||||
btb.io.bht_update := io.cpu.bht_update
|
||||
btb.io.ras_update := io.cpu.ras_update
|
||||
btb.io.invalidate := false
|
||||
when (!stall && !icmiss) {
|
||||
btb.io.req.valid := true
|
||||
s2_btb_resp_valid := btb.io.resp.valid
|
||||
|
Loading…
Reference in New Issue
Block a user