diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index 02ca111e..0f1ee962 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -8,7 +8,7 @@ import Node._ import uncore._ case object NBTBEntries extends Field[Int] -case object NRAS extends Field[Int] +case object NRAS extends Field[Int] abstract trait BTBParameters extends UsesParameters { val vaddrBits = params(VAddrBits) @@ -41,21 +41,25 @@ class RAS(nras: Int) { } class BHTResp extends Bundle with BTBParameters { - val index = UInt(width = log2Up(nBHT).max(1)) + val history = UInt(width = log2Up(nBHT).max(1)) val value = UInt(width = 2) } class BHT(nbht: Int) { - val nbhtbits = log2Up(nbht) - def get(addr: UInt): BHTResp = { + val nbhtbits = log2Up(nbht) + def get(addr: UInt, update: Bool): BHTResp = { val res = new BHTResp - res.index := addr(nbhtbits+1,2) ^ history - res.value := table(res.index) + val index = addr(nbhtbits+1,2) ^ history + res.value := table(index) + res.history := history + val taken = res.value(0) + when (update) { history := Cat(taken, history(nbhtbits-1,1)) } res } - def update(d: BHTResp, taken: Bool): Unit = { - table(d.index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken)) - history := Cat(taken, history(nbhtbits-1,1)) + def update(addr: UInt, d: BHTResp, taken: Bool, mispredict: Bool): Unit = { + val index = addr(nbhtbits+1,2) ^ d.history + table(index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken)) + when (mispredict) { history := Cat(taken, d.history(nbhtbits-1,1)) } } private val table = Mem(UInt(width = 2), nbht) @@ -81,10 +85,14 @@ class BTBResp extends Bundle with BTBParameters { val bht = new BHTResp } +class BTBReq extends Bundle with BTBParameters { + val addr = UInt(width = vaddrBits) +} + // fully-associative branch target buffer class BTB extends Module with BTBParameters { val io = new Bundle { - val req = UInt(INPUT, vaddrBits) + val req = Valid(new BTBReq).flip val resp = Valid(new BTBResp) val update = Valid(new BTBUpdate).flip val invalidate = Bool(INPUT) @@ -115,23 +123,23 @@ class BTB extends Module with BTBParameters { idxValid & idxMatch & idxPageMatch } - val update = Pipe(io.update) - val update_target = io.req + val r_update = Pipe(io.update) + val update_target = io.req.bits.addr - val pageHit = pageMatch(io.req) - val hits = tagMatch(io.req, pageHit) - val updatePageHit = pageMatch(update.bits.pc) - val updateHits = tagMatch(update.bits.pc, updatePageHit) + val pageHit = pageMatch(io.req.bits.addr) + val hits = tagMatch(io.req.bits.addr, pageHit) + val updatePageHit = pageMatch(r_update.bits.pc) + val updateHits = tagMatch(r_update.bits.pc, updatePageHit) - private var lfsr = LFSR16(update.valid) + private var lfsr = LFSR16(r_update.valid) def rand(width: Int) = { lfsr = lfsr(lfsr.getWidth-1,1) Random.oneHot(width, lfsr) } - val updateHit = update.bits.prediction.valid - val updateValid = update.bits.incorrectTarget || updateHit && Bool(nBHT > 0) - val updateTarget = updateValid && update.bits.incorrectTarget + val updateHit = r_update.bits.prediction.valid + val updateValid = r_update.bits.incorrectTarget || updateHit && Bool(nBHT > 0) + val updateTarget = updateValid && r_update.bits.incorrectTarget val useUpdatePageHit = updatePageHit.orR val doIdxPageRepl = updateTarget && !useUpdatePageHit @@ -140,7 +148,7 @@ class BTB extends Module with BTBParameters { val idxPageUpdate = OHToUInt(idxPageUpdateOH) val idxPageReplEn = Mux(doIdxPageRepl, idxPageRepl, UInt(0)) - val samePage = page(update.bits.pc) === page(update_target) + val samePage = page(r_update.bits.pc) === page(update_target) val usePageHit = (pageHit & ~idxPageReplEn).orR val doTgtPageRepl = updateTarget && !samePage && !usePageHit val tgtPageRepl = Mux(samePage, idxPageUpdateOH, idxPageUpdateOH(nPages-2,0) << 1 | idxPageUpdateOH(nPages-1)) @@ -149,24 +157,24 @@ class BTB extends Module with BTBParameters { val doPageRepl = doIdxPageRepl || doTgtPageRepl val pageReplEn = idxPageReplEn | tgtPageReplEn - idxPageRepl := UIntToOH(Counter(update.valid && doPageRepl, nPages)._1) + idxPageRepl := UIntToOH(Counter(r_update.valid && doPageRepl, nPages)._1) - when (update.valid && !(updateValid && !updateTarget)) { + when (r_update.valid && !(updateValid && !updateTarget)) { val nextRepl = Counter(!updateHit && updateValid, entries)._1 - val waddr = Mux(updateHit, update.bits.prediction.bits.entry, nextRepl) + val waddr = Mux(updateHit, r_update.bits.prediction.bits.entry, nextRepl) // invalidate entries if we stomp on pages they depend upon idxValid := idxValid & ~Vec.tabulate(entries)(i => (pageReplEn & (idxPagesOH(i) | tgtPagesOH(i))).orR).toBits idxValid(waddr) := updateValid when (updateTarget) { - assert(io.req === update.bits.target, "BTB request != I$ target") - idxs(waddr) := update.bits.pc + assert(io.req.bits.addr === r_update.bits.target, "BTB request != I$ target") + idxs(waddr) := r_update.bits.pc tgts(waddr) := update_target idxPages(waddr) := idxPageUpdate tgtPages(waddr) := tgtPageUpdate - useRAS(waddr) := update.bits.isReturn - isJump(waddr) := update.bits.isJump + useRAS(waddr) := r_update.bits.isReturn + isJump(waddr) := r_update.bits.isJump } require(nPages % 2 == 0) @@ -177,9 +185,9 @@ class BTB extends Module with BTBParameters { when (en && pageReplEn(i)) { pages(i) := data } writeBank(0, 2, Mux(idxWritesEven, doIdxPageRepl, doTgtPageRepl), - Mux(idxWritesEven, page(update.bits.pc), page(update_target))) + Mux(idxWritesEven, page(r_update.bits.pc), page(update_target))) writeBank(1, 2, Mux(idxWritesEven, doTgtPageRepl, doIdxPageRepl), - Mux(idxWritesEven, page(update_target), page(update.bits.pc))) + Mux(idxWritesEven, page(update_target), page(r_update.bits.pc))) when (doPageRepl) { pageValid := pageValid | pageReplEn } } @@ -196,8 +204,12 @@ class BTB extends Module with BTBParameters { if (nBHT > 0) { val bht = new BHT(nBHT) - val res = bht.get(io.req) - when (update.valid && updateHit && !update.bits.isJump) { bht.update(update.bits.prediction.bits.bht, update.bits.taken) } + val res = bht.get(io.req.bits.addr, io.req.valid && hits.orR && !Mux1H(hits, isJump)) + val update_btb_hit = io.update.bits.prediction.valid + when (io.update.valid && update_btb_hit && !io.update.bits.isJump) { + bht.update(io.update.bits.pc, io.update.bits.prediction.bits.bht, + io.update.bits.taken, io.update.bits.incorrectTarget) + } when (!res.value(0) && !Mux1H(hits, isJump)) { io.resp.bits.taken := false } io.resp.bits.bht := res } diff --git a/rocket/src/main/scala/icache.scala b/rocket/src/main/scala/icache.scala index fb2dbb0c..45327e46 100644 --- a/rocket/src/main/scala/icache.scala +++ b/rocket/src/main/scala/icache.scala @@ -44,7 +44,7 @@ class Frontend extends FrontendModule val cpu = new CPUFrontendIO().flip val mem = new UncachedTileLinkIO } - + val btb = Module(new BTB) val icache = Module(new ICache) val tlb = Module(new TLB(params(NITLBEntries))) @@ -85,7 +85,8 @@ class Frontend extends FrontendModule s2_valid := Bool(false) } - btb.io.req := s1_pc & SInt(-coreInstBytes) + btb.io.req.valid := !stall && !icmiss + btb.io.req.bits.addr := s1_pc & SInt(-coreInstBytes) btb.io.update := io.cpu.btb_update btb.io.invalidate := io.cpu.invalidate || io.cpu.ptw.invalidate