From e8705b509cf64e8b4f30a79077b057933de4c029 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 9 Oct 2024 07:48:02 +0000 Subject: [PATCH 1/4] IR: Add `Conditional.else_bodies` property to gather nested bodies --- loki/ir/nodes.py | 10 ++++++++++ loki/ir/tests/test_ir_nodes.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 78a24e01f..aed0e064e 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -704,6 +704,16 @@ def __repr__(self): return f'Conditional:: {self.name}' return 'Conditional::' + @property + def else_bodies(self): + """ + Return all nested node tuples in the ``ELSEIF``/``ELSE`` part + of the conditional chain. + """ + if self.has_elseif: + return (self.else_body[0].body,) + self.else_body[0].else_bodies + return (self.else_body,) + @dataclass_strict(frozen=True) class _PragmaRegionBase(): diff --git a/loki/ir/tests/test_ir_nodes.py b/loki/ir/tests/test_ir_nodes.py index cb17ea943..b9a59dac5 100644 --- a/loki/ir/tests/test_ir_nodes.py +++ b/loki/ir/tests/test_ir_nodes.py @@ -151,6 +151,35 @@ def test_conditional(scope, one, i, n, a_i): # TODO: Test inline, name, has_elseif +def test_multi_conditional(scope, one, i, n, a_i): + """ + Test nested chains of constructors of :any:`Conditional` to form + multi-conditional. + """ + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(1)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)), + else_body=ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0)) + ) + for idx in range(2, 4): + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(idx)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))), + else_body=multicond, has_elseif=True + ) + + # Check that we can recover all bodies from a nested else-if construct + else_bodies = multicond.else_bodies + assert len(else_bodies) == 3 + assert all(isinstance(b, tuple) for b in else_bodies) + assert isinstance(else_bodies[0][0], ir.Assignment) + assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0' + assert isinstance(else_bodies[1][0], ir.Assignment) + assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0' + assert isinstance(else_bodies[2][0], ir.Assignment) + assert else_bodies[2][0].lhs == 'a(i)' and else_bodies[2][0].rhs == '42.0' + + def test_section(scope, one, i, n, a_n, a_i): """ Test constructors and behaviour of :any:`Section` nodes. From 832505026d4487a2a8e7dbdf296586b0ca31ccf4 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 9 Oct 2024 09:57:21 +0000 Subject: [PATCH 2/4] SingleColumn: Fix vectorisation of nested else-bodies in conditionals --- .../single_column/tests/test_scc_vector.py | 28 +++++++++++++++---- loki/transformations/single_column/vector.py | 8 ++---- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/loki/transformations/single_column/tests/test_scc_vector.py b/loki/transformations/single_column/tests/test_scc_vector.py index 66238d29b..96378d869 100644 --- a/loki/transformations/single_column/tests/test_scc_vector.py +++ b/loki/transformations/single_column/tests/test_scc_vector.py @@ -573,20 +573,30 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block """ fcode_kernel = """ - subroutine some_kernel(start, end, nlon, flag0, flag1) + subroutine some_kernel(start, end, nlon, flag0, flag1, flag2) implicit none integer, intent(in) :: nlon, start, end - logical, intent(in) :: flag0, flag1 + logical, intent(in) :: flag0, flag1, flag2 real, dimension(nlon) :: work integer :: jl - if(flag0)then + if (flag0) then call some_other_kernel() - elseif(flag1)then + elseif (flag1) then + do jl=start,end + work(jl) = 1. + enddo + elseif (flag2) then do jl=start,end work(jl) = 1. + work(jl) = 2. + enddo + else + do jl=start,end + work(jl) = 41. + work(jl) = 42. enddo endif @@ -595,7 +605,7 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block routine = Subroutine.from_source(fcode_kernel, frontend=frontend) - # check whether pipeline can be applied and works as expected + # check whether pipeline can be applied and works as expected scc_pipeline = SCCVectorPipeline( horizontal=horizontal, vertical=vertical, block_dim=blocking, directive='openacc', trim_vector_sections=trim_vector_sections @@ -611,3 +621,11 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block assert isinstance(conditional.else_body[0].body[0], ir.Comment) assert isinstance(conditional.else_body[0].body[1], ir.Loop) assert conditional.else_body[0].body[1].pragma[0].content.lower() == 'loop vector' + + # Check that all else-bodies have been wrapped + else_bodies = conditional.else_bodies + assert(len(else_bodies) == 3) + for body in else_bodies: + assert isinstance(body[0], ir.Comment) + assert isinstance(body[1], ir.Loop) + assert body[1].pragma[0].content.lower() == 'loop vector' diff --git a/loki/transformations/single_column/vector.py b/loki/transformations/single_column/vector.py index 2764c019f..12c8dcaa4 100644 --- a/loki/transformations/single_column/vector.py +++ b/loki/transformations/single_column/vector.py @@ -140,12 +140,8 @@ def extract_vector_sections(cls, section, horizontal): # we need to prevent that the whole 'else_body' is wrapped in a section, # as 'Conditional's rely on the fact that the first element of the 'else_body' # (if 'has_elseif') is a Conditional itself - if separator.has_elseif and separator.else_body: - subsec_else = cls.extract_vector_sections(separator.else_body[0].body, horizontal) - else: - subsec_else = cls.extract_vector_sections(separator.else_body, horizontal) - if subsec_else: - subsections += subsec_else + for ebody in separator.else_bodies: + subsections += cls.extract_vector_sections(ebody, horizontal) if isinstance(separator, ir.MultiConditional): for body in separator.bodies: From 96cfb7d7021f7960219481adfd0019eaa8189da4 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 9 Oct 2024 10:42:39 +0000 Subject: [PATCH 3/4] IR: Ensure `Conditional.else_bodies` does not return empty tuples --- loki/ir/nodes.py | 2 +- loki/ir/tests/test_ir_nodes.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index aed0e064e..8395d086e 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -712,7 +712,7 @@ def else_bodies(self): """ if self.has_elseif: return (self.else_body[0].body,) + self.else_body[0].else_bodies - return (self.else_body,) + return (self.else_body,) if self.else_body else () @dataclass_strict(frozen=True) diff --git a/loki/ir/tests/test_ir_nodes.py b/loki/ir/tests/test_ir_nodes.py index b9a59dac5..0fc8e66e6 100644 --- a/loki/ir/tests/test_ir_nodes.py +++ b/loki/ir/tests/test_ir_nodes.py @@ -179,6 +179,25 @@ def test_multi_conditional(scope, one, i, n, a_i): assert isinstance(else_bodies[2][0], ir.Assignment) assert else_bodies[2][0].lhs == 'a(i)' and else_bodies[2][0].rhs == '42.0' + # Not try without the final else + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(1)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)), + ) + for idx in range(2, 4): + multicond = ir.Conditional( + condition=sym.Comparison(i, '==', sym.IntLiteral(idx)), + body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))), + else_body=multicond, has_elseif=True + ) + else_bodies = multicond.else_bodies + assert len(else_bodies) == 2 + assert all(isinstance(b, tuple) for b in else_bodies) + assert isinstance(else_bodies[0][0], ir.Assignment) + assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0' + assert isinstance(else_bodies[1][0], ir.Assignment) + assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0' + def test_section(scope, one, i, n, a_n, a_i): """ From 53b208050bd674801a5bd5cab62ffcbae4c0d209 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter <6384870+reuterbal@users.noreply.github.com> Date: Thu, 10 Oct 2024 10:59:40 +0200 Subject: [PATCH 4/4] Update comment Co-authored-by: Michael Staneker <50531288+MichaelSt98@users.noreply.github.com> --- loki/transformations/single_column/vector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/loki/transformations/single_column/vector.py b/loki/transformations/single_column/vector.py index 12c8dcaa4..0578e2b11 100644 --- a/loki/transformations/single_column/vector.py +++ b/loki/transformations/single_column/vector.py @@ -137,8 +137,8 @@ def extract_vector_sections(cls, section, horizontal): subsec_body = cls.extract_vector_sections(separator.body, horizontal) if subsec_body: subsections += subsec_body - # we need to prevent that the whole 'else_body' is wrapped in a section, - # as 'Conditional's rely on the fact that the first element of the 'else_body' + # we need to prevent that all (possibly nested) 'else_bodies' are completely wrapped as a section, + # as 'Conditional's rely on the fact that the first element of each 'else_body' # (if 'has_elseif') is a Conditional itself for ebody in separator.else_bodies: subsections += cls.extract_vector_sections(ebody, horizontal)