-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FRONTEND] Support returning a named tuple #6042
Changes from all commits
b7d0842
f9be2cc
34c5fb5
7177ae4
d7ebf96
4941765
964a610
383b7ec
14b9058
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,23 @@ def _check_fn_args(node, fn, args): | |
) | ||
|
||
|
||
def _is_namedtuple(val): | ||
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") | ||
|
||
|
||
def _apply_to_tuple_values(value, fn): | ||
if _is_namedtuple(type(value)): | ||
fields = value._fields | ||
elif isinstance(value, language.tuple): | ||
fields = value.type.fields | ||
else: | ||
assert False, f"Unsupported type {type(value)}" | ||
|
||
vals = [fn(v) for v in value] | ||
types = [v.type for v in vals] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While you're at it, it might be worth fixing this to work with tuples containing None (which will decay here and no longer have a type). Locally I have this as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not think we even support I would defer it to another PR. Seems like there are more problems |
||
return language.tuple(vals, language.tuple_type(types, fields)) | ||
|
||
|
||
def flatten_values_to_ir(values: Iterable[base_value]): | ||
handles = [] | ||
for v in values: | ||
|
@@ -349,9 +366,6 @@ def _is_constexpr_global(self, name): | |
|
||
return False | ||
|
||
def _is_namedtuple(self, val): | ||
return isinstance(val, type) and issubclass(val, tuple) and hasattr(val, "_fields") | ||
|
||
def _define_name_lookup(self): | ||
|
||
def local_lookup(name: str, absent): | ||
|
@@ -370,7 +384,7 @@ def global_lookup(name: str, absent): | |
getattr(val, "__triton_builtin__", False), # | ||
getattr(val, "__module__", "").startswith("triton.language"), # | ||
isinstance(val, language.dtype), # | ||
self._is_namedtuple(val), | ||
_is_namedtuple(val), | ||
self._is_constexpr_global(name), # | ||
# Allow accesses to globals while visiting an ast.arg | ||
# because you should be able to do | ||
|
@@ -451,7 +465,7 @@ def visit_Return(self, node): | |
|
||
def decay(value): | ||
if isinstance(value, language.tuple): | ||
return language.tuple([decay(v) for v in value.values]) | ||
return _apply_to_tuple_values(value, decay) | ||
elif isinstance(value, (language.constexpr, int, float)): | ||
return semantic.to_tensor(value, self.builder) | ||
return value | ||
|
@@ -575,13 +589,8 @@ def assignTarget(self, target, value): | |
def visit_Assign(self, node): | ||
# construct values to assign | ||
def _sanitize_value(value): | ||
if self._is_namedtuple(type(value)): | ||
vals = [_sanitize_value(v) for v in value] | ||
types = [v.type for v in vals] | ||
fields = type(value)._fields | ||
return language.tuple(vals, language.tuple_type(types, fields)) | ||
if isinstance(value, language.tuple): | ||
return language.tuple([_sanitize_value(v) for v in value.values]) | ||
return _apply_to_tuple_values(value, _sanitize_value) | ||
native_nontensor_types = (language.dtype, language.tuple) | ||
value = _unwrap_if_constexpr(value) | ||
if value is not None and \ | ||
|
@@ -1253,7 +1262,8 @@ def visit_Call(self, node): | |
|
||
if fn in self.builtin_namespace.values(): | ||
args = map(_unwrap_if_constexpr, args) | ||
return fn(*args, **kws) | ||
ret = fn(*args, **kws) | ||
return _apply_to_tuple_values(ret, lambda x: x) if _is_namedtuple(type(ret)) else ret | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @peterbell10 Oh, I meant it's being converted here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed we convert every named tuple to a corresponding tl.tuple when created There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, so the |
||
|
||
def visit_Constant(self, node): | ||
return constexpr(node.value) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand that we can have a python named tuple instead of a
language.tuple
? That sounds like the real bug here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, python named tuple is supposed to be supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't sound right to me, it should be converted to a triton tuple type by the frontend