1
0

Implement new FP encoding proposal

Single-precision values are stored in the regfile as double-precision,
so that FSD on a single-precision value stores a proper double and
FLD restores it as either a double or a single.
This commit is contained in:
Andrew Waterman 2017-03-26 11:32:26 -07:00 committed by Andrew Waterman
parent 7180352067
commit e710e32f10

View File

@ -222,13 +222,11 @@ object ClassifyRecFN {
} }
object IsNaNRecFN { object IsNaNRecFN {
def apply(expWidth: Int, sigWidth: Int, in: UInt) = def apply(in: UInt, t: FType) = in(t.sig + t.exp - 1, t.sig + t.exp - 3).andR
in(sigWidth + expWidth - 1, sigWidth + expWidth - 3).andR
} }
object IsSNaNRecFN { object IsSNaNRecFN {
def apply(expWidth: Int, sigWidth: Int, in: UInt) = def apply(in: UInt, t: FType) = IsNaNRecFN(in, t) && !in(t.sig - 2)
IsNaNRecFN(expWidth, sigWidth, in) && !in(sigWidth - 2)
} }
class FType(val exp: Int, val sig: Int) class FType(val exp: Int, val sig: Int)
@ -273,6 +271,9 @@ trait HasFPUParameters {
val maxType = floatTypes.sortWith(_.exp > _.exp).head val maxType = floatTypes.sortWith(_.exp > _.exp).head
val maxExpWidth = maxType.exp val maxExpWidth = maxType.exp
val maxSigWidth = maxType.sig val maxSigWidth = maxType.sig
def expand(x: UInt, t: FType) = RecFNToRecFN_noncompliant(x, t.exp, t.sig, maxType.exp, maxType.sig)
def contract(x: UInt, t: FType) = RecFNToRecFN_noncompliant(x, maxType.exp, maxType.sig, t.exp, t.sig)
} }
abstract class FPUModule(implicit p: Parameters) extends CoreModule()(p) with HasFPUParameters abstract class FPUModule(implicit p: Parameters) extends CoreModule()(p) with HasFPUParameters
@ -291,20 +292,11 @@ class FPToInt(implicit p: Parameters) extends FPUModule()(p) {
val out = Valid(new Output) val out = Valid(new Output)
} }
val in = Reg(new FPInput) val in = RegEnable(io.in.bits, io.in.valid)
val valid = Reg(next=io.in.valid) val valid = Reg(next=io.in.valid)
def upconvert(x: UInt) = RecFNToRecFN_noncompliant(x, sExpWidth, sSigWidth, maxExpWidth, maxSigWidth) val in1_s = contract(in.in1, FType.S)
val unrec_s = hardfloat.fNFromRecFN(sExpWidth, sSigWidth, in1_s).sextTo(xLen)
when (io.in.valid) {
in := io.in.bits
if (fLen > 32) when (io.in.bits.single && !io.in.bits.ldst && io.in.bits.cmd =/= FCMD_MV_XF) {
in.in1 := upconvert(io.in.bits.in1)
in.in2 := upconvert(io.in.bits.in2)
}
}
val unrec_s = hardfloat.fNFromRecFN(sExpWidth, sSigWidth, in.in1).sextTo(xLen)
val unrec_mem = fLen match { val unrec_mem = fLen match {
case 32 => unrec_s case 32 => unrec_s
case 64 => case 64 =>
@ -316,7 +308,7 @@ class FPToInt(implicit p: Parameters) extends FPUModule()(p) {
case fLen => unrec_mem case fLen => unrec_mem
} }
val classify_s = ClassifyRecFN(sExpWidth, sSigWidth, in.in1) val classify_s = ClassifyRecFN(sExpWidth, sSigWidth, in1_s)
val classify_out = fLen match { val classify_out = fLen match {
case 32 => classify_s case 32 => classify_s
case 64 => case 64 =>
@ -367,7 +359,7 @@ class IntToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) {
val mux = Wire(new FPResult) val mux = Wire(new FPResult)
mux.exc := Bits(0) mux.exc := Bits(0)
mux.data := hardfloat.recFNFromFN(sExpWidth, sSigWidth, in.bits.in1) mux.data := expand(hardfloat.recFNFromFN(sExpWidth, sSigWidth, in.bits.in1), FType.S)
if (fLen > 32) when (!in.bits.single) { if (fLen > 32) when (!in.bits.single) {
mux.data := hardfloat.recFNFromFN(dExpWidth, dSigWidth, in.bits.in1) mux.data := hardfloat.recFNFromFN(dExpWidth, dSigWidth, in.bits.in1)
} }
@ -390,67 +382,51 @@ class IntToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) {
l2s.io.signedIn := ~in.bits.typ(0) l2s.io.signedIn := ~in.bits.typ(0)
l2s.io.in := intValue l2s.io.in := intValue
l2s.io.roundingMode := in.bits.rm l2s.io.roundingMode := in.bits.rm
mux.data := expand(l2s.io.out, FType.S)
mux.exc := l2s.io.exceptionFlags
fLen match { fLen match {
case 32 => case 32 =>
mux.data := l2s.io.out
mux.exc := l2s.io.exceptionFlags
case 64 => case 64 =>
val l2d = Module(new hardfloat.INToRecFN(xLen, dExpWidth, dSigWidth)) val l2d = Module(new hardfloat.INToRecFN(xLen, dExpWidth, dSigWidth))
l2d.io.signedIn := ~in.bits.typ(0) l2d.io.signedIn := ~in.bits.typ(0)
l2d.io.in := intValue l2d.io.in := intValue
l2d.io.roundingMode := in.bits.rm l2d.io.roundingMode := in.bits.rm
mux.data := Cat(l2d.io.out >> l2s.io.out.getWidth, l2s.io.out)
mux.exc := l2s.io.exceptionFlags
when (!in.bits.single) { when (!in.bits.single) {
mux.data := l2d.io.out mux.data := l2d.io.out
mux.exc := l2d.io.exceptionFlags mux.exc := l2d.io.exceptionFlags
} }
}
} }
io.out <> Pipe(in.valid, mux, latency-1)
} }
class FPToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) { io.out <> Pipe(in.valid, mux, latency-1)
val io = new Bundle { }
val in = Valid(new FPInput).flip
val out = Valid(new FPResult)
val lt = Bool(INPUT) // from FPToInt
}
val in = Pipe(io.in) class FPToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) {
val io = new Bundle {
val in = Valid(new FPInput).flip
val out = Valid(new FPResult)
val lt = Bool(INPUT) // from FPToInt
}
val signNum = Mux(in.bits.rm(1), in.bits.in1 ^ in.bits.in2, Mux(in.bits.rm(0), ~in.bits.in2, in.bits.in2)) val in = Pipe(io.in)
val fsgnj_s = Cat(signNum(32), in.bits.in1(31, 0))
val fsgnj = fLen match {
case 32 => fsgnj_s
case 64 => Mux(in.bits.single, Cat(in.bits.in1 >> 33, fsgnj_s),
Cat(signNum(64), in.bits.in1(63, 0)))
}
val mux = Wire(new FPResult)
mux.exc := UInt(0)
mux.data := fsgnj
when (in.bits.cmd === FCMD_MINMAX) { val signNum = Mux(in.bits.rm(1), in.bits.in1 ^ in.bits.in2, Mux(in.bits.rm(0), ~in.bits.in2, in.bits.in2))
def doMinMax(expWidth: Int, sigWidth: Int) = { val fsgnj = Cat(signNum(fLen), in.bits.in1(fLen-1, 0))
val isnan1 = IsNaNRecFN(expWidth, sigWidth, in.bits.in1)
val isnan2 = IsNaNRecFN(expWidth, sigWidth, in.bits.in2) val mux = Wire(new FPResult)
val issnan1 = IsSNaNRecFN(expWidth, sigWidth, in.bits.in1) mux.exc := UInt(0)
val issnan2 = IsSNaNRecFN(expWidth, sigWidth, in.bits.in2) mux.data := fsgnj
val invalid = issnan1 || issnan2
val isNaNOut = invalid || (isnan1 && isnan2) when (in.bits.cmd === FCMD_MINMAX) {
val cNaN = floatTypes.filter(_.exp >= expWidth).map(CanonicalNaN(_)).reduce(_+_) val isnan1 = IsNaNRecFN(in.bits.in1, maxType)
(isnan2 || in.bits.rm(0) =/= io.lt && !isnan1, invalid, isNaNOut, cNaN) val isnan2 = IsNaNRecFN(in.bits.in2, maxType)
} val isInvalid = IsSNaNRecFN(in.bits.in1, maxType) || IsSNaNRecFN(in.bits.in2, maxType)
val (isLHS, isInvalid, isNaNOut, cNaN) = fLen match { val isNaNOut = isInvalid || (isnan1 && isnan2)
case 32 => doMinMax(sExpWidth, sSigWidth) val isLHS = isnan2 || in.bits.rm(0) =/= io.lt && !isnan1
case 64 => MuxT(in.bits.single, doMinMax(sExpWidth, sSigWidth), doMinMax(dExpWidth, dSigWidth)) mux.exc := isInvalid << 4
} mux.data := Mux(isNaNOut, CanonicalNaN(maxType), Mux(isLHS, in.bits.in1, in.bits.in2))
mux.exc := isInvalid << 4 }
mux.data := Mux(isNaNOut, cNaN, Mux(isLHS, in.bits.in1, in.bits.in2))
}
fLen match { fLen match {
case 32 => case 32 =>
@ -461,11 +437,11 @@ class IntToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) {
d2s.io.roundingMode := in.bits.rm d2s.io.roundingMode := in.bits.rm
val s2d = Module(new hardfloat.RecFNToRecFN(sExpWidth, sSigWidth, dExpWidth, dSigWidth)) val s2d = Module(new hardfloat.RecFNToRecFN(sExpWidth, sSigWidth, dExpWidth, dSigWidth))
s2d.io.in := in.bits.in1 s2d.io.in := contract(in.bits.in1, FType.S)
s2d.io.roundingMode := in.bits.rm s2d.io.roundingMode := in.bits.rm
when (in.bits.single) { when (in.bits.single) {
mux.data := Cat(s2d.io.out >> d2s.io.out.getWidth, d2s.io.out) mux.data := expand(d2s.io.out, FType.S)
mux.exc := d2s.io.exceptionFlags mux.exc := d2s.io.exceptionFlags
}.otherwise { }.otherwise {
mux.data := s2d.io.out mux.data := s2d.io.out
@ -477,28 +453,29 @@ class IntToFP(val latency: Int)(implicit p: Parameters) extends FPUModule()(p) {
io.out <> Pipe(in.valid, mux, latency-1) io.out <> Pipe(in.valid, mux, latency-1)
} }
class FPUFMAPipe(val latency: Int, expWidth: Int, sigWidth: Int)(implicit p: Parameters) extends FPUModule()(p) { class FPUFMAPipe(val latency: Int, t: FType)(implicit p: Parameters) extends FPUModule()(p) {
val io = new Bundle { val io = new Bundle {
val in = Valid(new FPInput).flip val in = Valid(new FPInput).flip
val out = Valid(new FPResult) val out = Valid(new FPResult)
} }
val width = sigWidth + expWidth val width = t.sig + t.exp
val one = UInt(1) << (width-1) val one = UInt(1) << (width-1)
val zero = (io.in.bits.in1(width) ^ io.in.bits.in2(width)) << width val zero = (io.in.bits.in1(width) ^ io.in.bits.in2(width)) << width
val valid = Reg(next=io.in.valid) val valid = Reg(next=io.in.valid)
val in = Reg(new FPInput) val in = Reg(new FPInput)
when (io.in.valid) { when (io.in.valid) {
in := io.in.bits
val cmd_fma = io.in.bits.ren3 val cmd_fma = io.in.bits.ren3
val cmd_addsub = io.in.bits.swap23 val cmd_addsub = io.in.bits.swap23
in := io.in.bits
in.in1 := contract(io.in.bits.in1, t)
in.in2 := Mux(cmd_addsub, one, contract(io.in.bits.in2, t))
in.in3 := Mux(cmd_fma || cmd_addsub, contract(io.in.bits.in3, t), zero)
in.cmd := Cat(io.in.bits.cmd(1) & (cmd_fma || cmd_addsub), io.in.bits.cmd(0)) in.cmd := Cat(io.in.bits.cmd(1) & (cmd_fma || cmd_addsub), io.in.bits.cmd(0))
when (cmd_addsub) { in.in2 := one }
unless (cmd_fma || cmd_addsub) { in.in3 := zero }
} }
val fma = Module(new hardfloat.MulAddRecFN(expWidth, sigWidth)) val fma = Module(new hardfloat.MulAddRecFN(t.exp, t.sig))
fma.io.op := in.cmd fma.io.op := in.cmd
fma.io.roundingMode := in.rm fma.io.roundingMode := in.rm
fma.io.a := in.in1 fma.io.a := in.in1
@ -506,7 +483,7 @@ class FPUFMAPipe(val latency: Int, expWidth: Int, sigWidth: Int)(implicit p: Par
fma.io.c := in.in3 fma.io.c := in.in3
val res = Wire(new FPResult) val res = Wire(new FPResult)
res.data := fma.io.out res.data := expand(fma.io.out, t)
res.exc := fma.io.exceptionFlags res.exc := fma.io.exceptionFlags
io.out := Pipe(valid, res, latency-1) io.out := Pipe(valid, res, latency-1)
} }
@ -538,9 +515,6 @@ class FPU(cfg: FPUParams)(implicit p: Parameters) extends FPUModule()(p) {
val mem_ctrl = RegEnable(ex_ctrl, req_valid) val mem_ctrl = RegEnable(ex_ctrl, req_valid)
val wb_ctrl = RegEnable(mem_ctrl, mem_reg_valid) val wb_ctrl = RegEnable(mem_ctrl, mem_reg_valid)
def expand(x: UInt, t: FType) = RecFNToRecFN_noncompliant(x, t.exp, t.sig, maxType.exp, maxType.sig)
def contract(x: UInt, t: FType) = RecFNToRecFN_noncompliant(x, maxType.exp, maxType.sig, t.exp, t.sig)
// load response // load response
val load_wb = Reg(next=io.dmem_resp_val) val load_wb = Reg(next=io.dmem_resp_val)
val load_wb_single = RegEnable(!io.dmem_resp_type(0), io.dmem_resp_val) val load_wb_single = RegEnable(!io.dmem_resp_type(0), io.dmem_resp_val)
@ -551,7 +525,7 @@ class FPU(cfg: FPUParams)(implicit p: Parameters) extends FPUModule()(p) {
case 32 => rec_s case 32 => rec_s
case 64 => case 64 =>
val rec_d = hardfloat.recFNFromFN(dExpWidth, dSigWidth, load_wb_data) val rec_d = hardfloat.recFNFromFN(dExpWidth, dSigWidth, load_wb_data)
Mux(load_wb_single, rec_s | CanonicalNaN.signaling(maxType), rec_d) Mux(load_wb_single, expand(rec_s, FType.S), rec_d)
} }
// regfile // regfile
@ -592,7 +566,7 @@ class FPU(cfg: FPUParams)(implicit p: Parameters) extends FPUModule()(p) {
} }
} }
val sfma = Module(new FPUFMAPipe(cfg.sfmaLatency, sExpWidth, sSigWidth)) val sfma = Module(new FPUFMAPipe(cfg.sfmaLatency, FType.S))
sfma.io.in.valid := req_valid && ex_ctrl.fma && ex_ctrl.single sfma.io.in.valid := req_valid && ex_ctrl.fma && ex_ctrl.single
sfma.io.in.bits := req sfma.io.in.bits := req
@ -632,7 +606,7 @@ class FPU(cfg: FPUParams)(implicit p: Parameters) extends FPUModule()(p) {
Pipe(ifpu, ifpu.latency, (c: FPUCtrlSigs) => c.fromint, ifpu.io.out.bits), Pipe(ifpu, ifpu.latency, (c: FPUCtrlSigs) => c.fromint, ifpu.io.out.bits),
Pipe(sfma, sfma.latency, (c: FPUCtrlSigs) => c.fma && c.single, sfma.io.out.bits)) ++ Pipe(sfma, sfma.latency, (c: FPUCtrlSigs) => c.fma && c.single, sfma.io.out.bits)) ++
(fLen > 32).option({ (fLen > 32).option({
val dfma = Module(new FPUFMAPipe(cfg.dfmaLatency, dExpWidth, dSigWidth)) val dfma = Module(new FPUFMAPipe(cfg.dfmaLatency, FType.D))
dfma.io.in.valid := req_valid && ex_ctrl.fma && !ex_ctrl.single dfma.io.in.valid := req_valid && ex_ctrl.fma && !ex_ctrl.single
dfma.io.in.bits := req dfma.io.in.bits := req
Pipe(dfma, dfma.latency, (c: FPUCtrlSigs) => c.fma && !c.single, dfma.io.out.bits) Pipe(dfma, dfma.latency, (c: FPUCtrlSigs) => c.fma && !c.single, dfma.io.out.bits)
@ -677,17 +651,13 @@ class FPU(cfg: FPUParams)(implicit p: Parameters) extends FPUModule()(p) {
} }
val waddr = Mux(divSqrt_wen, divSqrt_waddr, wbInfo(0).rd) val waddr = Mux(divSqrt_wen, divSqrt_waddr, wbInfo(0).rd)
val wdata0 = Mux(divSqrt_wen, divSqrt_wdata, (pipes.map(_.res.data): Seq[UInt])(wbInfo(0).pipeid)) val wdata = Mux(divSqrt_wen, divSqrt_wdata, (pipes.map(_.res.data): Seq[UInt])(wbInfo(0).pipeid))
val wsingle = Mux(divSqrt_wen, divSqrt_single, wbInfo(0).single)
val wdata = fLen match {
case 32 => wdata0
case 64 => Mux(wsingle, wdata0(32, 0) | CanonicalNaN.signaling(maxType), wdata0)
}
val wexc = (pipes.map(_.res.exc): Seq[UInt])(wbInfo(0).pipeid) val wexc = (pipes.map(_.res.exc): Seq[UInt])(wbInfo(0).pipeid)
when ((!wbInfo(0).cp && wen(0)) || divSqrt_wen) { when ((!wbInfo(0).cp && wen(0)) || divSqrt_wen) {
regfile(waddr) := wdata regfile(waddr) := wdata
if (enableCommitLog) { if (enableCommitLog) {
val wdata_unrec_s = hardfloat.fNFromRecFN(sExpWidth, sSigWidth, wdata) val wsingle = Mux(divSqrt_wen, divSqrt_single, wbInfo(0).single)
val wdata_unrec_s = hardfloat.fNFromRecFN(sExpWidth, sSigWidth, contract(wdata, FType.S))
val unrec = fLen match { val unrec = fLen match {
case 32 => wdata_unrec_s case 32 => wdata_unrec_s
case 64 => case 64 =>
@ -757,7 +727,7 @@ class FPU(cfg: FPUParams)(implicit p: Parameters) extends FPUModule()(p) {
val divSqrt_toSingle = Module(new hardfloat.RecFNToRecFN(11, 53, 8, 24)) val divSqrt_toSingle = Module(new hardfloat.RecFNToRecFN(11, 53, 8, 24))
divSqrt_toSingle.io.in := divSqrt_wdata_double divSqrt_toSingle.io.in := divSqrt_wdata_double
divSqrt_toSingle.io.roundingMode := divSqrt_rm divSqrt_toSingle.io.roundingMode := divSqrt_rm
divSqrt_wdata := Mux(divSqrt_single, divSqrt_toSingle.io.out, divSqrt_wdata_double) divSqrt_wdata := Mux(divSqrt_single, expand(divSqrt_toSingle.io.out, FType.S), divSqrt_wdata_double)
divSqrt_flags := divSqrt_flags_double | Mux(divSqrt_single, divSqrt_toSingle.io.exceptionFlags, Bits(0)) divSqrt_flags := divSqrt_flags_double | Mux(divSqrt_single, divSqrt_toSingle.io.exceptionFlags, Bits(0))
} else { } else {
when (id_ctrl.div || id_ctrl.sqrt) { io.illegal_rm := true } when (id_ctrl.div || id_ctrl.sqrt) { io.illegal_rm := true }