Implement client-side DMA controller
This commit is contained in:
parent
01a3447989
commit
304d8b814a
337
rocket/src/main/scala/dma.scala
Normal file
337
rocket/src/main/scala/dma.scala
Normal file
@ -0,0 +1,337 @@
|
||||
package rocket
|
||||
|
||||
import Chisel._
|
||||
import uncore._
|
||||
import uncore.DmaRequest._
|
||||
import junctions.ParameterizedBundle
|
||||
import cde.Parameters
|
||||
|
||||
trait HasClientDmaParameters extends HasCoreParameters with HasDmaParameters {
|
||||
val dmaAddrBits = coreMaxAddrBits
|
||||
val dmaSizeBits = coreMaxAddrBits
|
||||
val dmaSegmentBits = 24
|
||||
}
|
||||
|
||||
abstract class ClientDmaBundle(implicit val p: Parameters)
|
||||
extends ParameterizedBundle()(p) with HasClientDmaParameters
|
||||
abstract class ClientDmaModule(implicit val p: Parameters)
|
||||
extends Module with HasClientDmaParameters
|
||||
|
||||
class ClientDmaRequest(implicit p: Parameters) extends ClientDmaBundle()(p) {
|
||||
val cmd = UInt(width = DMA_CMD_SZ)
|
||||
val src_start = UInt(width = dmaAddrBits)
|
||||
val dst_start = UInt(width = dmaAddrBits)
|
||||
val src_stride = UInt(width = dmaSizeBits)
|
||||
val dst_stride = UInt(width = dmaSizeBits)
|
||||
val segment_size = UInt(width = dmaSizeBits)
|
||||
val nsegments = UInt(width = dmaSegmentBits)
|
||||
}
|
||||
|
||||
object ClientDmaRequest {
|
||||
def apply(cmd: UInt,
|
||||
src_start: UInt,
|
||||
dst_start: UInt,
|
||||
segment_size: UInt,
|
||||
nsegments: UInt = UInt(1),
|
||||
src_stride: UInt = UInt(0),
|
||||
dst_stride: UInt = UInt(0))
|
||||
(implicit p: Parameters) = {
|
||||
val req = Wire(new ClientDmaRequest)
|
||||
req.cmd := cmd
|
||||
req.src_start := src_start
|
||||
req.dst_start := dst_start
|
||||
req.src_stride := src_stride
|
||||
req.dst_stride := dst_stride
|
||||
req.segment_size := segment_size
|
||||
req.nsegments := nsegments
|
||||
req
|
||||
}
|
||||
}
|
||||
|
||||
object ClientDmaResponse {
|
||||
val pagefault = UInt("b01")
|
||||
val outer_err = UInt("b10")
|
||||
|
||||
def apply(status: UInt = UInt(0))(implicit p: Parameters) = {
|
||||
val resp = Wire(new ClientDmaResponse)
|
||||
resp.status := status
|
||||
resp
|
||||
}
|
||||
}
|
||||
|
||||
class ClientDmaResponse(implicit p: Parameters) extends ClientDmaBundle {
|
||||
val status = UInt(width = dmaStatusBits)
|
||||
}
|
||||
|
||||
class ClientDmaIO(implicit p: Parameters) extends ParameterizedBundle()(p) {
|
||||
val req = Decoupled(new ClientDmaRequest)
|
||||
val resp = Valid(new ClientDmaResponse).flip
|
||||
}
|
||||
|
||||
class DmaFrontend(implicit val p: Parameters)
|
||||
extends Module with HasClientDmaParameters {
|
||||
val io = new Bundle {
|
||||
val cpu = (new ClientDmaIO).flip
|
||||
val dma = new DmaIO
|
||||
val ptw = new TLBPTWIO
|
||||
val busy = Bool(OUTPUT)
|
||||
}
|
||||
|
||||
private val pgSize = 1 << pgIdxBits
|
||||
|
||||
val priv = Mux(io.ptw.status.mprv, io.ptw.status.prv1, io.ptw.status.prv)
|
||||
val vm_enabled = io.ptw.status.vm(3) && priv <= UInt(PRV_S)
|
||||
|
||||
val cmd = Reg(UInt(width = DMA_CMD_SZ))
|
||||
|
||||
val segment_size = Reg(UInt(width = dmaSizeBits))
|
||||
val bytes_left = Reg(UInt(width = dmaSizeBits))
|
||||
val segments_left = Reg(UInt(width = dmaSegmentBits))
|
||||
|
||||
val src_vaddr = Reg(UInt(width = dmaAddrBits))
|
||||
val dst_vaddr = Reg(UInt(width = dmaAddrBits))
|
||||
val src_vpn = src_vaddr(dmaAddrBits - 1, pgIdxBits)
|
||||
val dst_vpn = dst_vaddr(dmaAddrBits - 1, pgIdxBits)
|
||||
val src_idx = src_vaddr(pgIdxBits - 1, 0)
|
||||
val dst_idx = dst_vaddr(pgIdxBits - 1, 0)
|
||||
val src_pglen = UInt(pgSize) - src_idx
|
||||
val dst_pglen = UInt(pgSize) - dst_idx
|
||||
|
||||
val src_stride = Reg(UInt(width = dmaSizeBits))
|
||||
val dst_stride = Reg(UInt(width = dmaSizeBits))
|
||||
|
||||
val src_ppn = Reg(UInt(width = ppnBits))
|
||||
val dst_ppn = Reg(UInt(width = ppnBits))
|
||||
|
||||
val src_paddr = Mux(vm_enabled, Cat(src_ppn, src_idx), src_vaddr)
|
||||
val dst_paddr = Mux(vm_enabled, Cat(dst_ppn, dst_idx), dst_vaddr)
|
||||
|
||||
val last_src_vpn = Reg(UInt(width = vpnBits))
|
||||
val last_dst_vpn = Reg(UInt(width = vpnBits))
|
||||
|
||||
val tx_len = Mux(!vm_enabled, bytes_left,
|
||||
Util.minUInt(src_pglen, dst_pglen, bytes_left))
|
||||
|
||||
val (dma_xact_id, _) = Counter(io.dma.req.fire(), nDmaXactsPerClient)
|
||||
val dma_busy = Reg(init = UInt(0, nDmaXactsPerClient))
|
||||
|
||||
val (s_idle :: s_translate :: s_dma_req :: s_dma_update ::
|
||||
s_prepare :: s_finish :: Nil) = Enum(Bits(), 6)
|
||||
val state = Reg(init = s_idle)
|
||||
|
||||
// lower bit is for src, higher bit is for dst
|
||||
val to_translate = Reg(init = UInt(0, 2))
|
||||
val ptw_sent = Reg(init = UInt(0, 2))
|
||||
val ptw_to_send = to_translate & ~ptw_sent
|
||||
val ptw_resp_id = Reg(init = UInt(0, 1))
|
||||
val resp_status = Reg(UInt(width = dmaStatusBits))
|
||||
|
||||
io.ptw.req.valid := ptw_to_send.orR && vm_enabled
|
||||
io.ptw.req.bits.addr := Mux(ptw_to_send(0), src_vpn, dst_vpn)
|
||||
io.ptw.req.bits.prv := io.ptw.status.prv
|
||||
io.ptw.req.bits.store := !ptw_to_send(0) // storing to destination
|
||||
io.ptw.req.bits.fetch := Bool(true)
|
||||
|
||||
when (io.ptw.req.fire()) {
|
||||
ptw_sent := ptw_sent | PriorityEncoderOH(ptw_to_send)
|
||||
}
|
||||
|
||||
when (io.ptw.resp.valid) {
|
||||
when (io.ptw.resp.bits.error) {
|
||||
resp_status := ClientDmaResponse.pagefault
|
||||
state := s_finish
|
||||
}
|
||||
val recv_choice = PriorityEncoderOH(to_translate)
|
||||
to_translate := to_translate & ~recv_choice
|
||||
|
||||
// getting the src translation
|
||||
// if this is a prefetch, dst_ppn and src_ppn should be equal
|
||||
when (recv_choice(0) || cmd(1)) {
|
||||
src_ppn := io.ptw.resp.bits.pte.ppn
|
||||
} .otherwise {
|
||||
dst_ppn := io.ptw.resp.bits.pte.ppn
|
||||
}
|
||||
}
|
||||
|
||||
io.cpu.req.ready := state === s_idle
|
||||
io.cpu.resp.valid := state === s_finish
|
||||
io.cpu.resp.bits := ClientDmaResponse(resp_status)
|
||||
io.dma.req.valid := state === s_dma_req && !dma_busy(dma_xact_id)
|
||||
io.dma.req.bits := DmaRequest(
|
||||
client_xact_id = dma_xact_id,
|
||||
cmd = cmd,
|
||||
source = src_paddr,
|
||||
dest = dst_paddr,
|
||||
length = tx_len)
|
||||
io.dma.resp.ready := Bool(true)
|
||||
|
||||
when (io.cpu.req.fire()) {
|
||||
val req = io.cpu.req.bits
|
||||
cmd := req.cmd
|
||||
src_vaddr := req.src_start
|
||||
dst_vaddr := req.dst_start
|
||||
src_stride := req.src_stride
|
||||
dst_stride := req.dst_stride
|
||||
segment_size := req.segment_size
|
||||
segments_left := req.nsegments - UInt(1)
|
||||
bytes_left := req.segment_size
|
||||
to_translate := Mux(req.cmd(1), UInt("b10"), UInt("b11"))
|
||||
ptw_sent := UInt(0)
|
||||
state := Mux(vm_enabled, s_translate, s_dma_req)
|
||||
}
|
||||
|
||||
when (state === s_translate && !to_translate.orR) {
|
||||
state := s_dma_req
|
||||
}
|
||||
|
||||
when (io.dma.req.fire()) {
|
||||
src_vaddr := src_vaddr + tx_len
|
||||
dst_vaddr := dst_vaddr + tx_len
|
||||
bytes_left := bytes_left - tx_len
|
||||
dma_busy := dma_busy | UIntToOH(dma_xact_id)
|
||||
state := s_dma_update
|
||||
}
|
||||
|
||||
when (io.dma.resp.fire()) {
|
||||
dma_busy := dma_busy & ~UIntToOH(io.dma.resp.bits.client_xact_id)
|
||||
}
|
||||
|
||||
when (state === s_dma_update) {
|
||||
when (bytes_left === UInt(0)) {
|
||||
when (segments_left === UInt(0)) {
|
||||
resp_status := UInt(0)
|
||||
state := s_finish
|
||||
} .otherwise {
|
||||
last_src_vpn := src_vpn
|
||||
last_dst_vpn := dst_vpn
|
||||
src_vaddr := src_vaddr + src_stride
|
||||
dst_vaddr := dst_vaddr + dst_stride
|
||||
bytes_left := segment_size
|
||||
segments_left := segments_left - UInt(1)
|
||||
state := Mux(vm_enabled, s_prepare, s_dma_req)
|
||||
}
|
||||
} .otherwise {
|
||||
to_translate := Cat(dst_idx === UInt(0), !cmd(1) && src_idx === UInt(0))
|
||||
ptw_sent := UInt(0)
|
||||
state := s_translate
|
||||
}
|
||||
}
|
||||
|
||||
when (state === s_prepare) {
|
||||
to_translate := Cat(
|
||||
dst_vpn =/= last_dst_vpn,
|
||||
src_vpn =/= last_src_vpn && !cmd(1))
|
||||
ptw_sent := UInt(0)
|
||||
state := s_translate
|
||||
}
|
||||
|
||||
when (state === s_finish) { state := s_idle }
|
||||
|
||||
io.busy := (state =/= s_idle) || dma_busy.orR
|
||||
}
|
||||
|
||||
object DmaCtrlRegNumbers {
|
||||
val SRC_STRIDE = 0
|
||||
val DST_STRIDE = 1
|
||||
val SEGMENT_SIZE = 2
|
||||
val NSEGMENTS = 3
|
||||
val RESP_STATUS = 4
|
||||
}
|
||||
import DmaCtrlRegNumbers._
|
||||
|
||||
class DmaCtrlRegFile(implicit p: Parameters) extends ClientDmaModule()(p) {
|
||||
private val nWriteRegs = 4
|
||||
private val nReadRegs = 1
|
||||
private val nRegs = nWriteRegs + nReadRegs
|
||||
|
||||
val io = new Bundle {
|
||||
val wen = Bool(INPUT)
|
||||
val addr = UInt(INPUT, log2Up(nRegs))
|
||||
val wdata = UInt(INPUT, dmaSizeBits)
|
||||
val rdata = UInt(OUTPUT, dmaSizeBits)
|
||||
|
||||
val src_stride = UInt(OUTPUT, dmaSizeBits)
|
||||
val dst_stride = UInt(OUTPUT, dmaSizeBits)
|
||||
val segment_size = UInt(OUTPUT, dmaSizeBits)
|
||||
val nsegments = UInt(OUTPUT, dmaSegmentBits)
|
||||
|
||||
val status = UInt(INPUT, dmaStatusBits)
|
||||
}
|
||||
|
||||
val regs = Reg(Vec(nWriteRegs, UInt(width = dmaSizeBits)))
|
||||
|
||||
io.src_stride := regs(SRC_STRIDE)
|
||||
io.dst_stride := regs(DST_STRIDE)
|
||||
io.segment_size := regs(SEGMENT_SIZE)
|
||||
io.nsegments := regs(NSEGMENTS)
|
||||
|
||||
when (io.wen && io.addr < UInt(nWriteRegs)) {
|
||||
regs.write(io.addr, io.wdata)
|
||||
}
|
||||
|
||||
io.rdata := MuxLookup(io.addr, regs(io.addr), Seq(
|
||||
UInt(RESP_STATUS) -> io.status))
|
||||
}
|
||||
|
||||
class DmaController(implicit p: Parameters) extends RoCC()(p)
|
||||
with HasClientDmaParameters {
|
||||
io.mem.req.valid := Bool(false)
|
||||
io.autl.acquire.valid := Bool(false)
|
||||
io.autl.grant.ready := Bool(false)
|
||||
io.iptw.req.valid := Bool(false)
|
||||
io.pptw.req.valid := Bool(false)
|
||||
|
||||
val cmd = Queue(io.cmd)
|
||||
val inst = cmd.bits.inst
|
||||
val is_transfer = inst.funct < UInt(4)
|
||||
val is_cr_write = inst.funct === UInt(4)
|
||||
val is_cr_read = inst.funct === UInt(5)
|
||||
val is_cr_access = is_cr_write || is_cr_read
|
||||
|
||||
val resp_rd = Reg(io.resp.bits.rd)
|
||||
val resp_data = Reg(io.resp.bits.data)
|
||||
|
||||
val s_idle :: s_resp :: Nil = Enum(Bits(), 2)
|
||||
val state = Reg(init = s_idle)
|
||||
|
||||
val reg_status = Reg(UInt(width = dmaStatusBits))
|
||||
val crfile = Module(new DmaCtrlRegFile)
|
||||
crfile.io.addr := cmd.bits.rs1
|
||||
crfile.io.wdata := cmd.bits.rs2
|
||||
crfile.io.wen := cmd.fire() && is_cr_write
|
||||
|
||||
val frontend = Module(new DmaFrontend)
|
||||
io.dma <> frontend.io.dma
|
||||
io.dptw <> frontend.io.ptw
|
||||
frontend.io.cpu.req.valid := cmd.valid && is_transfer
|
||||
frontend.io.cpu.req.bits := ClientDmaRequest(
|
||||
cmd = cmd.bits.inst.funct,
|
||||
src_start = cmd.bits.rs2,
|
||||
dst_start = cmd.bits.rs1,
|
||||
src_stride = crfile.io.src_stride,
|
||||
dst_stride = crfile.io.dst_stride,
|
||||
segment_size = crfile.io.segment_size,
|
||||
nsegments = crfile.io.nsegments)
|
||||
|
||||
cmd.ready := state === s_idle && (!is_transfer || frontend.io.cpu.req.ready)
|
||||
io.resp.valid := state === s_resp
|
||||
io.resp.bits.rd := resp_rd
|
||||
io.resp.bits.data := resp_data
|
||||
|
||||
when (cmd.fire()) {
|
||||
when (is_cr_read) {
|
||||
resp_rd := inst.rd
|
||||
resp_data := crfile.io.rdata
|
||||
state := s_resp
|
||||
}
|
||||
}
|
||||
|
||||
when (io.resp.fire()) { state := s_idle }
|
||||
|
||||
when (frontend.io.cpu.resp.valid) {
|
||||
reg_status := frontend.io.cpu.resp.bits.status
|
||||
}
|
||||
|
||||
io.busy := (state =/= s_idle) || cmd.valid || frontend.io.busy
|
||||
io.interrupt := Bool(false)
|
||||
}
|
@ -50,6 +50,8 @@ class RoCCInterface(implicit p: Parameters) extends Bundle {
|
||||
val fpu_req = Decoupled(new FPInput)
|
||||
val fpu_resp = Decoupled(new FPResult).flip
|
||||
val exception = Bool(INPUT)
|
||||
|
||||
val dma = new DmaIO
|
||||
}
|
||||
|
||||
abstract class RoCC(implicit p: Parameters) extends CoreModule()(p) {
|
||||
|
@ -14,7 +14,8 @@ case class RoccParameters(
|
||||
opcodes: OpcodeSet,
|
||||
generator: Parameters => RoCC,
|
||||
nMemChannels: Int = 0,
|
||||
useFPU: Boolean = false)
|
||||
useFPU: Boolean = false,
|
||||
useDma: Boolean = false)
|
||||
|
||||
abstract class Tile(resetSignal: Bool = null)
|
||||
(implicit p: Parameters) extends Module(_reset = resetSignal) {
|
||||
@ -22,6 +23,7 @@ abstract class Tile(resetSignal: Bool = null)
|
||||
val usingRocc = !buildRocc.isEmpty
|
||||
val nRocc = buildRocc.size
|
||||
val nFPUPorts = buildRocc.filter(_.useFPU).size
|
||||
val nDmaPorts = buildRocc.filter(_.useDma).size
|
||||
val nDCachePorts = 2 + nRocc
|
||||
val nPTWPorts = 2 + 3 * nRocc
|
||||
val nCachedTileLinkPorts = 1
|
||||
@ -31,6 +33,7 @@ abstract class Tile(resetSignal: Bool = null)
|
||||
val cached = Vec(nCachedTileLinkPorts, new ClientTileLinkIO)
|
||||
val uncached = Vec(nUncachedTileLinkPorts, new ClientUncachedTileLinkIO)
|
||||
val host = new HtifIO
|
||||
val dma = new DmaIO
|
||||
}
|
||||
}
|
||||
|
||||
@ -104,6 +107,14 @@ class RocketTile(resetSignal: Bool = null)(implicit p: Parameters) extends Tile(
|
||||
}
|
||||
}
|
||||
|
||||
if (nDmaPorts > 0) {
|
||||
val dmaArb = Module(new DmaArbiter(nDmaPorts))
|
||||
dmaArb.io.in <> roccs.zip(buildRocc)
|
||||
.filter { case (_, params) => params.useDma }
|
||||
.map { case (rocc, _) => rocc.io.dma }
|
||||
io.dma <> dmaArb.io.out
|
||||
}
|
||||
|
||||
core.io.rocc.busy := cmdRouter.io.busy || roccs.map(_.io.busy).reduce(_ || _)
|
||||
core.io.rocc.interrupt := roccs.map(_.io.interrupt).reduce(_ || _)
|
||||
respArb.io.in <> roccs.map(rocc => Queue(rocc.io.resp))
|
||||
@ -117,4 +128,9 @@ class RocketTile(resetSignal: Bool = null)(implicit p: Parameters) extends Tile(
|
||||
fpu.io.cp_resp.ready := Bool(false)
|
||||
}
|
||||
}
|
||||
|
||||
if (!usingRocc || nDmaPorts == 0) {
|
||||
io.dma.req.valid := Bool(false)
|
||||
io.dma.resp.ready := Bool(false)
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,12 @@ object Util {
|
||||
implicit def booleanToIntConv(x: Boolean) = new AnyRef {
|
||||
def toInt: Int = if (x) 1 else 0
|
||||
}
|
||||
|
||||
def minUInt(values: Seq[UInt]): UInt =
|
||||
values.reduce((a, b) => Mux(a < b, a, b))
|
||||
|
||||
def minUInt(first: UInt, rest: UInt*): UInt =
|
||||
minUInt(first +: rest.toSeq)
|
||||
}
|
||||
|
||||
import Util._
|
||||
|
Loading…
Reference in New Issue
Block a user