diff --git a/rocket/src/main/scala/dpath_alu.scala b/rocket/src/main/scala/dpath_alu.scala index 6c4eadfd..2635f89e 100644 --- a/rocket/src/main/scala/dpath_alu.scala +++ b/rocket/src/main/scala/dpath_alu.scala @@ -19,73 +19,41 @@ class ioALU extends Bundle(){ class rocketDpathALU extends Component { val io = new ioALU(); - - val out64 = - MuxCase(Fix(0, 64), Array( - (io.fn === FN_ADD) -> (io.in1 + io.in2).toFix, - (io.fn === FN_SUB) -> (io.in1 - io.in2).toFix, - (io.fn === FN_SLT) -> (io.in1.toFix < io.in2.toFix), //(io.in1 < io.in2) - (io.fn === FN_SLTU) -> (io.in1 < io.in2).toFix, - (io.fn === FN_AND) -> (io.in1 & io.in2).toFix, - (io.fn === FN_OR) -> (io.in1 | io.in2).toFix, - (io.fn === FN_XOR) -> (io.in1 ^ io.in2).toFix, - (io.fn === FN_SL) -> (io.in1 << io.shamt).toFix, - (io.fn === FN_SR && io.dw === DW_64) -> (io.in1 >> io.shamt).toFix, - (io.fn === FN_SR && io.dw === DW_32) -> (Cat(Fix(0, 32),io.in1(31, 0)).toUFix >> io.shamt), - (io.fn === FN_SRA) -> (io.in1.toFix >>> io.shamt))); - - io.out := MuxLookup(io.dw, Fix(0, 64), Array( - DW_64 -> out64(63,0), - DW_32 -> Cat(Fill(32, out64(31)), out64(31,0)).toFix)).toUFix; + // ADD, SUB + val sub = (io.fn === FN_SUB) || (io.fn === FN_SLT) || (io.fn === FN_SLTU) + val adder_rhs = Mux(sub, ~io.in2, io.in2) + val adder_out = (io.in1 + adder_rhs + sub.toUFix)(63,0) + + // SLT, SLTU + val less = Mux(io.in1(63) === io.in2(63), adder_out(63), io.in1(63)) + val lessu = Mux(io.in1(63) === io.in2(63), adder_out(63), io.in2(63)) + + // SLL, SRL, SRA + val sra = (io.fn === FN_SRA) + val shright = sra || (io.fn === FN_SR) + val shin_hi_32 = Mux(sra, Fill(32, io.in1(31)), UFix(0,32)) + val shin_hi = Mux(io.dw === DW_64, io.in1(63,32), shin_hi_32) + val shin_r = Cat(shin_hi, io.in1(31,0)) + val shin = Mux(shright, shin_r, Reverse(shin_r)) + val shout_r = (Cat(sra & shin_r(63), shin).toFix >>> io.shamt)(63,0) + + val out64 = Wire { Bits(64) } + switch(io.fn) + { + is(FN_ADD) { out64 <== adder_out } + is(FN_SUB) { out64 <== adder_out } + is(FN_SLT) { out64 <== less } + is(FN_SLTU) { out64 <== lessu } + is(FN_AND) { out64 <== io.in1 & io.in2 } + is(FN_OR) { out64 <== io.in1 | io.in2 } + is(FN_XOR) { out64 <== io.in1 ^ io.in2 } + is(FN_SL) { out64 <== Reverse(shout_r) } + } + out64 <== shout_r + + val out_hi = Mux(io.dw === DW_64, out64(63,32), Fill(32, out64(31))) + io.out := Cat(out_hi, out64(31,0)).toUFix } -/* -class IoDpathALU extends Bundle { - val in0 = Bits(32,'input); - val in1 = Bits(32,'input); - val fn = Bits(4,'input); - val out = Bits(32,'output); -} - -class DpathALU extends Component { - val io = new IoDpathALU(); - - val adder_in0 = MuxCase(io.in0,Array( - ((io.fn === FN_SUB) | (io.fn === FN_SLT) | (io.fn === FN_SLTU)) -> (~io.in0))); - - val adder_in1 = io.in1; - val adder_cin = MuxCase(Bits(0),Array( - ((io.fn === FN_SUB) | (io.fn === FN_SLT) | (io.fn === FN_SLTU)) -> Bits(1))); - - // Need to make the same width? - val adder_out = Cat(Bits(0,1),adder_in1).toUFix + Cat(Bits(0,1),adder_in0).toUFix + adder_cin.toUFix; - //adder_out := (adder_in1.toUFix + adder_in0.toUFix + adder_cin.toUFix); - - // Determine if there is overflow - val overflow = (io.in0(31) ^ io.in1(31)) & (adder_out(32) != io.in0(31)); - - val compare_yes = MuxLookup(io.fn,Bits(0),Array( - // If unsigned, do subtraction, and if the result is negative, then slt=true - FN_SLTU -> ~adder_out(32), - // If signed, do subtraction, and if the result is negative, then slt=true as well - // But if there is bad overflow (operands same sign and result is a different sign), - // then need to flip - FN_SLT -> ~(adder_out(32) ^ overflow))); - - io.out := MuxLookup(io.fn,Fix(0),Array( - FN_ADD -> adder_out, - FN_SUB -> adder_out, - FN_SLT -> compare_yes, - FN_SLTU -> compare_yes, - FN_AND -> (io.in0 & io.in1), - FN_OR -> (io.in0 | io.in1), - FN_XOR -> (io.in0 ^ io.in1), - FN_SL -> (io.in1 << io.in0(4,0).toUFix), - FN_SR -> (io.in1 >> io.in0(4,0).toUFix), - FN_SRA -> (io.in1.toFix >> io.in0(4,0).toUFix) - )); -} -*/ - } diff --git a/rocket/src/main/scala/util.scala b/rocket/src/main/scala/util.scala index 07e7ae41..06d3a986 100644 --- a/rocket/src/main/scala/util.scala +++ b/rocket/src/main/scala/util.scala @@ -15,12 +15,23 @@ object FillInterleaved def apply(n: Int, in: Bits) = { var out = Fill(n, in(0)) - for (i <- 1 until in.width) + for (i <- 1 until in.getWidth) out = Cat(Fill(n, in(i)), out) out } } +object Reverse +{ + def apply(in: Bits) = + { + var out = in(in.getWidth-1) + for (i <- 1 until in.getWidth) + out = Cat(in(in.getWidth-i-1), out) + out + } +} + class Mux1H(n: Int, w: Int) extends Component { val io = new Bundle {