forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add test and benchmark for explicit dot GEMM
- Loading branch information
1 parent
96b3d37
commit c0c697a
Showing
2 changed files
with
311 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
323 changes: 309 additions & 14 deletions
323
python/perf-kernels/multreduce_matmul_kernel.py
100644 → 100755
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,340 @@ | ||
#!/usr/bin/env python | ||
|
||
# -*- coding: utf-8 -*- | ||
|
||
# Imports: | ||
# -------- | ||
|
||
import argparse | ||
import sys | ||
from typing import Any, Optional | ||
|
||
import pytest | ||
import torch | ||
from torch import Tensor | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
# Input generation: | ||
# ----------------- | ||
|
||
|
||
def gen_input(M: int, N: int, K: int, use_bias: bool, device: str = "cuda") -> tuple[Tensor, Tensor, Optional[Tensor]]: | ||
assert M > 0, "M for input generation must be positive." | ||
assert M <= 8, "M for input generation must be less or equal to 8." | ||
assert N > 0, "N for input generation must be positive." | ||
assert K > 0, "K for input generation must be positive." | ||
|
||
a: Tensor = torch.randn((M, K), dtype=torch.float16, device=device) | ||
b: Tensor = torch.randn((K, N), dtype=a.dtype, device=a.device) | ||
bias: Optional[Tensor] = torch.randn(M, dtype=a.dtype, device=a.device) if use_bias else None | ||
|
||
return a, b, bias | ||
|
||
|
||
# PyTorch GEMM: | ||
# ------------- | ||
|
||
|
||
def torch_matmul(a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor: | ||
c: Tensor = torch.matmul(a, b) | ||
if bias is not None: | ||
c += bias[:, None] | ||
return c | ||
|
||
|
||
# Triton GEMM: | ||
# ------------ | ||
|
||
# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. | ||
# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch). | ||
|
||
def get_triton_autotune_configs() -> list[triton.Config]: | ||
# yapf: disable | ||
return [ | ||
triton.Config( | ||
{ | ||
"BLOCK_SIZE_M": 1, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0, | ||
"matrix_instr_nonkdim": 16, "kpack": 1 | ||
}, num_warps=8, num_stages=0), | ||
triton.Config( | ||
{ | ||
"BLOCK_SIZE_M": 1, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0, | ||
"matrix_instr_nonkdim": 16, "kpack": 2 | ||
}, num_warps=8, num_stages=0), | ||
triton.Config( | ||
{ | ||
"BLOCK_SIZE_M": 2, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0, | ||
"matrix_instr_nonkdim": 16, "kpack": 1 | ||
}, num_warps=8, num_stages=0), | ||
] | ||
# yapf: enable | ||
|
||
|
||
@triton.autotune(configs=get_triton_autotune_configs(), key=["M", "N", "K"]) | ||
@triton.heuristics({ | ||
"EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, | ||
}) | ||
@triton.jit | ||
def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, | ||
stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, | ||
BLOCK_SIZE_K: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr): | ||
def triton_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M: int, N: int, K: int, stride_am: int, stride_ak: int, | ||
stride_bk: int, stride_bn: int, stride_cm: int, stride_cn: int, stride_bias: int, | ||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | ||
USE_BIAS: tl.constexpr, EVEN_K: tl.constexpr): | ||
pid = tl.program_id(axis=0) | ||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
pid_m = pid // num_pid_n | ||
pid_n = pid % num_pid_n | ||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) | ||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) | ||
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak | ||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn | ||
if BIAS: | ||
if USE_BIAS: | ||
bias_ptrs = bias_ptr + offs_am * stride_bias | ||
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) | ||
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0) | ||
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 | ||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) | ||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||
if EVEN_K: | ||
a = tl.load(a_ptrs) | ||
b = tl.load(b_ptrs) | ||
else: | ||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | ||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | ||
# Dot product implemented as explicit multiply-reduce: | ||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) | ||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0) | ||
a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype) | ||
b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype) | ||
accumulator += tl.sum(a * b, axis=1) | ||
a_ptrs += BLOCK_SIZE_K * stride_ak | ||
b_ptrs += BLOCK_SIZE_K * stride_bk | ||
c = accumulator.to(c_ptr.type.element_ty) | ||
if BIAS: | ||
if USE_BIAS: | ||
c += bias[:, None] | ||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn | ||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
tl.store(c_ptrs, c, mask=c_mask) | ||
|
||
|
||
def triton_matmul(a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor: | ||
M: int | ||
N: int | ||
K: int | ||
M, K = a.shape | ||
_, N = b.shape | ||
|
||
c: Tensor = torch.empty((M, N), device=a.device, dtype=a.dtype) | ||
|
||
def grid(args: dict[str, Any]) -> tuple[int]: | ||
return (triton.cdiv(M, args["BLOCK_SIZE_M"]) * triton.cdiv(N, args["BLOCK_SIZE_N"]), ) | ||
|
||
triton_matmul_kernel[grid]( | ||
# Data pointers | ||
a, | ||
b, | ||
c, | ||
bias, | ||
# Size of matrices | ||
M, | ||
N, | ||
K, | ||
# Strides | ||
a.stride(0), | ||
a.stride(1), | ||
b.stride(0), | ||
b.stride(1), | ||
c.stride(0), | ||
c.stride(1), | ||
bias.stride(0) if bias is not None else 0, | ||
# Other kernel parameters | ||
USE_BIAS=bias is not None, | ||
) | ||
|
||
return c | ||
|
||
|
||
# Wrapper for calling PyTorch GEMM or Triton GEMM: | ||
# ------------------------------------------------ | ||
|
||
|
||
def matmul(provider: str, a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor: | ||
assert provider in ["torch", "triton"] | ||
|
||
assert a.is_cuda, "Matrix A must be in GPU." | ||
assert a.is_contiguous(), "Matrix A must be continuous." | ||
assert b.is_cuda, "Matrix B must be in GPU." | ||
assert b.is_contiguous(), "Matrix B must be continuous." | ||
assert a.device == b.device, "Matrix A and matrix B must be in the same GPU." | ||
assert a.dtype == b.dtype, "Matrix A and matrix B must have the same data type." | ||
assert a.dim() == b.dim() == 2, "Matrix A and matrix B must be two-dimensional tensors." | ||
assert a.shape[1] == b.shape[0], "Matrix A columns must be equal to matrix B rows." | ||
|
||
if bias is not None: | ||
assert bias.is_cuda, "Bias vector must be in GPU." | ||
assert bias.is_contiguous(), "Bias vector must be continuous." | ||
assert bias.device == a.device, "Matrix A and bias vector must be in the same GPU." | ||
assert bias.dtype == a.dtype, "Matrix A and bias vector must have the same data type." | ||
assert bias.dim() == 1, "Bias vector must be one-dimensional tensor." | ||
assert bias.shape == (a.shape[0], ), "Bias vector length must be equal to matrix A rows." | ||
|
||
if provider == "torch": | ||
return torch_matmul(a, b, bias) | ||
|
||
return triton_matmul(a, b, bias) | ||
|
||
|
||
# Run Triton GEMM: | ||
# ---------------- | ||
# This is useful to run the kernel in isolation, in order to get performance traces for instance. | ||
|
||
|
||
def run_triton_matmul(M: int, N: int, K: int, use_bias: bool) -> None: | ||
a: Tensor | ||
b: Tensor | ||
bias: Optional[Tensor] | ||
a, b, bias = gen_input(M, N, K, use_bias) | ||
matmul("triton", a, b, bias) | ||
|
||
|
||
# Test Triton GEMM, comparing it to PyTorch GEMM reference implementation: | ||
# ------------------------------------------------------------------------ | ||
# It's a pytest suite, you can run it with `pytest multreduce_matmul_kernel.py`. | ||
|
||
|
||
def get_target_shapes() -> list[tuple[int, int, int]]: | ||
return [ | ||
(1, 8192, 28672), | ||
(1, 6144, 6144), | ||
(1, 4096, 4096), | ||
(2, 16384, 16384), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize('M, N, K', get_target_shapes()) | ||
@pytest.mark.parametrize("use_bias", [False, True]) | ||
def test_matmul(M: int, N: int, K: int, use_bias: bool) -> None: | ||
a: Tensor | ||
b: Tensor | ||
bias: Optional[Tensor] | ||
a, b, bias = gen_input(M, N, K, use_bias) | ||
c_torch: Tensor = matmul("torch", a, b, bias) | ||
c_triton: Tensor = matmul("triton", a, b, bias) | ||
assert torch.allclose(c_torch, c_triton, atol=1e-3, rtol=1e-2), "PyTorch and Triton results don't match." | ||
|
||
|
||
# Benchmark Triton GEMM, comparing it to PyTorch GEMM reference implementation: | ||
# ----------------------------------------------------------------------------- | ||
|
||
|
||
def ms_to_gibps(M: int, N: int, K: int, milliseconds: float) -> float: | ||
read_elems: int = M * K + K * N | ||
write_elems: int = M * N | ||
transf_elems: int = read_elems + write_elems | ||
transf_bytes: int = 2 * transf_elems # times 2 due to fp16 | ||
transf_gibibytes: float = 2**-30 * transf_bytes | ||
seconds: float = 1e-3 * milliseconds | ||
return round(transf_gibibytes / seconds, 2) | ||
|
||
|
||
# yapf: disable | ||
@triton.testing.perf_report(triton.testing.Benchmark( | ||
x_names=["M", "N", "K"], | ||
x_vals=get_target_shapes(), | ||
line_arg="provider", | ||
line_vals=["torch", "triton"], | ||
line_names=["PyTorch (GiB/s)", "Triton (GiB/s)"], | ||
ylabel="GiB/s", | ||
args={}, | ||
# Using empty `plot_name` because a 2D plot doesn't make sense for a subset of shapes. | ||
# Furthermore, rendering the PNG plot image takes a considerable amount of time. | ||
plot_name="", | ||
)) | ||
# yapf: enable | ||
def benchmark(M: int, N: int, K: int, provider: str) -> tuple[float, float, float]: | ||
a: Tensor | ||
b: Tensor | ||
a, b, _ = gen_input(M, N, K, False) | ||
|
||
p50_ms: float | ||
p20_ms: float | ||
p80_ms: float | ||
p50_ms, p20_ms, p80_ms = triton.testing.do_bench(lambda: matmul(provider, a, b, None), quantiles=[0.5, 0.2, 0.8]) | ||
|
||
def perf(milliseconds: float) -> float: | ||
return ms_to_gibps(M, N, K, milliseconds) | ||
|
||
p50_gibps: float = perf(p50_ms) | ||
print(f"(M, N, K) = {(M, N, K)}, provider = {provider}, p50 = {p50_gibps} GiB/s") | ||
return p50_gibps, perf(p20_ms), perf(p80_ms) | ||
|
||
|
||
def run_benchmark() -> None: | ||
print("Running benchmark...") | ||
benchmark.run(show_plots=False, print_data=True) | ||
print("Done.") | ||
|
||
|
||
# Script entry point: | ||
# ------------------- | ||
|
||
|
||
def positive_int(value: str) -> int: | ||
try: | ||
int_value = int(value) | ||
except ValueError: | ||
raise argparse.ArgumentTypeError(f"{value} is not an integer.") | ||
if int_value <= 0: | ||
raise argparse.ArgumentTypeError(f"{value} is not a positive integer.") | ||
return int_value | ||
|
||
|
||
def parse_args() -> argparse.Namespace: | ||
parser = argparse.ArgumentParser(description="C = A * B + BIAS matrix multiplication kernel for small matrices", | ||
formatter_class=argparse.RawTextHelpFormatter) | ||
parser.add_argument( | ||
"mode", choices=["run", "check", "test", "bench"], help="mode of operation:\n" | ||
" run: run Triton kernel for a given (M, N, K) shape\n" | ||
" check: correctness check for a given (M, N, K) shape\n" | ||
" test: full correctness check for target shapes\n" | ||
" bench: benchmark performance for target shapes\n") | ||
parser.add_argument("--seed", type=int, default=42, help="random seed for input generation") | ||
shape_group = parser.add_argument_group("kernel shape arguments") | ||
shape_group.add_argument("-M", type=positive_int, help="rows of matrix A") | ||
shape_group.add_argument("-N", type=positive_int, help="columns of matrix A / rows of matrix B") | ||
shape_group.add_argument("-K", type=positive_int, help="columns of matrix B") | ||
shape_group.add_argument("--use-bias", default=False, action="store_true", help="use BIAS vector") | ||
args = parser.parse_args() | ||
if args.mode in ["run", "check"]: | ||
try: | ||
sizes: tuple[Optional[int], ...] = tuple(size for size in (args.M, args.N, args.K)) | ||
if any(size is None for size in sizes): | ||
raise ValueError(f"(M, N, K) = {sizes}, all sizes must be specified together.") | ||
if args.M > 8: | ||
raise ValueError(f"M = {args.M} is too big, this kernel was designed for M ≤ 8.") | ||
except ValueError as arg_error: | ||
print(arg_error) | ||
sys.exit(1) | ||
return args | ||
|
||
|
||
def main() -> int: | ||
args: argparse.Namespace = parse_args() | ||
torch.manual_seed(args.seed) | ||
status: int = 0 | ||
match args.mode: | ||
case "run": | ||
run_triton_matmul(args.M, args.N, args.K, args.use_bias) | ||
case "check": | ||
try: | ||
test_matmul(args.M, args.N, args.K, args.use_bias) | ||
except AssertionError as assert_error: | ||
print(assert_error) | ||
status = 1 | ||
case "test": | ||
status = pytest.main(["-vvv", __file__]) | ||
case "bench": | ||
run_benchmark() | ||
return status | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |