Skip to content

Commit

Permalink
Add test and benchmark for explicit dot GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
brunomazzottiamd committed Sep 20, 2024
1 parent 96b3d37 commit c0c697a
Show file tree
Hide file tree
Showing 2 changed files with 311 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/amd_perf_kernel_Integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ jobs:
pytest -vvv ./python/perf-kernels/flash-attention.py
pytest -vvvv ./python/perf-kernels/softmax.py
pytest -vvv ./python/perf-kernels/rmsnorm.py
pytest -vvv ./python/perf-kernels/multreduce_matmul_kernel.py
- name: Run Perf Kernels Benchmark
run: |
python ./python/perf-kernels/flash-attention.py
python ./python/perf-kernels/softmax.py
python ./python/perf-kernels/rmsnorm.py
python ./python/perf-kernels/multreduce_matmul_kernel.py bench
323 changes: 309 additions & 14 deletions python/perf-kernels/multreduce_matmul_kernel.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,45 +1,340 @@
#!/usr/bin/env python

# -*- coding: utf-8 -*-

# Imports:
# --------

import argparse
import sys
from typing import Any, Optional

import pytest
import torch
from torch import Tensor

import triton
import triton.language as tl

# Input generation:
# -----------------


def gen_input(M: int, N: int, K: int, use_bias: bool, device: str = "cuda") -> tuple[Tensor, Tensor, Optional[Tensor]]:
assert M > 0, "M for input generation must be positive."
assert M <= 8, "M for input generation must be less or equal to 8."
assert N > 0, "N for input generation must be positive."
assert K > 0, "K for input generation must be positive."

a: Tensor = torch.randn((M, K), dtype=torch.float16, device=device)
b: Tensor = torch.randn((K, N), dtype=a.dtype, device=a.device)
bias: Optional[Tensor] = torch.randn(M, dtype=a.dtype, device=a.device) if use_bias else None

return a, b, bias


# PyTorch GEMM:
# -------------


