From 0d3d9fca2530ca82cdf0ec0f4ee1cdde1cf9a73c Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Fri, 29 Jul 2016 15:01:05 -0700 Subject: [PATCH] [rocket] Allow zapping of BTB entries This is necessary to guarantee forward progress with RVC, since if the BTB keeps mispredicting, the processor might never successfully fetch both halves of a misaligned instruction. --- rocket/src/main/scala/btb.scala | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index 9f4add7f..01e5dead 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -91,6 +91,7 @@ class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) { val pc = UInt(width = vaddrBits) val target = UInt(width = vaddrBits) val taken = Bool() + val isValid = Bool() val isJump = Bool() val isReturn = Bool() val br_pc = UInt(width = vaddrBits) @@ -151,7 +152,8 @@ class BTB(implicit p: Parameters) extends BtbModule { val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0)) val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0)) - val useRAS = Reg(UInt(width = entries)) + val isValid = Reg(init = UInt(0, entries)) + val isReturn = Reg(UInt(width = entries)) val isJump = Reg(UInt(width = entries)) val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth)))) @@ -163,7 +165,7 @@ class BTB(implicit p: Parameters) extends BtbModule { private def tagMatch(addr: UInt, pgMatch: UInt) = { val idxMatch = idxs.map(_ === addr(matchBits-1, log2Up(coreInstBytes))).toBits val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits - idxMatch & idxPageMatch + idxMatch & idxPageMatch & isValid } val r_btb_update = Pipe(io.btb_update) @@ -177,11 +179,7 @@ class BTB(implicit p: Parameters) extends BtbModule { val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit) val updateHit = if (updatesOutOfOrder) updateHits.orR else r_btb_update.bits.prediction.valid val updateHitAddr = if (updatesOutOfOrder) OHToUInt(updateHits) else r_btb_update.bits.prediction.bits.entry - - // guarantee one-hotness of idx after reset - val resetting = Reg(init = Bool(true)) - val (nextRepl, wrap) = Counter(resetting || (r_btb_update.valid && !updateHit), entries) - when (wrap) { resetting := false } + val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1 val useUpdatePageHit = updatePageHit.orR val usePageHit = pageHit.orR @@ -204,17 +202,15 @@ class BTB(implicit p: Parameters) extends BtbModule { nextPageRepl := Mux(next >= nPages, next(0), next) } - when (r_btb_update.valid || resetting) { - assert(resetting || io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target") - - val waddr = Mux(updateHit && !resetting, updateHitAddr, nextRepl) + when (r_btb_update.valid) { + val waddr = Mux(updateHit, updateHitAddr, nextRepl) val mask = UIntToOH(waddr) - val newIdx = r_btb_update.bits.pc(matchBits-1, log2Up(coreInstBytes)) - idxs(waddr) := Mux(resetting, Cat(newIdx >> log2Ceil(entries), nextRepl), newIdx) + idxs(waddr) := r_btb_update.bits.pc(matchBits-1, log2Up(coreInstBytes)) tgts(waddr) := update_target(matchBits-1, log2Up(coreInstBytes)) idxPages(waddr) := idxPageUpdate tgtPages(waddr) := tgtPageUpdate - useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask) + isValid := Mux(r_btb_update.bits.isValid, isValid | mask, isValid & ~mask) + isReturn := Mux(r_btb_update.bits.isReturn, isReturn | mask, isReturn & ~mask) isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask) if (fetchWidth > 1) brIdx(waddr) := r_btb_update.bits.br_pc >> log2Up(coreInstBytes) @@ -234,7 +230,7 @@ class BTB(implicit p: Parameters) extends BtbModule { } io.resp.valid := hits.orR - io.resp.bits.taken := io.resp.valid + io.resp.bits.taken := true io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts) << log2Up(coreInstBytes)) io.resp.bits.entry := OHToUInt(hits) io.resp.bits.bridx := Mux1H(hitsVec, brIdx) @@ -254,7 +250,7 @@ class BTB(implicit p: Parameters) extends BtbModule { if (nRAS > 0) { val ras = new RAS(nRAS) - val doPeek = (hits & useRAS).orR + val doPeek = (hits & isReturn).orR when (!ras.isEmpty && doPeek) { io.resp.bits.target := ras.peek }