From 304d8b814abf37a2e5ae241fd70c02a8c5bd57e8 Mon Sep 17 00:00:00 2001 From: Howard Mao Date: Tue, 17 Nov 2015 18:14:30 -0800 Subject: [PATCH] Implement client-side DMA controller --- rocket/src/main/scala/dma.scala | 337 +++++++++++++++++++++++++++++++ rocket/src/main/scala/rocc.scala | 2 + rocket/src/main/scala/tile.scala | 18 +- rocket/src/main/scala/util.scala | 6 + 4 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 rocket/src/main/scala/dma.scala diff --git a/rocket/src/main/scala/dma.scala b/rocket/src/main/scala/dma.scala new file mode 100644 index 00000000..333ff95d --- /dev/null +++ b/rocket/src/main/scala/dma.scala @@ -0,0 +1,337 @@ +package rocket + +import Chisel._ +import uncore._ +import uncore.DmaRequest._ +import junctions.ParameterizedBundle +import cde.Parameters + +trait HasClientDmaParameters extends HasCoreParameters with HasDmaParameters { + val dmaAddrBits = coreMaxAddrBits + val dmaSizeBits = coreMaxAddrBits + val dmaSegmentBits = 24 +} + +abstract class ClientDmaBundle(implicit val p: Parameters) + extends ParameterizedBundle()(p) with HasClientDmaParameters +abstract class ClientDmaModule(implicit val p: Parameters) + extends Module with HasClientDmaParameters + +class ClientDmaRequest(implicit p: Parameters) extends ClientDmaBundle()(p) { + val cmd = UInt(width = DMA_CMD_SZ) + val src_start = UInt(width = dmaAddrBits) + val dst_start = UInt(width = dmaAddrBits) + val src_stride = UInt(width = dmaSizeBits) + val dst_stride = UInt(width = dmaSizeBits) + val segment_size = UInt(width = dmaSizeBits) + val nsegments = UInt(width = dmaSegmentBits) +} + +object ClientDmaRequest { + def apply(cmd: UInt, + src_start: UInt, + dst_start: UInt, + segment_size: UInt, + nsegments: UInt = UInt(1), + src_stride: UInt = UInt(0), + dst_stride: UInt = UInt(0)) + (implicit p: Parameters) = { + val req = Wire(new ClientDmaRequest) + req.cmd := cmd + req.src_start := src_start + req.dst_start := dst_start + req.src_stride := src_stride + req.dst_stride := dst_stride + req.segment_size := segment_size + req.nsegments := nsegments + req + } +} + +object ClientDmaResponse { + val pagefault = UInt("b01") + val outer_err = UInt("b10") + + def apply(status: UInt = UInt(0))(implicit p: Parameters) = { + val resp = Wire(new ClientDmaResponse) + resp.status := status + resp + } +} + +class ClientDmaResponse(implicit p: Parameters) extends ClientDmaBundle { + val status = UInt(width = dmaStatusBits) +} + +class ClientDmaIO(implicit p: Parameters) extends ParameterizedBundle()(p) { + val req = Decoupled(new ClientDmaRequest) + val resp = Valid(new ClientDmaResponse).flip +} + +class DmaFrontend(implicit val p: Parameters) + extends Module with HasClientDmaParameters { + val io = new Bundle { + val cpu = (new ClientDmaIO).flip + val dma = new DmaIO + val ptw = new TLBPTWIO + val busy = Bool(OUTPUT) + } + + private val pgSize = 1 << pgIdxBits + + val priv = Mux(io.ptw.status.mprv, io.ptw.status.prv1, io.ptw.status.prv) + val vm_enabled = io.ptw.status.vm(3) && priv <= UInt(PRV_S) + + val cmd = Reg(UInt(width = DMA_CMD_SZ)) + + val segment_size = Reg(UInt(width = dmaSizeBits)) + val bytes_left = Reg(UInt(width = dmaSizeBits)) + val segments_left = Reg(UInt(width = dmaSegmentBits)) + + val src_vaddr = Reg(UInt(width = dmaAddrBits)) + val dst_vaddr = Reg(UInt(width = dmaAddrBits)) + val src_vpn = src_vaddr(dmaAddrBits - 1, pgIdxBits) + val dst_vpn = dst_vaddr(dmaAddrBits - 1, pgIdxBits) + val src_idx = src_vaddr(pgIdxBits - 1, 0) + val dst_idx = dst_vaddr(pgIdxBits - 1, 0) + val src_pglen = UInt(pgSize) - src_idx + val dst_pglen = UInt(pgSize) - dst_idx + + val src_stride = Reg(UInt(width = dmaSizeBits)) + val dst_stride = Reg(UInt(width = dmaSizeBits)) + + val src_ppn = Reg(UInt(width = ppnBits)) + val dst_ppn = Reg(UInt(width = ppnBits)) + + val src_paddr = Mux(vm_enabled, Cat(src_ppn, src_idx), src_vaddr) + val dst_paddr = Mux(vm_enabled, Cat(dst_ppn, dst_idx), dst_vaddr) + + val last_src_vpn = Reg(UInt(width = vpnBits)) + val last_dst_vpn = Reg(UInt(width = vpnBits)) + + val tx_len = Mux(!vm_enabled, bytes_left, + Util.minUInt(src_pglen, dst_pglen, bytes_left)) + + val (dma_xact_id, _) = Counter(io.dma.req.fire(), nDmaXactsPerClient) + val dma_busy = Reg(init = UInt(0, nDmaXactsPerClient)) + + val (s_idle :: s_translate :: s_dma_req :: s_dma_update :: + s_prepare :: s_finish :: Nil) = Enum(Bits(), 6) + val state = Reg(init = s_idle) + + // lower bit is for src, higher bit is for dst + val to_translate = Reg(init = UInt(0, 2)) + val ptw_sent = Reg(init = UInt(0, 2)) + val ptw_to_send = to_translate & ~ptw_sent + val ptw_resp_id = Reg(init = UInt(0, 1)) + val resp_status = Reg(UInt(width = dmaStatusBits)) + + io.ptw.req.valid := ptw_to_send.orR && vm_enabled + io.ptw.req.bits.addr := Mux(ptw_to_send(0), src_vpn, dst_vpn) + io.ptw.req.bits.prv := io.ptw.status.prv + io.ptw.req.bits.store := !ptw_to_send(0) // storing to destination + io.ptw.req.bits.fetch := Bool(true) + + when (io.ptw.req.fire()) { + ptw_sent := ptw_sent | PriorityEncoderOH(ptw_to_send) + } + + when (io.ptw.resp.valid) { + when (io.ptw.resp.bits.error) { + resp_status := ClientDmaResponse.pagefault + state := s_finish + } + val recv_choice = PriorityEncoderOH(to_translate) + to_translate := to_translate & ~recv_choice + + // getting the src translation + // if this is a prefetch, dst_ppn and src_ppn should be equal + when (recv_choice(0) || cmd(1)) { + src_ppn := io.ptw.resp.bits.pte.ppn + } .otherwise { + dst_ppn := io.ptw.resp.bits.pte.ppn + } + } + + io.cpu.req.ready := state === s_idle + io.cpu.resp.valid := state === s_finish + io.cpu.resp.bits := ClientDmaResponse(resp_status) + io.dma.req.valid := state === s_dma_req && !dma_busy(dma_xact_id) + io.dma.req.bits := DmaRequest( + client_xact_id = dma_xact_id, + cmd = cmd, + source = src_paddr, + dest = dst_paddr, + length = tx_len) + io.dma.resp.ready := Bool(true) + + when (io.cpu.req.fire()) { + val req = io.cpu.req.bits + cmd := req.cmd + src_vaddr := req.src_start + dst_vaddr := req.dst_start + src_stride := req.src_stride + dst_stride := req.dst_stride + segment_size := req.segment_size + segments_left := req.nsegments - UInt(1) + bytes_left := req.segment_size + to_translate := Mux(req.cmd(1), UInt("b10"), UInt("b11")) + ptw_sent := UInt(0) + state := Mux(vm_enabled, s_translate, s_dma_req) + } + + when (state === s_translate && !to_translate.orR) { + state := s_dma_req + } + + when (io.dma.req.fire()) { + src_vaddr := src_vaddr + tx_len + dst_vaddr := dst_vaddr + tx_len + bytes_left := bytes_left - tx_len + dma_busy := dma_busy | UIntToOH(dma_xact_id) + state := s_dma_update + } + + when (io.dma.resp.fire()) { + dma_busy := dma_busy & ~UIntToOH(io.dma.resp.bits.client_xact_id) + } + + when (state === s_dma_update) { + when (bytes_left === UInt(0)) { + when (segments_left === UInt(0)) { + resp_status := UInt(0) + state := s_finish + } .otherwise { + last_src_vpn := src_vpn + last_dst_vpn := dst_vpn + src_vaddr := src_vaddr + src_stride + dst_vaddr := dst_vaddr + dst_stride + bytes_left := segment_size + segments_left := segments_left - UInt(1) + state := Mux(vm_enabled, s_prepare, s_dma_req) + } + } .otherwise { + to_translate := Cat(dst_idx === UInt(0), !cmd(1) && src_idx === UInt(0)) + ptw_sent := UInt(0) + state := s_translate + } + } + + when (state === s_prepare) { + to_translate := Cat( + dst_vpn =/= last_dst_vpn, + src_vpn =/= last_src_vpn && !cmd(1)) + ptw_sent := UInt(0) + state := s_translate + } + + when (state === s_finish) { state := s_idle } + + io.busy := (state =/= s_idle) || dma_busy.orR +} + +object DmaCtrlRegNumbers { + val SRC_STRIDE = 0 + val DST_STRIDE = 1 + val SEGMENT_SIZE = 2 + val NSEGMENTS = 3 + val RESP_STATUS = 4 +} +import DmaCtrlRegNumbers._ + +class DmaCtrlRegFile(implicit p: Parameters) extends ClientDmaModule()(p) { + private val nWriteRegs = 4 + private val nReadRegs = 1 + private val nRegs = nWriteRegs + nReadRegs + + val io = new Bundle { + val wen = Bool(INPUT) + val addr = UInt(INPUT, log2Up(nRegs)) + val wdata = UInt(INPUT, dmaSizeBits) + val rdata = UInt(OUTPUT, dmaSizeBits) + + val src_stride = UInt(OUTPUT, dmaSizeBits) + val dst_stride = UInt(OUTPUT, dmaSizeBits) + val segment_size = UInt(OUTPUT, dmaSizeBits) + val nsegments = UInt(OUTPUT, dmaSegmentBits) + + val status = UInt(INPUT, dmaStatusBits) + } + + val regs = Reg(Vec(nWriteRegs, UInt(width = dmaSizeBits))) + + io.src_stride := regs(SRC_STRIDE) + io.dst_stride := regs(DST_STRIDE) + io.segment_size := regs(SEGMENT_SIZE) + io.nsegments := regs(NSEGMENTS) + + when (io.wen && io.addr < UInt(nWriteRegs)) { + regs.write(io.addr, io.wdata) + } + + io.rdata := MuxLookup(io.addr, regs(io.addr), Seq( + UInt(RESP_STATUS) -> io.status)) +} + +class DmaController(implicit p: Parameters) extends RoCC()(p) + with HasClientDmaParameters { + io.mem.req.valid := Bool(false) + io.autl.acquire.valid := Bool(false) + io.autl.grant.ready := Bool(false) + io.iptw.req.valid := Bool(false) + io.pptw.req.valid := Bool(false) + + val cmd = Queue(io.cmd) + val inst = cmd.bits.inst + val is_transfer = inst.funct < UInt(4) + val is_cr_write = inst.funct === UInt(4) + val is_cr_read = inst.funct === UInt(5) + val is_cr_access = is_cr_write || is_cr_read + + val resp_rd = Reg(io.resp.bits.rd) + val resp_data = Reg(io.resp.bits.data) + + val s_idle :: s_resp :: Nil = Enum(Bits(), 2) + val state = Reg(init = s_idle) + + val reg_status = Reg(UInt(width = dmaStatusBits)) + val crfile = Module(new DmaCtrlRegFile) + crfile.io.addr := cmd.bits.rs1 + crfile.io.wdata := cmd.bits.rs2 + crfile.io.wen := cmd.fire() && is_cr_write + + val frontend = Module(new DmaFrontend) + io.dma <> frontend.io.dma + io.dptw <> frontend.io.ptw + frontend.io.cpu.req.valid := cmd.valid && is_transfer + frontend.io.cpu.req.bits := ClientDmaRequest( + cmd = cmd.bits.inst.funct, + src_start = cmd.bits.rs2, + dst_start = cmd.bits.rs1, + src_stride = crfile.io.src_stride, + dst_stride = crfile.io.dst_stride, + segment_size = crfile.io.segment_size, + nsegments = crfile.io.nsegments) + + cmd.ready := state === s_idle && (!is_transfer || frontend.io.cpu.req.ready) + io.resp.valid := state === s_resp + io.resp.bits.rd := resp_rd + io.resp.bits.data := resp_data + + when (cmd.fire()) { + when (is_cr_read) { + resp_rd := inst.rd + resp_data := crfile.io.rdata + state := s_resp + } + } + + when (io.resp.fire()) { state := s_idle } + + when (frontend.io.cpu.resp.valid) { + reg_status := frontend.io.cpu.resp.bits.status + } + + io.busy := (state =/= s_idle) || cmd.valid || frontend.io.busy + io.interrupt := Bool(false) +} diff --git a/rocket/src/main/scala/rocc.scala b/rocket/src/main/scala/rocc.scala index f2b2decd..a0d68abd 100644 --- a/rocket/src/main/scala/rocc.scala +++ b/rocket/src/main/scala/rocc.scala @@ -50,6 +50,8 @@ class RoCCInterface(implicit p: Parameters) extends Bundle { val fpu_req = Decoupled(new FPInput) val fpu_resp = Decoupled(new FPResult).flip val exception = Bool(INPUT) + + val dma = new DmaIO } abstract class RoCC(implicit p: Parameters) extends CoreModule()(p) { diff --git a/rocket/src/main/scala/tile.scala b/rocket/src/main/scala/tile.scala index 86582cff..233b5b8b 100644 --- a/rocket/src/main/scala/tile.scala +++ b/rocket/src/main/scala/tile.scala @@ -14,7 +14,8 @@ case class RoccParameters( opcodes: OpcodeSet, generator: Parameters => RoCC, nMemChannels: Int = 0, - useFPU: Boolean = false) + useFPU: Boolean = false, + useDma: Boolean = false) abstract class Tile(resetSignal: Bool = null) (implicit p: Parameters) extends Module(_reset = resetSignal) { @@ -22,6 +23,7 @@ abstract class Tile(resetSignal: Bool = null) val usingRocc = !buildRocc.isEmpty val nRocc = buildRocc.size val nFPUPorts = buildRocc.filter(_.useFPU).size + val nDmaPorts = buildRocc.filter(_.useDma).size val nDCachePorts = 2 + nRocc val nPTWPorts = 2 + 3 * nRocc val nCachedTileLinkPorts = 1 @@ -31,6 +33,7 @@ abstract class Tile(resetSignal: Bool = null) val cached = Vec(nCachedTileLinkPorts, new ClientTileLinkIO) val uncached = Vec(nUncachedTileLinkPorts, new ClientUncachedTileLinkIO) val host = new HtifIO + val dma = new DmaIO } } @@ -104,6 +107,14 @@ class RocketTile(resetSignal: Bool = null)(implicit p: Parameters) extends Tile( } } + if (nDmaPorts > 0) { + val dmaArb = Module(new DmaArbiter(nDmaPorts)) + dmaArb.io.in <> roccs.zip(buildRocc) + .filter { case (_, params) => params.useDma } + .map { case (rocc, _) => rocc.io.dma } + io.dma <> dmaArb.io.out + } + core.io.rocc.busy := cmdRouter.io.busy || roccs.map(_.io.busy).reduce(_ || _) core.io.rocc.interrupt := roccs.map(_.io.interrupt).reduce(_ || _) respArb.io.in <> roccs.map(rocc => Queue(rocc.io.resp)) @@ -117,4 +128,9 @@ class RocketTile(resetSignal: Bool = null)(implicit p: Parameters) extends Tile( fpu.io.cp_resp.ready := Bool(false) } } + + if (!usingRocc || nDmaPorts == 0) { + io.dma.req.valid := Bool(false) + io.dma.resp.ready := Bool(false) + } } diff --git a/rocket/src/main/scala/util.scala b/rocket/src/main/scala/util.scala index 4050be5b..a6ac1ad5 100644 --- a/rocket/src/main/scala/util.scala +++ b/rocket/src/main/scala/util.scala @@ -21,6 +21,12 @@ object Util { implicit def booleanToIntConv(x: Boolean) = new AnyRef { def toInt: Int = if (x) 1 else 0 } + + def minUInt(values: Seq[UInt]): UInt = + values.reduce((a, b) => Mux(a < b, a, b)) + + def minUInt(first: UInt, rest: UInt*): UInt = + minUInt(first +: rest.toSeq) } import Util._