From b7d0842a2651b2721ad653076e30100c83a18c88 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 26 Feb 2025 23:54:16 -0500 Subject: [PATCH 1/9] Update --- python/test/unit/language/test_tuple.py | 9 +++++-- python/triton/compiler/code_generator.py | 31 +++++++++++++++--------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py index cd819221df6e..77dbd01ed39e 100644 --- a/python/test/unit/language/test_tuple.py +++ b/python/test/unit/language/test_tuple.py @@ -53,7 +53,7 @@ def _tuple_assign(XPtrs, YPtrs, values): @pytest.mark.interpreter def test_assign(device): - vals = (2., 3.) + vals = (2.0, 3.0) x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) _tuple_assign[(1, )](x, y, vals) @@ -115,6 +115,11 @@ class Tensor(NamedTuple): stride: tuple +@triton.jit +def _namedtuple_create_func(shape, ptr, stride): + return Tensor(shape=shape, ptr=ptr, stride=stride) + + @triton.jit def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) @@ -127,7 +132,7 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride) + X = _namedtuple_create_func(_X.shape, _X.ptr, _X.stride) Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 532e84b52325..7550834c07d4 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -89,6 +89,23 @@ def _check_fn_args(node, fn, args): ) +def _is_namedtuple(val): + return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") + + +def _apply_to_tuple_values(value, fn): + if _is_namedtuple(type(value)): + fields = value._fields + elif isinstance(value, language.tuple): + fields = value.type.fields + else: + assert False, f"Unsupported type {type(value)}" + + vals = [fn(v) for v in value] + types = [v.type for v in vals] + return language.tuple(vals, language.tuple_type(types, fields)) + + def flatten_values_to_ir(values: Iterable[base_value]): handles = [] for v in values: @@ -349,9 +366,6 @@ def _is_constexpr_global(self, name): return False - def _is_namedtuple(self, val): - return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") - def _define_name_lookup(self): def local_lookup(name: str, absent): @@ -451,7 +465,7 @@ def visit_Return(self, node): def decay(value): if isinstance(value, language.tuple): - return language.tuple([decay(v) for v in value.values]) + return _apply_to_tuple_values(value, decay) elif isinstance(value, (language.constexpr, int, float)): return semantic.to_tensor(value, self.builder) return value @@ -575,13 +589,8 @@ def assignTarget(self, target, value): def visit_Assign(self, node): # construct values to assign def _sanitize_value(value): - if self._is_namedtuple(type(value)): - vals = [_sanitize_value(v) for v in value] - types = [v.type for v in vals] - fields = type(value)._fields - return language.tuple(vals, language.tuple_type(types, fields)) - if isinstance(value, language.tuple): - return language.tuple([_sanitize_value(v) for v in value.values]) + if self._is_namedtuple(type(value)) or isinstance(value, language.tuple): + return _apply_to_tuple_values(value, _sanitize_value) native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) if value is not None and \ From f9be2cc028f501653e2603f2159ca7db3edcae8b Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:07:32 -0500 Subject: [PATCH 2/9] Update --- python/triton/compiler/code_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 7550834c07d4..2575cb8054db 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -384,7 +384,7 @@ def global_lookup(name: str, absent): getattr(val, "__triton_builtin__", False), # getattr(val, "__module__", "").startswith("triton.language"), # isinstance(val, language.dtype), # - self._is_namedtuple(val), + _is_namedtuple(val), self._is_constexpr_global(name), # # Allow accesses to globals while visiting an ast.arg # because you should be able to do @@ -589,7 +589,7 @@ def assignTarget(self, target, value): def visit_Assign(self, node): # construct values to assign def _sanitize_value(value): - if self._is_namedtuple(type(value)) or isinstance(value, language.tuple): + if _is_namedtuple(type(value)) or isinstance(value, language.tuple): return _apply_to_tuple_values(value, _sanitize_value) native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) From 34c5fb5d55ea24285a0dd480f103f24cd142e058 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:19:31 -0500 Subject: [PATCH 3/9] Update --- python/test/unit/language/test_tuple.py | 11 +++++++++-- python/triton/compiler/code_generator.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py index 77dbd01ed39e..97d652903234 100644 --- a/python/test/unit/language/test_tuple.py +++ b/python/test/unit/language/test_tuple.py @@ -116,10 +116,16 @@ class Tensor(NamedTuple): @triton.jit -def _namedtuple_create_func(shape, ptr, stride): +def _namedtuple_create_func0(shape, ptr, stride): return Tensor(shape=shape, ptr=ptr, stride=stride) +@triton.jit +def _namedtuple_create_func1(shape, ptr, stride): + tensor = Tensor(shape=shape, ptr=ptr, stride=stride) + return tensor + + @triton.jit def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) @@ -132,7 +138,8 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - X = _namedtuple_create_func(_X.shape, _X.ptr, _X.stride) + X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride) + Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride) Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 2575cb8054db..ac37ad93b77a 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -464,7 +464,7 @@ def visit_Return(self, node): handles = [] def decay(value): - if isinstance(value, language.tuple): + if _is_namedtuple(type(value)) or isinstance(value, language.tuple): return _apply_to_tuple_values(value, decay) elif isinstance(value, (language.constexpr, int, float)): return semantic.to_tensor(value, self.builder) From 7177ae4bd57075106a567c32f04866dfec1a90fa Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:22:25 -0500 Subject: [PATCH 4/9] Update --- python/test/unit/language/test_tuple.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py index 97d652903234..9c753b15889a 100644 --- a/python/test/unit/language/test_tuple.py +++ b/python/test/unit/language/test_tuple.py @@ -116,14 +116,9 @@ class Tensor(NamedTuple): @triton.jit -def _namedtuple_create_func0(shape, ptr, stride): - return Tensor(shape=shape, ptr=ptr, stride=stride) - - -@triton.jit -def _namedtuple_create_func1(shape, ptr, stride): - tensor = Tensor(shape=shape, ptr=ptr, stride=stride) - return tensor +def _namedtuple_create_func(shape0, ptr0, stride0, shape1, ptr1, stride1): + tensor0 = Tensor(shape=shape0, ptr=ptr0, stride=stride0) + return tensor0, Tensor(shape=shape1, ptr=ptr1, stride=stride1) @triton.jit @@ -138,8 +133,7 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride) - Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride) + X, Y = _namedtuple_create_func(_X.shape, _X.ptr, _X.stride, Y.shape, Y.ptr, Y.stride) Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) From d7ebf9606211660333f56973f10d42591f207948 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:51:48 -0500 Subject: [PATCH 5/9] Update --- python/triton/compiler/code_generator.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index ac37ad93b77a..15e06e5bd416 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -464,7 +464,7 @@ def visit_Return(self, node): handles = [] def decay(value): - if _is_namedtuple(type(value)) or isinstance(value, language.tuple): + if isinstance(value, language.tuple): return _apply_to_tuple_values(value, decay) elif isinstance(value, (language.constexpr, int, float)): return semantic.to_tensor(value, self.builder) @@ -589,7 +589,7 @@ def assignTarget(self, target, value): def visit_Assign(self, node): # construct values to assign def _sanitize_value(value): - if _is_namedtuple(type(value)) or isinstance(value, language.tuple): + if isinstance(value, language.tuple): return _apply_to_tuple_values(value, _sanitize_value) native_nontensor_types = (language.dtype, language.tuple) value = _unwrap_if_constexpr(value) @@ -1262,7 +1262,10 @@ def visit_Call(self, node): if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) - return fn(*args, **kws) + ret = fn(*args, **kws) + if _is_namedtuple(ret): + return _apply_to_tuple_values(ret, lambda x: x) + return ret def visit_Constant(self, node): return constexpr(node.value) From 4941765bb3c7326f82ab61d3cf18c883b85d969d Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:53:34 -0500 Subject: [PATCH 6/9] Update --- python/triton/compiler/code_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 15e06e5bd416..d808039e6691 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1263,7 +1263,7 @@ def visit_Call(self, node): if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) ret = fn(*args, **kws) - if _is_namedtuple(ret): + if _is_namedtuple(type(ret)): return _apply_to_tuple_values(ret, lambda x: x) return ret From 964a6105b9a7bbe827a54c1258c74e1e8f539faa Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:56:55 -0500 Subject: [PATCH 7/9] Update --- python/test/unit/language/test_tuple.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py index 9c753b15889a..97d652903234 100644 --- a/python/test/unit/language/test_tuple.py +++ b/python/test/unit/language/test_tuple.py @@ -116,9 +116,14 @@ class Tensor(NamedTuple): @triton.jit -def _namedtuple_create_func(shape0, ptr0, stride0, shape1, ptr1, stride1): - tensor0 = Tensor(shape=shape0, ptr=ptr0, stride=stride0) - return tensor0, Tensor(shape=shape1, ptr=ptr1, stride=stride1) +def _namedtuple_create_func0(shape, ptr, stride): + return Tensor(shape=shape, ptr=ptr, stride=stride) + + +@triton.jit +def _namedtuple_create_func1(shape, ptr, stride): + tensor = Tensor(shape=shape, ptr=ptr, stride=stride) + return tensor @triton.jit @@ -133,7 +138,8 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - X, Y = _namedtuple_create_func(_X.shape, _X.ptr, _X.stride, Y.shape, Y.ptr, Y.stride) + X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride) + Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride) Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1] Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1] x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0) From 383b7ecb2b31eb22b98f877e2f151dd009619b94 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 00:59:27 -0500 Subject: [PATCH 8/9] Update --- python/triton/compiler/code_generator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index d808039e6691..9b184795743e 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -1263,9 +1263,7 @@ def visit_Call(self, node): if fn in self.builtin_namespace.values(): args = map(_unwrap_if_constexpr, args) ret = fn(*args, **kws) - if _is_namedtuple(type(ret)): - return _apply_to_tuple_values(ret, lambda x: x) - return ret + return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret def visit_Constant(self, node): return constexpr(node.value) From 14b90582ac1f1b3f84c1918ed590945f12281edb Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Feb 2025 01:11:54 -0500 Subject: [PATCH 9/9] Update --- python/test/unit/language/test_tuple.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/test/unit/language/test_tuple.py b/python/test/unit/language/test_tuple.py index 97d652903234..0c08413c380a 100644 --- a/python/test/unit/language/test_tuple.py +++ b/python/test/unit/language/test_tuple.py @@ -53,7 +53,7 @@ def _tuple_assign(XPtrs, YPtrs, values): @pytest.mark.interpreter def test_assign(device): - vals = (2.0, 3.0) + vals = (2., 3.) x = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(2)]) y = tuple([torch.zeros((1, ), dtype=torch.float32, device=device) for _ in range(3)]) _tuple_assign[(1, )](x, y, vals)