Skip to content

Commit

Permalink
Transformations: Small test cleanup for test_sanitise.
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 5, 2024
1 parent b9b0cb6 commit 474a768
Showing 1 changed file with 32 additions and 35 deletions.
67 changes: 32 additions & 35 deletions loki/transformations/tests/test_sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
BasicType, FindNodes, Subroutine, Module, fgen
)
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, Assignment, Associate, CallStatement,
Conditional
)
from loki.ir import nodes as ir, FindNodes

from loki.transformations.sanitise import (
resolve_associates, transform_sequence_association,
Expand Down Expand Up @@ -43,18 +40,18 @@ def test_transform_associates_simple(frontend):
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(Associate).visit(routine.body)) == 1
assert len(FindNodes(Assignment).visit(routine.body)) == 1
assign = FindNodes(Assignment).visit(routine.body)[0]
assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'a' and 'some_obj' not in assign.rhs
assert assign.rhs.type.dtype == BasicType.DEFERRED

# Now apply the association resolver
resolve_associates(routine)

assert len(FindNodes(Associate).visit(routine.body)) == 0
assert len(FindNodes(Assignment).visit(routine.body)) == 1
assign = FindNodes(Assignment).visit(routine.body)[0]
assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'some_obj%a'
assert assign.rhs.parent == 'some_obj'
assert assign.rhs.type.dtype == BasicType.DEFERRED
Expand Down Expand Up @@ -87,18 +84,18 @@ def test_transform_associates_nested(frontend):
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(Associate).visit(routine.body)) == 3
assert len(FindNodes(Assignment).visit(routine.body)) == 1
assign = FindNodes(Assignment).visit(routine.body)[0]
assert len(FindNodes(ir.Associate).visit(routine.body)) == 3
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.lhs == 'rick' and assign.rhs == 'a'
assert assign.rhs.type.dtype == BasicType.DEFERRED

# Now apply the association resolver
resolve_associates(routine)

assert len(FindNodes(Associate).visit(routine.body)) == 0
assert len(FindNodes(Assignment).visit(routine.body)) == 1
assign = FindNodes(Assignment).visit(routine.body)[0]
assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'some_obj%never%gonna%give%you%up'


Expand Down Expand Up @@ -129,18 +126,18 @@ def test_transform_associates_array_call(frontend):

routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(Associate).visit(routine.body)) == 1
assert len(FindNodes(CallStatement).visit(routine.body)) == 1
call = FindNodes(CallStatement).visit(routine.body)[0]
assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
call = FindNodes(ir.CallStatement).visit(routine.body)[0]
assert call.kwarguments[0][1] == 'some_array(i)%n'
assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED

# Now apply the association resolver
resolve_associates(routine)

assert len(FindNodes(Associate).visit(routine.body)) == 0
assert len(FindNodes(CallStatement).visit(routine.body)) == 1
call = FindNodes(CallStatement).visit(routine.body)[0]
assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
call = FindNodes(ir.CallStatement).visit(routine.body)[0]
assert call.kwarguments[0][1] == 'some_obj%some_array(i)%n'
assert call.kwarguments[0][1].scope == routine
assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED
Expand Down Expand Up @@ -182,20 +179,20 @@ def test_transform_associates_nested_conditional(frontend):
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(Conditional).visit(routine.body)) == 2
assert len(FindNodes(Associate).visit(routine.body)) == 1
assert len(FindNodes(Assignment).visit(routine.body)) == 3
assign = FindNodes(Assignment).visit(routine.body)[1]
assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2
assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
assign = FindNodes(ir.Assignment).visit(routine.body)[1]
assert assign.rhs == 'a' and 'some_obj' not in assign.rhs
assert assign.rhs.type.dtype == BasicType.DEFERRED

# Now apply the association resolver
resolve_associates(routine)

assert len(FindNodes(Conditional).visit(routine.body)) == 2
assert len(FindNodes(Associate).visit(routine.body)) == 0
assert len(FindNodes(Assignment).visit(routine.body)) == 3
assign = FindNodes(Assignment).visit(routine.body)[1]
assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2
assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
assign = FindNodes(ir.Assignment).visit(routine.body)[1]
assert assign.rhs == 'some_obj%a'
assert assign.rhs.parent == 'some_obj'
assert assign.rhs.type.dtype == BasicType.DEFERRED
Expand Down Expand Up @@ -299,7 +296,7 @@ def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path):

transform_sequence_association(routine)

calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)

assert fgen(calls[0]).lower() == 'call sub_x(array(1:10, 1), 1)'
assert fgen(calls[1]).lower() == 'call sub_x(array(2:10, 2), 2)'
Expand Down Expand Up @@ -349,9 +346,9 @@ def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence,
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['test_transformation_sanitise']

assoc = FindNodes(Associate).visit(routine.body)
assoc = FindNodes(ir.Associate).visit(routine.body)
assert len(assoc) == 1
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[0] == 'a(1)'

Expand All @@ -361,9 +358,9 @@ def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence,
)
trafo.apply(routine)

assoc = FindNodes(Associate).visit(routine.body)
assoc = FindNodes(ir.Associate).visit(routine.body)
assert len(assoc) == 0 if resolve_associate else 1

calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[0] == 'a(1:3)' if resolve_sequence else 'a(1)'

0 comments on commit 474a768

Please sign in to comment.