-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Logic issue in nondeterministic reduction mode of Stream-K tile scheduler. #2027
Labels
Comments
Thanks for pointing this out. You're right: the logic seems to have been corrupted during the CUTLASS 3.4 release. We'll fix the logic in the next release. In the meantime, if you'd like to try out the correct version now, the following diff should suffice: diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
index b5e62164..4aad9da7 100644
--- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
+++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
@@ -497,8 +497,18 @@ public:
BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx);
}
else {
- // Wait until the preceding split added its accumulators
- BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
+ if (params.reduction_mode_ == ReductionMode::Deterministic) {
+ // Wait until the preceding split added its accumulators
+ BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
+ }
+ else {
+ // Wait until the first split has stored its accumulators. Note that the first split will have
+ // accumulated a value into the lock potentially greater than one (since the locked value is
+ // incremented by work_tile_info.k_tile_count below for both the deterministic and non-deterministic)
+ // cases. For non-deterministic reductions, all that non-first or last splits care about is whether
+ // the first split has been written, so we only wait while the locked value is less than 1.
+ BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1);
+ }
// Perform reduction in workspace
BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx);
@@ -512,18 +522,8 @@ public:
BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment);
}
else {
- if (
- params.reduction_mode_ == ReductionMode::Deterministic
- ) {
-
- // Wait until the preceding split added its accumulators
- BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
-
- }
- else {
- // Wait until the first split has stored its accumulators
- BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1);
- }
+ // Wait until the preceding split added its accumulators
+ BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
// The block computing the final split for the tile adds previously-reduced partials
// to its accumulators and computes the epilogue. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
The nondeterministic reduction mode of the Stream-K tile scheduler, as described here, is supposed to have all CTAs collaborating on a tile wait for the first one to store its data (to initialize the workspace), and to have the final CTA wait for all others (so that it can load the data from the workspace and compute the epilogue). This appears not to happen in the
fixup
code. Instead, all non-final CTAs (the!compute_epilogue
branch) besides the initial one wait for the previous CTA, as in the deterministic mode; while the final CTA (theelse
branch) only waits for the initial CTA. In particular, the final CTA can compute the epilogue before all non-initial, non-final CTAs have performed their reduction, leading to incorrect results. It just seems like some of the branches infixup
got swapped around, so it should be pretty simple to fix.Steps/Code to reproduce bug
Below is a reproducing example based on CUTLASS example 49. I believe the issue should trigger when the scheduler assigns at least 3 worktiles to an SM, which is going to depend in part on the specific device being used; on an H100 PCIe with 114 SMs, it triggered consistently for me on 1024x1024xK GEMMs with K >= 4096.
The text was updated successfully, but these errors were encountered: