From b52a07c4288fe9b1dd6ecca23b050e8f7eeceec1 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 20 Feb 2025 16:14:01 +0100 Subject: [PATCH] Cleanup --- .../next/iterator/transforms/collapse_tuple.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 97c8f1ca02..923f6d1302 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -170,7 +170,7 @@ class Transformation(enum.Flag): PROPAGATE_NESTED_LET = enum.auto() #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() - #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> as_fieldop(λ(a, b) → ·a+·b)(a, b) + #: `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` FLATTEN_AS_FIELDOP_ARGS = enum.auto() #: `let(a, b[1])(a)` -> `b[1]` INLINE_TRIVIAL_TUPLE_LET_VAR = enum.auto() @@ -529,6 +529,7 @@ def transform_inline_trivial_tuple_let_var(self, node: ir.Node, **kwargs) -> Opt def transform_flatten_as_fieldop_args( self, node: itir.FunCall, **kwargs ) -> Optional[itir.Node]: + # `as_fieldop(λ(t) → ·t[0]+·t[1])({a, b})` -> `as_fieldop(λ(a, b) → ·a+·b)(a, b)` if not cpm.is_applied_as_fieldop(node): return None @@ -545,13 +546,15 @@ def transform_flatten_as_fieldop_args( new_body = stencil.expr domain = node.fun.args[1] if len(node.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop - orig_args_map: dict[itir.Sym, itir.Expr] = {} + remapped_args: dict[ + itir.Sym, itir.Expr + ] = {} # contains the arguments that are remapped, e.g. `{a, b}` new_params: list[itir.Sym] = [] new_args: list[itir.Expr] = [] for param, arg in zip(stencil.params, node.args, strict=True): if isinstance(arg.type, ts.TupleType): - ref_to_orig_arg = im.ref(f"__ct_flat_orig_arg_{len(orig_args_map)}", arg.type) - orig_args_map[im.sym(ref_to_orig_arg.id, arg.type)] = arg + ref_to_remapped_arg = im.ref(f"__ct_flat_remapped_{len(remapped_args)}", arg.type) + remapped_args[im.sym(ref_to_remapped_arg.id, arg.type)] = arg new_params_inner, lift_params = [], [] for i, type_ in enumerate(param.type.element_type.types): new_param = im.sym( @@ -565,8 +568,10 @@ def transform_flatten_as_fieldop_args( ) ) new_params_inner.append(new_param) - new_args.append(im.tuple_get(i, ref_to_orig_arg)) + new_args.append(im.tuple_get(i, ref_to_remapped_arg)) + # an iterator that substitutes the original (tuple) iterator, e.g. `t`. Built + # from the new parameters which are the elements of `t`. param_substitute = im.lift( im.lambda_(*lift_params)( im.make_tuple(*[im.deref(im.ref(p.id, p.type)) for p in lift_params]) @@ -576,6 +581,7 @@ def transform_flatten_as_fieldop_args( new_body = im.let(param.id, param_substitute)(new_body) # note: the lift is trivial so inlining it is not an issue with respect to tree size new_body = inline_lambda(new_body, force_inline_lift_args=True) + new_params.extend(new_params_inner) else: new_params.append(param) @@ -589,4 +595,4 @@ def transform_flatten_as_fieldop_args( new_body = self.visit(new_body, **kwargs) new_stencil = restore_scan(im.lambda_(*new_params)(new_body)) - return im.let(*orig_args_map.items())(im.as_fieldop(new_stencil, domain)(*new_args)) + return im.let(*remapped_args.items())(im.as_fieldop(new_stencil, domain)(*new_args))