From 9beae4714191f00b7ff12d00632cdd7d8cd94cd0 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 16 May 2024 00:53:30 +0100 Subject: [PATCH] Finish pdl to irdl conversion --- xdsl_pdl/dialects/irdl_extension.py | 21 +++- xdsl_pdl/tools/pdl_to_irdl_check.py | 158 +++++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 6 deletions(-) diff --git a/xdsl_pdl/dialects/irdl_extension.py b/xdsl_pdl/dialects/irdl_extension.py index dcbfded..8c14954 100644 --- a/xdsl_pdl/dialects/irdl_extension.py +++ b/xdsl_pdl/dialects/irdl_extension.py @@ -62,4 +62,23 @@ def __init__( ) -IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, YieldOp]) +@irdl_op_definition +class EqOp(IRDLOperation): + name = "irdl_ext.eq" + + args = var_operand_def(AttributeType()) + + assembly_format = "attr-dict $args" + + def __init__( + self, + args: Sequence[SSAValue], + attr_dict: DictionaryAttr | None = None, + ): + super().__init__( + operands=[args], + attributes=attr_dict.data if attr_dict is not None else None, + ) + + +IRDLExtension = Dialect("irdl_ext", [CheckSubsetOp, YieldOp, EqOp]) diff --git a/xdsl_pdl/tools/pdl_to_irdl_check.py b/xdsl_pdl/tools/pdl_to_irdl_check.py index 8eb0bfb..29999a5 100644 --- a/xdsl_pdl/tools/pdl_to_irdl_check.py +++ b/xdsl_pdl/tools/pdl_to_irdl_check.py @@ -4,15 +4,23 @@ """ import argparse -from re import Pattern import sys -from xdsl.ir import MLContext, OpResult, Region, SSAValue +from xdsl.ir import MLContext, OpResult, Operation, Region, SSAValue, Use, dataclass +from xdsl.irdl import irdl_op_verify_arg_list from xdsl.parser import Parser from xdsl.rewriter import InsertPoint, Rewriter +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) from xdsl.dialects.pdl import ( PDL, + ApplyNativeConstraintOp, AttributeOp, OperandOp, OperationOp, @@ -24,8 +32,9 @@ ) from xdsl.dialects.builtin import Builtin -from xdsl.dialects.irdl import IRDL -from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, IRDLExtension, YieldOp +from xdsl.dialects.irdl import IRDL, DialectOp +from xdsl.dialects import irdl +from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, IRDLExtension, YieldOp def add_missing_pdl_result(program: PatternOp): @@ -144,6 +153,136 @@ def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp: return check_subset +class PDLToIRDLTypePattern(RewritePattern): + """ + Replace `pdl.type` to either `irdl.is` or `irdl.any`. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TypeOp, rewriter: PatternRewriter, /): + if op.constantType is None: + rewriter.replace_matched_op(irdl.AnyOp()) + return + rewriter.replace_matched_op(irdl.IsOp(op.constantType)) + + +class PDLToIRDLOperandPattern(RewritePattern): + """ + Replace `pdl.operand` to either `irdl.any`, or its type constraint. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: OperandOp, rewriter: PatternRewriter, /): + if op.value_type is None: + rewriter.replace_matched_op(irdl.AnyOp()) + return + rewriter.replace_matched_op([], new_results=[op.value_type]) + + +class PDLToIRDLAttributePattern(RewritePattern): + """ + Replace `pdl.attribute` to `irdl.is`, `irdl.any`, or a constraint over an + IntegerAttr. We assume that typed attributes are always IntegerAttr. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: AttributeOp, rewriter: PatternRewriter, /): + # In the case of a constant attribute, we can replace it with an `irdl.is` + # operation. + if op.value is not None: + rewriter.replace_matched_op(irdl.IsOp(op.value)) + return + # In the case of an untyped attribute, we can replace it with an `irdl.any` + if op.value_type is not None: + rewriter.replace_matched_op( + irdl.ParametricOp("builtin.integer_attr", [op.value_type]) + ) + return + # Otherwise, it could by anything + rewriter.replace_matched_op(irdl.AnyOp()) + + +class PDLToIRDLNativeConstraintPattern(RewritePattern): + """ + Remove `pdl.native_constraint` operations + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: ApplyNativeConstraintOp, rewriter: PatternRewriter, / + ): + rewriter.erase_matched_op() + + +@dataclass +class PDLToIRDLOperationPattern(RewritePattern): + """Replace `pdl.operation` to its constraints.""" + + irdl_ops: dict[str, irdl.OperationOp] + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /): + if op.opName is None: + raise Exception("All PDL operations are expected to have a name.") + if op.opName.data not in self.irdl_ops: + raise Exception("Operation not found in IRDL: " + op.opName.data) + irdl_op = self.irdl_ops[op.opName.data] + + # Grab the constraints corresponding to the operands and results + irdl_operands = [] + irdl_results = [] + + # Clone all the operation constraints + cloned_op = irdl_op.clone() + for constraint in list(cloned_op.body.ops): + if isinstance(constraint, irdl.OperandsOp): + irdl_operands = constraint.args + continue + if isinstance(constraint, irdl.ResultsOp): + irdl_results = constraint.args + continue + constraint.detach() + rewriter.insert_op_before_matched_op(constraint) + + cloned_op.erase() + + operand_matches = list(zip(irdl_operands, op.operand_values, strict=True)) + results_matches = list(zip(irdl_results, op.type_values, strict=True)) + + # Merge irdl_operand and pdl_operand + for irdl_operand, pdl_operand in [*operand_matches, *results_matches]: + merge_op = EqOp([irdl_operand, pdl_operand]) + rewriter.insert_op_before_matched_op(merge_op) + + for uses in list(op.op.uses): + if not isinstance(uses.operation, ResultOp): + raise Exception("Expected a `pdl.result` operation") + rewriter.replace_op( + uses.operation, [], new_results=[results_matches[uses.index][1]] + ) + rewriter.erase_matched_op() + + +def convert_pdl_match_to_irdl_match( + program: Operation, irdl_ops: dict[str, irdl.OperationOp] +): + """ + Convert PDL operations to IRDL operations in the given program. + """ + walker = PatternRewriteWalker( + GreedyRewritePatternApplier( + [ + PDLToIRDLTypePattern(), + PDLToIRDLOperandPattern(), + PDLToIRDLAttributePattern(), + PDLToIRDLNativeConstraintPattern(), + PDLToIRDLOperationPattern(irdl_ops), + ] + ) + ) + walker.rewrite_op(program) + + def main(): arg_parser = argparse.ArgumentParser( prog="pdl-to-irdl-check", @@ -181,10 +320,19 @@ def main(): ) exit(1) + irdl_ops: dict[str, irdl.OperationOp] = {} + for op in program.walk(): + if isinstance(op, irdl.OperationOp): + assert isinstance((parent := op.parent_op()), DialectOp) + name = parent.sym_name.data + "." + op.sym_name.data + irdl_ops[name] = op + # Add `pdl.result` operation for each `pdl.operation`. # This allows us to easily map the input values to the output values. add_missing_pdl_result(rewrite) - print(convert_pattern_to_check_subset(rewrite)) + check_subset = convert_pattern_to_check_subset(rewrite) + convert_pdl_match_to_irdl_match(check_subset, irdl_ops) + print(check_subset) if __name__ == "__main__":