diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index 81a6d459..94f7a6ac 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -51,45 +51,34 @@ class BHTResp extends Bundle with BTBParameters { // - updated speculatively in fetch (if there's a BTB hit). // - on a mispredict, the history register is reset (again, only if BTB hit). // The counter table: -// - each PC has its own counter, updated when a branch resolves (and BTB hit). -// - the BTB provides the predicted branch PC, allowing us to properly index -// the counter table and provide the prediction for that specific branch. -// Critical path concerns may require only providing a single counter for -// the entire fetch packet, but that complicates how multiple branches -// update that line. -class BHT(nbht: Int, fetchwidth: Int) { +// - each counter corresponds with the "fetch pc" (not the PC of the branch). +// - updated when a branch resolves (and BTB was a hit for that branch). +// The updating branch must provide its "fetch pc" in addition to its own PC. +class BHT(nbht: Int) { val nbhtbits = log2Up(nbht) - private val logfw = if (fetchwidth == 1) 0 else log2Up(fetchwidth) - - def get(fetch_addr: UInt, bridx: UInt, update: Bool): BHTResp = { + def get(addr: UInt, bridx: UInt, update: Bool): BHTResp = { val res = new BHTResp - val aligned_addr = fetch_addr >> UInt(logfw + 2) - val index = aligned_addr ^ history - val counters = table(index) - res.value := (counters >> (bridx<<1)) & Bits(0x3) + val index = addr(nbhtbits+1,2) ^ history + res.value := table(index) res.history := history val taken = res.value(0) when (update) { history := Cat(taken, history(nbhtbits-1,1)) } res } def update(addr: UInt, d: BHTResp, taken: Bool, mispredict: Bool): Unit = { - val aligned_addr = addr >> UInt(logfw + 2) - val index = aligned_addr ^ d.history - val new_cntr = Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken)) - var bridx: UInt = null - if (logfw == 0) bridx = UInt(0) else bridx = addr(logfw+1,2) - val mask = Bits(0x3) << (bridx<<1) - table.write(index, new_cntr, mask) + val index = addr(nbhtbits+1,2) ^ d.history + table(index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken)) when (mispredict) { history := Cat(taken, d.history(nbhtbits-1,1)) } } - private val table = Mem(UInt(width = 2*fetchwidth), nbht) + private val table = Mem(UInt(width = 2), nbht) val history = Reg(UInt(width = nbhtbits)) } // BTB update occurs during branch resolution. // - "pc" is what future fetch PCs will tag match against. // - "br_pc" is the PC of the branch instruction. +// - "bridx" is the low-order PC bits of the predicted branch. // - "resp.mask" provides a mask of valid instructions (instructions are // masked off by the predicted taken branch). class BTBUpdate extends Bundle with BTBParameters { @@ -107,7 +96,8 @@ class BTBUpdate extends Bundle with BTBParameters { class BTBResp extends Bundle with BTBParameters { val taken = Bool() - val mask = Bits(width = log2Up(params(FetchWidth))) + val mask = Bits(width = params(FetchWidth)) + val bridx = Bits(width = log2Up(params(FetchWidth))) val target = UInt(width = vaddrBits) val entry = UInt(width = opaqueBits) val bht = new BHTResp @@ -232,13 +222,14 @@ class BTB extends Module with BTBParameters { io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts)) io.resp.bits.entry := OHToUInt(hits) io.resp.bits.mask := Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)) + io.resp.bits.bridx := brIdx(io.resp.bits.entry) if (nBHT > 0) { - val bht = new BHT(nBHT, params(FetchWidth)) + val bht = new BHT(nBHT) val res = bht.get(io.req.bits.addr, brIdx(io.resp.bits.entry), io.req.valid && hits.orR && !Mux1H(hits, isJump)) val update_btb_hit = io.update.bits.prediction.valid when (io.update.valid && update_btb_hit && !io.update.bits.isJump) { - bht.update(io.update.bits.br_pc, io.update.bits.prediction.bits.bht, + bht.update(io.update.bits.pc, io.update.bits.prediction.bits.bht, io.update.bits.taken, io.update.bits.incorrectTarget) } when (!res.value(0) && !Mux1H(hits, isJump)) { io.resp.bits.taken := false }