diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 78a24e01f..8395d086e 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -704,6 +704,16 @@ def __repr__(self): return f'Conditional:: {self.name}' return 'Conditional::' + @property + def else_bodies(self): + """ + Return all nested node tuples in the ``ELSEIF``/``ELSE`` part + of the conditional chain. + """ + if self.has_elseif: + return (self.else_body[0].body,) + self.else_body[0].else_bodies + return (self.else_body,) if self.else_body else () + @dataclass_strict(frozen=True) class _PragmaRegionBase(): diff --git a/loki/ir/tests/test_ir_nodes.py b/loki/ir/tests/test_ir_nodes.py index cb17ea943..0fc8e66e6 100644 --- a/loki/ir/tests/test_ir_nodes.py +++ b/loki/ir/tests/test_ir_nodes.py @@ -151,6 +151,54 @@ def test_conditional(scope, one, i, n, a_i): # TODO: Test inline, name, has_elseif +def test_multi_conditional(scope, one, i, n, a_i): + """ + Test nested chains of constructors of :any:`Conditional` to form + multi-conditional. + """ + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(1)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)), + else_body=ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0)) + ) + for idx in range(2, 4): + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(idx)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))), + else_body=multicond, has_elseif=True + ) + + # Check that we can recover all bodies from a nested else-if construct + else_bodies = multicond.else_bodies + assert len(else_bodies) == 3 + assert all(isinstance(b, tuple) for b in else_bodies) + assert isinstance(else_bodies[0][0], ir.Assignment) + assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0' + assert isinstance(else_bodies[1][0], ir.Assignment) + assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0' + assert isinstance(else_bodies[2][0], ir.Assignment) + assert else_bodies[2][0].lhs == 'a(i)' and else_bodies[2][0].rhs == '42.0' + + # Not try without the final else + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(1)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)), + ) + for idx in range(2, 4): + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(idx)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))), + else_body=multicond, has_elseif=True + ) + else_bodies = multicond.else_bodies + assert len(else_bodies) == 2 + assert all(isinstance(b, tuple) for b in else_bodies) + assert isinstance(else_bodies[0][0], ir.Assignment) + assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0' + assert isinstance(else_bodies[1][0], ir.Assignment) + assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0' + + def test_section(scope, one, i, n, a_n, a_i): """ Test constructors and behaviour of :any:`Section` nodes. diff --git a/loki/transformations/single_column/tests/test_scc_vector.py b/loki/transformations/single_column/tests/test_scc_vector.py index 66238d29b..96378d869 100644 --- a/loki/transformations/single_column/tests/test_scc_vector.py +++ b/loki/transformations/single_column/tests/test_scc_vector.py @@ -573,20 +573,30 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block """ fcode_kernel = """ - subroutine some_kernel(start, end, nlon, flag0, flag1) + subroutine some_kernel(start, end, nlon, flag0, flag1, flag2) implicit none integer, intent(in) :: nlon, start, end - logical, intent(in) :: flag0, flag1 + logical, intent(in) :: flag0, flag1, flag2 real, dimension(nlon) :: work integer :: jl - if(flag0)then + if (flag0) then call some_other_kernel() - elseif(flag1)then + elseif (flag1) then + do jl=start,end + work(jl) = 1. + enddo + elseif (flag2) then do jl=start,end work(jl) = 1. + work(jl) = 2. + enddo + else + do jl=start,end + work(jl) = 41. + work(jl) = 42. enddo endif @@ -595,7 +605,7 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block routine = Subroutine.from_source(fcode_kernel, frontend=frontend) - # check whether pipeline can be applied and works as expected + # check whether pipeline can be applied and works as expected scc_pipeline = SCCVectorPipeline( horizontal=horizontal, vertical=vertical, block_dim=blocking, directive='openacc', trim_vector_sections=trim_vector_sections @@ -611,3 +621,11 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block assert isinstance(conditional.else_body[0].body[0], ir.Comment) assert isinstance(conditional.else_body[0].body[1], ir.Loop) assert conditional.else_body[0].body[1].pragma[0].content.lower() == 'loop vector' + + # Check that all else-bodies have been wrapped + else_bodies = conditional.else_bodies + assert(len(else_bodies) == 3) + for body in else_bodies: + assert isinstance(body[0], ir.Comment) + assert isinstance(body[1], ir.Loop) + assert body[1].pragma[0].content.lower() == 'loop vector' diff --git a/loki/transformations/single_column/vector.py b/loki/transformations/single_column/vector.py index 2764c019f..0578e2b11 100644 --- a/loki/transformations/single_column/vector.py +++ b/loki/transformations/single_column/vector.py @@ -137,15 +137,11 @@ def extract_vector_sections(cls, section, horizontal): subsec_body = cls.extract_vector_sections(separator.body, horizontal) if subsec_body: subsections += subsec_body - # we need to prevent that the whole 'else_body' is wrapped in a section, - # as 'Conditional's rely on the fact that the first element of the 'else_body' + # we need to prevent that all (possibly nested) 'else_bodies' are completely wrapped as a section, + # as 'Conditional's rely on the fact that the first element of each 'else_body' # (if 'has_elseif') is a Conditional itself - if separator.has_elseif and separator.else_body: - subsec_else = cls.extract_vector_sections(separator.else_body[0].body, horizontal) - else: - subsec_else = cls.extract_vector_sections(separator.else_body, horizontal) - if subsec_else: - subsections += subsec_else + for ebody in separator.else_bodies: + subsections += cls.extract_vector_sections(ebody, horizontal) if isinstance(separator, ir.MultiConditional): for body in separator.bodies: