Skip to content

Commit

Permalink
Merge pull request #126 from ecmwf-ifs/naan-debugrules-fix
Browse files Browse the repository at this point in the history
Linter debug rule fixes
  • Loading branch information
reuterbal authored Sep 5, 2023
2 parents 3846c67 + e5d2c6d commit 400ebbb
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 32 deletions.
47 changes: 34 additions & 13 deletions lint_rules/lint_rules/debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FindNodes, CallStatement, Assignment, Scalar, RangeIndex, resolve_associates,
simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array,
symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten,
FindInlineCalls, Conditional, Transformer, FindExpressions, Comparison
FindInlineCalls, Conditional, FindExpressions, Comparison
)
from loki.lint import GenericRule, RuleType

Expand All @@ -21,6 +21,10 @@ class ArgSizeMismatchRule(GenericRule):

type = RuleType.WARN

config = {
'max_indirections': 2,
}

@staticmethod
def range_to_sum(lower, upper):
"""
Expand Down Expand Up @@ -106,6 +110,8 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
:any:`Subroutine`.
"""

max_indirections = config['max_indirections']

# first resolve associates
resolve_associates(subroutine)

Expand Down Expand Up @@ -183,9 +189,22 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
dummy_size = Product(dummy_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

# if necessary, update dimension names and check
if not stat:
dummy_size = Product(SubstituteExpressions(assign_map).visit(dummy_arg_size))
# we check for a configurable number of indirections for the dummy and arg dimension names
for _ in range(max_indirections):
if stat:
break

# if necessary, update dummy arg dimension names and check
dummy_arg_size = SubstituteExpressions(assign_map).visit(dummy_arg_size)
dummy_size = Product(dummy_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

if stat:
break

# if necessary, update arg dimension names and check
arg_size = SubstituteExpressions(assign_map).visit(arg_size)
alt_arg_size = SubstituteExpressions(assign_map).visit(alt_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

if not stat:
Expand Down Expand Up @@ -254,8 +273,8 @@ def fix_subroutine(cls, subroutine, rule_report, config):
ubound_checks = cls.get_ubound_checks(subroutine)
args = cls.get_assumed_shape_args(subroutine)

new_vars = ()
node_map = {}
var_map = {}

for arg in args:
checks = [c for c in ubound_checks if arg.name in c.arguments]
Expand All @@ -281,16 +300,18 @@ def fix_subroutine(cls, subroutine, rule_report, config):
else:
new_shape += as_tuple(cond.left)

vtype = arg.type.clone(shape=new_shape, scope=subroutine)
new_vars += as_tuple(arg.clone(type=vtype, dimensions=new_shape, scope=subroutine))
vtype = arg.type.clone(shape=new_shape)
var_map.update({arg: arg.clone(type=vtype, dimensions=new_shape)})

#TODO: add 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression'
# to enable case-insensitive search here
new_var_names = [v.name.lower() for v in new_vars]
subroutine.variables = [var for var in subroutine.variables if not var.name.lower() in new_var_names]
# update variable declarations
subroutine.spec = SubstituteExpressions(var_map).visit(subroutine.spec)
for decl in FindNodes(VariableDeclaration).visit(subroutine.spec):
if decl.dimensions:
if not all(sym.shape == decl.dimensions for sym in decl.symbols):
new_decls = as_tuple(VariableDeclaration(as_tuple(sym)) for sym in decl.symbols)
node_map.update({decl: new_decls})

subroutine.body = Transformer(node_map).visit(subroutine.body)
subroutine.variables += new_vars
return node_map

# Create the __all__ property of the module to contain only the rule names
__all__ = tuple(name for name in dir() if name.endswith('Rule') and name != 'GenericRule')
19 changes: 9 additions & 10 deletions lint_rules/lint_rules/ifs_coding_standards_2011.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,15 +472,6 @@ class Fortran90OperatorsRule(GenericRule): # Coding standards 4.15
'<': re.compile(r'(?P<f77>\.lt\.)|(?P<f90><(?!=))', re.I),
}

_op_map = {
'==': '.eq.',
'/=': '.ne.',
'>=': '.ge.',
'<=': '.le.',
'>': '.gt.',
'<': '.lt.'
}

@classmethod
def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
'''Check for the use of Fortran 90 comparison operators.'''
Expand Down Expand Up @@ -513,8 +504,16 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
op_str = op if op != '!=' else '/='
line = [line for line in lines if op_str in strip_inline_comments(line.string)]
if not line:
_op_map = {
'==': '.eq.',
'/=': '.ne.',
'>=': '.ge.',
'<=': '.le.',
'>': '.gt.',
'<': '.lt.'
}
line = [line for line in lines
if op_str in strip_inline_comments(line.string.replace(cls._op_map[op_str], op_str))]
if op_str in strip_inline_comments(line.string.replace(_op_map[op_str], op_str))]

source_string = strip_inline_comments(line[0].string)
matches = cls._op_patterns[op].findall(source_string)
Expand Down
46 changes: 37 additions & 9 deletions lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from conftest import run_linter, available_frontends
from loki import Sourcefile, FindInlineCalls
from loki import Sourcefile, FindInlineCalls, FindNodes, VariableDeclaration
from loki.lint import DefaultHandler


Expand Down Expand Up @@ -42,14 +42,16 @@ def test_arg_size_array_slices(rules, frontend):
real, intent(in) :: var6(:,:), var7(:,:)
real, intent(inout) :: var0(klon, nblk), var1(klon, 138, nblk)
real(kind=jphook) :: zhook_handle
integer :: klev, ibl
integer :: klev, ibl, iproma, iend
if(lhook) call dr_hook('driver', 0, zhook_handle)
associate(nlev => klev)
nlev = 137
do ibl = 1, nblk
call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:klon, 1:nlev), &
iproma = klon
iend = iproma
call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:iend, 1:nlev), &
var3, var4(1:klon, 1:nlev+1), var5(:, 1:nlev+1), &
var6_d=var6, var7_d=var7(:,1:nlev))
enddo
Expand Down Expand Up @@ -84,7 +86,8 @@ def test_arg_size_array_slices(rules, frontend):

