Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Add test case for bfloat16 and int4 gemm with mma #65

Merged
merged 22 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
03cf5b5
[Enhancement] Add VectorizeLoop function and update imports for compa…
LeiWang1999 Feb 3, 2025
73cb739
[CI][Test] Improve test cases for vectorization and fix typos in pars…
LeiWang1999 Feb 3, 2025
6b80e0e
lint fix
LeiWang1999 Feb 3, 2025
91d91a7
Fix incorrect module reference for VectorizeLoop transformation
LeiWang1999 Feb 3, 2025
e3b1856
Refactor vectorize_loop transformation by removing unused extent muta…
LeiWang1999 Feb 3, 2025
b6a1d81
[Enhancement] Add support for FP8 data types and global barriers in C…
LeiWang1999 Feb 4, 2025
6aef1f8
Fix formatting in CUDA FP8 header file for consistency
LeiWang1999 Feb 4, 2025
d0dbc46
Refactor CI workflow to use 'tilelang_ci' virtual environment and upd…
LeiWang1999 Feb 4, 2025
bbc3cd7
Update submodule 'tvm' to latest commit for improved functionality
LeiWang1999 Feb 4, 2025
22f41e0
Refactor execution backend references from 'dl_pack' to 'dlpack' for …
LeiWang1999 Feb 5, 2025
fffda93
Refactor CUDA code for improved readability; clean up formatting and …
LeiWang1999 Feb 5, 2025
22cc8aa
Refactor import statement in test_tilelang_kernel_dequantize_gemm.py …
LeiWang1999 Feb 5, 2025
b004e3c
Add CUDA requirements to FP8 test cases and update references for cla…
LeiWang1999 Feb 5, 2025
4b5bcb2
Add a blank line for improved readability in test_tilelang_kernel_fp8…
LeiWang1999 Feb 5, 2025
f8d9005
Fix data type in reference result calculation for consistency in test…
LeiWang1999 Feb 5, 2025
5b1c005
Add CUDA requirements and FP8 test cases for matmul and gemv simulations
LeiWang1999 Feb 6, 2025
226ac59
Remove debug print statements and use tilelang's testing assertion fo…
LeiWang1999 Feb 6, 2025
e03159f
Remove outdated comment regarding FP8 tests in test_tilelang_kernel_g…
LeiWang1999 Feb 6, 2025
fcb642e
Add BF16 support to matrix multiplication and introduce corresponding…
LeiWang1999 Feb 6, 2025
deeb142
Merge branch 'main' of https://github.com/tile-ai/tilelang into bitblas
LeiWang1999 Feb 6, 2025
d5b057b
Add a blank line for improved readability in BF16 GEMM test
LeiWang1999 Feb 6, 2025
4f99b7c
Update acknowledgements in README to include supervision by Zhi Yang …
LeiWang1999 Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ Welcome to join our Discord community for discussions, support, and collaboratio

## Acknowledgements

We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410). Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) under the supervision of [zhi yang](https://yangzhihome.github.io) at Peking university. Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.
236 changes: 236 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(0)


def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape

can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)

def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]

return T.Layout(shape, transform_func)


@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32

# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"

# Pipeline Stage
stage = 2

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined((K // block_K), num_stages=stage):

# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)

# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]

return main


def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)

if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5

C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)

mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)

mod(A, B, C)

latency = mod.do_bench(mod.func, warmup=25)

# Ensure that the latency is not None
assert latency is not None

# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(128, 128, 128, "bfloat16", "float32", "float32")


if __name__ == "__main__":
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def tl_matmul(
):
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"int8",
Expand Down Expand Up @@ -230,6 +231,7 @@ def map_torch_type(intype):
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "bfloat16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")


Expand Down
Loading