Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(StoreUnit): writeback at s2 #4068

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/scala/xiangshan/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ case class XSCoreParameters
VirtualLoadQueueSize: Int = 72,
LoadQueueRARSize: Int = 64,
LoadQueueRAWSize: Int = 32, // NOTE: make sure that LoadQueueRAWSize is power of 2.
RollbackGroupSize: Int = 8,
RollbackGroupSize: Int = 16,
LoadQueueReplaySize: Int = 72,
LoadUncacheBufferSize: Int = 4,
LoadQueueNWriteBanks: Int = 8, // NOTE: make sure that LoadQueueRARSize/LoadQueueRAWSize is divided by LoadQueueNWriteBanks
Expand Down
24 changes: 12 additions & 12 deletions src/main/scala/xiangshan/mem/lsqueue/LoadQueueRAW.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
numRead = LoadPipelineWidth,
numWrite = LoadPipelineWidth,
numWBank = LoadQueueNWriteBanks,
numWDelay = 2,
numWDelay = 1,
numCamPort = StorePipelineWidth
))
paddrModule.io := DontCare
Expand All @@ -90,7 +90,7 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
numRead = LoadPipelineWidth,
numWrite = LoadPipelineWidth,
numWBank = LoadQueueNWriteBanks,
numWDelay = 2,
numWDelay = 1,
numCamPort = StorePipelineWidth
))
maskModule.io := DontCare
Expand Down Expand Up @@ -276,9 +276,9 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
// select logic
if (valid.length <= SelectGroupSize) {
val (selValid, selBits) = selectPartialOldest(valid, bits)
val selValidNext = GatedValidRegNext(selValid(0))
val selBitsNext = RegEnable(selBits(0), selValid(0))
(Seq(selValidNext && !selBitsNext.uop.robIdx.needFlush(RegNext(io.redirect))), Seq(selBitsNext))
val selValidNext = selValid(0)
val selBitsNext = selBits(0)
(Seq(selValidNext && !selBitsNext.uop.robIdx.needFlush(io.redirect)), Seq(selBitsNext))
} else {
val select = (0 until numSelectGroups).map(g => {
val (selValid, selBits) = selectPartialOldest(selectValidGroups(g), selectBitsGroups(g))
Expand All @@ -293,13 +293,13 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
val storeIn = io.storeIn

def detectRollback(i: Int) = {
paddrModule.io.violationMdata(i) := genPartialPAddr(RegEnable(storeIn(i).bits.paddr, storeIn(i).valid))
maskModule.io.violationMdata(i) := RegEnable(storeIn(i).bits.mask, storeIn(i).valid)
paddrModule.io.violationMdata(i) := genPartialPAddr(storeIn(i).bits.paddr)
maskModule.io.violationMdata(i) := storeIn(i).bits.mask

val addrMaskMatch = paddrModule.io.violationMmask(i).asUInt & maskModule.io.violationMmask(i).asUInt
val entryNeedCheck = GatedValidRegNext(VecInit((0 until LoadQueueRAWSize).map(j => {
val entryNeedCheck = VecInit((0 until LoadQueueRAWSize).map(j => {
allocated(j) && storeIn(i).valid && isAfter(uop(j).robIdx, storeIn(i).bits.uop.robIdx) && datavalid(j) && !uop(j).robIdx.needFlush(io.redirect)
})))
}))
val lqViolationSelVec = VecInit((0 until LoadQueueRAWSize).map(j => {
addrMaskMatch(j) && entryNeedCheck(j)
}))
Expand Down Expand Up @@ -334,10 +334,10 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
val stFtqOffset = Wire(Vec(StorePipelineWidth, UInt(log2Up(PredictWidth).W)))
for (w <- 0 until StorePipelineWidth) {
val detectedRollback = detectRollback(w)
rollbackLqWb(w).valid := detectedRollback._1 && DelayN(storeIn(w).valid && !storeIn(w).bits.miss, TotalSelectCycles)
rollbackLqWb(w).valid := detectedRollback._1 && RegNext(storeIn(w).valid && !storeIn(w).bits.miss)
rollbackLqWb(w).bits := detectedRollback._2
stFtqIdx(w) := DelayNWithValid(storeIn(w).bits.uop.ftqPtr, storeIn(w).valid, TotalSelectCycles)._2
stFtqOffset(w) := DelayNWithValid(storeIn(w).bits.uop.ftqOffset, storeIn(w).valid, TotalSelectCycles)._2
stFtqIdx(w) := RegNext(storeIn(w).bits.uop.ftqPtr)
stFtqOffset(w) := RegNext(storeIn(w).bits.uop.ftqOffset)
}

// select rollback (part2), generate rollback request, then fire rollback request
Expand Down
147 changes: 36 additions & 111 deletions src/main/scala/xiangshan/mem/pipeline/StoreUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class StoreUnit(implicit p: Parameters) extends XSModule
val s0_s1_valid = Output(Bool())
})

val s1_ready, s2_ready, s3_ready = WireInit(false.B)
val s1_ready, s2_ready = WireInit(false.B)

// Pipeline
// --------------------------------------------------------------------------------
Expand Down Expand Up @@ -257,7 +257,7 @@ class StoreUnit(implicit p: Parameters) extends XSModule
io.st_mask_out.valid := s0_use_flow_rs || s0_use_flow_vec
io.st_mask_out.bits.mask := s0_out.mask
io.st_mask_out.bits.sqIdx := s0_out.uop.sqIdx

io.stin.ready := s1_ready && s0_use_flow_rs
io.vecstin.ready := s1_ready && s0_use_flow_vec
io.prefetch_req.ready := s1_ready && io.dcache.req.ready && !s0_iss_valid && !s0_vec_valid && !s0_ma_st_valid
Expand Down Expand Up @@ -419,15 +419,15 @@ class StoreUnit(implicit p: Parameters) extends XSModule
val s2_in = RegEnable(s1_out, s1_fire)
val s2_out = Wire(new LsPipelineBundle)
val s2_kill = Wire(Bool())
val s2_can_go = s3_ready
val s2_can_go = io.stout.ready
val s2_fire = s2_valid && !s2_kill && s2_can_go
val s2_vecActive = RegEnable(s1_out.vecActive, true.B, s1_fire)
val s2_frm_mabuf = s2_in.isFrmMisAlignBuf
val s2_frm_mab_vec = RegEnable(s1_frm_mab_vec, true.B, s1_fire)
val s2_pbmt = RegEnable(s1_pbmt, s1_fire)
val s2_trigger_debug_mode = RegEnable(s1_trigger_debug_mode, false.B, s1_fire)

s2_ready := !s2_valid || s2_kill || s3_ready
s2_ready := !s2_valid || s2_kill || io.stout.ready
when (s1_fire) { s2_valid := true.B }
.elsewhen (s2_fire) { s2_valid := false.B }
.elsewhen (s2_kill) { s2_valid := false.B }
Expand Down Expand Up @@ -528,122 +528,47 @@ class StoreUnit(implicit p: Parameters) extends XSModule
io.prefetch_train.bits.updateAddrValid := false.B
io.prefetch_train.bits.hasException := false.B

// Pipeline
// --------------------------------------------------------------------------------
// stage 3
// --------------------------------------------------------------------------------
// store write back
val s3_valid = RegInit(false.B)
val s3_in = RegEnable(s2_out, s2_fire)
val s3_out = Wire(new MemExuOutput(isVector = true))
val s3_kill = s3_in.uop.robIdx.needFlush(io.redirect)
val s3_can_go = s3_ready
val s3_fire = s3_valid && !s3_kill && s3_can_go
val s3_vecFeedback = RegEnable(s2_vecFeedback, s2_fire)

// store misalign will not writeback to rob now
when (s2_fire) { s3_valid := (!s2_mmio || s2_exception) && !s2_out.isHWPrefetch && !s2_mis_align && !s2_frm_mabuf }
.elsewhen (s3_fire) { s3_valid := false.B }
.elsewhen (s3_kill) { s3_valid := false.B }

// wb: writeback
val SelectGroupSize = RollbackGroupSize
val lgSelectGroupSize = log2Ceil(SelectGroupSize)
val TotalSelectCycles = scala.math.ceil(log2Ceil(LoadQueueRAWSize).toFloat / lgSelectGroupSize).toInt + 1

s3_out := DontCare
s3_out.uop := s3_in.uop
s3_out.data := DontCare
s3_out.debug.isMMIO := s3_in.mmio
s3_out.debug.isNC := s3_in.nc
s3_out.debug.paddr := s3_in.paddr
s3_out.debug.vaddr := s3_in.vaddr
s3_out.debug.isPerfCnt := false.B

// Pipeline
// --------------------------------------------------------------------------------
// stage x
// --------------------------------------------------------------------------------
// delay TotalSelectCycles - 2 cycle(s)
val TotalDelayCycles = TotalSelectCycles - 2
val sx_valid = Wire(Vec(TotalDelayCycles + 1, Bool()))
val sx_ready = Wire(Vec(TotalDelayCycles + 1, Bool()))
val sx_in = Wire(Vec(TotalDelayCycles + 1, new VecMemExuOutput(isVector = true)))
val sx_in_vec = Wire(Vec(TotalDelayCycles +1, Bool()))

// backward ready signal
s3_ready := sx_ready.head
for (i <- 0 until TotalDelayCycles + 1) {
if (i == 0) {
sx_valid(i) := s3_valid
sx_in(i).output := s3_out
sx_in(i).vecFeedback := s3_vecFeedback
sx_in(i).nc := s3_in.nc
sx_in(i).mmio := s3_in.mmio
sx_in(i).usSecondInv := s3_in.usSecondInv
sx_in(i).elemIdx := s3_in.elemIdx
sx_in(i).alignedType := s3_in.alignedType
sx_in(i).mbIndex := s3_in.mbIndex
sx_in(i).mask := s3_in.mask
sx_in(i).vaddr := s3_in.fullva
sx_in(i).vaNeedExt := s3_in.vaNeedExt
sx_in(i).gpaddr := s3_in.gpaddr
sx_in(i).isForVSnonLeafPTE := s3_in.isForVSnonLeafPTE
sx_in(i).vecTriggerMask := s3_in.vecTriggerMask
sx_in_vec(i) := s3_in.isvec
sx_ready(i) := !s3_valid(i) || sx_in(i).output.uop.robIdx.needFlush(io.redirect) || (if (TotalDelayCycles == 0) io.stout.ready else sx_ready(i+1))
} else {
val cur_kill = sx_in(i).output.uop.robIdx.needFlush(io.redirect)
val cur_can_go = (if (i == TotalDelayCycles) io.stout.ready else sx_ready(i+1))
val cur_fire = sx_valid(i) && !cur_kill && cur_can_go
val prev_fire = sx_valid(i-1) && !sx_in(i-1).output.uop.robIdx.needFlush(io.redirect) && sx_ready(i)

sx_ready(i) := !sx_valid(i) || cur_kill || (if (i == TotalDelayCycles) io.stout.ready else sx_ready(i+1))
val sx_valid_can_go = prev_fire || cur_fire || cur_kill
sx_valid(i) := RegEnable(Mux(prev_fire, true.B, false.B), false.B, sx_valid_can_go)
sx_in(i) := RegEnable(sx_in(i-1), prev_fire)
sx_in_vec(i) := RegEnable(sx_in_vec(i-1), prev_fire)
}
}
val sx_last_valid = sx_valid.takeRight(1).head
val sx_last_ready = sx_ready.takeRight(1).head
val sx_last_in = sx_in.takeRight(1).head
val sx_last_in_vec = sx_in_vec.takeRight(1).head
sx_last_ready := !sx_last_valid || sx_last_in.output.uop.robIdx.needFlush(io.redirect) || io.stout.ready

// write back: normal store, nc store
io.stout.valid := sx_last_valid && !sx_last_in.output.uop.robIdx.needFlush(io.redirect) && !sx_last_in_vec //isStore(sx_last_in.output.uop.fuType)
io.stout.bits := sx_last_in.output
io.stout.bits.uop.exceptionVec := ExceptionNO.selectByFu(sx_last_in.output.uop.exceptionVec, StaCfg)

io.vecstout.valid := sx_last_valid && !sx_last_in.output.uop.robIdx.needFlush(io.redirect) && sx_last_in_vec //isVStore(sx_last_in.output.uop.fuType)
val s2_can_writeback = (!s2_mmio || s2_exception) && !s2_out.isHWPrefetch && !s2_mis_align && !s2_frm_mabuf
io.stout.valid := s2_valid && s2_can_writeback && io.feedback_slow.bits.hit && !s2_out.isvec
io.stout.bits.uop := s2_out.uop
io.stout.bits.data := DontCare
io.stout.bits.debug.isMMIO := s2_out.mmio
io.stout.bits.debug.isNC := s2_out.nc
io.stout.bits.debug.paddr := s2_out.paddr
io.stout.bits.debug.vaddr := s2_out.vaddr
io.stout.bits.debug.isPerfCnt := false.B
io.stout.bits.uop.exceptionVec := ExceptionNO.selectByFu(s2_out.uop.exceptionVec, StaCfg)
io.stout.bits.isFromLoadUnit := false.B

io.vecstout.valid := s2_valid && s2_can_writeback && s2_out.isvec
// TODO: implement it!
io.vecstout.bits.mBIndex := sx_last_in.mbIndex
io.vecstout.bits.hit := sx_last_in.vecFeedback
io.vecstout.bits.mBIndex := s2_out.mbIndex
io.vecstout.bits.hit := io.feedback_slow.bits.hit
io.vecstout.bits.isvec := true.B
io.vecstout.bits.sourceType := RSFeedbackType.tlbMiss
io.vecstout.bits.flushState := DontCare
io.vecstout.bits.trigger := sx_last_in.output.uop.trigger
io.vecstout.bits.nc := sx_last_in.nc
io.vecstout.bits.mmio := sx_last_in.mmio
io.vecstout.bits.exceptionVec := ExceptionNO.selectByFu(sx_last_in.output.uop.exceptionVec, VstuCfg)
io.vecstout.bits.usSecondInv := sx_last_in.usSecondInv
io.vecstout.bits.vecFeedback := sx_last_in.vecFeedback
io.vecstout.bits.elemIdx := sx_last_in.elemIdx
io.vecstout.bits.alignedType := sx_last_in.alignedType
io.vecstout.bits.mask := sx_last_in.mask
io.vecstout.bits.vaddr := sx_last_in.vaddr
io.vecstout.bits.vaNeedExt := sx_last_in.vaNeedExt
io.vecstout.bits.gpaddr := sx_last_in.gpaddr
io.vecstout.bits.isForVSnonLeafPTE := sx_last_in.isForVSnonLeafPTE
io.vecstout.bits.vstart := sx_last_in.output.uop.vpu.vstart
io.vecstout.bits.vecTriggerMask := sx_last_in.vecTriggerMask
io.vecstout.bits.trigger := s2_out.uop.trigger
io.vecstout.bits.nc := s2_out.nc
io.vecstout.bits.mmio := s2_out.mmio
io.vecstout.bits.exceptionVec := ExceptionNO.selectByFu(s2_out.uop.exceptionVec, VstuCfg)
io.vecstout.bits.usSecondInv := s2_out.usSecondInv
io.vecstout.bits.vecFeedback := io.feedback_slow.bits.hit
io.vecstout.bits.elemIdx := s2_out.elemIdx
io.vecstout.bits.alignedType := s2_out.alignedType
io.vecstout.bits.mask := s2_out.mask
io.vecstout.bits.vaddr := s2_out.vaddr
io.vecstout.bits.vaNeedExt := s2_out.vaNeedExt
io.vecstout.bits.gpaddr := s2_out.gpaddr
io.vecstout.bits.isForVSnonLeafPTE := s2_out.isForVSnonLeafPTE
io.vecstout.bits.vstart := s2_out.uop.vpu.vstart
io.vecstout.bits.vecTriggerMask := s2_out.vecTriggerMask
// io.vecstout.bits.reg_offset.map(_ := DontCare)
// io.vecstout.bits.elemIdx.map(_ := sx_last_in.elemIdx)
// io.vecstout.bits.elemIdx.map(_ := s2_out.elemIdx)
// io.vecstout.bits.elemIdxInsideVd.map(_ := DontCare)
// io.vecstout.bits.vecdata.map(_ := DontCare)
// io.vecstout.bits.mask.map(_ := DontCare)
// io.vecstout.bits.alignedType.map(_ := sx_last_in.alignedType)
// io.vecstout.bits.alignedType.map(_ := s2_out.alignedType)

io.debug_ls := DontCare
io.debug_ls.s1_robIdx := s1_in.uop.robIdx.value
Expand Down
Loading