diff --git a/src/main/scala/uncore/tilelink2/AddressDecoder.scala b/src/main/scala/uncore/tilelink2/AddressDecoder.scala new file mode 100644 index 00000000..e937d60d --- /dev/null +++ b/src/main/scala/uncore/tilelink2/AddressDecoder.scala @@ -0,0 +1,118 @@ +// See LICENSE for license details. + +package uncore.tilelink2 + +import Chisel._ +import scala.math.{max,min} + +object AddressDecoder +{ + type Port = Seq[AddressSet] + type Ports = Seq[Port] + type Partition = Ports + type Partitions = Seq[Partition] + + val addressOrder = Ordering.ordered[AddressSet] + val portOrder = Ordering.Iterable(addressOrder) + val partitionOrder = Ordering.Iterable(portOrder) + + // Find the minimum subset of bits needed to disambiguate port addresses. + // ie: inspecting only the bits in the output, you can look at an address + // and decide to which port (outer Seq) the address belongs. + def apply(ports: Ports): BigInt = if (ports.size <= 1) 0 else { + // Every port must have at least one address! + ports.foreach { p => require (!p.isEmpty) } + // Verify the user did not give us an impossible problem + ports.combinations(2).foreach { case Seq(x, y) => + x.foreach { a => y.foreach { b => + require (!a.overlaps(b)) // it must be possible to disambiguate addresses! + } } + } + val maxBits = log2Ceil(ports.map(_.map(_.max).max).max + 1) + val bits = (0 until maxBits).map(BigInt(1) << _).toSeq + val selected = recurse(Seq(ports.map(_.sorted).sorted(portOrder)), bits) + selected.reduceLeft(_ | _) + } + + // A simpler version that works for a Seq[Int] + def apply(keys: Seq[Int]): Int = { + val ports = keys.map(b => Seq(AddressSet(b, 0))) + apply(ports).toInt + } + + // The algorithm has a set of partitions, discriminated by the selected bits. + // Each partion has a set of ports, listing all addresses that lead to that port. + // Seq[Seq[Seq[AddressSet]]] + // ^^^^^^^^^^^^^^^ set of addresses that are routed out this port + // ^^^ the list of ports + // ^^^ cases already distinguished by the selected bits thus far + // + // Solving this problem is NP-hard, so we use a simple greedy heuristic: + // pick the bit which minimizes the number of ports in each partition + // as a secondary goal, reduce the number of AddressSets within a partition + + val bigValue = 100000 + def bitScore(partitions: Partitions): Int = { + val maxPortsPerPartition = partitions.map(_.size).max + val maxSetsPerPartition = partitions.map(_.map(_.size).sum).max + maxPortsPerPartition * bigValue + maxSetsPerPartition + } + + def partitionPort(port: Port, bit: BigInt): (Port, Port) = { + val addr_a = AddressSet(0, ~bit) + val addr_b = AddressSet(bit, ~bit) + // The addresses were sorted, so the filtered addresses are still sorted + val subset_a = port.filter(_.overlaps(addr_a)) + val subset_b = port.filter(_.overlaps(addr_b)) + (subset_a, subset_b) + } + + def partitionPorts(ports: Ports, bit: BigInt): (Ports, Ports) = { + val partitioned_ports = ports.map(p => partitionPort(p, bit)) + // because partitionPort dropped AddresSets, the ports might no longer be sorted + val case_a_ports = partitioned_ports.map(_._1).filter(!_.isEmpty).sorted(portOrder) + val case_b_ports = partitioned_ports.map(_._2).filter(!_.isEmpty).sorted(portOrder) + (case_a_ports, case_b_ports) + } + + def partitionPartitions(partitions: Partitions, bit: BigInt): Partitions = { + val partitioned_partitions = partitions.map(p => partitionPorts(p, bit)) + val case_a_partitions = partitioned_partitions.map(_._1) + val case_b_partitions = partitioned_partitions.map(_._2) + val new_partitions = (case_a_partitions ++ case_b_partitions).sorted(partitionOrder) + // Prevent combinational memory explosion; if two partitions are equal, keep only one + // Note: AddressSets in a port are sorted, and ports in a partition are sorted. + // This makes it easy to structurally compare two partitions for equality + val keep = (new_partitions.init zip new_partitions.tail) filter { case (a,b) => partitionOrder.compare(a,b) != 0 } map { _._2 } + new_partitions.head +: keep + } + + // requirement: ports have sorted addresses and are sorted lexicographically + val debug = false + def recurse(partitions: Partitions, bits: Seq[BigInt]): Seq[BigInt] = { + if (debug) { + println("Partitioning:") + partitions.foreach { partition => + println(" Partition:") + partition.foreach { port => + print(" ") + port.foreach { a => print(s" ${a}") } + println("") + } + } + } + val candidates = bits.map { bit => + val result = partitionPartitions(partitions, bit) + val score = bitScore(result) + (score, bit, result) + } + val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Int, BigInt, Partitions), Int](_._1)) + if (debug) println("=> Selected bit 0x%x".format(bestBit)) + if (bestScore < 2*bigValue) { + if (debug) println("---") + Seq(bestBit) + } else { + bestBit +: recurse(bestPartitions, bits.filter(_ != bestBit)) + } + } +} diff --git a/src/main/scala/uncore/tilelink2/Parameters.scala b/src/main/scala/uncore/tilelink2/Parameters.scala index 3c93e96b..6423db3c 100644 --- a/src/main/scala/uncore/tilelink2/Parameters.scala +++ b/src/main/scala/uncore/tilelink2/Parameters.scala @@ -78,7 +78,7 @@ object TransferSizes { // Base is the base address, and mask are the bits consumed by the manager // e.g: base=0x200, mask=0xff describes a device managing 0x200-0x2ff // e.g: base=0x1000, mask=0xf0f decribes a device managing 0x1000-0x100f, 0x1100-0x110f, ... -case class AddressSet(base: BigInt, mask: BigInt) +case class AddressSet(base: BigInt, mask: BigInt) extends Ordered[AddressSet] { // Forbid misaligned base address (and empty sets) require ((base & mask) == 0) @@ -97,6 +97,16 @@ case class AddressSet(base: BigInt, mask: BigInt) // A strided slave serves discontiguous ranges def strided = alignment1 != mask + + // AddressSets have one natural Ordering (the containment order) + def compare(x: AddressSet) = { + val primary = (this.base - x.base).signum // smallest address first + val secondary = (x.mask - this.mask).signum // largest mask first + if (primary != 0) primary else secondary + } + + // We always want to see things in hex + override def toString() = "AddressSet(0x%x, 0x%x)".format(base, mask) } case class TLManagerParameters( @@ -185,6 +195,7 @@ case class TLManagerPortParameters(managers: Seq[TLManagerParameters], beatBytes def findFifoId(address: UInt) = Mux1H(find(address), managers.map(m => UInt(m.fifoId.map(_+1).getOrElse(0)))) def hasFifoId(address: UInt) = Mux1H(find(address), managers.map(m => Bool(m.fifoId.isDefined))) + lazy val addressMask = AddressDecoder(managers.map(_.address)) // !!! need a cheaper version of find, where we assume a valid address match exists // Does this Port manage this ID/address?