Skip to content

Commit

Permalink
Merge pull request #401 from ecmwf-ifs/nams-inline-skip-elemental-fun…
Browse files Browse the repository at this point in the history
…ctions-array-args

Inline elemental functions: skip calls with args being array (slices)
  • Loading branch information
reuterbal authored Oct 18, 2024
2 parents 5b65e56 + 06a810e commit cc30c9e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
22 changes: 18 additions & 4 deletions loki/transformations/inline/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from collections import ChainMap

from loki.expression import symbols as sym, ExpressionRetriever
from loki.logging import warning
from loki.expression import symbols as sym, ExpressionRetriever, ExpressionDimensionsMapper
from loki.ir import (
Transformer, FindNodes, FindVariables, Import, StatementFunction,
FindInlineCalls, ExpressionFinder, SubstituteExpressions,
Expand Down Expand Up @@ -87,6 +88,13 @@ def _inline_functions(routine, inline_elementals_only=False, functions=None):
inlined in the next call to this function.
"""

def is_array(expr):
"""
Check whether expr evaluates to an array.
E.g., for arr(:, :) return True, for arr(1, 1) or arr(jl, jk) return False.
"""
return any(d != '1' for d in ExpressionDimensionsMapper()(expr))

class ExpressionRetrieverSkipInlineCallParameters(ExpressionRetriever):
"""
Expression retriever skipping parameters of inline calls.
Expand All @@ -102,6 +110,11 @@ def __init__(self, query, recurse_query=None, inline_elementals_only=False,
def map_inline_call(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return
if not expr.procedure_type is BasicType.DEFERRED and expr.procedure_type.is_elemental:
if any(is_array(val) for val in expr.arg_map.values() if isinstance(val, sym.Array)):
warning(f"Call to elemental function '{expr.routine.name}' with array arguments."
f' There is currently no support to inline those calls!')
return
self.rec(expr.function, *args, **kwargs)
# SKIP parameters/args/kwargs on purpose
# under certain circumstances
Expand Down Expand Up @@ -142,9 +155,10 @@ class FindInlineCallsSkipInlineCallParameters(ExpressionFinder):
for call in calls:
if call.procedure_type is BasicType.DEFERRED or isinstance(call.routine, StatementFunction):
continue
if inline_elementals_only:
if not (call.procedure_type.is_function and call.procedure_type.is_elemental):
continue
if not call.procedure_type.is_function:
continue
if inline_elementals_only and not call.procedure_type.is_elemental:
continue
if functions:
if call.routine not in functions:
continue
Expand Down
92 changes: 70 additions & 22 deletions loki/transformations/inline/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# nor does it submit to any jurisdiction.

import pytest
import numpy as np

from loki import Module, Subroutine
from loki.build import jit_compile_lib, Builder, Obj
Expand Down Expand Up @@ -100,12 +101,8 @@ def test_transform_inline_elemental_functions(tmp_path, builder, frontend):
builder.clean()
(tmp_path/f'{routine.name}.f90').unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_inline_elemental_functions_extended(tmp_path, builder, frontend):
"""
Test correct inlining of elemental functions.
"""
@pytest.fixture(name='multiply_extended_mod', params=available_frontends())
def fixture_multiply_extended_mod(request, tmp_path):
fcode_module = """
module multiply_extended_mod
use iso_fortran_env, only: real64
Expand Down Expand Up @@ -144,8 +141,15 @@ def test_transform_inline_elemental_functions_extended(tmp_path, builder, fronte
end module multiply_extended_mod
"""

frontend = request.param
module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
return module, frontend

def test_transform_inline_elemental_functions_extended_scalar(multiply_extended_mod, builder, tmp_path):
module, frontend = multiply_extended_mod

fcode = """
subroutine transform_inline_elemental_functions_extended(v1, v2, v3)
subroutine transform_inline_elemental_functions_extended_scalar(v1, v2, v3)
use iso_fortran_env, only: real64
use multiply_extended_mod, only: multiply, multiply_single_line, add
real(kind=real64), intent(in) :: v1
Expand All @@ -154,44 +158,88 @@ def test_transform_inline_elemental_functions_extended(tmp_path, builder, fronte
v2 = multiply(v1, 6._real64) + multiply_single_line(v1, 3._real64)
v3 = add(param1, 200._real64) + add(150._real64, 150._real64) + multiply(6._real64, 11._real64)
end subroutine transform_inline_elemental_functions_extended
end subroutine transform_inline_elemental_functions_extended_scalar
"""

# Generate reference code, compile run and verify
module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

routine = Subroutine.from_source(fcode, frontend=frontend, definitions=[module], xmods=[tmp_path])
refname = f'ref_{routine.name}_{frontend}'
reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)

v2, v3 = reference.transform_inline_elemental_functions_extended(11.)
v2, v3 = reference.transform_inline_elemental_functions_extended_scalar(11.)
assert v2 == 99.
assert v3 == 666.

(tmp_path/f'{module.name}.f90').unlink()
(tmp_path/f'{routine.name}.f90').unlink()

# Now inline elemental functions
routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path])
inline_elemental_functions(routine)


