Skip to content

Commit

Permalink
[Dev] Add test case for bfloat16 and int4 gemm with mma (#65)
Browse files Browse the repository at this point in the history
* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py

* Add BF16 support to matrix multiplication and introduce corresponding test cases

* Add a blank line for improved readability in BF16 GEMM test

* Update acknowledgements in README to include supervision by Zhi Yang at Peking University
  • Loading branch information
LeiWang1999 authored Feb 6, 2025
1 parent b6c9c19 commit 560b1d8
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ Welcome to join our Discord community for discussions, support, and collaboratio

## Acknowledgements

We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410). Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.
We learned a lot from the [TVM](https://github.com/apache/tvm) community and would like to thank them for their contributions. The initial version of this project is mainly contributed by [LeiWang1999](https://github.com/LeiWang1999), [chengyupku](https://github.com/chengyupku) and [nox-410](https://github.com/nox-410) under the supervision of [zhi yang](https://yangzhihome.github.io) at Peking university. Part of this work was done during the internship at Microsoft Research, under the supervision of Dr. Lingxiao Ma, Dr. Yuqing Xia, Dr. Jilong Xue, and Dr. Fan Yang.
236 changes: 236 additions & 0 deletions testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func

tilelang.testing.set_random_seed(0)


def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape

can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)

def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]

return T.Layout(shape, transform_func)


@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32

# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"

# Pipeline Stage
stage = 2

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined((K // block_K), num_stages=stage):

# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)

# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)

# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]

return main


def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None

def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)

if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5

C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)

mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)

mod(A, B, C)

latency = mod.do_bench(mod.func, warmup=25)

# Ensure that the latency is not None
assert latency is not None

# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(128, 128, 128, "bfloat16", "float32", "float32")


if __name__ == "__main__":
tilelang.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def tl_matmul(
):
assert in_dtype in [
"float16",
"bfloat16",
"e4m3_float8",
"e5m2_float8",
"int8",
Expand Down Expand Up @@ -230,6 +231,7 @@ def map_torch_type(intype):
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "bfloat16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")


Expand Down

0 comments on commit 560b1d8

Please sign in to comment.