Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CK-Tile Grouped GEMM refactor and post PR fixes #1756

Merged
merged 29 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
44fae39
Grouped gemm simple code refactor
mozga-amd Dec 17, 2024
dc48a14
Offset invoker
mozga-amd Jan 5, 2025
5df59d5
Merge remote-tracking branch 'origin/develop' into mozga-amd/gemm_ref…
mozga-amd Jan 5, 2025
ae36a63
Invoke generic Run, and replace name of parrtitioner variable
mozga-amd Jan 6, 2025
cbba680
Tests fix type
mozga-amd Jan 6, 2025
b781628
Removed namespaces
mozga-amd Jan 7, 2025
774f903
Add template param to avoid implicit cast
mozga-amd Jan 7, 2025
0c8a579
Remove generic function
mozga-amd Jan 7, 2025
2f80a6a
Constant value
mozga-amd Jan 7, 2025
e8da31e
underline enum to int16_t
mozga-amd Jan 14, 2025
bdc17fb
Generalize partitioner function
mozga-amd Jan 15, 2025
56c1916
Remove whitespaces
mozga-amd Jan 15, 2025
5aa63ce
Rename function
mozga-amd Jan 15, 2025
a0cffd8
Using support
mozga-amd Jan 15, 2025
414328c
Clang-format
mozga-amd Jan 15, 2025
2ac3b7f
Clang-format
mozga-amd Jan 15, 2025
b72d199
Fn-partitioner description fn
mozga-amd Jan 16, 2025
f70d888
Merge remote-tracking branch 'origin/develop' into mozga-amd/gemm_ref…
mozga-amd Jan 16, 2025
faad6fc
Typo
mozga-amd Jan 16, 2025
3fc0b87
Typo 2
mozga-amd Jan 16, 2025
60ee8fa
Better description
mozga-amd Jan 16, 2025
997bce8
Merge remote-tracking branch 'origin/develop' into mozga-amd/gemm_ref…
mozga-amd Jan 20, 2025
ccc19d6
Better description
mozga-amd Jan 20, 2025
189cfa7
Refactor after review
mozga-amd Jan 20, 2025
66d8b6b
Use ctr instead of set fn
mozga-amd Jan 20, 2025
d78b2df
Inovke ctr and typo
mozga-amd Jan 20, 2025
ceaf540
Comments
mozga-amd Jan 21, 2025
92a369e
Remove unnecessary comment
mozga-amd Jan 21, 2025
4735476
Review, remove modulo
mozga-amd Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
#include "utils.hpp"

namespace {

Expand Down Expand Up @@ -102,7 +101,7 @@ using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
GemmEpilogue<CLayout>>;
}; // namespace

std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
{
return ::Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
}
Expand Down
8 changes: 4 additions & 4 deletions example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}

std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs);
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs);

float grouped_gemm_calc(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* p_workspace_);
20 changes: 10 additions & 10 deletions example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ float invoke_gemm(int n_warmup,
{

ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
gemm_workspace.Realloc(get_workspace_size(args));

float ave_time = grouped_gemm<ALayout, BLayout, CLayout>(
args,
Expand Down Expand Up @@ -100,16 +100,16 @@ int run_grouped_gemm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];

stride_As[i] = f_get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout);
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout);
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{});

a_m_k_tensors.push_back(
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back(
ck_tile::HostTensor<BDataType>(f_host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout)));
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout)));
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{})));

std::cout << "gemm[" << i << "]"
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
Expand Down Expand Up @@ -150,7 +150,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; ++i)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
Expand Down
38 changes: 0 additions & 38 deletions example/ck_tile/17_grouped_gemm/utils.hpp
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved

This file was deleted.

1 change: 0 additions & 1 deletion include/ck_tile/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
Expand Down
26 changes: 25 additions & 1 deletion include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -109,4 +109,28 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif
}

#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}

template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved

} // namespace ck_tile
37 changes: 0 additions & 37 deletions include/ck_tile/core/utility/amd_address_space.hpp

This file was deleted.

33 changes: 33 additions & 0 deletions include/ck_tile/host/host_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,4 +678,37 @@ struct HostTensor
Descriptor mDesc;
Data mData;
};

template <typename TLayout>
auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
using namespace ck_tile::literals;

if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
}
}
template <typename TLayout>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout)
{
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
{
return col;
}
else
{
return row;
}
}
else
return stride;
}

} // namespace ck_tile
1 change: 1 addition & 0 deletions include/ck_tile/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_offset_block.hpp"
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
Expand Down
49 changes: 49 additions & 0 deletions include/ck_tile/ops/gemm/kernel/gemm_offset_block.hpp
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "ck_tile/core.hpp"

namespace ck_tile {
template <typename TilePartitioner_>
struct OffsettedBlockToCTileMap
{
using tile_partitioner_type = TilePartitioner_;

__host__ __device__ OffsettedBlockToCTileMap(ck_tile::index_t B2CkTileMap,
ck_tile::index_t M,
ck_tile::index_t N)
: B2CkTileMap_{B2CkTileMap}, M_{M}, N_{N}
{
}

__host__ __device__ constexpr auto CalculateBottomIndex(const ck_tile::index_t idx_top) const
{
ck_tile::index_t block_1d_id = idx_top;

const auto M0 = ck_tile::integer_divide_ceil(M_, tile_partitioner_type::MPerBlock);
const auto N0 = ck_tile::integer_divide_ceil(N_, tile_partitioner_type::NPerBlock);

block_1d_id = block_1d_id % (M0 * N0);

block_1d_id = block_1d_id % (M0 * N0);

ck_tile::index_t idx_N0 = block_1d_id % N0;
ck_tile::index_t idx_M0 = block_1d_id / N0;

const auto M01_adapt = (idx_M0 < M0 - M0 % B2CkTileMap_) ? B2CkTileMap_ : M0 % B2CkTileMap_;

ck_tile::index_t idx_M00 = idx_M0 / B2CkTileMap_;
ck_tile::index_t idx_M01 = idx_M0 % B2CkTileMap_;
ck_tile::index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;

return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * B2CkTileMap_,
idx_N0_M01_local / M01_adapt);
}

ck_tile::index_t B2CkTileMap_;
ck_tile::index_t M_;
ck_tile::index_t N_;
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
};
} // namespace ck_tile
8 changes: 3 additions & 5 deletions include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ struct GemmTile1DPartitioner
return integer_divide_ceil(K, KPerBlock);
}

CK_TILE_DEVICE auto operator()(index_t blockOffset, index_t NBlockSize)
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
CK_TILE_DEVICE auto operator()()
{
index_t iM = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) /
GetNBlock(NBlockSize) * MPerBlock);
index_t iN = __builtin_amdgcn_readfirstlane((blockIdx.x - blockOffset) %
GetNBlock(NBlockSize) * NPerBlock);
index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * MPerBlock);
index_t iN = __builtin_amdgcn_readfirstlane(blockIdx.x * NPerBlock);
mozga-amd marked this conversation as resolved.
Show resolved Hide resolved
return make_tuple(iM, iN);
}
};
Expand Down
Loading