diff --git a/src/main/scala/rocket/BTB.scala b/src/main/scala/rocket/BTB.scala index 15cb5d24..1d9491bc 100644 --- a/src/main/scala/rocket/BTB.scala +++ b/src/main/scala/rocket/BTB.scala @@ -86,6 +86,15 @@ class BHT(nbht: Int)(implicit val p: Parameters) extends HasCoreParameters { val history = Reg(UInt(width = nbhtbits)) } +object CFIType { + def SZ = 2 + def apply() = UInt(width = SZ) + def branch = 0.U + def jump = 1.U + def call = 2.U + def ret = 3.U +} + // BTB update occurs during branch resolution (and only on a mispredict). // - "pc" is what future fetch PCs will tag match against. // - "br_pc" is the PC of the branch instruction. @@ -95,9 +104,8 @@ class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) { val target = UInt(width = vaddrBits) val taken = Bool() val isValid = Bool() - val isJump = Bool() - val isReturn = Bool() val br_pc = UInt(width = vaddrBits) + val cfiType = CFIType() } // BHT update occurs during branch resolution on all conditional branches. @@ -110,8 +118,7 @@ class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) { } class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) { - val isCall = Bool() - val isReturn = Bool() + val cfiType = CFIType() val returnAddr = UInt(width = vaddrBits) val prediction = Valid(new BTBResp) } @@ -121,6 +128,7 @@ class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) { // - "mask" provides a mask of valid instructions (instructions are // masked off by the predicted taken branch from the BTB). class BTBResp(implicit p: Parameters) extends BtbBundle()(p) { + val cfiType = CFIType() val taken = Bool() val mask = Bits(width = fetchWidth) val bridx = Bits(width = log2Up(fetchWidth)) @@ -154,8 +162,7 @@ class BTB(implicit p: Parameters) extends BtbModule { val pageValid = Reg(init = UInt(0, nPages)) val isValid = Reg(init = UInt(0, entries)) - val isReturn = Reg(UInt(width = entries)) - val isJump = Reg(UInt(width = entries)) + val cfiType = Reg(Vec(entries, CFIType())) val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth)))) private def page(addr: UInt) = addr >> matchBits @@ -210,9 +217,8 @@ class BTB(implicit p: Parameters) extends BtbModule { tgts(waddr) := update_target(matchBits-1, log2Up(coreInstBytes)) idxPages(waddr) := idxPageUpdate +& 1 // the +1 corresponds to the <<1 on io.resp.valid tgtPages(waddr) := tgtPageUpdate + cfiType(waddr) := r_btb_update.bits.cfiType isValid := Mux(r_btb_update.bits.isValid, isValid | mask, isValid & ~mask) - isReturn := Mux(r_btb_update.bits.isReturn, isReturn | mask, isReturn & ~mask) - isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask) if (fetchWidth > 1) brIdx(waddr) := r_btb_update.bits.br_pc >> log2Up(coreInstBytes) @@ -236,6 +242,7 @@ class BTB(implicit p: Parameters) extends BtbModule { io.resp.bits.entry := OHToUInt(idxHit) io.resp.bits.bridx := (if (fetchWidth > 1) Mux1H(idxHit, brIdx) else UInt(0)) io.resp.bits.mask := Cat((UInt(1) << ~Mux(io.resp.bits.taken, ~io.resp.bits.bridx, UInt(0)))-1, UInt(1)) + io.resp.bits.cfiType := Mux1H(idxHit, cfiType) // if multiple entries for same PC land in BTB, zap them when (PopCountAtLeast(idxHit, 2)) { @@ -244,7 +251,7 @@ class BTB(implicit p: Parameters) extends BtbModule { if (nBHT > 0) { val bht = new BHT(nBHT) - val isBranch = !(idxHit & isJump).orR + val isBranch = (idxHit & cfiType.map(_ === CFIType.branch).asUInt).orR val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch) val update_btb_hit = io.bht_update.bits.prediction.valid when (io.bht_update.valid && update_btb_hit) { @@ -256,17 +263,14 @@ class BTB(implicit p: Parameters) extends BtbModule { if (nRAS > 0) { val ras = new RAS(nRAS) - val doPeek = (idxHit & isReturn).orR + val doPeek = (idxHit & cfiType.map(_ === CFIType.ret).asUInt).orR when (!ras.isEmpty && doPeek) { io.resp.bits.target := ras.peek } when (io.ras_update.valid) { - when (io.ras_update.bits.isCall) { + when (io.ras_update.bits.cfiType === CFIType.call) { ras.push(io.ras_update.bits.returnAddr) - when (doPeek) { - io.resp.bits.target := io.ras_update.bits.returnAddr - } - }.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) { + }.elsewhen (io.ras_update.bits.cfiType === CFIType.ret && io.ras_update.bits.prediction.valid) { ras.pop() } } diff --git a/src/main/scala/rocket/Frontend.scala b/src/main/scala/rocket/Frontend.scala index 126f2b25..27a02fd3 100644 --- a/src/main/scala/rocket/Frontend.scala +++ b/src/main/scala/rocket/Frontend.scala @@ -84,7 +84,9 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) val s2_speculative = Reg(init=Bool(false)) val s2_cacheable = Reg(init=Bool(false)) - val ntpc = ~(~s1_pc | (coreInstBytes*fetchWidth-1)) + UInt(coreInstBytes*fetchWidth) + val fetchBytes = coreInstBytes * fetchWidth + val s1_base_pc = ~(~s1_pc | (fetchBytes - 1)) + val ntpc = s1_base_pc + fetchBytes.U val predicted_npc = Wire(init = ntpc) val predicted_taken = Wire(init = Bool(false)) @@ -129,6 +131,14 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) predicted_npc := btb.io.resp.bits.target.sextTo(vaddrBitsExtended) predicted_taken := Bool(true) } + + // push RAS speculatively + btb.io.ras_update.valid := btb.io.req.valid && btb.io.resp.valid && btb.io.resp.bits.cfiType.isOneOf(CFIType.call, CFIType.ret) + val returnAddrLSBs = btb.io.resp.bits.bridx +& 1 + btb.io.ras_update.bits.returnAddr := + Mux(returnAddrLSBs(log2Ceil(fetchWidth)), ntpc, s1_base_pc | ((returnAddrLSBs << log2Ceil(coreInstBytes)) & (fetchBytes - 1))) + btb.io.ras_update.bits.cfiType := btb.io.resp.bits.cfiType + btb.io.ras_update.bits.prediction.valid := true } io.ptw <> tlb.io.ptw diff --git a/src/main/scala/rocket/Rocket.scala b/src/main/scala/rocket/Rocket.scala index b3c92cfe..e7005050 100644 --- a/src/main/scala/rocket/Rocket.scala +++ b/src/main/scala/rocket/Rocket.scala @@ -587,8 +587,10 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) io.imem.btb_update.valid := (mem_reg_replay && mem_reg_btb_hit) || (mem_reg_valid && !take_pc_wb && (((mem_cfi_taken || !mem_cfi) && mem_wrong_npc) || (Bool(fastJAL) && mem_ctrl.jal && !mem_reg_btb_hit))) io.imem.btb_update.bits.isValid := !mem_reg_replay && mem_cfi - io.imem.btb_update.bits.isJump := mem_ctrl.jal || mem_ctrl.jalr - io.imem.btb_update.bits.isReturn := mem_ctrl.jalr && mem_reg_inst(19,15) === BitPat("b00?01") + io.imem.btb_update.bits.cfiType := + Mux(mem_ctrl.jalr && mem_reg_inst(19,15) === BitPat("b00?01"), CFIType.ret, + Mux(mem_ctrl.jal || mem_ctrl.jalr, Mux(mem_waddr(0), CFIType.call, CFIType.jump), + CFIType.branch)) io.imem.btb_update.bits.target := io.imem.req.bits.pc io.imem.btb_update.bits.br_pc := (if (usingCompressed) mem_reg_pc + Mux(mem_reg_rvc, UInt(0), UInt(2)) else mem_reg_pc) io.imem.btb_update.bits.pc := ~(~io.imem.btb_update.bits.br_pc | (coreInstBytes*fetchWidth-1)) @@ -601,12 +603,6 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) io.imem.bht_update.bits.mispredict := mem_wrong_npc io.imem.bht_update.bits.prediction := io.imem.btb_update.bits.prediction - io.imem.ras_update.valid := mem_reg_valid && !take_pc_wb - io.imem.ras_update.bits.returnAddr := mem_int_wdata - io.imem.ras_update.bits.isCall := io.imem.btb_update.bits.isJump && mem_waddr(0) - io.imem.ras_update.bits.isReturn := io.imem.btb_update.bits.isReturn - io.imem.ras_update.bits.prediction := io.imem.btb_update.bits.prediction - io.fpu.valid := !ctrl_killd && id_ctrl.fp io.fpu.killx := ctrl_killx io.fpu.killm := killm_common