diff --git a/tests/filecheck/transforms/liveness.mlir b/tests/filecheck/transforms/liveness.mlir new file mode 100644 index 0000000000..fbd42629d7 --- /dev/null +++ b/tests/filecheck/transforms/liveness.mlir @@ -0,0 +1,487 @@ +// RUN: xdsl-opt -t liveness %s --split-input-file | filecheck %s + +func.func @func_empty() { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK-NEXT: EndCurrentlyLive + return +} + +// ----- + +func.func @func_simpleBranch(%arg0: i32, %arg1 : i32) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg0@0 arg1@0 + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: cf.br + // CHECK-SAME: arg0@0 arg1@0 + // CHECK-NEXT: EndCurrentlyLive + cf.br ^exit +^exit: + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLivenessIntervals + // CHECK: val_2 + // CHECK-NEXT: %result = arith.addi + // CHECK-NEXT: return + // CHECK-NEXT: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg1@0 val_2 + // CHECK: return + // CHECK-SAME: val_2 + // CHECK-NEXT:EndCurrentlyLive + %result = arith.addi %arg0, %arg1 : i32 + return %result : i32 +} + +// ----- + +func.func @func_condBranch(%cond : i1, %arg1: i32, %arg2 : i32) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg1@0 arg2@0 + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: cf.cond_br + // CHECK-SAME: arg0@0 arg1@0 arg2@0 + // CHECK-NEXT: EndCurrentlyLive + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg1@0 arg2@0 + // CHECK-NEXT: LiveOut: arg1@0 arg2@0 + // CHECK: BeginCurrentlyLive + // CHECK: cf.br + // COM: arg0@0 had its last user in the previous block. + // CHECK-SAME: arg1@0 arg2@0 + // CHECK-NEXT: EndCurrentlyLive + cf.br ^exit +^bb2: + // CHECK: Block: 2 + // CHECK-NEXT: LiveIn: arg1@0 arg2@0 + // CHECK-NEXT: LiveOut: arg1@0 arg2@0 + // CHECK: BeginCurrentlyLive + // CHECK: cf.br + // CHECK-SAME: arg1@0 arg2@0 + // CHECK-NEXT: EndCurrentlyLive + cf.br ^exit +^exit: + // CHECK: Block: 3 + // CHECK-NEXT: LiveIn: arg1@0 arg2@0 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLivenessIntervals + // CHECK: val_3 + // CHECK-NEXT: %result = arith.addi + // CHECK-NEXT: return + // CHECK-NEXT: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: arith.addi + // CHECK-SAME: arg1@0 arg2@0 val_3 + // CHECK: return + // CHECK-SAME: val_3 + // CHECK-NEXT: EndCurrentlyLive + %result = arith.addi %arg1, %arg2 : i32 + return %result : i32 +} + +// ----- + +func.func @func_loop(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg1@0 + // CHECK: BeginCurrentlyLive + // CHECK: arith.constant + // CHECK-SAME: arg0@0 arg1@0 val_2 + // CHECK: cf.br + // CHECK-SAME: arg0@0 arg1@0 val_2 + // CHECK-NEXT: EndCurrentlyLive + %const0 = arith.constant 0 : i32 + cf.br ^loopHeader(%const0, %arg0 : i32, i32) +^loopHeader(%counter : i32, %i : i32): + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg1@0 + // CHECK-NEXT: LiveOut: arg1@0 arg0@1 + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_5 + // CHECK-NEXT: %lessThan = arith.cmpi + // CHECK-NEXT: cf.cond_br + // CHECK-NEXT: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: arith.cmpi + // CHECK-SAME: arg1@0 arg0@1 arg1@1 val_5 + // CHECK: cf.cond_br + // CHECK-SAME: arg1@0 arg0@1 arg1@1 val_5 + // CHECK-NEXT: EndCurrentlyLive + %lessThan = arith.cmpi slt, %counter, %arg1 : i32 + cf.cond_br %lessThan, ^loopBody(%i : i32), ^exit(%i : i32) +^loopBody(%val : i32): + // CHECK: Block: 2 + // CHECK-NEXT: LiveIn: arg1@0 arg0@1 + // CHECK-NEXT: LiveOut: arg1@0 + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_7 + // CHECK-NEXT: %c + // CHECK-NEXT: %inc = arith.addi + // CHECK-NEXT: %inc2 = arith.addi + // CHECK-NEXT: val_8 + // CHECK-NEXT: %inc = arith.addi + // CHECK-NEXT: %inc2 = arith.addi + // CHECK-NEXT: cf.br + // CHECK: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: arith.constant + // CHECK-SAME: arg1@0 arg0@1 arg0@2 val_7 + // CHECK: arith.addi + // CHECK-SAME: arg1@0 arg0@1 arg0@2 val_7 val_8 + // CHECK: arith.addi + // CHECK-SAME: arg1@0 arg0@1 val_7 val_8 val_9 + // CHECK: cf.br + // CHECK-SAME: arg1@0 val_8 val_9 + // CHECK-NEXT: EndCurrentlyLive + %const1 = arith.constant 1 : i32 + %inc = arith.addi %val, %const1 : i32 + %inc2 = arith.addi %counter, %const1 : i32 + cf.br ^loopHeader(%inc, %inc2 : i32, i32) +^exit(%sum : i32): + // CHECK: Block: 3 + // CHECK-NEXT: LiveIn: arg1@0 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK: BeginCurrentlyLive + // CHECK: arith.addi + // CHECK-SAME: arg1@0 arg0@3 val_11 + // CHECK: return + // CHECK-SAME: val_11 + // CHECK-NEXT: EndCurrentlyLive + %result = arith.addi %sum, %arg1 : i32 + return %result : i32 +} + +// ----- + +func.func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg2@0 val_9 val_10 + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_4 + // CHECK-NEXT: %0 = arith.addi + // CHECK-NEXT: %c + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: %2 = arith.addi + // CHECK-NEXT: %3 = arith.muli + // CHECK-NEXT: val_5 + // CHECK-NEXT: %c + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: %2 = arith.addi + // CHECK-NEXT: %3 = arith.muli + // CHECK-NEXT: %4 = arith.muli + // CHECK-NEXT: %5 = arith.addi + // CHECK-NEXT: val_6 + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: %2 = arith.addi + // CHECK-NEXT: %3 = arith.muli + // CHECK-NEXT: val_7 + // CHECK-NEXT: %2 = arith.addi + // CHECK-NEXT: %3 = arith.muli + // CHECK-NEXT: %4 = arith.muli + // CHECK: val_8 + // CHECK-NEXT: %3 = arith.muli + // CHECK-NEXT: %4 = arith.muli + // CHECK-NEXT: val_9 + // CHECK-NEXT: %4 = arith.muli + // CHECK-NEXT: %5 = arith.addi + // CHECK-NEXT: cf.cond_br + // CHECK-NEXT: %c + // CHECK-NEXT: %6 = arith.muli + // CHECK-NEXT: %7 = arith.muli + // CHECK-NEXT: %8 = arith.addi + // CHECK-NEXT: val_10 + // CHECK-NEXT: %5 = arith.addi + // CHECK-NEXT: cf.cond_br + // CHECK-NEXT: %7 + // CHECK: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg3@0 val_4 + // CHECK: arith.constant + // CHECK-SAME: arg0@0 arg2@0 arg3@0 val_4 val_5 + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg2@0 arg3@0 val_4 val_5 val_6 + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg2@0 arg3@0 val_4 val_5 val_6 val_7 + // CHECK: arith.muli + // CHECK-SAME: arg0@0 arg2@0 val_4 val_5 val_6 val_7 val_8 + // CHECK: arith.muli + // CHECK-SAME: arg0@0 arg2@0 val_5 val_7 val_8 val_9 + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg2@0 val_5 val_9 val_10 + // CHECK: cf.cond_br + // CHECK-SAME: arg0@0 arg2@0 val_9 val_10 + // CHECK-NEXT: EndCurrentlyLive + %0 = arith.addi %arg1, %arg2 : i32 + %const1 = arith.constant 1 : i32 + %1 = arith.addi %const1, %arg2 : i32 + %2 = arith.addi %const1, %arg3 : i32 + %3 = arith.muli %0, %1 : i32 + %4 = arith.muli %3, %2 : i32 + %5 = arith.addi %4, %const1 : i32 + cf.cond_br %cond, ^bb1, ^bb2 + +^bb1: + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg2@0 val_9 + // CHECK-NEXT: LiveOut: arg2@0 + // CHECK: BeginCurrentlyLive + // CHECK: arith.constant + // CHECK-SAME: arg2@0 val_9 + // CHECK: arith.muli + // CHECK-SAME: arg2@0 val_9 + // CHECK: cf.br + // CHECK-SAME: arg2@0 + // CHECK-NEXT: EndCurrentlyLive + %const4 = arith.constant 4 : i32 + %6 = arith.muli %4, %const4 : i32 + cf.br ^exit(%6 : i32) + +^bb2: + // CHECK: Block: 2 + // CHECK-NEXT: LiveIn: arg2@0 val_9 val_10 + // CHECK-NEXT: LiveOut: arg2@0 + // CHECK: BeginCurrentlyLive + // CHECK: arith.muli + // CHECK-SAME: arg2@0 val_9 val_10 + // CHECK: arith.addi + // CHECK-SAME: arg2@0 + // CHECK: cf.br + // CHECK-SAME: arg2@0 + // CHECK: EndCurrentlyLive + %7 = arith.muli %4, %5 : i32 + %8 = arith.addi %4, %arg2 : i32 + cf.br ^exit(%8 : i32) + +^exit(%sum : i32): + // CHECK: Block: 3 + // CHECK-NEXT: LiveIn: arg2@0 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK: BeginCurrentlyLive + // CHECK: arith.addi + // CHECK-SAME: arg2@0 + // CHECK: return + // CHECK-NOT: arg2@0 + // CHECK: EndCurrentlyLive + %result = arith.addi %sum, %arg2 : i32 + return %result : i32 +} + +// ----- + + +func.func @nested_region( + %arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : i32, %arg4 : i32, %arg5 : i32, + %buffer : memref) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_7 + // CHECK-NEXT: %0 = arith.addi + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // %2 = arith.addi + // CHECK-NEXT: %3 = arith.addi + // CHECK-NEXT: val_8 + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // func.return %1 + // CHECK: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg3@0 arg4@0 arg5@0 arg6@0 val_7 + // CHECK: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg4@0 arg5@0 arg6@0 val_7 val_8 + // CHECK: scf.for + // CHECK-NEXT: arith.addi + // CHECK-NEXT: arith.addi + // CHECK-NEXT: memref.store + // CHECK-NEXT: arg5@0 arg6@0 val_7 val_8 + // CHECK: return + // CHECK-SAME: val_8 + // CHECK-NEXT: EndCurrentlyLive + %0 = arith.addi %arg3, %arg4 : i32 + %1 = arith.addi %arg4, %arg5 : i32 + scf.for %arg6 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_7 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg5@0 arg6@0 val_7 arg0@1 val_10 + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg6@0 val_7 val_10 val_11 + // CHECK-NEXT: memref.store + // CHECK-SAME: arg6@0 val_11 + // CHECK-NEXT: EndCurrentlyLive + %2 = arith.addi %0, %arg5 : i32 + %3 = arith.addi %2, %0 : i32 + memref.store %3, %buffer[] : memref + } + return %1 : i32 +} + +// ----- + + +func.func @nested_region2( + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_7 + // CHECK-NEXT: %0 = arith.addi + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // %2 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // %3 = arith.addi + // CHECK-NEXT: val_8 + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // func.return %1 + // CHECK: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg3@0 arg4@0 arg5@0 arg6@0 val_7 + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg4@0 arg5@0 arg6@0 val_7 val_8 + // CHECK-NEXT: scf.for {{.*}} + // CHECK-NEXT: arith.addi + // CHECK-NEXT: scf.for {{.*}} { + // CHECK-NEXT: arith.addi + // CHECK-NEXT: memref.store + // CHECK-NEXT: } + // CHECK-NEXT: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_7 val_8 + // CHECK-NEXT: return + // CHECK-SAME: val_8 + %arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : i32, %arg4 : i32, %arg5 : i32, + %buffer : memref) -> i32 { + %0 = arith.addi %arg3, %arg4 : i32 + %1 = arith.addi %arg4, %arg5 : i32 + scf.for %arg6 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_7 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_10 + // CHECK-NEXT: %2 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // %3 = arith.addi + // CHECK: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_7 arg0@1 val_10 + // CHECK-NEXT: scf.for {{.*}} + // CHECK-NEXT: arith.addi + // CHECK-NEXT: memref.store + // CHECK-NEXT: arg0@0 arg1@0 arg2@0 arg6@0 val_7 + %2 = arith.addi %0, %arg5 : i32 + scf.for %arg7 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 2 + // CHECK: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg6@0 val_7 val_10 arg0@2 val_12 + // CHECK-NEXT: memref.store + // CHECK-SAME: arg6@0 val_12 + // CHECK: EndCurrentlyLive + %3 = arith.addi %2, %0 : i32 + memref.store %3, %buffer[] : memref + } + } + return %1 : i32 +} + +// ----- + + +func.func @nested_region3( + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8 + // CHECK-NEXT: BeginLivenessIntervals + // CHECK-NEXT: val_7 + // CHECK-NEXT: %0 = arith.addi + // CHECK-NEXT: %1 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // cf.br ^0 + // CHECK-NEXT: %2 = arith.addi + // CHECK-NEXT: scf.for + // CHECK: // %3 = arith.addi + // CHECK: EndLivenessIntervals + // CHECK-NEXT: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg3@0 arg4@0 arg5@0 arg6@0 val_7 + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg4@0 arg5@0 arg6@0 val_7 val_8 + // CHECK-NEXT: scf.for + // COM: Skipping the body of the scf.for... + // CHECK: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_7 val_8 + // CHECK-NEXT: cf.br + // CHECK-SAME: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8 + // CHECK-NEXT: EndCurrentlyLive + %arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : i32, %arg4 : i32, %arg5 : i32, + %buffer : memref) -> i32 { + %0 = arith.addi %arg3, %arg4 : i32 + %1 = arith.addi %arg4, %arg5 : i32 + scf.for %arg6 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_7 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg5@0 arg6@0 val_7 arg0@1 val_10 + // CHECK-NEXT: memref.store + // CHECK-SAME: arg6@0 val_10 + // CHECK-NEXT: EndCurrentlyLive + %2 = arith.addi %0, %arg5 : i32 + memref.store %2, %buffer[] : memref + } + cf.br ^exit + +^exit: + // CHECK: Block: 2 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK: BeginCurrentlyLive + // CHECK: scf.for + // CHECK: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8 + // CHECK-NEXT: return + // CHECK-SAME: val_8 + // CHECK-NEXT: EndCurrentlyLive + scf.for %arg7 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 3 + // CHECK-NEXT: LiveIn: arg6@0 val_7 val_8 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK: BeginCurrentlyLive + // CHECK-NEXT: arith.addi + // CHECK-SAME: arg6@0 val_7 val_8 arg0@3 val_12 + // CHECK-NEXT: memref.store + // CHECK-SAME: arg6@0 val_12 + // CHECK-NEXT: EndCurrentlyLive + %2 = arith.addi %0, %1 : i32 + memref.store %2, %buffer[] : memref + } + return %1 : i32 +} diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index 22f915207a..2c7288eac2 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -951,6 +951,16 @@ def drop_all_references(self) -> None: for region in self.regions: region.drop_all_references() + def get_parent_of_type(self, parent_type: type[Operation]) -> Operation | None: + current_op = self + + while parent := current_op.parent_op(): + if isinstance(parent, parent_type): + return parent + current_op = parent + + return None + def walk( self, *, reverse: bool = False, region_first: bool = False ) -> Iterator[Operation]: @@ -967,6 +977,13 @@ def walk( if region_first: yield self + def walk_blocks_preorder(self) -> Iterator[Block]: + for region in self.regions: + for block in region.blocks: + yield block + for op in block.ops: + yield from op.walk_blocks_preorder() + def get_attr_or_prop(self, name: str) -> Attribute | None: """ Get a named attribute or property. @@ -978,6 +995,17 @@ def get_attr_or_prop(self, name: str) -> Attribute | None: return self.attributes[name] return None + def is_before_in_block(self, other_op: Operation) -> bool: + parent_block = self.parent_block() + assert isinstance(parent_block, Block) + + if parent_block.get_operation_index(self) < parent_block.get_operation_index( + other_op + ): + return True + else: + return False + def verify(self, verify_nested_ops: bool = True) -> None: for operand in self.operands: if isinstance(operand, ErasedSSAValue): @@ -1407,6 +1435,11 @@ def predecessors(self) -> tuple[Block, ...]: p for use in self.uses if (p := use.operation.parent_block()) is not None ) + def successors(self) -> tuple[Block, ...]: + terminator = self.last_op + assert isinstance(terminator, Operation) + return tuple(successor for successor in terminator.successors) + def parent_op(self) -> Operation | None: return self.parent.parent if self.parent else None @@ -1721,6 +1754,19 @@ def erase(self, safe_erase: bool = True) -> None: for op in self.ops: op.erase(safe_erase=safe_erase, drop_references=False) + def get_terminator(self) -> Operation | None: + if self.last_op and self.last_op.has_trait(IsTerminator): + return self.last_op + else: + return None + + def find_ancestor_op_in_block(self, other_op: Operation) -> Operation | None: + for op in self.ops: + if op.is_ancestor(other_op): + return op + + return None + def is_structurally_equivalent( self, other: IRNode, @@ -1945,6 +1991,17 @@ def parent_region(self) -> Region | None: else None ) + def find_ancestor_block_in_region(self, block: Block) -> Block | None: + curr_block = block + while curr_block.parent_region() != self: + parent_op = curr_block.parent_op() + if not parent_op or not parent_op.parent_block(): + return None + curr_block = parent_op.parent_block() + assert isinstance(curr_block, Block) + + return curr_block + @property def blocks(self) -> RegionBlocks: """ diff --git a/xdsl/transforms/experimental/liveness.py b/xdsl/transforms/experimental/liveness.py new file mode 100644 index 0000000000..28680d17b4 --- /dev/null +++ b/xdsl/transforms/experimental/liveness.py @@ -0,0 +1,446 @@ +from dataclasses import dataclass, field +from typing import IO, cast + +from xdsl.dialects import builtin, func +from xdsl.ir import Block, BlockArgument, Operation, Region, SSAValue +from xdsl.printer import Printer + + +class BlockInfoBuilder: + def __init__(self, block: Block): + self.out_values: set[SSAValue] = set() + self.def_values: set[SSAValue] = set() + self.use_values: set[SSAValue] = set() + self.in_values: set[SSAValue] = set() + self.block = block + + def gather_out_values(value: SSAValue): + # Check whether this value will be in the outValues set (its uses escape + # this block). Due to the SSA properties of the program, the uses must + # occur after the definition. Therefore, we do not have to check + # additional conditions to detect an escaping value. + for use_op in [use.operation for use in value.uses]: + owner_block = use_op.parent_block() + # Find an owner block in the current region. Note that a value does not + # escape this block if it is used in a nested region. + parent_region = block.parent_region() + assert isinstance(parent_region, Region) + assert isinstance(owner_block, Block) + owner_block = parent_region.find_ancestor_block_in_region(owner_block) + assert owner_block + assert "Use leaves the current parent region" + if owner_block != block: + self.out_values.add(value) + break + + # Mark all block arguments (phis) as defined + for argument in block.args: + # Insert value into the set of defined values + self.def_values.add(argument) + + # Gather all out values of all arguments in the current block. + gather_out_values(argument) + + # Gather out values of all operations in the current block. + for operation in block.ops: + for result in operation.results: + gather_out_values(result) + + # Mark all nested operation results as defined, and nested operation + # operands as used. All defined value will be removed from the used set + # at the end. + for op in block.walk(): + for result in op.results: + self.def_values.add(result) + for operand in op.operands: + self.use_values.add(operand) + self.use_values = self.use_values.difference(self.def_values) + + # Updates live-in information of the current block. To do so it uses the + # default liveness-computation formula: newIn = use union out \ def. The + # methods returns true, if the set has changed (newIn != in), false + # otherwise. + def update_livein(self): + new_in = self.use_values + new_in = new_in.union(self.out_values) + new_in = new_in.difference(self.def_values) + + # It is sufficient to check the set sizes (instead of their contents) since + # the live-in set can only grow monotonically during all update operations. + if len(new_in) == len(self.in_values): + return False + + self.in_values = new_in.copy() + return True + + # Updates live-out information of the current block. It iterates over all + # successors and unifies their live-in values with the current live-out + # values. + def update_liveout(self, builders: dict[Block, "BlockInfoBuilder"]): + for succ in self.block.successors(): + builder = builders[succ] + self.out_values = self.out_values.union(builder.in_values) + + +# Builds the internal liveness block mapping. +def build_block_mapping(operation: Operation) -> dict[Block, BlockInfoBuilder]: + to_process: set[Block] = set() + builders: dict[Block, BlockInfoBuilder] = dict() + + # for op in operation.walk(): + # for block in [block for region in op.regions for block in region.blocks]: + for block in operation.walk_blocks_preorder(): + assert isinstance(block, Block) + if block not in builders: + builders[block] = BlockInfoBuilder(block) + + builder = builders[block] + + if builder.update_livein(): + list( + map( + lambda x: to_process.add(x), [pred for pred in block.predecessors()] + ) + ) + + # Propagate the in and out-value sets (fixpoint iteration). + while to_process: + current = to_process.pop() + builder = builders[current] + + # Update the current out values. + builder.update_liveout(builders) + + # Compute (potentially) updated live in values. + if builder.update_livein(): + list( + map( + lambda x: to_process.add(x), + [pred for pred in current.predecessors()], + ) + ) + + return builders + + +# ===----------------------------------------------------------------------===// +# LivenessBlockInfo +# ===----------------------------------------------------------------------===// + + +# This class represents liveness information on block level. +@dataclass +class LivenessBlockInfo: + in_values: set[SSAValue] = field(default_factory=set) + out_values: set[SSAValue] = field(default_factory=set) + + def __init__(self, block: Block): + self.block = block + + # Returns True if the given value is in the live-in set. + def is_livein(self, value: SSAValue): + return value in self.in_values + + # Returns True if the given vlaue is in the live-out set. + def is_liveout(self, value: SSAValue): + return value in self.out_values + + # Gets the start operation for the given value (must be referenced in this block). + def get_start_operation(self, value: SSAValue): + defining_op = value.owner if isinstance(value.owner, Operation) else None + # The given value is either live-in or is defined in the scope of this block + if self.is_livein(value) or not defining_op: + return self.block.first_op + + return defining_op + + # Gets the end operation for the given value using the start operation provided ( + # must be referenced in this block) + def get_end_operation(self, value: SSAValue, start_operation: Operation): + # The given value is either dying in this block or live-out. + if self.is_liveout(value): + return self.block.last_op + + # Resolve the last operation (must exist by definition). + end_operation = start_operation + for use_op in [use.operation for use in value.uses]: + use_op = self.block.find_ancestor_op_in_block(use_op) + # Check whether the use is in our block and after the current end operation. + if use_op and end_operation.is_before_in_block(use_op): + end_operation = use_op + + return end_operation + + # Return the values that are currently live as of the given operation. + def currently_live_values(self, op: Operation, output: IO[str]): + live_set: set[SSAValue] = set() + + # Given a value, check which ops are within its live range. For each of + # those ops, add the value to the set of live values as-of that op + def add_value_to_currently_live_sets(value: SSAValue): + start_of_live_range = ( + value.owner if isinstance(value.owner, Operation) else None + ) + end_of_live_range = None + + # If it's a live in or a block argument, then the start is the beginning of + # the block. + if self.is_livein(value) or isinstance(value, BlockArgument): + start_of_live_range = self.block.first_op + else: + assert isinstance(start_of_live_range, Operation) + start_of_live_range = self.block.find_ancestor_op_in_block( + start_of_live_range + ) + + # If it's a live out, then the end is the back of the block. + if self.is_liveout(value): + end_of_live_range = self.block.last_op + + # We must have at least a start_of_live_range at this point. Given this, we can + # use the existing get_end_operation to find the end of the live range. + if start_of_live_range and not end_of_live_range: + end_of_live_range = self.get_end_operation(value, start_of_live_range) + + assert end_of_live_range + assert "Must have end_of_live_range at this point!" + # If this op is within the live range, insert the value into the set. + assert isinstance(start_of_live_range, Operation) + if not ( + op.is_before_in_block(start_of_live_range) + or end_of_live_range.is_before_in_block(op) + ): + live_set.add(value) + + # Handle block arguments if any. + for arg in self.block.args: + add_value_to_currently_live_sets(arg) + + # Handle live-ins. Between the live ins and all the op results that gives us every value + # in the block. + for in_val in self.in_values: + add_value_to_currently_live_sets(in_val) + + # Now walk the block and handle all the values used in the block and values defined by the + # block. + for _op in self.block.ops: + for result in _op.results: + add_value_to_currently_live_sets(result) + + return live_set + + +# ===----------------------------------------------------------------------===// +# Liveness +# ===----------------------------------------------------------------------===// +block_mapping: dict[Block, LivenessBlockInfo] = dict() + + +# Creates a new Liveness analysis that computes liveness information for all +# associated regions. +@dataclass +class Liveness: + operation: Operation + + def __init__(self, op: Operation): + self.operation = op + self.build(op) + + # Initializes the internal mappings + def build(self, op: Operation): + # Build internal block mapping + builders: dict[Block, BlockInfoBuilder] = dict() + builders = build_block_mapping(op) + + # Store internal block data + for block in builders: + builder = builders[block] + block_mapping[block] = LivenessBlockInfo(block) + + block_mapping[block].block = builder.block + block_mapping[block].in_values = builder.in_values.copy() + block_mapping[block].out_values = builder.out_values.copy() + + # Gets liveness info (if any) for the given value. + def resolve_liveness(self, value: SSAValue, output: IO[str]) -> list[Operation]: + to_process: list[Block] = [] + visited: set[Block] = set() + result: list[Operation] = [] + + # Start with the defining block + if isinstance(def_op := value.owner, Operation): + current_block = def_op.parent_block() + else: + assert isinstance(value, Block) + current_block = cast(BlockArgument, value).owner + + assert isinstance(current_block, Block) + to_process.append(current_block) + visited.add(current_block) + + # Start with all associated blocks. + for use in value.uses: + use_block = use.operation.parent_block() + assert isinstance(use_block, Block) + if use_block not in visited: + to_process.append(use_block) + visited.add(use_block) + while to_process: + # Get block and block liveness information + block = to_process[-1] + to_process.pop() + block_info = self.get_liveness(block) + + # Note that start and end will be in the same block. + start = block_info.get_start_operation(value) + assert isinstance(start, Operation) + end = block_info.get_end_operation(value, start) + + assert start + result.append(start) + while start != end: + start = start.next_op + assert isinstance(start, Operation) + result.append(start) + + for successor in block.successors(): + if ( + self.get_liveness(successor).is_livein(value) + and successor not in visited + ): + to_process.append(successor) + visited.add(successor) + return result + + # Gets liveness info (if any) for the block. + def get_liveness(self, block: Block): + it = block_mapping[block] + + # FIXME: fix for the case when there is not info + return it + + # Returns a reference to a set containing live-in values. + def get_livein(self, block: Block): + self.get_liveness(block).in_values + + # Returns a reference to a set containing live-out values. + def get_liveoiut(self, block: Block): + self.get_liveness(block).out_values + + # Returns true if `value` is not live after `operation`. + def is_dead_after(self, value: SSAValue, operation: Operation): + block = operation.parent_block() + assert isinstance(block, Block) + block_info = self.get_liveness(block) + + # The given value escapes the associated block. + if block_info.is_liveout(value): + return False + + end_operation = block_info.get_end_operation(value, operation) + assert isinstance(end_operation, Operation) + # If the operation is a real user of `value` the first check is sufficient. + # If not, we will have to test whether the end operation is executed before + # the given operation in the block. + return end_operation == operation or end_operation.is_before_in_block(operation) + + # Dumps the liveness information in a human readable format. + # TODO: dump() + + # Dumps the liveness information to the given stream. + def print(self, output: IO[str], printer: Printer): + print("// ---- Liveness ----", file=output) + + # Builds unique block/value mappings for testing purposes. + block_ids: dict[Block, int] = dict() + operation_ids: dict[Operation, int] = dict() + value_ids: dict[SSAValue, int] = dict() + + for block in self.operation.walk_blocks_preorder(): + assert isinstance(block, Block) + block_ids[block] = len(block_ids) + + for argument in block.args: + value_ids[argument] = len(value_ids) + + for operation in block.ops: + operation_ids[operation] = len(operation_ids) + for result in operation.results: + value_ids[result] = len(value_ids) + + # Local printing helpers + def print_value_ref(value: SSAValue): + if isinstance(value.owner, Operation): + print(f"val_{value_ids[value]}", file=output, end="") + else: + block_arg = cast(BlockArgument, value) + print( + f"arg{block_arg.index}@{block_ids[block_arg.owner]}", + file=output, + end="", + ) + + print(" ", file=output, end="") + + def print_value_refs(values: set[SSAValue]): + ordered_values: list[SSAValue] = list(values) + + ordered_values.sort(key=lambda x: value_ids[x]) + for value in ordered_values: + print_value_ref(value) + + # Dump information about in and out values. + for block in self.operation.walk_blocks_preorder(): + assert isinstance(block, Block) + print(f"// - Block: {block_ids[block]}", file=output) + liveness = self.get_liveness(block) + print("// --- LiveIn: ", file=output, end="") + print_value_refs(liveness.in_values) + print("\n// --- LiveOut: ", file=output, end="") + print_value_refs(liveness.out_values) + print("\n", file=output, end="") + + # Print liveness intervals. + print("// --- BeginLivenessIntervals", file=output, end="") + for op in block.ops: + if len(op.results) < 1: + continue + print("", file=output) + for result in op.results: + print("//", file=output, end="") + print_value_ref(result) + print(":", file=output, end="") + live_operations = self.resolve_liveness(result, output) + live_operations.sort(key=lambda x: operation_ids[x]) + + for operation in live_operations: + print("\n// ", file=output, end="") + printer.print_op(operation) + + print("\n// --- EndLivenessIntervals", file=output) + + # Print currently live values. + print("// --- BeginCurrentlyLive", file=output) + for op in block.ops: + currently_live = liveness.currently_live_values(op, output) + if not currently_live: + continue + print("// ", file=output, end="") + printer.print_op(op) + print(" [", file=output, end="") + print_value_refs(currently_live) + print("\b]\n", file=output, end="") + + print("// --- EndCurrentlyLive", file=output) + + print("// -------------------", file=output) + + +def print_liveness(program: builtin.ModuleOp, output: IO[str]): + printer = Printer( + stream=output, + ) + + for func_op in filter(lambda x: isinstance(x, func.FuncOp), program.walk()): + liveness = Liveness(func_op) + liveness.print(output, printer) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index ee46e56ac3..02d16c644c 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -240,12 +240,19 @@ def _print_to_csl(prog: ModuleOp, output: IO[str]): print_to_csl(prog, output) + def _print_liveness(prog: ModuleOp, output: IO[str]): + from xdsl.transforms.experimental.liveness import print_liveness + + print_liveness(prog, output) + _output_mlir(prog, output) + self.available_targets["mlir"] = _output_mlir self.available_targets["riscv-asm"] = _output_riscv_asm self.available_targets["x86-asm"] = _output_x86_asm self.available_targets["riscemu"] = _emulate_riscv self.available_targets["wat"] = _output_wat self.available_targets["csl"] = _print_to_csl + self.available_targets["liveness"] = _print_liveness def setup_pipeline(self): """