diff --git a/loki/expression/tests/test_expr_visitors.py b/loki/expression/tests/test_expr_visitors.py index 7960193eb..1745647ab 100644 --- a/loki/expression/tests/test_expr_visitors.py +++ b/loki/expression/tests/test_expr_visitors.py @@ -52,3 +52,68 @@ def test_expression_finder_retrieval_function(frontend, tmp_path): # Make sure the first expression finder still works assert find_ts.visit(source['other_routine'].body) == expected_ts + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_find_variables(frontend, tmp_path): + """ Test that :any:`FindVariables` finds all symbol uses. """ + + fcode_external = """ +module external_mod +implicit none +contains +subroutine rick(dave, never) + real(kind=8), intent(inout) :: dave, never +end subroutine rick +end module external_mod + """ + fcode = """ +module test_mod + use external_mod, only: rick + + type my_type + real(kind=8) :: never + real(kind=8), pointer :: give_you(:) + end type my_type + +contains + + subroutine test_routine(n, a, b, gonna) + integer, intent(in) :: n + real(kind=8), intent(inout) :: a, b(n) + type(my_type), intent(inout) :: gonna + integer :: i + + associate(will=>gonna%never, up=>n) + do i=1, n + b(i) = b(i) + a + end do + + call rick(will, never=gonna%give_you(up)) + end associate + end subroutine test_routine +end module test_mod + """ + _ = Sourcefile.from_source(fcode_external, frontend=frontend, xmods=[tmp_path]) + source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = source['test_routine'] + + # Test unique=True|False using the spec + expected = ['n', 'a', 'gonna', 'i', 'b(n)'] + spec_vars = FindVariables(unique=True).visit(routine.spec) + assert len(spec_vars) == 5 + assert all(v in spec_vars for v in expected) + + spec_vars = FindVariables(unique=False).visit(routine.spec) + assert len(spec_vars) == 6 + assert all(v in spec_vars for v in expected) + assert len([v for v in spec_vars if v == 'n']) == 2 # two occurences of 'n' + + # Test retrieval with associates and keyword arg calls + expected = [ + 'will', 'gonna', 'gonna%never', 'up', 'n', 'i', 'b(i)', 'a', + 'rick', 'gonna%give_you(up)' + ] + body_vars = FindVariables(unique=True).visit(routine.body) + assert len(body_vars) == 10 + assert all(v in body_vars for v in expected)