Skip to content

Commit

Permalink
[FRONTEND] make CompiledKernel metadata a namedtuple instead of a d…
Browse files Browse the repository at this point in the history
…ict, and pass it to hook in lieu of kernel object (#2929)
  • Loading branch information
ptillet authored Jan 12, 2024
1 parent d0cb667 commit f3e2d84
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 21 deletions.
31 changes: 19 additions & 12 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def make_ir(self, options, context):

def metadata(self):
# TODO: remove once TMA support is cleaned up
return {"ids_of_folded_args": tuple([int(k) for k in self.attrs.ids_of_folded_args])}
return {
"ids_of_folded_args": tuple([int(k) for k in self.attrs.ids_of_folded_args]),
}

def parse_options(self):
return dict()
Expand Down Expand Up @@ -146,7 +148,7 @@ def make_ir(self, options, context):
return module

def metadata(self):
return dict()
return {"ids_of_folded_args": tuple()}

def parse_options(self):
if self.ext == "ttgir":
Expand Down Expand Up @@ -208,6 +210,7 @@ def compile(src, target=None, options=None):
return CompiledKernel(src, metadata_group)
# initialize metadata
metadata = {
"hash": hash,
"target": target,
**options.__dict__,
**get_env_vars(),
Expand Down Expand Up @@ -249,15 +252,18 @@ class CompiledKernel:
launch_exit_hook = None

def __init__(self, src, metadata_group):
from collections import namedtuple
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
self.metadata = json.loads(metadata_path.read_text())
self.metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in self.metadata['tensormaps_info']
] if 'tensormaps_info' in self.metadata else []
for i, _ in enumerate(self.metadata["tensormaps_info"]):
self.metadata["tensormaps_info"][i].ids_of_folded_args = tuple(self.metadata["ids_of_folded_args"])
self.name = self.metadata["name"]
for key, val in self.metadata.items():
setattr(self, key, val)
self.metadata["tensormaps_info"] = tuple(self.metadata["tensormaps_info"])
KernelMetadata = namedtuple('KernelMetadata', sorted(list(self.metadata.keys())))
self.metadata = KernelMetadata(**self.metadata)

self.name = self.metadata.name
# create launcher
self.run = driver.launcher_cls(src, self.metadata)
# stores the text of each level of IR that was generated during compilation
Expand All @@ -279,11 +285,11 @@ def _init_handles(self):
device = driver.get_current_device()
# not enough shared memory to run the kernel
max_shared = driver.utils.get_device_properties(device)["max_shared_mem"]
if self.shared > max_shared:
raise OutOfResources(self.shared, max_shared, "shared memory")
if self.metadata.shared > max_shared:
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
# TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_binary(
self.name, self.kernel, self.shared, device)
self.name, self.kernel, self.metadata.shared, device)

def __getattribute__(self, name):
if name == 'run':
Expand All @@ -294,12 +300,13 @@ def __getitem__(self, grid):
self._init_handles()

def runner(*args, stream=None):
args_expand = driver.assemble_tensormap_to_arg(self.tensormaps_info, args)
if stream is None:
device = driver.get_current_device()
stream = driver.get_current_stream(device)
self.run(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.cluster_dims[0],
self.cluster_dims[1], self.cluster_dims[2], self.shared, stream, self.function,
CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand)
md = self.metadata
args_expand = driver.assemble_tensormap_to_arg(md.tensormaps_info, args)
self.run(grid[0], grid[1], grid[2], md.num_warps, md.num_ctas, md.cluster_dims[0], md.cluster_dims[1],
md.cluster_dims[2], md.shared, stream, self.function, CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook, md, *args_expand)

return runner
12 changes: 7 additions & 5 deletions python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,13 @@ def run(self, *args, grid, warmup, **kwargs):
kernel = self.cache[device][key]
if not warmup:
args = [arg.value for arg in args if not arg.param.is_constexpr]
kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance
kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster
kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook, kernel,
*driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args))
metadata = kernel.metadata
kernel.run(grid_0, grid_1, grid_2, metadata.num_warps,
metadata.num_ctas, # number of warps/ctas per instance
metadata.cluster_dims[0], metadata.cluster_dims[1], metadata.cluster_dims[2], # cluster
metadata.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook, metadata,
*driver.assemble_tensormap_to_arg(metadata.tensormaps_info, args))
return kernel

def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
Expand Down
2 changes: 1 addition & 1 deletion python/triton/tools/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def constexpr(s):
"arg_pointers": ", ".join([f"&{arg}" for arg in arg_names]),
"num_args": len(arg_names),
"kernel_docstring": doc_string,
"shared": ccinfo.shared,
"shared": ccinfo.metadata.shared,
"num_warps": args.num_warps,
"algo_info": '_'.join([const_sig, meta_sig]),
"gridX": grid[0],
Expand Down
6 changes: 3 additions & 3 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ class CudaLauncher(object):

def __init__(self, src, metadata):
ids = {
"ids_of_tensormaps": metadata.get("ids_of_tensormaps", tuple()), "ids_of_folded_args":
metadata.get("ids_of_folded_args",
tuple()), "ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()
"ids_of_tensormaps": metadata.ids_of_tensormaps,
"ids_of_folded_args": metadata.ids_of_folded_args,
"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()
}
constants = src.constants if hasattr(src, "constants") else dict()
enable_warp_specialization = False
Expand Down

0 comments on commit f3e2d84

Please sign in to comment.