Skip to content

Commit

Permalink
revert remaining unnecessary changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Riasanovsky committed Feb 25, 2025
1 parent 5d9b54c commit ca6fc11
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 63 deletions.
1 change: 0 additions & 1 deletion python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def inc_counter(*args, **kwargs):
x = torch.empty(1, dtype=torch.int32, device=device)
function = {'enable': kernel, 'disable': kernel_nospec, 'disable_on_alignment': kernel_nospec_on_alignment}[mode]
target = {'enable': 3, 'disable': 1, 'disable_on_alignment': 2}[mode]
target = {"enable": 3, "disable": 1, "disable_on_alignment": 2}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1, )](x, i, BLOCK=512)
assert counter == target
Expand Down
103 changes: 41 additions & 62 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import functools
from triton.backends.compiler import BaseBackend, GPUTarget
from triton._C.libtriton import ir, passes, llvm, amd
from dataclasses import dataclass
from typing import Any, Dict, Tuple
from types import ModuleType
import hashlib
import tempfile
import os
import re
import subprocess
import tempfile
from dataclasses import dataclass
import functools
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, Tuple

import torch
from triton._C.libtriton import amd, ir, llvm, passes
from triton.backends.compiler import BaseBackend, GPUTarget


def min_dot_size(target: GPUTarget):
Expand Down Expand Up @@ -40,7 +39,7 @@ class HIPOptions:
kpack: int = 1
allow_flush_denorm: bool = False
max_num_imprecise_acc_default: int = 0
backend_name: str = "hip"
backend_name: str = 'hip'

# The following option provides hints to the AMDGPU backend regarding instruction scheduling
# for all `tt.dot` operations in a kernel. The "none" variant preserves the default
Expand All @@ -59,52 +58,53 @@ class HIPOptions:
# Kernel library. Note, this variant requires the use of buffer load/store ops
# and a special software pipelining style - i.e., 1x LDS and 1x register
# prefetch buffers for each GEMM tile.
instruction_sched_variant: str = "none"
instruction_sched_variant: str = 'none'

def __post_init__(self):
default_libdir = Path(__file__).parent / "lib"
default_libdir = Path(__file__).parent / 'lib'
extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
# Ignore user-defined warp size for gfx9
warp_size = (32 if "gfx10" in self.arch or "gfx11" in self.arch or "gfx12" in self.arch else 64)
object.__setattr__(self, "warp_size", warp_size)
warp_size = 32 if 'gfx10' in self.arch or 'gfx11' in self.arch or 'gfx12' in self.arch else 64
object.__setattr__(self, 'warp_size', warp_size)
# Only kpack=1 is supported on gfx950
kpack = 1 if self.arch == "gfx950" else self.kpack
object.__setattr__(self, "kpack", kpack)
kpack = 1 if self.arch == 'gfx950' else self.kpack
object.__setattr__(self, 'kpack', kpack)
libs = ["ocml", "ockl"]
for lib in libs:
extern_libs[lib] = str(default_libdir / f"{lib}.bc")
object.__setattr__(self, "extern_libs", tuple(extern_libs.items()))
assert (self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0), "num_warps must be a power of 2"
extern_libs[lib] = str(default_libdir / f'{lib}.bc')
object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
"num_warps must be a power of 2"

def hash(self):
key = "_".join([f"{name}-{val}" for name, val in self.__dict__.items()])
key = '_'.join([f'{name}-{val}' for name, val in self.__dict__.items()])
return hashlib.sha256(key.encode("utf-8")).hexdigest()


class HIPBackend(BaseBackend):

@staticmethod
def supports_target(target: GPUTarget):
return target.backend == "hip"
return target.backend == 'hip'

def __init__(self, target: GPUTarget) -> None:
super().__init__(target)
assert isinstance(target.arch, str)
self.binary_ext = "hsaco"

def parse_options(self, opts) -> Any:
args = {"arch": os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}
args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)}

# Enable XF32 (TF32) for CDNA3 GPUs
if self.target.arch in ("gfx940", "gfx941", "gfx942"):
if self.target.arch in ('gfx940', 'gfx941', 'gfx942'):
allowed_dot_input_precisions = set(HIPOptions.allowed_dot_input_precisions)
allowed_dot_input_precisions.update({"tf32"})
allowed_dot_input_precisions.update({'tf32'})
args["allowed_dot_input_precisions"] = tuple(sorted(allowed_dot_input_precisions))

