diff --git a/rocket/src/main/scala/btb.scala b/rocket/src/main/scala/btb.scala index 94f7a6ac..b323ef98 100644 --- a/rocket/src/main/scala/btb.scala +++ b/rocket/src/main/scala/btb.scala @@ -53,10 +53,10 @@ class BHTResp extends Bundle with BTBParameters { // The counter table: // - each counter corresponds with the "fetch pc" (not the PC of the branch). // - updated when a branch resolves (and BTB was a hit for that branch). -// The updating branch must provide its "fetch pc" in addition to its own PC. +// The updating branch must provide its "fetch pc". class BHT(nbht: Int) { val nbhtbits = log2Up(nbht) - def get(addr: UInt, bridx: UInt, update: Bool): BHTResp = { + def get(addr: UInt, update: Bool): BHTResp = { val res = new BHTResp val index = addr(nbhtbits+1,2) ^ history res.value := table(index) @@ -78,7 +78,8 @@ class BHT(nbht: Int) { // BTB update occurs during branch resolution. // - "pc" is what future fetch PCs will tag match against. // - "br_pc" is the PC of the branch instruction. -// - "bridx" is the low-order PC bits of the predicted branch. +// - "bridx" is the low-order PC bits of the predicted branch (after +// shifting off the lowest log(inst_bytes) bits off). // - "resp.mask" provides a mask of valid instructions (instructions are // masked off by the predicted taken branch). class BTBUpdate extends Bundle with BTBParameters { @@ -91,7 +92,7 @@ class BTBUpdate extends Bundle with BTBParameters { val isCall = Bool() val isReturn = Bool() val br_pc = UInt(width = vaddrBits) - val incorrectTarget = Bool() + val mispredict = Bool() } class BTBResp extends Bundle with BTBParameters { @@ -157,8 +158,8 @@ class BTB extends Module with BTBParameters { } val updateHit = r_update.bits.prediction.valid - val updateValid = r_update.bits.incorrectTarget || updateHit && Bool(nBHT > 0) - val updateTarget = updateValid && r_update.bits.incorrectTarget + val updateValid = r_update.bits.mispredict || updateHit && Bool(nBHT > 0) + val updateTarget = updateValid && r_update.bits.mispredict && r_update.bits.taken val useUpdatePageHit = updatePageHit.orR val doIdxPageRepl = updateTarget && !useUpdatePageHit @@ -194,7 +195,11 @@ class BTB extends Module with BTBParameters { tgtPages(waddr) := tgtPageUpdate useRAS(waddr) := r_update.bits.isReturn isJump(waddr) := r_update.bits.isJump - brIdx(waddr) := r_update.bits.br_pc + if (params(FetchWidth) == 1) { + brIdx(waddr) := UInt(0) + } else { + brIdx(waddr) := r_update.bits.br_pc >> log2Up(params(CoreInstBits)/8) + } } require(nPages % 2 == 0) @@ -226,11 +231,11 @@ class BTB extends Module with BTBParameters { if (nBHT > 0) { val bht = new BHT(nBHT) - val res = bht.get(io.req.bits.addr, brIdx(io.resp.bits.entry), io.req.valid && hits.orR && !Mux1H(hits, isJump)) + val res = bht.get(io.req.bits.addr, io.req.valid && hits.orR && !Mux1H(hits, isJump)) val update_btb_hit = io.update.bits.prediction.valid when (io.update.valid && update_btb_hit && !io.update.bits.isJump) { bht.update(io.update.bits.pc, io.update.bits.prediction.bits.bht, - io.update.bits.taken, io.update.bits.incorrectTarget) + io.update.bits.taken, io.update.bits.mispredict) } when (!res.value(0) && !Mux1H(hits, isJump)) { io.resp.bits.taken := false } io.resp.bits.bht := res diff --git a/rocket/src/main/scala/csr.scala b/rocket/src/main/scala/csr.scala index b1e76254..590752e5 100644 --- a/rocket/src/main/scala/csr.scala +++ b/rocket/src/main/scala/csr.scala @@ -113,10 +113,7 @@ class CSRFile extends Module val map = for ((v, i) <- CSRs.all.zipWithIndex) yield v -> UInt(BigInt(1) << i) val out = ROM(map)(addr) - val a = Array.fill(CSRs.all.max+1)(null.asInstanceOf[Bool]) - for (i <- 0 until CSRs.all.size) - a(CSRs.all(i)) = out(i) - a + Map((CSRs.all zip out.toBools):_*) } val wen = cpu_req_valid || host_pcr_req_fire && host_pcr_bits.rw diff --git a/rocket/src/main/scala/ctrl.scala b/rocket/src/main/scala/ctrl.scala index d72eba45..6ea50107 100644 --- a/rocket/src/main/scala/ctrl.scala +++ b/rocket/src/main/scala/ctrl.scala @@ -652,11 +652,11 @@ class Control extends Module Mux(replay_wb, PC_WB, // replay PC_MEM))) - io.imem.btb_update.valid := mem_reg_branch || mem_reg_jal || mem_reg_jalr + io.imem.btb_update.valid := (mem_reg_branch || io.imem.btb_update.bits.isJump) && !take_pc_wb io.imem.btb_update.bits.prediction.valid := mem_reg_btb_hit io.imem.btb_update.bits.prediction.bits := mem_reg_btb_resp - io.imem.btb_update.bits.taken := mem_reg_jal || mem_reg_branch && io.dpath.mem_br_taken - io.imem.btb_update.bits.incorrectTarget := take_pc_mem + io.imem.btb_update.bits.taken := mem_reg_branch && io.dpath.mem_br_taken || io.imem.btb_update.bits.isJump + io.imem.btb_update.bits.mispredict := take_pc_mem io.imem.btb_update.bits.isJump := mem_reg_jal || mem_reg_jalr io.imem.btb_update.bits.isCall := mem_reg_wen && io.dpath.mem_waddr(0) io.imem.btb_update.bits.isReturn := mem_reg_jalr && io.dpath.mem_rs1_ra diff --git a/rocket/src/main/scala/icache.scala b/rocket/src/main/scala/icache.scala index e5c760c8..0c7fbaba 100644 --- a/rocket/src/main/scala/icache.scala +++ b/rocket/src/main/scala/icache.scala @@ -114,7 +114,7 @@ class Frontend extends FrontendModule } val all_ones = UInt((1 << coreFetchWidth)-1) - val msk_pc = all_ones << s2_pc(log2Up(coreFetchWidth)-1+2,2) + val msk_pc = if (coreFetchWidth == 1) all_ones else all_ones << s2_pc(log2Up(coreFetchWidth) -1+2,2) io.cpu.resp.bits.mask := msk_pc & btb.io.resp.bits.mask io.cpu.resp.bits.xcpt_ma := s2_pc(log2Up(coreInstBytes)-1,0) != UInt(0)