Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: preorder walk of blocks in nested operations #3367

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/test_preorder_walk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest

from xdsl.context import MLContext
from xdsl.dialects.arith import Arith
from xdsl.dialects.builtin import Builtin
from xdsl.dialects.func import Func, FuncOp
from xdsl.dialects.scf import For, If, Scf
from xdsl.parser import Parser

test_prog = """
"func.func"() <{function_type = (i1, i32, i32) -> i32, sym_name = "example_func"}> ({
^bb0(%arg0: i1, %arg1: i32, %arg2: i32):
%0 = "scf.if"(%arg0) ({
%1 = "arith.constant"() <{value = 42 : i32}> : () -> i32
%2 = "arith.constant"() <{value = true}> : () -> i1
%3 = "scf.if"(%2) ({
%4 = "arith.constant"() <{value = 84 : i32}> : () -> i32
"scf.yield"(%4) : (i32) -> ()
}, {
%4 = "arith.constant"() <{value = 21 : i32}> : () -> i32
"scf.yield"(%4) : (i32) -> ()
}) : (i1) -> i32
"scf.yield"(%3) : (i32) -> ()
}, {
%1 = "arith.index_cast"(%arg1) : (i32) -> index
%2 = "arith.index_cast"(%arg2) : (i32) -> index
%3 = "arith.constant"() <{value = 0 : i32}> : () -> i32
"scf.for"(%1, %2, %1) ({
^bb0(%arg3: index):
%4 = "arith.index_cast"(%arg3) : (index) -> i32
%5 = "arith.constant"() <{value = 10 : i32}> : () -> i32
%6 = "arith.constant"() <{value = false}> : () -> i1
"scf.if"(%6) ({
%7 = "arith.constant"() <{value = 100 : i32}> : () -> i32
"scf.yield"() : () -> ()
}, {
%7 = "arith.constant"() <{value = 200 : i32}> : () -> i32
"scf.yield"() : () -> ()
}) : (i1) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"scf.yield"(%3) : (i32) -> ()
}) : (i1) -> i32
"func.return"(%0) : (i32) -> ()
}) : () -> ()
"""


def test_preorder_walk():
ctx = MLContext()
ctx.load_dialect(Builtin)
ctx.load_dialect(Arith)
ctx.load_dialect(Func)
ctx.load_dialect(Scf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please change this test to work with only the test dialect, to avoid spurious dependencies?


parser = Parser(ctx, test_prog)
op = parser.parse_op()
Comment on lines +94 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry to be annoying, but do you think you could change this from parsing the text to constructing it in Python? Something like this would make the test a little easier to read and understand, IMO. then we could just assert tuple(op.walk_blocks()) == (...)


assert isinstance(op, FuncOp)

first_if = op.body.block.ops.first
assert isinstance(first_if, If)
second_if = list(first_if.true_region.block.ops)[2]
assert isinstance(second_if, If)
for_loop = list(first_if.false_region.block.ops)[3]
assert isinstance(for_loop, For)
third_if = list(for_loop.body.block.ops)[3]
assert isinstance(third_if, If)

it = op.walk_blocks_preorder()
assert next(it) == op.body.block
assert next(it) == first_if.true_region.block
assert next(it) == second_if.true_region.block
assert next(it) == second_if.false_region.block
assert next(it) == first_if.false_region.block
assert next(it) == for_loop.body.block
assert next(it) == third_if.true_region.block
assert next(it) == third_if.false_region.block

with pytest.raises(StopIteration):
next(it)
7 changes: 7 additions & 0 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,13 @@ def walk(
if region_first:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps rename the post_order.py file to traversals.py and move it there?

yield self

def walk_blocks_preorder(self) -> Iterator[Block]:
for region in self.regions:
for block in region.blocks:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not what I would have expected a preorder traversal of blocks to do. Does llvm simply iterate over them in the order they appear?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it, it would probably be worth adding a walk method on Block that mirrors the API of the Operation walk, and to change this to forward to that one. I also prefer our walk_regions_first to Preorder as it's more declarative, maybe for blocks it could be walk_child_blocks_first? pre-post order doesn't make sense in the absence of left and right children

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I more meant that I would have expected a "traversal" to explore blocks via a depth first search of successors from each block, rather than just return each block in order

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no that's not what I would expect TBH. @gabrielrodcanal, if you don't feel like implementing reversed iteration and post-order then I think it would still be worth adding walk on Block, and walk_blocks on Operation, and we can add the optional parameters later without API breaking.

yield block
for op in block.ops:
yield from op.walk_blocks_preorder()

def get_attr_or_prop(self, name: str) -> Attribute | None:
"""
Get a named attribute or property.
Expand Down
Loading