diff --git a/uncore/src/main/scala/cache.scala b/uncore/src/main/scala/cache.scala index 31fbd0fa..0fed70b0 100644 --- a/uncore/src/main/scala/cache.scala +++ b/uncore/src/main/scala/cache.scala @@ -11,6 +11,7 @@ case object NSets extends Field[Int] case object NWays extends Field[Int] case object RowBits extends Field[Int] case object Replacer extends Field[() => ReplacementPolicy] +case object L2Replacer extends Field[() => SeqReplacementPolicy] case object AmoAluOperandBits extends Field[Int] case object NPrimaryMisses extends Field[Int] case object NSecondaryMisses extends Field[Int] @@ -58,6 +59,68 @@ class RandomReplacement(ways: Int) extends ReplacementPolicy { def hit = {} } +abstract class SeqReplacementPolicy { + def access(set: UInt): Unit + def update(valid: Bool, hit: Bool, set: UInt, way: UInt): Unit + def way: UInt +} + +class SeqRandom(n_ways: Int) extends SeqReplacementPolicy { + val logic = new RandomReplacement(n_ways) + def access(set: UInt) = { } + def update(valid: Bool, hit: Bool, set: UInt, way: UInt) = { + when (valid && !hit) { logic.miss } + } + def way = logic.way +} + +class PseudoLRU(n: Int) +{ + val state_reg = Reg(Bits(width = n)) + def access(way: UInt) { + state_reg := get_next_state(state_reg,way) + } + def get_next_state(state: Bits, way: UInt) = { + var next_state = state + var idx = UInt(1,1) + for (i <- log2Up(n)-1 to 0 by -1) { + val bit = way(i) + val mask = (UInt(1,n) << idx)(n-1,0) + next_state = next_state & ~mask | Mux(bit, UInt(0), mask) + //next_state.bitSet(idx, !bit) + idx = Cat(idx, bit) + } + next_state + } + def replace = get_replace_way(state_reg) + def get_replace_way(state: Bits) = { + var idx = UInt(1,1) + for (i <- 0 until log2Up(n)) + idx = Cat(idx, state(idx)) + idx(log2Up(n)-1,0) + } +} + +class SeqPLRU(n_sets: Int, n_ways: Int) extends SeqReplacementPolicy { + val state = SeqMem(Bits(width = n_ways-1), n_sets) + val logic = new PseudoLRU(n_ways) + val current_state = Wire(Bits()) + val plru_way = logic.get_replace_way(current_state) + val next_state = Wire(Bits()) + + def access(set: UInt) = { + current_state := Cat(state.read(set), Bits(0, width = 1)) + } + + def update(valid: Bool, hit: Bool, set: UInt, way: UInt) = { + val update_way = Mux(hit, way, plru_way) + next_state := logic.get_next_state(current_state, update_way) + when (valid) { state.write(set, next_state(n_ways-1,1)) } + } + + def way = plru_way +} + abstract class Metadata(implicit p: Parameters) extends CacheBundle()(p) { val tag = Bits(width = tagBits) val coh: CoherenceMetadata @@ -212,16 +275,22 @@ class L2MetadataArray(implicit p: Parameters) extends L2HellaCacheModule()(p) { val s1_clk_en = Reg(next = io.read.fire()) val s1_tag_eq_way = wayMap((w: Int) => meta.io.resp(w).tag === s1_tag) val s1_tag_match_way = wayMap((w: Int) => s1_tag_eq_way(w) && meta.io.resp(w).coh.outer.isValid()).toBits + val s1_idx = RegEnable(io.read.bits.idx, io.read.valid) // deal with stalls? val s2_tag_match_way = RegEnable(s1_tag_match_way, s1_clk_en) val s2_tag_match = s2_tag_match_way.orR val s2_hit_coh = Mux1H(s2_tag_match_way, wayMap((w: Int) => RegEnable(meta.io.resp(w).coh, s1_clk_en))) - val replacer = p(Replacer)() + val replacer = p(L2Replacer)() + val s1_hit_way = Wire(Bits()) + s1_hit_way := Bits(0) + (0 until nWays).foreach(i => when (s1_tag_match_way(i)) { s1_hit_way := Bits(i) }) + replacer.access(io.read.bits.idx) + replacer.update(s1_clk_en, s1_tag_match_way.orR, s1_idx, s1_hit_way) + val s1_replaced_way_en = UIntToOH(replacer.way) val s2_replaced_way_en = UIntToOH(RegEnable(replacer.way, s1_clk_en)) val s2_repl_meta = Mux1H(s2_replaced_way_en, wayMap((w: Int) => RegEnable(meta.io.resp(w), s1_clk_en && s1_replaced_way_en(w))).toSeq) - when(!s2_tag_match) { replacer.miss } io.resp.valid := Reg(next = s1_clk_en) io.resp.bits.id := RegEnable(s1_id, s1_clk_en)