diff --git a/loki/transformations/sanitise.py b/loki/transformations/sanitise.py index f66f5fb14..f014eea04 100644 --- a/loki/transformations/sanitise.py +++ b/loki/transformations/sanitise.py @@ -15,7 +15,7 @@ from loki.batch import Transformation from loki.expression import Array, RangeIndex, LokiIdentityMapper from loki.ir import nodes as ir, FindNodes, Transformer -from loki.tools import as_tuple +from loki.tools import as_tuple, dict_override from loki.types import BasicType @@ -58,16 +58,22 @@ def transform_subroutine(self, routine, **kwargs): transform_sequence_association(routine) -def resolve_associates(routine): +def resolve_associates(routine, start_depth=0): """ Resolve :any:`Associate` mappings in the body of a given routine. + Optionally, partial resolution of only inner :any:`Associate` + mappings is supported when a ``start_depth`` is specified. + Parameters ---------- routine : :any:`Subroutine` The subroutine for which to resolve all associate blocks. + start_depth : int, optional + Starting depth for partial resolution of :any:`Associate` """ - routine.body = ResolveAssociatesTransformer().visit(routine.body) + transformer = ResolveAssociatesTransformer(start_depth=start_depth) + routine.body = transformer.visit(routine.body) # Ensure that all symbols have the appropriate scope attached. # This is needed, as the parent of a symbol might have changed, @@ -84,6 +90,10 @@ class ResolveAssociateMapper(LokiIdentityMapper): and replace it with the inverse of the associate mapping. """ + def __init__(self, *args, start_depth=0, **kwargs): + self.start_depth = start_depth + super().__init__(*args, **kwargs) + def map_scalar(self, expr, *args, **kwargs): # Skip unscoped expressions if not hasattr(expr, 'scope'): @@ -95,6 +105,13 @@ def map_scalar(self, expr, *args, **kwargs): scope = expr.scope + # Determine the depth of the symbol-defining associate + depth = len(tuple( + p for p in scope.parents if isinstance(p, ir.Associate) + )) + 1 + if depth <= self.start_depth: + return expr + # Recurse on parent first and propagate scope changes parent = self.rec(expr.parent, *args, **kwargs) if parent != expr.parent: @@ -136,17 +153,40 @@ class ResolveAssociatesTransformer(Transformer): Importantly, this :any:`Transformer` can also be applied over partial bodies of :any:`Associate` bodies. + + Optionally, partial resolution of only inner :any:`Associate` + mappings is supported when a ``start_depth`` is specified. + + Parameters + ---------- + start_depth : int, optional + Starting depth for partial resolution of :any:`Associate` """ # pylint: disable=unused-argument + def __init__(self, start_depth=0, **kwargs): + self.start_depth = start_depth + super().__init__(**kwargs) + def visit_Expression(self, o, **kwargs): - return ResolveAssociateMapper()(o) + return ResolveAssociateMapper(start_depth=self.start_depth)(o) def visit_Associate(self, o, **kwargs): """ Replaces an :any:`Associate` node with its transformed body """ - return self.visit(o.body, **kwargs) + + # Establish traversal depth in kwargs + depth = kwargs.get('depth', 1) + + # First head-recurse, so that all associate blocks beneath are resolved + with dict_override(kwargs, {'depth': depth + 1}): + body = self.visit(o.body, **kwargs) + + if depth <= self.start_depth: + return o.clone(body=body) + + return body def visit_CallStatement(self, o, **kwargs): arguments = self.visit(o.arguments, **kwargs) diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index f7881f9da..3c08a32c5 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -253,6 +253,55 @@ def test_transform_associates_partial_body(frontend): assert assigns[2].rhs == 'some_obj%b(i) + 1.' +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'OMNI does not handle missing type definitions')] +)) +def test_transform_associates_start_depth(frontend): + """ + Test resolving associated symbols, but only for a part of an + associate's body. + """ + fcode = """ +subroutine transform_associates_partial + use some_module, only: some_obj + implicit none + + integer :: i + real :: local_var + + associate (a=>some_obj%a, b=>some_obj%b) + associate (c=>a%b, d=>b%d) + local_var = a(1) + + do i=1, some_obj%n + c(i) = c(i) + 1. + d(i) = d(i) + 1. + end do + end associate + end associate +end subroutine transform_associates_partial +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + loops = FindNodes(ir.Loop).visit(routine.body) + assert len(loops) == 1 + + # Resolve all expect the outermost associate block + resolve_associates(routine, start_depth=1) + + # Check that associated symbols have been resolved in loop body only + assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 3 + assert assigns[0].lhs == 'local_var' + assert assigns[0].rhs == 'a(1)' + assert assigns[1].lhs == 'a%b(i)' + assert assigns[1].rhs == 'a%b(i) + 1.' + assert assigns[2].lhs == 'b%d(i)' + assert assigns[2].rhs == 'b%d(i) + 1.' + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path): fcode = """