Skip to content

Commit

Permalink
[FRONTEND] Support returning a named tuple (#6042)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Feb 27, 2025
1 parent f52d3fe commit 95bedd5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
14 changes: 13 additions & 1 deletion python/test/unit/language/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,17 @@ class Tensor(NamedTuple):
stride: tuple


@triton.jit
def _namedtuple_create_func0(shape, ptr, stride):
return Tensor(shape=shape, ptr=ptr, stride=stride)


@triton.jit
def _namedtuple_create_func1(shape, ptr, stride):
tensor = Tensor(shape=shape, ptr=ptr, stride=stride)
return tensor


@triton.jit
def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
Expand All @@ -127,7 +138,8 @@ def _namedtuple_mask_func(Tensor, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
def _namedtuple_kernel(closure, _X, Y, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
X = Tensor(shape=_X.shape, ptr=_X.ptr, stride=_X.stride)
X = _namedtuple_create_func0(_X.shape, _X.ptr, _X.stride)
Y = _namedtuple_create_func1(Y.shape, Y.ptr, Y.stride)
Xs = X.ptr + offs_m[:, None] * X.stride[0] + offs_n[None, :] * X.stride[1]
Ys = Y.ptr + offs_m[:, None] * Y.stride[0] + offs_n[None, :] * Y.stride[1]
x = tl.load(Xs, mask=_namedtuple_mask_func(X, BLOCK_M, BLOCK_N), other=0)
Expand Down
34 changes: 22 additions & 12 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def _check_fn_args(node, fn, args):
)


def _is_namedtuple(val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")


def _apply_to_tuple_values(value, fn):
if _is_namedtuple(type(value)):
fields = value._fields
elif isinstance(value, language.tuple):
fields = value.type.fields
else:
assert False, f"Unsupported type {type(value)}"

vals = [fn(v) for v in value]
types = [v.type for v in vals]
return language.tuple(vals, language.tuple_type(types, fields))


def flatten_values_to_ir(values: Iterable[base_value]):
handles = []
for v in values:
Expand Down Expand Up @@ -349,9 +366,6 @@ def _is_constexpr_global(self, name):

return False

def _is_namedtuple(self, val):
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields")

def _define_name_lookup(self):

def local_lookup(name: str, absent):
Expand All @@ -370,7 +384,7 @@ def global_lookup(name: str, absent):
getattr(val, "__triton_builtin__", False), #
getattr(val, "__module__", "").startswith("triton.language"), #
isinstance(val, language.dtype), #
self._is_namedtuple(val),
_is_namedtuple(val),
self._is_constexpr_global(name), #
# Allow accesses to globals while visiting an ast.arg
# because you should be able to do
Expand Down Expand Up @@ -451,7 +465,7 @@ def visit_Return(self, node):

def decay(value):
if isinstance(value, language.tuple):
return language.tuple([decay(v) for v in value.values])
return _apply_to_tuple_values(value, decay)
elif isinstance(value, (language.constexpr, int, float)):
return semantic.to_tensor(value, self.builder)
return value
Expand Down Expand Up @@ -575,13 +589,8 @@ def assignTarget(self, target, value):
def visit_Assign(self, node):
# construct values to assign
def _sanitize_value(value):
if self._is_namedtuple(type(value)):
vals = [_sanitize_value(v) for v in value]
types = [v.type for v in vals]
fields = type(value)._fields
return language.tuple(vals, language.tuple_type(types, fields))
if isinstance(value, language.tuple):
return language.tuple([_sanitize_value(v) for v in value.values])
return _apply_to_tuple_values(value, _sanitize_value)
native_nontensor_types = (language.dtype, language.tuple)
value = _unwrap_if_constexpr(value)
if value is not None and \
Expand Down Expand Up @@ -1253,7 +1262,8 @@ def visit_Call(self, node):

if fn in self.builtin_namespace.values():
args = map(_unwrap_if_constexpr, args)
return fn(*args, **kws)
ret = fn(*args, **kws)
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret

def visit_Constant(self, node):
return constexpr(node.value)
Expand Down

0 comments on commit 95bedd5

Please sign in to comment.