diff --git a/rocket/src/main/scala/multiplier.scala b/rocket/src/main/scala/multiplier.scala index ce063399..9770d632 100644 --- a/rocket/src/main/scala/multiplier.scala +++ b/rocket/src/main/scala/multiplier.scala @@ -58,13 +58,16 @@ class MulDiv( FN_MULHU -> List(Y, Y, N, N), FN_MULHSU -> List(Y, Y, Y, N))).map(_ toBool) - def sext(x: Bits, signed: Bool) = { - val sign = signed && Mux(io.req.bits.dw === DW_64, x(w-1), x(w/2-1)) - val hi = Mux(io.req.bits.dw === DW_64, x(w-1,w/2), Fill(w/2, sign)) + require(w == 32 || w == 64) + def halfWidth(req: MultiplierReq) = Bool(w > 32) && req.dw === DW_32 + + def sext(x: Bits, halfW: Bool, signed: Bool) = { + val sign = signed && Mux(halfW, x(w/2-1), x(w-1)) + val hi = Mux(halfW, Fill(w/2, sign), x(w-1,w/2)) (Cat(hi, x(w/2-1,0)), sign) } - val (lhs_in, lhs_sign) = sext(io.req.bits.in1, lhsSigned) - val (rhs_in, rhs_sign) = sext(io.req.bits.in2, rhsSigned) + val (lhs_in, lhs_sign) = sext(io.req.bits.in1, halfWidth(io.req.bits), lhsSigned) + val (rhs_in, rhs_sign) = sext(io.req.bits.in2, halfWidth(io.req.bits), rhsSigned) val subtractor = remainder(2*w,w) - divisor(w,0) val less = subtractor(w) @@ -143,7 +146,7 @@ class MulDiv( } io.resp.bits := req - io.resp.bits.data := Mux(req.dw === DW_32, Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) + io.resp.bits.data := Mux(halfWidth(req), Cat(Fill(w/2, remainder(w/2-1)), remainder(w/2-1,0)), remainder(w-1,0)) io.resp.valid := state === s_done io.req.ready := state === s_ready }