From d68cecad21dc2ac0e20340bebf3b6b98a6d12876 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 16 Oct 2024 00:59:05 +0000 Subject: [PATCH 01/10] Use smaller width for lse_accum dist tensor --- ...ha_fwd_splitkv_combine_pipeline_default_policy.hpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 3327d4af87..4f78648c0c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -134,15 +134,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = max(Problem::kMaxSplits, get_warp_size()); + constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kM0; - constexpr index_t NThreads = get_warp_size(); + constexpr index_t NThreads = 4; constexpr index_t NPerThread = kNPerBlock / NThreads; - constexpr index_t MThreads = kBlockSize / NThreads; - constexpr index_t MPerThread = kMPerBlock / MThreads; + constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; + constexpr index_t MWarps = kMPerBlock / MThreadPerWarp; + constexpr index_t MThreads = MThreadPerWarp * MWarps; + constexpr index_t MPerThread = kMPerBlock / MThreads; + static_assert(kBlockSize % MWarps == 0); static_assert(NThreads * NPerThread == kNPerBlock); static_assert(MThreads * MPerThread == kMPerBlock); From f92661a9dae4b3a6eb1763729cb2bb98e6bed9ca Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 16 Oct 2024 01:06:55 +0000 Subject: [PATCH 02/10] Update pipeline comment --- .../pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 1afe0feab3..a945b3cfbd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -138,9 +138,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto lse_accum = make_static_distributed_tensor( Policy::template MakeLSEaccRegTileDistribution()); - // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)]) - // this will extend the distributed tensor width so that each thread in wave have data to - // reduce. + // copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits]) + // and fill up -INF values outside the [kM0, num_splits] region. { constexpr auto spans = decltype(lse_accum)::get_distributed_spans(); sweep_tile_span(spans[number<0>{}], [&](auto idx0) { From 253c9ba23252a8f0671aa42f3998136991b17ac8 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 16 Oct 2024 01:40:38 +0000 Subject: [PATCH 03/10] Fix wrong distribution for lse_accum --- ..._fmha_fwd_splitkv_combine_pipeline_default_policy.hpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 4f78648c0c..f3290a7969 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -149,14 +149,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy static_assert(NThreads * NPerThread == kNPerBlock); static_assert(MThreads * MPerThread == kMPerBlock); + // duplicate MWarps if less than (kBlockSize / get_warp_size()) return make_static_tile_distribution( tile_distribution_encoding< sequence<1>, - tuple, sequence>, - tuple, sequence<2>>, - tuple, sequence<0>>, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 1>>, sequence<1, 2>, - sequence<1, 1>>{}); + sequence<2, 1>>{}); } template From 3e15078fe6ec14eaa084a840e2dd21f0ae8393dd Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 16 Oct 2024 07:48:27 +0000 Subject: [PATCH 04/10] Remove duplicate dim in lse_accum dist encoding --- ..._fmha_fwd_splitkv_combine_pipeline_default_policy.hpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index f3290a7969..f8a47ed751 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -132,8 +132,6 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kM0; @@ -145,17 +143,16 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy constexpr index_t MThreads = MThreadPerWarp * MWarps; constexpr index_t MPerThread = kMPerBlock / MThreads; - static_assert(kBlockSize % MWarps == 0); + static_assert(MWarps <= 4); static_assert(NThreads * NPerThread == kNPerBlock); static_assert(MThreads * MPerThread == kMPerBlock); - // duplicate MWarps if less than (kBlockSize / get_warp_size()) return make_static_tile_distribution( tile_distribution_encoding< sequence<1>, tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<0, 1>>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 1>>, sequence<1, 2>, sequence<2, 1>>{}); } From 81cb4c4f56317b986fe5363446cc959168cd3ed8 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 16 Oct 2024 07:58:14 +0000 Subject: [PATCH 05/10] Decide fmha splitkv combine kernel kBlockSize by kM0 --- .../ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index d254f07e2d..1846664e7d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem using ODataType = remove_cvref_t; using Traits = remove_cvref_t; - static constexpr index_t kBlockSize = 256; + static constexpr index_t kNumWarps = kM0_ / (get_warp_size() / 4); + static constexpr index_t kBlockSize = kNumWarps * get_warp_size(); static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr index_t kHeadDimV = HeadDimV_; From 9df2983cc0ae8cb45400a003693b1c28a1111059 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Wed, 16 Oct 2024 08:32:31 +0000 Subject: [PATCH 06/10] Remove assumption of MPerThread=1 --- ...ha_fwd_splitkv_combine_pipeline_default_policy.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index f8a47ed751..56ec85888b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -132,20 +132,21 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccRegTileDistribution() { + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::kMaxSplits; constexpr index_t kMPerBlock = Problem::kM0; constexpr index_t NThreads = 4; constexpr index_t NPerThread = kNPerBlock / NThreads; - constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; - constexpr index_t MWarps = kMPerBlock / MThreadPerWarp; - constexpr index_t MThreads = MThreadPerWarp * MWarps; + constexpr index_t MThreads = kBlockSize / NThreads; constexpr index_t MPerThread = kMPerBlock / MThreads; + constexpr index_t MWarps = kBlockSize / get_warp_size(); + constexpr index_t MThreadPerWarp = get_warp_size() / NThreads; - static_assert(MWarps <= 4); static_assert(NThreads * NPerThread == kNPerBlock); - static_assert(MThreads * MPerThread == kMPerBlock); + static_assert(MWarps * MThreadPerWarp * MPerThread == kMPerBlock); return make_static_tile_distribution( tile_distribution_encoding< From f262becb68ba0b5332e6a47a4ec230a148fc3906 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Thu, 17 Oct 2024 07:03:23 +0000 Subject: [PATCH 07/10] Add log<4> & log<8> specialization --- .../block_fmha_fwd_splitkv_combine_pipeline.hpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index a945b3cfbd..1dc27488a9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -12,6 +12,16 @@ namespace detail { template struct log2; +template <> +struct log2<4> : std::integral_constant +{ +}; + +template <> +struct log2<8> : std::integral_constant +{ +}; + template <> struct log2<16> : std::integral_constant { From 88e5de43b7d99c2459f15dd98854a3e8b20698d4 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Thu, 17 Oct 2024 09:23:57 +0000 Subject: [PATCH 08/10] Enlarge occupancy array --- .../block_fmha_fwd_splitkv_combine_pipeline.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 1dc27488a9..7c49fce99a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -82,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline { if constexpr(kHeadDimV <= 32) { - constexpr std::array occupancy{3, 3, 3, 1}; - return occupancy[detail::log2::value - 4]; + constexpr std::array occupancy{3, 3, 3, 3, 3, 1}; + return occupancy[detail::log2::value - 2]; } else if constexpr(kHeadDimV <= 128) { - constexpr std::array occupancy{3, 3, 2, 1}; - return occupancy[detail::log2::value - 4]; + constexpr std::array occupancy{3, 3, 3, 3, 2, 1}; + return occupancy[detail::log2::value - 2]; } else if constexpr(kHeadDimV <= 256) { - constexpr std::array occupancy{2, 2, 2, 1}; - return occupancy[detail::log2::value - 4]; + constexpr std::array occupancy{2, 2, 2, 2, 2, 1}; + return occupancy[detail::log2::value - 2]; } } }(); From 598ee5cd15c843f9b8a38ef55d6659cb16d52a04 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Thu, 17 Oct 2024 10:13:11 +0000 Subject: [PATCH 09/10] Fix vector size for small tile --- ...plitkv_combine_pipeline_default_policy.hpp | 47 ++++++++++++++----- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp index 56ec85888b..ebd69c0cf8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp @@ -10,11 +10,26 @@ namespace ck_tile { struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy { + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeForTile() + { + constexpr index_t PixelsPerThread = (M * N) / BlockSize; + static_assert(0 < PixelsPerThread); + + constexpr index_t MaxNPerThread = 16 / sizeof(DataType); + constexpr index_t NPerThread = min(MaxNPerThread, PixelsPerThread); + + return NPerThread; + } + + // alignment for dram lse tile (shape=[kMaxSplits, kM0]) template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentLSE() { - using LSEDataType = remove_cvref_t; - return 16 / sizeof(LSEDataType); + return GetVectorSizeForTile(); } template @@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy MakeLSEaccLdsBlockDescriptor().get_element_space_size(); } + // shape=[kMaxSplits, kM0] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccDramTileDistribution() { using LSEDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNumWarps = Problem::kNumWarps; constexpr index_t kNPerBlock = Problem::kM0; constexpr index_t kMPerBlock = Problem::kMaxSplits; - constexpr index_t NPerThread = 16 / sizeof(LSEDataType); - constexpr index_t NThreads = kNPerBlock / NPerThread; + constexpr index_t NPerThread = + GetVectorSizeForTile(); + constexpr index_t NThreads = kNPerBlock / NPerThread; constexpr index_t MThreadsPerWarp = get_warp_size() / NThreads; - constexpr index_t TotalWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (TotalWarps * MThreadsPerWarp); + constexpr index_t MPerThread = kMPerBlock / (kNumWarps * MThreadsPerWarp); static_assert(NThreads * NPerThread == kNPerBlock); - static_assert(MPerThread * TotalWarps * MThreadsPerWarp == kMPerBlock); + static_assert(MPerThread * kNumWarps * MThreadsPerWarp == kMPerBlock); return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, @@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy sequence<0, 1>>{}); } - // 3d + padding, [kMaxSplits, kM0] + // 3d + padding, shape=[kMaxSplits, kM0] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsStoreBlockDescriptor() { using LSEDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kNPerBlock = Problem::kM0; - constexpr index_t NPack = 16 / sizeof(LSEDataType); + constexpr index_t NPack = + GetVectorSizeForTile(); constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy return lse_acc_lds_block_desc; } - // 3d + padding, [kM0, kMaxSplits] + // 3d + padding, shape=[kM0, kMaxSplits] template CK_TILE_HOST_DEVICE static constexpr auto MakeLSEaccLdsBlockDescriptor() { using LSEDataType = remove_cvref_t; + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::kMaxSplits; constexpr index_t kNPerBlock = Problem::kM0; - constexpr index_t NPack = 16 / sizeof(LSEDataType); + constexpr index_t NPack = + GetVectorSizeForTile(); constexpr auto lse_acc_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), From 5776e16c7647fb7dfc05252e9840297bce67b930 Mon Sep 17 00:00:00 2001 From: "Po Yen, Chen" Date: Thu, 17 Oct 2024 10:17:57 +0000 Subject: [PATCH 10/10] Add support for kMaxSplits=8 --- example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 82cf3a5ab2..2eebd02a30 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -191,7 +191,9 @@ template<> void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if (a.num_splits <= 16) {{ + if (a.num_splits <= 8) {{ + kernel_runner<3>::run(s, a); + }} else if (a.num_splits <= 16) {{ kernel_runner<4>::run(s, a); }} else if (a.num_splits <= 32) {{ kernel_runner<5>::run(s, a);