Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused GEMM+GEMM #351

Merged
merged 26 commits into from
Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
89a5e84
initial stub for gemm_gemm_xdl_cshuffle
rosenrodt Jul 11, 2022
68b7153
set up example code
rosenrodt Jul 12, 2022
047cee2
compiles
rosenrodt Jul 20, 2022
237371a
prevent integer overflow
rosenrodt Jul 27, 2022
b57c387
harmonize interface between ref_gemm and ref_batched_gemm
rosenrodt Jul 27, 2022
408ba59
batched_gemm_gemm
rosenrodt Jul 27, 2022
b790e44
fix example
rosenrodt Aug 1, 2022
caf2b2e
host tensor gen: diagonal pattern in lowest two-dimensions only
rosenrodt Aug 1, 2022
4ee3402
make c descriptors containing only integral constants
rosenrodt Aug 3, 2022
eceea10
clean up
rosenrodt Aug 3, 2022
98e4c0c
add BlockwiseGemmXdlops_v2 while exploring an unified approach
rosenrodt Aug 3, 2022
5f94555
implement proper interface
rosenrodt Aug 4, 2022
ed42497
tidy up example
rosenrodt Aug 4, 2022
e55b67a
fix compilation warnings
rosenrodt Aug 10, 2022
c9bef1c
coarsely controlled 2nd gemm padding
rosenrodt Aug 10, 2022
00331ee
remove rocm-cmake's hard requirement for certain revision
rosenrodt Aug 4, 2022
edc494d
clang-format
rosenrodt Aug 4, 2022
3c5a50f
resolve merge conflict
rosenrodt Aug 8, 2022
8672733
fix compilation error on gfx10
rosenrodt Aug 10, 2022
51fc99a
adds acc0 elementwise op to interface
rosenrodt Aug 11, 2022
8aa44bc
add gemm_gemm instances and tests
rosenrodt Aug 11, 2022
b64a286
avoid LDS data hazard
rosenrodt Aug 11, 2022
000eefb
Merge remote-tracking branch 'origin/develop' into fused-gemm
Aug 13, 2022
8bea6b2
fix build
Aug 13, 2022
2564c49
Merge remote-tracking branch 'origin/develop' into fused-gemm
Aug 13, 2022
7ea9c9c
Merge branch 'fix_0813' into fused-gemm
Aug 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
enable_testing()

set(ROCM_SYMLINK_LIBS OFF)
find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm)
find_package(ROCM REQUIRED PATHS /opt/rocm)

include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
Expand Down
10 changes: 8 additions & 2 deletions example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on

using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceBatchedGemmInstance =
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
CDataType,
ReduceAccDataType,
AElementOp,
BElementOp,
CElementOp>;

int main(int argc, char* argv[])
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu
< ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on

using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<ADataType, BDataType, EDataType, AElementOp, BElementOp, CDEElementOp>;
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;

int main(int argc, char* argv[])
{
Expand Down
1 change: 1 addition & 0 deletions example/31_batched_gemm_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
Loading