diff --git a/src/main/scala/rocket/FPU.scala b/src/main/scala/rocket/FPU.scala index a13a76e3..6d23a81b 100644 --- a/src/main/scala/rocket/FPU.scala +++ b/src/main/scala/rocket/FPU.scala @@ -248,14 +248,21 @@ object RecFNToRecFN_noncompliant { } } +object CanonicalNaN { + def apply(expWidth: Int, sigWidth: Int): UInt = + UInt((BigInt(7) << (expWidth + sigWidth - 3)) + (BigInt(1) << (sigWidth - 2)), expWidth + sigWidth + 1) +} + trait HasFPUParameters { val fLen: Int val (sExpWidth, sSigWidth) = (8, 24) val (dExpWidth, dSigWidth) = (11, 53) - val (maxExpWidth, maxSigWidth) = fLen match { - case 32 => (sExpWidth, sSigWidth) - case 64 => (dExpWidth, dSigWidth) + val floatWidths = fLen match { + case 32 => List((sExpWidth, sSigWidth)) + case 64 => List((sExpWidth, sSigWidth), (dExpWidth, dSigWidth)) } + val maxExpWidth = floatWidths.map(_._1).max + val maxSigWidth = floatWidths.map(_._2).max } abstract class FPUModule(implicit p: Parameters) extends CoreModule()(p) with HasFPUParameters @@ -419,15 +426,17 @@ class IntToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) { val isnan2 = IsNaNRecFN(expWidth, sigWidth, in.bits.in2) val issnan1 = IsSNaNRecFN(expWidth, sigWidth, in.bits.in1) val issnan2 = IsSNaNRecFN(expWidth, sigWidth, in.bits.in2) - (isnan2 || in.bits.rm(0) =/= io.lt && !isnan1, issnan1 || issnan2) + val invalid = issnan1 || issnan2 + val isNaNOut = invalid || (isnan1 && isnan2) + val cNaN = floatWidths.filter(_._1 >= expWidth).map(x => CanonicalNaN(x._1, x._2)).reduce(_+_) + (isnan2 || in.bits.rm(0) =/= io.lt && !isnan1, invalid, isNaNOut, cNaN) } - val (isLHS, isInvalid) = fLen match { + val (isLHS, isInvalid, isNaNOut, cNaN) = fLen match { case 32 => doMinMax(sExpWidth, sSigWidth) case 64 => MuxT(in.bits.single, doMinMax(sExpWidth, sSigWidth), doMinMax(dExpWidth, dSigWidth)) } mux.exc := isInvalid << 4 - mux.data := in.bits.in1 - when (!isLHS) { mux.data := in.bits.in2 } + mux.data := Mux(isNaNOut, cNaN, Mux(isLHS, in.bits.in1, in.bits.in2)) } fLen match { diff --git a/src/main/scala/util/Misc.scala b/src/main/scala/util/Misc.scala index cf93be0d..d754d24b 100644 --- a/src/main/scala/util/Misc.scala +++ b/src/main/scala/util/Misc.scala @@ -37,6 +37,9 @@ object MuxT { def apply[T <: Data, U <: Data, W <: Data](cond: Bool, con: (T, U, W), alt: (T, U, W)): (T, U, W) = (Mux(cond, con._1, alt._1), Mux(cond, con._2, alt._2), Mux(cond, con._3, alt._3)) + + def apply[T <: Data, U <: Data, W <: Data, X <: Data](cond: Bool, con: (T, U, W, X), alt: (T, U, W, X)): (T, U, W, X) = + (Mux(cond, con._1, alt._1), Mux(cond, con._2, alt._2), Mux(cond, con._3, alt._3), Mux(cond, con._4, alt._4)) } /** Creates a cascade of n MuxTs to search for a key value. */