From 9d38decc5297e561d7718a6edbf87beeccece919 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Tue, 27 Aug 2024 18:52:26 +0000 Subject: [PATCH] Expressions: Add SusbtituteStringExpressions, which uses `parse_expr` --- loki/expression/expr_visitors.py | 33 +++++++++++++- loki/expression/tests/test_expr_visitors.py | 49 ++++++++++++++++++++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index 47a9bc180..f2ed2a089 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -24,8 +24,9 @@ ) __all__ = [ - 'FindExpressions', 'FindVariables', 'FindTypedSymbols', 'FindInlineCalls', - 'FindLiterals', 'SubstituteExpressions', 'ExpressionFinder', 'AttachScopes' + 'FindExpressions', 'FindVariables', 'FindTypedSymbols', + 'FindInlineCalls', 'FindLiterals', 'SubstituteExpressions', + 'SubstituteStringExpressions', 'ExpressionFinder', 'AttachScopes' ] @@ -259,6 +260,34 @@ def visit_Import(self, o, **kwargs): visit_ProcedureDeclaration = visit_Import +class SubstituteStringExpressions(SubstituteExpressions): + """ + Extension to :any:`SubstituteExpressions` that allows symbol + substitution of pure string mappings via :any:`parse_expr`. + + In addition to the input string mapping this requires a :any:`Scope` + (eg. :any:`Subroutine` or :any:`Module`) to parse the respective strings. + + Parameters + ---------- + expr_map : dict + String-to-string mapping of expressions to apply to the expression tree. + scope : :any:`Scope` + The scope to which symbol names inside the expression belong + invalidate_source : bool, optional + By default the :attr:`source` property of nodes is discarded + when rebuilding the node, setting this to `False` allows to + retain that information + """ + def __init__(self, str_map, scope, invalidate_source=True): + from loki.expression.parser import parse_expr # pylint: disable=import-outside-toplevel,cyclic-import + expr_map = { + parse_expr(k, scope=scope): parse_expr(v, scope=scope) + for k, v in str_map.items() + } + super().__init__(expr_map=expr_map, invalidate_source=invalidate_source) + + class AttachScopes(Visitor): """ Scoping visitor that traverses the control flow tree and uses diff --git a/loki/expression/tests/test_expr_visitors.py b/loki/expression/tests/test_expr_visitors.py index e134577fd..29d634b35 100644 --- a/loki/expression/tests/test_expr_visitors.py +++ b/loki/expression/tests/test_expr_visitors.py @@ -10,7 +10,7 @@ from loki import Sourcefile, Subroutine from loki.expression import ( symbols as sym, parse_expr, FindVariables, FindTypedSymbols, - SubstituteExpressions + SubstituteExpressions, SubstituteStringExpressions ) from loki.frontend import available_frontends from loki.ir import nodes as ir, FindNodes @@ -173,3 +173,50 @@ def test_substitute_expressions(frontend): calls = FindNodes(ir.CallStatement).visit(routine.body) assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)') assert calls[0].kwarguments == (('a2', 'a'),) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_substitute_string_expressions(frontend): + """ Test symbol replacement with symbol string mappping. """ + + fcode = """ +subroutine test_routine(n, a, b) + integer, intent(in) :: n + real(kind=8), intent(inout) :: a, b(n) + real(kind=8) :: c(n) + integer :: i + + associate(d => a) + do i=1, n + c(i) = b(i) + a + end do + + call another_routine(n, a, c(:), a2=d) + + end associate +end subroutine test_routine +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + calls = FindNodes(ir.CallStatement).visit(routine.body) + assoc = FindNodes(ir.Associate).visit(routine.body)[0] + assert calls[0].arguments == ('n', 'a', 'c(:)') + assert calls[0].kwarguments == (('a2', 'd'),) + + expr_map = { + 'n': 'n - 1', + 'b(i)': 'b(i+1)', + 'c(:)': 'c(1:2)', + 'a': 'd', + 'd': 'a', + } + # Note that we need to use the associate block here, as it defines 'd' + routine.body = SubstituteStringExpressions(expr_map, scope=assoc).visit(routine.body) + + loops = FindNodes(ir.Loop).visit(routine.body) + assert loops[0].bounds == '1:n-1' + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert assigns[0].lhs == 'c(i)' and assigns[0].rhs == 'b(i+1) + d' + calls = FindNodes(ir.CallStatement).visit(routine.body) + assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)') + assert calls[0].kwarguments == (('a2', 'a'),)