Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Logic issue in nondeterministic reduction mode of Stream-K tile scheduler. #2027

Open
allispaul opened this issue Jan 7, 2025 · 2 comments
Labels
? - Needs Triage bug Something isn't working

Comments

@allispaul
Copy link

Describe the bug
The nondeterministic reduction mode of the Stream-K tile scheduler, as described here, is supposed to have all CTAs collaborating on a tile wait for the first one to store its data (to initialize the workspace), and to have the final CTA wait for all others (so that it can load the data from the workspace and compute the epilogue). This appears not to happen in the fixup code. Instead, all non-final CTAs (the !compute_epilogue branch) besides the initial one wait for the previous CTA, as in the deterministic mode; while the final CTA (the else branch) only waits for the initial CTA. In particular, the final CTA can compute the epilogue before all non-initial, non-final CTAs have performed their reduction, leading to incorrect results. It just seems like some of the branches in fixup got swapped around, so it should be pretty simple to fix.

Steps/Code to reproduce bug
Below is a reproducing example based on CUTLASS example 49. I believe the issue should trigger when the scheduler assigns at least 3 worktiles to an SM, which is going to depend in part on the specific device being used; on an H100 PCIe with 114 SMs, it triggered consistently for me on 1024x1024xK GEMMs with K >= 4096.

#include <cstdlib>
#include <iostream>

#include "cute/tensor.hpp"

#include "cutlass/cutlass.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"

#include "cutlass/util/command_line.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
#include "cutlass/util/reference/device/tensor_fill.h"

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////

struct Options {
  int m, n, k, l;
  float alpha, beta;
  int seed;

  Options():
    m(1024), n(1024), k(4096), l(1),
    alpha(1.f), beta(0.f),
    seed(0)
  { }

  void parse(int argc, char const **args) {
    cutlass::CommandLine cmd(argc, args);

    cmd.get_cmd_line_argument("m", m, 1024);
    cmd.get_cmd_line_argument("n", n, 1024);
    cmd.get_cmd_line_argument("k", k, 4096);
    cmd.get_cmd_line_argument("l", l, 1);
    cmd.get_cmd_line_argument("seed", seed, 0);
    cmd.get_cmd_line_argument("alpha", alpha, 1.f);
    cmd.get_cmd_line_argument("beta", beta, 0.f);
  }
};

///////////////////////////////////////////////////////////////////////////////////////////////////

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
template <
  class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto,
  class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
  class StageCountType = cutlass::gemm::collective::StageCountAuto,
  class TileSchedulerType = cutlass::gemm::PersistentScheduler,
  bool Deterministic = true
>
struct ExampleRunner {

  using LayoutA = cutlass::layout::RowMajor;
  using LayoutB = cutlass::layout::ColumnMajor;
  using LayoutC = cutlass::layout::ColumnMajor;
  using LayoutD = cutlass::layout::ColumnMajor;

  using ElementA = float;
  using ElementB = float;
  using ElementC = float;
  using ElementD = float;
  using ElementAccumulator = float;
  using ElementCompute = float;
  using ElementScalar = float;

  static constexpr int AlignmentA = 16 / sizeof(ElementA);
  static constexpr int AlignmentB = 16 / sizeof(ElementB);
  static constexpr int AlignmentC = 16 / sizeof(ElementC);
  static constexpr int AlignmentD = 16 / sizeof(ElementD);
  static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;

  using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
      Shape<_128,_128,_64>, Shape<_1,_1,_1>,
      cutlass::epilogue::collective::EpilogueTileAuto,
      ElementAccumulator, ElementCompute,
      ElementC, LayoutC, AlignmentC,
      ElementD, LayoutD, AlignmentD,
      EpilogueScheduleType,
      DefaultOperation
    >::CollectiveOp;

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
      ElementA, LayoutA, AlignmentA,
      ElementB, LayoutB, AlignmentB,
      ElementAccumulator,
      Shape<_128,_128,_64>, Shape<_2,_1,_1>,
      cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
      MainloopScheduleType
    >::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int,int,int,int>,
      CollectiveMainloop,
      CollectiveEpilogue,
      TileSchedulerType
  >;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;

  using StrideA = typename Gemm::GemmKernel::StrideA;
  using StrideB = typename Gemm::GemmKernel::StrideB;
  using StrideC = typename Gemm::GemmKernel::StrideC;
  using StrideD = typename Gemm::GemmKernel::StrideD;

  using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t<StrideA>;
  using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t<StrideB>;
  using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideC>;
  using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t<StrideD>;

  StrideA stride_A;
  StrideB stride_B;
  StrideC stride_C;
  StrideD stride_D;

