diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index bc3799f015..4b939337dc 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,24 @@ +function (add_gemm_example TARGET_NAME MAIN_SRC) +message("adding ${TARGET_NAME}") +# not using add_example_executable() to add target, since we don't want this to have +# to be included in "make all/install/check" +add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + +foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) +endforeach() + +target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + +set(COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template) + +target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) +endfunction(add_gemm_example TARGET_NAME MAIN_SRC) + +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_gemm_example(tile_example_gemm_universal universal_gemm.cpp ${INSTANCE_SRCS}) + add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) -add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) diff --git a/example/ck_tile/03_gemm/gemm.hpp b/example/ck_tile/03_gemm/gemm.hpp new file mode 100644 index 0000000000..10a1934fc4 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm.hpp @@ -0,0 +1,124 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +using Types = GemmBasicTypeConfig; + +// Specific type aliases for easy access +using ADataType = Types::ADataType; +using BDataType = Types::BDataType; +using AccDataType = Types::AccDataType; +using CDataType = Types::CDataType; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +/** \brief Struct used for specifying desired gemm details*/ +struct gemm_traits +{ + std::string data_type; /** Tensors datatype, can be set to either fp16 or bf16*/ + bool is_a_rowmajor; /** Whether A matrix is rowmajor */ + bool is_b_rowmajor; /** Whether B matrix is rowmajor */ + bool is_c_rowmajor; /** Whether C matrix is rowmajor */ +}; + +template +struct gemm_traits_ +{ + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using CDataType = ck_tile::remove_cvref_t; + using ALayout = ck_tile::remove_cvref_t; + using BLayout = ck_tile::remove_cvref_t; + using CLayout = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t M_Tile = M_Tile_; + static constexpr ck_tile::index_t N_Tile = N_Tile_; + static constexpr ck_tile::index_t K_Tile = K_Tile_; + static constexpr ck_tile::index_t M_Warp = M_Warp_; + static constexpr ck_tile::index_t N_Warp = N_Warp_; + static constexpr ck_tile::index_t K_Warp = K_Warp_; + static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_; + static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_; + static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; +}; + +// host API + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); + +/** + * \brief Invoke gemm function + * + * \param traits Gemm traits which are used for choosing best instance. + * \param args Runtime gemm host arguments. + * \param s Stream configuration. + * \return Time of execution. + */ +float gemm(const gemm_traits& traits, + const ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 16f1466dd3..fe2cef6730 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -9,11 +9,10 @@ #include #include -#include "ck_tile/host.hpp" -#include "gemm_basic.hpp" +#include "gemm.hpp" template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. constexpr bool kPadM = false; @@ -101,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; } +float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + return gemm_(args, s); + } + else + { + throw std::runtime_error("Wrong! Layouts not supported!\n"); + } +} + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp deleted file mode 100644 index 4500e3b4fd..0000000000 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ /dev/null @@ -1,99 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -#define CK_TILE_PIPELINE_COMPUTE 1 -#define CK_TILE_PIPELINE_MEMORY 2 - -#ifndef CK_TILE_PIPELINE_DEFAULT -#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE -#endif - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) -#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 -#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 -#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave -#else -#error "unsupported CK_TILE_PIPELINE_DEFAULT value" -#endif - -template -struct GemmBasicTypeConfig; - -template <> -struct GemmBasicTypeConfig -{ - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; - // ToDo: Add more bias config to support different categories of GEMM. -}; - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -using Types = GemmBasicTypeConfig; - -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("b_layout", "R", "B tensor data layout - Row by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -// host API -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/instances/gemm_api.cpp b/example/ck_tile/03_gemm/instances/gemm_api.cpp new file mode 100644 index 0000000000..3075d1631b --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_api.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using FP32 = float; +using FP16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; + +float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else + { + throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); + } + } + else if(t.data_type.compare("bf16") == 0) + { + if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) + { + if(a.M > 512) + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + else + { + // clang-format off + // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK + return gemm_>(a, s); + // clang-format on + } + } + else + { + throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); + } + } + else + { + throw std::runtime_error("Wrong! DataTypes not supported!\n"); + } +} diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp new file mode 100644 index 0000000000..fc350a5fd6 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_kn_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp new file mode 100644 index 0000000000..eeb1c35132 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_km_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp new file mode 100644 index 0000000000..6c2fe38914 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_kn_mn.cpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..3aa33ca83f --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_bf16_bf16_bf16_mk_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 0000000000..ed695b3be9 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 0000000000..cb975f33a0 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 0000000000..bfc9fc6a97 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..1be99be0b6 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_comp_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp new file mode 100644 index 0000000000..387864a8d9 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_comp_instance_common.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include "gemm.hpp" + +using A = ck_tile::GemmHostArgs; +using S = ck_tile::stream_config; + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTile2DPartitioner; + + using GemmEpilogue = + ck_tile::Default2DEpilogue>; + using GemmTraits = ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3< + ck_tile::GemmPipelineProblem>; + + constexpr int kBlockPerCu = 1; + + const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + static_assert(BaseGemmPipeline::PrefetchStages > 3); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp new file mode 100644 index 0000000000..299924eb32 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_kn_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp new file mode 100644 index 0000000000..d28ce6e637 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_km_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp new file mode 100644 index 0000000000..aa8a772eec --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_kn_mn.cpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp new file mode 100644 index 0000000000..30871f99fa --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_bf16_bf16_bf16_mk_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp new file mode 100644 index 0000000000..611de23784 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_kn_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp new file mode 100644 index 0000000000..15b01460e3 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_km_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp new file mode 100644 index 0000000000..b9b6c9b263 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_kn_mn.cpp @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp new file mode 100644 index 0000000000..ee0703edb5 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_f16_f16_f16_mk_nk_mn.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_universal_mem_instance_common.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +template float gemm_>(const A&, const S&); +// clang-format on diff --git a/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp new file mode 100644 index 0000000000..e23b752a39 --- /dev/null +++ b/example/ck_tile/03_gemm/instances/gemm_universal_mem_instance_common.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include "gemm.hpp" + +using A = ck_tile::GemmHostArgs; +using S = ck_tile::stream_config; + +template +float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile::GemmTile2DPartitioner; + + using GemmEpilogue = + ck_tile::Default2DEpilogue>; + using GemmTraits = ck_tile::TileGemmTraits; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; + + constexpr int kBlockPerCu = 1; + + const ck_tile::index_t k_grain = args.k_batch * Traits_::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * Traits_::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< + ck_tile::UniversalGemmPipelineProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + if(has_hot_loop) + { + // Tail pipeline One to Seven + if(tail_num == ck_tile::TailNumber::One) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(BaseGemmPipeline::PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 3) + { + static_assert(BaseGemmPipeline::PrefetchStages > 3); + if(tail_num == ck_tile::TailNumber::Three) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(BaseGemmPipeline::PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + else + { + // Tail number always Full - #PrefetchStages + if(tail_num == ck_tile::TailNumber::Full) + { + Run(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "When there's no hot loop, this tail number \"" << tail_num + << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + return ave_time; +} diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index e29ba272f5..fad534fcf9 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -2,6 +2,29 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "R", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + template static constexpr inline auto is_row_major(Layout layout_) { @@ -55,8 +78,14 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + // TODO: Change datatypes in future to allow mixed precision gemms! + gemm_traits traits{DataTypeTraits{}.name, + std::is_same_v, + std::is_same_v, + std::is_same_v}; + + float ave_time = + gemm(traits, args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -224,9 +253,6 @@ int run_gemm_example(int argc, char* argv[]) if(!result) return -1; - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); @@ -238,16 +264,14 @@ int run_gemm_example(int argc, char* argv[]) { return run_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } - // TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not - // work. - // else if(a_layout == "C" && b_layout == "C") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); - // } - // else if(a_layout == "C" && b_layout == "R") - // { - // return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); - // } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{}); + } else { throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index bff243d559..cd6c1dbfb0 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -10,194 +10,7 @@ #include #include "ck_tile/host.hpp" -#include "gemm_basic.hpp" - -template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) -{ -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Memory friendly for Interwave scheduler - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 32; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 4; - constexpr ck_tile::index_t N_Warp = 1; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; - -#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) - // Compute friendly for Intrawave scheduler - constexpr ck_tile::index_t M_Tile = 256; - constexpr ck_tile::index_t N_Tile = 256; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; -#endif - - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - // =============================================== - - using GemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTile2DPartitioner; - - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - - using Traits = ck_tile::TileGemmTraits; - - using GemmPipelineProblem = - ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; - - const ck_tile::index_t k_grain = args.k_batch * K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = GEMM_PIPELINE; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - if(has_hot_loop) - { - // Tail pipeline One to Seven - if(tail_num == ck_tile::TailNumber::One) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - } - else - { - // Tail number always Full - #PrefetchStages - if(tail_num == ck_tile::TailNumber::Full) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "When there's no hot loop, this tail number \"" << tail_num - << "\" is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } - - return ave_time; -} +#include "gemm.hpp" #include "run_gemm_example.inc"