diff --git a/rocket/src/main/scala/dpath_alu.scala b/rocket/src/main/scala/dpath_alu.scala index e26bff8a..841f0ec0 100644 --- a/rocket/src/main/scala/dpath_alu.scala +++ b/rocket/src/main/scala/dpath_alu.scala @@ -12,12 +12,12 @@ object ALU val FN_X = BitPat("b????") val FN_ADD = UInt(0) val FN_SL = UInt(1) + val FN_SEQ = UInt(2) + val FN_SNE = UInt(3) val FN_XOR = UInt(4) + val FN_SR = UInt(5) val FN_OR = UInt(6) val FN_AND = UInt(7) - val FN_SR = UInt(5) - val FN_SEQ = UInt(8) - val FN_SNE = UInt(9) val FN_SUB = UInt(10) val FN_SRA = UInt(11) val FN_SLT = UInt(12) @@ -35,11 +35,12 @@ object ALU val FN_MULHSU = FN_SLT val FN_MULHU = FN_SLTU - def isMulFN(fn: Bits, cmp: Bits) = fn(1,0) === cmp(1,0) - def isSub(cmd: Bits) = cmd(3) - def cmpUnsigned(cmd: Bits) = cmd(1) - def cmpInverted(cmd: Bits) = cmd(0) - def cmpEq(cmd: Bits) = !cmd(2) + def isMulFN(fn: UInt, cmp: UInt) = fn(1,0) === cmp(1,0) + def isSub(cmd: UInt) = cmd(3) + def isCmp(cmd: UInt) = cmd === FN_SEQ || cmd === FN_SNE || cmd >= FN_SLT + def cmpUnsigned(cmd: UInt) = cmd(1) + def cmpInverted(cmd: UInt) = cmd(0) + def cmpEq(cmd: UInt) = !cmd(3) } import ALU._ @@ -51,43 +52,42 @@ class ALU(implicit p: Parameters) extends CoreModule()(p) { val in1 = UInt(INPUT, xLen) val out = UInt(OUTPUT, xLen) val adder_out = UInt(OUTPUT, xLen) + val cmp_out = Bool(OUTPUT) } // ADD, SUB - val sum = io.in1 + Mux(isSub(io.fn), -io.in2, io.in2) + val in2_inv = Mux(isSub(io.fn), ~io.in2, io.in2) + val in1_xor_in2 = io.in1 ^ in2_inv + io.adder_out := io.in1 + in2_inv + isSub(io.fn) // SLT, SLTU - val cmp = cmpInverted(io.fn) ^ - Mux(cmpEq(io.fn), sum === UInt(0), - Mux(io.in1(xLen-1) === io.in2(xLen-1), sum(xLen-1), + io.cmp_out := cmpInverted(io.fn) ^ + Mux(cmpEq(io.fn), in1_xor_in2 === UInt(0), + Mux(io.in1(xLen-1) === io.in2(xLen-1), io.adder_out(xLen-1), Mux(cmpUnsigned(io.fn), io.in2(xLen-1), io.in1(xLen-1)))) // SLL, SRL, SRA - val full_shamt = io.in2(log2Up(xLen)-1,0) - val (shamt, shin_r) = - if (xLen == 32) (full_shamt, io.in1) + if (xLen == 32) (io.in2(4,0), io.in1) else { require(xLen == 64) val shin_hi_32 = Fill(32, isSub(io.fn) && io.in1(31)) val shin_hi = Mux(io.dw === DW_64, io.in1(63,32), shin_hi_32) - val shamt = Cat(full_shamt(5) & (io.dw === DW_64), full_shamt(4,0)) + val shamt = Cat(io.in2(5) & (io.dw === DW_64), io.in2(4,0)) (shamt, Cat(shin_hi, io.in1(31,0))) } val shin = Mux(io.fn === FN_SR || io.fn === FN_SRA, shin_r, Reverse(shin_r)) val shout_r = (Cat(isSub(io.fn) & shin(xLen-1), shin).toSInt >> shamt)(xLen-1,0) val shout_l = Reverse(shout_r) + val shout = Mux(io.fn === FN_SR || io.fn === FN_SRA, shout_r, UInt(0)) | + Mux(io.fn === FN_SL, shout_l, UInt(0)) - val out = - Mux(io.fn === FN_ADD || io.fn === FN_SUB, sum, - Mux(io.fn === FN_SR || io.fn === FN_SRA, shout_r, - Mux(io.fn === FN_SL, shout_l, - Mux(io.fn === FN_AND, io.in1 & io.in2, - Mux(io.fn === FN_OR, io.in1 | io.in2, - Mux(io.fn === FN_XOR, io.in1 ^ io.in2, - /* all comparisons */ cmp)))))) + // AND, OR, XOR + val logic = Mux(io.fn === FN_XOR || io.fn === FN_OR, in1_xor_in2, UInt(0)) | + Mux(io.fn === FN_OR || io.fn === FN_AND, io.in1 & io.in2, UInt(0)) + val shift_logic = (isCmp(io.fn) && io.cmp_out) | logic | shout + val out = Mux(io.fn === FN_ADD || io.fn === FN_SUB, io.adder_out, shift_logic) - io.adder_out := sum io.out := out if (xLen > 32) { require(xLen == 64)