Skip to content

Commit

Permalink
Expression: Add test for FindVariables that covers corner cases
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Aug 28, 2024
1 parent 90e9fe8 commit 22a62a1
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions loki/expression/tests/test_expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,68 @@ def test_expression_finder_retrieval_function(frontend, tmp_path):

# Make sure the first expression finder still works
assert find_ts.visit(source['other_routine'].body) == expected_ts


@pytest.mark.parametrize('frontend', available_frontends())
def test_find_variables(frontend, tmp_path):
""" Test that :any:`FindVariables` finds all symbol uses. """

fcode_external = """
module external_mod
implicit none
contains
subroutine rick(dave, never)
real(kind=8), intent(inout) :: dave, never
end subroutine rick
end module external_mod
"""
fcode = """
module test_mod
use external_mod, only: rick
type my_type
real(kind=8) :: never
real(kind=8), pointer :: give_you(:)
end type my_type
contains
subroutine test_routine(n, a, b, gonna)
integer, intent(in) :: n
real(kind=8), intent(inout) :: a, b(n)
type(my_type), intent(inout) :: gonna
integer :: i
associate(will=>gonna%never, up=>n)
do i=1, n
b(i) = b(i) + a
end do
call rick(will, never=gonna%give_you(up))
end associate
end subroutine test_routine
end module test_mod
"""
_ = Sourcefile.from_source(fcode_external, frontend=frontend, xmods=[tmp_path])
source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['test_routine']

# Test unique=True|False using the spec
expected = ['n', 'a', 'gonna', 'i', 'b(n)']
spec_vars = FindVariables(unique=True).visit(routine.spec)
assert len(spec_vars) == 5
assert all(v in spec_vars for v in expected)

spec_vars = FindVariables(unique=False).visit(routine.spec)
assert len(spec_vars) == 6
assert all(v in spec_vars for v in expected)
assert len([v for v in spec_vars if v == 'n']) == 2 # two occurences of 'n'

# Test retrieval with associates and keyword arg calls
expected = [
'will', 'gonna', 'gonna%never', 'up', 'n', 'i', 'b(i)', 'a',
'rick', 'gonna%give_you(up)'
]
body_vars = FindVariables(unique=True).visit(routine.body)
assert len(body_vars) == 10
assert all(v in body_vars for v in expected)

0 comments on commit 22a62a1

Please sign in to comment.