Merge pull request #307 from ucb-bar/address-shrink
RR: undefined regs return zeros
This commit is contained in:
commit
503ce14c98
@ -25,13 +25,25 @@ object AddressDecoder
|
|||||||
// Verify the user did not give us an impossible problem
|
// Verify the user did not give us an impossible problem
|
||||||
ports.combinations(2).foreach { case Seq(x, y) =>
|
ports.combinations(2).foreach { case Seq(x, y) =>
|
||||||
x.foreach { a => y.foreach { b =>
|
x.foreach { a => y.foreach { b =>
|
||||||
require (!a.overlaps(b)) // it must be possible to disambiguate addresses!
|
require (!a.overlaps(b)) // it must be possible to disambiguate ports!
|
||||||
} }
|
} }
|
||||||
}
|
}
|
||||||
|
|
||||||
val maxBits = log2Ceil(ports.map(_.map(_.max).max).max + 1)
|
val maxBits = log2Ceil(ports.map(_.map(_.max).max).max + 1)
|
||||||
val bits = (0 until maxBits).map(BigInt(1) << _).toSeq
|
val bits = (0 until maxBits).map(BigInt(1) << _).toSeq
|
||||||
val selected = recurse(Seq(ports.map(_.sorted).sorted(portOrder)), bits)
|
val selected = recurse(Seq(ports.map(_.sorted).sorted(portOrder)), bits)
|
||||||
selected.reduceLeft(_ | _)
|
val output = selected.reduceLeft(_ | _)
|
||||||
|
|
||||||
|
// Modify the AddressSets to allow the new wider match functions
|
||||||
|
val widePorts = ports.map { _.map { _.widen(~output) } }
|
||||||
|
// Verify that it remains possible to disambiguate all ports
|
||||||
|
widePorts.combinations(2).foreach { case Seq(x, y) =>
|
||||||
|
x.foreach { a => y.foreach { b =>
|
||||||
|
require (!a.overlaps(b))
|
||||||
|
} }
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
// A simpler version that works for a Seq[Int]
|
// A simpler version that works for a Seq[Int]
|
||||||
@ -51,11 +63,12 @@ object AddressDecoder
|
|||||||
// pick the bit which minimizes the number of ports in each partition
|
// pick the bit which minimizes the number of ports in each partition
|
||||||
// as a secondary goal, reduce the number of AddressSets within a partition
|
// as a secondary goal, reduce the number of AddressSets within a partition
|
||||||
|
|
||||||
val bigValue = 100000
|
def bitScore(partitions: Partitions): Seq[Int] = {
|
||||||
def bitScore(partitions: Partitions): Int = {
|
|
||||||
val maxPortsPerPartition = partitions.map(_.size).max
|
val maxPortsPerPartition = partitions.map(_.size).max
|
||||||
|
val sumPortsPerPartition = partitions.map(_.size).sum
|
||||||
val maxSetsPerPartition = partitions.map(_.map(_.size).sum).max
|
val maxSetsPerPartition = partitions.map(_.map(_.size).sum).max
|
||||||
maxPortsPerPartition * bigValue + maxSetsPerPartition
|
val sumSetsPerPartition = partitions.map(_.map(_.size).sum).sum
|
||||||
|
Seq(maxPortsPerPartition, sumPortsPerPartition, maxSetsPerPartition, sumSetsPerPartition)
|
||||||
}
|
}
|
||||||
|
|
||||||
def partitionPort(port: Port, bit: BigInt): (Port, Port) = {
|
def partitionPort(port: Port, bit: BigInt): (Port, Port) = {
|
||||||
@ -77,8 +90,8 @@ object AddressDecoder
|
|||||||
|
|
||||||
def partitionPartitions(partitions: Partitions, bit: BigInt): Partitions = {
|
def partitionPartitions(partitions: Partitions, bit: BigInt): Partitions = {
|
||||||
val partitioned_partitions = partitions.map(p => partitionPorts(p, bit))
|
val partitioned_partitions = partitions.map(p => partitionPorts(p, bit))
|
||||||
val case_a_partitions = partitioned_partitions.map(_._1)
|
val case_a_partitions = partitioned_partitions.map(_._1).filter(!_.isEmpty)
|
||||||
val case_b_partitions = partitioned_partitions.map(_._2)
|
val case_b_partitions = partitioned_partitions.map(_._2).filter(!_.isEmpty)
|
||||||
val new_partitions = (case_a_partitions ++ case_b_partitions).sorted(partitionOrder)
|
val new_partitions = (case_a_partitions ++ case_b_partitions).sorted(partitionOrder)
|
||||||
// Prevent combinational memory explosion; if two partitions are equal, keep only one
|
// Prevent combinational memory explosion; if two partitions are equal, keep only one
|
||||||
// Note: AddressSets in a port are sorted, and ports in a partition are sorted.
|
// Note: AddressSets in a port are sorted, and ports in a partition are sorted.
|
||||||
@ -106,9 +119,9 @@ object AddressDecoder
|
|||||||
val score = bitScore(result)
|
val score = bitScore(result)
|
||||||
(score, bit, result)
|
(score, bit, result)
|
||||||
}
|
}
|
||||||
val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Int, BigInt, Partitions), Int](_._1))
|
val (bestScore, bestBit, bestPartitions) = candidates.min(Ordering.by[(Seq[Int], BigInt, Partitions), Iterable[Int]](_._1.toIterable))
|
||||||
if (debug) println("=> Selected bit 0x%x".format(bestBit))
|
if (debug) println("=> Selected bit 0x%x".format(bestBit))
|
||||||
if (bestScore < 2*bigValue) {
|
if (bestScore(0) <= 1) {
|
||||||
if (debug) println("---")
|
if (debug) println("---")
|
||||||
Seq(bestBit)
|
Seq(bestBit)
|
||||||
} else {
|
} else {
|
||||||
|
@ -98,6 +98,9 @@ case class AddressSet(base: BigInt, mask: BigInt) extends Ordered[AddressSet]
|
|||||||
// A strided slave serves discontiguous ranges
|
// A strided slave serves discontiguous ranges
|
||||||
def strided = alignment1 != mask
|
def strided = alignment1 != mask
|
||||||
|
|
||||||
|
// Widen the match function to ignore all bits in imask
|
||||||
|
def widen(imask: BigInt) = AddressSet(base & ~imask, mask | imask)
|
||||||
|
|
||||||
// AddressSets have one natural Ordering (the containment order)
|
// AddressSets have one natural Ordering (the containment order)
|
||||||
def compare(x: AddressSet) = {
|
def compare(x: AddressSet) = {
|
||||||
val primary = (this.base - x.base).signum // smallest address first
|
val primary = (this.base - x.base).signum // smallest address first
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
package uncore.tilelink2
|
package uncore.tilelink2
|
||||||
|
|
||||||
import Chisel._
|
import chisel3._
|
||||||
import chisel3.util.{Irrevocable, IrrevocableIO}
|
import chisel3.util._
|
||||||
|
|
||||||
// A bus agnostic register interface to a register-based device
|
// A bus agnostic register interface to a register-based device
|
||||||
|
|
||||||
@ -36,6 +36,23 @@ object RegMapper
|
|||||||
regmap.combinations(2).foreach { case Seq((reg1, _), (reg2, _)) =>
|
regmap.combinations(2).foreach { case Seq((reg1, _), (reg2, _)) =>
|
||||||
require (reg1 != reg2)
|
require (reg1 != reg2)
|
||||||
}
|
}
|
||||||
|
// Don't be an asshole...
|
||||||
|
regmap.foreach { reg => require (reg._1 >= 0) }
|
||||||
|
// Make sure registers fit
|
||||||
|
val inParams = in.bits.params
|
||||||
|
val inBits = inParams.indexBits
|
||||||
|
assert (regmap.map(_._1).max < (1 << inBits))
|
||||||
|
|
||||||
|
val out = Wire(Irrevocable(new RegMapperOutput(inParams)))
|
||||||
|
val front = Wire(Irrevocable(new RegMapperInput(inParams)))
|
||||||
|
front.bits := in.bits
|
||||||
|
|
||||||
|
// Must this device pipeline the control channel?
|
||||||
|
val pipelined = regmap.map(_._2.map(_.pipelined)).flatten.reduce(_ || _)
|
||||||
|
val depth = concurrency.getOrElse(if (pipelined) 1 else 0)
|
||||||
|
require (depth >= 0)
|
||||||
|
require (!pipelined || depth > 0)
|
||||||
|
val back = if (depth > 0) Queue(front, depth, pipe = depth == 1) else front
|
||||||
|
|
||||||
// Convert to and from Bits
|
// Convert to and from Bits
|
||||||
def toBits(x: Int, tail: List[Boolean] = List.empty): List[Boolean] =
|
def toBits(x: Int, tail: List[Boolean] = List.empty): List[Boolean] =
|
||||||
@ -44,36 +61,33 @@ object RegMapper
|
|||||||
|
|
||||||
// Find the minimal mask that can decide the register map
|
// Find the minimal mask that can decide the register map
|
||||||
val mask = AddressDecoder(regmap.map(_._1))
|
val mask = AddressDecoder(regmap.map(_._1))
|
||||||
|
val maskMatch = ~UInt(mask, width = inBits)
|
||||||
val maskFilter = toBits(mask)
|
val maskFilter = toBits(mask)
|
||||||
val maskBits = maskFilter.filter(x => x).size
|
val maskBits = maskFilter.filter(x => x).size
|
||||||
|
|
||||||
// Calculate size and indexes into the register map
|
// Calculate size and indexes into the register map
|
||||||
val endIndex = 1 << log2Ceil(regmap.map(_._1).max+1)
|
|
||||||
val params = RegMapperParams(log2Up(endIndex), bytes, in.bits.params.extraBits)
|
|
||||||
val regSize = 1 << maskBits
|
val regSize = 1 << maskBits
|
||||||
def regIndexI(x: Int) = ofBits((maskFilter zip toBits(x)).filter(_._1).map(_._2))
|
def regIndexI(x: Int) = ofBits((maskFilter zip toBits(x)).filter(_._1).map(_._2))
|
||||||
def regIndexU(x: UInt) = if (maskBits == 0) UInt(0) else
|
def regIndexU(x: UInt) = if (maskBits == 0) UInt(0) else
|
||||||
Cat((maskFilter zip x.toBools).filter(_._1).map(_._2).reverse)
|
Cat((maskFilter zip x.toBools).filter(_._1).map(_._2).reverse)
|
||||||
|
|
||||||
|
// Protection flag for undefined registers
|
||||||
|
val iRightReg = Array.fill(regSize) { Bool(true) }
|
||||||
|
val oRightReg = Array.fill(regSize) { Bool(true) }
|
||||||
|
|
||||||
// Flatten the regmap into (RegIndex:Int, Offset:Int, field:RegField)
|
// Flatten the regmap into (RegIndex:Int, Offset:Int, field:RegField)
|
||||||
val flat = regmap.map { case (reg, fields) =>
|
val flat = regmap.map { case (reg, fields) =>
|
||||||
val offsets = fields.scanLeft(0)(_ + _.width).init
|
val offsets = fields.scanLeft(0)(_ + _.width).init
|
||||||
val index = regIndexI(reg)
|
val index = regIndexI(reg)
|
||||||
|
val uint = UInt(reg, width = inBits)
|
||||||
|
if (undefZero) {
|
||||||
|
iRightReg(index) = ((front.bits.index ^ uint) & maskMatch) === UInt(0)
|
||||||
|
oRightReg(index) = ((back .bits.index ^ uint) & maskMatch) === UInt(0)
|
||||||
|
}
|
||||||
// println("mapping 0x%x -> 0x%x for 0x%x/%d".format(reg, index, mask, maskBits))
|
// println("mapping 0x%x -> 0x%x for 0x%x/%d".format(reg, index, mask, maskBits))
|
||||||
(offsets zip fields) map { case (o, f) => (index, o, f) }
|
(offsets zip fields) map { case (o, f) => (index, o, f) }
|
||||||
}.flatten
|
}.flatten
|
||||||
|
|
||||||
val out = Wire(Irrevocable(new RegMapperOutput(params)))
|
|
||||||
val front = Wire(Irrevocable(new RegMapperInput(params)))
|
|
||||||
front.bits := in.bits
|
|
||||||
|
|
||||||
// Must this device pipeline the control channel?
|
|
||||||
val pipelined = flat.map(_._3.pipelined).reduce(_ || _)
|
|
||||||
val depth = concurrency.getOrElse(if (pipelined) 1 else 0)
|
|
||||||
require (depth >= 0)
|
|
||||||
require (!pipelined || depth > 0)
|
|
||||||
val back = if (depth > 0) Queue(front, depth, pipe = depth == 1) else front
|
|
||||||
|
|
||||||
// Forward declaration of all flow control signals
|
// Forward declaration of all flow control signals
|
||||||
val rivalid = Wire(Vec(flat.size, Bool()))
|
val rivalid = Wire(Vec(flat.size, Bool()))
|
||||||
val wivalid = Wire(Vec(flat.size, Bool()))
|
val wivalid = Wire(Vec(flat.size, Bool()))
|
||||||
@ -122,10 +136,10 @@ object RegMapper
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Is the selected register ready?
|
// Is the selected register ready?
|
||||||
val rifireMux = Vec(rifire.map(_.reduce(_ && _)))
|
val rifireMux = Vec(rifire.zipWithIndex.map { case (seq, i) => !iRightReg(i) || seq.reduce(_ && _)})
|
||||||
val wifireMux = Vec(wifire.map(_.reduce(_ && _)))
|
val wifireMux = Vec(wifire.zipWithIndex.map { case (seq, i) => !iRightReg(i) || seq.reduce(_ && _)})
|
||||||
val rofireMux = Vec(rofire.map(_.reduce(_ && _)))
|
val rofireMux = Vec(rofire.zipWithIndex.map { case (seq, i) => !oRightReg(i) || seq.reduce(_ && _)})
|
||||||
val wofireMux = Vec(wofire.map(_.reduce(_ && _)))
|
val wofireMux = Vec(wofire.zipWithIndex.map { case (seq, i) => !oRightReg(i) || seq.reduce(_ && _)})
|
||||||
val iindex = regIndexU(front.bits.index)
|
val iindex = regIndexU(front.bits.index)
|
||||||
val oindex = regIndexU(back .bits.index)
|
val oindex = regIndexU(back .bits.index)
|
||||||
val iready = Mux(front.bits.read, rifireMux(iindex), wifireMux(iindex))
|
val iready = Mux(front.bits.read, rifireMux(iindex), wifireMux(iindex))
|
||||||
@ -138,8 +152,8 @@ object RegMapper
|
|||||||
out.valid := back.valid && oready
|
out.valid := back.valid && oready
|
||||||
|
|
||||||
// Which register is touched?
|
// Which register is touched?
|
||||||
val frontSel = UIntToOH(iindex)
|
val frontSel = UIntToOH(iindex) & Cat(iRightReg.reverse)
|
||||||
val backSel = UIntToOH(oindex)
|
val backSel = UIntToOH(oindex) & Cat(oRightReg.reverse)
|
||||||
|
|
||||||
// Include the per-register one-hot selected criteria
|
// Include the per-register one-hot selected criteria
|
||||||
for (reg <- 0 until regSize) {
|
for (reg <- 0 until regSize) {
|
||||||
@ -159,9 +173,9 @@ object RegMapper
|
|||||||
}
|
}
|
||||||
|
|
||||||
out.bits.read := back.bits.read
|
out.bits.read := back.bits.read
|
||||||
out.bits.data := Vec(dataOut)(oindex)
|
out.bits.data := Mux(Vec(oRightReg)(oindex), Vec(dataOut)(oindex), UInt(0))
|
||||||
out.bits.extra := back.bits.extra
|
out.bits.extra := back.bits.extra
|
||||||
|
|
||||||
(endIndex, out)
|
out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ class TLRegisterNode(address: AddressSet, concurrency: Option[Int] = None, beatB
|
|||||||
val (sourceEnd, sourceOff) = (edge.bundle.sourceBits + sizeEnd, sizeEnd)
|
val (sourceEnd, sourceOff) = (edge.bundle.sourceBits + sizeEnd, sizeEnd)
|
||||||
val (addrLoEnd, addrLoOff) = (log2Up(beatBytes) + sourceEnd, sourceEnd)
|
val (addrLoEnd, addrLoOff) = (log2Up(beatBytes) + sourceEnd, sourceEnd)
|
||||||
|
|
||||||
val params = RegMapperParams(log2Up(address.mask+1), beatBytes, addrLoEnd)
|
val params = RegMapperParams(log2Up((address.mask+1)/beatBytes), beatBytes, addrLoEnd)
|
||||||
val in = Wire(Decoupled(new RegMapperInput(params)))
|
val in = Wire(Decoupled(new RegMapperInput(params)))
|
||||||
in.bits.read := a.bits.opcode === TLMessages.Get
|
in.bits.read := a.bits.opcode === TLMessages.Get
|
||||||
in.bits.index := a.bits.addr_hi
|
in.bits.index := a.bits.addr_hi
|
||||||
@ -36,10 +36,7 @@ class TLRegisterNode(address: AddressSet, concurrency: Option[Int] = None, beatB
|
|||||||
in.bits.extra := Cat(edge.addr_lo(a.bits), a.bits.source, a.bits.size)
|
in.bits.extra := Cat(edge.addr_lo(a.bits), a.bits.source, a.bits.size)
|
||||||
|
|
||||||
// Invoke the register map builder
|
// Invoke the register map builder
|
||||||
val (endIndex, out) = RegMapper(beatBytes, concurrency, undefZero, in, mapping:_*)
|
val out = RegMapper(beatBytes, concurrency, undefZero, in, mapping:_*)
|
||||||
|
|
||||||
// All registers must fit inside the device address space
|
|
||||||
require (address.mask >= (endIndex-1)*beatBytes)
|
|
||||||
|
|
||||||
// No flow control needed
|
// No flow control needed
|
||||||
in.valid := a.valid
|
in.valid := a.valid
|
||||||
|
@ -26,15 +26,15 @@ class RRTestCombinational(val bits: Int, rvalid: Bool => Bool, wready: Bool => B
|
|||||||
val wdata = UInt(INPUT, width = bits)
|
val wdata = UInt(INPUT, width = bits)
|
||||||
}
|
}
|
||||||
|
|
||||||
val rfire = io.rvalid && io.rready
|
|
||||||
val wfire = io.wvalid && io.wready
|
|
||||||
val reg = Reg(UInt(width = bits))
|
val reg = Reg(UInt(width = bits))
|
||||||
|
|
||||||
io.rvalid := rvalid(rfire)
|
val rvalid_s = rvalid(io.rready)
|
||||||
io.wready := wready(wfire)
|
val wready_s = wready(io.wvalid)
|
||||||
|
io.rvalid := rvalid_s
|
||||||
|
io.wready := wready_s
|
||||||
|
|
||||||
io.rdata := reg
|
io.rdata := reg
|
||||||
when (wfire) { reg := io.wdata }
|
when (io.wvalid && wready_s) { reg := io.wdata }
|
||||||
}
|
}
|
||||||
|
|
||||||
object RRTestCombinational
|
object RRTestCombinational
|
||||||
@ -43,19 +43,19 @@ object RRTestCombinational
|
|||||||
|
|
||||||
def always: Bool => Bool = _ => Bool(true)
|
def always: Bool => Bool = _ => Bool(true)
|
||||||
|
|
||||||
def random: Bool => Bool = { fire =>
|
def random: Bool => Bool = { ready =>
|
||||||
seed = seed + 1
|
seed = seed + 1
|
||||||
val lfsr = LFSR16Seed(seed)
|
val lfsr = LFSR16Seed(seed)
|
||||||
val reg = RegInit(Bool(true))
|
val valid = RegInit(Bool(true))
|
||||||
reg := Mux(reg, !fire, lfsr(0) && lfsr(1))
|
valid := Mux(valid, !ready, lfsr(0) && lfsr(1))
|
||||||
reg
|
valid
|
||||||
}
|
}
|
||||||
|
|
||||||
def delay(x: Int): Bool => Bool = { fire =>
|
def delay(x: Int): Bool => Bool = { ready =>
|
||||||
val reg = RegInit(UInt(0, width = log2Ceil(x+1)))
|
val reg = RegInit(UInt(0, width = log2Ceil(x+1)))
|
||||||
val ready = reg === UInt(0)
|
val valid = reg === UInt(0)
|
||||||
reg := Mux(fire, UInt(x), Mux(ready, UInt(0), reg - UInt(1)))
|
reg := Mux(ready && valid, UInt(x), Mux(valid, UInt(0), reg - UInt(1)))
|
||||||
ready
|
valid
|
||||||
}
|
}
|
||||||
|
|
||||||
def combo(bits: Int, rvalid: Bool => Bool, wready: Bool => Bool): RegField = {
|
def combo(bits: Int, rvalid: Bool => Bool, wready: Bool => Bool): RegField = {
|
||||||
|
Loading…
Reference in New Issue
Block a user