Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single declarations for hoisted variables in recursive hoist transformation (fixes #143) #144

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions loki/transform/transform_hoist_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@
"""
from loki.expression import FindVariables, SubstituteExpressions
from loki.ir import CallStatement, Allocation, Deallocation
from loki.tools.util import is_iterable, as_tuple
from loki.tools.util import is_iterable, as_tuple, CaseInsensitiveDict
from loki.visitors import Transformer, FindNodes
from loki.transform.transformation import Transformation
from loki.transform.transform_utilities import single_variable_declaration
import loki.expression.symbols as sym


Expand Down Expand Up @@ -152,7 +153,7 @@ def transform_subroutine(self, routine, **kwargs):

calls = [call for call in FindNodes(CallStatement).visit(routine.body) if call.name
not in self.disable]
call_map = {str(call.name): call for call in calls}
call_map = CaseInsensitiveDict((str(call.name), call) for call in calls)

for child in successors:
arg_map = dict(call_map[child.routine.name].arg_iter())
Expand Down Expand Up @@ -249,8 +250,15 @@ def transform_subroutine(self, routine, **kwargs):
for var in item.trafo_data[self._key]["to_hoist"]:
self.driver_variable_declaration(routine, var)
else:
routine.arguments += as_tuple([var.clone(type=var.type.clone(intent='inout'),
scope=routine) for var in item.trafo_data[self._key]["to_hoist"]])
# We build the list of tempararies that are hoisted to the calling routine
# Because this requires adding an intent, we need to make sure they are not
# declared together with non-hoisted variables
hoisted_temporaries = tuple(
var.clone(type=var.type.clone(intent='inout'), scope=routine)
for var in item.trafo_data[self._key]['to_hoist']
)
single_variable_declaration(routine, variables=[var.clone(dimensions=None) for var in hoisted_temporaries])
routine.arguments += hoisted_temporaries

call_map = {}
calls = [_ for _ in FindNodes(CallStatement).visit(routine.body) if _.name not in self.disable]
Expand Down
116 changes: 112 additions & 4 deletions tests/test_transform_hoist_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import numpy as np

from conftest import available_frontends, jit_compile_lib, clean_test
from loki import FindNodes, Scheduler, Builder
from loki import ir, is_iterable
from loki.transform import (HoistVariablesAnalysis, HoistVariablesTransformation,
HoistTemporaryArraysAnalysis, HoistTemporaryArraysTransformationAllocatable)
from loki import FindNodes, Scheduler, Builder, SchedulerConfig, OMNI
from loki import ir, is_iterable, gettempdir, normalize_range_indexing
from loki.transform import (
HoistVariablesAnalysis, HoistVariablesTransformation,
HoistTemporaryArraysAnalysis, HoistTemporaryArraysTransformationAllocatable
)


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -323,3 +325,109 @@ def test_hoist_allocatable(here, frontend, config):

check_arguments(scheduler=scheduler, subroutine_arguments=subroutine_arguments, call_arguments=call_arguments)
compile_and_test(scheduler=scheduler, here=here, a=(5, 10, 100), frontend=frontend, test_name="allocatable")


@pytest.mark.parametrize('frontend', available_frontends())
def test_hoist_mixed_variable_declarations(frontend, config):

fcode_driver = """
subroutine driver(NLON, NZ, NB, FIELD1, FIELD2)
use kernel_mod, only: kernel
implicit none
INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300)
INTEGER, INTENT(IN) :: NLON, NZ, NB
integer :: b
real(kind=jprb), intent(inout) :: field1(nlon, nb)
real(kind=jprb), intent(inout) :: field2(nlon, nz, nb)
do b=1,nb
call KERNEL(1, nlon, nlon, nz, 2, field1(:,b), field2(:,:,b))
end do
end subroutine driver
""".strip()
fcode_kernel = """
module kernel_mod
implicit none
contains
subroutine kernel(start, end, klon, klev, nclv, field1, field2)
use iso_c_binding, only : c_size_t
implicit none
integer, parameter :: jprb = selected_real_kind(13,300)
integer, intent(in) :: nclv
integer, intent(in) :: start, end, klon, klev
real(kind=jprb), intent(inout) :: field1(klon)
real(kind=jprb), intent(inout) :: field2(klon,klev)
real(kind=jprb) :: tmp1(klon)
real(kind=jprb) :: tmp2(klon, klev), tmp3(nclv)
real(kind=jprb) :: tmp4(2), tmp5(klon, nclv, klev)
integer :: jk, jl, jm

do jk=1,klev
tmp1(jl) = 0.0_jprb
do jl=start,end
tmp2(jl, jk) = field2(jl, jk)
tmp1(jl) = field2(jl, jk)
end do
field1(jl) = tmp1(jl)
end do

do jm=1,nclv
tmp3(jm) = 0._jprb
do jl=start,end
tmp5(jl, jm, :) = field1(jl)
enddo
enddo
end subroutine kernel
end module kernel_mod
""".strip()

basedir = gettempdir()/'test_hoist_mixed_variable_declarations'
basedir.mkdir(exist_ok=True)
(basedir/'driver.F90').write_text(fcode_driver)
(basedir/'kernel_mod.F90').write_text(fcode_kernel)

config = {
'default': {
'mode': 'idem',
'role': 'kernel',
'expand': True,
'strict': True
},
'routine': [{
'name': 'driver',
'role': 'driver',
}]
}

scheduler = Scheduler(paths=[basedir], config=SchedulerConfig.from_dict(config), frontend=frontend)

if frontend == OMNI:
for item in scheduler.items:
normalize_range_indexing(item.routine)

scheduler.process(transformation=HoistTemporaryArraysAnalysis(dim_vars=('klev',)), reverse=True)
scheduler.process(transformation=HoistTemporaryArraysTransformationAllocatable())

driver_variables = (
'jprb', 'nlon', 'nz', 'nb', 'b',
'field1(nlon, nb)', 'field2(nlon, nz, nb)',
'kernel_tmp2(:,:)', 'kernel_tmp5(:,:,:)'
)
kernel_arguments = (
'start', 'end', 'klon', 'klev', 'nclv',
'field1(klon)', 'field2(klon,klev)', 'tmp2(klon,klev)', 'tmp5(klon,nclv,klev)'
)

# Check hoisting and declaration in driver
assert scheduler['#driver'].routine.variables == driver_variables
assert scheduler['kernel_mod#kernel'].routine.arguments == kernel_arguments

# Check updated call signature
calls = FindNodes(ir.CallStatement).visit(scheduler['#driver'].routine.body)
assert len(calls) == 1
assert calls[0].arguments == (
'1', 'nlon', 'nlon', 'nz', '2', 'field1(:,b)', 'field2(:,:,b)',
'kernel_tmp2', 'kernel_tmp5'
)

# Check that fgen works
assert scheduler['kernel_mod#kernel'].source.to_fortran()