diff --git a/loki/transformations/sanitise.py b/loki/transformations/sanitise.py index f014eea04..c82b5df79 100644 --- a/loki/transformations/sanitise.py +++ b/loki/transformations/sanitise.py @@ -14,13 +14,13 @@ from loki.batch import Transformation from loki.expression import Array, RangeIndex, LokiIdentityMapper -from loki.ir import nodes as ir, FindNodes, Transformer +from loki.ir import nodes as ir, FindNodes, Transformer, NestedTransformer from loki.tools import as_tuple, dict_override from loki.types import BasicType __all__ = [ - 'SanitiseTransformation', 'resolve_associates', + 'SanitiseTransformation', 'resolve_associates', 'merge_associates', 'ResolveAssociatesTransformer', 'transform_sequence_association', 'transform_sequence_association_append_map' ] @@ -194,6 +194,86 @@ def visit_CallStatement(self, o, **kwargs): return o._rebuild(arguments=arguments, kwarguments=kwarguments) +def merge_associates(routine, max_parents=None): + """ + Moves associate mappings in :any:`Associate` within a + :any:`Subroutine` to the outermost parent scope. + + Please see :any:`MergeAssociatesTransformer` for mode details. + + Note + ---- + This method can be combined with :any:`resolve_associates` to + create a more unified look-and-feel for nested ASSOCIATE blocks. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine for which to resolve all associate blocks. + max_parents : int, optional + Maximum number of parent symbols for valid selector to have. + """ + transformer = MergeAssociatesTransformer(max_parents=max_parents) + routine.body = transformer.visit(routine.body) + + +class MergeAssociatesTransformer(NestedTransformer): + """ + :any:`NestedTransformer` that moves associate mappings in + :any:`Associate` to parent nodes. + + If a selector expression depends on a symbol from a parent + :any:`Associate` exists, it does not get moved. + + Additionally, a maximum parent-depth can be specified for the + selector to prevent overly long symbols to be moved up. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine for which to resolve all associate blocks. + max_parents : int, optional + Maximum number of parent symbols for valid selector to have. + """ + + def __init__(self, max_parents=None, **kwargs): + self.max_parents = max_parents + super().__init__(**kwargs) + + def visit_Associate(self, o, **kwargs): + body = self.visit(o.body, **kwargs) + + if not o.parent or not isinstance(o.parent, ir.Associate): + return o._rebuild(body=body) + + # Find all associate mapping that can be moved up + to_move = tuple( + (expr, name) for expr, name in o.associations + if not expr.scope == o.parent + ) + + if self.max_parents: + # Optionally filter by depth of symbol-parentage + to_move = tuple( + (expr, name) for expr, name in to_move + if not len(expr.parents) > self.max_parents + ) + + # Move up to parent ... + parent_assoc = tuple( + (expr, name) for expr, name in to_move + if (expr, name) not in o.parent.associations + ) + o.parent._update(associations=o.parent.associations + parent_assoc) + + # ... and remove from this associate node + new_assocs = tuple( + (expr, name) for expr, name in o.associations + if (expr, name) not in to_move + ) + return o._rebuild(body=body, associations=new_assocs) + + def check_if_scalar_syntax(arg, dummy): """ Check if an array argument, arg, diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index 3c08a32c5..b2ecb1255 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -14,8 +14,9 @@ from loki.ir import nodes as ir, FindNodes from loki.transformations.sanitise import ( - resolve_associates, transform_sequence_association, - ResolveAssociatesTransformer, SanitiseTransformation + resolve_associates, merge_associates, + transform_sequence_association, ResolveAssociatesTransformer, + SanitiseTransformation ) @@ -302,6 +303,59 @@ def test_transform_associates_start_depth(frontend): assert assigns[2].rhs == 'b%d(i) + 1.' +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'OMNI does not handle missing type definitions')] +)) +def test_merge_associates_nested(frontend): + """ + Test association merging for nested mappings. + """ + fcode = """ +subroutine merge_associates_simple(base) + use some_module, only: some_type + implicit none + + type(some_type), intent(inout) :: base + integer :: i + real :: local_var + + associate(a => base%a) + associate(b => base%other%symbol, c => a%more) + associate(d => base%other%symbol%really%deep, & + & a => base%a) + do i=1, 5 + call another_routine(i, n=b(c)%n) + + d(i) = 42.0 + end do + end associate + end associate + end associate +end subroutine merge_associates_simple +""" + + routine = Subroutine.from_source(fcode, frontend=frontend) + + assocs = FindNodes(ir.Associate).visit(routine.body) + assert len(assocs) == 3 + assert len(assocs[0].associations) == 1 + assert len(assocs[1].associations) == 2 + assert len(assocs[2].associations) == 2 + + # Move associate mapping around + merge_associates(routine, max_parents=2) + + assocs = FindNodes(ir.Associate).visit(routine.body) + assert len(assocs) == 3 + assert len(assocs[0].associations) == 2 + assert assocs[0].associations[0] == ('base%a', 'a') + assert assocs[0].associations[1] == ('base%other%symbol', 'b') + assert len(assocs[1].associations) == 1 + assert assocs[1].associations[0] == ('a%more', 'c') + assert len(assocs[2].associations) == 1 + assert assocs[2].associations[0] == ('base%other%symbol%really%deep', 'd') + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path): fcode = """