Skip to content

Commit

Permalink
ArgSizeMismatchRule: can now check for a configurable number of indir…
Browse files Browse the repository at this point in the history
…ections of arg size names
  • Loading branch information
awnawab committed Aug 29, 2023
1 parent 3473ccb commit 0711e9d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
25 changes: 22 additions & 3 deletions lint_rules/lint_rules/debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class ArgSizeMismatchRule(GenericRule):

type = RuleType.WARN

config = {
'max_indirections': 2,
}

@staticmethod
def range_to_sum(lower, upper):
"""
Expand Down Expand Up @@ -106,6 +110,8 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
:any:`Subroutine`.
"""

max_indirections = config['max_indirections']

# first resolve associates
resolve_associates(subroutine)

Expand Down Expand Up @@ -183,9 +189,22 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs):
dummy_size = Product(dummy_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

# if necessary, update dimension names and check
if not stat:
dummy_size = Product(SubstituteExpressions(assign_map).visit(dummy_arg_size))
# we check for a configurable number of indirections for the dummy and arg dimension names
for _ in range(max_indirections):
if stat:
break

# if necessary, update dummy arg dimension names and check
dummy_arg_size = SubstituteExpressions(assign_map).visit(dummy_arg_size)
dummy_size = Product(dummy_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

if stat:
break

# if necessary, update arg dimension names and check
arg_size = SubstituteExpressions(assign_map).visit(arg_size)
alt_arg_size = SubstituteExpressions(assign_map).visit(alt_arg_size)
stat = cls.compare_sizes(arg_size, alt_arg_size, dummy_size)

if not stat:
Expand Down
15 changes: 10 additions & 5 deletions lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ def test_arg_size_array_slices(rules, frontend):
real, intent(in) :: var6(:,:), var7(:,:)
real, intent(inout) :: var0(klon, nblk), var1(klon, 138, nblk)
real(kind=jphook) :: zhook_handle
integer :: klev, ibl
integer :: klev, ibl, iproma, iend
if(lhook) call dr_hook('driver', 0, zhook_handle)
associate(nlev => klev)
nlev = 137
do ibl = 1, nblk
call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:klon, 1:nlev), &
iproma = klon
iend = iproma
call kernel(klon, nlev, var0(:,ibl), var1(:,:,ibl), var2(1:iend, 1:nlev), &
var3, var4(1:klon, 1:nlev+1), var5(:, 1:nlev+1), &
var6_d=var6, var7_d=var7(:,1:nlev))
enddo
Expand Down Expand Up @@ -84,7 +86,8 @@ def test_arg_size_array_slices(rules, frontend):

messages = []
handler = DefaultHandler(target=messages.append)
_ = run_linter(driver_source, [rules.ArgSizeMismatchRule], handlers=[handler], targets=['kernel',])
_ = run_linter(driver_source, [rules.ArgSizeMismatchRule], config={'ArgSizeMismatchRule': {'max_indirections': 3}},
handlers=[handler], targets=['kernel',])

assert len(messages) == 3
keyword = 'ArgSizeMismatchRule'
Expand Down Expand Up @@ -113,13 +116,15 @@ def test_arg_size_array_sequence(rules, frontend):
real(kind=jphook) :: zhook_handle
real, dimension(klon, 137) :: var4, var5
real :: var6
integer :: klev, ibl
integer :: klev, ibl, iproma, iend
if(lhook) call dr_hook('driver', 0, zhook_handle)
klev = 137
do ibl = 1, nblk
call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1, 1), var3(1), &
iproma = klon
iend = iproma
call kernel(klon, klev, var0(1,ibl), var1(1,1,ibl), var2(1:iend, 1), var3(1), &
var4(1, 1), var5, var6, 1, .true.)
enddo
Expand Down

0 comments on commit 0711e9d

Please sign in to comment.