# Make sure there are no more inline calls in the routine body
assert not FindInlineCalls().visit(routine.body)

# Verify correct scope of inlined elements
assert all(v.scope is routine for v in FindVariables().visit(routine.body))

# Hack: rename routine to use a different filename in the build
routine.name = f'{routine.name}_'
kernel = jit_compile_lib([routine], path=tmp_path, name=routine.name, builder=builder)

v2, v3 = kernel.transform_inline_elemental_functions_extended_(11.)
kernel = jit_compile_lib([routine, module], path=tmp_path, name=routine.name, builder=builder)
v2, v3 = kernel.transform_inline_elemental_functions_extended_scalar_(11.)
assert v2 == 99.
assert v3 == 666.

builder.clean()
(tmp_path/f'{routine.name}.f90').unlink()
(tmp_path/f'{module.name}.f90').unlink()

def test_transform_inline_elemental_functions_extended_arr(multiply_extended_mod, builder, tmp_path):
module, frontend = multiply_extended_mod

fcode_arr = """
subroutine transform_inline_elemental_functions_extended_array(v1, v2, v3, len)
use iso_fortran_env, only: real64
use multiply_extended_mod, only: multiply, multiply_single_line, add
integer, intent(in) :: len
real(kind=real64), intent(in) :: v1(len)
real(kind=real64), intent(inout) :: v2(len), v3(len)
real(kind=real64), parameter :: param1 = 100.
integer, parameter :: arr_index = 1
v2 = multiply(v1(:), 6._real64) + multiply_single_line(v1(:), 3._real64)
v3 = add(param1, 200._real64) + add(v1, 150._real64) + multiply(v1(arr_index), v2(1))
end subroutine transform_inline_elemental_functions_extended_array
"""

routine = Subroutine.from_source(fcode_arr, frontend=frontend, definitions=[module], xmods=[tmp_path])
refname = f'ref_{routine.name}_frontend'
reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder)
arr_len = 5
v1 = np.array([1.0, 2.0, 3.0, 5.0, 3.0], dtype=np.float64, order='F')
v2 = np.zeros((arr_len,), dtype=np.float64, order='F')
v3 = np.zeros((arr_len,), dtype=np.float64, order='F')
reference.transform_inline_elemental_functions_extended_array(v1, v2, v3, arr_len)
assert (v2 == np.array([9., 18., 27., 45., 27.], dtype=np.float64, order='F')).all()
assert (v3 == np.array([460., 461., 462., 464., 462.], dtype=np.float64, order='F')).all()

(tmp_path/f'{routine.name}.f90').unlink()

routine = Subroutine.from_source(fcode_arr, definitions=module, frontend=frontend, xmods=[tmp_path])
inline_elemental_functions(routine)
# TODO: Make sure there are no more inline calls in the routine body
# assert not FindInlineCalls().visit(routine.body)
# this is currently not achievable as calls to elemental functions with array arguments
# can't be properly inlined and therefore are skipped
# Verify correct scope of inlined elements
assert all(v.scope is routine for v in FindVariables().visit(routine.body))
# Hack: rename routine to use a different filename in the build
routine.name = f'{routine.name}_'
kernel = jit_compile_lib([routine, module], path=tmp_path, name=routine.name, builder=builder)
v1 = np.array([1.0, 2.0, 3.0, 5.0, 3.0], dtype=np.float64, order='F')
v2 = np.zeros((arr_len,), dtype=np.float64, order='F')
v3 = np.zeros((arr_len,), dtype=np.float64, order='F')
kernel.transform_inline_elemental_functions_extended_array_(v1, v2, v3, arr_len)
assert (v2 == np.array([9., 18., 27., 45., 27.], dtype=np.float64, order='F')).all()
assert (v3 == np.array([460., 461., 462., 464., 462.], dtype=np.float64, order='F')).all()

builder.clean()
(tmp_path/f'{routine.name}.f90').unlink()
(tmp_path/f'{module.name}.f90').unlink()


@pytest.mark.parametrize('frontend', available_frontends(
Expand Down

0 comments on commit cc30c9e

Please sign in to comment.