Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Hopper MMA by warp-tile before instruction tile #3642

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 24, 2024

Currently we ignore the warp tile parameter when scheduling Hopper matmuls (see #3636). This PR introduces a test with different CTA, warp, and instruction tiles and modifies the Hopper scheduler to split by warp tile in addition to instruction tile. Note that the instruction tile split results in two serial loop domain so we wind up executing multiple mma instructions in each main loop. In the included example, warp_tile is 64, 128, 16 and the macro is Hopper_64_8_16. In this case, there are 128/8 = 16 instruction tiles per warp tile so the generated main loop looks like this:

  #pragma unroll 3
  for(nvfuser_index_t i33 = 0; i33 < i4; ++i33) {
    nvfuser_index_t i34;
    i34 = 48 + (16 * i33);
    nvfuser_index_t i35;
    i35 = (3 + i33) % 4;
    unsigned i36;
    i36 = i7 + (8192 * i35);
    unsigned i37;
    i37 = i10 + (4096 * i35);
    nvfuser_index_t i38;
    i38 = i33 % 4;
    unsigned i39;
    i39 = i13 + (4096 * i38);
    uint64_t i40;
    i40 = 4611686293305294848ULL | ((262143ULL & (uint64_t)(i39)) >> 4ULL);
    unsigned i41;
    i41 = i15 + (8192 * i38);
    if (((Hopper::electSync(4294967295U) && b22) && b23)) {
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 8192U);
      #pragma unroll
      for(nvfuser_index_t i31 = 0; i31 < 4; ++i31) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr5, (Array<nvfuser_index_t, 2, 1>{(i6 + (64 * i31)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i36 + (2048 * i31)));
      }
      mbarrier::arriveExpectTX(toSmem((&T8[((3LL + i33) % 4)])), 4096U);
      #pragma unroll
      for(nvfuser_index_t i32 = 0; i32 < 2; ++i32) {
        Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr8, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i32)), i34}), toSmem((&T8[((3LL + i33) % 4)])) }), (i37 + (2048 * i32)));
      }
    }
    mbarrier::waitParity(toSmem((&T8[(i33 % 4)])), (uint32_t)(((i33 / 4) % 2)));
    #pragma unroll
    for(nvfuser_index_t i25 = 0; i25 < 16; ++i25) {
      unsigned i42;
      i42 = (i41 + (2048 * (i25 / 8))) + (16 * (i25 % 8));
      asm volatile(
        "{\n"
        "  .reg .pred p0; \n"
        "  setp.ne.b32 p0, %6, 0;\n"
        "  wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 {%0, %1, %2, %3}, %4, %5, p0, %7, %8, %9, %10;\n"
        "}\n"
        :"+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[0]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[1]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[2]),
         "+f"((*reinterpret_cast<Array<float, 4, 1>*>(&T2[(4 * i25)]))[3])
        :"l"(i40),
         "l"((4611686293305294848ULL | ((262143ULL & (uint64_t)(i42)) >> 4ULL))),
         "n"((uint32_t)(true)),
         "n"(1),
         "n"(1),
         "n"(1),
         "n"(1)
      );
    }
    __syncthreads();
    asm volatile("wgmma.commit_group.sync.aligned;\n");
    asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  }

Fixes #3636

@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

The bank conflict came from stmatrix scheduling which needs to be updated. I will do that in a separate PR. For now, I've disabled smem epilogue in the included test.

@jacobhinkle jacobhinkle marked this pull request as ready for review December 31, 2024 13:47
@jacobhinkle
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Dec 31, 2024

When I manually disable stmatrix but keep TMA store, I still hit a bank conflict and misaligned address in the smem read when doing the TMA store. The epilogue looks like this:

  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i50 = 0; i50 < 16; ++i50) {
    nvfuser_index_t i51;
    i51 = 4 * i50;
    #pragma unroll
    for(nvfuser_index_t i52 = 0; i52 < 2; ++i52) {
      nvfuser_index_t i53;
      i53 = i51 + (2 * i52);
      Array<__half, 2, 2> T6;
      #pragma unroll
      for(nvfuser_index_t i54 = 0; i54 < 2; ++i54) {
        T6[i54]
           = __float2half(T2[(i53 + i54)]);
      }
      loadGeneric<__half, 2>( &T7[(i17 + (128 * i52))],  &T6[0]);
    }
    __syncthreads();
    asm volatile("fence.proxy.async;\n");
    if (b24) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr19, (Array<nvfuser_index_t, 2, 1>{(i20 + (8 * i50)), i21}) }), i18);
    }
    __syncthreads();
    asm volatile("cp.async.bulk.commit_group;\n");
    asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");
  }
  asm volatile("cp.async.bulk.commit_group;\n");
  asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");