  cutlass::DeviceAllocation<ElementA> block_A;
  cutlass::DeviceAllocation<ElementB> block_B;
  cutlass::DeviceAllocation<ElementC> block_C;
  cutlass::DeviceAllocation<ElementD> block_D;
  cutlass::DeviceAllocation<ElementD> block_ref_D;

  bool verify(const ProblemShapeType& problem_size, float alpha, float beta) {
    auto [M, N, K, L] = problem_size;

    cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K}));
    cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N}));
    cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N}));
    cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N}));

    cutlass::reference::device::GemmComplex(
          {M, N, K},
          ElementScalar(alpha),
          ref_A,
          cutlass::ComplexTransform::kNone,
          ref_B,
          cutlass::ComplexTransform::kNone,
          ElementScalar(beta),
          ref_C,
          ref_D,
          ElementAccumulator(0),
          L,     // batch_count
          M * K, // batch_stride_A
          K * N, // batch_stride_B
          M * N, // batch_stride_C
          M * N  // batch_stride_D
        );

    cudaError_t result = cudaDeviceSynchronize();
    if (result != cudaSuccess) {
      std::cerr << "Reference kernel failed. Last CUDA error: "
                << cudaGetErrorString(result) << std::endl;
      return false;
    }

    ElementD* hD =     static_cast<ElementD*>(malloc(M * N * L * sizeof(ElementD)));
    ElementD* hD_ref = static_cast<ElementD*>(malloc(M * N * L * sizeof(ElementD)));
    cudaMemcpy(hD, block_D.get(), M * N * L * sizeof(ElementD), cudaMemcpyDeviceToHost);
    cudaMemcpy(hD_ref, block_ref_D.get(), M * N * L * sizeof(ElementD), cudaMemcpyDeviceToHost);

    float max_diff = 0.0f;
    int max_idx = 0;
    for (int i = 0; i < M * N * L; ++i) {
      float this_diff = abs(static_cast<float>(hD[i]) - static_cast<float>(hD_ref[i]));
      if (this_diff > max_diff) {
        max_diff = this_diff;
        max_idx = i;
      }
    }
    bool passed = true;
    if (max_diff > 0.0f) {
      passed = false;
      std::cerr.precision(4);
      std::cerr << "Max absolute difference: " << max_diff << " at index " << max_idx
                << ", reference = " << hD_ref[max_idx] << ", obtained = " << hD[max_idx] << std::endl;
    }

    free(hD);
    free(hD_ref);
    return passed;
  }

  void initialize(const ProblemShapeType& problem_size, uint64_t seed) {
    auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
    auto [M, N, K, L] = problem_shape_MNKL;

    stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
    stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
    stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
    stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));

    block_A.reset(M * K * L);
    block_B.reset(K * N * L);
    block_C.reset(M * N * L);
    block_D.reset(M * N * L);
    block_ref_D.reset(M * N * L);

    ElementA* hA;
    ElementB* hB;
    ElementC* hC;

    hA = static_cast<ElementA*>(malloc(M * K * L * sizeof(ElementA)));
    hB = static_cast<ElementB*>(malloc(N * K * L * sizeof(ElementB)));
    hC = static_cast<ElementC*>(malloc(M * N * L * sizeof(ElementC)));

    srand(seed);
    for (int i = 0; i < M * K * L; ++i)
      hA[i] = static_cast<ElementA>(1.0);
      // hA[i] = static_cast<ElementA>(static_cast<double>(rand()) / RAND_MAX - 1);
    for (int i = 0; i < N * K * L; ++i)
      hB[i] = static_cast<ElementB>(1.0);
      // hB[i] = static_cast<ElementB>(static_cast<double>(rand()) / RAND_MAX - 1);
    for (int i = 0; i < M * N * L; ++i)
      hC[i] = static_cast<ElementC>(1.0);
      // hC[i] = static_cast<ElementC>(static_cast<double>(rand()) / RAND_MAX - 1);

    block_A.copy_from_host(hA);
    block_B.copy_from_host(hB);
    block_C.copy_from_host(hC);

    free(hA);
    free(hB);
    free(hC);
  }

  bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) {
    ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};

    initialize(problem_size, static_cast<uint64_t>(options.seed));

    typename Gemm::Arguments arguments{
      cutlass::gemm::GemmUniversalMode::kGemm,
      problem_size,
      {block_A.get(), stride_A, block_B.get(), stride_B},
      {{}, // epilogue.thread
       block_C.get(), stride_C, block_D.get(), stride_D},
      hw_info
    };

    arguments.epilogue.thread.alpha = options.alpha;
    arguments.epilogue.thread.beta = options.beta;
    using SchedulerParams = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams;

    if constexpr (Deterministic) {
      arguments.scheduler.reduction_mode = SchedulerParams::ReductionMode::Deterministic;
    } else {
      arguments.scheduler.reduction_mode = SchedulerParams::ReductionMode::Nondeterministic;
    }
    // Force a Stream-K schedule
    arguments.scheduler.decomposition_mode = SchedulerParams::DecompositionMode::StreamK;

    Gemm gemm_op;

    size_t workspace_size = Gemm::get_workspace_size(arguments);
    cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    cutlass::Status status = gemm_op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
      std::cerr << "This kernel is not supported. Last CUDA error is: "
                << cudaGetErrorString(cudaGetLastError()) << std::endl;
      return false;
    }

    gemm_op.initialize(arguments, workspace.get());
    gemm_op.run();
    cudaError_t result = cudaDeviceSynchronize();
    if (result != cudaSuccess) {
      std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: "
                << cudaGetErrorString(result) << std::endl;
      return false;
    }

    bool passed = verify(problem_size, options.alpha, options.beta);
    if (!passed) {
      std::cerr << "Reference check failed" << std::endl;
    }

    return passed;
  }

};

