diff --git a/src/main/scala/uncore/tilelink2/AddressDecoder.scala b/src/main/scala/uncore/tilelink2/AddressDecoder.scala index e937d60d..adddc8a5 100644 --- a/src/main/scala/uncore/tilelink2/AddressDecoder.scala +++ b/src/main/scala/uncore/tilelink2/AddressDecoder.scala @@ -25,13 +25,14 @@ object AddressDecoder // Verify the user did not give us an impossible problem ports.combinations(2).foreach { case Seq(x, y) => x.foreach { a => y.foreach { b => - require (!a.overlaps(b)) // it must be possible to disambiguate addresses! + require (!a.overlaps(b)) // it must be possible to disambiguate ports! } } } val maxBits = log2Ceil(ports.map(_.map(_.max).max).max + 1) val bits = (0 until maxBits).map(BigInt(1) << _).toSeq val selected = recurse(Seq(ports.map(_.sorted).sorted(portOrder)), bits) selected.reduceLeft(_ | _) + // port validation via mask expansion } // A simpler version that works for a Seq[Int] @@ -51,11 +52,12 @@ object AddressDecoder // pick the bit which minimizes the number of ports in each partition // as a secondary goal, reduce the number of AddressSets within a partition - val bigValue = 100000 - def bitScore(partitions: Partitions): Int = { + def bitScore(partitions: Partitions): Seq[Int] = { val maxPortsPerPartition = partitions.map(_.size).max + val sumPortsPerPartition = partitions.map(_.size).sum val maxSetsPerPartition = partitions.map(_.map(_.size).sum).max - maxPortsPerPartition * bigValue + maxSetsPerPartition + val sumSetsPerPartition = partitions.map(_.map(_.size).sum).sum + Seq(maxPortsPerPartition, sumPortsPerPartition, maxSetsPerPartition, sumSetsPerPartition) } def partitionPort(port: Port, bit: BigInt): (Port, Port) = { @@ -77,8 +79,8 @@ object AddressDecoder def partitionPartitions(partitions: Partitions, bit: BigInt): Partitions = { val partitioned_partitions = partitions.map(p => partitionPorts(p, bit)) - val case_a_partitions = partitioned_partitions.map(_._1) - val case_b_partitions = partitioned_partitions.map(_._2) + val case_a_partitions = partitioned_partitions.map(_._1).filter(!_.isEmpty) + val case_b_partitions = partitioned_partitions.map(_._2).filter(!_.isEmpty) val new_partitions = (case_a_partitions ++ case_b_partitions).sorted(partitionOrder) // Prevent combinational memory explosion; if two partitions are equal, keep only one // Note: AddressSets in a port are sorted, and ports in a partition are sorted. @@ -106,9 +108,9 @@ object AddressDecoder val score = bitScore(result) (score, bit, result) } - val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Int, BigInt, Partitions), Int](_._1)) + val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Seq[Int], BigInt, Partitions), Iterable[Int]](_._1.toIterable)) if (debug) println("=> Selected bit 0x%x".format(bestBit)) - if (bestScore < 2*bigValue) { + if (bestScore(0) <= 1) { if (debug) println("---") Seq(bestBit) } else {