Skip to content

Commit

Permalink
Finish pdl to irdl conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 15, 2024
1 parent d4fdb0b commit 9beae47
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 6 deletions.
21 changes: 20 additions & 1 deletion xdsl_pdl/dialects/irdl_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,23 @@ def __init__(
)


IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, YieldOp])
@irdl_op_definition
class EqOp(IRDLOperation):
name = "irdl_ext.eq"

args = var_operand_def(AttributeType())

assembly_format = "attr-dict $args"

def __init__(
self,
args: Sequence[SSAValue],
attr_dict: DictionaryAttr | None = None,
):
super().__init__(
operands=[args],
attributes=attr_dict.data if attr_dict is not None else None,
)


IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, YieldOp, EqOp])
158 changes: 153 additions & 5 deletions xdsl_pdl/tools/pdl_to_irdl_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@
"""

import argparse
from re import Pattern
import sys

from xdsl.ir import MLContext, OpResult, Region, SSAValue
from xdsl.ir import MLContext, OpResult, Operation, Region, SSAValue, Use, dataclass
from xdsl.irdl import irdl_op_verify_arg_list
from xdsl.parser import Parser
from xdsl.rewriter import InsertPoint, Rewriter
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)

from xdsl.dialects.pdl import (
PDL,
ApplyNativeConstraintOp,
AttributeOp,
OperandOp,
OperationOp,
Expand All @@ -24,8 +32,9 @@
)

from xdsl.dialects.builtin import Builtin
from xdsl.dialects.irdl import IRDL
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, IRDLExtension, YieldOp
from xdsl.dialects.irdl import IRDL, DialectOp
from xdsl.dialects import irdl
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, IRDLExtension, YieldOp


def add_missing_pdl_result(program: PatternOp):
Expand Down Expand Up @@ -144,6 +153,136 @@ def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp:
return check_subset


class PDLToIRDLTypePattern(RewritePattern):
"""
Replace `pdl.type` to either `irdl.is` or `irdl.any`.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: TypeOp, rewriter: PatternRewriter, /):
if op.constantType is None:
rewriter.replace_matched_op(irdl.AnyOp())
return
rewriter.replace_matched_op(irdl.IsOp(op.constantType))


class PDLToIRDLOperandPattern(RewritePattern):
"""
Replace `pdl.operand` to either `irdl.any`, or its type constraint.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: OperandOp, rewriter: PatternRewriter, /):
if op.value_type is None:
rewriter.replace_matched_op(irdl.AnyOp())
return
rewriter.replace_matched_op([], new_results=[op.value_type])


class PDLToIRDLAttributePattern(RewritePattern):
"""
Replace `pdl.attribute` to `irdl.is`, `irdl.any`, or a constraint over an
IntegerAttr. We assume that typed attributes are always IntegerAttr.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: AttributeOp, rewriter: PatternRewriter, /):
# In the case of a constant attribute, we can replace it with an `irdl.is`
# operation.
if op.value is not None:
rewriter.replace_matched_op(irdl.IsOp(op.value))
return
# In the case of an untyped attribute, we can replace it with an `irdl.any`
if op.value_type is not None:
rewriter.replace_matched_op(
irdl.ParametricOp("builtin.integer_attr", [op.value_type])
)
return
# Otherwise, it could by anything
rewriter.replace_matched_op(irdl.AnyOp())


class PDLToIRDLNativeConstraintPattern(RewritePattern):
"""
Remove `pdl.native_constraint` operations
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: ApplyNativeConstraintOp, rewriter: PatternRewriter, /
):
rewriter.erase_matched_op()


@dataclass
class PDLToIRDLOperationPattern(RewritePattern):
"""Replace `pdl.operation` to its constraints."""

irdl_ops: dict[str, irdl.OperationOp]

@op_type_rewrite_pattern
def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /):
if op.opName is None:
raise Exception("All PDL operations are expected to have a name.")
if op.opName.data not in self.irdl_ops:
raise Exception("Operation not found in IRDL: " + op.opName.data)
irdl_op = self.irdl_ops[op.opName.data]

# Grab the constraints corresponding to the operands and results
irdl_operands = []
irdl_results = []

# Clone all the operation constraints
cloned_op = irdl_op.clone()
for constraint in list(cloned_op.body.ops):
if isinstance(constraint, irdl.OperandsOp):
irdl_operands = constraint.args
continue
if isinstance(constraint, irdl.ResultsOp):
irdl_results = constraint.args
continue
constraint.detach()
rewriter.insert_op_before_matched_op(constraint)

cloned_op.erase()

operand_matches = list(zip(irdl_operands, op.operand_values, strict=True))
results_matches = list(zip(irdl_results, op.type_values, strict=True))

# Merge irdl_operand and pdl_operand
for irdl_operand, pdl_operand in [*operand_matches, *results_matches]:
merge_op = EqOp([irdl_operand, pdl_operand])
rewriter.insert_op_before_matched_op(merge_op)

for uses in list(op.op.uses):
if not isinstance(uses.operation, ResultOp):
raise Exception("Expected a `pdl.result` operation")
rewriter.replace_op(
uses.operation, [], new_results=[results_matches[uses.index][1]]
)
rewriter.erase_matched_op()


def convert_pdl_match_to_irdl_match(
program: Operation, irdl_ops: dict[str, irdl.OperationOp]
):
"""
Convert PDL operations to IRDL operations in the given program.
"""
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
PDLToIRDLTypePattern(),
PDLToIRDLOperandPattern(),
PDLToIRDLAttributePattern(),
PDLToIRDLNativeConstraintPattern(),
PDLToIRDLOperationPattern(irdl_ops),
]
)
)
walker.rewrite_op(program)


def main():
arg_parser = argparse.ArgumentParser(
prog="pdl-to-irdl-check",
Expand Down Expand Up @@ -181,10 +320,19 @@ def main():
)
exit(1)

irdl_ops: dict[str, irdl.OperationOp] = {}
for op in program.walk():
if isinstance(op, irdl.OperationOp):
assert isinstance((parent := op.parent_op()), DialectOp)
name = parent.sym_name.data + "." + op.sym_name.data
irdl_ops[name] = op

# Add `pdl.result` operation for each `pdl.operation`.
# This allows us to easily map the input values to the output values.
add_missing_pdl_result(rewrite)
print(convert_pattern_to_check_subset(rewrite))
check_subset = convert_pattern_to_check_subset(rewrite)
convert_pdl_match_to_irdl_match(check_subset, irdl_ops)
print(check_subset)


if __name__ == "__main__":
Expand Down

0 comments on commit 9beae47

Please sign in to comment.