#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

///////////////////////////////////////////////////////////////////////////////////////////////////

void print_result(const std::string& description, bool passed) {
  std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl;
}

///////////////////////////////////////////////////////////////////////////////////////////////////

int main(int argc, char const **args) {

  cudaDeviceProp props;

  cudaError_t error = cudaGetDeviceProperties(&props, 0);
  if (error != cudaSuccess) {
    std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
    return -1;
  }

  Options options;

  options.parse(argc, args);

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
  cutlass::KernelHardwareInfo hw_info;

  hw_info.device_id = 0;
  hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);

  bool passed;

  ExampleRunner<
    cutlass::gemm::KernelTmaWarpSpecializedCooperative,
    cutlass::epilogue::TmaWarpSpecializedCooperative,
    cutlass::gemm::collective::StageCountAuto,
    cutlass::gemm::StreamKScheduler,
    /*Deterministic=*/false> ws_cooperative_stream_k_schedule_auto_stage_runner_nd;
  passed = ws_cooperative_stream_k_schedule_auto_stage_runner_nd.run(options, hw_info);
  print_result("Nondeterministic", passed);

  ExampleRunner<
    cutlass::gemm::KernelTmaWarpSpecializedCooperative,
    cutlass::epilogue::TmaWarpSpecializedCooperative,
    cutlass::gemm::collective::StageCountAuto,
    cutlass::gemm::StreamKScheduler,
    /*Deterministic=*/true> ws_cooperative_stream_k_schedule_auto_stage_runner;
  passed = ws_cooperative_stream_k_schedule_auto_stage_runner.run(options, hw_info);
  print_result("Deterministic", passed);

#endif

  return 0;
}
@allispaul allispaul added ? - Needs Triage bug Something isn't working labels Jan 7, 2025
@hwu36
Copy link
Collaborator

hwu36 commented Jan 8, 2025

@jackkosaian

@jackkosaian
Copy link
Contributor

Thanks for pointing this out. You're right: the logic seems to have been corrupted during the CUTLASS 3.4 release. We'll fix the logic in the next release.

In the meantime, if you'd like to try out the correct version now, the following diff should suffice:

diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
index b5e62164..4aad9da7 100644
--- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
+++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp
@@ -497,8 +497,18 @@ public:
         BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx);
       }
       else {
-        // Wait until the preceding split added its accumulators
-        BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
+        if (params.reduction_mode_ == ReductionMode::Deterministic) {
+          // Wait until the preceding split added its accumulators
+          BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
+        }
+        else {
+          // Wait until the first split has stored its accumulators. Note that the first split will have
+          // accumulated a value into the lock potentially greater than one (since the locked value is
+          // incremented by work_tile_info.k_tile_count below for both the deterministic and non-deterministic)
+          // cases. For non-deterministic reductions, all that non-first or last splits care about is whether
+          // the first split has been written, so we only wait while the locked value is less than 1.
+          BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1);
+        }

         // Perform reduction in workspace
         BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx);
@@ -512,18 +522,8 @@ public:
       BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, increment);
     }
     else {
-      if (
-        params.reduction_mode_ == ReductionMode::Deterministic
-      ) {
-
-        // Wait until the preceding split added its accumulators
-        BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);
-
-      }
-      else {
-        // Wait until the first split has stored its accumulators
-        BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1);
-      }
+      // Wait until the preceding split added its accumulators
+      BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx);

       // The block computing the final split for the tile adds previously-reduced partials
       // to its accumulators and computes the epilogue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants