Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev][AMD] Implement LDS Async Copy for CDNA Arch #246

Merged
merged 37 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c4853ec
Refactor Simplify function to handle multiple functions in IRModule
LeiWang1999 Oct 16, 2024
9a21acf
Update submodule commit reference
LeiWang1999 Oct 17, 2024
f8d046b
Add CUDA_DEVICE_ORDER environment variable to bashrc
LeiWang1999 Oct 17, 2024
c1371dd
test fix
LeiWang1999 Oct 17, 2024
416cad2
lint fix
LeiWang1999 Oct 17, 2024
9209d1e
Refactor test_general_matmul_bf16.py to use bitblas.testing.main()
LeiWang1999 Oct 17, 2024
1cf7570
Update submodule commit reference
LeiWang1999 Oct 17, 2024
5fec040
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
4e1a0d2
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
fa85f8c
Update submodule commit reference
LeiWang1999 Oct 19, 2024
429d5b5
Update submodule commit reference
LeiWang1999 Oct 19, 2024
4003509
Update submodule commit reference
LeiWang1999 Oct 20, 2024
1d86582
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 20, 2024
df3af0d
Update submodule commit reference
LeiWang1999 Oct 28, 2024
1f1e027
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 28, 2024
732dda6
Update submodule commit reference
LeiWang1999 Oct 29, 2024
ebffbfa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 29, 2024
ff227fa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 4, 2024
ac62936
[Dev] Update subproject commit for TVM
LeiWang1999 Nov 7, 2024
a7a239c
ignore profiler directories.
LeiWang1999 Nov 7, 2024
dcedbde
MFMA Support
LeiWang1999 Nov 7, 2024
e0b36f5
lint fix
LeiWang1999 Nov 7, 2024
fe668f9
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 7, 2024
3579c6b
MFMA Fixed.
LeiWang1999 Nov 8, 2024
e60ccd9
merge upstream
LeiWang1999 Nov 8, 2024
d4df21c
update
LeiWang1999 Nov 8, 2024
e4ff7f3
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 8, 2024
57e3cf9
Fix MFMA Layout Related issue
LeiWang1999 Nov 8, 2024
c3398f5
lint fix
LeiWang1999 Nov 8, 2024
ddd0219
amd hip update
LeiWang1999 Nov 8, 2024
754294f
Block GEMM Example
LeiWang1999 Nov 13, 2024
e041d91
fix amd
YangWang92 Nov 15, 2024
2910b3c
mi300 update
YangWang92 Nov 15, 2024
cf934c1
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
YangWang92 Nov 15, 2024
baf4abd
fix and enhance
YangWang92 Nov 15, 2024
c3605e8
lintfix
LeiWang1999 Nov 15, 2024
8ee2c63
enhance amd installation
LeiWang1999 Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 7b325a to 4a2e00
73 changes: 62 additions & 11 deletions install_amd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,44 @@ pip install -r requirements.txt

# determine if root
USER_IS_ROOT=false
if [ "$EUID" -e 0 ]; then
if [ "$EUID" -eq 0 ]; then
USER_IS_ROOT=true
fi

if $USER_IS_ROOT; then
# Fetch the GPG key for the LLVM repository and add it to the trusted keys
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" >> /etc/apt/sources.list
echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" >> /etc/apt/sources.list
apt-get install llvm-16
else

# Check if the repository is already present in the sources.list
if ! grep -q "http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" /etc/apt/sources.list; then
# Add the LLVM repository to sources.list
echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" >> /etc/apt/sources.list
echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" >> /etc/apt/sources.list
else
# Print a message if the repository is already added
echo "The repository is already added."
fi

# Update package lists and install llvm-16
apt-get update
apt-get install -y llvm-16
else
# Fetch the GPG key for the LLVM repository and add it to the trusted keys using sudo
wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc
echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" | sudo tee /etc/apt/sources.list
echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" | sudo tee /etc/apt/sources.list
sudo apt-get install llvm-16

# Check if the repository is already present in the sources.list
if ! grep -q "http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" /etc/apt/sources.list; then
# Add the LLVM repository to sources.list using sudo
echo "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" | sudo tee -a /etc/apt/sources.list
echo "deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-16 main" | sudo tee -a /etc/apt/sources.list
else
# Print a message if the repository is already added
echo "The repository is already added."
fi

# Update package lists and install llvm-16 using sudo
sudo apt-get update
sudo apt-get install -y llvm-16
fi

# clone and build tvm
Expand All @@ -38,7 +62,34 @@ echo "set(USE_LLVM llvm-config-16)" >> config.cmake && echo "set(USE_ROCM /opt/r

cmake .. && make -j && cd ../../..

echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc
echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc
echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc
# Define the lines to be added
TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm"
BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH"
CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID"

# Check and add the first line if not already present
if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then
echo "$TVM_HOME_ENV" >> ~/.bashrc
echo "Added TVM_HOME to ~/.bashrc"
else
echo "TVM_HOME is already set in ~/.bashrc"
fi

# Check and add the second line if not already present
if ! grep -qxF "$BITBLAS_PYPATH_ENV" ~/.bashrc; then
echo "$BITBLAS_PYPATH_ENV" >> ~/.bashrc
echo "Added PYTHONPATH to ~/.bashrc"
else
echo "PYTHONPATH is already set in ~/.bashrc"
fi

# Check and add the third line if not already present
if ! grep -qxF "$CUDA_DEVICE_ORDER_ENV" ~/.bashrc; then
echo "$CUDA_DEVICE_ORDER_ENV" >> ~/.bashrc
echo "Added CUDA_DEVICE_ORDER to ~/.bashrc"
else
echo "CUDA_DEVICE_ORDER is already set in ~/.bashrc"
fi

# Reload ~/.bashrc to apply the changes
source ~/.bashrc
178 changes: 178 additions & 0 deletions integration/ComposableKernel/test_block_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
from tvm import tl


@tvm.register_func("tvm_callback_hip_postproc", override=True)
def tvm_callback_hip_postproc(code, _):
print(code)
# code = '''
# #include <hip/hip_runtime.h>
# #include <tl_templates/hip/gemm.h>
# #include <tl_templates/hip/copy.h>
# #include <tl_templates/hip/reduce.h>
# #include <tl_templates/hip/ldsm.h>
# #include <tl_templates/hip/threadblock_swizzle.h>

# extern "C" __global__ void __launch_bounds__(128) main_kernel(half_t* __restrict__ A, half_t* __restrict__ B, half_t* __restrict__ C) {
# float C_local[128];
# __shared__ half_t A_shared[4096];
# __shared__ half_t B_shared[4096];
# #pragma unroll
# for (int i = 0; i < 64; ++i) {
# *(float2*)(C_local + (i * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
# }
# #pragma unroll
# for (int i_1 = 0; i_1 < 4; ++i_1) {
# *(uint4*)(A_shared + ((((i_1 * 1024) + ((((int)threadIdx.x) >> 2) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 8))) = *(uint4*)(A + ((i_1 * 1024) + (((int)threadIdx.x) * 8)));
# }
# #pragma unroll
# for (int i_2 = 0; i_2 < 4; ++i_2) {
# *(uint4*)(B_shared + ((((i_2 * 1024) + ((((int)threadIdx.x) >> 2) * 32)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 16)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 8))) = *(uint4*)(B + ((i_2 * 1024) + (((int)threadIdx.x) * 8)));
# }
# __syncthreads();
# tl::gemm_ss<128, 128, 32, 2, 2, 0, 1>((&(A_shared[0])), (&(B_shared[0])), (&(C_local[0])));
# if(threadIdx.x == 0){
# for (size_t i = 0; i < 128; i++) {
# printf("%f ", C_local[i]);
# }
# }
# #pragma unroll
# for (int i_3 = 0; i_3 < 64; ++i_3) {
# uint1 __1;
# float2 v_ = *(float2*)(C_local + (i_3 * 2));
# ((half2*)(&(__1.x)))->x = (half_t)(v_.x);
# ((half2*)(&(__1.x)))->y = (half_t)(v_.y);
# *(uint1*)(C + (((((((((i_3 & 7) >> 1) * 4096) + (((((int)threadIdx.x) & 63) >> 5) * 2048)) + ((i_3 & 1) * 1024)) + (((((int)threadIdx.x) & 31) >> 2) * 128)) + ((i_3 >> 3) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = __1;
# }
# }
# '''
return code


