Skip to content

Commit

Permalink
Merge pull request #132 from ecmwf-ifs/naml-stmt-func-fixes
Browse files Browse the repository at this point in the history
Statement function parsing on-the-fly
  • Loading branch information
reuterbal authored Sep 5, 2023
2 parents 38e347a + a7b778f commit a987012
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 120 deletions.
57 changes: 50 additions & 7 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from loki.visitors import GenericVisitor, Transformer, FindNodes
from loki.frontend.source import Source
from loki.frontend.preprocessing import sanitize_registry
from loki.frontend.util import read_file, FP, inject_statement_functions, sanitize_ir
from loki.frontend.util import read_file, FP, sanitize_ir
from loki import ir
import loki.expression.symbols as sym
from loki.expression.operations import (
Expand All @@ -34,7 +34,7 @@
ExpressionDimensionsMapper, FindTypedSymbols, SubstituteExpressions, AttachScopesMapper
)
from loki.logging import debug, info, warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup
from loki.pragma_utils import (
attach_pragmas, process_dimension_pragmas, detach_pragmas, pragmas_attached
)
Expand Down Expand Up @@ -1753,6 +1753,15 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
body = self.visit(body_ast, **kwargs)
body = sanitize_ir(body, FP, pp_registry=sanitize_registry[FP], pp_info=self.pp_info)

# Workaround for lost StatementFunctions:
# Since FParser has no means to identify StmtFuncs, the last set of them
# can get lumped in with the body, and we simply need to shift them over.
stmt_funcs = tuple(n for n in body.body if isinstance(n, ir.StatementFunction))
if stmt_funcs:
idx = body.body.index(stmt_funcs[-1]) + 1
spec._update(body=spec.body + body.body[:idx])
body._update(body=body.body[idx:])

# Another big hack: fparser allocates all comments before and after the
# spec to the spec. We remove them from the beginning to get the docstring.
comment_map = {}
Expand Down Expand Up @@ -1802,6 +1811,12 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
rescope_symbols=True, source=source, incomplete=False
)

# Once statement functions are in place, we need to update the original declaration symbol
for decl in FindNodes(ir.VariableDeclaration).visit(spec):
if any(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols):
assert all(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols)
decl._update(symbols=tuple(s.clone() for s in decl.symbols))

# Big, but necessary hack:
# For deferred array dimensions on allocatables, we infer the conceptual
# dimension by finding any `allocate(var(<dims>))` statements.
Expand All @@ -1811,9 +1826,6 @@ def visit_Subroutine_Subprogram(self, o, **kwargs):
with pragmas_attached(routine, ir.VariableDeclaration):
routine.spec = process_dimension_pragmas(routine.spec)

# Inject statement function definitions
inject_statement_functions(routine)

if isinstance(o, Fortran2003.Subroutine_Body):
# Return the subroutine object along with any clutter before it for interface declarations
return (*pre, routine)
Expand Down Expand Up @@ -2912,8 +2924,39 @@ def visit_Assignment_Stmt(self, o, **kwargs):
ptr = isinstance(o, Fortran2003.Pointer_Assignment_Stmt)
lhs = self.visit(o.items[0], **kwargs)
rhs = self.visit(o.items[2], **kwargs)
return ir.Assignment(lhs=lhs, rhs=rhs, ptr=ptr,
label=kwargs.get('label'), source=kwargs.get('source'))

# Special-case: Identify statement functions using our internal symbol table
symbol_attrs = kwargs['scope'].symbol_attrs
if isinstance(lhs, sym.Array) and lhs.name in symbol_attrs:

def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=kwargs['scope'],
query=lambda x: [
f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name
][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)

if not symbol_attrs[lhs.name].shape and not symbol_attrs[lhs.name].intent:
# If the LHS array access is actually declared as a scalar,
# we are actually dealing with a statement function!
stmt_func = ir.StatementFunction(
variable=lhs.clone(dimensions=None), arguments=lhs.dimensions,
rhs=rhs, return_type=symbol_attrs[lhs.name],
label=kwargs.get('label'), source=kwargs.get('source')
)

# Update the type in the local scope and return stmt func node
symbol_attrs[str(stmt_func.variable)] = _create_stmt_func_type(stmt_func)
return stmt_func

# Return Assignment node if we don't have to deal with the stupid side of Fortran!
return ir.Assignment(
lhs=lhs, rhs=rhs, ptr=ptr, label=kwargs.get('label'), source=kwargs.get('source')
)

visit_Pointer_Assignment_Stmt = visit_Assignment_Stmt

Expand Down
113 changes: 3 additions & 110 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,18 @@
from pathlib import Path
import codecs

from loki.visitors import (
Transformer, NestedTransformer, FindNodes, PatternFinder, SequenceFinder
)
from loki.visitors import NestedTransformer, FindNodes, PatternFinder, SequenceFinder
from loki.ir import (
Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration,
Loop, Intrinsic, Pragma, StatementFunction
Loop, Intrinsic, Pragma
)
from loki.expression import SubstituteExpressions, Scalar, Array, InlineCall, FindVariables, ProcedureSymbol
from loki.types import ProcedureType, SymbolAttributes
from loki.tools import LazyNodeLookup
from loki.frontend.source import Source
from loki.logging import warning

