Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CK Tile Gemm API and heuristics changes #1809

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion example/ck_tile/03_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,24 @@
function (add_gemm_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()

target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
jakpiase marked this conversation as resolved.
Show resolved Hide resolved

set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template)

target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_gemm_example TARGET_NAME MAIN_SRC)

file(GLOB INSTANCE_SRCS instances/*.cpp)
add_gemm_example(tile_example_gemm_universal universal_gemm.cpp ${INSTANCE_SRCS})

add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
124 changes: 124 additions & 0 deletions example/ck_tile/03_gemm/gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"

template <typename DataType>
struct GemmBasicTypeConfig;

template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};

template <typename T>
struct DataTypeTraits;

template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};

template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};

template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};

using Types = GemmBasicTypeConfig<ck_tile::half_t>;

// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

/** \brief Struct used for specifying desired gemm details*/
struct gemm_traits
{
std::string data_type; /** Tensors datatype, can be set to either fp16 or bf16*/
jakpiase marked this conversation as resolved.
Show resolved Hide resolved
bool is_a_rowmajor; /** Whether A matrix is rowmajor */
bool is_b_rowmajor; /** Whether B matrix is rowmajor */
bool is_c_rowmajor; /** Whether C matrix is rowmajor */
};

template <typename ADataType_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t M_Tile_,
ck_tile::index_t N_Tile_,
ck_tile::index_t K_Tile_,
ck_tile::index_t M_Warp_,
ck_tile::index_t N_Warp_,
ck_tile::index_t K_Warp_,
ck_tile::index_t M_Warp_Tile_,
ck_tile::index_t N_Warp_Tile_,
ck_tile::index_t K_Warp_Tile_,
bool kPadM_,
bool kPadN_,
bool kPadK_>
struct gemm_traits_
{
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
using AccDataType = ck_tile::remove_cvref_t<AccDataType_>;
using CDataType = ck_tile::remove_cvref_t<CDataType_>;
using ALayout = ck_tile::remove_cvref_t<ALayout_>;
using BLayout = ck_tile::remove_cvref_t<BLayout_>;
using CLayout = ck_tile::remove_cvref_t<CLayout_>;
static constexpr ck_tile::index_t M_Tile = M_Tile_;
static constexpr ck_tile::index_t N_Tile = N_Tile_;
static constexpr ck_tile::index_t K_Tile = K_Tile_;
static constexpr ck_tile::index_t M_Warp = M_Warp_;
static constexpr ck_tile::index_t N_Warp = N_Warp_;
static constexpr ck_tile::index_t K_Warp = K_Warp_;
static constexpr ck_tile::index_t M_Warp_Tile = M_Warp_Tile_;
static constexpr ck_tile::index_t N_Warp_Tile = N_Warp_Tile_;
static constexpr ck_tile::index_t K_Warp_Tile = K_Warp_Tile_;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
};

// host API

template <typename Traits_>
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

/**
* \brief Invoke gemm function
*
* \param traits Gemm traits which are used for choosing best instance.
* \param args Runtime gemm host arguments.
* \param s Stream configuration.
* \return Time of execution.
*/
float gemm(const gemm_traits& traits,
const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& s);
29 changes: 26 additions & 3 deletions example/ck_tile/03_gemm/gemm_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
#include <string>
#include <tuple>

#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm.hpp"

template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
Expand Down Expand Up @@ -101,6 +100,30 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
return ave_time;
}

float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Row, Row>(args, s);
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Row, Col, Row>(args, s);
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Row, Row>(args, s);
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
return gemm_<Col, Col, Row>(args, s);
}
else
{
throw std::runtime_error("Wrong! Layouts not supported!\n");
}
}

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
99 changes: 0 additions & 99 deletions example/ck_tile/03_gemm/gemm_basic.hpp

This file was deleted.

Loading