def torch_matmul(a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor:
c: Tensor = torch.matmul(a, b)
if bias is not None:
c += bias[:, None]
return c


# Triton GEMM:
# ------------

# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes.
# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch).

def get_triton_autotune_configs() -> list[triton.Config]:
# yapf: disable
return [
triton.Config(
{
"BLOCK_SIZE_M": 1, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0,
"matrix_instr_nonkdim": 16, "kpack": 1
}, num_warps=8, num_stages=0),
triton.Config(
{
"BLOCK_SIZE_M": 1, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0,
"matrix_instr_nonkdim": 16, "kpack": 2
}, num_warps=8, num_stages=0),
triton.Config(
{
"BLOCK_SIZE_M": 2, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0,
"matrix_instr_nonkdim": 16, "kpack": 1
}, num_warps=8, num_stages=0),
]
# yapf: enable


@triton.autotune(configs=get_triton_autotune_configs(), key=["M", "N", "K"])
@triton.heuristics({
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0,
})
@triton.jit
def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn,
stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr):
def triton_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M: int, N: int, K: int, stride_am: int, stride_ak: int,
stride_bk: int, stride_bn: int, stride_cm: int, stride_cn: int, stride_bias: int,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
USE_BIAS: tl.constexpr, EVEN_K: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
if BIAS:
if USE_BIAS:
bias_ptrs = bias_ptr + offs_am * stride_bias
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0)
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0)
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# Dot product implemented as explicit multiply-reduce:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0)
a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype)
b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype)
accumulator += tl.sum(a * b, axis=1)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(c_ptr.type.element_ty)
if BIAS:
if USE_BIAS:
c += bias[:, None]
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def triton_matmul(a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor:
M: int
N: int
K: int
M, K = a.shape
_, N = b.shape

c: Tensor = torch.empty((M, N), device=a.device, dtype=a.dtype)

def grid(args: dict[str, Any]) -> tuple[int]:
return (triton.cdiv(M, args["BLOCK_SIZE_M"]) * triton.cdiv(N, args["BLOCK_SIZE_N"]), )

triton_matmul_kernel[grid](
# Data pointers
a,
b,
c,
bias,
# Size of matrices
M,
N,
K,
# Strides
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
bias.stride(0) if bias is not None else 0,
# Other kernel parameters
USE_BIAS=bias is not None,
)

return c


# Wrapper for calling PyTorch GEMM or Triton GEMM:
# ------------------------------------------------


def matmul(provider: str, a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor:
assert provider in ["torch", "triton"]

assert a.is_cuda, "Matrix A must be in GPU."
assert a.is_contiguous(), "Matrix A must be continuous."
assert b.is_cuda, "Matrix B must be in GPU."
assert b.is_contiguous(), "Matrix B must be continuous."
assert a.device == b.device, "Matrix A and matrix B must be in the same GPU."
assert a.dtype == b.dtype, "Matrix A and matrix B must have the same data type."
assert a.dim() == b.dim() == 2, "Matrix A and matrix B must be two-dimensional tensors."
assert a.shape[1] == b.shape[0], "Matrix A columns must be equal to matrix B rows."

if bias is not None:
assert bias.is_cuda, "Bias vector must be in GPU."
assert bias.is_contiguous(), "Bias vector must be continuous."
assert bias.device == a.device, "Matrix A and bias vector must be in the same GPU."
assert bias.dtype == a.dtype, "Matrix A and bias vector must have the same data type."
assert bias.dim() == 1, "Bias vector must be one-dimensional tensor."
assert bias.shape == (a.shape[0], ), "Bias vector length must be equal to matrix A rows."

if provider == "torch":
return torch_matmul(a, b, bias)

return triton_matmul(a, b, bias)


# Run Triton GEMM:
# ----------------
# This is useful to run the kernel in isolation, in order to get performance traces for instance.


def run_triton_matmul(M: int, N: int, K: int, use_bias: bool) -> None:
a: Tensor
b: Tensor
bias: Optional[Tensor]
a, b, bias = gen_input(M, N, K, use_bias)
matmul("triton", a, b, bias)


# Test Triton GEMM, comparing it to PyTorch GEMM reference implementation:
# ------------------------------------------------------------------------
# It's a pytest suite, you can run it with `pytest multreduce_matmul_kernel.py`.


def get_target_shapes() -> list[tuple[int, int, int]]:
return [
(1, 8192, 28672),
(1, 6144, 6144),
(1, 4096, 4096),
(2, 16384, 16384),
]


@pytest.mark.parametrize('M, N, K', get_target_shapes())
@pytest.mark.parametrize("use_bias", [False, True])
def test_matmul(M: int, N: int, K: int, use_bias: bool) -> None:
a: Tensor
b: Tensor
bias: Optional[Tensor]
a, b, bias = gen_input(M, N, K, use_bias)
c_torch: Tensor = matmul("torch", a, b, bias)
c_triton: Tensor = matmul("triton", a, b, bias)
assert torch.allclose(c_torch, c_triton, atol=1e-3, rtol=1e-2), "PyTorch and Triton results don't match."


# Benchmark Triton GEMM, comparing it to PyTorch GEMM reference implementation:
# -----------------------------------------------------------------------------


def ms_to_gibps(M: int, N: int, K: int, milliseconds: float) -> float:
read_elems: int = M * K + K * N
write_elems: int = M * N
transf_elems: int = read_elems + write_elems
transf_bytes: int = 2 * transf_elems # times 2 due to fp16
transf_gibibytes: float = 2**-30 * transf_bytes
seconds: float = 1e-3 * milliseconds
return round(transf_gibibytes / seconds, 2)


# yapf: disable
@triton.testing.perf_report(triton.testing.Benchmark(
x_names=["M", "N", "K"],
x_vals=get_target_shapes(),
line_arg="provider",
line_vals=["torch", "triton"],
line_names=["PyTorch (GiB/s)", "Triton (GiB/s)"],
ylabel="GiB/s",
args={},
# Using empty `plot_name` because a 2D plot doesn't make sense for a subset of shapes.
# Furthermore, rendering the PNG plot image takes a considerable amount of time.
plot_name="",
))
# yapf: enable
def benchmark(M: int, N: int, K: int, provider: str) -> tuple[float, float, float]:
a: Tensor
b: Tensor
a, b, _ = gen_input(M, N, K, False)

p50_ms: float
p20_ms: float
p80_ms: float
p50_ms, p20_ms, p80_ms = triton.testing.do_bench(lambda: matmul(provider, a, b, None), quantiles=[0.5, 0.2, 0.8])

def perf(milliseconds: float) -> float:
return ms_to_gibps(M, N, K, milliseconds)

p50_gibps: float = perf(p50_ms)
print(f"(M, N, K) = {(M, N, K)}, provider = {provider}, p50 = {p50_gibps} GiB/s")
return p50_gibps, perf(p20_ms), perf(p80_ms)


def run_benchmark() -> None:
print("Running benchmark...")
benchmark.run(show_plots=False, print_data=True)
print("Done.")


# Script entry point:
# -------------------


def positive_int(value: str) -> int:
try:
int_value = int(value)
except ValueError:
raise argparse.ArgumentTypeError(f"{value} is not an integer.")
if int_value <= 0:
raise argparse.ArgumentTypeError(f"{value} is not a positive integer.")
return int_value


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="C = A * B + BIAS matrix multiplication kernel for small matrices",
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"mode", choices=["run", "check", "test", "bench"], help="mode of operation:\n"
" run: run Triton kernel for a given (M, N, K) shape\n"
" check: correctness check for a given (M, N, K) shape\n"
" test: full correctness check for target shapes\n"
" bench: benchmark performance for target shapes\n")
parser.add_argument("--seed", type=int, default=42, help="random seed for input generation")
shape_group = parser.add_argument_group("kernel shape arguments")
shape_group.add_argument("-M", type=positive_int, help="rows of matrix A")
shape_group.add_argument("-N", type=positive_int, help="columns of matrix A / rows of matrix B")
shape_group.add_argument("-K", type=positive_int, help="columns of matrix B")
shape_group.add_argument("--use-bias", default=False, action="store_true", help="use BIAS vector")
args = parser.parse_args()
if args.mode in ["run", "check"]:
try:
sizes: tuple[Optional[int], ...] = tuple(size for size in (args.M, args.N, args.K))
if any(size is None for size in sizes):
raise ValueError(f"(M, N, K) = {sizes}, all sizes must be specified together.")
if args.M > 8:
raise ValueError(f"M = {args.M} is too big, this kernel was designed for M ≤ 8.")
except ValueError as arg_error:
print(arg_error)
sys.exit(1)
return args


def main() -> int:
args: argparse.Namespace = parse_args()
torch.manual_seed(args.seed)
status: int = 0
match args.mode:
case "run":
run_triton_matmul(args.M, args.N, args.K, args.use_bias)
case "check":
try:
test_matmul(args.M, args.N, args.K, args.use_bias)
except AssertionError as assert_error:
print(assert_error)
status = 1
case "test":
status = pytest.main(["-vvv", __file__])
case "bench":
run_benchmark()
return status


if __name__ == "__main__":
sys.exit(main())

0 comments on commit c0c697a

Please sign in to comment.