def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
accum_dtype,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

import tvm.tl.language as T

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, dtypeAB),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope="shared")
B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.serial(T.ceildiv(K, block_K)):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, False, True)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm(
M,
N,
K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
block_M,
block_N,
block_K,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
dtypeAB,
dtypeC,
dtypeAccum,
num_threads,
)
mod, params = tl.lower(program, target="hip")
# print(mod.imported_modules[0].get_source())
mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer)
import torch
torch.random.manual_seed(0)
a = torch.randn((M, K), dtype=torch.__getattribute__(dtypeAB)).to("cuda")
# b = torch.randn((N, K), dtype=torch.__getattribute__(dtypeAB)).to("cuda")
# a = torch.ones((M, K), dtype=torch.__getattribute__(dtypeAB)).to("cuda")
b = torch.ones((N, K), dtype=torch.__getattribute__(dtypeAB)).to("cuda")
c = torch.zeros((M, N), dtype=torch.__getattribute__(dtypeC)).to("cuda")
print(f"{a=}")
print(f"{b=}")
mod(a, b, c)

print(c)

ref_c = torch.matmul(a, b.T).to(torch.__getattribute__(dtypeC))
print(ref_c)

latency = mod.do_bench(mod.func, profiler="tvm")
print(f"Latency: {latency}")

torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
# run_gemm(
# 64,
# 16,
# 16,
# False,
# True,
# "float16",
# "float32",
# "float32",
# 64,
# 16,
# 16,
# 128,
# )

run_gemm(
256,
256,
256,
False,
True,
"float16",
"float32",
"float32",
128,
128,
32,
256,
)
77 changes: 77 additions & 0 deletions integration/ComposableKernel/test_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import subprocess

layouts = [
[False, False, False, False],
[False, False, False, True],
[False, False, True, False],
[False, False, True, True],
[False, True, False, False],
[False, True, False, True],
[False, True, True, False],
[False, True, True, True],
[True, False, False, False],
[True, False, False, True],
[True, False, True, False],
[True, False, True, True],
[True, True, False, False],
[True, True, False, True],
[True, True, True, False],
[True, True, True, True],
]

raw_func = '''Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64) LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentCDNA16x16()->Repeat({1, 1}, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, block_n / warp_n}, false, false);
auto block_layout = warp_layout->Repeat({warp_m / 16, warp_n / 16}, true, true);
return block_layout;
}'''
file_path = "/home/aiscuser/leiwang/BitBLAS/3rdparty/tvm/src/tl/layout/gemm_layouts.cc"

for layout in layouts:
block_layout_0 = "false" if not layout[0] else "true"
block_layout_1 = "false" if not layout[1] else "true"
warp_layout_0 = "false" if not layout[2] else "true"
warp_layout_1 = "false" if not layout[3] else "true"

log_path = f"block_{block_layout_0}_{block_layout_1}_warp_{warp_layout_0}_{warp_layout_1}.log"

new_func = raw_func.replace(
"base_layout->Repeat({block_m / warp_m, block_n / warp_n}, false, false);",
f"base_layout->Repeat({{block_m / warp_m, block_n / warp_n}}, {block_layout_0}, {block_layout_1});"
).replace(
"warp_layout->Repeat({warp_m / 16, warp_n / 16}, true, true);",
f"warp_layout->Repeat({{warp_m / 16, warp_n / 16}}, {warp_layout_0}, {warp_layout_1});")
print(new_func)
with open(file_path, "r") as f:
content = f.read()
content = content.replace(raw_func, new_func)
with open(file_path, "w") as f:
f.write(content)

with open(log_path, "w") as log_file:
# build tvm
subprocess.run(["make", "-j8"],
cwd="/home/aiscuser/leiwang/BitBLAS/3rdparty/tvm/build",
stdout=log_file,
stderr=log_file)

# Execute Test log
subprocess.run([
"python",
"/home/aiscuser/leiwang/BitBLAS/integration/ComposableKernel/test_block_gemm.py"
],
cwd="/home/aiscuser/leiwang/BitBLAS/integration/ComposableKernel",
stdout=log_file,
stderr=log_file)

# Recover
content = content.replace(new_func, raw_func)

with open(file_path, "w") as f:
f.write(content)
Loading