From d4b3a0f0beaa1bea08b4b43b3b901cc7e4a4160c Mon Sep 17 00:00:00 2001 From: "Wesley W. Terpstra" Date: Sun, 22 Jan 2017 00:58:55 -0800 Subject: [PATCH] diplomacy: support given bits in AddressDecoder --- src/main/scala/diplomacy/AddressDecoder.scala | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/src/main/scala/diplomacy/AddressDecoder.scala b/src/main/scala/diplomacy/AddressDecoder.scala index 57902946..083eac11 100644 --- a/src/main/scala/diplomacy/AddressDecoder.scala +++ b/src/main/scala/diplomacy/AddressDecoder.scala @@ -19,7 +19,7 @@ object AddressDecoder // Find the minimum subset of bits needed to disambiguate port addresses. // ie: inspecting only the bits in the output, you can look at an address // and decide to which port (outer Seq) the address belongs. - def apply(ports: Ports): BigInt = if (ports.size <= 1) 0 else { + def apply(ports: Ports, givenBits: BigInt = BigInt(0)): BigInt = if (ports.size <= 1) givenBits else { // Every port must have at least one address! ports.foreach { p => require (!p.isEmpty) } // Verify the user did not give us an impossible problem @@ -30,9 +30,11 @@ object AddressDecoder } val maxBits = log2Ceil(ports.map(_.map(_.max).max).max + 1) - val bits = (0 until maxBits).map(BigInt(1) << _).toSeq - val selected = recurse(Seq(ports.map(_.sorted).sorted(portOrder)), bits) - val output = selected.reduceLeft(_ | _) + val (bitsToTry, bitsToTake) = (0 until maxBits).map(BigInt(1) << _).partition(b => (givenBits & b) == 0) + val partitions = Seq(ports.map(_.sorted).sorted(portOrder)) + val givenPartitions = bitsToTake.foldLeft(partitions) { (p, b) => partitionPartitions(p, b) } + val selected = recurse(givenPartitions, bitsToTry.toSeq) + val output = selected.reduceLeft(_ | _) | givenBits // Modify the AddressSets to allow the new wider match functions val widePorts = ports.map { _.map { _.widen(~output) } } @@ -103,28 +105,25 @@ object AddressDecoder // requirement: ports have sorted addresses and are sorted lexicographically val debug = false def recurse(partitions: Partitions, bits: Seq[BigInt]): Seq[BigInt] = { - if (debug) { - println("Partitioning:") - partitions.foreach { partition => - println(" Partition:") - partition.foreach { port => - print(" ") - port.foreach { a => print(s" ${a}") } - println("") + if (partitions.map(_.size <= 1).reduce(_ && _)) Seq() else { + if (debug) { + println("Partitioning:") + partitions.foreach { partition => + println(" Partition:") + partition.foreach { port => + print(" ") + port.foreach { a => print(s" ${a}") } + println("") + } } } - } - val candidates = bits.map { bit => - val result = partitionPartitions(partitions, bit) - val score = bitScore(result) - (score, bit, result) - } - val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Seq[Int], BigInt, Partitions), Iterable[Int]](_._1.toIterable)) - if (debug) println("=> Selected bit 0x%x".format(bestBit)) - if (bestScore(0) <= 1) { - if (debug) println("---") - Seq(bestBit) - } else { + val candidates = bits.map { bit => + val result = partitionPartitions(partitions, bit) + val score = bitScore(result) + (score, bit, result) + } + val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Seq[Int], BigInt, Partitions), Iterable[Int]](_._1.toIterable)) + if (debug) println("=> Selected bit 0x%x".format(bestBit)) bestBit +: recurse(bestPartitions, bits.filter(_ != bestBit)) } }