Skip to content

Commit

Permalink
DynamicUboundCheckRule: fixer method now tries to preserve declaratio…
Browse files Browse the repository at this point in the history
…ns wherever possible
  • Loading branch information
awnawab committed Aug 29, 2023
1 parent 0711e9d commit e5d2c6d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 27 deletions.
31 changes: 11 additions & 20 deletions lint_rules/lint_rules/debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FindNodes, CallStatement, Assignment, Scalar, RangeIndex, resolve_associates,
simplify, Sum, Product, IntLiteral, as_tuple, SubstituteExpressions, Array,
symbolic_op, StringLiteral, is_constant, LogicLiteral, VariableDeclaration, flatten,
FindInlineCalls, Conditional, FindExpressions, Comparison, single_variable_declaration
FindInlineCalls, Conditional, FindExpressions, Comparison
)
from loki.lint import GenericRule, RuleType

Expand Down Expand Up @@ -273,8 +273,8 @@ def fix_subroutine(cls, subroutine, rule_report, config):
ubound_checks = cls.get_ubound_checks(subroutine)
args = cls.get_assumed_shape_args(subroutine)

new_vars = ()
node_map = {}
var_map = {}

for arg in args:
checks = [c for c in ubound_checks if arg.name in c.arguments]
Expand All @@ -300,25 +300,16 @@ def fix_subroutine(cls, subroutine, rule_report, config):
else:
new_shape += as_tuple(cond.left)

vtype = arg.type.clone(shape=new_shape, scope=subroutine)
new_vars += as_tuple(arg.clone(type=vtype, dimensions=new_shape, scope=subroutine))
vtype = arg.type.clone(shape=new_shape)
var_map.update({arg: arg.clone(type=vtype, dimensions=new_shape)})

# simplify variable declarations
single_variable_declaration(subroutine)

#TODO: 'VariableDeclaration.symbols' should be of type 'Variable' rather than 'Expression'
# to enable case-insensitive search here
new_var_names = [v.name.lower() for v in new_vars]

routine = subroutine.clone()
routine.variables = [var for var in routine.variables if not var.name.lower() in new_var_names]
routine.variables += new_vars

old_decls = as_tuple([decl for decl in FindNodes(VariableDeclaration).visit(subroutine.spec)
if decl.symbols[0].name.lower() in new_var_names])
new_decls = as_tuple([decl for decl in FindNodes(VariableDeclaration).visit(routine.spec)
if decl.symbols[0].name.lower() in new_var_names])
node_map.update({old_decls: new_decls})
# update variable declarations
subroutine.spec = SubstituteExpressions(var_map).visit(subroutine.spec)
for decl in FindNodes(VariableDeclaration).visit(subroutine.spec):
if decl.dimensions:
if not all(sym.shape == decl.dimensions for sym in decl.symbols):
new_decls = as_tuple(VariableDeclaration(as_tuple(sym)) for sym in decl.symbols)
node_map.update({decl: new_decls})

return node_map

Expand Down
29 changes: 22 additions & 7 deletions lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,13 @@ def test_dynamic_ubound_checks(rules, frontend):
"""

fcode = """
subroutine kernel(klon, klev, nblk, var0, var1, var2)
subroutine kernel(klon, klev, nblk, var0, var1, var2, var3, var4)
use abort_mod
implicit none
integer, intent(in) :: klon, klev, nblk
real, dimension(:,:,:), intent(inout) :: var0, var1
real, dimension(:,:,:), intent(inout) :: var2
real, intent(inout) :: var3(:,:), var4(:,:,:)
if(ubound(var0, 1) < klon)then
call abort('kernel: first dimension of var0 too short')
Expand All @@ -203,7 +204,11 @@ def test_dynamic_ubound_checks(rules, frontend):
call abort('kernel: dimensions of var2 too short')
endif
call some_other_kernel(klon, klen, nblk, var0, var1, var2)
if(ubound(var4, 1) < klon .and. ubound(var4, 2) < klev .and. ubound(var4, 3) < nblk)then
call abort('kernel: dimensions of var4 too short')
endif
call some_other_kernel(klon, klen, nblk, var0, var1, var2, var3, var4)
end subroutine kernel
""".strip()
Expand All @@ -216,11 +221,12 @@ def test_dynamic_ubound_checks(rules, frontend):
_ = run_linter(kernel, [rules.DynamicUboundCheckRule], config={'fix': True}, handlers=[handler])

# check rule violations
assert len(messages) == 2
assert len(messages) == 3
assert all('DynamicUboundCheckRule' in msg for msg in messages)

assert 'var0' in messages[0]
assert 'var2' in messages[1]
assert 'var4' in messages[2]

# check fixed subroutine
routine = kernel['kernel']
Expand All @@ -233,13 +239,22 @@ def test_dynamic_ubound_checks(rules, frontend):

assert all(s.name == d for s, d in zip(routine.variable_map['var0'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var2'].shape, shape))
assert all(s.name == d for s, d in zip(routine.variable_map['var4'].shape, shape))

arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2']
arg_names = ['klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4']
assert [arg.name.lower() for arg in routine.arguments] == arg_names

# check that variable declarations have not been duplicated
symbols = [s.name.lower() for decl in FindNodes(VariableDeclaration).visit(routine.spec) for s in decl.symbols]
assert len(symbols) == 6
assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2'}
declarations = FindNodes(VariableDeclaration).visit(routine.spec)
symbols = [s.name.lower() for decl in declarations for s in decl.symbols]
assert len(symbols) == 8
assert set(symbols) == {'klon', 'klev', 'nblk', 'var0', 'var1', 'var2', 'var3', 'var4'}

# check number of declarations and symbols per declarations
assert len(declarations) == 5
assert len(declarations[0].symbols) == 3
for decl in declarations[1:4]:
assert len(decl.symbols) == 1
assert len(declarations[4].symbols) == 2

os.remove(kernel.path)

0 comments on commit e5d2c6d

Please sign in to comment.