Skip to content

Commit

Permalink
Expression: Add an explicit test for SubstituteExpressions
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Aug 28, 2024
1 parent 22a62a1 commit cabc90c
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions loki/expression/tests/test_expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

import pytest

from loki import Sourcefile
from loki.expression import FindVariables, FindTypedSymbols
from loki import Sourcefile, Subroutine
from loki.expression import (
symbols as sym, parse_expr, FindVariables, FindTypedSymbols,
SubstituteExpressions
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down Expand Up @@ -117,3 +121,55 @@ def test_find_variables(frontend, tmp_path):
body_vars = FindVariables(unique=True).visit(routine.body)
assert len(body_vars) == 10
assert all(v in body_vars for v in expected)


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_expressions(frontend):
""" Test symbol replacement with :any:`Expression` symbols. """

fcode = """
subroutine test_routine(n, a, b)
integer, intent(in) :: n
real(kind=8), intent(inout) :: a, b(n)
real(kind=8) :: c(n)
integer :: i
associate(d => a)
do i=1, n
c(i) = b(i) + a
end do
call another_routine(n, a, c(:), a2=d)
end associate
end subroutine test_routine
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

calls = FindNodes(ir.CallStatement).visit(routine.body)
assoc = FindNodes(ir.Associate).visit(routine.body)[0]
assert calls[0].arguments == ('n', 'a', 'c(:)')
assert calls[0].kwarguments == (('a2', 'd'),)

n = routine.variable_map['n']
i = routine.variable_map['i']
a = routine.variable_map['a']
b_i = parse_expr('b(i)', scope=routine)
c_r = parse_expr('c(:)', scope=routine)
d = parse_expr('d', scope=assoc)
expr_map = {
n: sym.Sum((n, sym.Product((-1, sym.Literal(1))))),
b_i: b_i.clone(dimensions=sym.Sum((i, sym.Literal(1)))),
c_r: c_r.clone(dimensions=sym.Range((sym.Literal(1), sym.Literal(2)))),
a: d,
d: a,
}
routine.body = SubstituteExpressions(expr_map).visit(routine.body)

loops = FindNodes(ir.Loop).visit(routine.body)
assert loops[0].bounds == '1:n-1'
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert assigns[0].lhs == 'c(i)' and assigns[0].rhs == 'b(i+1) + d'
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert calls[0].arguments == ('n - 1', 'd', 'c(1:2)')
assert calls[0].kwarguments == (('a2', 'a'),)

0 comments on commit cabc90c

Please sign in to comment.