diff --git a/xdsl_pdl/passes/pdl_to_irdl.py b/xdsl_pdl/passes/pdl_to_irdl.py index 8851526..99f1e17 100644 --- a/xdsl_pdl/passes/pdl_to_irdl.py +++ b/xdsl_pdl/passes/pdl_to_irdl.py @@ -24,8 +24,9 @@ TypeOp, ) -from xdsl.dialects.builtin import SymbolRefAttr, ModuleOp +from xdsl.dialects.builtin import IntegerType, SymbolRefAttr, ModuleOp from xdsl.dialects import irdl +from xdsl.utils.hints import isa from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp @@ -49,7 +50,10 @@ def add_missing_pdl_result(program: PatternOp): results_found[use.index] = True for index, found in enumerate(results_found): if not found: - Rewriter.insert_op_after(op, ResultOp(index, op.op)) + result_op = ResultOp(index, op.op) + if op.op.name_hint is not None: + result_op.val.name_hint = op.op.name_hint + f"_result_{index}_" + Rewriter.insert_op_after(op, result_op) def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp: @@ -65,6 +69,19 @@ def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp: program.body.clone_into(check_subset.lhs) program.body.clone_into(check_subset.rhs) + for op1, op2, op3 in zip( + program.body.walk(), + check_subset.lhs.walk(), + check_subset.rhs.walk(), + strict=True, + ): + for op1_res, op2_res, op3_res in zip( + op1.results, op2.results, op3.results, strict=True + ): + if op1_res.name_hint is not None: + op2_res.name_hint = "match_" + op1_res.name_hint + op3_res.name_hint = "rewrite_" + op1_res.name_hint + # Remove the rewrite part of the lhs assert check_subset.lhs.ops.last is not None Rewriter.erase_op(check_subset.lhs.ops.last) @@ -175,9 +192,15 @@ def match_and_rewrite(self, op: AttributeOp, rewriter: PatternRewriter, /): if op.value is not None: rewriter.replace_matched_op(irdl.IsOp(op.value)) return - # In the case of an untyped attribute, we can replace it with an `irdl.any` + # In the case of a typed attribute, we assume that it is an integer attribute if op.value_type is not None: + if not isinstance(op.value_type, IntegerType): + raise Exception( + "Only typed attributes with integer types are supported" + ) value = irdl.AnyOp() + if op.output.name_hint is not None: + value.output.name_hint = op.output.name_hint + "_value" rewriter.replace_matched_op( [ value, @@ -224,15 +247,21 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /): # Clone all the operation constraints cloned_op = irdl_op.clone() - for constraint in list(cloned_op.body.ops): - if isinstance(constraint, irdl.OperandsOp): - irdl_operands = constraint.args + for cloned_constraint, constraint in zip( + list(cloned_op.body.ops), irdl_op.body.ops + ): + if isinstance(cloned_constraint, irdl.OperandsOp): + irdl_operands = cloned_constraint.args continue - if isinstance(constraint, irdl.ResultsOp): - irdl_results = constraint.args + if isinstance(cloned_constraint, irdl.ResultsOp): + irdl_results = cloned_constraint.args continue - constraint.detach() - rewriter.insert_op_before_matched_op(constraint) + cloned_constraint.detach() + rewriter.insert_op_before_matched_op(cloned_constraint) + if (op_hint := op.op.name_hint) is not None and ( + hint := constraint.results[0].name_hint + ) is not None: + cloned_constraint.results[0].name_hint = op_hint + "_" + hint cloned_op.erase() @@ -253,6 +282,23 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /): rewriter.erase_matched_op() +def get_zero_irdl(op: ApplyNativeRewriteOp, rewriter: PatternRewriter): + """Return the IRDL constraint representing the PDL constraint `get_zero`.""" + zero = irdl.AnyOp() + res = irdl.ParametricOp( + SymbolRefAttr("builtin", ["integer_attr"]), [zero.output, op.args[0]] + ) + rewriter.replace_matched_op([zero, res]) + + +def integer_attr_arithmetic_irdl(op: ApplyNativeRewriteOp, rewriter: PatternRewriter): + """ + Return the IRDL constraint representing the PDL + constraints doing arithmetic on integer attributes. + """ + rewriter.replace_matched_op([], new_results=[op.args[0]]) + + class PDLToIRDLNativeRewritePattern(RewritePattern): """ Replace `pdl.native_rewrite` operations with our hardcoded implementation. @@ -261,13 +307,10 @@ class PDLToIRDLNativeRewritePattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyNativeRewriteOp, rewriter: PatternRewriter, /): if op.constraint_name.data == "get_zero": - # We do not currently support the 0 part of the rewrite - # We only say it is an integer_attr with the given type - zero = irdl.AnyOp() - res = irdl.ParametricOp( - SymbolRefAttr("builtin", ["integer_attr"]), [zero.output, op.args[0]] - ) - rewriter.replace_matched_op([zero, res]) + get_zero_irdl(op, rewriter) + return + if op.constraint_name.data == "addi": + integer_attr_arithmetic_irdl(op, rewriter) return raise Exception(f"Unknown native rewrite {op.constraint_name}")