diff --git a/src/main/scala/uncore/tilelink2/ToAXI4.scala b/src/main/scala/uncore/tilelink2/ToAXI4.scala index 5d26a547..75003eda 100644 --- a/src/main/scala/uncore/tilelink2/ToAXI4.scala +++ b/src/main/scala/uncore/tilelink2/ToAXI4.scala @@ -60,9 +60,12 @@ class TLToAXI4(beatBytes: Int, combinational: Boolean = true)(implicit p: Parame // Construct the source=>ID mapping table val idTable = Wire(Vec(edgeIn.client.endSourceId, out.aw.bits.id)) + var idCount = Array.fill(edgeOut.master.endId) { 0 } (edgeIn.client.clients zip edgeOut.master.masters) foreach { case (c, m) => for (i <- 0 until c.sourceId.size) { - idTable(c.sourceId.start + i) := UInt(m.id.start + (if (c.requestFifo) 0 else i)) + val id = m.id.start + (if (c.requestFifo) 0 else i) + idTable(c.sourceId.start + i) := UInt(id) + idCount(id) = idCount(id) + 1 } } @@ -139,11 +142,11 @@ class TLToAXI4(beatBytes: Int, combinational: Boolean = true)(implicit p: Parame arw.qos := UInt(0) // no QoS arw.user.foreach { _ := a_state } - // !!! Mix R-W stall here - in.a.ready := Mux(a_isPut, (doneAW || out_arw.ready) && out_w.ready, out_arw.ready) - out_arw.valid := in.a.valid && Mux(a_isPut, !doneAW && out_w.ready, Bool(true)) + val stall = Wire(Bool()) + in.a.ready := !stall && Mux(a_isPut, (doneAW || out_arw.ready) && out_w.ready, out_arw.ready) + out_arw.valid := !stall && in.a.valid && Mux(a_isPut, !doneAW && out_w.ready, Bool(true)) - out_w.valid := in.a.valid && a_isPut && (doneAW || out_arw.ready) + out_w.valid := !stall && in.a.valid && a_isPut && (doneAW || out_arw.ready) out_w.bits.data := in.a.bits.data out_w.bits.strb := in.a.bits.mask out_w.bits.last := a_last @@ -167,6 +170,31 @@ class TLToAXI4(beatBytes: Int, combinational: Boolean = true)(implicit p: Parame in.d.bits := Mux(r_wins, r_d, b_d) in.d.bits.data := out.r.bits.data // avoid a costly Mux + // We need to track if any reads or writes are inflight for a given ID. + // If the opposite type arrives, we must stall until it completes. + val a_sel = UIntToOH(arw.id, edgeOut.master.endId).toBools + val d_sel = UIntToOH(Mux(r_wins, out.r.bits.id, out.b.bits.id), edgeOut.master.endId).toBools + val d_last = Mux(r_wins, out.r.bits.last, Bool(true)) + val d_first = RegInit(Bool(true)) + when (in.d.fire()) { d_first := d_last } + val stalls = ((a_sel zip d_sel) zip idCount) filter { case (_, n) => n > 1 } map { case ((as, ds), n) => + val count = RegInit(UInt(0, width = log2Ceil(n + 1))) + val write = Reg(Bool()) + val idle = count === UInt(0) + + // Once we start getting the response, it's safe to already switch R/W + val inc = as && out_arw.fire() + val dec = ds && d_first && in.d.fire() + count := count + inc.asUInt - dec.asUInt + + assert (!dec || count =/= UInt(0)) // underflow + assert (!inc || count =/= UInt(n)) // overflow + + when (inc) { write := arw.wen } + !idle && write =/= arw.wen + } + stall := stalls.foldLeft(Bool(false))(_||_) + // Tie off unused channels in.b.valid := Bool(false) in.c.ready := Bool(true)