Update RAS speculatively from fetch stage
This commit is contained in:
		
				
					committed by
					
						 Andrew Waterman
						Andrew Waterman
					
				
			
			
				
	
			
			
			
						parent
						
							3b2c15b648
						
					
				
				
					commit
					f2d4cb8152
				
			| @@ -86,6 +86,15 @@ class BHT(nbht: Int)(implicit val p: Parameters) extends HasCoreParameters { | |||||||
|   val history = Reg(UInt(width = nbhtbits)) |   val history = Reg(UInt(width = nbhtbits)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | object CFIType { | ||||||
|  |   def SZ = 2 | ||||||
|  |   def apply() = UInt(width = SZ) | ||||||
|  |   def branch = 0.U | ||||||
|  |   def jump = 1.U | ||||||
|  |   def call = 2.U | ||||||
|  |   def ret = 3.U | ||||||
|  | } | ||||||
|  |  | ||||||
| // BTB update occurs during branch resolution (and only on a mispredict). | // BTB update occurs during branch resolution (and only on a mispredict). | ||||||
| //  - "pc" is what future fetch PCs will tag match against. | //  - "pc" is what future fetch PCs will tag match against. | ||||||
| //  - "br_pc" is the PC of the branch instruction. | //  - "br_pc" is the PC of the branch instruction. | ||||||
| @@ -95,9 +104,8 @@ class BTBUpdate(implicit p: Parameters) extends BtbBundle()(p) { | |||||||
|   val target = UInt(width = vaddrBits) |   val target = UInt(width = vaddrBits) | ||||||
|   val taken = Bool() |   val taken = Bool() | ||||||
|   val isValid = Bool() |   val isValid = Bool() | ||||||
|   val isJump = Bool() |  | ||||||
|   val isReturn = Bool() |  | ||||||
|   val br_pc = UInt(width = vaddrBits) |   val br_pc = UInt(width = vaddrBits) | ||||||
|  |   val cfiType = CFIType() | ||||||
| } | } | ||||||
|  |  | ||||||
| // BHT update occurs during branch resolution on all conditional branches. | // BHT update occurs during branch resolution on all conditional branches. | ||||||
| @@ -110,8 +118,7 @@ class BHTUpdate(implicit p: Parameters) extends BtbBundle()(p) { | |||||||
| } | } | ||||||
|  |  | ||||||
| class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) { | class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) { | ||||||
|   val isCall = Bool() |   val cfiType = CFIType() | ||||||
|   val isReturn = Bool() |  | ||||||
|   val returnAddr = UInt(width = vaddrBits) |   val returnAddr = UInt(width = vaddrBits) | ||||||
|   val prediction = Valid(new BTBResp) |   val prediction = Valid(new BTBResp) | ||||||
| } | } | ||||||
| @@ -121,6 +128,7 @@ class RASUpdate(implicit p: Parameters) extends BtbBundle()(p) { | |||||||
| //  - "mask" provides a mask of valid instructions (instructions are | //  - "mask" provides a mask of valid instructions (instructions are | ||||||
| //     masked off by the predicted taken branch from the BTB). | //     masked off by the predicted taken branch from the BTB). | ||||||
| class BTBResp(implicit p: Parameters) extends BtbBundle()(p) { | class BTBResp(implicit p: Parameters) extends BtbBundle()(p) { | ||||||
|  |   val cfiType = CFIType() | ||||||
|   val taken = Bool() |   val taken = Bool() | ||||||
|   val mask = Bits(width = fetchWidth) |   val mask = Bits(width = fetchWidth) | ||||||
|   val bridx = Bits(width = log2Up(fetchWidth)) |   val bridx = Bits(width = log2Up(fetchWidth)) | ||||||
| @@ -154,8 +162,7 @@ class BTB(implicit p: Parameters) extends BtbModule { | |||||||
|   val pageValid = Reg(init = UInt(0, nPages)) |   val pageValid = Reg(init = UInt(0, nPages)) | ||||||
|  |  | ||||||
|   val isValid = Reg(init = UInt(0, entries)) |   val isValid = Reg(init = UInt(0, entries)) | ||||||
|   val isReturn = Reg(UInt(width = entries)) |   val cfiType = Reg(Vec(entries, CFIType())) | ||||||
|   val isJump = Reg(UInt(width = entries)) |  | ||||||
|   val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth)))) |   val brIdx = Reg(Vec(entries, UInt(width=log2Up(fetchWidth)))) | ||||||
|  |  | ||||||
|   private def page(addr: UInt) = addr >> matchBits |   private def page(addr: UInt) = addr >> matchBits | ||||||
| @@ -210,9 +217,8 @@ class BTB(implicit p: Parameters) extends BtbModule { | |||||||
|     tgts(waddr) := update_target(matchBits-1, log2Up(coreInstBytes)) |     tgts(waddr) := update_target(matchBits-1, log2Up(coreInstBytes)) | ||||||
|     idxPages(waddr) := idxPageUpdate +& 1 // the +1 corresponds to the <<1 on io.resp.valid |     idxPages(waddr) := idxPageUpdate +& 1 // the +1 corresponds to the <<1 on io.resp.valid | ||||||
|     tgtPages(waddr) := tgtPageUpdate |     tgtPages(waddr) := tgtPageUpdate | ||||||
|  |     cfiType(waddr) := r_btb_update.bits.cfiType | ||||||
|     isValid := Mux(r_btb_update.bits.isValid, isValid | mask, isValid & ~mask) |     isValid := Mux(r_btb_update.bits.isValid, isValid | mask, isValid & ~mask) | ||||||
|     isReturn := Mux(r_btb_update.bits.isReturn, isReturn | mask, isReturn & ~mask) |  | ||||||
|     isJump := Mux(r_btb_update.bits.isJump, isJump | mask, isJump & ~mask) |  | ||||||
|     if (fetchWidth > 1) |     if (fetchWidth > 1) | ||||||
|       brIdx(waddr) := r_btb_update.bits.br_pc >> log2Up(coreInstBytes) |       brIdx(waddr) := r_btb_update.bits.br_pc >> log2Up(coreInstBytes) | ||||||
|  |  | ||||||
| @@ -236,6 +242,7 @@ class BTB(implicit p: Parameters) extends BtbModule { | |||||||
|   io.resp.bits.entry := OHToUInt(idxHit) |   io.resp.bits.entry := OHToUInt(idxHit) | ||||||
|   io.resp.bits.bridx := (if (fetchWidth > 1) Mux1H(idxHit, brIdx) else UInt(0)) |   io.resp.bits.bridx := (if (fetchWidth > 1) Mux1H(idxHit, brIdx) else UInt(0)) | ||||||
|   io.resp.bits.mask := Cat((UInt(1) << ~Mux(io.resp.bits.taken, ~io.resp.bits.bridx, UInt(0)))-1, UInt(1)) |   io.resp.bits.mask := Cat((UInt(1) << ~Mux(io.resp.bits.taken, ~io.resp.bits.bridx, UInt(0)))-1, UInt(1)) | ||||||
|  |   io.resp.bits.cfiType := Mux1H(idxHit, cfiType) | ||||||
|  |  | ||||||
|   // if multiple entries for same PC land in BTB, zap them |   // if multiple entries for same PC land in BTB, zap them | ||||||
|   when (PopCountAtLeast(idxHit, 2)) { |   when (PopCountAtLeast(idxHit, 2)) { | ||||||
| @@ -244,7 +251,7 @@ class BTB(implicit p: Parameters) extends BtbModule { | |||||||
|  |  | ||||||
|   if (nBHT > 0) { |   if (nBHT > 0) { | ||||||
|     val bht = new BHT(nBHT) |     val bht = new BHT(nBHT) | ||||||
|     val isBranch = !(idxHit & isJump).orR |     val isBranch = (idxHit & cfiType.map(_ === CFIType.branch).asUInt).orR | ||||||
|     val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch) |     val res = bht.get(io.req.bits.addr, io.req.valid && io.resp.valid && isBranch) | ||||||
|     val update_btb_hit = io.bht_update.bits.prediction.valid |     val update_btb_hit = io.bht_update.bits.prediction.valid | ||||||
|     when (io.bht_update.valid && update_btb_hit) { |     when (io.bht_update.valid && update_btb_hit) { | ||||||
| @@ -256,17 +263,14 @@ class BTB(implicit p: Parameters) extends BtbModule { | |||||||
|  |  | ||||||
|   if (nRAS > 0) { |   if (nRAS > 0) { | ||||||
|     val ras = new RAS(nRAS) |     val ras = new RAS(nRAS) | ||||||
|     val doPeek = (idxHit & isReturn).orR |     val doPeek = (idxHit & cfiType.map(_ === CFIType.ret).asUInt).orR | ||||||
|     when (!ras.isEmpty && doPeek) { |     when (!ras.isEmpty && doPeek) { | ||||||
|       io.resp.bits.target := ras.peek |       io.resp.bits.target := ras.peek | ||||||
|     } |     } | ||||||
|     when (io.ras_update.valid) { |     when (io.ras_update.valid) { | ||||||
|       when (io.ras_update.bits.isCall) { |       when (io.ras_update.bits.cfiType === CFIType.call) { | ||||||
|         ras.push(io.ras_update.bits.returnAddr) |         ras.push(io.ras_update.bits.returnAddr) | ||||||
|         when (doPeek) { |       }.elsewhen (io.ras_update.bits.cfiType === CFIType.ret && io.ras_update.bits.prediction.valid) { | ||||||
|           io.resp.bits.target := io.ras_update.bits.returnAddr |  | ||||||
|         } |  | ||||||
|       }.elsewhen (io.ras_update.bits.isReturn && io.ras_update.bits.prediction.valid) { |  | ||||||
|         ras.pop() |         ras.pop() | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -84,7 +84,9 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) | |||||||
|   val s2_speculative = Reg(init=Bool(false)) |   val s2_speculative = Reg(init=Bool(false)) | ||||||
|   val s2_cacheable = Reg(init=Bool(false)) |   val s2_cacheable = Reg(init=Bool(false)) | ||||||
|  |  | ||||||
|   val ntpc = ~(~s1_pc | (coreInstBytes*fetchWidth-1)) + UInt(coreInstBytes*fetchWidth) |   val fetchBytes = coreInstBytes * fetchWidth | ||||||
|  |   val s1_base_pc = ~(~s1_pc | (fetchBytes - 1)) | ||||||
|  |   val ntpc = s1_base_pc + fetchBytes.U | ||||||
|   val predicted_npc = Wire(init = ntpc) |   val predicted_npc = Wire(init = ntpc) | ||||||
|   val predicted_taken = Wire(init = Bool(false)) |   val predicted_taken = Wire(init = Bool(false)) | ||||||
|  |  | ||||||
| @@ -129,6 +131,14 @@ class FrontendModule(outer: Frontend) extends LazyModuleImp(outer) | |||||||
|       predicted_npc := btb.io.resp.bits.target.sextTo(vaddrBitsExtended) |       predicted_npc := btb.io.resp.bits.target.sextTo(vaddrBitsExtended) | ||||||
|       predicted_taken := Bool(true) |       predicted_taken := Bool(true) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // push RAS speculatively | ||||||
|  |     btb.io.ras_update.valid := btb.io.req.valid && btb.io.resp.valid && btb.io.resp.bits.cfiType.isOneOf(CFIType.call, CFIType.ret) | ||||||
|  |     val returnAddrLSBs = btb.io.resp.bits.bridx +& 1 | ||||||
|  |     btb.io.ras_update.bits.returnAddr := | ||||||
|  |       Mux(returnAddrLSBs(log2Ceil(fetchWidth)), ntpc, s1_base_pc | ((returnAddrLSBs << log2Ceil(coreInstBytes)) & (fetchBytes - 1))) | ||||||
|  |     btb.io.ras_update.bits.cfiType := btb.io.resp.bits.cfiType | ||||||
|  |     btb.io.ras_update.bits.prediction.valid := true | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   io.ptw <> tlb.io.ptw |   io.ptw <> tlb.io.ptw | ||||||
|   | |||||||
| @@ -587,8 +587,10 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) | |||||||
|  |  | ||||||
|   io.imem.btb_update.valid := (mem_reg_replay && mem_reg_btb_hit) || (mem_reg_valid && !take_pc_wb && (((mem_cfi_taken || !mem_cfi) && mem_wrong_npc) || (Bool(fastJAL) && mem_ctrl.jal && !mem_reg_btb_hit))) |   io.imem.btb_update.valid := (mem_reg_replay && mem_reg_btb_hit) || (mem_reg_valid && !take_pc_wb && (((mem_cfi_taken || !mem_cfi) && mem_wrong_npc) || (Bool(fastJAL) && mem_ctrl.jal && !mem_reg_btb_hit))) | ||||||
|   io.imem.btb_update.bits.isValid := !mem_reg_replay && mem_cfi |   io.imem.btb_update.bits.isValid := !mem_reg_replay && mem_cfi | ||||||
|   io.imem.btb_update.bits.isJump := mem_ctrl.jal || mem_ctrl.jalr |   io.imem.btb_update.bits.cfiType := | ||||||
|   io.imem.btb_update.bits.isReturn := mem_ctrl.jalr && mem_reg_inst(19,15) === BitPat("b00?01") |     Mux(mem_ctrl.jalr && mem_reg_inst(19,15) === BitPat("b00?01"), CFIType.ret, | ||||||
|  |     Mux(mem_ctrl.jal || mem_ctrl.jalr, Mux(mem_waddr(0), CFIType.call, CFIType.jump), | ||||||
|  |     CFIType.branch)) | ||||||
|   io.imem.btb_update.bits.target := io.imem.req.bits.pc |   io.imem.btb_update.bits.target := io.imem.req.bits.pc | ||||||
|   io.imem.btb_update.bits.br_pc := (if (usingCompressed) mem_reg_pc + Mux(mem_reg_rvc, UInt(0), UInt(2)) else mem_reg_pc) |   io.imem.btb_update.bits.br_pc := (if (usingCompressed) mem_reg_pc + Mux(mem_reg_rvc, UInt(0), UInt(2)) else mem_reg_pc) | ||||||
|   io.imem.btb_update.bits.pc := ~(~io.imem.btb_update.bits.br_pc | (coreInstBytes*fetchWidth-1)) |   io.imem.btb_update.bits.pc := ~(~io.imem.btb_update.bits.br_pc | (coreInstBytes*fetchWidth-1)) | ||||||
| @@ -601,12 +603,6 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) | |||||||
|   io.imem.bht_update.bits.mispredict := mem_wrong_npc |   io.imem.bht_update.bits.mispredict := mem_wrong_npc | ||||||
|   io.imem.bht_update.bits.prediction := io.imem.btb_update.bits.prediction |   io.imem.bht_update.bits.prediction := io.imem.btb_update.bits.prediction | ||||||
|  |  | ||||||
|   io.imem.ras_update.valid := mem_reg_valid && !take_pc_wb |  | ||||||
|   io.imem.ras_update.bits.returnAddr := mem_int_wdata |  | ||||||
|   io.imem.ras_update.bits.isCall := io.imem.btb_update.bits.isJump && mem_waddr(0) |  | ||||||
|   io.imem.ras_update.bits.isReturn := io.imem.btb_update.bits.isReturn |  | ||||||
|   io.imem.ras_update.bits.prediction := io.imem.btb_update.bits.prediction |  | ||||||
|  |  | ||||||
|   io.fpu.valid := !ctrl_killd && id_ctrl.fp |   io.fpu.valid := !ctrl_killd && id_ctrl.fp | ||||||
|   io.fpu.killx := ctrl_killx |   io.fpu.killx := ctrl_killx | ||||||
|   io.fpu.killm := killm_common |   io.fpu.killm := killm_common | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user