Skip to content

Commit

Permalink
Expressions: Add SusbtituteStringExpressions, which uses parse_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Aug 28, 2024
1 parent cabc90c commit 9d38dec
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
33 changes: 31 additions & 2 deletions loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
)

__all__ = [
'FindExpressions', 'FindVariables', 'FindTypedSymbols', 'FindInlineCalls',
'FindLiterals', 'SubstituteExpressions', 'ExpressionFinder', 'AttachScopes'
'FindExpressions', 'FindVariables', 'FindTypedSymbols',
'FindInlineCalls', 'FindLiterals', 'SubstituteExpressions',
'SubstituteStringExpressions', 'ExpressionFinder', 'AttachScopes'
]


Expand Down Expand Up @@ -259,6 +260,34 @@ def visit_Import(self, o, **kwargs):
visit_ProcedureDeclaration = visit_Import


class SubstituteStringExpressions(SubstituteExpressions):
"""
Extension to :any:`SubstituteExpressions` that allows symbol
substitution of pure string mappings via :any:`parse_expr`.
In addition to the input string mapping this requires a :any:`Scope`
(eg. :any:`Subroutine` or :any:`Module`) to parse the respective strings.
Parameters
----------
expr_map : dict
String-to-string mapping of expressions to apply to the expression tree.
scope : :any:`Scope`
The scope to which symbol names inside the expression belong
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""
def __init__(self, str_map, scope, invalidate_source=True):
from loki.expression.parser import parse_expr # pylint: disable=import-outside-toplevel,cyclic-import
expr_map = {
parse_expr(k, scope=scope): parse_expr(v, scope=scope)
for k, v in str_map.items()
}
super().__init__(expr_map=expr_map, invalidate_source=invalidate_source)


class AttachScopes(Visitor):
"""
Scoping visitor that traverses the control flow tree and uses
Expand Down
49 changes: 48 additions & 1 deletion loki/expression/tests/test_expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loki import Sourcefile, Subroutine
from loki.expression import (
symbols as sym, parse_expr, FindVariables, FindTypedSymbols,
SubstituteExpressions
SubstituteExpressions, SubstituteStringExpressions
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
Expand Down Expand Up @@ -173,3 +173,50 @@ def test_substitute_expressions(frontend):
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)')
assert calls[0].kwarguments == (('a2', 'a'),)


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_string_expressions(frontend):
""" Test symbol replacement with symbol string mappping. """

fcode = """
subroutine test_routine(n, a, b)
integer, intent(in) :: n
real(kind=8), intent(inout) :: a, b(n)
real(kind=8) :: c(n)
integer :: i
associate(d => a)
do i=1, n
c(i) = b(i) + a
end do
call another_routine(n, a, c(:), a2=d)
end associate
end subroutine test_routine
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

calls = FindNodes(ir.CallStatement).visit(routine.body)
assoc = FindNodes(ir.Associate).visit(routine.body)[0]
assert calls[0].arguments == ('n', 'a', 'c(:)')
assert calls[0].kwarguments == (('a2', 'd'),)

expr_map = {
'n': 'n - 1',
'b(i)': 'b(i+1)',
'c(:)': 'c(1:2)',
'a': 'd',
'd': 'a',
}
# Note that we need to use the associate block here, as it defines 'd'
routine.body = SubstituteStringExpressions(expr_map, scope=assoc).visit(routine.body)

loops = FindNodes(ir.Loop).visit(routine.body)
assert loops[0].bounds == '1:n-1'
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert assigns[0].lhs == 'c(i)' and assigns[0].rhs == 'b(i+1) + d'
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)')
assert calls[0].kwarguments == (('a2', 'a'),)

0 comments on commit 9d38dec

Please sign in to comment.