if "supported_fp8_dtypes" not in opts:
supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes)
if self.target.arch in ("gfx940", "gfx941", "gfx942", "gfx950"):
supported_fp8_dtypes.update({"fp8e4nv", "fp8e4b8", "fp8e5b16"})
if self.target.arch in ('gfx940', 'gfx941', 'gfx942', 'gfx950'):
supported_fp8_dtypes.update({'fp8e4nv', 'fp8e4b8', 'fp8e5b16'})
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))

if "enable_fp_fusion" not in opts:
Expand Down Expand Up @@ -204,13 +204,8 @@ def make_ttir(mod, metadata, options):
def make_ttgir(mod, metadata, options):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(
pm,
f"hip:{options.arch}",
options.num_warps,
options.warp_size,
options.num_ctas,
)
passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size,
options.num_ctas)
pm.run(mod)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
Expand Down Expand Up @@ -305,9 +300,9 @@ def make_llir(src, metadata, options):
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
amd.attach_target_triple(llvm_mod)
target_features = ""
target_features = ''
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features = "+xnack"
target_features = '+xnack'
llvm.attach_datalayout(llvm_mod, amd.TARGET_TRIPLE, options.arch, target_features)

# Set various control constants on the LLVM module so that device
Expand Down Expand Up @@ -344,18 +339,18 @@ def make_llir(src, metadata, options):
amd.set_all_fn_arg_inreg(fns[0])

if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
default_libdir = Path(__file__).parent / "lib"
default_libdir = Path(__file__).parent / 'lib'
paths = [
str(default_libdir / "asanrtl.bc"),
str(default_libdir / 'asanrtl.bc'),
str(default_libdir / "ocml.bc"),
str(default_libdir / "ockl.bc"),
str(default_libdir / "ockl.bc")
]
llvm.link_extern_libs(llvm_mod, paths)
elif options.extern_libs:
paths = [path for (name, path) in options.extern_libs if amd.need_extern_lib(llvm_mod, name)]
llvm.link_extern_libs(llvm_mod, paths)

llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, "", [], options.enable_fp_fusion)
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion)

# Get some metadata
metadata["shared"] = src.get_int_attr("ttg.shared")
Expand All @@ -375,42 +370,26 @@ def make_amdgcn(src, metadata, options):
assert len(names) == 1
metadata["name"] = names[0]
# llvm -> hsaco
amdgcn = llvm.translate_to_asm(
src,
amd.TARGET_TRIPLE,
options.arch,
"",
[],
options.enable_fp_fusion,
False,
)
amdgcn = llvm.translate_to_asm(src, amd.TARGET_TRIPLE, options.arch, '', [], options.enable_fp_fusion, False)
if os.environ.get("AMDGCN_ENABLE_DUMP", "0") == "1":
print("// -----// AMDGCN Dump //----- //")
print(amdgcn)
return amdgcn

@staticmethod
def make_hsaco(src, metadata, options):
target_features = ""
target_features = ''
if os.environ.get("TRITON_ENABLE_ASAN", "0") == "1":
target_features = "+xnack"
target_features = '+xnack'
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)

rocm_path = HIPBackend.path_to_rocm_lld()
with tempfile.NamedTemporaryFile() as tmp_out:
with tempfile.NamedTemporaryFile() as tmp_in:
with open(tmp_in.name, "wb") as fd_in:
with open(tmp_in.name, 'wb') as fd_in:
fd_in.write(hsaco)
subprocess.check_call([
rocm_path,
"-flavor",
"gnu",
"-shared",
tmp_in.name,
"-o",
tmp_out.name,
])
with open(tmp_out.name, "rb") as fd_out:
subprocess.check_call([rocm_path, '-flavor', 'gnu', '-shared', tmp_in.name, '-o', tmp_out.name])
with open(tmp_out.name, 'rb') as fd_out:
ret = fd_out.read()
return ret

Expand All @@ -423,5 +402,5 @@ def add_stages(self, stages, options):

@functools.lru_cache()
def hash(self):
version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding="utf-8")
return f"{version}-{self.target}"
version = subprocess.check_output([HIPBackend.path_to_rocm_lld(), "--version"], encoding='utf-8')
return f'{version}-{self.target}'

0 comments on commit ca6fc11

Please sign in to comment.