diff --git a/src/main/scala/util/ShiftQueue.scala b/src/main/scala/util/ShiftQueue.scala index a64d3ee8..5914da89 100644 --- a/src/main/scala/util/ShiftQueue.scala +++ b/src/main/scala/util/ShiftQueue.scala @@ -20,20 +20,20 @@ class ShiftQueue[T <: Data](gen: T, private val valid = RegInit(Vec.fill(entries) { Bool(false) }) private val elts = Reg(Vec(entries, gen)) - private val do_enq = io.enq.fire() - private val do_deq = io.deq.fire() - for (i <- 0 until entries) { + def paddedValid(i: Int) = if (i == -1) true.B else if (i == entries) false.B else valid(i) + val wdata = if (i == entries-1) io.enq.bits else Mux(valid(i+1), elts(i+1), io.enq.bits) - val shiftDown = if (i == entries-1) false.B else io.deq.ready && valid(i+1) - val enqNew = io.enq.fire() && Mux(io.deq.ready, valid(i), !valid(i) && (if (i == 0) true.B else valid(i-1))) - when (shiftDown || enqNew) { elts(i) := wdata } - } + val wen = + Mux(io.deq.ready, + paddedValid(i+1) || io.enq.fire() && valid(i), + io.enq.fire() && paddedValid(i-1) && !valid(i)) + when (wen) { elts(i) := wdata } - val padded = Seq(true.B) ++ valid ++ Seq(false.B) - for (i <- 0 until entries) { - when ( do_enq && !do_deq && padded(i+0)) { valid(i) := true.B } - when (!do_enq && do_deq && !padded(i+2)) { valid(i) := false.B } + valid(i) := + Mux(io.deq.ready, + paddedValid(i+1) || io.enq.fire() && (Bool(i == 0 && !flow) || valid(i)), + io.enq.fire() && paddedValid(i-1) || valid(i)) } io.enq.ready := !valid(entries-1)