Skip to content

Commit

Permalink
IR: Add Conditional.else_bodies property to gather nested bodies
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 10, 2024
1 parent 3834903 commit e8705b5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
10 changes: 10 additions & 0 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,16 @@ def __repr__(self):
return f'Conditional:: {self.name}'
return 'Conditional::'

@property
def else_bodies(self):
"""
Return all nested node tuples in the ``ELSEIF``/``ELSE`` part
of the conditional chain.
"""
if self.has_elseif:
return (self.else_body[0].body,) + self.else_body[0].else_bodies
return (self.else_body,)


@dataclass_strict(frozen=True)
class _PragmaRegionBase():
Expand Down
29 changes: 29 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ def test_conditional(scope, one, i, n, a_i):
# TODO: Test inline, name, has_elseif


def test_multi_conditional(scope, one, i, n, a_i):
"""
Test nested chains of constructors of :any:`Conditional` to form
multi-conditional.
"""
multicond = ir.Conditional(
condition=sym.Comparison(i, '==', sym.IntLiteral(1)),
body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)),
else_body=ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
)
for idx in range(2, 4):
multicond = ir.Conditional(
condition=sym.Comparison(i, '==', sym.IntLiteral(idx)),
body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))),
else_body=multicond, has_elseif=True
)

# Check that we can recover all bodies from a nested else-if construct
else_bodies = multicond.else_bodies
assert len(else_bodies) == 3
assert all(isinstance(b, tuple) for b in else_bodies)
assert isinstance(else_bodies[0][0], ir.Assignment)
assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0'
assert isinstance(else_bodies[1][0], ir.Assignment)
assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0'
assert isinstance(else_bodies[2][0], ir.Assignment)
assert else_bodies[2][0].lhs == 'a(i)' and else_bodies[2][0].rhs == '42.0'


def test_section(scope, one, i, n, a_n, a_i):
"""
Test constructors and behaviour of :any:`Section` nodes.
Expand Down

0 comments on commit e8705b5

Please sign in to comment.