Skip to content

Commit

Permalink
DependencyTransform: Test header mod import and fix inline call impor…
Browse files Browse the repository at this point in the history
…t renaming
  • Loading branch information
reuterbal committed Sep 12, 2023
1 parent 0ce23a0 commit fdf59cd
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 19 deletions.
4 changes: 2 additions & 2 deletions loki/transform/dependency_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ def rename_imports(self, source, imports, **kwargs):
calls = ()
for routine in source.subroutines:
calls += as_tuple(str(c.name).upper() for c in FindNodes(CallStatement).visit(routine.body))
calls += as_tuple(str(c).upper() for c in FindInlineCalls().visit(routine.body))
calls += as_tuple(str(c.name).upper() for c in FindInlineCalls().visit(routine.body))
else:
calls = as_tuple(str(c.name).upper() for c in FindNodes(CallStatement).visit(source.body))
calls += as_tuple(str(c).upper() for c in FindInlineCalls().visit(source.body))
calls += as_tuple(str(c.name).upper() for c in FindInlineCalls().visit(source.body))

# Import statements still point to unmodified call names
calls = [call.replace(f'{self.suffix.upper()}', '') for call in calls]
Expand Down
111 changes: 94 additions & 17 deletions tests/test_transform_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, te
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend)
driver = Sourcefile.from_source(driver_fcode, frontend=frontend)

# Because the renaming is intended to be applied to the routines as well as the enclosing module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets=('kernel', 'some_const'))

Expand Down Expand Up @@ -160,8 +158,6 @@ def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_sc
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend)
driver = Sourcefile.from_source(driver_fcode, frontend=frontend)

# Because the renaming is intended to be applied to the routines as well as the enclosing module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel')
driver.apply(transformation, role='driver', targets=('kernel', 'some_const'))

Expand Down Expand Up @@ -285,8 +281,6 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir,
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend)
driver = Sourcefile.from_source(driver_fcode, frontend=frontend)

# Because the renaming is intended to also wrap the kernel in a module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets='kernel')

Expand Down Expand Up @@ -363,8 +357,6 @@ def test_dependency_transformation_replace_interface(frontend, use_scheduler, te
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend)
driver = Sourcefile.from_source(driver_fcode, frontend=frontend)

# Because the renaming is intended to also wrap the kernel in a module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets='kernel')

Expand Down Expand Up @@ -435,8 +427,6 @@ def test_dependency_transformation_inline_call(frontend):

# Apply injection transformation via C-style includes by giving `include_path`
transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod')
# Because the renaming is intended to also wrap the kernel in a module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets='kernel')

Expand Down Expand Up @@ -505,8 +495,6 @@ def test_dependency_transformation_inline_call_result_var(frontend):

# Apply injection transformation via C-style includes by giving `include_path`
transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod')
# Because the renaming is intended to also wrap the kernel in a module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets='kernel')

Expand Down Expand Up @@ -542,8 +530,8 @@ def test_dependency_transformation_inline_call_result_var(frontend):
def test_dependency_transformation_contained_member(frontend, use_scheduler, tempdir, config):
"""
The scheduler currently does not recognize or allow processing contained member routines as part
of the scheduler graph traversal. This test ensures that even with the transformation class functionality
to recurse into contained members enabled, the dependency injection is not applied.
of the scheduler graph traversal. This test ensures that the transformation class
does not recurse into contained members.
"""

kernel_fcode = """
Expand Down Expand Up @@ -582,7 +570,6 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tem
END SUBROUTINE driver
""".strip()


transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')

if use_scheduler:
Expand All @@ -597,8 +584,6 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tem
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend)
driver = Sourcefile.from_source(driver_fcode, frontend=frontend)

# Because the renaming is intended to be applied to the routines as well as the enclosing module,
# we need to invoke the transformation on the full source file and activate recursion to contained nodes
kernel.apply(transformation, role='kernel', targets=('set_a', 'get_b'))
driver['driver'].apply(transformation, role='driver', targets=('kernel', 'some_const'))

Expand Down Expand Up @@ -627,3 +612,95 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tem
calls = FindInlineCalls(unique=False).visit(kernel['kernel_test'].body)
assert len(calls) == 1
assert calls[0].name == 'get_b'


@pytest.mark.parametrize('frontend', available_frontends())
def test_dependency_transformation_item_filter(frontend, tempdir, config):
"""
Test that injection is not applied to modules that have no procedures
in the scheduler graph, even if they have other item members.
"""

driver_fcode = """
SUBROUTINE driver(a, b, c)
USE HEADER_MOD, ONLY: HEADER_VAR
USE KERNEL_MOD, ONLY: KERNEL
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b, c
a = kernel(a)
b = kernel(a)
c = kernel(c) + HEADER_VAR
END SUBROUTINE driver
""".strip()

kernel_fcode = """
MODULE kernel_mod
IMPLICIT NONE
CONTAINS
FUNCTION kernel(a) RESULT(ret)
INTEGER, INTENT(IN) :: a
INTEGER :: ret
ret = 2*a
END FUNCTION kernel
END MODULE kernel_mod
""".strip()

header_fcode = """
MODULE header_mod
IMPLICIT NONE
INTEGER :: HEADER_VAR
END MODULE header_mod
""".strip()

(tempdir/'kernel_mod.F90').write_text(kernel_fcode)
(tempdir/'header_mod.F90').write_text(header_fcode)
(tempdir/'driver.F90').write_text(driver_fcode)

# Create the scheduler such that it chases imports
config['default']['enable_imports'] = True
scheduler = Scheduler(paths=[tempdir], config=SchedulerConfig.from_dict(config), frontend=frontend)

# Make sure the header var item exists
assert 'header_mod#header_var' in scheduler.items

transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod')
scheduler.process(transformation, use_file_graph=True)

kernel = scheduler['kernel_mod#kernel'].source
header = scheduler['header_mod#header_var'].source
driver = scheduler['#driver'].source

# Check that the kernel mod has been changed
assert len(kernel.subroutines) == 0
assert len(kernel.all_subroutines) == 1
assert kernel.all_subroutines[0].name == 'kernel_test'
assert kernel['kernel_test'] == kernel.all_subroutines[0]
assert kernel['kernel_test'].is_function
assert len(kernel.modules) == 1
assert kernel.modules[0].name == 'kernel_test_mod'
assert kernel['kernel_test_mod'] == kernel.modules[0]

# Check that the header name has not been changed
assert len(header.modules) == 1
assert header.modules[0].name == 'header_mod'
assert header.modules[0].variables == ('header_var',)

# Check that the driver name has not changed
assert len(driver.modules) == 0
assert len(driver.subroutines) == 1
assert driver.subroutines[0].name == 'driver'

# Check that calls and imports have been diverted to the re-generated routine
calls = tuple(FindInlineCalls().visit(driver['driver'].body))
assert len(calls) == 2
calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body))
assert len(calls) == 3
assert all(call.name == 'kernel_test' for call in calls)
imports = FindNodes(Import).visit(driver['driver'].spec)
imports = driver['driver'].import_map
assert len(imports) == 2
assert 'header_var' in imports and imports['header_var'].module.lower() == 'header_mod'
assert 'kernel_test' in imports and imports['kernel_test'].module.lower() == 'kernel_test_mod'

0 comments on commit fdf59cd

Please sign in to comment.