From 8ba213426d9645e014d4d33284d1601797622340 Mon Sep 17 00:00:00 2001 From: The jax_triton Authors Date: Tue, 22 Aug 2023 06:33:41 -0700 Subject: [PATCH] More flexible debugging of triton IRs Printing triton IRs to stdout is easy but if there are multiple kernels involved it's hard to tell which kernel corresponds to which IR. Instead we propose that the user decides how to deal with the IRs by providing a callback and dealing with the log however they want. PiperOrigin-RevId: 559094543 --- jax_triton/triton_lib.py | 50 ++++++++++++++++++++++++++++----------- tests/triton_call_test.py | 35 +++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 8e678d8f..55ab5296 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -143,7 +143,25 @@ def aval_size_bytes(aval): def ptx_get_kernel_name(module) -> str: - return tc.get_kernel_name(module, pattern='// .globl') + return tc.get_kernel_name(module, pattern="// .globl") + + +def _maybe_dump( + dump: Callable[[str, str], None] | bool, ir_name: str, ir_body: Any +): + """Do the right thing w.r.t. logging the IR. + + Args: + dump: if it is a callable use it to log the IR. If it's bool, its value + decieds whether to do nothing or dump the IR in stdout. + ir_name: the name of the IR. + ir_body: the text of the ir. Must implement `__str__()`. + """ + + if callable(dump): + dump(ir_name, str(ir_body)) + elif dump: + print(ir_body) def compile_ttir_to_ptx_inplace( @@ -151,13 +169,12 @@ def compile_ttir_to_ptx_inplace( device: int = 0, num_warps: int = 4, num_stages: Optional[int] = None, - dump: bool = False, + dump: Callable[[str, str], None] | bool = False, ) -> Tuple[str, str, int, int]: compute_capability = triton_kernel_call_lib.get_compute_capability(device) if num_stages is None: num_stages = 3 if compute_capability >= 75 else 2 - if dump: - print(ttir) + _maybe_dump(dump, "ttir", ttir) try: ttir = tc.optimize_ttir(ttir, compute_capability) ttgir = tc.ttir_to_ttgir(ttir, num_warps) @@ -165,8 +182,7 @@ def compile_ttir_to_ptx_inplace( except RuntimeError as e: ttir.dump() raise ValueError("TTIR->TTGIR pass failed!") from e - if dump: - print(ttgir) + _maybe_dump(dump, "ttgir", ttgir) extern_libs = {} try: llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability) @@ -174,11 +190,9 @@ def compile_ttir_to_ptx_inplace( ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e shared_mem_bytes = _triton.get_shared_memory_size(ttgir) - if dump: - print(llir) + _maybe_dump(dump, "llir", llir) ptx = tc.llir_to_ptx(llir, compute_capability) - if dump: - print(ptx) + _maybe_dump(dump, "ptx", ptx) name = ptx_get_kernel_name(ptx) return ptx, name, shared_mem_bytes, compute_capability @@ -194,7 +208,7 @@ def get_or_create_triton_kernel( num_warps, num_stages, metaparams, - dump: bool, + dump: Callable[[str, str], None] | bool, ) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]: signature = dict(enumerate(arg_dtypes)) # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers @@ -229,8 +243,14 @@ def get_or_create_triton_kernel( # general. device = 0 arch = triton_kernel_call_lib.get_compute_capability(device) + _maybe_dump(dump, "py", fn.src) module = code_gen.ast_to_ttir( - fn, signature, specialization, constants, debug=dump, arch=arch + fn, + signature, + specialization, + constants, + debug=dump is not None, + arch=arch, ) ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. ptx, kernel_name, shared_mem_bytes, compute_capability = ( @@ -454,7 +474,7 @@ def triton_call( zeroed_outputs: Union[ Sequence[int], Callable[[Dict[str, Any]], Sequence[int]] ] = (), - debug: bool = False, + debug: Callable[[str, str], None] | bool = False, serialized_metadata: bytes = b"", **metaparams: Any, ) -> Any: @@ -529,7 +549,9 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: indices, for outputs that should be zeroed before the kernel is launched. num_warps: The number of warps used to execute the Triton kernel. num_stages: The number of stages emitted by the Triton compiler. - debug: Prints out intermediate IRs if True for debugging purposes. + debug: Passes the IRs to a callable `debug(ir_name, ir_body)` for debugging. + It could also be a bool where `True` means dump to stdout and `False` don't + dump anything. serialized_metadata: Arbitrary metadata that will be added into the serialized kernel call. **metaparams: Additional keyword arguments that will be provided to a `grid` diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 71a59806..bbb3492d 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -549,6 +549,41 @@ def test_specialization(self): # specialize" leaving `stride_{bn,cn}`. self.assertEqual(specialization.equal_to_1, (8, 10)) + def test_debug_callable(self): + emitted_irs = dict() + + m, n, k = 128, 128, 128 + x, y = create_random_inputs([m, k], [k, n]) + + def intercept_ir(ir_name, ir_body): + ir_body = str(ir_body) + self.assertNotIn( + ir_name, emitted_irs, f"Attempted to overwrite {ir_name}" + ) + self.assertNotEmpty(ir_body, f"IR '{ir_name}' was empty.") + emitted_irs[ir_name] = ir_body + + block_size_m, block_size_n, block_size_k = 128, 128, 32 + _ = matmul( + x, + y, + debug=intercept_ir, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + K_EXACTLY_DIVISIBLE_BY_BLOCK=k % block_size_k == 0, + ) + + for ir_name in ["py", "ttir", "ttgir", "llir", "ptx"]: + self.assertIn(ir_name, emitted_irs, f"IR '{ir_name}' was not recorded.") + self.assertNotEmpty(emitted_irs[ir_name], f"IR '{ir_name}' was empty.") + + self.assertStartsWith( + emitted_irs["py"], + "def", + "Python code is emitted as the strigification of a different object", + ) + if __name__ == "__main__": os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"