Skip to content

Commit

Permalink
[AMD] [FrontEnd] Check is_within_2gb only when buffer ops on (#5898)
Browse files Browse the repository at this point in the history
Updates the AMD compiler backend to optimize the implementation of
`is_within_2gb` and gate its invocation behind buffer ops being enabled.
From internal testing we found that this check can be expensive and
since as part of the cache key it will be executed on every kernel
invocation. Therefore I believe it makes sense to add a fast path that
avoids it when buffer ops related optimizations are disabled.

This also may have an added benefit that if a kernel cannot benefit from
the buffer operations then the kernel won't need to be recompiled
because we have unnecessarily changed the keys based on the 2GB
threshold and therefore caused a cache miss.

---------

Co-authored-by: Nick Riasanovsky <[email protected]>
Co-authored-by: Nick Riasanovsky <[email protected]>
  • Loading branch information
3 people authored Feb 28, 2025
1 parent 3f2fb59 commit 37ff43c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 28 deletions.
62 changes: 40 additions & 22 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import itertools
import os
import shutil
import pathlib

Expand Down Expand Up @@ -553,28 +554,45 @@ def compiled_hook(*args, **kwargs):

@pytest.mark.skipif(reason="within_2g is a HIP specific optimization", condition=not is_hip())
def test_within_2gb(device, fresh_triton_cache) -> None:

@triton.jit
def kernel_add(a):
tl.load(a)

# This is the attribute we want to test
pointer_range_32 = None

def cache_hook(*args, **kwargs):
nonlocal pointer_range_32
pointer_range_32 = [k for k, v in kwargs["compile"]["configs"][0].items() if ['tt.pointer_range', 32] in v]

JITFunction.cache_hook = cache_hook
# In warmup we assume that the pointer range is 32 bits
kernel_add.warmup(torch.float32, grid=(1, ))
assert pointer_range_32 == [(0, )]
# Torch tensor > 2GB
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
assert len(pointer_range_32) == 0
# Torch tensor <= 2GB
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
assert pointer_range_32 == [(0, )]
default_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0")
from triton.backends import backends

amd_backend = backends["amd"]
try:
use_buffer_ops_opts = ["1", "0"]
# The ranges should only be available when buffer ops are enabled
pointer_ranges = [[(0, )], []]
for use_buffer_ops, pointer_range in zip(use_buffer_ops_opts, pointer_ranges):
# Set AMDGCN_USE_BUFFER_OPS
amd_backend.compiler.use_buffer_ops.cache_clear()
os.environ["AMDGCN_USE_BUFFER_OPS"] = use_buffer_ops

@triton.jit
def kernel_add(a):
tl.load(a)

# This is the attribute we want to test
pointer_range_32 = None

def cache_hook(*args, **kwargs):
nonlocal pointer_range_32
pointer_range_32 = [
k for k, v in kwargs["compile"]["configs"][0].items() if ["tt.pointer_range", 32] in v
]

JITFunction.cache_hook = cache_hook
# In warmup we assume that the pointer range is 32 bits
kernel_add.warmup(torch.float32, grid=(1, ))
assert pointer_range_32 == pointer_range
# Torch tensor > 2GB
kernel_add[(1, 0)](torch.empty(2**31, dtype=torch.int8, device=device))
assert len(pointer_range_32) == 0
# Torch tensor <= 2GB
kernel_add[(1, 0)](torch.empty(2**31 - 1, dtype=torch.int8, device=device))
assert pointer_range_32 == pointer_range
finally:
amd_backend.compiler.use_buffer_ops.cache_clear()
os.environ["AMDGCN_USE_BUFFER_OPS"] = default_buffer_ops


def test_function_arguments(device):
Expand Down
22 changes: 16 additions & 6 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,26 @@ def get_codegen_implementation(self, options):

def get_module_map(self) -> Dict[str, ModuleType]:
from triton.language.extra.hip import libdevice

return {"triton.language.extra.libdevice": libdevice}

def load_dialects(self, ctx):
amd.load_dialects(ctx)

@staticmethod
@functools.lru_cache()
def use_buffer_ops():
return os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"

@staticmethod
def is_within_2gb(arg):
import torch

MAX_INT_32 = 2**31 - 1
if hasattr(arg, "ptr_range"):
return arg.ptr_range() <= 2**31 - 1
if "torch.Tensor" in str(type(arg)) and hasattr(arg, "untyped_storage"):
return arg.untyped_storage().size() <= 2**31 - 1
return arg.ptr_range() <= MAX_INT_32
if isinstance(arg, torch.Tensor) and hasattr(arg, "untyped_storage"):
return arg.untyped_storage().size() <= MAX_INT_32
return False

@staticmethod
Expand All @@ -157,7 +166,9 @@ def parse_attr(desc):
@staticmethod
def get_arg_specialization(arg, ty, **kwargs):
ret = BaseBackend.get_arg_specialization(arg, ty, **kwargs)
if ty == "tensor" and HIPBackend.is_within_2gb(arg):
# Only attempt to do buffer ops specialization if buffer ops are enabled.
# Otherwise the is_within_2gb check is unnecessary overhead.
if HIPBackend.use_buffer_ops() and ty == "tensor" and HIPBackend.is_within_2gb(arg):
ret += "S"
return ret

Expand Down Expand Up @@ -241,8 +252,7 @@ def make_ttgir(mod, metadata, options):
if use_block_pingpong and options.num_stages == 2:
amd.passes.ttgpuir.add_block_pingpong(pm)

use_buffer_ops = os.environ.get("AMDGCN_USE_BUFFER_OPS", "0") == "1"
if use_buffer_ops:
if HIPBackend.use_buffer_ops():
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
passes.common.add_canonicalizer(pm)
amd.passes.ttgpuir.add_convert_to_buffer_ops(pm, options.arch)
Expand Down

0 comments on commit 37ff43c

Please sign in to comment.