The misaligned read happens with i20 = 1152, i50 = 0, i21 = 320, i18 = 3088. Note that we have

  threadIdx.y = 3;
  i11 = ((nvfuser_index_t)threadIdx.y) / 2; // =1
  i12 = 2048 * i11; // =2048
  i14 = ((nvfuser_index_t)threadIdx.y) % 2; // =1
  i18 = (toSmem(T7) + i12) + (16 * i14); // =toSmem(T7) + 2064

CUDA Exception: Warp Misaligned Address

@jacobhinkle
Copy link
Collaborator Author

mma result before this PR:

T2_l_float[iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y61{64}, iS58{64}, iS60{8}, rS52{16}]
 root domain : (rS6{i0}, iS7{i1}, iS8{i6})
 logical domain : (iS7{i1}, iS8{i6}, rS6{i0})
 contiguity: t t n
  Split: iS7{i1} by factor 128 -> iblockIdx.y55{( ceilDiv(i1, 128) )}, iS56{128}
  Split: iS8{i6} by factor 256 -> iblockIdx.x53{( ceilDiv(i6, 256) )}, iS54{256}
  Split: rS6{i0} by factor 16 -> rS51{( ceilDiv(i0, 16) )}, rS52{16}
  Split: iS56{128} by factor 64 -> iS57{2}, iS58{64}
  Split: iS54{256} by factor 8 -> iS59{32}, iS60{8}
  Merge: iS57{2} and iS59{32} -> ithreadIdx.y61{64}
 loop domain : (iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y61{64}, iS58{64}, iS60{8}, rS52{16})

And after this PR:

T2_l_float[iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y65{4}, iS59{1}, iS63{16}, iS60{64}, iS64{8}, rS52{16}]
 root domain : (rS6{i0}, iS7{i1}, iS8{i6})
 logical domain : (iS7{i1}, iS8{i6}, rS6{i0})
 contiguity: t t n
  Split: iS7{i1} by factor 128 -> iblockIdx.y55{( ceilDiv(i1, 128) )}, iS56{128}
  Split: iS8{i6} by factor 256 -> iblockIdx.x53{( ceilDiv(i6, 256) )}, iS54{256}
  Split: rS6{i0} by factor 16 -> rS51{( ceilDiv(i0, 16) )}, rS52{16}
  Split: iS56{128} by factor 64 -> iS57{2}, iS58{64}
  Split: iS54{256} by factor 128 -> iS61{2}, iS62{128}
  Merge: iS57{2} and iS61{2} -> ithreadIdx.y65{4}
  Split: iS58{64} by factor 64 -> iS59{1}, iS60{64}
  Split: iS62{128} by factor 8 -> iS63{16}, iS64{8}
 loop domain : (iblockIdx.y55{( ceilDiv(i1, 128) )}, iblockIdx.x53{( ceilDiv(i6, 256) )}, rS51{( ceilDiv(i0, 16) )}, ithreadIdx.y65{4}, iS59{1}, iS63{16}, iS60{64}, iS64{8}, rS52{16})

@jacobhinkle
Copy link
Collaborator Author

Note that I can enable smem epilogue and the test passes if I use Hopper_64_64_16 and I disable stmatrix.

Comment on lines 47 to 49
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 is this enough or is #3616 still needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

// size
// Original: [..., M, N(, K)]
// We split this into warp tiles then instruction tiles
if (is_mma_result) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: since there is no code in common between these branches, we should split this into two separate functions.

Copy link
Collaborator

@rdspring1 rdspring1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to remove this limitation to handle all matmul parameter configurations?

CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]

Comment on lines 47 to 49
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

@rdspring1
Copy link
Collaborator

rdspring1 commented Jan 2, 2025

I see C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/executor.cpp":1421, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_INVALID_VALUE failed with error invalid argument with warp specialization enabled in test HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile

@jacobhinkle
Copy link
Collaborator Author

I see C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/runtime/executor.cpp":1421, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. CUDA error: CUDA_ERROR_INVALID_VALUE failed with error invalid argument with warp specialization enabled in test HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile

I just checked this by modifying the test to expect 5 warpgroups instead of four and by adding a WarpSpecialized(ParallelType::TIDy) argument to these circularBuffer calls:

