Finalize superscalar btb.
This commit is contained in:
		@@ -51,45 +51,34 @@ class BHTResp extends Bundle with BTBParameters {
 | 
				
			|||||||
//    - updated speculatively in fetch (if there's a BTB hit).
 | 
					//    - updated speculatively in fetch (if there's a BTB hit).
 | 
				
			||||||
//    - on a mispredict, the history register is reset (again, only if BTB hit).
 | 
					//    - on a mispredict, the history register is reset (again, only if BTB hit).
 | 
				
			||||||
// The counter table:
 | 
					// The counter table:
 | 
				
			||||||
//    - each PC has its own counter, updated when a branch resolves (and BTB hit).
 | 
					//    - each counter corresponds with the "fetch pc" (not the PC of the branch). 
 | 
				
			||||||
//    - the BTB provides the predicted branch PC, allowing us to properly index
 | 
					//    - updated when a branch resolves (and BTB was a hit for that branch).
 | 
				
			||||||
//      the counter table and provide the prediction for that specific branch.
 | 
					//      The updating branch must provide its "fetch pc" in addition to its own PC.
 | 
				
			||||||
//      Critical path concerns may require only providing a single counter for
 | 
					class BHT(nbht: Int) {
 | 
				
			||||||
//      the entire fetch packet, but that complicates how multiple branches
 | 
					 | 
				
			||||||
//      update that line.
 | 
					 | 
				
			||||||
class BHT(nbht: Int, fetchwidth: Int) {
 | 
					 | 
				
			||||||
  val nbhtbits = log2Up(nbht)
 | 
					  val nbhtbits = log2Up(nbht)
 | 
				
			||||||
  private val logfw = if (fetchwidth == 1) 0 else log2Up(fetchwidth)
 | 
					  def get(addr: UInt, bridx: UInt, update: Bool): BHTResp = {
 | 
				
			||||||
 | 
					 | 
				
			||||||
  def get(fetch_addr: UInt, bridx: UInt, update: Bool): BHTResp = {
 | 
					 | 
				
			||||||
    val res = new BHTResp
 | 
					    val res = new BHTResp
 | 
				
			||||||
    val aligned_addr = fetch_addr >> UInt(logfw + 2)
 | 
					    val index = addr(nbhtbits+1,2) ^ history
 | 
				
			||||||
    val index = aligned_addr ^ history
 | 
					    res.value := table(index)
 | 
				
			||||||
    val counters = table(index)
 | 
					 | 
				
			||||||
    res.value := (counters >> (bridx<<1)) & Bits(0x3)
 | 
					 | 
				
			||||||
    res.history := history
 | 
					    res.history := history
 | 
				
			||||||
    val taken = res.value(0)
 | 
					    val taken = res.value(0)
 | 
				
			||||||
    when (update) { history := Cat(taken, history(nbhtbits-1,1)) }
 | 
					    when (update) { history := Cat(taken, history(nbhtbits-1,1)) }
 | 
				
			||||||
    res
 | 
					    res
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  def update(addr: UInt, d: BHTResp, taken: Bool, mispredict: Bool): Unit = {
 | 
					  def update(addr: UInt, d: BHTResp, taken: Bool, mispredict: Bool): Unit = {
 | 
				
			||||||
    val aligned_addr = addr >> UInt(logfw + 2)
 | 
					    val index = addr(nbhtbits+1,2) ^ d.history
 | 
				
			||||||
    val index = aligned_addr ^ d.history
 | 
					    table(index) := Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken))
 | 
				
			||||||
    val new_cntr = Cat(taken, (d.value(1) & d.value(0)) | ((d.value(1) | d.value(0)) & taken))
 | 
					 | 
				
			||||||
    var bridx: UInt = null
 | 
					 | 
				
			||||||
    if (logfw == 0) bridx = UInt(0) else bridx = addr(logfw+1,2)
 | 
					 | 
				
			||||||
    val mask = Bits(0x3) << (bridx<<1)
 | 
					 | 
				
			||||||
    table.write(index, new_cntr, mask) 
 | 
					 | 
				
			||||||
    when (mispredict) { history := Cat(taken, d.history(nbhtbits-1,1)) }
 | 
					    when (mispredict) { history := Cat(taken, d.history(nbhtbits-1,1)) }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  private val table = Mem(UInt(width = 2*fetchwidth), nbht)
 | 
					  private val table = Mem(UInt(width = 2), nbht)
 | 
				
			||||||
  val history = Reg(UInt(width = nbhtbits))
 | 
					  val history = Reg(UInt(width = nbhtbits))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// BTB update occurs during branch resolution.
 | 
					// BTB update occurs during branch resolution.
 | 
				
			||||||
//  - "pc" is what future fetch PCs will tag match against.
 | 
					//  - "pc" is what future fetch PCs will tag match against.
 | 
				
			||||||
//  - "br_pc" is the PC of the branch instruction.
 | 
					//  - "br_pc" is the PC of the branch instruction.
 | 
				
			||||||
 | 
					//  - "bridx" is the low-order PC bits of the predicted branch.
 | 
				
			||||||
//  - "resp.mask" provides a mask of valid instructions (instructions are
 | 
					//  - "resp.mask" provides a mask of valid instructions (instructions are
 | 
				
			||||||
//      masked off by the predicted taken branch).
 | 
					//      masked off by the predicted taken branch).
 | 
				
			||||||
class BTBUpdate extends Bundle with BTBParameters {
 | 
					class BTBUpdate extends Bundle with BTBParameters {
 | 
				
			||||||
@@ -107,7 +96,8 @@ class BTBUpdate extends Bundle with BTBParameters {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class BTBResp extends Bundle with BTBParameters {
 | 
					class BTBResp extends Bundle with BTBParameters {
 | 
				
			||||||
  val taken = Bool()
 | 
					  val taken = Bool()
 | 
				
			||||||
  val mask = Bits(width = log2Up(params(FetchWidth)))
 | 
					  val mask = Bits(width = params(FetchWidth))
 | 
				
			||||||
 | 
					  val bridx = Bits(width = log2Up(params(FetchWidth)))
 | 
				
			||||||
  val target = UInt(width = vaddrBits)
 | 
					  val target = UInt(width = vaddrBits)
 | 
				
			||||||
  val entry = UInt(width = opaqueBits)
 | 
					  val entry = UInt(width = opaqueBits)
 | 
				
			||||||
  val bht = new BHTResp
 | 
					  val bht = new BHTResp
 | 
				
			||||||
@@ -232,13 +222,14 @@ class BTB extends Module with BTBParameters {
 | 
				
			|||||||
  io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts))
 | 
					  io.resp.bits.target := Cat(Mux1H(Mux1H(hits, tgtPagesOH), pages), Mux1H(hits, tgts))
 | 
				
			||||||
  io.resp.bits.entry := OHToUInt(hits)
 | 
					  io.resp.bits.entry := OHToUInt(hits)
 | 
				
			||||||
  io.resp.bits.mask := Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)) 
 | 
					  io.resp.bits.mask := Cat((UInt(1) << brIdx(io.resp.bits.entry))-1, UInt(1)) 
 | 
				
			||||||
 | 
					  io.resp.bits.bridx := brIdx(io.resp.bits.entry)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (nBHT > 0) {
 | 
					  if (nBHT > 0) {
 | 
				
			||||||
    val bht = new BHT(nBHT, params(FetchWidth))
 | 
					    val bht = new BHT(nBHT)
 | 
				
			||||||
    val res = bht.get(io.req.bits.addr, brIdx(io.resp.bits.entry), io.req.valid && hits.orR && !Mux1H(hits, isJump))
 | 
					    val res = bht.get(io.req.bits.addr, brIdx(io.resp.bits.entry), io.req.valid && hits.orR && !Mux1H(hits, isJump))
 | 
				
			||||||
    val update_btb_hit = io.update.bits.prediction.valid
 | 
					    val update_btb_hit = io.update.bits.prediction.valid
 | 
				
			||||||
    when (io.update.valid && update_btb_hit && !io.update.bits.isJump) {
 | 
					    when (io.update.valid && update_btb_hit && !io.update.bits.isJump) {
 | 
				
			||||||
      bht.update(io.update.bits.br_pc, io.update.bits.prediction.bits.bht,
 | 
					      bht.update(io.update.bits.pc, io.update.bits.prediction.bits.bht,
 | 
				
			||||||
                 io.update.bits.taken, io.update.bits.incorrectTarget)
 | 
					                 io.update.bits.taken, io.update.bits.incorrectTarget)
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    when (!res.value(0) && !Mux1H(hits, isJump)) { io.resp.bits.taken := false }
 | 
					    when (!res.value(0) && !Mux1H(hits, isJump)) { io.resp.bits.taken := false }
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user