From 5aa8ef1855091c7d82a7cc06a758805ff31f0021 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Sat, 2 Jul 2016 14:27:29 -0700 Subject: [PATCH] Remove invalidation support from BTB Validating the target PC in the pipeline is cheaper than maintaining the valid bits and control logic to guarantee the BTB won't ever mispredict branch targets. --- rocket/src/main/scala/btb.scala | 107 ++++++++++++--------------- rocket/src/main/scala/frontend.scala | 1 - 2 files changed, 46 insertions(+), 62 deletions(-) diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index c5fb0e78..ea30d668 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -36,15 +36,15 @@ class RAS(nras: Int) { pos := nextPos } def peek: UInt = stack(pos) - def pop: Unit = when (!isEmpty) { + def pop(): Unit = when (!isEmpty) { count := count - 1 pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1)) } - def clear: Unit = count := UInt(0) + def clear(): Unit = count := UInt(0) def isEmpty: Bool = count === UInt(0) - private val count = Reg(init=UInt(0,log2Up(nras+1))) - private val pos = Reg(init=UInt(0,log2Up(nras))) + private val count = Reg(UInt(width = log2Up(nras+1))) + private val pos = Reg(UInt(width = log2Up(nras))) private val stack = Reg(Vec(nras, UInt())) } @@ -140,82 +140,80 @@ class BTB(implicit p: Parameters) extends BtbModule { val btb_update = Valid(new BTBUpdate).flip val bht_update = Valid(new BHTUpdate).flip val ras_update = Valid(new RASUpdate).flip - val invalidate = Bool(INPUT) } - val idxValid = Reg(init=UInt(0, entries)) - val idxs = Mem(entries, UInt(width=matchBits)) - val idxPages = Mem(entries, UInt(width=log2Up(nPages))) - val tgts = Mem(entries, UInt(width=matchBits)) - val tgtPages = Mem(entries, UInt(width=log2Up(nPages))) - val pages = Mem(nPages, UInt(width=vaddrBits-matchBits)) - val pageValid = Reg(init=UInt(0, nPages)) + val idxs = Reg(Vec(entries, UInt(width=matchBits))) + val idxPages = Reg(Vec(entries, UInt(width=log2Up(nPages)))) + val tgts = Reg(Vec(entries, UInt(width=matchBits))) + val tgtPages = Reg(Vec(entries, UInt(width=log2Up(nPages)))) + val pages = Reg(Vec(nPages, UInt(width=vaddrBits-matchBits))) val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0)) val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0)) - val useRAS = Reg(Vec(entries, Bool())) - val isJump = Reg(Vec(entries, Bool())) - val brIdx = Mem(entries, UInt(width=log2Up(fetchWidth))) + val useRAS = Reg(UInt(width = entries)) + val isJump = Reg(UInt(width = entries)) + val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth)))) private def page(addr: UInt) = addr >> matchBits private def pageMatch(addr: UInt) = { val p = page(addr) - Vec(pages.map(_ === p)).toBits & pageValid + Vec(pages.map(_ === p)).toBits } private def tagMatch(addr: UInt, pgMatch: UInt) = { - val idx = addr(matchBits-1,0) - val idxMatch = idxs.map(_ === idx).toBits - val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits - idxValid & idxMatch & idxPageMatch + val idxMatch = idxs.map(_ === addr(matchBits-1,0)) + val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR) + (idxPageMatch zip idxMatch) map { case (p, i) => p && i } } val r_btb_update = Pipe(io.btb_update) val update_target = io.req.bits.addr val pageHit = pageMatch(io.req.bits.addr) - val hits = tagMatch(io.req.bits.addr, pageHit) + val hitsVec = tagMatch(io.req.bits.addr, pageHit) + val hits = hitsVec.toBits val updatePageHit = pageMatch(r_btb_update.bits.pc) val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit) val updateHit = r_btb_update.bits.prediction.valid - val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1 + val nextRepl = Reg(UInt(width = log2Ceil(entries))) + when (r_btb_update.valid && !updateHit) { nextRepl := Mux(nextRepl === entries-1 && Bool(!isPow2(entries)), 0, nextRepl + 1) } + val nextPageRepl = Reg(UInt(width = log2Ceil(nPages))) val useUpdatePageHit = updatePageHit.orR + val usePageHit = pageHit.orR val doIdxPageRepl = !useUpdatePageHit - val idxPageRepl = Wire(UInt(width = nPages)) - val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, idxPageRepl) + val idxPageRepl = UIntToOH(nextPageRepl) + val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, + Mux(usePageHit, Cat(pageHit(nPages-2,0), pageHit(nPages-1)), idxPageRepl)) val idxPageUpdate = OHToUInt(idxPageUpdateOH) val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0)) val samePage = page(r_btb_update.bits.pc) === page(update_target) - val usePageHit = (pageHit & ~idxPageReplEn).orR val doTgtPageRepl = !samePage && !usePageHit - val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1)) + val tgtPageRepl = Mux(samePage, idxPageUpdateOH, Cat(idxPageUpdateOH(nPages-2,0), idxPageUpdateOH(nPages-1))) val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl)) val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0)) - val doPageRepl = doIdxPageRepl || doTgtPageRepl - val pageReplEn = idxPageReplEn | tgtPageReplEn - idxPageRepl := UIntToOH(Counter(r_btb_update.valid && doPageRepl, nPages)._1) + when (r_btb_update.valid && (doIdxPageRepl || doTgtPageRepl)) { + val both = doIdxPageRepl && doTgtPageRepl + val next = nextPageRepl + Mux[UInt](both, 2, 1) + nextPageRepl := Mux(next >= nPages, next(0), next) + } when (r_btb_update.valid) { assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target") val waddr = - if (updatesOutOfOrder) Mux(updateHits.orR, OHToUInt(updateHits), nextRepl) + if (updatesOutOfOrder) Mux(updateHits.reduce(_|_), OHToUInt(updateHits), nextRepl) else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl) - // invalidate entries if we stomp on pages they depend upon - val invalidateMask = Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits - val validateMask = UIntToOH(waddr) - idxValid := (idxValid & ~invalidateMask) | validateMask - idxs(waddr) := r_btb_update.bits.pc tgts(waddr) := update_target idxPages(waddr) := idxPageUpdate tgtPages(waddr) := tgtPageUpdate - useRAS(waddr) := r_btb_update.bits.isReturn - isJump(waddr) := r_btb_update.bits.isJump + val mask = UIntToOH(waddr) + useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask) + isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask) if (fetchWidth == 1) { brIdx(waddr) := UInt(0) } else { @@ -223,41 +221,29 @@ class BTB(implicit p: Parameters) extends BtbModule { } require(nPages % 2 == 0) - val idxWritesEven = (idxPageUpdateOH & Fill(nPages/2, UInt(1,2))).orR + val idxWritesEven = !idxPageUpdate(0) - def writeBank(i: Int, mod: Int, en: Bool, data: UInt) = + def writeBank(i: Int, mod: Int, en: UInt, data: UInt) = for (i <- i until nPages by mod) - when (en && pageReplEn(i)) { pages(i) := data } + when (en(i)) { pages(i) := data } - writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl), + writeBank(0, 2, Mux(idxWritesEven, idxPageReplEn, tgtPageReplEn), Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target))) - writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl), + writeBank(1, 2, Mux(idxWritesEven, tgtPageReplEn, idxPageReplEn), Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc))) - - when (doPageRepl) { pageValid := pageValid | pageReplEn } - } - - when (io.invalidate) { - idxValid := 0 - pageValid := 0 } io.resp.valid := hits.orR io.resp.bits.taken := io.resp.valid - io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts)) + io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts)) io.resp.bits.entry := OHToUInt(hits) io.resp.bits.bridx := brIdx(io.resp.bits.entry) - if (fetchWidth == 1) { - io.resp.bits.mask := UInt(1) - } else { - // note: btb_resp is clock gated, so the mask is only relevant for the io.resp.valid case - io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt, - SInt(-1)).toUInt - } + io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt, + SInt(-1)).toUInt if (nBHT > 0) { val bht = new BHT(nBHT) - val isBranch = !Mux1H(hits, isJump) + val isBranch = !(hits & isJump).orR val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch) val update_btb_hit = io.bht_update.bits.prediction.valid when (io.bht_update.valid && update_btb_hit) { @@ -269,7 +255,7 @@ class BTB(implicit p: Parameters) extends BtbModule { if (nRAS > 0) { val ras = new RAS(nRAS) - val doPeek = Mux1H(hits, useRAS) + val doPeek = (hits & useRAS).orR when (!ras.isEmpty && doPeek) { io.resp.bits.target := ras.peek } @@ -280,9 +266,8 @@ class BTB(implicit p: Parameters) extends BtbModule { io.resp.bits.target := io.ras_update.bits.returnAddr } }.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) { - ras.pop + ras.pop() } } - when (io.invalidate) { ras.clear } } } diff --git a/rocket/src/main/scala/frontend.scala b/rocket/src/main/scala/frontend.scala index 4c9d3aaf..83189a94 100644 --- a/rocket/src/main/scala/frontend.scala +++ b/rocket/src/main/scala/frontend.scala @@ -81,7 +81,6 @@ class Frontend(implicit p: Parameters) extends CoreModule()(p) with HasL1CachePa btb.io.btb_update := io.cpu.btb_update btb.io.bht_update := io.cpu.bht_update btb.io.ras_update := io.cpu.ras_update - btb.io.invalidate := false when (!stall && !icmiss) { btb.io.req.valid := true s2_btb_resp_valid := btb.io.resp.valid