Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 16, 2024
1 parent c699e77 commit ef6190b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
6 changes: 3 additions & 3 deletions tests/filecheck/pdl_to_irdl_check/mulsi_extended_bug.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ irdl.dialect @arith {
}

pdl.pattern @MulSIExtendedRHSOne : benefit(0) {
%t = pdl.type : index
%t = pdl.type
%x = pdl.operand : %t
%one = pdl.attribute : %t
pdl.apply_native_constraint "is_one"(%one : !pdl.attribute)
Expand All @@ -61,11 +61,11 @@ pdl.pattern @MulSIExtendedRHSOne : benefit(0) {

%root = pdl.operation "arith.mulsi_extended"(%x, %one_val : !pdl.value, !pdl.value) -> (%t, %t : !pdl.type, !pdl.type)
pdl.rewrite %root {
%zero = pdl.attribute = 0 : index
%zero = pdl.apply_native_rewrite "get_zero"(%t : !pdl.type) : !pdl.attribute
%zero_op = pdl.operation "arith.constant" {"value" = %zero} -> (%t : !pdl.type)
%zero_val = pdl.result 0 of %zero_op

%two = pdl.attribute = 2 : index
%two = pdl.attribute = 2 : i64
%i1 = pdl.type : i1
%cmpi_op = pdl.operation "arith.cmpi"(%x, %zero_val : !pdl.value, !pdl.value) {"predicate" = %two} -> (%i1 : !pdl.type)
%cmpi_val = pdl.result 0 of %cmpi_op
Expand Down
27 changes: 22 additions & 5 deletions xdsl_pdl/tools/pdl_to_irdl_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from xdsl.dialects.pdl import (
PDL,
ApplyNativeConstraintOp,
ApplyNativeRewriteOp,
AttributeOp,
OperandOp,
OperationOp,
Expand Down Expand Up @@ -122,11 +123,7 @@ def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp:
else:
Rewriter.insert_op_before(root.owner, op)
continue
if isinstance(op, TypeOp):
op.detach()
Rewriter.insert_op_before(root.owner, op)
continue
if isinstance(op, OperationOp):
if isinstance(op, TypeOp | OperationOp | ApplyNativeRewriteOp):
op.detach()
Rewriter.insert_op_before(root.owner, op)
continue
Expand Down Expand Up @@ -270,6 +267,25 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /):
rewriter.erase_matched_op()


class PDLToIRDLNativeRewritePattern(RewritePattern):
"""
Replace `pdl.native_rewrite` operations with our hardcoded implementation.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyNativeRewriteOp, rewriter: PatternRewriter, /):
if op.constraint_name.data == "get_zero":
# We do not currently support the 0 part of the rewrite
# We only say it is an integer_attr with the given type
zero = irdl.AnyOp()
res = irdl.ParametricOp(
SymbolRefAttr("builtin", ["integer_attr"]), [zero.output, op.args[0]]
)
rewriter.replace_matched_op([zero, res])
return
raise Exception(f"Unknown native rewrite {op.constraint_name}")


def convert_pdl_match_to_irdl_match(
program: Operation, irdl_ops: dict[str, irdl.OperationOp]
):
Expand All @@ -284,6 +300,7 @@ def convert_pdl_match_to_irdl_match(
PDLToIRDLAttributePattern(),
PDLToIRDLNativeConstraintPattern(),
PDLToIRDLOperationPattern(irdl_ops),
PDLToIRDLNativeRewritePattern(),
]
)
)
Expand Down

0 comments on commit ef6190b

Please sign in to comment.