messages = []
handler = DefaultHandler(target=messages.append)
_ = run_linter(driver_source, [rules.ArgSizeMismatchRule], handlers=[handler], targets=['kernel',])
_ = run_linter(driver_source, [rules.ArgSizeMismatchRule], config={'ArgSizeMismatchRule': {'max_indirections': 3}},
handlers=[handler], targets=['kernel',])

assert len(messages) == 3
keyword = 'ArgSizeMismatchRule'
Expand Down Expand Up @@ -113,13 +116,15 @@ def test_arg_size_array_sequence(rules, frontend):
real(kind=jphook) :: zhook_handle
real, dimension(klon, 137) :: var4, var5
real :: var6
integer :: klev, ibl
integer :: klev, ibl, iproma, iend
if(lhook) call dr_hook('driver', 0, zhook_handle)
klev = 137
do ibl = 1, nblk
call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1, 1), var3(1), &
iproma = klon
iend = iproma
call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1:iend, 1), var3(1), &
var4(1, 1), var5, var6, 1, .true.)
enddo
Expand Down Expand Up @@ -173,12 +178,13 @@ def test_dynamic_ubound_checks(rules, frontend):
"""

fcode = """
subroutine kernel(klon, klev, nblk, var0, var1, var2)
subroutine kernel(klon, klev, nblk, var0, var1, var2, var3, var4)
use abort_mod
implicit none
integer, intent(in) :: klon, klev, nblk
real, dimension(:,:,:), intent(inout) :: var0, var1
real, dimension(:,:,:), intent(inout) :: var2
real, intent(inout) :: var3(:,:), var4(:,:,:)
if(ubound(var0, 1) < klon)then
call abort('kernel: first dimension of var0 too short')
Expand All @@ -198,7 +204,11 @@ def test_dynamic_ubound_checks(rules, frontend):
call abort('kernel: dimensions of var2 too short')
endif
call some_other_kernel(klon, klen, nblk, var0, var1, var2)
if(ubound(var4, 1) < klon .and. ubound(var4, 2) < klev .and. ubound(var4, 3) < nblk)then
call abort('kernel: dimensions of var4 too short')
endif
call some_other_kernel(klon, klen, nblk, var0, var1, var2, var3, var4)
end subroutine kernel
""".strip()
Expand All @@ -211,11 +221,12 @@ def test_dynamic_ubound_checks(rules, frontend):
_ = run_linter(kernel, [rules.DynamicUboundCheckRule], config={'fix': True}, handlers=[handler])

# check rule violations
assert len(messages) == 2
assert len(messages) == 3
assert all('DynamicUboundCheckRule' in msg for msg in messages)

assert 'var0' in messages[0]
assert 'var2' in messages[1]
assert 'var4' in messages[2]

# check fixed subroutine
routine = kernel['kernel']
Expand All @@ -228,5 +239,22 @@ def test_dynamic_ubound_checks(rules, frontend):

assert all(s.name == d for s, d in zip(routine.variable_map['var0'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var2'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var4'].shape, shape))

arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4']
assert [arg.name.lower() for arg in routine.arguments] == arg_names

# check that variable declarations have not been duplicated
declarations = FindNodes(VariableDeclaration).visit(routine.spec)
symbols = [s.name.lower() for decl in declarations for s in decl.symbols]
assert len(symbols) == 8
assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4'}

# check number of declarations and symbols per declarations
assert len(declarations) == 5
assert len(declarations[0].symbols) == 3
for decl in declarations[1:4]:
assert len(decl.symbols) == 1
assert len(declarations[4].symbols) == 2

os.remove(kernel.path)

0 comments on commit 400ebbb

Please sign in to comment.