From c0c697a738374a9c178fc8cea8f03833557bf30b Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Thu, 5 Sep 2024 14:26:18 +0000 Subject: [PATCH] Add test and benchmark for explicit dot GEMM --- .../amd_perf_kernel_Integration_tests.yml | 2 + .../perf-kernels/multreduce_matmul_kernel.py | 323 +++++++++++++++++- 2 files changed, 311 insertions(+), 14 deletions(-) mode change 100644 => 100755 python/perf-kernels/multreduce_matmul_kernel.py diff --git a/.github/workflows/amd_perf_kernel_Integration_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml index 61e44c4859d0..e5aa6ce902ae 100644 --- a/.github/workflows/amd_perf_kernel_Integration_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -128,8 +128,10 @@ jobs: pytest -vvv ./python/perf-kernels/flash-attention.py pytest -vvvv ./python/perf-kernels/softmax.py pytest -vvv ./python/perf-kernels/rmsnorm.py + pytest -vvv ./python/perf-kernels/multreduce_matmul_kernel.py - name: Run Perf Kernels Benchmark run: | python ./python/perf-kernels/flash-attention.py python ./python/perf-kernels/softmax.py python ./python/perf-kernels/rmsnorm.py + python ./python/perf-kernels/multreduce_matmul_kernel.py bench diff --git a/python/perf-kernels/multreduce_matmul_kernel.py b/python/perf-kernels/multreduce_matmul_kernel.py old mode 100644 new mode 100755 index 61535d5bcdd3..3f0d8641bf31 --- a/python/perf-kernels/multreduce_matmul_kernel.py +++ b/python/perf-kernels/multreduce_matmul_kernel.py @@ -1,25 +1,96 @@ +#!/usr/bin/env python + +# -*- coding: utf-8 -*- + +# Imports: +# -------- + +import argparse +import sys +from typing import Any, Optional + +import pytest +import torch +from torch import Tensor + import triton import triton.language as tl +# Input generation: +# ----------------- + + +def gen_input(M: int, N: int, K: int, use_bias: bool, device: str = "cuda") -> tuple[Tensor, Tensor, Optional[Tensor]]: + assert M > 0, "M for input generation must be positive." + assert M <= 8, "M for input generation must be less or equal to 8." + assert N > 0, "N for input generation must be positive." + assert K > 0, "K for input generation must be positive." + + a: Tensor = torch.randn((M, K), dtype=torch.float16, device=device) + b: Tensor = torch.randn((K, N), dtype=a.dtype, device=a.device) + bias: Optional[Tensor] = torch.randn(M, dtype=a.dtype, device=a.device) if use_bias else None + + return a, b, bias + + +# PyTorch GEMM: +# ------------- + + +def torch_matmul(a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor: + c: Tensor = torch.matmul(a, b) + if bias is not None: + c += bias[:, None] + return c + + +# Triton GEMM: +# ------------ -# Kernel that implements GEMM with explicit multiply-reduce instructions for small block sizes. -# Based on **tune_gemm** `matmul_kernel` from commit `cf44637` (see `triton-mlir` branch). + +def get_triton_autotune_configs() -> list[triton.Config]: + # yapf: disable + return [ + triton.Config( + { + "BLOCK_SIZE_M": 1, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, "kpack": 1 + }, num_warps=8, num_stages=0), + triton.Config( + { + "BLOCK_SIZE_M": 1, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, "kpack": 2 + }, num_warps=8, num_stages=0), + triton.Config( + { + "BLOCK_SIZE_M": 2, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, "kpack": 1 + }, num_warps=8, num_stages=0), + ] + # yapf: enable + + +@triton.autotune(configs=get_triton_autotune_configs(), key=["M", "N", "K"]) +@triton.heuristics({ + "EVEN_K": lambda args: args["K"] % args["BLOCK_SIZE_K"] == 0, +}) @triton.jit -def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, - stride_cm, stride_cn, stride_bias, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, BIAS: tl.constexpr, EVEN_K: tl.constexpr): +def triton_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M: int, N: int, K: int, stride_am: int, stride_ak: int, + stride_bk: int, stride_bn: int, stride_cm: int, stride_cn: int, stride_bias: int, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + USE_BIAS: tl.constexpr, EVEN_K: tl.constexpr): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn - if BIAS: + if USE_BIAS: bias_ptrs = bias_ptr + offs_am * stride_bias - bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) + bias = tl.load(bias_ptrs, mask=offs_am < M, other=0) acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): @@ -27,19 +98,243 @@ def multreduce_matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, a = tl.load(a_ptrs) b = tl.load(b_ptrs) else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - # Dot product implemented as explicit multiply-reduce: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0) a = tl.reshape(a, (BLOCK_SIZE_M, BLOCK_SIZE_K, 1)).to(acc_dtype) b = tl.reshape(b, (1, BLOCK_SIZE_K, BLOCK_SIZE_N)).to(acc_dtype) accumulator += tl.sum(a * b, axis=1) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk c = accumulator.to(c_ptr.type.element_ty) - if BIAS: + if USE_BIAS: c += bias[:, None] offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_ptrs = c_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) + + +def triton_matmul(a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor: + M: int + N: int + K: int + M, K = a.shape + _, N = b.shape + + c: Tensor = torch.empty((M, N), device=a.device, dtype=a.dtype) + + def grid(args: dict[str, Any]) -> tuple[int]: + return (triton.cdiv(M, args["BLOCK_SIZE_M"]) * triton.cdiv(N, args["BLOCK_SIZE_N"]), ) + + triton_matmul_kernel[grid]( + # Data pointers + a, + b, + c, + bias, + # Size of matrices + M, + N, + K, + # Strides + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + bias.stride(0) if bias is not None else 0, + # Other kernel parameters + USE_BIAS=bias is not None, + ) + + return c + + +# Wrapper for calling PyTorch GEMM or Triton GEMM: +# ------------------------------------------------ + + +def matmul(provider: str, a: Tensor, b: Tensor, bias: Optional[Tensor]) -> Tensor: + assert provider in ["torch", "triton"] + + assert a.is_cuda, "Matrix A must be in GPU." + assert a.is_contiguous(), "Matrix A must be continuous." + assert b.is_cuda, "Matrix B must be in GPU." + assert b.is_contiguous(), "Matrix B must be continuous." + assert a.device == b.device, "Matrix A and matrix B must be in the same GPU." + assert a.dtype == b.dtype, "Matrix A and matrix B must have the same data type." + assert a.dim() == b.dim() == 2, "Matrix A and matrix B must be two-dimensional tensors." + assert a.shape[1] == b.shape[0], "Matrix A columns must be equal to matrix B rows." + + if bias is not None: + assert bias.is_cuda, "Bias vector must be in GPU." + assert bias.is_contiguous(), "Bias vector must be continuous." + assert bias.device == a.device, "Matrix A and bias vector must be in the same GPU." + assert bias.dtype == a.dtype, "Matrix A and bias vector must have the same data type." + assert bias.dim() == 1, "Bias vector must be one-dimensional tensor." + assert bias.shape == (a.shape[0], ), "Bias vector length must be equal to matrix A rows." + + if provider == "torch": + return torch_matmul(a, b, bias) + + return triton_matmul(a, b, bias) + + +# Run Triton GEMM: +# ---------------- +# This is useful to run the kernel in isolation, in order to get performance traces for instance. + + +def run_triton_matmul(M: int, N: int, K: int, use_bias: bool) -> None: + a: Tensor + b: Tensor + bias: Optional[Tensor] + a, b, bias = gen_input(M, N, K, use_bias) + matmul("triton", a, b, bias) + + +# Test Triton GEMM, comparing it to PyTorch GEMM reference implementation: +# ------------------------------------------------------------------------ +# It's a pytest suite, you can run it with `pytest multreduce_matmul_kernel.py`. + + +def get_target_shapes() -> list[tuple[int, int, int]]: + return [ + (1, 8192, 28672), + (1, 6144, 6144), + (1, 4096, 4096), + (2, 16384, 16384), + ] + + +@pytest.mark.parametrize('M, N, K', get_target_shapes()) +@pytest.mark.parametrize("use_bias", [False, True]) +def test_matmul(M: int, N: int, K: int, use_bias: bool) -> None: + a: Tensor + b: Tensor + bias: Optional[Tensor] + a, b, bias = gen_input(M, N, K, use_bias) + c_torch: Tensor = matmul("torch", a, b, bias) + c_triton: Tensor = matmul("triton", a, b, bias) + assert torch.allclose(c_torch, c_triton, atol=1e-3, rtol=1e-2), "PyTorch and Triton results don't match." + + +# Benchmark Triton GEMM, comparing it to PyTorch GEMM reference implementation: +# ----------------------------------------------------------------------------- + + +def ms_to_gibps(M: int, N: int, K: int, milliseconds: float) -> float: + read_elems: int = M * K + K * N + write_elems: int = M * N + transf_elems: int = read_elems + write_elems + transf_bytes: int = 2 * transf_elems # times 2 due to fp16 + transf_gibibytes: float = 2**-30 * transf_bytes + seconds: float = 1e-3 * milliseconds + return round(transf_gibibytes / seconds, 2) + + +# yapf: disable +@triton.testing.perf_report(triton.testing.Benchmark( + x_names=["M", "N", "K"], + x_vals=get_target_shapes(), + line_arg="provider", + line_vals=["torch", "triton"], + line_names=["PyTorch (GiB/s)", "Triton (GiB/s)"], + ylabel="GiB/s", + args={}, + # Using empty `plot_name` because a 2D plot doesn't make sense for a subset of shapes. + # Furthermore, rendering the PNG plot image takes a considerable amount of time. + plot_name="", +)) +# yapf: enable +def benchmark(M: int, N: int, K: int, provider: str) -> tuple[float, float, float]: + a: Tensor + b: Tensor + a, b, _ = gen_input(M, N, K, False) + + p50_ms: float + p20_ms: float + p80_ms: float + p50_ms, p20_ms, p80_ms = triton.testing.do_bench(lambda: matmul(provider, a, b, None), quantiles=[0.5, 0.2, 0.8]) + + def perf(milliseconds: float) -> float: + return ms_to_gibps(M, N, K, milliseconds) + + p50_gibps: float = perf(p50_ms) + print(f"(M, N, K) = {(M, N, K)}, provider = {provider}, p50 = {p50_gibps} GiB/s") + return p50_gibps, perf(p20_ms), perf(p80_ms) + + +def run_benchmark() -> None: + print("Running benchmark...") + benchmark.run(show_plots=False, print_data=True) + print("Done.") + + +# Script entry point: +# ------------------- + + +def positive_int(value: str) -> int: + try: + int_value = int(value) + except ValueError: + raise argparse.ArgumentTypeError(f"{value} is not an integer.") + if int_value <= 0: + raise argparse.ArgumentTypeError(f"{value} is not a positive integer.") + return int_value + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="C = A * B + BIAS matrix multiplication kernel for small matrices", + formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument( + "mode", choices=["run", "check", "test", "bench"], help="mode of operation:\n" + " run: run Triton kernel for a given (M, N, K) shape\n" + " check: correctness check for a given (M, N, K) shape\n" + " test: full correctness check for target shapes\n" + " bench: benchmark performance for target shapes\n") + parser.add_argument("--seed", type=int, default=42, help="random seed for input generation") + shape_group = parser.add_argument_group("kernel shape arguments") + shape_group.add_argument("-M", type=positive_int, help="rows of matrix A") + shape_group.add_argument("-N", type=positive_int, help="columns of matrix A / rows of matrix B") + shape_group.add_argument("-K", type=positive_int, help="columns of matrix B") + shape_group.add_argument("--use-bias", default=False, action="store_true", help="use BIAS vector") + args = parser.parse_args() + if args.mode in ["run", "check"]: + try: + sizes: tuple[Optional[int], ...] = tuple(size for size in (args.M, args.N, args.K)) + if any(size is None for size in sizes): + raise ValueError(f"(M, N, K) = {sizes}, all sizes must be specified together.") + if args.M > 8: + raise ValueError(f"M = {args.M} is too big, this kernel was designed for M ≤ 8.") + except ValueError as arg_error: + print(arg_error) + sys.exit(1) + return args + + +def main() -> int: + args: argparse.Namespace = parse_args() + torch.manual_seed(args.seed) + status: int = 0 + match args.mode: + case "run": + run_triton_matmul(args.M, args.N, args.K, args.use_bias) + case "check": + try: + test_matmul(args.M, args.N, args.K, args.use_bias) + except AssertionError as assert_error: + print(assert_error) + status = 1 + case "test": + status = pytest.main(["-vvv", __file__]) + case "bench": + run_benchmark() + return status + + +if __name__ == "__main__": + sys.exit(main())