Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linter debug rule fixes #126

Merged
merged 5 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions lint_rules/lint_rules/debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FindNodes, CallStatement, Assignment, Scalar, RangeIndex, resolve_associates,
simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array,
symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten,
FindInlineCalls, Conditional, Transformer, FindExpressions, Comparison
FindInlineCalls, Conditional, FindExpressions, Comparison
)
from loki.lint import GenericRule, RuleType

Expand All @@ -21,6 +21,10 @@ class ArgSizeMismatchRule(GenericRule):

type = RuleType.WARN

config = {
'max_indirections': 2,
}

@staticmethod
def range_to_sum(lower, upper):
"""
Expand Down Expand Up @@ -106,6 +110,8 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
:any:`Subroutine`.
"""

max_indirections = config['max_indirections']

# first resolve associates
resolve_associates(subroutine)

Expand Down Expand Up @@ -183,9 +189,22 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
dummy_size = Product(dummy_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

# if necessary, update dimension names and check
if not stat:
dummy_size = Product(SubstituteExpressions(assign_map).visit(dummy_arg_size))
# we check for a configurable number of indirections for the dummy and arg dimension names
for _ in range(max_indirections):
if stat:
break

# if necessary, update dummy arg dimension names and check
dummy_arg_size = SubstituteExpressions(assign_map).visit(dummy_arg_size)
dummy_size = Product(dummy_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

if stat:
break

# if necessary, update arg dimension names and check
arg_size = SubstituteExpressions(assign_map).visit(arg_size)
alt_arg_size = SubstituteExpressions(assign_map).visit(alt_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

if not stat:
Expand Down Expand Up @@ -254,8 +273,8 @@ def fix_subroutine(cls, subroutine, rule_report, config):
ubound_checks = cls.get_ubound_checks(subroutine)
args = cls.get_assumed_shape_args(subroutine)

new_vars = ()
node_map = {}
var_map = {}

for arg in args:
checks = [c for c in ubound_checks if arg.name in c.arguments]
Expand All @@ -281,16 +300,18 @@ def fix_subroutine(cls, subroutine, rule_report, config):
else:
new_shape += as_tuple(cond.left)

vtype = arg.type.clone(shape=new_shape, scope=subroutine)
new_vars += as_tuple(arg.clone(type=vtype, dimensions=new_shape, scope=subroutine))
vtype = arg.type.clone(shape=new_shape)
var_map.update({arg: arg.clone(type=vtype, dimensions=new_shape)})

#TODO: add 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression'
# to enable case-insensitive search here
new_var_names = [v.name.lower() for v in new_vars]
subroutine.variables = [var for var in subroutine.variables if not var.name.lower() in new_var_names]
# update variable declarations
subroutine.spec = SubstituteExpressions(var_map).visit(subroutine.spec)
for decl in FindNodes(VariableDeclaration).visit(subroutine.spec):
if decl.dimensions:
if not all(sym.shape == decl.dimensions for sym in decl.symbols):
new_decls = as_tuple(VariableDeclaration(as_tuple(sym)) for sym in decl.symbols)
node_map.update({decl: new_decls})

subroutine.body = Transformer(node_map).visit(subroutine.body)
subroutine.variables += new_vars
return node_map

# Create the __all__ property of the module to contain only the rule names
__all__ = tuple(name for name in dir() if name.endswith('Rule') and name != 'GenericRule')
19 changes: 9 additions & 10 deletions lint_rules/lint_rules/ifs_coding_standards_2011.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,15 +472,6 @@ class Fortran90OperatorsRule(GenericRule): # Coding standards 4.15
'<': re.compile(r'(?P<f77>\.lt\.)|(?P<f90><(?!=))', re.I),
}

_op_map = {
'==': '.eq.',
'/=': '.ne.',
'>=': '.ge.',
'<=': '.le.',
'>': '.gt.',
'<': '.lt.'
}

@classmethod
def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
'''Check for the use of Fortran 90 comparison operators.'''
Expand Down Expand Up @@ -513,8 +504,16 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
op_str = op if op != '!=' else '/='
line = [line for line in lines if op_str in strip_inline_comments(line.string)]
if not line:
_op_map = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: any particular reason for moving the map here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As explained in #118, this ended up here because I misunderstood one of your earlier comments 😅

'==': '.eq.',
'/=': '.ne.',
'>=': '.ge.',
'<=': '.le.',
'>': '.gt.',
'<': '.lt.'
}
line = [line for line in lines
if op_str in strip_inline_comments(line.string.replace(cls._op_map[op_str], op_str))]
if op_str in strip_inline_comments(line.string.replace(_op_map[op_str], op_str))]

source_string = strip_inline_comments(line[0].string)
matches = cls._op_patterns[op].findall(source_string)
Expand Down
46 changes: 37 additions & 9 deletions lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest

from conftest import run_linter, available_frontends
from loki import Sourcefile, FindInlineCalls
from loki import Sourcefile, FindInlineCalls, FindNodes, VariableDeclaration
from loki.lint import DefaultHandler


Expand Down Expand Up @@ -42,14 +42,16 @@ def test_arg_size_array_slices(rules, frontend):
real, intent(in) :: var6(:,:), var7(:,:)
real, intent(inout) :: var0(klon, nblk), var1(klon, 138, nblk)
real(kind=jphook) :: zhook_handle
integer :: klev, ibl
integer :: klev, ibl, iproma, iend

if(lhook) call dr_hook('driver', 0, zhook_handle)

associate(nlev => klev)
nlev = 137
do ibl = 1, nblk
call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:klon, 1:nlev), &
iproma = klon
iend = iproma
call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:iend, 1:nlev), &
var3, var4(1:klon, 1:nlev+1), var5(:, 1:nlev+1), &
var6_d=var6, var7_d=var7(:,1:nlev))
enddo
Expand Down Expand Up @@ -84,7 +86,8 @@ def test_arg_size_array_slices(rules, frontend):

messages = []
handler = DefaultHandler(target=messages.append)
_ = run_linter(driver_source, [rules.ArgSizeMismatchRule], handlers=[handler], targets=['kernel',])
_ = run_linter(driver_source, [rules.ArgSizeMismatchRule], config={'ArgSizeMismatchRule': {'max_indirections': 3}},
handlers=[handler], targets=['kernel',])

assert len(messages) == 3
keyword = 'ArgSizeMismatchRule'
Expand Down Expand Up @@ -113,13 +116,15 @@ def test_arg_size_array_sequence(rules, frontend):
real(kind=jphook) :: zhook_handle
real, dimension(klon, 137) :: var4, var5
real :: var6
integer :: klev, ibl
integer :: klev, ibl, iproma, iend

if(lhook) call dr_hook('driver', 0, zhook_handle)

klev = 137
do ibl = 1, nblk
call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1, 1), var3(1), &
iproma = klon
iend = iproma
call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1:iend, 1), var3(1), &
var4(1, 1), var5, var6, 1, .true.)
enddo

Expand Down Expand Up @@ -173,12 +178,13 @@ def test_dynamic_ubound_checks(rules, frontend):
"""

fcode = """
subroutine kernel(klon, klev, nblk, var0, var1, var2)
subroutine kernel(klon, klev, nblk, var0, var1, var2, var3, var4)
use abort_mod
implicit none
integer, intent(in) :: klon, klev, nblk
real, dimension(:,:,:), intent(inout) :: var0, var1
real, dimension(:,:,:), intent(inout) :: var2
real, intent(inout) :: var3(:,:), var4(:,:,:)

if(ubound(var0, 1) < klon)then
call abort('kernel: first dimension of var0 too short')
Expand All @@ -198,7 +204,11 @@ def test_dynamic_ubound_checks(rules, frontend):
call abort('kernel: dimensions of var2 too short')
endif

call some_other_kernel(klon, klen, nblk, var0, var1, var2)
if(ubound(var4, 1) < klon .and. ubound(var4, 2) < klev .and. ubound(var4, 3) < nblk)then
call abort('kernel: dimensions of var4 too short')
endif

call some_other_kernel(klon, klen, nblk, var0, var1, var2, var3, var4)

end subroutine kernel
""".strip()
Expand All @@ -211,11 +221,12 @@ def test_dynamic_ubound_checks(rules, frontend):
_ = run_linter(kernel, [rules.DynamicUboundCheckRule], config={'fix': True}, handlers=[handler])

# check rule violations
assert len(messages) == 2
assert len(messages) == 3
assert all('DynamicUboundCheckRule' in msg for msg in messages)

assert 'var0' in messages[0]
assert 'var2' in messages[1]
assert 'var4' in messages[2]

# check fixed subroutine
routine = kernel['kernel']
Expand All @@ -228,5 +239,22 @@ def test_dynamic_ubound_checks(rules, frontend):

assert all(s.name == d for s, d in zip(routine.variable_map['var0'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var2'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var4'].shape, shape))

arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4']
assert [arg.name.lower() for arg in routine.arguments] == arg_names

# check that variable declarations have not been duplicated
declarations = FindNodes(VariableDeclaration).visit(routine.spec)
symbols = [s.name.lower() for decl in declarations for s in decl.symbols]
assert len(symbols) == 8
assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4'}

# check number of declarations and symbols per declarations
assert len(declarations) == 5
assert len(declarations[0].symbols) == 3
for decl in declarations[1:4]:
assert len(decl.symbols) == 1
assert len(declarations[4].symbols) == 2

os.remove(kernel.path)