Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Feb 20, 2025
1 parent d3957bd commit b52a07c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class Transformation(enum.Flag):
PROPAGATE_NESTED_LET = enum.auto()
#: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)`
INLINE_TRIVIAL_LET = enum.auto()
#: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b)
#: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)`
FLATTEN_AS_FIELDOP_ARGS = enum.auto()
#: `let(a, b[1])(a)` -> `b[1]`
INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto()
Expand Down Expand Up @@ -529,6 +529,7 @@ def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Opt
def transform_flatten_as_fieldop_args(
self, node: itir.FunCall, **kwargs
) -> Optional[itir.Node]:
# `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)`
if not cpm.is_applied_as_fieldop(node):
return None

Expand All @@ -545,13 +546,15 @@ def transform_flatten_as_fieldop_args(

new_body = stencil.expr
domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop
orig_args_map: dict[itir.Sym, itir.Expr] = {}
remapped_args: dict[
itir.Sym, itir.Expr
] = {} # contains the arguments that are remapped, e.g. `{a, b}`
new_params: list[itir.Sym] = []
new_args: list[itir.Expr] = []
for param, arg in zip(stencil.params, node.args, strict=True):
if isinstance(arg.type, ts.TupleType):
ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type)
orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg
ref_to_remapped_arg = im.ref(f"__ct_flat_remapped_{len(remapped_args)}", arg.type)
remapped_args[im.sym(ref_to_remapped_arg.id, arg.type)] = arg
new_params_inner, lift_params = [], []
for i, type_ in enumerate(param.type.element_type.types):
new_param = im.sym(
Expand All @@ -565,8 +568,10 @@ def transform_flatten_as_fieldop_args(
)
)
new_params_inner.append(new_param)
new_args.append(im.tuple_get(i, ref_to_orig_arg))
new_args.append(im.tuple_get(i, ref_to_remapped_arg))

# an iterator that substitutes the original (tuple) iterator, e.g. `t`. Built
# from the new parameters which are the elements of `t`.
param_substitute = im.lift(
im.lambda_(*lift_params)(
im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params])
Expand All @@ -576,6 +581,7 @@ def transform_flatten_as_fieldop_args(
new_body = im.let(param.id, param_substitute)(new_body)
# note: the lift is trivial so inlining it is not an issue with respect to tree size
new_body = inline_lambda(new_body, force_inline_lift_args=True)

new_params.extend(new_params_inner)
else:
new_params.append(param)
Expand All @@ -589,4 +595,4 @@ def transform_flatten_as_fieldop_args(
new_body = self.visit(new_body, **kwargs)
new_stencil = restore_scan(im.lambda_(*new_params)(new_body))

return im.let(*orig_args_map.items())(im.as_fieldop(new_stencil, domain)(*new_args))
return im.let(*remapped_args.items())(im.as_fieldop(new_stencil, domain)(*new_args))

0 comments on commit b52a07c

Please sign in to comment.