diff --git a/rocket/src/main/scala/rocket.scala b/rocket/src/main/scala/rocket.scala index 482c7dd9..013fc2a8 100644 --- a/rocket/src/main/scala/rocket.scala +++ b/rocket/src/main/scala/rocket.scala @@ -417,19 +417,18 @@ class Rocket extends CoreModule io.imem.ras_update.bits.returnAddr := mem_int_wdata // stall for RAW/WAW hazards on CSRs, loads, AMOs, and mul/div in execute stage. - val id_renx1_not0 = id_ctrl.rxs1 && id_raddr1 != UInt(0) - val id_renx2_not0 = id_ctrl.rxs2 && id_raddr2 != UInt(0) - val id_wen_not0 = id_ctrl.wxd && id_waddr != UInt(0) + val hazard_targets = Seq((id_ctrl.rxs1 && id_raddr1 != UInt(0), id_raddr1), + (id_ctrl.rxs2 && id_raddr2 != UInt(0), id_raddr2), + (id_ctrl.wxd && id_waddr != UInt(0), id_waddr)) + val fp_hazard_targets = Seq((io.fpu.dec.ren1, id_raddr1), + (io.fpu.dec.ren2, id_raddr2), + (io.fpu.dec.ren3, id_raddr3), + (io.fpu.dec.wen, id_waddr)) + + val id_sboard_hazard = checkHazards(hazard_targets, sboard.readBypassed _) val ex_cannot_bypass = ex_ctrl.csr != CSR.N || ex_ctrl.jalr || ex_ctrl.mem || ex_ctrl.div || ex_ctrl.fp || ex_ctrl.rocc - val data_hazard_ex = ex_ctrl.wxd && - (id_renx1_not0 && id_raddr1 === ex_waddr || - id_renx2_not0 && id_raddr2 === ex_waddr || - id_wen_not0 && id_waddr === ex_waddr) - val fp_data_hazard_ex = ex_ctrl.wfd && - (io.fpu.dec.ren1 && id_raddr1 === ex_waddr || - io.fpu.dec.ren2 && id_raddr2 === ex_waddr || - io.fpu.dec.ren3 && id_raddr3 === ex_waddr || - io.fpu.dec.wen && id_waddr === ex_waddr) + val data_hazard_ex = ex_ctrl.wxd && checkHazards(hazard_targets, _ === ex_waddr) + val fp_data_hazard_ex = ex_ctrl.wfd && checkHazards(fp_hazard_targets, _ === ex_waddr) val id_ex_hazard = ex_reg_valid && (data_hazard_ex && ex_cannot_bypass || fp_data_hazard_ex) // stall for RAW/WAW hazards on CSRs, LB/LH, and mul/div in memory stage. @@ -437,35 +436,16 @@ class Rocket extends CoreModule if (params(FastLoadWord)) Bool(!params(FastLoadByte)) && mem_reg_slow_bypass else Bool(true) val mem_cannot_bypass = mem_ctrl.csr != CSR.N || mem_ctrl.mem && mem_mem_cmd_bh || mem_ctrl.div || mem_ctrl.fp || mem_ctrl.rocc - val data_hazard_mem = mem_ctrl.wxd && - (id_renx1_not0 && id_raddr1 === mem_waddr || - id_renx2_not0 && id_raddr2 === mem_waddr || - id_wen_not0 && id_waddr === mem_waddr) - val fp_data_hazard_mem = mem_ctrl.wfd && - (io.fpu.dec.ren1 && id_raddr1 === mem_waddr || - io.fpu.dec.ren2 && id_raddr2 === mem_waddr || - io.fpu.dec.ren3 && id_raddr3 === mem_waddr || - io.fpu.dec.wen && id_waddr === mem_waddr) + val data_hazard_mem = mem_ctrl.wxd && checkHazards(hazard_targets, _ === mem_waddr) + val fp_data_hazard_mem = mem_ctrl.wfd && checkHazards(fp_hazard_targets, _ === mem_waddr) val id_mem_hazard = mem_reg_valid && (data_hazard_mem && mem_cannot_bypass || fp_data_hazard_mem) id_load_use := mem_reg_valid && data_hazard_mem && mem_ctrl.mem // stall for RAW/WAW hazards on load/AMO misses and mul/div in writeback. - val data_hazard_wb = wb_ctrl.wxd && - (id_renx1_not0 && id_raddr1 === wb_waddr || - id_renx2_not0 && id_raddr2 === wb_waddr || - id_wen_not0 && id_waddr === wb_waddr) - val fp_data_hazard_wb = wb_ctrl.wfd && - (io.fpu.dec.ren1 && id_raddr1 === wb_waddr || - io.fpu.dec.ren2 && id_raddr2 === wb_waddr || - io.fpu.dec.ren3 && id_raddr3 === wb_waddr || - io.fpu.dec.wen && id_waddr === wb_waddr) + val data_hazard_wb = wb_ctrl.wxd && checkHazards(hazard_targets, _ === wb_waddr) + val fp_data_hazard_wb = wb_ctrl.wfd && checkHazards(fp_hazard_targets, _ === wb_waddr) val id_wb_hazard = wb_reg_valid && (data_hazard_wb && wb_set_sboard || fp_data_hazard_wb) - val id_sboard_hazard = - (id_renx1_not0 && sboard.readBypassed(id_raddr1) || - id_renx2_not0 && sboard.readBypassed(id_raddr2) || - id_wen_not0 && sboard.readBypassed(id_waddr)) - sboard.set(wb_set_sboard && wb_wen, wb_waddr) val id_stall_fpu = if (!params(BuildFPU).isEmpty) { @@ -531,6 +511,9 @@ class Rocket extends CoreModule def checkExceptions(x: Seq[(Bool, UInt)]) = (x.map(_._1).reduce(_||_), PriorityMux(x)) + def checkHazards(targets: Seq[(Bool, UInt)], cond: UInt => Bool) = + targets.map(h => h._1 && cond(h._2)).reduce(_||_) + def imm(sel: Bits, inst: Bits) = { val sign = Mux(sel === IMM_Z, SInt(0), inst(31).toSInt) val b30_20 = Mux(sel === IMM_U, inst(30,20).toSInt, sign)