diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp index 7661114ef8..52e390e769 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -83,6 +83,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK{}; + + static constexpr bool IsHWSupportAtomicAdd = + (CDEShuffleBlockTransferScalarPerVectors{}[I0] > 1) || (sizeof(CDataType) > 2); + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< ALayout, @@ -153,6 +158,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1)) + { + throw std::runtime_error("wrong! AtomicAdd is not valid for SplitK"); + } + index_t gdx, gdy, gdz; std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); @@ -166,7 +176,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK DsSize; Argument arg_ = arg; @@ -238,12 +247,15 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); + if constexpr(IsHWSupportAtomicAdd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } } else { @@ -260,113 +272,118 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + if constexpr(IsHWSupportAtomicAdd) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, - TailNumber::Two>; + TailNumber::One>; Run(kernel); } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) { const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< GridwiseGemm, true, InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy, - TailNumber::Three>; + TailNumber::Full>; Run(kernel); } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } } - } - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } } } } @@ -488,25 +505,28 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + if constexpr(IsHWSupportAtomicAdd) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } } } else @@ -537,25 +557,28 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + if constexpr(IsHWSupportAtomicAdd) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } } } else @@ -590,12 +613,15 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK 1) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); + if constexpr(IsHWSupportAtomicAdd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } } else {