Skip to content

Commit

Permalink
Improve MLIR analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 11, 2024
1 parent 6a489b2 commit aa7d7b6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 161 deletions.
36 changes: 17 additions & 19 deletions xdsl_pdl/analysis/mlir_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,11 @@
from xdsl.ir import MLContext, Operation, Region, Block
from xdsl.printer import Printer

from xdsl.dialects.builtin import ModuleOp, StringAttr
from xdsl.dialects.builtin import FunctionType, ModuleOp, StringAttr
from xdsl.dialects.pdl import PatternOp
from xdsl.dialects.test import TestOp
from xdsl.dialects.func import FuncOp

from xdsl_pdl.fuzzing.generate_pdl_matches import (
create_dag_in_region,
generate_all_dags,
pdl_to_operations,
put_operations_in_region,
)
from xdsl_pdl.fuzzing.generate_pdl_matches import get_all_matches


@dataclass
Expand Down Expand Up @@ -81,18 +76,21 @@ def analyze_with_mlir(
Run the pattern on multiple examples with MLIR.
If MLIR returns an error in any of the examples, returns the error.
"""
pattern = pattern.clone()
all_dags = generate_all_dags(5)
try:
for _ in range(0, 10):
region, ops = pdl_to_operations(pattern, ctx, randgen)
dag = all_dags[randgen.randrange(0, len(all_dags))]
create_dag_in_region(region, dag, ctx)
for populated_region in put_operations_in_region(dag, region, ops, ctx):
cloned_region = Region()
populated_region.clone_into(cloned_region)
program = TestOp.create(regions=[cloned_region])
run_with_mlir(program, pattern, mlir_executable_path)
pattern = pattern.clone()
for populated_region in get_all_matches(
pattern, Region([Block()]), randgen, ctx
):
cloned_region = Region()
populated_region.clone_into(cloned_region)
program = FuncOp(
"test",
FunctionType.from_lists(
[arg.type for arg in cloned_region.blocks[0].args], []
),
cloned_region,
)
run_with_mlir(program, pattern, mlir_executable_path)
except MLIRFailure as e:
return e
except MLIRInfiniteLoop as e:
Expand Down
188 changes: 46 additions & 142 deletions xdsl_pdl/fuzzing/generate_pdl_matches.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

from itertools import chain, combinations
from dataclasses import dataclass, field
from random import Random
from typing import Generator, Iterable, cast
from typing import Generator, Iterable

from xdsl.ir import Attribute, Block, MLContext, OpResult, Operation, Region, SSAValue
from xdsl.ir import Attribute, MLContext, Operation, Region, SSAValue
from xdsl.dialects.builtin import (
IntegerAttr,
IntegerType,
Expand All @@ -22,89 +21,6 @@
)


@dataclass
class SingleEntryDAGStructure:
"""
A DAG structure, represented by a reverse adjency list.
A reverse adjency list is a list of sets, where the i-th set contains
the indices of the nodes that points to the i-th node.
"""

size: int = field(default=0)
reverse_adjency_list: list[set[int]] = field(default_factory=list)

def add_node(self, parents: set[int]):
if all(parent >= self.size for parent in parents):
raise Exception("Can't add a node without non-self parents")
for parent in parents:
if parent > self.size:
raise Exception(
"Can't add a node with parents that are not " "yet in the DAG"
)
self.reverse_adjency_list.append(parents)
self.size += 1

def get_adjency_list(self) -> list[set[int]]:
adjency_list: list[set[int]] = [set() for _ in range(self.size)]
for i, parents in enumerate(self.reverse_adjency_list):
for parent in parents:
adjency_list[parent].add(i)
return adjency_list

def copy(self) -> SingleEntryDAGStructure:
return SingleEntryDAGStructure(self.size, self.reverse_adjency_list.copy())

def get_dominance_list(self) -> list[set[int]]:
dominance_list = [set[int]()]
for i in range(1, self.size):
reverse_strict_adjency = self.reverse_adjency_list[i].copy()
if i in reverse_strict_adjency:
reverse_strict_adjency.remove(i)
assert len(self.reverse_adjency_list[i]) > 0
if len(reverse_strict_adjency) == 1:
parent = list(reverse_strict_adjency)[0]
dominance_list.append(dominance_list[parent] | {parent})
else:
first_parent = list(reverse_strict_adjency)[0]
dominance = dominance_list[first_parent] | {first_parent}
for parent in reverse_strict_adjency:
dominance = dominance & (dominance_list[parent] | {parent})
dominance_list.append(dominance)

return dominance_list


def powerset(iterable: Iterable[int]) -> chain[tuple[int, ...]]:
"powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))


def generate_all_dags(num_blocks: int = 5) -> list[SingleEntryDAGStructure]:
"""
Create all possible single-entry DAGs with the given number of blocks.
First, generate all possible DAGs, then remove the ones that are not
single-entry.
There should be only 2 ^ (n*(n-1)/2) possible DAGs, so this should be fine.
This is because a DAG can always be represented as a lower triangular
matrix.
"""
if num_blocks < 1:
raise Exception("Can't generate a DAG with less than 1 block")
if num_blocks == 1:
return [SingleEntryDAGStructure(1, [set()]), SingleEntryDAGStructure(1, [{0}])]
previous_dags = generate_all_dags(num_blocks - 1)
res: list[SingleEntryDAGStructure] = []
for dag in previous_dags:
for parents in powerset(range(num_blocks)):
if all(parent >= num_blocks - 1 for parent in parents):
continue
new_dag = dag.copy()
new_dag.add_node(set(parents))
res.append(new_dag)
return res


@dataclass
class PDLSynthContext:
"""
Expand All @@ -129,10 +45,12 @@ def possible_values_of_type(self, type: Attribute) -> list[SSAValue]:


def pdl_to_operations(
pattern: PatternOp, ctx: MLContext, randgen: Random
pattern: PatternOp, region: Region, ctx: MLContext, randgen: Random
) -> tuple[Region, list[Operation]]:
assert len(region.blocks) == 1
assert len(region.blocks[0].ops) == 0
assert len(region.blocks[0].args) == 0
pattern_ops = pattern.body.ops
region = Region([Block()])
synth_ops: list[Operation] = []
pdl_context = PDLSynthContext()

Expand Down Expand Up @@ -212,59 +130,45 @@ def pdl_to_operations(
return region, synth_ops


def create_dag_in_region(region: Region, dag: SingleEntryDAGStructure, ctx: MLContext):
def get_all_interleavings(
ops: list[Operation],
) -> Generator[list[Operation], None, None]:
"""
Generate all possible interleavings of the given operations,
while respecting dominance order.
"""
if ops == []:
yield []
return
for i in range(len(ops)):
dominating_ops = {
operand.owner
for operand in ops[i].operands
if isinstance(operand.owner, Operation)
}
if dominating_ops.isdisjoint(ops):
for interleaving in get_all_interleavings(ops[:i] + ops[i + 1 :]):
yield [ops[i]] + interleaving
return


def get_all_matches(
pattern: PatternOp, region: Region, randgen: Random, ctx: MLContext
) -> Iterable[Region]:
"""
Generate all possible matches of the pattern in the given region with a
single empty block.
"""
assert len(region.blocks) == 1
blocks: list[Block] = []
for _ in range(dag.size):
block = Block()
region.add_block(block)
blocks.append(block)

region.blocks[0].add_op(ctx.get_op("test.entry").create(successors=[blocks[0]]))

for i, adjency_set in enumerate(dag.get_adjency_list()):
block = blocks[i]
successors = [blocks[j] for j in adjency_set]
branch_op = ctx.get_op("test.branch")
block.add_op(branch_op.create(successors=successors))


def put_operations_in_region(
dag: SingleEntryDAGStructure, region: Region, ops: list[Operation], ctx: MLContext
) -> Generator[Region, None, None]:
block_to_idx: dict[Block, int] = {}
for i, block in enumerate(region.blocks[1:]):
block_to_idx[block] = i
dominance_list = dag.get_dominance_list()

def rec(i: int, ops: list[Operation]) -> Generator[Region, None, None]:
# Finished placing all operations.
if len(ops) == 0:
yield region
return
# No more blocks to place operations in.
if i == dag.size:
return

# Try to place operations in next blocks
yield from rec(i + 1, ops)

# Check if we can place the first operation in this block
operands_index = set(
block_to_idx[cast(Block, operand.owner.parent_block())]
for operand in ops[0].operands
if isinstance(operand, OpResult)
)
if operands_index.issubset(dominance_list[i]):
# Place the operation, and recurse
block = region.blocks[i + 1]
assert block.ops.last is not None
block.insert_op_before(ops[0], block.ops.last)
block.insert_op_before(
ctx.get_op("test.use_op").create(operands=ops[0].results),
block.ops.last,
assert len(region.blocks[0].ops) == 0

region, ops = pdl_to_operations(pattern, region, ctx, randgen)
for interleaving in get_all_interleavings(ops):
for op in interleaving:
region.blocks[0].add_op(op)
region.blocks[0].add_op(
ctx.get_op("test.use_op").create(operands=op.results)
)
yield from rec(i, ops[1:])
ops[0].detach()

yield from rec(0, ops)
yield region
for op in region.blocks[0].ops:
region.blocks[0].detach_op(op)

0 comments on commit aa7d7b6

Please sign in to comment.