Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] Support returning a named tuple #6042

Merged
merged 9 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ class Tensor(NamedTuple):
stride: tuple


@triton.jit
def _namedtuple_create_func0(shape, ptr, stride):
return Tensor(shape=shape, ptr=ptr, stride=stride)


@triton.jit
def _namedtuple_create_func1(shape, ptr, stride):
tensor = Tensor(shape=shape, ptr=ptr, stride=stride)
return tensor


@triton.jit
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
Expand All @@ -127,7 +138,8 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride)
X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride)
Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride)
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
Expand Down
34 changes: 22 additions & 12 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def _check_fn_args(node, fn, args):
)


def _is_namedtuple(val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")


def _apply_to_tuple_values(value, fn):
if _is_namedtuple(type(value)):
fields = value._fields
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand that we can have a python named tuple instead of a language.tuple? That sounds like the real bug here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand that we can have a python named tuple instead of a language.tuple? That sounds like the real bug here.

Yeah, python named tuple is supposed to be supported

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't sound right to me, it should be converted to a triton tuple type by the frontend

elif isinstance(value, language.tuple):
fields = value.type.fields
else:
assert False, f"Unsupported type {type(value)}"

vals = [fn(v) for v in value]
types = [v.type for v in vals]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're at it, it might be worth fixing this to work with tuples containing None (which will decay here and no longer have a type). Locally I have this as [v.type if v is not None else constexpr for v in vals] but I'm not entirely sure this is right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we even support return None.

I would defer it to another PR. Seems like there are more problems

return language.tuple(vals, language.tuple_type(types, fields))


def flatten_values_to_ir(values: Iterable[base_value]):
handles = []
for v in values:
Expand Down Expand Up @@ -349,9 +366,6 @@ def _is_constexpr_global(self, name):

return False

def _is_namedtuple(self, val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")

def _define_name_lookup(self):

def local_lookup(name: str, absent):
Expand All @@ -370,7 +384,7 @@ def global_lookup(name: str, absent):
getattr(val, "__triton_builtin__", False), #
getattr(val, "__module__", "").startswith("triton.language"), #
isinstance(val, language.dtype), #
self._is_namedtuple(val),
_is_namedtuple(val),
self._is_constexpr_global(name), #
# Allow accesses to globals while visiting an ast.arg
# because you should be able to do
Expand Down Expand Up @@ -451,7 +465,7 @@ def visit_Return(self, node):

def decay(value):
if isinstance(value, language.tuple):
return language.tuple([decay(v) for v in value.values])
return _apply_to_tuple_values(value, decay)
elif isinstance(value, (language.constexpr, int, float)):
return semantic.to_tensor(value, self.builder)
return value
Expand Down Expand Up @@ -575,13 +589,8 @@ def assignTarget(self, target, value):
def visit_Assign(self, node):
# construct values to assign
def _sanitize_value(value):
if self._is_namedtuple(type(value)):
vals = [_sanitize_value(v) for v in value]
types = [v.type for v in vals]
fields = type(value)._fields
return language.tuple(vals, language.tuple_type(types, fields))
if isinstance(value, language.tuple):
return language.tuple([_sanitize_value(v) for v in value.values])
return _apply_to_tuple_values(value, _sanitize_value)
native_nontensor_types = (language.dtype, language.tuple)
value = _unwrap_if_constexpr(value)
if value is not None and \
Expand Down Expand Up @@ -1253,7 +1262,8 @@ def visit_Call(self, node):

if fn in self.builtin_namespace.values():
args = map(_unwrap_if_constexpr, args)
return fn(*args, **kws)
ret = fn(*args, **kws)
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 Oh, I meant it's being converted here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed we convert every named tuple to a corresponding tl.tuple when created

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so the _is_namedtuple path should never actually be hit from visit_Return only from here. Makes sense.


def visit_Constant(self, node):
return constexpr(node.value)
Expand Down
Loading