diff --git a/junctions/src/main/scala/addrmap.scala b/junctions/src/main/scala/addrmap.scala index 8eb6ddd7..a223f63a 100644 --- a/junctions/src/main/scala/addrmap.scala +++ b/junctions/src/main/scala/addrmap.scala @@ -16,7 +16,6 @@ case object PPNBits extends Field[Int] case object VPNBits extends Field[Int] case object GlobalAddrMap extends Field[AddrMap] -case object GlobalAddrHashMap extends Field[AddrHashMap] trait HasAddrMapParameters { implicit val p: Parameters @@ -30,23 +29,27 @@ trait HasAddrMapParameters { val pgLevelBits = p(PgLevelBits) val asIdBits = p(ASIdBits) - val addrMap = p(GlobalAddrHashMap) + val addrMap = p(GlobalAddrMap) } case class MemAttr(prot: Int, cacheable: Boolean = false) -abstract class MemRegion { - def align: BigInt +sealed abstract class MemRegion { + def start: BigInt def size: BigInt def numSlaves: Int + def attr: MemAttr + + def containsAddress(x: UInt) = UInt(start) <= x && x < UInt(start + size) } -case class MemSize(size: BigInt, align: BigInt, attr: MemAttr) extends MemRegion { +case class MemSize(size: BigInt, attr: MemAttr) extends MemRegion { + def start = 0 def numSlaves = 1 } -case class MemSubmap(size: BigInt, entries: AddrMap) extends MemRegion { - val numSlaves = entries.countSlaves - val align = entries.computeAlign + +case class MemRange(start: BigInt, size: BigInt, attr: MemAttr) extends MemRegion { + def numSlaves = 1 } object AddrMapProt { @@ -67,90 +70,80 @@ class AddrMapProt extends Bundle { case class AddrMapEntry(name: String, region: MemRegion) -case class AddrHashMapEntry(port: Int, start: BigInt, region: MemRegion) - -class AddrMap(entries: Seq[AddrMapEntry]) extends scala.collection.IndexedSeq[AddrMapEntry] { - private val hash = HashMap(entries.map(e => (e.name, e.region)):_*) - - def apply(index: Int): AddrMapEntry = entries(index) - - def length: Int = entries.size - - def countSlaves: Int = entries.map(_.region.numSlaves).foldLeft(0)(_ + _) - - def computeSize: BigInt = new AddrHashMap(this).size - - def computeAlign: BigInt = entries.map(_.region.align).foldLeft(BigInt(1))(_ max _) - - override def tail: AddrMap = new AddrMap(entries.tail) -} - object AddrMap { def apply(elems: AddrMapEntry*): AddrMap = new AddrMap(elems) } -class AddrHashMap(addrmap: AddrMap, start: BigInt = BigInt(0)) { - private val mapping = HashMap[String, AddrHashMapEntry]() - private val subMaps = HashMap[String, AddrHashMapEntry]() +class AddrMap(entriesIn: Seq[AddrMapEntry], val start: BigInt = BigInt(0)) extends MemRegion { + def isEmpty = entries.isEmpty + def length = entries.size + def numSlaves = entries.map(_.region.numSlaves).foldLeft(0)(_ + _) + def attr = ??? - private def genPairs(am: AddrMap, start: BigInt, startIdx: Int, prefix: String): (BigInt, Int) = { - var ind = startIdx + private val slavePorts = HashMap[String, Int]() + private val mapping = HashMap[String, MemRegion]() + + val (size: BigInt, entries: Seq[AddrMapEntry]) = { + var ind = 0 var base = start - am.foreach { ame => - val name = prefix + ame.name - base = (base + ame.region.align - 1) / ame.region.align * ame.region.align - ame.region match { - case r: MemSize => - mapping += name -> AddrHashMapEntry(ind, base, r) - base += r.size - ind += 1 - case r: MemSubmap => - subMaps += name -> AddrHashMapEntry(-1, base, r) - ind = genPairs(r.entries, base, ind, name + ":")._2 - base += r.size - }} - (base, ind) - } + var rebasedEntries = collection.mutable.ArrayBuffer[AddrMapEntry]() + for (AddrMapEntry(name, r) <- entriesIn) { + if (r.start != 0) { + val align = BigInt(1) << log2Ceil(r.size) + require(r.start >= base, s"region $name base address 0x${r.start.toString(16)} overlaps previous base 0x${base.toString(16)}") + require(r.start % align == 0, s"region $name base address 0x${r.start.toString(16)} not aligned to 0x${align.toString(16)}") + base = r.start + } else { + base = (base + r.size - 1) / r.size * r.size + } - val size = genPairs(addrmap, start, 0, "")._1 + r match { + case r: AddrMap => + val subMap = new AddrMap(r.entries, base) + rebasedEntries += AddrMapEntry(name, subMap) + mapping += name -> subMap + mapping ++= subMap.mapping.map { case (k, v) => s"$name:$k" -> v } + slavePorts ++= subMap.slavePorts.map { case (k, v) => s"$name:$k" -> (ind + v) } + case _ => + val e = MemRange(base, r.size, r.attr) + rebasedEntries += AddrMapEntry(name, e) + mapping += name -> e + slavePorts += name -> ind + } - val sortedEntries: Seq[(String, BigInt, MemSize)] = { - val arr = new Array[(String, BigInt, MemSize)](mapping.size) - mapping.foreach { case (name, AddrHashMapEntry(port, base, region)) => - arr(port) = (name, base, region.asInstanceOf[MemSize]) + ind += r.numSlaves + base += r.size } - arr.toSeq + (base - start, rebasedEntries) } - def nEntries: Int = mapping.size - def apply(name: String): AddrHashMapEntry = mapping.getOrElse(name, subMaps(name)) - def subMap(name: String): (BigInt, AddrMap) = { - val m = subMaps(name) - (m.start, m.region.asInstanceOf[MemSubmap].entries) + val flatten: Seq[(String, MemRange)] = { + val arr = new Array[(String, MemRange)](slavePorts.size) + for ((name, port) <- slavePorts) + arr(port) = (name, mapping(name).asInstanceOf[MemRange]) + arr } - def isInRegion(name: String, addr: UInt): Bool = { - val start = mapping(name).start - val size = mapping(name).region.size - UInt(start) <= addr && addr < UInt(start + size) - } + def apply(name: String): MemRegion = mapping(name) + def port(name: String): Int = slavePorts(name) + def subMap(name: String): AddrMap = mapping(name).asInstanceOf[AddrMap] + def isInRegion(name: String, addr: UInt): Bool = mapping(name).containsAddress(addr) def isCacheable(addr: UInt): Bool = { - sortedEntries.filter(_._3.attr.cacheable).map { case (_, base, region) => - UInt(base) <= addr && addr < UInt(base + region.size) + flatten.filter(_._2.attr.cacheable).map { case (_, region) => + region.containsAddress(addr) }.foldLeft(Bool(false))(_ || _) } def isValid(addr: UInt): Bool = { - sortedEntries.map { case (_, base, region) => - addr >= UInt(base) && addr < UInt(base + region.size) + flatten.map { case (_, region) => + region.containsAddress(addr) }.foldLeft(Bool(false))(_ || _) } def getProt(addr: UInt): AddrMapProt = { - val protForRegion = sortedEntries.map { case (_, base, region) => - val inRegion = addr >= UInt(base) && addr < UInt(base + region.size) - Mux(inRegion, UInt(region.attr.prot, AddrMapProt.SZ), UInt(0)) + val protForRegion = flatten.map { case (_, region) => + Mux(region.containsAddress(addr), UInt(region.attr.prot, AddrMapProt.SZ), UInt(0)) } new AddrMapProt().fromBits(protForRegion.reduce(_|_)) } diff --git a/junctions/src/main/scala/nasti.scala b/junctions/src/main/scala/nasti.scala index 2453e0e3..f82073bb 100644 --- a/junctions/src/main/scala/nasti.scala +++ b/junctions/src/main/scala/nasti.scala @@ -506,37 +506,29 @@ abstract class NastiInterconnect(implicit p: Parameters) extends NastiModule()(p lazy val io = new NastiInterconnectIO(nMasters, nSlaves) } -class NastiRecursiveInterconnect( - val nMasters: Int, val nSlaves: Int, - addrmap: AddrMap, base: BigInt) +class NastiRecursiveInterconnect(val nMasters: Int, addrMap: AddrMap) (implicit p: Parameters) extends NastiInterconnect()(p) { - val levelSize = addrmap.size + def port(name: String) = io.slaves(addrMap.port(name)) + val nSlaves = addrMap.numSlaves + val routeSel = (addr: UInt) => + Cat(addrMap.entries.map(e => addrMap(e.name).containsAddress(addr)).reverse) - val addrHashMap = new AddrHashMap(addrmap, base) - val routeSel = (addr: UInt) => { - Cat(addrmap.map { case entry => - val hashEntry = addrHashMap(entry.name) - addr >= UInt(hashEntry.start) && addr < UInt(hashEntry.start + hashEntry.region.size) - }.reverse) - } - - val xbar = Module(new NastiCrossbar(nMasters, levelSize, routeSel)) + val xbar = Module(new NastiCrossbar(nMasters, addrMap.length, routeSel)) xbar.io.masters <> io.masters - io.slaves <> addrmap.zip(xbar.io.slaves).flatMap { + io.slaves <> addrMap.entries.zip(xbar.io.slaves).flatMap { case (entry, xbarSlave) => { entry.region match { - case _: MemSize => - Some(xbarSlave) - case MemSubmap(_, submap) if submap.isEmpty => + case submap: AddrMap if submap.entries.isEmpty => val err_slave = Module(new NastiErrorSlave) err_slave.io <> xbarSlave None - case MemSubmap(_, submap) => - val subSlaves = submap.countSlaves - val ic = Module(new NastiRecursiveInterconnect(1, subSlaves, submap, addrHashMap(entry.name).start)) + case submap: AddrMap => + val ic = Module(new NastiRecursiveInterconnect(1, submap)) ic.io.masters.head <> xbarSlave ic.io.slaves + case r: MemRange => + Some(xbarSlave) } } }