Skip to content

Commit

Permalink
Batch: disable split strategy for FPGA to reduce gates
Browse files Browse the repository at this point in the history
Previous we enable splitting Batch data to handle peak data valid
in same cycle (we call these data collected in same cycle Stepdata).
The step data may be splitting according to remain space of Batch
state, and part of them will be appended to Batch output to make
full use of Batch DPIC.

However, this split strategy need some gates to handle splitting
and appending behaviour. Now for FPGA, we need to reduce gates of
Batch, so we disable split strategy.

According to gateCount of Palladium, disable split strategy will
reduce gates of Batch from 12M to 8M (in full Difftest).
  • Loading branch information
klin02 committed Feb 9, 2025
1 parent 153ee37 commit bbe919b
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 93 deletions.
199 changes: 106 additions & 93 deletions src/main/scala/Batch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ case class BatchParam(config: GatewayConfig, bundles: Seq[DifftestBundle]) {

// Truncate width when shifting to reduce useless gates
val TruncDataBitLen = math.min(MaxDataBitLen, StepDataBitLen)
val TruncInfoBitLen = math.min(MaxInfoBitLen, StepInfoBitLen)
}

class BatchIO(param: BatchParam) extends Bundle {
Expand Down Expand Up @@ -230,46 +229,108 @@ class BatchAssembler(
val data_bytes_avail = param.MaxDataByteLen.U -& state_status.data_bytes
// Always leave space for BatchFinish and BatchInterval, use MaxInfoSize - 2
val info_size_avail = (param.MaxInfoSize - 2).U -& state_status.info_size
val data_exceed_v = VecInit(delay_step_status.map(_.data_bytes > data_bytes_avail && delay_step_enable))
val info_exceed_v = VecInit(delay_step_status.map(_.info_size > info_size_avail && delay_step_enable))
val exceed_v = VecInit(data_exceed_v.zip(info_exceed_v).map { case (de, ie) => de | ie })
// Extract last non-exceed stats
// When no stats exceeds, return step_stats to append whole step for state flushing
val concat_mask = VecInit.tabulate(param.StepGroupSize - 1) { idx => exceed_v(idx) ^ exceed_v(idx + 1) }
val concat_stats = Mux(
!exceed_v.asUInt.orR,
delay_step_status.last,
VecInit(delay_step_status.dropRight(1).zip(concat_mask).map { case (stats, mask) =>
Mux(mask, stats.asUInt, 0.U)
}).reduceTree(_ | _).asTypeOf(new BatchStats(param)),
)
val remain_stats = Wire(new BatchStats(param))
remain_stats.data_bytes := delay_step_status.last.data_bytes -& concat_stats.data_bytes
remain_stats.info_size := delay_step_status.last.info_size -& concat_stats.info_size
assert(remain_stats.data_bytes <= param.MaxDataByteLen.U)
assert(remain_stats.info_size + 1.U <= param.MaxInfoSize.U)

val concat_data = (delay_step_data >> (remain_stats.data_bytes << 3).asUInt).asUInt
val concat_info = (delay_step_info >> (remain_stats.info_size * param.infoWidth.U)).asUInt
// Note we need only lowest bits to update state, truncate high bits to reduce gates
val remain_data = (~(~0.U(param.TruncDataBitLen.W) <<
(remain_stats.data_bytes << 3).asUInt)).asUInt & delay_step_data
val remain_info = (~(~0.U(param.TruncInfoBitLen.W) <<
(remain_stats.info_size * param.infoWidth.U))).asUInt & delay_step_info
val data_exceed = Wire(Bool())
val info_exceed = Wire(Bool())
val append_data = Wire(UInt(param.TruncDataBitLen.W))
val append_info = Wire(UInt(param.StepInfoBitLen.W))
val finish_step = Wire(UInt(config.stepWidth.W))
val next_state_step_cnt = Wire(UInt(config.stepWidth.W))
val next_state_data = Wire(UInt(param.MaxDataBitLen.W))
val next_state_info = Wire(UInt(param.MaxInfoBitLen.W))
val next_state_stats = Wire(new BatchStats(param))

// Use BatchInterval to update index of software buffer
val BatchInterval = Wire(new BatchInfo)
BatchInterval.id := Batch.getTemplate.length.U
BatchInterval.num := delay_step_status.last.info_size // unused, only for debugging
val BatchFinish = Wire(new BatchInfo)
BatchFinish.id := (Batch.getTemplate.length + 1).U
BatchFinish.num := finish_step

val step_exceed = delay_step_enable && (state_step_cnt === config.batchSize.U)
val cont_exceed = data_exceed || info_exceed
val state_flush = step_enable && step_status.last.data_bytes >= param.MaxDataByteLen.U // use Stage 1 bytes to flush ahead

if (config.batchSplit) {
val data_exceed_v = VecInit(delay_step_status.map(_.data_bytes > data_bytes_avail && delay_step_enable))
val info_exceed_v = VecInit(delay_step_status.map(_.info_size > info_size_avail && delay_step_enable))
data_exceed := data_exceed_v.asUInt.orR
info_exceed := info_exceed_v.asUInt.orR
val exceed_v = VecInit(data_exceed_v.zip(info_exceed_v).map { case (de, ie) => de | ie })

// Extract last non-exceed stats
// When no stats exceeds, return step_stats to append whole step for state flushing
val concat_mask = VecInit.tabulate(param.StepGroupSize - 1) { idx => exceed_v(idx) ^ exceed_v(idx + 1) }
val concat_stats = Mux(
!exceed_v.asUInt.orR,
delay_step_status.last,
VecInit(delay_step_status.dropRight(1).zip(concat_mask).map { case (stats, mask) =>
Mux(mask, stats.asUInt, 0.U)
}).reduceTree(_ | _).asTypeOf(new BatchStats(param)),
)
val remain_stats = Wire(new BatchStats(param))
remain_stats.data_bytes := delay_step_status.last.data_bytes -& concat_stats.data_bytes
remain_stats.info_size := delay_step_status.last.info_size -& concat_stats.info_size
assert(remain_stats.data_bytes <= param.MaxDataByteLen.U)
assert(remain_stats.info_size + 1.U <= param.MaxInfoSize.U)

val concat_data = (delay_step_data >> (remain_stats.data_bytes << 3).asUInt).asUInt
val concat_info = (delay_step_info >> (remain_stats.info_size * param.infoWidth.U)).asUInt
// Note we need only lowest bits to update state, truncate high bits to reduce gates
val remain_data = (~(~0.U(param.TruncDataBitLen.W) <<
(remain_stats.data_bytes << 3).asUInt)).asUInt & delay_step_data
val remain_info = (~(~0.U(param.StepInfoBitLen.W) <<
(remain_stats.info_size * param.infoWidth.U))).asUInt & delay_step_info

// Delay step can be partly appended to output for making full use of transmission param
// Avoid appending when step equals batchSize(delay_step_exceed), last appended data will overwrite first step data
val has_append = delay_step_enable && (state_flush || cont_exceed) && !exceed_v.asUInt.andR && !step_exceed
// When the whole step is appended to output, state_step should be 0, and output step + 1
val append_whole = has_append && !cont_exceed
finish_step := state_step_cnt + Mux(append_whole, 1.U, 0.U)

append_data := Mux(has_append, concat_data(param.TruncDataBitLen - 1, 0), 0.U)
val append_finish_map = Seq.tabulate(param.StepGroupSize) { g =>
(g.U, (BatchFinish.asUInt << (g * param.infoWidth)).asUInt)
}
append_info := Mux(
has_append,
Cat(concat_info | LookupTree(concat_stats.info_size, append_finish_map), BatchInterval.asUInt),
BatchFinish.asUInt,
)

next_state_step_cnt := Mux(has_append && append_whole, 0.U, 1.U)
next_state_data := Mux(has_append, remain_data, delay_step_data)
next_state_info := Mux(has_append, remain_info, Cat(delay_step_info, BatchInterval.asUInt))
next_state_stats.data_bytes := Mux(has_append, remain_stats.data_bytes, delay_step_status.last.data_bytes)
next_state_stats.info_size := Mux(has_append, remain_stats.info_size, delay_step_status.last.info_size + 1.U)
} else {
data_exceed := delay_step_enable && delay_step_status.last.data_bytes > data_bytes_avail
info_exceed := delay_step_enable && delay_step_status.last.info_size > info_size_avail
assert(delay_step_status.last.data_bytes <= param.MaxDataByteLen.U)
assert(delay_step_status.last.info_size <= param.MaxInfoSize.U)

finish_step := state_step_cnt
append_data := 0.U
append_info := BatchFinish.asUInt

next_state_step_cnt := 1.U
next_state_data := delay_step_data
next_state_info := Cat(delay_step_info, BatchInterval.asUInt)
next_state_stats.data_bytes := delay_step_status.last.data_bytes
next_state_stats.info_size := delay_step_status.last.info_size + 1.U
}

// Stage 2:
// update state
val step_exceed = delay_step_enable && (state_step_cnt === config.batchSize.U)
val cont_exceed = exceed_v.asUInt.orR
val trace_exceed = Option.when(config.hasReplay) {
delay_step_enable && (state_trace_size.get +& delay_step_trace_info.get.trace_size >= config.replaySize.U)
}
val state_flush = step_enable && step_status.last.data_bytes >= param.MaxDataByteLen.U // use Stage 1 bytes to flush ahead
val timeout_count = RegInit(0.U(32.W))
val timeout = timeout_count === 200000.U
if (config.hasBuiltInPerf) {
DifftestPerf("BatchExceed_data", data_exceed_v.asUInt.orR)
DifftestPerf("BatchExceed_info", info_exceed_v.asUInt.orR)
DifftestPerf("BatchExceed_data", data_exceed)
DifftestPerf("BatchExceed_info", info_exceed)
DifftestPerf("BatchExceed_step", step_exceed.asUInt)
DifftestPerf("BatchExceed_flush", state_flush.asUInt)
DifftestPerf("BatchExceed_timeout", timeout.asUInt)
Expand All @@ -284,85 +345,37 @@ class BatchAssembler(
}.otherwise {
timeout_count := 0.U
}
// Delay step can be partly appended to output for making full use of transmission param
// Avoid appending when step equals batchSize(delay_step_exceed), last appended data will overwrite first step data
val has_append =
delay_step_enable && (state_flush || cont_exceed) && !exceed_v.asUInt.andR && !step_exceed
// When the whole step is appended to output, state_step should be 0, and output step + 1
val append_whole = has_append && !cont_exceed
val finish_step = state_step_cnt + Mux(append_whole, 1.U, 0.U)
val BatchFinish = Wire(new BatchInfo)
BatchFinish.id := (Batch.getTemplate.length + 1).U
BatchFinish.num := finish_step
// Use BatchInterval to update index of software buffer
val BatchInterval = Wire(new BatchInfo)
BatchInterval.id := Batch.getTemplate.length.U
BatchInterval.num := delay_step_status.last.info_size // unused, only for debugging
val append_finish_map = Seq.tabulate(param.StepGroupSize) { g =>
(g.U, (BatchFinish.asUInt << (g * param.infoWidth)).asUInt)
}
val append_info = Cat(concat_info | LookupTree(concat_stats.info_size, append_finish_map), BatchInterval.asUInt)

out.io.data := state_data |
Mux(has_append, (concat_data(param.TruncDataBitLen - 1, 0) << (state_status.data_bytes << 3).asUInt).asUInt, 0.U)
out.io.info := state_info |
(Mux(has_append, append_info, BatchFinish.asUInt)(
param.TruncInfoBitLen - 1,
0,
) << (state_status.info_size * param.infoWidth.U)).asUInt

out.io.data := state_data | (append_data << (state_status.data_bytes << 3).asUInt).asUInt
out.io.info := state_info | (append_info << (state_status.info_size * param.infoWidth.U)).asUInt
out.enable := should_tick
out.step := Mux(out.enable, finish_step, 0.U)

val next_state_data_bytes = Mux(
delay_step_enable,
Mux(
should_tick,
Mux(has_append, remain_stats.data_bytes, delay_step_status.last.data_bytes),
state_status.data_bytes + delay_step_status.last.data_bytes,
),
0.U,
)
val next_state_info_size = Mux(
delay_step_enable,
Mux(
should_tick,
Mux(has_append, remain_stats.info_size, delay_step_status.last.info_size + 1.U),
state_status.info_size + delay_step_status.last.info_size + 1.U,
),
0.U,
)
val state_update = delay_step_enable || state_flush || timeout
when(state_update) {
state_status.data_bytes := next_state_data_bytes
state_status.info_size := next_state_info_size
when(delay_step_enable) {
when(should_tick) {
when(has_append) { // include state_flush with new-coming step
state_step_cnt := Mux(append_whole, 0.U, 1.U)
state_data := remain_data
state_info := remain_info
}.otherwise {
state_step_cnt := 1.U
state_data := delay_step_data
state_info := Cat(delay_step_info, BatchInterval.asUInt)
}
state_step_cnt := next_state_step_cnt
state_data := next_state_data
state_info := next_state_info
state_status := next_state_stats
if (config.hasReplay) state_trace_size.get := delay_step_trace_info.get.trace_size
}.otherwise {
state_step_cnt := state_step_cnt + 1.U
val append_step_data =
state_data := state_data |
(delay_step_data(param.TruncDataBitLen - 1, 0) << (state_status.data_bytes << 3).asUInt).asUInt
val append_step_info =
(Cat(delay_step_info(param.TruncInfoBitLen - 1, 0), BatchInterval.asUInt)
<< (state_status.info_size * param.infoWidth.U)).asUInt
state_data := state_data | append_step_data
state_info := state_info | append_step_info
state_info := state_info |
(Cat(delay_step_info, BatchInterval.asUInt) << (state_status.info_size * param.infoWidth.U)).asUInt
state_status.data_bytes := state_status.data_bytes + delay_step_status.last.data_bytes
state_status.info_size := state_status.info_size + delay_step_status.last.info_size + 1.U
if (config.hasReplay) state_trace_size.get := state_trace_size.get + delay_step_trace_info.get.trace_size
}
}.otherwise { // state_flush without new-coming step
state_step_cnt := 0.U
state_data := 0.U
state_info := 0.U
state_status.data_bytes := 0.U
state_status.info_size := 0.U
if (config.hasReplay) state_trace_size.get := 0.U
}
}
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/Gateway.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ case class GatewayConfig(
def stepWidth: Int = log2Ceil(maxStep + 1)
def replayWidth: Int = log2Ceil(replaySize + 1)
def batchArgByteLen: (Int, Int) = if (isNonBlock || isFPGA) (3600, 400) else (7200, 800)
def batchSplit: Boolean = !isFPGA // Disable split for FPGA to reduce gates
def hasDeferredResult: Boolean = isNonBlock || hasInternalStep
def needTraceInfo: Boolean = hasReplay
def needEndpoint: Boolean =
Expand Down

0 comments on commit bbe919b

Please sign in to comment.