__all__ = [
'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX',
'inline_comments', 'cluster_comments', 'read_file',
'combine_multiline_pragmas', 'inject_statement_functions', 'sanitize_ir'
'combine_multiline_pragmas', 'sanitize_ir'
]


Expand Down Expand Up @@ -185,108 +180,6 @@ def combine_multiline_pragmas(ir):
return NestedTransformer(pragma_mapper, invalidate_source=False).visit(ir)


def inject_statement_functions(routine):
"""
Identify statement function definitions and correct their
representation in the IR
Fparser misinterprets statement function definitions as array
assignments and may put them into the subroutine's body instead of
the spec. This function tries to identify them, correct the type of
the symbol_attrs representing statement functions (as :any:`ProcedureSymbol`)
and store their definition as :any:`StatementFunction`.
Parameters
----------
routine : :any:`Subroutine`
The subroutine object for which statement functions should be
injected
"""
def create_stmt_func(assignment):
arguments = assignment.lhs.dimensions
variable = assignment.lhs.clone(dimensions=None)
return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source)

def create_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=routine,
query=lambda x: [f for f in FindNodes(StatementFunction).visit(x.spec) if f.variable == name][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)

# Only locally declared scalar variables are potential candidates
candidates = [str(v).lower() for v in routine.variables if isinstance(v, Scalar)]

# First suspects: Array assignments in the spec
spec_map = {}
for assignment in FindNodes(Assignment).visit(routine.spec):
if isinstance(assignment.lhs, Array) and assignment.lhs.name.lower() in candidates:
stmt_func = create_stmt_func(assignment)
spec_map[assignment] = stmt_func
routine.symbol_attrs[str(stmt_func.variable)] = create_type(stmt_func)

# Other suspects: Array assignments at the beginning of the body
spec_appendix = []
body_map = {}
for node in routine.body.body:
if isinstance(node, (Comment, CommentBlock)):
spec_appendix += [node]
if isinstance(node, Assignment) and isinstance(node.lhs, Array) and node.lhs.name.lower() in candidates:
stmt_func = create_stmt_func(node)
spec_appendix += [stmt_func]
body_map[node] = None
routine.symbol_attrs[str(stmt_func.variable)] = create_type(stmt_func)
else:
break

if spec_map or body_map:
# All statement functions
stmt_funcs = {node.lhs.name.lower() for node in spec_map}
stmt_funcs |= {node.lhs.name.lower() for node in body_map}

# Find any use of the statement functions in the body and replace
# them with function calls
expr_map_spec = {}
for variable in FindVariables().visit(routine.spec):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_spec[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters)
elif not isinstance(variable, ProcedureSymbol):
expr_map_spec[variable] = variable.clone()
expr_map_body = {}
for variable in FindVariables().visit(routine.body):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_body[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters)
elif not isinstance(variable, ProcedureSymbol):
expr_map_body[variable] = variable.clone()

# Make sure we remove comments from the body if we append them to spec
if any(isinstance(node, StatementFunction) for node in spec_appendix):
body_map.update({node: None for node in spec_appendix if isinstance(node, (Comment, CommentBlock))})

# Apply transformer with the built maps
if spec_map:
routine.spec = Transformer(spec_map, invalidate_source=False).visit(routine.spec)
if body_map:
routine.body = Transformer(body_map, invalidate_source=False).visit(routine.body)
if spec_appendix:
routine.spec.append(spec_appendix)
if expr_map_spec:
routine.spec = SubstituteExpressions(expr_map_spec, invalidate_source=False).visit(routine.spec)
if expr_map_body:
routine.body = SubstituteExpressions(expr_map_body, invalidate_source=False).visit(routine.body)

# And make sure all symbols have the right type
routine.rescope_symbols()


def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None):
"""
Utility function to sanitize internal representation after creating it
Expand Down
8 changes: 5 additions & 3 deletions tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ def test_subroutine_rescope_symbols(frontend):
subroutine nested_routine(a, n)
use some_mod, only: ext2
integer, parameter :: jpim = selected_int_kind(4)
integer, intent(inout) :: a
integer, intent(inout) :: a(n)
integer, intent(in) :: n
integer(kind=jpim) :: j
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def test_subroutine_rescope_clone(frontend):
subroutine nested_routine(a, n)
use some_mod, only: ext2
integer, intent(inout) :: a
integer, intent(inout) :: a(n)
integer, intent(in) :: n
integer :: j
Expand Down Expand Up @@ -1492,14 +1492,16 @@ def test_subroutine_stmt_func(here, frontend):
implicit none
integer, intent(in) :: a
integer, intent(out) :: b
integer :: array(a)
integer :: i, j
integer :: plus, minus
plus(i, j) = i + j
minus(i, j) = i - j
integer :: mult
mult(i, j) = i * j
integer :: tmp
mult(i, j) = i * j
array(i) = i
tmp = plus(a, 5)
tmp = minus(tmp, 1)
b = mult(2, tmp)
Expand Down

0 comments on commit a987012

Please sign in to comment.