Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Refactor some methods using BlockInsertPoint #3704

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions tests/test_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,22 +618,11 @@ def transformation(module: ModuleOp, rewriter: Rewriter) -> None:


def test_verify_inline_region():
block = Block()
region = Region(Block())

with pytest.raises(
ValueError, match="Cannot inline region before a block with no parent"
):
Rewriter.inline_region_before(region, block)

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_before(region, region.block)

with pytest.raises(
ValueError, match="Cannot inline region before a block with no parent"
):
Rewriter.inline_region_after(region, block)

with pytest.raises(ValueError, match="Cannot move region into itself."):
Rewriter.inline_region_after(region, region.block)

Expand Down
48 changes: 19 additions & 29 deletions xdsl/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from xdsl.dialects.builtin import ArrayAttr
from xdsl.ir import Attribute, Block, BlockArgument, Operation, OperationInvT, Region
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter


@dataclass(eq=False)
Expand Down Expand Up @@ -74,36 +74,38 @@ def insert(self, op: OperationInvT) -> OperationInvT:

return op

def create_block_before(
self, insert_before: Block, arg_types: Iterable[Attribute] = ()
def create_block(
self, insert_point: BlockInsertPoint, arg_types: Iterable[Attribute]
math-fehr marked this conversation as resolved.
Show resolved Hide resolved
) -> Block:
"""
Create a block before `insert_before`, and set
the insertion point at the end of the inserted block.
Create a block at the given location, and set the operation insertion point
at the end of the inserted block.
"""
block = Block(arg_types=arg_types)
Rewriter.insert_block_before(block, insert_before)
Rewriter.insert_block(block, insert_point)

self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block

def create_block_before(
self, insert_before: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
"""
Create a block before `insert_before`, and set
the insertion point at the end of the inserted block.
"""
return self.create_block(BlockInsertPoint.before(insert_before), arg_types)

def create_block_after(
self, insert_after: Block, arg_types: Iterable[Attribute] = ()
) -> Block:
"""
Create a block after `insert_after`, and set
the insertion point at the end of the inserted block.
"""

block = Block(arg_types=arg_types)
Rewriter.insert_block_after(block, insert_after)
self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block
return self.create_block(BlockInsertPoint.after(insert_after), arg_types)

def create_block_at_start(
self, region: Region, arg_types: Iterable[Attribute] = ()
Expand All @@ -112,13 +114,7 @@ def create_block_at_start(
Create a block at the start of `region`, and set
the insertion point at the end of the inserted block.
"""
block = Block(arg_types=arg_types)
region.insert_block(block, 0)
self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block
return self.create_block(BlockInsertPoint.at_start(region), arg_types)

def create_block_at_end(
self, region: Region, arg_types: Iterable[Attribute] = ()
Expand All @@ -127,13 +123,7 @@ def create_block_at_end(
Create a block at the end of `region`, and set
the insertion point at the end of the inserted block.
"""
block = Block(arg_types=arg_types)
region.add_block(block)
self.insertion_point = InsertPoint.at_end(block)

self.handle_block_creation(block)

return block
return self.create_block(BlockInsertPoint.at_end(region), arg_types)

@staticmethod
def _region_no_args(func: Callable[[Builder], None]) -> Region:
Expand Down
19 changes: 10 additions & 9 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SSAValue,
)
from xdsl.irdl import GenericAttrConstraint, base
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.rewriter import BlockInsertPoint, InsertPoint, Rewriter
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

Expand Down Expand Up @@ -351,25 +351,26 @@ def move_region_contents_to_new_regions(self, region: Region) -> Region:
self.has_done_action = True
return Rewriter.move_region_contents_to_new_regions(region)

def inline_region(self, region: Region, insertion_point: BlockInsertPoint) -> None:
"""Move the region blocks to the specified insertion point."""
self.has_done_action = True
Rewriter.inline_region(region, insertion_point)

def inline_region_before(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_before(region, target)
self.inline_region(region, BlockInsertPoint.before(target))

def inline_region_after(self, region: Region, target: Block) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_after(region, target)
self.inline_region(region, BlockInsertPoint.after(target))

def inline_region_at_start(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_at_start(region, target)
self.inline_region(region, BlockInsertPoint.at_start(target))

def inline_region_at_end(self, region: Region, target: Region) -> None:
"""Move the region blocks to an existing region."""
self.has_done_action = True
Rewriter.inline_region_at_end(region, target)
self.inline_region(region, BlockInsertPoint.at_end(target))


class RewritePattern(ABC):
Expand Down
56 changes: 28 additions & 28 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field

from xdsl.ir import Block, Operation, Region, SSAValue
Expand Down Expand Up @@ -227,21 +227,27 @@ def inline_block(
parent_region.detach_block(source)
source.erase()

@staticmethod
def insert_block(block: Block | Iterable[Block], insert_point: BlockInsertPoint):
"""
Insert one or multiple blocks at a given location.
The blocks to insert should be detached from any region.
The insertion point should not be contained in the block to insert.
"""
region = insert_point.region
if insert_point.insert_before is not None:
region.insert_block_before(block, insert_point.insert_before)
else:
region.add_block(block)

@staticmethod
def insert_block_after(block: Block | list[Block], target: Block):
"""
Insert one or multiple blocks after another block.
The blocks to insert should be detached from any region.
The target block should not be contained in the block to insert.
"""
if target.parent is None:
raise Exception("Cannot move a block after a toplevel op")
region = target.parent
block_list = block if isinstance(block, list) else [block]
if len(block_list) == 0:
return
pos = region.get_block_index(target)
region.insert_block(block_list, pos + 1)
Rewriter.insert_block(block, BlockInsertPoint.after(target))

@staticmethod
def insert_block_before(block: Block | list[Block], target: Block):
Expand All @@ -250,12 +256,7 @@ def insert_block_before(block: Block | list[Block], target: Block):
The blocks to insert should be detached from any region.
The target block should not be contained in the block to insert.
"""
if target.parent is None:
raise Exception("Cannot move a block after a toplevel op")
region = target.parent
block_list = block if isinstance(block, list) else [block]
pos = region.get_block_index(target)
region.insert_block(block_list, pos)
Rewriter.insert_block(block, BlockInsertPoint.before(target))

@staticmethod
def insert_op(
Expand All @@ -275,31 +276,30 @@ def move_region_contents_to_new_regions(region: Region) -> Region:
region.move_blocks(new_region)
return new_region

@staticmethod
def inline_region(region: Region, insertion_point: BlockInsertPoint) -> None:
"""Move the region blocks to a given location."""
if insertion_point.insert_before is not None:
region.move_blocks_before(insertion_point.insert_before)
else:
region.move_blocks(insertion_point.region)

@staticmethod
def inline_region_before(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, before `target`."""
region.move_blocks_before(target)
Rewriter.inline_region(region, BlockInsertPoint.before(target))

@staticmethod
def inline_region_after(region: Region, target: Block) -> None:
"""Move the region blocks to an existing region, after `target`."""
if target.next_block is not None:
Rewriter.inline_region_before(region, target.next_block)
else:
parent_region = target.parent
if parent_region is None:
raise ValueError("Cannot inline region before a block with no parent")
region.move_blocks(region)
Rewriter.inline_region(region, BlockInsertPoint.after(target))

@staticmethod
def inline_region_at_start(region: Region, target: Region) -> None:
"""Move the region blocks to the start of an existing region."""
if target.first_block is not None:
Rewriter.inline_region_before(region, target.first_block)
else:
Rewriter.inline_region_at_end(region, target)
Rewriter.inline_region(region, BlockInsertPoint.at_start(target))

@staticmethod
def inline_region_at_end(region: Region, target: Region) -> None:
"""Move the region blocks to the end of an existing region."""
region.move_blocks(target)
Rewriter.inline_region(region, BlockInsertPoint.at_end(target))
Loading