diff --git a/src/main/scala/rocket/ALU.scala b/src/main/scala/rocket/ALU.scala index 9636bbed..7791a4ee 100644 --- a/src/main/scala/rocket/ALU.scala +++ b/src/main/scala/rocket/ALU.scala @@ -38,7 +38,7 @@ object ALU def isMulFN(fn: UInt, cmp: UInt) = fn(1,0) === cmp(1,0) def isSub(cmd: UInt) = cmd(3) - def isCmp(cmd: UInt) = cmd === FN_SEQ || cmd === FN_SNE || cmd >= FN_SLT + def isCmp(cmd: UInt) = cmd >= FN_SLT def cmpUnsigned(cmd: UInt) = cmd(1) def cmpInverted(cmd: UInt) = cmd(0) def cmpEq(cmd: UInt) = !cmd(3) @@ -64,10 +64,10 @@ class ALU(implicit p: Parameters) extends CoreModule()(p) { io.adder_out := io.in1 + in2_inv + isSub(io.fn) // SLT, SLTU - io.cmp_out := cmpInverted(io.fn) ^ - Mux(cmpEq(io.fn), in1_xor_in2 === UInt(0), + val slt = Mux(io.in1(xLen-1) === io.in2(xLen-1), io.adder_out(xLen-1), - Mux(cmpUnsigned(io.fn), io.in2(xLen-1), io.in1(xLen-1)))) + Mux(cmpUnsigned(io.fn), io.in2(xLen-1), io.in1(xLen-1))) + io.cmp_out := cmpInverted(io.fn) ^ Mux(cmpEq(io.fn), in1_xor_in2 === UInt(0), slt) // SLL, SRL, SRA val (shamt, shin_r) = @@ -88,7 +88,7 @@ class ALU(implicit p: Parameters) extends CoreModule()(p) { // AND, OR, XOR val logic = Mux(io.fn === FN_XOR || io.fn === FN_OR, in1_xor_in2, UInt(0)) | Mux(io.fn === FN_OR || io.fn === FN_AND, io.in1 & io.in2, UInt(0)) - val shift_logic = (isCmp(io.fn) && io.cmp_out) | logic | shout + val shift_logic = (isCmp(io.fn) && slt) | logic | shout val out = Mux(io.fn === FN_ADD || io.fn === FN_SUB, io.adder_out, shift_logic) io.out := out diff --git a/src/main/scala/rocket/RocketCore.scala b/src/main/scala/rocket/RocketCore.scala index f5c4b78f..7ac6b43d 100644 --- a/src/main/scala/rocket/RocketCore.scala +++ b/src/main/scala/rocket/RocketCore.scala @@ -152,6 +152,7 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) val mem_reg_raw_inst = Reg(UInt()) val mem_reg_wdata = Reg(Bits()) val mem_reg_rs2 = Reg(Bits()) + val mem_br_taken = Reg(Bool()) val take_pc_mem = Wire(Bool()) val wb_reg_valid = Reg(Bool()) @@ -361,7 +362,6 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) // memory stage val mem_pc_valid = mem_reg_valid || mem_reg_replay || mem_reg_xcpt_interrupt - val mem_br_taken = mem_reg_wdata(0) val mem_br_target = mem_reg_pc.asSInt + Mux(mem_ctrl.branch && mem_br_taken, ImmGen(IMM_SB, mem_reg_inst), Mux(mem_ctrl.jal, ImmGen(IMM_UJ, mem_reg_inst), @@ -403,6 +403,7 @@ class Rocket(implicit p: Parameters) extends CoreModule()(p) mem_reg_raw_inst := ex_reg_raw_inst mem_reg_pc := ex_reg_pc mem_reg_wdata := alu.io.out + mem_br_taken := alu.io.cmp_out when (ex_ctrl.rxs2 && (ex_ctrl.mem || ex_ctrl.rocc || ex_sfence)) { val typ = Mux(ex_ctrl.rocc, log2Ceil(xLen/8).U, ex_ctrl.mem_type)