diff --git a/loki/transform/dependency_transform.py b/loki/transform/dependency_transform.py index 72d2d1e18..d70c2d4bb 100644 --- a/loki/transform/dependency_transform.py +++ b/loki/transform/dependency_transform.py @@ -206,10 +206,10 @@ def rename_imports(self, source, imports, **kwargs): calls = () for routine in source.subroutines: calls += as_tuple(str(c.name).upper() for c in FindNodes(CallStatement).visit(routine.body)) - calls += as_tuple(str(c).upper() for c in FindInlineCalls().visit(routine.body)) + calls += as_tuple(str(c.name).upper() for c in FindInlineCalls().visit(routine.body)) else: calls = as_tuple(str(c.name).upper() for c in FindNodes(CallStatement).visit(source.body)) - calls += as_tuple(str(c).upper() for c in FindInlineCalls().visit(source.body)) + calls += as_tuple(str(c.name).upper() for c in FindInlineCalls().visit(source.body)) # Import statements still point to unmodified call names calls = [call.replace(f'{self.suffix.upper()}', '') for call in calls] diff --git a/tests/test_transform_dependency.py b/tests/test_transform_dependency.py index 0a7132431..f2ccca5b5 100644 --- a/tests/test_transform_dependency.py +++ b/tests/test_transform_dependency.py @@ -86,8 +86,6 @@ def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, te kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - # Because the renaming is intended to be applied to the routines as well as the enclosing module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel') driver['driver'].apply(transformation, role='driver', targets=('kernel', 'some_const')) @@ -160,8 +158,6 @@ def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_sc kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - # Because the renaming is intended to be applied to the routines as well as the enclosing module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel') driver.apply(transformation, role='driver', targets=('kernel', 'some_const')) @@ -285,8 +281,6 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir, kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - # Because the renaming is intended to also wrap the kernel in a module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel') driver['driver'].apply(transformation, role='driver', targets='kernel') @@ -363,8 +357,6 @@ def test_dependency_transformation_replace_interface(frontend, use_scheduler, te kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - # Because the renaming is intended to also wrap the kernel in a module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel') driver['driver'].apply(transformation, role='driver', targets='kernel') @@ -435,8 +427,6 @@ def test_dependency_transformation_inline_call(frontend): # Apply injection transformation via C-style includes by giving `include_path` transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') - # Because the renaming is intended to also wrap the kernel in a module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel') driver['driver'].apply(transformation, role='driver', targets='kernel') @@ -505,8 +495,6 @@ def test_dependency_transformation_inline_call_result_var(frontend): # Apply injection transformation via C-style includes by giving `include_path` transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') - # Because the renaming is intended to also wrap the kernel in a module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel') driver['driver'].apply(transformation, role='driver', targets='kernel') @@ -542,8 +530,8 @@ def test_dependency_transformation_inline_call_result_var(frontend): def test_dependency_transformation_contained_member(frontend, use_scheduler, tempdir, config): """ The scheduler currently does not recognize or allow processing contained member routines as part - of the scheduler graph traversal. This test ensures that even with the transformation class functionality - to recurse into contained members enabled, the dependency injection is not applied. + of the scheduler graph traversal. This test ensures that the transformation class + does not recurse into contained members. """ kernel_fcode = """ @@ -582,7 +570,6 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tem END SUBROUTINE driver """.strip() - transformation = DependencyTransformation(suffix='_test', module_suffix='_mod') if use_scheduler: @@ -597,8 +584,6 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tem kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - # Because the renaming is intended to be applied to the routines as well as the enclosing module, - # we need to invoke the transformation on the full source file and activate recursion to contained nodes kernel.apply(transformation, role='kernel', targets=('set_a', 'get_b')) driver['driver'].apply(transformation, role='driver', targets=('kernel', 'some_const')) @@ -627,3 +612,95 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tem calls = FindInlineCalls(unique=False).visit(kernel['kernel_test'].body) assert len(calls) == 1 assert calls[0].name == 'get_b' + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_dependency_transformation_item_filter(frontend, tempdir, config): + """ + Test that injection is not applied to modules that have no procedures + in the scheduler graph, even if they have other item members. + """ + + driver_fcode = """ +SUBROUTINE driver(a, b, c) + USE HEADER_MOD, ONLY: HEADER_VAR + USE KERNEL_MOD, ONLY: KERNEL + IMPLICIT NONE + + INTEGER, INTENT(INOUT) :: a, b, c + + a = kernel(a) + b = kernel(a) + c = kernel(c) + HEADER_VAR +END SUBROUTINE driver + """.strip() + + kernel_fcode = """ +MODULE kernel_mod +IMPLICIT NONE +CONTAINS +FUNCTION kernel(a) RESULT(ret) + INTEGER, INTENT(IN) :: a + INTEGER :: ret + + ret = 2*a +END FUNCTION kernel +END MODULE kernel_mod + """.strip() + + header_fcode = """ +MODULE header_mod + IMPLICIT NONE + INTEGER :: HEADER_VAR +END MODULE header_mod + """.strip() + + (tempdir/'kernel_mod.F90').write_text(kernel_fcode) + (tempdir/'header_mod.F90').write_text(header_fcode) + (tempdir/'driver.F90').write_text(driver_fcode) + + # Create the scheduler such that it chases imports + config['default']['enable_imports'] = True + scheduler = Scheduler(paths=[tempdir], config=SchedulerConfig.from_dict(config), frontend=frontend) + + # Make sure the header var item exists + assert 'header_mod#header_var' in scheduler.items + + transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') + scheduler.process(transformation, use_file_graph=True) + + kernel = scheduler['kernel_mod#kernel'].source + header = scheduler['header_mod#header_var'].source + driver = scheduler['#driver'].source + + # Check that the kernel mod has been changed + assert len(kernel.subroutines) == 0 + assert len(kernel.all_subroutines) == 1 + assert kernel.all_subroutines[0].name == 'kernel_test' + assert kernel['kernel_test'] == kernel.all_subroutines[0] + assert kernel['kernel_test'].is_function + assert len(kernel.modules) == 1 + assert kernel.modules[0].name == 'kernel_test_mod' + assert kernel['kernel_test_mod'] == kernel.modules[0] + + # Check that the header name has not been changed + assert len(header.modules) == 1 + assert header.modules[0].name == 'header_mod' + assert header.modules[0].variables == ('header_var',) + + # Check that the driver name has not changed + assert len(driver.modules) == 0 + assert len(driver.subroutines) == 1 + assert driver.subroutines[0].name == 'driver' + + # Check that calls and imports have been diverted to the re-generated routine + calls = tuple(FindInlineCalls().visit(driver['driver'].body)) + assert len(calls) == 2 + calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body)) + assert len(calls) == 3 + assert all(call.name == 'kernel_test' for call in calls) + imports = FindNodes(Import).visit(driver['driver'].spec) + imports = driver['driver'].import_map + assert len(imports) == 2 + assert 'header_var' in imports and imports['header_var'].module.lower() == 'header_mod' + assert 'kernel_test' in imports and imports['kernel_test'].module.lower() == 'kernel_test_mod'