diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index be71ffded..97426f677 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -11,10 +11,7 @@ BasicType, FindNodes, Subroutine, Module, fgen ) from loki.frontend import available_frontends, OMNI -from loki.ir import ( - nodes as ir, FindNodes, Assignment, Associate, CallStatement, - Conditional -) +from loki.ir import nodes as ir, FindNodes from loki.transformations.sanitise import ( resolve_associates, transform_sequence_association, @@ -43,18 +40,18 @@ def test_transform_associates_simple(frontend): """ routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Associate).visit(routine.body)) == 1 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.rhs == 'a' and 'some_obj' not in assign.rhs assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.rhs == 'some_obj%a' assert assign.rhs.parent == 'some_obj' assert assign.rhs.type.dtype == BasicType.DEFERRED @@ -87,18 +84,18 @@ def test_transform_associates_nested(frontend): """ routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Associate).visit(routine.body)) == 3 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 3 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.lhs == 'rick' and assign.rhs == 'a' assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.rhs == 'some_obj%never%gonna%give%you%up' @@ -129,18 +126,18 @@ def test_transform_associates_array_call(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Associate).visit(routine.body)) == 1 - assert len(FindNodes(CallStatement).visit(routine.body)) == 1 - call = FindNodes(CallStatement).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 + call = FindNodes(ir.CallStatement).visit(routine.body)[0] assert call.kwarguments[0][1] == 'some_array(i)%n' assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(CallStatement).visit(routine.body)) == 1 - call = FindNodes(CallStatement).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 + call = FindNodes(ir.CallStatement).visit(routine.body)[0] assert call.kwarguments[0][1] == 'some_obj%some_array(i)%n' assert call.kwarguments[0][1].scope == routine assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED @@ -182,20 +179,20 @@ def test_transform_associates_nested_conditional(frontend): """ routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Conditional).visit(routine.body)) == 2 - assert len(FindNodes(Associate).visit(routine.body)) == 1 - assert len(FindNodes(Assignment).visit(routine.body)) == 3 - assign = FindNodes(Assignment).visit(routine.body)[1] + assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2 + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + assign = FindNodes(ir.Assignment).visit(routine.body)[1] assert assign.rhs == 'a' and 'some_obj' not in assign.rhs assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Conditional).visit(routine.body)) == 2 - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(Assignment).visit(routine.body)) == 3 - assign = FindNodes(Assignment).visit(routine.body)[1] + assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2 + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + assign = FindNodes(ir.Assignment).visit(routine.body)[1] assert assign.rhs == 'some_obj%a' assert assign.rhs.parent == 'some_obj' assert assign.rhs.type.dtype == BasicType.DEFERRED @@ -299,7 +296,7 @@ def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path): transform_sequence_association(routine) - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) assert fgen(calls[0]).lower() == 'call sub_x(array(1:10, 1), 1)' assert fgen(calls[1]).lower() == 'call sub_x(array(2:10, 2), 2)' @@ -349,9 +346,9 @@ def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence, module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) routine = module['test_transformation_sanitise'] - assoc = FindNodes(Associate).visit(routine.body) + assoc = FindNodes(ir.Associate).visit(routine.body) assert len(assoc) == 1 - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[0] == 'a(1)' @@ -361,9 +358,9 @@ def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence, ) trafo.apply(routine) - assoc = FindNodes(Associate).visit(routine.body) + assoc = FindNodes(ir.Associate).visit(routine.body) assert len(assoc) == 0 if resolve_associate else 1 - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[0] == 'a(1:3)' if resolve_sequence else 'a(1)'