Skip to content

Commit

Permalink
Transformations: Add start_depth argument to resolve_associates
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 7, 2024
1 parent 4445dd0 commit dd993dd
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 5 deletions.
50 changes: 45 additions & 5 deletions loki/transformations/sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from loki.batch import Transformation
from loki.expression import Array, RangeIndex, LokiIdentityMapper
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.tools import as_tuple
from loki.tools import as_tuple, dict_override
from loki.types import BasicType


Expand Down Expand Up @@ -58,16 +58,22 @@ def transform_subroutine(self, routine, **kwargs):
transform_sequence_association(routine)


def resolve_associates(routine):
def resolve_associates(routine, start_depth=0):
"""
Resolve :any:`Associate` mappings in the body of a given routine.
Optionally, partial resolution of only inner :any:`Associate`
mappings is supported when a ``start_depth`` is specified.
Parameters
----------
routine : :any:`Subroutine`
The subroutine for which to resolve all associate blocks.
start_depth : int, optional
Starting depth for partial resolution of :any:`Associate`
"""
routine.body = ResolveAssociatesTransformer().visit(routine.body)
transformer = ResolveAssociatesTransformer(start_depth=start_depth)
routine.body = transformer.visit(routine.body)

# Ensure that all symbols have the appropriate scope attached.
# This is needed, as the parent of a symbol might have changed,
Expand All @@ -84,6 +90,10 @@ class ResolveAssociateMapper(LokiIdentityMapper):
and replace it with the inverse of the associate mapping.
"""

def __init__(self, start_depth=0, *args, **kwargs):
self.start_depth = start_depth
super().__init__(*args, **kwargs)

def map_scalar(self, expr, *args, **kwargs):
# Skip unscoped expressions
if not hasattr(expr, 'scope'):
Expand All @@ -95,6 +105,13 @@ def map_scalar(self, expr, *args, **kwargs):

scope = expr.scope

# Determine the depth of the symbol-defining associate
depth = len(tuple(
p for p in scope.parents if isinstance(p, ir.Associate)
)) + 1
if depth <= self.start_depth:
return expr

# Recurse on parent first and propagate scope changes
parent = self.rec(expr.parent, *args, **kwargs)
if parent != expr.parent:
Expand Down Expand Up @@ -136,17 +153,40 @@ class ResolveAssociatesTransformer(Transformer):
Importantly, this :any:`Transformer` can also be applied over partial
bodies of :any:`Associate` bodies.
Optionally, partial resolution of only inner :any:`Associate`
mappings is supported when a ``start_depth`` is specified.
Parameters
----------
start_depth : int, optional
Starting depth for partial resolution of :any:`Associate`
"""
# pylint: disable=unused-argument

def __init__(self, start_depth=0, **kwargs):
self.start_depth = start_depth
super().__init__(**kwargs)

def visit_Expression(self, o, **kwargs):
return ResolveAssociateMapper()(o)
return ResolveAssociateMapper(start_depth=self.start_depth)(o)

def visit_Associate(self, o, **kwargs):
"""
Replaces an :any:`Associate` node with its transformed body
"""
return self.visit(o.body, **kwargs)

# Establish traversal depth in kwargs
depth = kwargs.get('depth', 1)

# First head-recurse, so that all associate blocks beneath are resolved
with dict_override(kwargs, {'depth': depth + 1}):
body = self.visit(o.body, **kwargs)

if depth <= self.start_depth:
return o.clone(body=body)

return body

def visit_CallStatement(self, o, **kwargs):
arguments = self.visit(o.arguments, **kwargs)
Expand Down
49 changes: 49 additions & 0 deletions loki/transformations/tests/test_sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,55 @@ def test_transform_associates_partial_body(frontend):
assert assigns[2].rhs == 'some_obj%b(i) + 1.'


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_start_depth(frontend):
"""
Test resolving associated symbols, but only for a part of an
associate's body.
"""
fcode = """
subroutine transform_associates_partial
use some_module, only: some_obj
implicit none
integer :: i
real :: local_var
associate (a=>some_obj%a, b=>some_obj%b)
associate (c=>a%b, d=>b%d)
local_var = a(1)
do i=1, some_obj%n
c(i) = c(i) + 1.
d(i) = d(i) + 1.
end do
end associate
end associate
end subroutine transform_associates_partial
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
loops = FindNodes(ir.Loop).visit(routine.body)
assert len(loops) == 1

# Resolve all expect the outermost associate block
resolve_associates(routine, start_depth=1)

# Check that associated symbols have been resolved in loop body only
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 3
assert assigns[0].lhs == 'local_var'
assert assigns[0].rhs == 'a(1)'
assert assigns[1].lhs == 'a%b(i)'
assert assigns[1].rhs == 'a%b(i) + 1.'
assert assigns[2].lhs == 'b%d(i)'
assert assigns[2].rhs == 'b%d(i) + 1.'


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path):
fcode = """
Expand Down

0 comments on commit dd993dd

Please sign in to comment.