forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Rahul Batra
committed
Sep 16, 2024
1 parent
c4bd738
commit 674d526
Showing
2 changed files
with
297 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
import argparse | ||
import sys | ||
import pytest | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
def is_cuda(): | ||
return triton.runtime.driver.active.get_current_target().backend == "cuda" | ||
|
||
|
||
def is_hip(): | ||
return triton.runtime.driver.active.get_current_target().backend == "hip" | ||
|
||
|
||
def is_cdna(): | ||
return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', | ||
'gfx90a', 'gfx908') | ||
|
||
|
||
def get_cuda_autotune_config(): | ||
return [ | ||
triton.Config({}, num_warps=4, num_stages=1), | ||
triton.Config({}, num_warps=8, num_stages=1), | ||
triton.Config({}, num_warps=16, num_stages=1), | ||
] | ||
|
||
|
||
def get_hip_autotune_config(): | ||
return [ | ||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), | ||
] | ||
|
||
|
||
def get_autotune_config(): | ||
if is_cuda(): | ||
return get_cuda_autotune_config() | ||
else: | ||
return get_hip_autotune_config() | ||
|
||
|
||
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def layernorm_kernel_my(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps, | ||
BLOCK_SIZE: tl.constexpr): | ||
|
||
#program id | ||
row = tl.program_id(0) | ||
x_ptr_start = x_ptr + (row * x_row_stride) | ||
y_ptr_start = y_ptr + (row * y_row_stride) | ||
|
||
#calculate mean | ||
mean = 0 | ||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
for b in range(0, n_cols, BLOCK_SIZE): | ||
col_offsets = b + tl.arange(0, BLOCK_SIZE) | ||
x_block = tl.load(x_ptr_start + col_offsets, mask=col_offsets < n_cols, other=0.).to(tl.float32) | ||
_mean += x_block | ||
mean = tl.sum(_mean, axis=0) / n_cols | ||
|
||
#variance | ||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
for b in range(0, n_cols, BLOCK_SIZE): | ||
col_offsets = b + tl.arange(0, BLOCK_SIZE) | ||
x_block = tl.load(x_ptr_start + col_offsets, mask=col_offsets < n_cols, other=0.).to(tl.float32) | ||
x_block = tl.where(col_offsets < n_cols, x_block - mean, 0.) | ||
_var += x_block * x_block | ||
var = tl.sum(_var, axis=0) / n_cols | ||
rstd = tl.rsqrt(var + eps) | ||
|
||
#Normalize and store | ||
for b in range(0, n_cols, BLOCK_SIZE): | ||
col_offsets = b + tl.arange(0, BLOCK_SIZE) | ||
mask = col_offsets < n_cols | ||
w_block = tl.load(w_ptr + col_offsets, mask=mask) | ||
b_block = tl.load(b_ptr + col_offsets, mask=mask) | ||
x_block = tl.load(x_ptr_start + col_offsets, mask=mask) | ||
y_block = (x_block - mean) * rstd | ||
y_block = y_block * w_block + b_block | ||
tl.store(y_ptr_start + col_offsets, y_block, mask=mask) | ||
|
||
|
||
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def layernorm_kernel( | ||
X, # pointer to the input | ||
Y, # pointer to the output | ||
W, # pointer to the weights | ||
B, # pointer to the biases | ||
x_stride, # how much to increase the pointer when moving by 1 row | ||
y_stride, # how much to increase the pointer when moving by 1 row | ||
n_rows, | ||
n_cols, # number of columns in X | ||
eps, # epsilon to avoid division by zero | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
# Map the program id to the row of X and Y it should compute. | ||
N = n_cols | ||
stride = x_stride | ||
row = tl.program_id(0) | ||
Y += row * stride | ||
X += row * stride | ||
# Compute mean | ||
mean = 0 | ||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) | ||
_mean += a | ||
mean = tl.sum(_mean, axis=0) / N | ||
# Compute variance | ||
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) | ||
x = tl.where(cols < N, x - mean, 0.) | ||
_var += x * x | ||
var = tl.sum(_var, axis=0) / N | ||
rstd = 1 / tl.sqrt(var + eps) | ||
|
||
# Normalize and apply linear transformation | ||
for off in range(0, N, BLOCK_SIZE): | ||
cols = off + tl.arange(0, BLOCK_SIZE) | ||
mask = cols < N | ||
w = tl.load(W + cols, mask=mask) | ||
b = tl.load(B + cols, mask=mask) | ||
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) | ||
x_hat = (x - mean) * rstd | ||
y = x_hat * w + b | ||
# Write output | ||
tl.store(Y + cols, y, mask=mask) | ||
|
||
|
||
def layernorm(x, w, b, eps=1e-5): | ||
n_rows, n_cols = x.shape | ||
|
||
MAX_FUSED_SIZE = 65536 // x.element_size() | ||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(n_cols)) | ||
y = torch.empty_like(x) | ||
|
||
num_programs = n_rows | ||
|
||
grid = lambda meta: (num_programs, ) | ||
layernorm_kernel[grid]( | ||
x, | ||
y, | ||
w, | ||
b, | ||
x.stride(0), | ||
y.stride(0), | ||
n_rows, | ||
n_cols, | ||
eps, | ||
BLOCK_SIZE, | ||
) | ||
|
||
return y | ||
|
||
|
||
def torch_layernorm(x, w, b): | ||
M, N = x.shape | ||
w_shape = (N, ) | ||
y_torch = torch.nn.functional.layer_norm(x, w_shape, w, b, eps=1e-5) | ||
return y_torch | ||
|
||
|
||
def run_layernorm(M, N): | ||
print(f"Running Layernorm on shape ({M},{N})") | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
w_shape = (N, ) | ||
w = torch.rand(w_shape, device='cuda') | ||
b = torch.rand(w_shape, device='cuda') | ||
y_triton = layernorm(x, w, b) | ||
|
||
return y_triton | ||
|
||
|
||
#pytest | ||
@pytest.mark.parametrize('M, N', [(1823, 781), (2, 128), (1, 4), (128, 2), (1, 128), (8192, 8192), (4096, 8192), | ||
(359, 1), (1, 359), (1, 131072), (1, 89999)]) | ||
def test_layernorm(M, N, eps=1e-5): | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
w_shape = (N, ) | ||
w = torch.rand(w_shape, device='cuda') | ||
b = torch.rand(w_shape, device='cuda') | ||
y_triton = layernorm(x, w, b, eps) | ||
y_torch = torch.nn.functional.layer_norm(x, w_shape, w, b, eps) | ||
|
||
print(y_triton) | ||
print(y_torch) | ||
assert torch.allclose(y_triton, y_torch, rtol=1e-05, atol=1e-06) | ||
|
||
|
||
#Benchmark | ||
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} | ||
|
||
|
||
def run_benchmark(args): | ||
config = [] | ||
if (args.M_benchmark): | ||
val = args.M_start | ||
x_vals_list = [] | ||
while val <= args.M_end: | ||
x_vals_list.append(val) | ||
val *= args.M_step | ||
mn_args = {'N': args.N_start} | ||
plot_name = str("layernorm-performance_" + args.dtype + "_N" + str(args.N_start) + "_M" + str(args.M_start) + | ||
"-" + str(args.M_end) + "-" + str(args.M_step)) | ||
x_names = ['M'] | ||
else: | ||
x_vals_list = [i for i in range(args.N_start, args.N_end, args.N_step)] | ||
mn_args = {'M': args.M_start} | ||
plot_name = str("layernorm-performance_" + args.dtype + "_M" + str(args.M_start) + "_N" + str(args.N_start) + | ||
"-" + str(args.N_end) + "-" + str(args.N_step)) | ||
x_names = ['N'] | ||
dtype = arg_to_torch_dtype[args.dtype] | ||
|
||
print(plot_name) | ||
config.append( | ||
triton.testing.Benchmark( | ||
x_names=x_names, | ||
x_vals=x_vals_list, | ||
line_arg='provider', | ||
line_vals=['triton', 'torch'], | ||
line_names=[ | ||
"Triton", | ||
"Torch", | ||
], | ||
styles=[('blue', '-'), ('green', '-')], | ||
ylabel="GB/s", | ||
plot_name=plot_name, | ||
args=mn_args, | ||
)) | ||
|
||
@triton.testing.perf_report(config) | ||
def benchmark(M, N, provider): | ||
x = torch.randn(M, N, device='cuda', dtype=dtype) | ||
w_shape = (N, ) | ||
w = torch.rand(w_shape, device='cuda', dtype=dtype) | ||
b = torch.rand(w_shape, device='cuda', dtype=dtype) | ||
stream = torch.cuda.Stream() | ||
torch.cuda.set_stream(stream) | ||
if provider == 'torch': | ||
ms = triton.testing.do_bench(lambda: torch_layernorm(x, w, b)) | ||
if provider == 'triton': | ||
ms = triton.testing.do_bench(lambda: layernorm(x, w, b)) | ||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) | ||
return gbps(ms) | ||
|
||
benchmark.run(save_path=".", show_plots=True, print_data=True) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
prog="Benchmark Layernorm", | ||
allow_abbrev=False, | ||
) | ||
|
||
parser.add_argument('-M', "--M_start", default="1", type=int) | ||
parser.add_argument('-Ms', "--M_step", default="2", type=int) | ||
parser.add_argument('-Me', "--M_end", default="512", type=int) | ||
parser.add_argument('-Mb', "--M_benchmark", default=False, type=bool) | ||
|
||
parser.add_argument('-N', "--N_start", default="1024", type=int) | ||
parser.add_argument('-Ns', "--N_step", default="2048", type=int) | ||
parser.add_argument('-Ne', "--N_end", default="65536", type=int) | ||
|
||
parser.add_argument('-d', "--dtype", default="fp16") | ||
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
if args.no_benchmark: | ||
run_layernorm(args.M_start, args.N_start) | ||
else: | ||
run_benchmark(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(main()) |