Remove invalidation support from BTB
Validating the target PC in the pipeline is cheaper than maintaining the valid bits and control logic to guarantee the BTB won't ever mispredict branch targets.
This commit is contained in:
parent
663002ec0c
commit
5aa8ef1855
@ -36,15 +36,15 @@ class RAS(nras: Int) {
|
|||||||
pos := nextPos
|
pos := nextPos
|
||||||
}
|
}
|
||||||
def peek: UInt = stack(pos)
|
def peek: UInt = stack(pos)
|
||||||
def pop: Unit = when (!isEmpty) {
|
def pop(): Unit = when (!isEmpty) {
|
||||||
count := count - 1
|
count := count - 1
|
||||||
pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1))
|
pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1))
|
||||||
}
|
}
|
||||||
def clear: Unit = count := UInt(0)
|
def clear(): Unit = count := UInt(0)
|
||||||
def isEmpty: Bool = count === UInt(0)
|
def isEmpty: Bool = count === UInt(0)
|
||||||
|
|
||||||
private val count = Reg(init=UInt(0,log2Up(nras+1)))
|
private val count = Reg(UInt(width = log2Up(nras+1)))
|
||||||
private val pos = Reg(init=UInt(0,log2Up(nras)))
|
private val pos = Reg(UInt(width = log2Up(nras)))
|
||||||
private val stack = Reg(Vec(nras, UInt()))
|
private val stack = Reg(Vec(nras, UInt()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,82 +140,80 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
val btb_update = Valid(new BTBUpdate).flip
|
val btb_update = Valid(new BTBUpdate).flip
|
||||||
val bht_update = Valid(new BHTUpdate).flip
|
val bht_update = Valid(new BHTUpdate).flip
|
||||||
val ras_update = Valid(new RASUpdate).flip
|
val ras_update = Valid(new RASUpdate).flip
|
||||||
val invalidate = Bool(INPUT)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val idxValid = Reg(init=UInt(0, entries))
|
val idxs = Reg(Vec(entries, UInt(width=matchBits)))
|
||||||
val idxs = Mem(entries, UInt(width=matchBits))
|
val idxPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
|
||||||
val idxPages = Mem(entries, UInt(width=log2Up(nPages)))
|
val tgts = Reg(Vec(entries, UInt(width=matchBits)))
|
||||||
val tgts = Mem(entries, UInt(width=matchBits))
|
val tgtPages = Reg(Vec(entries, UInt(width=log2Up(nPages))))
|
||||||
val tgtPages = Mem(entries, UInt(width=log2Up(nPages)))
|
val pages = Reg(Vec(nPages, UInt(width=vaddrBits-matchBits)))
|
||||||
val pages = Mem(nPages, UInt(width=vaddrBits-matchBits))
|
|
||||||
val pageValid = Reg(init=UInt(0, nPages))
|
|
||||||
val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0))
|
val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0))
|
||||||
val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0))
|
val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0))
|
||||||
|
|
||||||
val useRAS = Reg(Vec(entries, Bool()))
|
val useRAS = Reg(UInt(width = entries))
|
||||||
val isJump = Reg(Vec(entries, Bool()))
|
val isJump = Reg(UInt(width = entries))
|
||||||
val brIdx = Mem(entries, UInt(width=log2Up(fetchWidth)))
|
val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth))))
|
||||||
|
|
||||||
private def page(addr: UInt) = addr >> matchBits
|
private def page(addr: UInt) = addr >> matchBits
|
||||||
private def pageMatch(addr: UInt) = {
|
private def pageMatch(addr: UInt) = {
|
||||||
val p = page(addr)
|
val p = page(addr)
|
||||||
Vec(pages.map(_ === p)).toBits & pageValid
|
Vec(pages.map(_ === p)).toBits
|
||||||
}
|
}
|
||||||
private def tagMatch(addr: UInt, pgMatch: UInt) = {
|
private def tagMatch(addr: UInt, pgMatch: UInt) = {
|
||||||
val idx = addr(matchBits-1,0)
|
val idxMatch = idxs.map(_ === addr(matchBits-1,0))
|
||||||
val idxMatch = idxs.map(_ === idx).toBits
|
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR)
|
||||||
val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits
|
(idxPageMatch zip idxMatch) map { case (p, i) => p && i }
|
||||||
idxValid & idxMatch & idxPageMatch
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val r_btb_update = Pipe(io.btb_update)
|
val r_btb_update = Pipe(io.btb_update)
|
||||||
val update_target = io.req.bits.addr
|
val update_target = io.req.bits.addr
|
||||||
|
|
||||||
val pageHit = pageMatch(io.req.bits.addr)
|
val pageHit = pageMatch(io.req.bits.addr)
|
||||||
val hits = tagMatch(io.req.bits.addr, pageHit)
|
val hitsVec = tagMatch(io.req.bits.addr, pageHit)
|
||||||
|
val hits = hitsVec.toBits
|
||||||
val updatePageHit = pageMatch(r_btb_update.bits.pc)
|
val updatePageHit = pageMatch(r_btb_update.bits.pc)
|
||||||
val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit)
|
val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit)
|
||||||
|
|
||||||
val updateHit = r_btb_update.bits.prediction.valid
|
val updateHit = r_btb_update.bits.prediction.valid
|
||||||
val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1
|
val nextRepl = Reg(UInt(width = log2Ceil(entries)))
|
||||||
|
when (r_btb_update.valid && !updateHit) { nextRepl := Mux(nextRepl === entries-1 && Bool(!isPow2(entries)), 0, nextRepl + 1) }
|
||||||
|
val nextPageRepl = Reg(UInt(width = log2Ceil(nPages)))
|
||||||
|
|
||||||
val useUpdatePageHit = updatePageHit.orR
|
val useUpdatePageHit = updatePageHit.orR
|
||||||
|
val usePageHit = pageHit.orR
|
||||||
val doIdxPageRepl = !useUpdatePageHit
|
val doIdxPageRepl = !useUpdatePageHit
|
||||||
val idxPageRepl = Wire(UInt(width = nPages))
|
val idxPageRepl = UIntToOH(nextPageRepl)
|
||||||
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, idxPageRepl)
|
val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit,
|
||||||
|
Mux(usePageHit, Cat(pageHit(nPages-2,0), pageHit(nPages-1)), idxPageRepl))
|
||||||
val idxPageUpdate = OHToUInt(idxPageUpdateOH)
|
val idxPageUpdate = OHToUInt(idxPageUpdateOH)
|
||||||
val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0))
|
val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0))
|
||||||
|
|
||||||
val samePage = page(r_btb_update.bits.pc) === page(update_target)
|
val samePage = page(r_btb_update.bits.pc) === page(update_target)
|
||||||
val usePageHit = (pageHit & ~idxPageReplEn).orR
|
|
||||||
val doTgtPageRepl = !samePage && !usePageHit
|
val doTgtPageRepl = !samePage && !usePageHit
|
||||||
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1))
|
val tgtPageRepl = Mux(samePage, idxPageUpdateOH, Cat(idxPageUpdateOH(nPages-2,0), idxPageUpdateOH(nPages-1)))
|
||||||
val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl))
|
val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl))
|
||||||
val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0))
|
val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0))
|
||||||
val doPageRepl = doIdxPageRepl || doTgtPageRepl
|
|
||||||
|
|
||||||
val pageReplEn = idxPageReplEn | tgtPageReplEn
|
when (r_btb_update.valid && (doIdxPageRepl || doTgtPageRepl)) {
|
||||||
idxPageRepl := UIntToOH(Counter(r_btb_update.valid && doPageRepl, nPages)._1)
|
val both = doIdxPageRepl && doTgtPageRepl
|
||||||
|
val next = nextPageRepl + Mux[UInt](both, 2, 1)
|
||||||
|
nextPageRepl := Mux(next >= nPages, next(0), next)
|
||||||
|
}
|
||||||
|
|
||||||
when (r_btb_update.valid) {
|
when (r_btb_update.valid) {
|
||||||
assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target")
|
assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target")
|
||||||
|
|
||||||
val waddr =
|
val waddr =
|
||||||
if (updatesOutOfOrder) Mux(updateHits.orR, OHToUInt(updateHits), nextRepl)
|
if (updatesOutOfOrder) Mux(updateHits.reduce(_|_), OHToUInt(updateHits), nextRepl)
|
||||||
else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl)
|
else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl)
|
||||||
|
|
||||||
// invalidate entries if we stomp on pages they depend upon
|
|
||||||
val invalidateMask = Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits
|
|
||||||
val validateMask = UIntToOH(waddr)
|
|
||||||
idxValid := (idxValid & ~invalidateMask) | validateMask
|
|
||||||
|
|
||||||
idxs(waddr) := r_btb_update.bits.pc
|
idxs(waddr) := r_btb_update.bits.pc
|
||||||
tgts(waddr) := update_target
|
tgts(waddr) := update_target
|
||||||
idxPages(waddr) := idxPageUpdate
|
idxPages(waddr) := idxPageUpdate
|
||||||
tgtPages(waddr) := tgtPageUpdate
|
tgtPages(waddr) := tgtPageUpdate
|
||||||
useRAS(waddr) := r_btb_update.bits.isReturn
|
val mask = UIntToOH(waddr)
|
||||||
isJump(waddr) := r_btb_update.bits.isJump
|
useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask)
|
||||||
|
isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask)
|
||||||
if (fetchWidth == 1) {
|
if (fetchWidth == 1) {
|
||||||
brIdx(waddr) := UInt(0)
|
brIdx(waddr) := UInt(0)
|
||||||
} else {
|
} else {
|
||||||
@ -223,41 +221,29 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
}
|
}
|
||||||
|
|
||||||
require(nPages % 2 == 0)
|
require(nPages % 2 == 0)
|
||||||
val idxWritesEven = (idxPageUpdateOH & Fill(nPages/2, UInt(1,2))).orR
|
val idxWritesEven = !idxPageUpdate(0)
|
||||||
|
|
||||||
def writeBank(i: Int, mod: Int, en: Bool, data: UInt) =
|
def writeBank(i: Int, mod: Int, en: UInt, data: UInt) =
|
||||||
for (i <- i until nPages by mod)
|
for (i <- i until nPages by mod)
|
||||||
when (en && pageReplEn(i)) { pages(i) := data }
|
when (en(i)) { pages(i) := data }
|
||||||
|
|
||||||
writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl),
|
writeBank(0, 2, Mux(idxWritesEven, idxPageReplEn, tgtPageReplEn),
|
||||||
Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target)))
|
Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target)))
|
||||||
writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl),
|
writeBank(1, 2, Mux(idxWritesEven, tgtPageReplEn, idxPageReplEn),
|
||||||
Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc)))
|
Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc)))
|
||||||
|
|
||||||
when (doPageRepl) { pageValid := pageValid | pageReplEn }
|
|
||||||
}
|
|
||||||
|
|
||||||
when (io.invalidate) {
|
|
||||||
idxValid := 0
|
|
||||||
pageValid := 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
io.resp.valid := hits.orR
|
io.resp.valid := hits.orR
|
||||||
io.resp.bits.taken := io.resp.valid
|
io.resp.bits.taken := io.resp.valid
|
||||||
io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts))
|
io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts))
|
||||||
io.resp.bits.entry := OHToUInt(hits)
|
io.resp.bits.entry := OHToUInt(hits)
|
||||||
io.resp.bits.bridx := brIdx(io.resp.bits.entry)
|
io.resp.bits.bridx := brIdx(io.resp.bits.entry)
|
||||||
if (fetchWidth == 1) {
|
|
||||||
io.resp.bits.mask := UInt(1)
|
|
||||||
} else {
|
|
||||||
// note: btb_resp is clock gated, so the mask is only relevant for the io.resp.valid case
|
|
||||||
io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
|
io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt,
|
||||||
SInt(-1)).toUInt
|
SInt(-1)).toUInt
|
||||||
}
|
|
||||||
|
|
||||||
if (nBHT > 0) {
|
if (nBHT > 0) {
|
||||||
val bht = new BHT(nBHT)
|
val bht = new BHT(nBHT)
|
||||||
val isBranch = !Mux1H(hits, isJump)
|
val isBranch = !(hits & isJump).orR
|
||||||
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
|
val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch)
|
||||||
val update_btb_hit = io.bht_update.bits.prediction.valid
|
val update_btb_hit = io.bht_update.bits.prediction.valid
|
||||||
when (io.bht_update.valid && update_btb_hit) {
|
when (io.bht_update.valid && update_btb_hit) {
|
||||||
@ -269,7 +255,7 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
|
|
||||||
if (nRAS > 0) {
|
if (nRAS > 0) {
|
||||||
val ras = new RAS(nRAS)
|
val ras = new RAS(nRAS)
|
||||||
val doPeek = Mux1H(hits, useRAS)
|
val doPeek = (hits & useRAS).orR
|
||||||
when (!ras.isEmpty && doPeek) {
|
when (!ras.isEmpty && doPeek) {
|
||||||
io.resp.bits.target := ras.peek
|
io.resp.bits.target := ras.peek
|
||||||
}
|
}
|
||||||
@ -280,9 +266,8 @@ class BTB(implicit p: Parameters) extends BtbModule {
|
|||||||
io.resp.bits.target := io.ras_update.bits.returnAddr
|
io.resp.bits.target := io.ras_update.bits.returnAddr
|
||||||
}
|
}
|
||||||
}.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) {
|
}.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) {
|
||||||
ras.pop
|
ras.pop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
when (io.invalidate) { ras.clear }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,6 @@ class Frontend(implicit p: Parameters) extends CoreModule()(p) with HasL1CachePa
|
|||||||
btb.io.btb_update := io.cpu.btb_update
|
btb.io.btb_update := io.cpu.btb_update
|
||||||
btb.io.bht_update := io.cpu.bht_update
|
btb.io.bht_update := io.cpu.bht_update
|
||||||
btb.io.ras_update := io.cpu.ras_update
|
btb.io.ras_update := io.cpu.ras_update
|
||||||
btb.io.invalidate := false
|
|
||||||
when (!stall && !icmiss) {
|
when (!stall && !icmiss) {
|
||||||
btb.io.req.valid := true
|
btb.io.req.valid := true
|
||||||
s2_btb_resp_valid := btb.io.resp.valid
|
s2_btb_resp_valid := btb.io.resp.valid
|
||||||
|
Loading…
Reference in New Issue
Block a user