Skip to content

Commit

Permalink
Improve naming of compiled variables
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed May 20, 2024
1 parent 27222e5 commit 5711037
Showing 1 changed file with 60 additions and 17 deletions.
77 changes: 60 additions & 17 deletions xdsl_pdl/passes/pdl_to_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
TypeOp,
)

from xdsl.dialects.builtin import SymbolRefAttr, ModuleOp
from xdsl.dialects.builtin import IntegerType, SymbolRefAttr, ModuleOp
from xdsl.dialects import irdl
from xdsl.utils.hints import isa
from xdsl_pdl.dialects.irdl_extension import CheckSubsetOp, EqOp, YieldOp


Expand All @@ -49,7 +50,10 @@ def add_missing_pdl_result(program: PatternOp):
results_found[use.index] = True
for index, found in enumerate(results_found):
if not found:
Rewriter.insert_op_after(op, ResultOp(index, op.op))
result_op = ResultOp(index, op.op)
if op.op.name_hint is not None:
result_op.val.name_hint = op.op.name_hint + f"_result_{index}_"
Rewriter.insert_op_after(op, result_op)


def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp:
Expand All @@ -65,6 +69,19 @@ def convert_pattern_to_check_subset(program: PatternOp) -> CheckSubsetOp:
program.body.clone_into(check_subset.lhs)
program.body.clone_into(check_subset.rhs)

for op1, op2, op3 in zip(
program.body.walk(),
check_subset.lhs.walk(),
check_subset.rhs.walk(),
strict=True,
):
for op1_res, op2_res, op3_res in zip(
op1.results, op2.results, op3.results, strict=True
):
if op1_res.name_hint is not None:
op2_res.name_hint = "match_" + op1_res.name_hint
op3_res.name_hint = "rewrite_" + op1_res.name_hint

# Remove the rewrite part of the lhs
assert check_subset.lhs.ops.last is not None
Rewriter.erase_op(check_subset.lhs.ops.last)
Expand Down Expand Up @@ -175,9 +192,15 @@ def match_and_rewrite(self, op: AttributeOp, rewriter: PatternRewriter, /):
if op.value is not None:
rewriter.replace_matched_op(irdl.IsOp(op.value))
return
# In the case of an untyped attribute, we can replace it with an `irdl.any`
# In the case of a typed attribute, we assume that it is an integer attribute
if op.value_type is not None:
if not isinstance(op.value_type, IntegerType):
raise Exception(
"Only typed attributes with integer types are supported"
)
value = irdl.AnyOp()
if op.output.name_hint is not None:
value.output.name_hint = op.output.name_hint + "_value"
rewriter.replace_matched_op(
[
value,
Expand Down Expand Up @@ -224,15 +247,21 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /):

# Clone all the operation constraints
cloned_op = irdl_op.clone()
for constraint in list(cloned_op.body.ops):
if isinstance(constraint, irdl.OperandsOp):
irdl_operands = constraint.args
for cloned_constraint, constraint in zip(
list(cloned_op.body.ops), irdl_op.body.ops
):
if isinstance(cloned_constraint, irdl.OperandsOp):
irdl_operands = cloned_constraint.args
continue
if isinstance(constraint, irdl.ResultsOp):
irdl_results = constraint.args
if isinstance(cloned_constraint, irdl.ResultsOp):
irdl_results = cloned_constraint.args
continue
constraint.detach()
rewriter.insert_op_before_matched_op(constraint)
cloned_constraint.detach()
rewriter.insert_op_before_matched_op(cloned_constraint)
if (op_hint := op.op.name_hint) is not None and (
hint := constraint.results[0].name_hint
) is not None:
cloned_constraint.results[0].name_hint = op_hint + "_" + hint

cloned_op.erase()

Expand All @@ -253,6 +282,23 @@ def match_and_rewrite(self, op: OperationOp, rewriter: PatternRewriter, /):
rewriter.erase_matched_op()


def get_zero_irdl(op: ApplyNativeRewriteOp, rewriter: PatternRewriter):
"""Return the IRDL constraint representing the PDL constraint `get_zero`."""
zero = irdl.AnyOp()
res = irdl.ParametricOp(
SymbolRefAttr("builtin", ["integer_attr"]), [zero.output, op.args[0]]
)
rewriter.replace_matched_op([zero, res])


def integer_attr_arithmetic_irdl(op: ApplyNativeRewriteOp, rewriter: PatternRewriter):
"""
Return the IRDL constraint representing the PDL
constraints doing arithmetic on integer attributes.
"""
rewriter.replace_matched_op([], new_results=[op.args[0]])


class PDLToIRDLNativeRewritePattern(RewritePattern):
"""
Replace `pdl.native_rewrite` operations with our hardcoded implementation.
Expand All @@ -261,13 +307,10 @@ class PDLToIRDLNativeRewritePattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ApplyNativeRewriteOp, rewriter: PatternRewriter, /):
if op.constraint_name.data == "get_zero":
# We do not currently support the 0 part of the rewrite
# We only say it is an integer_attr with the given type
zero = irdl.AnyOp()
res = irdl.ParametricOp(
SymbolRefAttr("builtin", ["integer_attr"]), [zero.output, op.args[0]]
)
rewriter.replace_matched_op([zero, res])
get_zero_irdl(op, rewriter)
return
if op.constraint_name.data == "addi":
integer_attr_arithmetic_irdl(op, rewriter)
return
raise Exception(f"Unknown native rewrite {op.constraint_name}")

Expand Down

0 comments on commit 5711037

Please sign in to comment.