diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index c5fb0e78..ea30d668 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -36,15 +36,15 @@ class RAS(nras: Int) { pos := nextPos } def peek: UInt = stack(pos) - def pop: Unit = when (!isEmpty) { + def pop(): Unit = when (!isEmpty) { count := count - 1 pos := Mux(Bool(isPow2(nras)) || pos > 0, pos-1, UInt(nras-1)) } - def clear: Unit = count := UInt(0) + def clear(): Unit = count := UInt(0) def isEmpty: Bool = count === UInt(0) - private val count = Reg(init=UInt(0,log2Up(nras+1))) - private val pos = Reg(init=UInt(0,log2Up(nras))) + private val count = Reg(UInt(width = log2Up(nras+1))) + private val pos = Reg(UInt(width = log2Up(nras))) private val stack = Reg(Vec(nras, UInt())) } @@ -140,82 +140,80 @@ class BTB(implicit p: Parameters) extends BtbModule { val btb_update = Valid(new BTBUpdate).flip val bht_update = Valid(new BHTUpdate).flip val ras_update = Valid(new RASUpdate).flip - val invalidate = Bool(INPUT) } - val idxValid = Reg(init=UInt(0, entries)) - val idxs = Mem(entries, UInt(width=matchBits)) - val idxPages = Mem(entries, UInt(width=log2Up(nPages))) - val tgts = Mem(entries, UInt(width=matchBits)) - val tgtPages = Mem(entries, UInt(width=log2Up(nPages))) - val pages = Mem(nPages, UInt(width=vaddrBits-matchBits)) - val pageValid = Reg(init=UInt(0, nPages)) + val idxs = Reg(Vec(entries, UInt(width=matchBits))) + val idxPages = Reg(Vec(entries, UInt(width=log2Up(nPages)))) + val tgts = Reg(Vec(entries, UInt(width=matchBits))) + val tgtPages = Reg(Vec(entries, UInt(width=log2Up(nPages)))) + val pages = Reg(Vec(nPages, UInt(width=vaddrBits-matchBits))) val idxPagesOH = idxPages.map(UIntToOH(_)(nPages-1,0)) val tgtPagesOH = tgtPages.map(UIntToOH(_)(nPages-1,0)) - val useRAS = Reg(Vec(entries, Bool())) - val isJump = Reg(Vec(entries, Bool())) - val brIdx = Mem(entries, UInt(width=log2Up(fetchWidth))) + val useRAS = Reg(UInt(width = entries)) + val isJump = Reg(UInt(width = entries)) + val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth)))) private def page(addr: UInt) = addr >> matchBits private def pageMatch(addr: UInt) = { val p = page(addr) - Vec(pages.map(_ === p)).toBits & pageValid + Vec(pages.map(_ === p)).toBits } private def tagMatch(addr: UInt, pgMatch: UInt) = { - val idx = addr(matchBits-1,0) - val idxMatch = idxs.map(_ === idx).toBits - val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR).toBits - idxValid & idxMatch & idxPageMatch + val idxMatch = idxs.map(_ === addr(matchBits-1,0)) + val idxPageMatch = idxPagesOH.map(_ & pgMatch).map(_.orR) + (idxPageMatch zip idxMatch) map { case (p, i) => p && i } } val r_btb_update = Pipe(io.btb_update) val update_target = io.req.bits.addr val pageHit = pageMatch(io.req.bits.addr) - val hits = tagMatch(io.req.bits.addr, pageHit) + val hitsVec = tagMatch(io.req.bits.addr, pageHit) + val hits = hitsVec.toBits val updatePageHit = pageMatch(r_btb_update.bits.pc) val updateHits = tagMatch(r_btb_update.bits.pc, updatePageHit) val updateHit = r_btb_update.bits.prediction.valid - val nextRepl = Counter(r_btb_update.valid && !updateHit, entries)._1 + val nextRepl = Reg(UInt(width = log2Ceil(entries))) + when (r_btb_update.valid && !updateHit) { nextRepl := Mux(nextRepl === entries-1 && Bool(!isPow2(entries)), 0, nextRepl + 1) } + val nextPageRepl = Reg(UInt(width = log2Ceil(nPages))) val useUpdatePageHit = updatePageHit.orR + val usePageHit = pageHit.orR val doIdxPageRepl = !useUpdatePageHit - val idxPageRepl = Wire(UInt(width = nPages)) - val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, idxPageRepl) + val idxPageRepl = UIntToOH(nextPageRepl) + val idxPageUpdateOH = Mux(useUpdatePageHit, updatePageHit, + Mux(usePageHit, Cat(pageHit(nPages-2,0), pageHit(nPages-1)), idxPageRepl)) val idxPageUpdate = OHToUInt(idxPageUpdateOH) val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0)) val samePage = page(r_btb_update.bits.pc) === page(update_target) - val usePageHit = (pageHit & ~idxPageReplEn).orR val doTgtPageRepl = !samePage && !usePageHit - val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1)) + val tgtPageRepl = Mux(samePage, idxPageUpdateOH, Cat(idxPageUpdateOH(nPages-2,0), idxPageUpdateOH(nPages-1))) val tgtPageUpdate = OHToUInt(Mux(usePageHit, pageHit, tgtPageRepl)) val tgtPageReplEn = Mux(doTgtPageRepl, tgtPageRepl, UInt(0)) - val doPageRepl = doIdxPageRepl || doTgtPageRepl - val pageReplEn = idxPageReplEn | tgtPageReplEn - idxPageRepl := UIntToOH(Counter(r_btb_update.valid && doPageRepl, nPages)._1) + when (r_btb_update.valid && (doIdxPageRepl || doTgtPageRepl)) { + val both = doIdxPageRepl && doTgtPageRepl + val next = nextPageRepl + Mux[UInt](both, 2, 1) + nextPageRepl := Mux(next >= nPages, next(0), next) + } when (r_btb_update.valid) { assert(io.req.bits.addr === r_btb_update.bits.target, "BTB request != I$ target") val waddr = - if (updatesOutOfOrder) Mux(updateHits.orR, OHToUInt(updateHits), nextRepl) + if (updatesOutOfOrder) Mux(updateHits.reduce(_|_), OHToUInt(updateHits), nextRepl) else Mux(updateHit, r_btb_update.bits.prediction.bits.entry, nextRepl) - // invalidate entries if we stomp on pages they depend upon - val invalidateMask = Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits - val validateMask = UIntToOH(waddr) - idxValid := (idxValid & ~invalidateMask) | validateMask - idxs(waddr) := r_btb_update.bits.pc tgts(waddr) := update_target idxPages(waddr) := idxPageUpdate tgtPages(waddr) := tgtPageUpdate - useRAS(waddr) := r_btb_update.bits.isReturn - isJump(waddr) := r_btb_update.bits.isJump + val mask = UIntToOH(waddr) + useRAS := Mux(r_btb_update.bits.isReturn, useRAS | mask, useRAS & ~mask) + isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask) if (fetchWidth == 1) { brIdx(waddr) := UInt(0) } else { @@ -223,41 +221,29 @@ class BTB(implicit p: Parameters) extends BtbModule { } require(nPages % 2 == 0) - val idxWritesEven = (idxPageUpdateOH & Fill(nPages/2, UInt(1,2))).orR + val idxWritesEven = !idxPageUpdate(0) - def writeBank(i: Int, mod: Int, en: Bool, data: UInt) = + def writeBank(i: Int, mod: Int, en: UInt, data: UInt) = for (i <- i until nPages by mod) - when (en && pageReplEn(i)) { pages(i) := data } + when (en(i)) { pages(i) := data } - writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl), + writeBank(0, 2, Mux(idxWritesEven, idxPageReplEn, tgtPageReplEn), Mux(idxWritesEven, page(r_btb_update.bits.pc), page(update_target))) - writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl), + writeBank(1, 2, Mux(idxWritesEven, tgtPageReplEn, idxPageReplEn), Mux(idxWritesEven, page(update_target), page(r_btb_update.bits.pc))) - - when (doPageRepl) { pageValid := pageValid | pageReplEn } - } - - when (io.invalidate) { - idxValid := 0 - pageValid := 0 } io.resp.valid := hits.orR io.resp.bits.taken := io.resp.valid - io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts)) + io.resp.bits.target := Cat(Mux1H(Mux1H(hitsVec, tgtPagesOH), pages), Mux1H(hitsVec, tgts)) io.resp.bits.entry := OHToUInt(hits) io.resp.bits.bridx := brIdx(io.resp.bits.entry) - if (fetchWidth == 1) { - io.resp.bits.mask := UInt(1) - } else { - // note: btb_resp is clock gated, so the mask is only relevant for the io.resp.valid case - io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt, - SInt(-1)).toUInt - } + io.resp.bits.mask := Mux(io.resp.bits.taken, Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)).toSInt, + SInt(-1)).toUInt if (nBHT > 0) { val bht = new BHT(nBHT) - val isBranch = !Mux1H(hits, isJump) + val isBranch = !(hits & isJump).orR val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch) val update_btb_hit = io.bht_update.bits.prediction.valid when (io.bht_update.valid && update_btb_hit) { @@ -269,7 +255,7 @@ class BTB(implicit p: Parameters) extends BtbModule { if (nRAS > 0) { val ras = new RAS(nRAS) - val doPeek = Mux1H(hits, useRAS) + val doPeek = (hits & useRAS).orR when (!ras.isEmpty && doPeek) { io.resp.bits.target := ras.peek } @@ -280,9 +266,8 @@ class BTB(implicit p: Parameters) extends BtbModule { io.resp.bits.target := io.ras_update.bits.returnAddr } }.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) { - ras.pop + ras.pop() } } - when (io.invalidate) { ras.clear } } } diff --git a/rocket/src/main/scala/frontend.scala b/rocket/src/main/scala/frontend.scala index 4c9d3aaf..83189a94 100644 --- a/rocket/src/main/scala/frontend.scala +++ b/rocket/src/main/scala/frontend.scala @@ -81,7 +81,6 @@ class Frontend(implicit p: Parameters) extends CoreModule()(p) with HasL1CachePa btb.io.btb_update := io.cpu.btb_update btb.io.bht_update := io.cpu.bht_update btb.io.ras_update := io.cpu.ras_update - btb.io.invalidate := false when (!stall && !icmiss) { btb.io.req.valid := true s2_btb_resp_valid := btb.io.resp.valid