Skip to content

Commit

Permalink
Merge pull request #140 from ecmwf-ifs/133-expression-finder-retrieva…
Browse files Browse the repository at this point in the history
…l-function

Expression mappers: Remove recurse_to_parent option and recurse by default
  • Loading branch information
reuterbal authored Sep 7, 2023
2 parents 35f8284 + 7a1b195 commit 7094446
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 285 deletions.
1 change: 0 additions & 1 deletion docs/source/visitors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ visitors exist that apply :any:`ExpressionRetriever` to all expression trees.
loki.expression.expr_visitors.FindVariables
loki.expression.expr_visitors.FindInlineCalls
loki.expression.expr_visitors.FindLiterals
loki.expression.expr_visitors.FindExpressionRoot

For example, the following finds all function calls embedded in expressions
(:any:`InlineCall`, as opposed to subroutine calls in :any:`CallStatement`):
Expand Down
143 changes: 86 additions & 57 deletions lint_rules/lint_rules/ifs_coding_standards_2011.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
"""

from pathlib import Path
from collections import defaultdict
import re

from pymbolic.primitives import Expression

from loki import (
Visitor, FindNodes, ExpressionFinder, FindExpressionRoot, ExpressionRetriever,
Visitor, FindNodes, ExpressionFinder, ExpressionRetriever,
flatten, as_tuple, strip_inline_comments, Module, Subroutine, BasicType, ir
)
from loki.lint import GenericRule, RuleType
Expand Down Expand Up @@ -372,15 +373,20 @@ def check_kind_literals(subroutine, types, allowed_type_kinds, rule_report):
'''Helper function that carries out the check for explicit kind specification
on all literals.
'''
# Custom retriever that yields the literal types specified in config and stops
# recursion on loop ranges and array subscripts
# (to avoid warnings about integer constants in these cases)
excl_types = (sym.Array, sym.Range)
retriever = ExpressionRetriever(query=lambda e: isinstance(e, types),
recurse_query=lambda e: not isinstance(e, excl_types))
finder = ExpressionFinder(unique=False, retrieve=retriever.retrieve, with_ir_node=True)

for node, exprs in finder.visit(subroutine.ir):

class FindLiteralsWithKind(ExpressionFinder):
"""
Custom expression finder that that yields all literals of the types
specified in the config and stops recursion on loop ranges and array subscripts
(to avoid warnings about integer constants in these cases)
"""

retriever = ExpressionRetriever(
query=lambda e: isinstance(e, types),
recurse_query=lambda e: not isinstance(e, (sym.Array, sym.Range))
)

for node, exprs in FindLiteralsWithKind(unique=False, with_ir_node=True).visit(subroutine.ir):
for literal in exprs:
if not literal.kind:
rule_report.add(f'{literal} used without explicit KIND', node)
Expand Down Expand Up @@ -472,55 +478,78 @@ class Fortran90OperatorsRule(GenericRule): # Coding standards 4.15
'<': re.compile(r'(?P<f77>\.lt\.)|(?P<f90><(?!=))', re.I),
}

_op_map = {
'==': '.eq.',
'/=': '.ne.',
'>=': '.ge.',
'<=': '.le.',
'>': '.gt.',
'<': '.lt.'
}

class ComparisonRetriever(Visitor):
"""
Bespoke expression retriever that extracts 3-tuples containing
``(node, expression root, comparison)`` for all :any:`Comparison` nodes.
"""

retriever = ExpressionRetriever(lambda e: isinstance(e, sym.Comparison))

def visit_Node(self, o, **kwargs):
"""
Generic visitor method that will call the :any:`ExpressionRetriever`
only on :class:`pymbolic.primitives.Expression` children, collecting
``(node, expression root, comparison)`` tuples for all matches.
"""
retval = ()
for ch in flatten(o.children):
if isinstance(ch, Expression):
comparisons = self.retriever.retrieve(ch)
if comparisons:
retval += ((o, ch, comparisons),)
elif ch is not None:
retval += self.visit(ch, **kwargs)
return retval

def visit_tuple(self, o, **kwargs):
"""
Specialized handling of tuples to concatenate the nested tuples
returned by :meth:`visit_Node`.
"""
retval = ()
for ch in o:
if ch is not None:
retval += self.visit(ch, **kwargs)
return retval

visit_list = visit_tuple

@classmethod
def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
'''Check for the use of Fortran 90 comparison operators.'''
# We extract all `Comparison` expression nodes, grouped by the IR node they are in.
# Then we run through all such pairs and check the symbol used in the source string.
retriever = ExpressionRetriever(lambda e: isinstance(e, sym.Comparison))
finder = ExpressionFinder(unique=False, retrieve=retriever.retrieve, with_ir_node=True)
for node, expr_list in finder.visit(subroutine.ir):
# First, we group all the expressions found in this node by their expression root
# (This is mostly required for Conditionals/MultiConditionals, where the different
# if-elseif-cases or select values are on different source lines)
root_expr_map = defaultdict(list)
for expr in expr_list:
expr_root = FindExpressionRoot(expr).visit(node)[0]
if node.source and node.source.string:
# Include only if we have source string information for this node
root_expr_map[expr_root] += [expr]

# Then we look at the comparison operators for each expression root and match
# them directly in the source string
for expr_root, exprs in root_expr_map.items():
# find source lines for expression root
lstart, lend = node.source.find(str(expr_root))
lines = node.source.clone_lines((lstart, lend))

# For each comparison operator, check if F90 or F77 operators are matched
for op in sorted({op.operator for op in exprs}):

# find source line for operator
op_str = op if op != '!=' else '/='
line = [line for line in lines if op_str in strip_inline_comments(line.string)]
if not line:
_op_map = {
'==': '.eq.',
'/=': '.ne.',
'>=': '.ge.',
'<=': '.le.',
'>': '.gt.',
'<': '.lt.'
}
line = [line for line in lines
if op_str in strip_inline_comments(line.string.replace(_op_map[op_str], op_str))]

source_string = strip_inline_comments(line[0].string)
matches = cls._op_patterns[op].findall(source_string)
for f77, _ in matches:
if f77:
msg = f'Use Fortran 90 comparison operator "{op_str}" instead of "{f77}"'
rule_report.add(msg, node)
# Use the bespoke visitor to retrieve all comparison nodes alongside with their expression root
# and the IR node they belong to
for node, expr_root, expr_list in cls.ComparisonRetriever().visit(subroutine.ir):
# Use the string representation of the expression to find the source line
lstart, lend = node.source.find(str(expr_root))
lines = node.source.clone_lines((lstart, lend))

# For each comparison operator, use the original source code (because the frontends always
# translate them to F90 operators) to check if F90 or F77 operators were used
for op in sorted({op.operator for op in expr_list}):
# find source line for operator
op_str = op if op != '!=' else '/='
line = [line for line in lines if op_str in strip_inline_comments(line.string)]
if not line:
line = [line for line in lines
if op_str in strip_inline_comments(line.string.replace(cls._op_map[op_str], op_str))]

source_string = strip_inline_comments(line[0].string)
matches = cls._op_patterns[op].findall(source_string)
for f77, _ in matches:
if f77:
msg = f'Use Fortran 90 comparison operator "{op_str}" instead of "{f77}"'
rule_report.add(msg, node)

@classmethod
def fix_subroutine(cls, subroutine, rule_report, config):
Expand Down
20 changes: 14 additions & 6 deletions lint_rules/tests/test_ifs_coding_standards_2011.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,12 +724,20 @@ def test_fortran_90_operators(rules, frontend):
keywords = ('Fortran90OperatorsRule', '[4.15]', 'Use Fortran 90 comparison operator')
assert all(all(keyword in msg for keyword in keywords) for msg in messages)

f77_f90_line = (('.ne.', '/=', '7'), ('.eq.', '==', '7'),
('.lt.', '<', '6'), ('.gt.', '>', '6'),
('.le.', '<=', '5'), ('.ge.', '>=', '5'),
('.gt.', '>', '26'), ('.gt.', '>', '32'),
('.eq.', '==', '29'), ('.gt.', '>', '25'),
('.le.', '<=', '23'))
# Check that violations are reported in the right order
f77_f90_line = (
('.le.', '<=', '5'),
('.ge.', '>=', '5'),
('.lt.', '<', '6'),
('.gt.', '>', '6'),
('.ne.', '/=', '7'),
('.eq.', '==', '7'),
('.le.', '<=', '23'),
('.gt.', '>', '25'),
('.gt.', '>', '26'),
('.eq.', '==', '29'),
('.gt.', '>', '32'),
)

for keywords, message in zip(f77_f90_line, messages):
assert all(str(keyword) in message for keyword in keywords)
Loading

0 comments on commit 7094446

Please sign in to comment.