Skip to content

Commit

Permalink
fix for arg shape regarding module imports
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Sep 8, 2023
1 parent 0db337c commit 78dd3c9
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 1 deletion.
18 changes: 18 additions & 0 deletions transformations/tests/sources/projArgShape/driver_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
MODULE DRIVER_MOD
USE KERNEL_A_MOD, ONLY: KERNEL_A
USE KERNEL_B_MOD, ONLY: KERNEL_B
IMPLICIT NONE
CONTAINS
SUBROUTINE driver(nlon, nlev, a, b, c)
INTEGER, INTENT(IN) :: nlon, nlev ! Dimension sizes
INTEGER, PARAMETER :: n = 5
REAL, INTENT(INOUT) :: a(nlon)
REAL, INTENT(INOUT) :: b(nlon,nlev)
REAL, INTENT(INOUT) :: c(nlon,n)

call kernel_a(a, b, c)

call kernel_b(b, c)
END SUBROUTINE driver

END MODULE DRIVER_MOD
12 changes: 12 additions & 0 deletions transformations/tests/sources/projArgShape/kernel_a1_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
MODULE KERNEL_A1_MOD
IMPLICIT NONE
CONTAINS

SUBROUTINE kernel_a1(b, c)
! Second-level kernel call
REAL, INTENT(INOUT) :: b(:,:)
REAL, INTENT(INOUT) :: c(:,:)

END SUBROUTINE kernel_a1

END MODULE KERNEL_A1_MOD
14 changes: 14 additions & 0 deletions transformations/tests/sources/projArgShape/kernel_a_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
MODULE KERNEL_A_MOD
USE KERNEL_A1_MOD, ONLY: KERNEL_A1
IMPLICIT NONE
CONTAINS
SUBROUTINE kernel_a(a, b, c)
USE VAR_MODULE_MOD, only: n
REAL, INTENT(INOUT) :: a(:)
REAL, INTENT(INOUT) :: b(:,:)
REAL, INTENT(INOUT) :: c(:,:)

CALL kernel_a1(b, c)
END SUBROUTINE kernel_a

END MODULE KERNEL_A_MOD
13 changes: 13 additions & 0 deletions transformations/tests/sources/projArgShape/kernel_b_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
MODULE KERNEL_B_MOD
USE VAR_MODULE_MOD, only: n
IMPLICIT NONE
CONTAINS

SUBROUTINE kernel_b(b, c)
! USE VAR_MODULE_MOD, only: n
! Second-level kernel call
REAL, INTENT(INOUT) :: b(:,:)
REAL, INTENT(INOUT) :: c(:,:)

END SUBROUTINE kernel_b
END MODULE KERNEL_B_MOD
3 changes: 3 additions & 0 deletions transformations/tests/sources/projArgShape/var_module_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MODULE VAR_MODULE_MOD
INTEGER, PARAMETER :: n = 5
END MODULE VAR_MODULE_MOD
69 changes: 68 additions & 1 deletion transformations/tests/test_argument_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
import pytest

from conftest import available_frontends
from loki import CallStatement, FindNodes, OMNI, Subroutine
from loki import CallStatement, FindNodes, OMNI, Subroutine, Scheduler, Sourcefile, flatten
from transformations import ArgumentArrayShapeAnalysis, ExplicitArgumentArrayShapeTransformation

@pytest.fixture(scope='module', name='here')
def fixture_here():
return Path(__file__).parent

@pytest.mark.parametrize('frontend', available_frontends())
def test_argument_shape_simple(frontend):
Expand Down Expand Up @@ -353,3 +357,66 @@ def test_argument_shape_transformation(frontend):
assert (v, v) in FindNodes(CallStatement).visit(kernel_a2.body)[0].kwarguments
assert (v, v) in FindNodes(CallStatement).visit(driver.body)[0].kwarguments
assert (v, v) in FindNodes(CallStatement).visit(driver.body)[1].kwarguments


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'OMNI module type definitions not available')]))
def test_argument_shape_transformation_import(frontend, here):
"""
Test that ensures that explicit argument shapes are indeed inserted
in a multi-layered call tree.
"""

config = {
'default': {
'mode': 'idem',
'role': 'kernel',
'expand': True,
'strict': True
},
'routine': [{
'name': 'driver',
'role': 'driver'
}]
}

header = [here/'sources/projArgShape/var_module_mod.F90']
frontend_type = frontend
headers = [Sourcefile.from_file(filename=h, frontend=frontend_type) for h in header]
definitions = flatten(h.modules for h in headers)
scheduler = Scheduler(paths=here/'sources/projArgShape', config=config, frontend=frontend,
definitions=definitions)
scheduler.process(transformation=ArgumentArrayShapeAnalysis())
scheduler.process(transformation=ExplicitArgumentArrayShapeTransformation(), reverse=True)

item_map = {item.name: item for item in scheduler.items}
driver = item_map['driver_mod#driver'].source['driver']
kernel_a = item_map['kernel_a_mod#kernel_a'].source['kernel_a']
kernel_a1 = item_map['kernel_a1_mod#kernel_a1'].source['kernel_a1']
kernel_b = item_map['kernel_b_mod#kernel_b'].source['kernel_b']

# Check that argument shapes have been applied
assert kernel_a.arguments[0].dimensions == ('nlon',)
assert kernel_a.arguments[1].dimensions == ('nlon', 'nlev')
assert kernel_a.arguments[2].dimensions == ('nlon', 'n')
assert 'nlon' in kernel_a.arguments
assert 'nlon' in kernel_a.arguments
assert 'n' not in kernel_a.arguments

assert kernel_b.arguments[0].dimensions == ('nlon', 'nlev')
assert kernel_b.arguments[1].dimensions == ('nlon', 'n')
assert 'nlon' in kernel_b.arguments
assert 'nlon' in kernel_b.arguments
assert 'n' not in kernel_b.arguments

assert kernel_a1.arguments[0].dimensions == ('nlon', 'nlev')
assert kernel_a1.arguments[1].dimensions == ('nlon', 'n')
assert 'nlon' in kernel_a1.arguments
assert 'nlon' in kernel_a1.arguments
assert 'n' in kernel_a1.arguments

# And finally, check that scalar dimension size variables have been added to calls
for v in ('nlon', 'nlev'):
assert (v, v) in FindNodes(CallStatement).visit(driver.body)[0].kwarguments
assert (v, v) in FindNodes(CallStatement).visit(driver.body)[1].kwarguments
for v in ('nlon', 'nlev', 'n'):
assert (v, v) in FindNodes(CallStatement).visit(kernel_a.body)[0].kwarguments
4 changes: 4 additions & 0 deletions transformations/transformations/argument_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def transform_subroutine(self, routine, **kwargs): # pylint: disable=arguments-
continue

callee = call.routine
imported_symbols = callee.imported_symbols
if callee.parent is not None:
imported_symbols += callee.parent.imported_symbols

# Collect all potential dimension variables and filter for scalar integers
dims = set(d for arg in callee.arguments if isinstance(arg, Array) for d in arg.shape)
Expand All @@ -141,6 +144,7 @@ def transform_subroutine(self, routine, **kwargs): # pylint: disable=arguments-
# Add all new dimension arguments to the callee signature
new_args = tuple(d for d in dim_vars if d not in callee.arguments)
new_args = tuple(d for d in new_args if d.type.dtype == BasicType.INTEGER)
new_args = tuple(d for d in new_args if d not in imported_symbols)
new_args = tuple(d.clone(scope=routine, type=d.type.clone(intent='IN')) for d in new_args)
callee.arguments += new_args

Expand Down

0 comments on commit 78dd3c9

Please sign in to comment.