acw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);
}
for (TensorView* bcw_smem : bcw_smems_) {
bcw_smem->circularBuffer(
params_->circular_buffer_options.smem_circular_buffer_stage,
/*prefetch_distance=*/
params_->circular_buffer_options.smem_circular_buffer_stage -
params_->circular_buffer_options
.smem_circular_buffer_prefetch_gap);

The result for me is a passing test but perf drops.

@jacobhinkle
Copy link
Collaborator Author

Do we need to remove this limitation to handle all matmul parameter configurations?

CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]

I might be confused here. The thing is that the K dimension is treated differently from the M and N dimensions in these tile definitions. The instruction tile's K dimension is clear, and the warp tile's K dimension (I think) signifies how much data we should load at a time then we can loop to compute instructions over all the loaded data. The CTA tile's M and N dimensions specify the tiling of the output, but what does the cta_tile.k signify? This is why I was thinking we'd keep this restriction.

Note that this restriction cta_tile.k == warp_tile.k is enforced on Ampere as part of scheduleWarpTileWithReduction.

EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is pretty much identical to the previous one, but it uses a MatmulOp instead of fusedMultiplySum. This is currently failing (passes on main) with

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/pass/circular_buffer.cpp":160, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. No IfThenElse should exist yet:
IF ElectSync:
  MBarrierArriveExpectTx(T9_s[i408] view( T9 ), 4096)
  FOR i372 in iB28{16}:
    FOR i375 in iB34{2}:
      FOR i373 in iB31{4}:
        FOR i376 in iB35{2}:
          FOR i374 in iB33{8}:
            T3_s___half[iblockIdx.x24{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )}, bS22{1}, iS20{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 16) )}, bS23{256}, iS26{1}, iB28{16}, iB34{2}, iB31{4}, iB35{2}, iB33{8}] ca_pos( 5 )
               = CpAsyncBulkTensorTile( T0_g___half[iS170{( (( (( getMetaData(T0) )).logical_size ))[0] )}, iS171{( (( (( getMetaData(T0) )).logical_size ))[1] )}] )

@jacobhinkle jacobhinkle added the on hold This issue should be revisited in the future label Jan 14, 2025
@jacobhinkle
Copy link
Collaborator Author

This is on hold temporarily while I investigate decoupling math warp groups by splitting by warp tile before the TMA/MMA scheduling. That would be a different approach that would let us schedule entire K loop of one math group before the next group's K loop, allowing some epilogue overlap between math groups in addition to overlapping the DMA warps.

@jacobhinkle
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Jan 17, 2025

PR Reviewer Guide 🔍

(Review updated until commit 1dccf22)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Split Logic

The new transformLikeMmaOutputWithK and transformLikeMmaOutputWithoutK methods split the tensor views differently. Verify that the logic is correct and consistent with the MMA output format.

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
    TensorView* tv) {
  // The input is originally block tiled so that the inner dims are the CTA tile
  // size
  //
  // We split this into warp tiles then instruction tiles
  // Original: [..., M, N, K]
  tv->split(-3, params_->tile_sizes.warp_tile.m);
  tv->split(-3, getM(params_->mma_macro));
  tv->split(-2, params_->tile_sizes.warp_tile.n);
  tv->split(-2, getN(params_->mma_macro));
  // K dimension is present for mma_result
  // We don't need to split by warp_tile.k, since we always have cta_tile.k==warp_tile.k
  tv->split(-1, getK(params_->mma_macro));
  // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
  tv->reorder({
      {-8, -8}, // Mo
      {-7, -6}, // Mw
      {-6, -3}, // Mi
      {-5, -7}, // No
      {-4, -5}, // Nw
      {-3, -2}, // Ni
      {-2, -4}, // Kw
      {-1, -1}, // Ki
  });
  // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
  tv->merge(-8);
  // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
  tv->axis(-7)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
}

