diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index 02ca111e..f6de407e 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -41,7 +41,9 @@ class RAS(nras: Int) { } class BHTResp extends Bundle with BTBParameters { + // TODO only carry history, not both index and history val index = UInt(width = log2Up(nBHT).max(1)) + val history = UInt(width = log2Up(nBHT).max(1)) val value = UInt(width = 2) } @@ -50,12 +52,19 @@ class BHT(nbht: Int) { def get(addr: UInt): BHTResp = { val res = new BHTResp res.index := addr(nbhtbits+1,2) ^ history + res.history := history res.value := table(res.index) + // TODO we actually want to include the final prediction result from the BTB + val taken = res.value(0) + // TODO only update history on an actual instruction fetch + history := Cat(taken, history(nbhtbits-1,1)) res } - def update(d: BHTResp, taken: Bool): Unit = { + def update(d: BHTResp, taken: Bool, mispredict: Bool): Unit = { table(d.index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken)) - history := Cat(taken, history(nbhtbits-1,1)) + when (mispredict) { + history := Cat(taken, d.history(nbhtbits-1,1)) + } } private val table = Mem(UInt(width = 2), nbht) @@ -197,7 +206,7 @@ class BTB extends Module with BTBParameters { if (nBHT > 0) { val bht = new BHT(nBHT) val res = bht.get(io.req) - when (update.valid && updateHit && !update.bits.isJump) { bht.update(update.bits.prediction.bits.bht, update.bits.taken) } + when (update.valid && updateHit && !update.bits.isJump) { bht.update(update.bits.prediction.bits.bht, update.bits.taken, update.bits.incorrectTarget) } when (!res.value(0) && !Mux1H(hits, isJump)) { io.resp.bits.taken := false } io.resp.bits.bht := res }