Skip to content

Commit

Permalink
More flexible debugging of triton IRs
Browse files Browse the repository at this point in the history
Printing triton IRs to stdout is easy but if there are multiple kernels involved
it's hard to tell which kernel corresponds to which IR. Instead we propose that
the user decides how to deal with the IRs by providing a callback and dealing
with the log however they want.

PiperOrigin-RevId: 559094543
  • Loading branch information
The jax_triton Authors committed Sep 15, 2023
1 parent 170511d commit 8ba2134
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 14 deletions.
50 changes: 36 additions & 14 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,42 +143,56 @@ def aval_size_bytes(aval):


def ptx_get_kernel_name(module) -> str:
return tc.get_kernel_name(module, pattern='// .globl')
return tc.get_kernel_name(module, pattern="// .globl")


def _maybe_dump(
dump: Callable[[str, str], None] | bool, ir_name: str, ir_body: Any
):
"""Do the right thing w.r.t. logging the IR.
Args:
dump: if it is a callable use it to log the IR. If it's bool, its value
decieds whether to do nothing or dump the IR in stdout.
ir_name: the name of the IR.
ir_body: the text of the ir. Must implement `__str__()`.
"""

if callable(dump):
dump(ir_name, str(ir_body))
elif dump:
print(ir_body)


def compile_ttir_to_ptx_inplace(
ttir,
device: int = 0,
num_warps: int = 4,
num_stages: Optional[int] = None,
dump: bool = False,
dump: Callable[[str, str], None] | bool = False,
) -> Tuple[str, str, int, int]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = 3 if compute_capability >= 75 else 2
if dump:
print(ttir)
_maybe_dump(dump, "ttir", ttir)
try:
ttir = tc.optimize_ttir(ttir, compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps)
ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
print(ttgir)
_maybe_dump(dump, "ttgir", ttgir)
extern_libs = {}
try:
llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
print(llir)
_maybe_dump(dump, "llir", llir)
ptx = tc.llir_to_ptx(llir, compute_capability)
if dump:
print(ptx)
_maybe_dump(dump, "ptx", ptx)
name = ptx_get_kernel_name(ptx)
return ptx, name, shared_mem_bytes, compute_capability

Expand All @@ -194,7 +208,7 @@ def get_or_create_triton_kernel(
num_warps,
num_stages,
metaparams,
dump: bool,
dump: Callable[[str, str], None] | bool,
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
signature = dict(enumerate(arg_dtypes))
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
Expand Down Expand Up @@ -229,8 +243,14 @@ def get_or_create_triton_kernel(
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
_maybe_dump(dump, "py", fn.src)
module = code_gen.ast_to_ttir(
fn, signature, specialization, constants, debug=dump, arch=arch
fn,
signature,
specialization,
constants,
debug=dump is not None,
arch=arch,
)
ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
Expand Down Expand Up @@ -454,7 +474,7 @@ def triton_call(
zeroed_outputs: Union[
Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]
] = (),
debug: bool = False,
debug: Callable[[str, str], None] | bool = False,
serialized_metadata: bytes = b"",
**metaparams: Any,
) -> Any:
Expand Down Expand Up @@ -529,7 +549,9 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
indices, for outputs that should be zeroed before the kernel is launched.
num_warps: The number of warps used to execute the Triton kernel.
num_stages: The number of stages emitted by the Triton compiler.
debug: Prints out intermediate IRs if True for debugging purposes.
debug: Passes the IRs to a callable `debug(ir_name, ir_body)` for debugging.
It could also be a bool where `True` means dump to stdout and `False` don't
dump anything.
serialized_metadata: Arbitrary metadata that will be added into the
serialized kernel call.
**metaparams: Additional keyword arguments that will be provided to a `grid`
Expand Down
35 changes: 35 additions & 0 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,41 @@ def test_specialization(self):
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.equal_to_1, (8, 10))

def test_debug_callable(self):
emitted_irs = dict()

m, n, k = 128, 128, 128
x, y = create_random_inputs([m, k], [k, n])

def intercept_ir(ir_name, ir_body):
ir_body = str(ir_body)
self.assertNotIn(
ir_name, emitted_irs, f"Attempted to overwrite {ir_name}"
)
self.assertNotEmpty(ir_body, f"IR '{ir_name}' was empty.")
emitted_irs[ir_name] = ir_body

block_size_m, block_size_n, block_size_k = 128, 128, 32
_ = matmul(
x,
y,
debug=intercept_ir,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
K_EXACTLY_DIVISIBLE_BY_BLOCK=k % block_size_k == 0,
)

for ir_name in ["py", "ttir", "ttgir", "llir", "ptx"]:
self.assertIn(ir_name, emitted_irs, f"IR '{ir_name}' was not recorded.")
self.assertNotEmpty(emitted_irs[ir_name], f"IR '{ir_name}' was empty.")

self.assertStartsWith(
emitted_irs["py"],
"def",
"Python code is emitted as the strigification of a different object",
)


if __name__ == "__main__":
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
Expand Down

0 comments on commit 8ba2134

Please sign in to comment.