diff --git a/src/main/scala/diplomacy/Nodes.scala b/src/main/scala/diplomacy/Nodes.scala index 2d67584c..40da6ab3 100644 --- a/src/main/scala/diplomacy/Nodes.scala +++ b/src/main/scala/diplomacy/Nodes.scala @@ -199,6 +199,11 @@ trait OutwardNode[DO, UO, BO <: Data] extends BaseNode with OutwardNodeHandle[DO protected[diplomacy] val oParams: Seq[DO] } +abstract class CycleException(kind: String, loop: Seq[String]) extends Exception(s"Diplomatic ${kind} cycle detected involving ${loop}") +case class StarCycleException(loop: Seq[String] = Nil) extends CycleException("star", loop) +case class DownwardCycleException(loop: Seq[String] = Nil) extends CycleException("downward", loop) +case class UpwardCycleException(loop: Seq[String] = Nil) extends CycleException("upward", loop) + case class Edges[EI, EO](in: EI, out: EO) sealed abstract class MixedNode[DI, UI, EI, BI <: Data, DO, UO, EO, BO <: Data]( inner: InwardNodeImp [DI, UI, EI, BI], @@ -212,31 +217,37 @@ sealed abstract class MixedNode[DI, UI, EI, BI <: Data, DO, UO, EO, BO <: Data]( protected[diplomacy] def mapParamsD(n: Int, p: Seq[DI]): Seq[DO] protected[diplomacy] def mapParamsU(n: Int, p: Seq[UO]): Seq[UI] + private var starCycleGuard = false protected[diplomacy] lazy val (oPortMapping, iPortMapping, oStar, iStar) = { - val oStars = oBindings.filter { case (_,_,b,_,_) => b == BIND_STAR }.size - val iStars = iBindings.filter { case (_,_,b,_,_) => b == BIND_STAR }.size - val oKnown = oBindings.map { case (_, n, b, _, _) => b match { - case BIND_ONCE => 1 - case BIND_QUERY => n.iStar - case BIND_STAR => 0 }}.foldLeft(0)(_+_) - val iKnown = iBindings.map { case (_, n, b, _, _) => b match { - case BIND_ONCE => 1 - case BIND_QUERY => n.oStar - case BIND_STAR => 0 }}.foldLeft(0)(_+_) - val (iStar, oStar) = resolveStar(iKnown, oKnown, iStars, oStars) - val oSum = oBindings.map { case (_, n, b, _, _) => b match { - case BIND_ONCE => 1 - case BIND_QUERY => n.iStar - case BIND_STAR => oStar }}.scanLeft(0)(_+_) - val iSum = iBindings.map { case (_, n, b, _, _) => b match { - case BIND_ONCE => 1 - case BIND_QUERY => n.oStar - case BIND_STAR => iStar }}.scanLeft(0)(_+_) - val oTotal = oSum.lastOption.getOrElse(0) - val iTotal = iSum.lastOption.getOrElse(0) - require(numPO.contains(oTotal), s"${name} has ${oTotal} outputs, expected ${numPO}${lazyModule.line}") - require(numPI.contains(iTotal), s"${name} has ${iTotal} inputs, expected ${numPI}${lazyModule.line}") - (oSum.init zip oSum.tail, iSum.init zip iSum.tail, oStar, iStar) + try { + if (starCycleGuard) throw StarCycleException() + val oStars = oBindings.filter { case (_,_,b,_,_) => b == BIND_STAR }.size + val iStars = iBindings.filter { case (_,_,b,_,_) => b == BIND_STAR }.size + val oKnown = oBindings.map { case (_, n, b, _, _) => b match { + case BIND_ONCE => 1 + case BIND_QUERY => n.iStar + case BIND_STAR => 0 }}.foldLeft(0)(_+_) + val iKnown = iBindings.map { case (_, n, b, _, _) => b match { + case BIND_ONCE => 1 + case BIND_QUERY => n.oStar + case BIND_STAR => 0 }}.foldLeft(0)(_+_) + val (iStar, oStar) = resolveStar(iKnown, oKnown, iStars, oStars) + val oSum = oBindings.map { case (_, n, b, _, _) => b match { + case BIND_ONCE => 1 + case BIND_QUERY => n.iStar + case BIND_STAR => oStar }}.scanLeft(0)(_+_) + val iSum = iBindings.map { case (_, n, b, _, _) => b match { + case BIND_ONCE => 1 + case BIND_QUERY => n.oStar + case BIND_STAR => iStar }}.scanLeft(0)(_+_) + val oTotal = oSum.lastOption.getOrElse(0) + val iTotal = iSum.lastOption.getOrElse(0) + require(numPO.contains(oTotal), s"${name} has ${oTotal} outputs, expected ${numPO}${lazyModule.line}") + require(numPI.contains(iTotal), s"${name} has ${iTotal} inputs, expected ${numPI}${lazyModule.line}") + (oSum.init zip oSum.tail, iSum.init zip iSum.tail, oStar, iStar) + } catch { + case c: StarCycleException => throw c.copy(loop = s"${name}${lazyModule.line}" +: c.loop) + } } lazy val oPorts = oBindings.flatMap { case (i, n, _, p, s) => @@ -248,15 +259,30 @@ sealed abstract class MixedNode[DI, UI, EI, BI <: Data, DO, UO, EO, BO <: Data]( (start until end) map { j => (j, n, p, s) } } + private var oParamsCycleGuard = false protected[diplomacy] lazy val oParams: Seq[DO] = { - val o = mapParamsD(oPorts.size, iPorts.map { case (i, n, _, _) => n.oParams(i) }) - require (o.size == oPorts.size, s"Bug in diplomacy; ${name} has ${o.size} != ${oPorts.size} down/up outer parameters${lazyModule.line}") - o.map(outer.mixO(_, this)) + try { + if (oParamsCycleGuard) throw DownwardCycleException() + oParamsCycleGuard = true + val o = mapParamsD(oPorts.size, iPorts.map { case (i, n, _, _) => n.oParams(i) }) + require (o.size == oPorts.size, s"Bug in diplomacy; ${name} has ${o.size} != ${oPorts.size} down/up outer parameters${lazyModule.line}") + o.map(outer.mixO(_, this)) + } catch { + case c: DownwardCycleException => throw c.copy(loop = s"${name}${lazyModule.line}" +: c.loop) + } } + + private var iParamsCycleGuard = false protected[diplomacy] lazy val iParams: Seq[UI] = { - val i = mapParamsU(iPorts.size, oPorts.map { case (o, n, _, _) => n.iParams(o) }) - require (i.size == iPorts.size, s"Bug in diplomacy; ${name} has ${i.size} != ${iPorts.size} up/down inner parameters${lazyModule.line}") - i.map(inner.mixI(_, this)) + try { + if (iParamsCycleGuard) throw UpwardCycleException() + iParamsCycleGuard = true + val i = mapParamsU(iPorts.size, oPorts.map { case (o, n, _, _) => n.iParams(o) }) + require (i.size == iPorts.size, s"Bug in diplomacy; ${name} has ${i.size} != ${iPorts.size} up/down inner parameters${lazyModule.line}") + i.map(inner.mixI(_, this)) + } catch { + case c: UpwardCycleException => throw c.copy(loop = s"${name}${lazyModule.line}" +: c.loop) + } } protected[diplomacy] def gco = if (iParams.size != 1) None else inner.getO(iParams(0))