void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
    TensorView* tv) {
  // TODO Add constraints

  // The input is originally block tiled so that the inner dims are the CTA tile
  // size
  // Original: [..., M, N]
  // We split this into warp tiles then instruction tiles
  tv->split(-2, params_->tile_sizes.warp_tile.m);
  tv->split(-2, getM(params_->mma_macro));
  tv->split(-1, params_->tile_sizes.warp_tile.n);
  tv->split(-1, getN(params_->mma_macro));
  // After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
  tv->reorder({
      {-3, -5},
      {-2, -3},
  });
  // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
  tv->merge(-6);
  // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
  tv->axis(-5)->parallelize(ParallelType::TIDy);
  // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}
Parallelization

The parallelizeBlocks method is called in several places. Ensure that the parallelization is correct and efficient.

void HopperMultipleMatmulScheduler::scheduleOperands() {
  NVF_CHECK(
      params_->async_gmem_load_operands,
      "Hopper matmul scheduler currently requires TMA to be enabled");
  auto scheduleBranch = [&](const std::vector<TensorView*>& gmem_operands,
                            const std::vector<TensorView*>& smem_operands,
                            MmaOperand operand_type) {
    blockTileTensors(smem_operands);
    parallelizeBlocks(smem_operands);
    for (TensorView* tv : smem_operands) {
      if (params_->promote_prologue_smem_reuse) {
        tv->promoteReuse();
      }
      mma_utils::orderTiledConcreteIdAsMaybeAllocationDomain(tv);
      MmaInputSmemSwizzle swizzle_type = mma_utils::tmaSwizzleSharedMemory(tv);
      tv->applyMmaSwizzleForTMALoad(swizzle_type);
    }
  };
  scheduleBranch(as_, acw_smems_, MmaOperand::A);
  scheduleBranch(bs_, bcw_smems_, MmaOperand::B);
}

void HopperMultipleMatmulScheduler::parallelizeBlocks(
    const std::vector<TensorView*>& tvs) const {
  for (TensorView* tv : tvs) {
    switch (params_->cta_order) {
      // TODO: Should we instead check the roles of these dimensions to take the
      // outermost two M or N axes?
      case MatmulParams::TileRasterizationOrder::RowMajor:
        tv->axis(num_device_and_batch_dims_)->parallelize(ParallelType::BIDx);
        tv->axis(num_device_and_batch_dims_ + 1)
            ->parallelize(ParallelType::BIDy);
        break;
      case MatmulParams::TileRasterizationOrder::ColumnMajor:
        tv->axis(num_device_and_batch_dims_)->parallelize(ParallelType::BIDy);
        tv->axis(num_device_and_batch_dims_ + 1)
            ->parallelize(ParallelType::BIDx);
        break;
      default:
        NVF_THROW("Invalid TileRasterizationOrder passed to Matmul scheduler");
    }
  }
}

void HopperMultipleMatmulScheduler::scheduleMmaResults() {
  GemmTile instruction_tile = getMmaOpShape(params_->mma_macro);
  NVF_CHECK(
      params_->tile_sizes.cta_tile.k == params_->tile_sizes.warp_tile.k,
      "CTA tile must match warp tile K dimension for Hopper matmul but found ",
      toString(params_->tile_sizes));
  // If cta_tile is not divisible by instruction tile the mma instruction will
  // be predicated.
  NVF_CHECK(
      params_->tile_sizes.cta_tile.m % instruction_tile.m == 0 &&
          params_->tile_sizes.cta_tile.n % instruction_tile.n == 0 &&
          params_->tile_sizes.cta_tile.k % instruction_tile.k == 0,
      "CTA tile must be divisible by macro size but found cta_tile: ",
      toString(params_->tile_sizes.cta_tile),
      " and macro: ",
      toString(params_->mma_macro));

  // Schedule mma results and propagate forward
  auto all_merged_roles = blockTileTensors(mma_results_);
  parallelizeBlocks(mma_results_);
  for (size_t i : c10::irange(mma_results_.size())) {
    TensorView*& mma_result = mma_results_[i];
    const std::vector<MatmulDimRole>& merged_roles = all_merged_roles[i];

    // Test that mma_result logical is MNK
    // TODO: This currently checks leaf domain only which does not necessarily
    // match logical
    // TODO: Lift this constraint. Use commitLeafToLogical if necessary. We
    // might just want to match using id_roles_
    NVF_ERROR(merged_roles.size() >= 3);
    const auto checkSingleDimRole =
        [&merged_roles](int64_t pos, MatmulDimRole expected_role) {
          if (pos < 0) {
            pos += (int64_t)merged_roles.size();
          }
          NVF_ERROR(pos >= 0);
          NVF_ERROR(pos < (int64_t)merged_roles.size());
          const auto& actual_role = merged_roles[(size_t)pos];
          NVF_ERROR(actual_role == expected_role);
        };
    checkSingleDimRole(-3, MatmulDimRole::M);
    checkSingleDimRole(-2, MatmulDimRole::N);
    checkSingleDimRole(-1, MatmulDimRole::K);

    // do split-K rFactor to define splitk_sum and smem_epilogue
    if (params_->splitk_factor != 1) {
      // Note that the split-K split is already done in blockTileTensors
      TensorView* splitk_sum = mma_result->rFactor({-4, -1});
      std::swap(splitk_sum, mma_result);
      splitk_sums_.push_back(splitk_sum);
    }

    transformLikeMmaOutputWithK(mma_result);
    auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
        mma_result->getLoopDomain());
    mma_result->setAllocationDomain(s.as<IterDomain*>(), true);
    mma_result->axis(-1)->parallelize(ParallelType::Mma);
    mma_result->axis(-2)->parallelize(ParallelType::Mma);
    mma_result->axis(-3)->parallelize(ParallelType::Mma);
  }
}

