1
0

diplomacy: support given bits in AddressDecoder

This commit is contained in:
Wesley W. Terpstra 2017-01-22 00:58:55 -08:00 committed by Henry Cook
parent c0b6d31377
commit d4b3a0f0be

View File

@ -19,7 +19,7 @@ object AddressDecoder
// Find the minimum subset of bits needed to disambiguate port addresses.
// ie: inspecting only the bits in the output, you can look at an address
// and decide to which port (outer Seq) the address belongs.
def apply(ports: Ports): BigInt = if (ports.size <= 1) 0 else {
def apply(ports: Ports, givenBits: BigInt = BigInt(0)): BigInt = if (ports.size <= 1) givenBits else {
// Every port must have at least one address!
ports.foreach { p => require (!p.isEmpty) }
// Verify the user did not give us an impossible problem
@ -30,9 +30,11 @@ object AddressDecoder
}
val maxBits = log2Ceil(ports.map(_.map(_.max).max).max + 1)
val bits = (0 until maxBits).map(BigInt(1) << _).toSeq
val selected = recurse(Seq(ports.map(_.sorted).sorted(portOrder)), bits)
val output = selected.reduceLeft(_ | _)
val (bitsToTry, bitsToTake) = (0 until maxBits).map(BigInt(1) << _).partition(b => (givenBits & b) == 0)
val partitions = Seq(ports.map(_.sorted).sorted(portOrder))
val givenPartitions = bitsToTake.foldLeft(partitions) { (p, b) => partitionPartitions(p, b) }
val selected = recurse(givenPartitions, bitsToTry.toSeq)
val output = selected.reduceLeft(_ | _) | givenBits
// Modify the AddressSets to allow the new wider match functions
val widePorts = ports.map { _.map { _.widen(~output) } }
@ -103,6 +105,7 @@ object AddressDecoder
// requirement: ports have sorted addresses and are sorted lexicographically
val debug = false
def recurse(partitions: Partitions, bits: Seq[BigInt]): Seq[BigInt] = {
if (partitions.map(_.size <= 1).reduce(_ && _)) Seq() else {
if (debug) {
println("Partitioning:")
partitions.foreach { partition =>
@ -121,10 +124,6 @@ object AddressDecoder
}
val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Seq[Int], BigInt, Partitions), Iterable[Int]](_._1.toIterable))
if (debug) println("=> Selected bit 0x%x".format(bestBit))
if (bestScore(0) <= 1) {
if (debug) println("---")
Seq(bestBit)
} else {
bestBit +: recurse(bestPartitions, bits.filter(_ != bestBit))
}
}