void HopperMultipleMatmulScheduler::scheduleEpilogue() {
  std::vector<TensorView*> cached_tvs;

  // Propagate to (not including) the splitk output if there is a splitk
  // else this is just mma_results_
  std::vector<TensorView*> propagate_to =
      splitk_sums_.empty() ? mma_results_ : splitk_sums_;
  if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) {
    auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT);
    // Load/cache the epilogue inputs if there are any.
    for (auto* c : c_tvs) {
      cached_tvs.push_back(c->cacheAfter());
    }
    propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end());
  }

  if (!params_->use_smem_epilogue) {
    for (Val* dv : fusion_->outputs()) {
      auto* d = dv->as<TensorView>();
      NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());

      // Schedule the output TV and propagate it back to the outputs of the Mma
      // op.
      blockTileTensors({d});
      parallelizeBlocks({d});
      transformLikeMmaOutputWithoutK(d);

      auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
          d->getLoopDomain());
      d->setLoopDomain(s.as<IterDomain*>());

      // TODO: We need to check bank conflicts in this path.
      scheduler_utils::BoundedDirectionalTransformPropagator::backward(
          d,
          -1,
          propagate_to,
          scheduler_utils::BoundedDirectionalTransformPropagator::Options()
              .propagateParallelType());

      // We don't respect vectorization_factor as yet. We vectorize the
      // inner-dim with extent 2.
      // TODO: support vectorization_factor.
      d->axis(-1)->parallelize(ParallelType::Vectorize);
      if (!cached_tvs.empty()) {
        scheduler_utils::parallelizeAllLike(d, -1, cached_tvs);
      }
    }
  } else {
    constexpr int64_t stmatrix_tile_m = 16;
    constexpr int64_t stmatrix_tile_n = 16;

    // TODO: Support tma tile sizes that are a multiple of mma_macro.
    // The wgmma operation creates an output matrix of mma_macro size. The TMA
    // tile is a multiple of the macro size because stmatrix stores results from
    // wgmma to shared memory. For maximum inlining and to reduce shared memory
    // usage, the tma tile is mma_macro size.
    const int64_t tma_m = params_->tile_sizes.warp_tile.m;
    const int64_t tma_n = params_->tile_sizes.warp_tile.n;

    fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
    fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
    fusion_->manage("st_matrix_m", tma_m);
    fusion_->manage("st_matrix_n", tma_n);

    // Manually schedule register cache and output TensorView
    for (Val* dv : fusion_->outputs()) {
      auto* d = dv->as<TensorView>();
      NVF_ERROR(d->definition() && d->definition()->isA<LoadStoreOp>());
      auto* dc = d->definition()->input(0)->as<TensorView>();

      // NOTE: cacheBefore does not work with blockTileTensors
      TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

      std::vector<TensorView*> tvs_to_schedule{d, d_smem};

      bool dc_in_mma_results =
          std::find(mma_results_.begin(), mma_results_.end(), dc) !=
          mma_results_.end();

      if (!dc_in_mma_results) {
        // Skip scheduling dc if it is an mma_result. This can happen if we are
        // not casting back to half-precision in the output
        tvs_to_schedule.push_back(dc);
      }

      // Set MemoryType
      dc->setMemoryType(MemoryType::Local);
      d_smem->setMemoryType(MemoryType::Shared);

      auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2;

      if (store_with_stmatrix) {
        // Set LoadStoreOp
        d_smem->definition()->as<LoadStoreOp>()->setOpType(
            LoadStoreOpType::StMatrix);
      }
      d->definition()->as<LoadStoreOp>()->setOpType(
          LoadStoreOpType::CpAsyncBulkTensorTile);

      // Apply the common transforms to dc, d_smem, d
Test Coverage

New tests have been added, but it's essential to verify that they cover all the necessary scenarios and edge cases.

// This tests that we can use a small instruction tile with a medium size
// warpgroup tile and a large CTA tile.
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int64_t M = 2048, N = 2048, K = 8192;
  const auto dtype = DataType::Half;

  auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
  auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  auto tv2 = fusedMultiplySum(tv0, tv1, {0});

  // Reorder the accumulator as [M, N, K]
  // [K, M, N] -> [M, N, K]
  tv2->reorder({{-3, -1}});
  tv2->commitLeafToLogical();

  auto tv3 = castOp(DataType::Half, tv2);
  fusion.addOutput(tv3);

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  auto a_ref = at::randn({K, M, 1}, options);
  auto b_ref = at::randn({K, 1, N}, options);
  auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

  MatMulTileOptions gemm_tile;
  // Regardless of the instruction, this should result in 2 warp groups i.e. 256
  // threads
  gemm_tile.cta_tile = GemmTile(256, 256, 32);
  gemm_tile.warp_tile = GemmTile(128, 128, 32);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 8};
  mparams.mma_macro = MmaMacro::Hopper_64_64_16;
  mparams.tile_sizes = gemm_tile;
  mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = false;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
  mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
  mparams.splitk_factor = 1;
  // NOTE: disabling smem use for this test since we currrently hit a bank
  // conflict.
  // TODO: enable smem epilogue once stmatrix is updated
  mparams.use_smem_epilogue = false;
  mparams.cluster_dims = {2, 1, 1};
  mparams.promote_prologue_smem_reuse = false;

  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  std::vector<c10::IValue> inputs = {a_ref, b_ref};

  KernelExecutor ke;
  ke.compile(&fusion, inputs);
  kir::Kernel* kernel = ke.compiledKernel()->kernel();
  ASSERT_TRUE(kernel != nullptr);
  EXPECT_TRUE(getBankConflictInfo(kernel).empty());
  EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

  auto cg_outputs = ke.run(inputs);

  // Check number of launched threads matches what we expect
  EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
  EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
      << " expected 4 warp groups (BIDy==4) but found BIDy=="
      << ke.lastLaunchParams().bdimy();

  // Relax tolerance for larger sum due to large K
  EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

TEST_F(HopperMatmulTest, ScheduleWithTranslation) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  constexpr int64_t M = 2048, N = 2048, K = 8192;
  const auto dtype = DataType::Half;

  auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // M, K
  auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // K, N
  // Note tv1 has allocation domain
  // tv1->setAllocationDomain({tv1->axis(1), tv1->axis(0)}, true);
  fusion.addInput(tv0);
  fusion.addInput(tv1);

  auto tv2 = matmul(tv0, tv1);

  fusion.addOutput(tv2);

  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
  auto a_ref = at::randn({M, K}, options);
  // auto b_ref = at::randn({N, K}, options).t();
  auto b_ref = at::randn({K, N}, options);
  auto out_ref = at::matmul(a_ref, b_ref);

  MatMulTileOptions gemm_tile;
  gemm_tile.cta_tile = GemmTile(128, 256, 16);
  gemm_tile.warp_tile = GemmTile(64, 64, 16);

  MatmulParams mparams;
  mparams.supported_vec_size = {8, 8, 8};
  mparams.mma_macro = MmaMacro::Hopper_64_64_16;
  mparams.tile_sizes = gemm_tile;
  mparams.cta_order = MatmulParams::TileRasterizationOrder::RowMajor;
  mparams.async_gmem_load_operands = true;
  mparams.circular_buffer_options.circular_buffer_smem_write = true;
  mparams.circular_buffer_options.circular_buffer_smem_read = false;
  mparams.circular_buffer_options.smem_circular_buffer_stage = 3;
  mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
  mparams.splitk_factor = 1;
  mparams.use_smem_epilogue = true;
  mparams.cluster_dims = {1, 1, 1};
  mparams.promote_prologue_smem_reuse = true;

  SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
      ->schedule(&fusion, &mparams);

  std::vector<c10::IValue> inputs = {a_ref, b_ref};

  KernelExecutor ke;
  ke.compile(&fusion, inputs);
  kir::Kernel* kernel = ke.compiledKernel()->kernel();
  ASSERT_TRUE(kernel != nullptr);
  EXPECT_TRUE(getBankConflictInfo(kernel).empty());
  EXPECT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(kernel));

  auto cg_outputs = ke.run(inputs);

  // Relax tolerance for larger sum due to large K
  EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
on hold This issue should be revisited in the future
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Split by warp tile in Hopper matmul scheduler
2 participants