Skip to content

Commit

Permalink
transformations: (convert-memref-to-ptr) add lower-func flag (#3820)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaylendog authored Feb 2, 2025
1 parent 3a9e8b3 commit dce19da
Show file tree
Hide file tree
Showing 2 changed files with 233 additions and 5 deletions.
42 changes: 42 additions & 0 deletions tests/filecheck/transforms/convert_memref_args_to_ptr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: xdsl-opt -p convert-memref-to-ptr{lower_func=true} --split-input-file --verify-diagnostics %s | filecheck %s

// CHECK: builtin.module {

// CHECK-NEXT: func.func @declaration(!ptr_xdsl.ptr) -> ()
func.func @declaration(%arg : memref<2x2xf32>)


// CHECK-NEXT: func.func @simple(%arg : !ptr_xdsl.ptr) {
// CHECK-NEXT: func.return
// CHECK-NEXT: }
func.func @simple(%arg : memref<2x2xf32>) {
func.return
}

// CHECK-NEXT: func.func @id(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr {
// CHECK-NEXT: func.return %arg : !ptr_xdsl.ptr
// CHECK-NEXT: }
func.func @id(%arg : memref<2x2xf32>) -> memref<2x2xf32> {
func.return %arg : memref<2x2xf32>
}

// CHECK-NEXT: func.func @id2(%arg : !ptr_xdsl.ptr) -> !ptr_xdsl.ptr {
// CHECK-NEXT: %res = func.call @id(%arg) : (!ptr_xdsl.ptr) -> !ptr_xdsl.ptr
// CHECK-NEXT: func.return %res : !ptr_xdsl.ptr
// CHECK-NEXT: }
func.func @id2(%arg : memref<2x2xf32>) -> memref<2x2xf32> {
%res = func.call @id(%arg) : (memref<2x2xf32>) -> memref<2x2xf32>
func.return %res : memref<2x2xf32>
}

// CHECK-NEXT: func.func @first(%arg : !ptr_xdsl.ptr) -> f32 {
// CHECK-NEXT: %res = ptr_xdsl.load %arg : !ptr_xdsl.ptr -> f32
// CHECK-NEXT: func.return %res : f32
// CHECK-NEXT: }
func.func @first(%arg : memref<2x2xf32>) -> f32 {
%pointer = ptr_xdsl.to_ptr %arg : memref<2x2xf32> -> !ptr_xdsl.ptr
%res = ptr_xdsl.load %pointer : !ptr_xdsl.ptr -> f32
func.return %res : f32
}

// CHECK-NEXT: }
196 changes: 191 additions & 5 deletions xdsl/transforms/convert_memref_to_ptr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import cast

from xdsl.context import MLContext
from xdsl.dialects import arith, builtin, memref, ptr
from xdsl.ir import Operation, SSAValue
from xdsl.dialects import arith, builtin, func, memref, ptr
from xdsl.ir import Attribute, Operation, SSAValue
from xdsl.irdl import Any
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
Expand All @@ -14,6 +14,7 @@
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.rewriter import InsertPoint
from xdsl.utils.exceptions import DiagnosticException


Expand Down Expand Up @@ -153,12 +154,197 @@ def match_and_rewrite(self, op: memref.LoadOp, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(ops, new_results=[load_result.res])


@dataclass
class LowerMemrefFuncOpPattern(RewritePattern):
"""
Rewrites function arguments of MemRefType to PtrType.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
# rewrite function declaration
new_input_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.inputs
]
new_output_types = [
ptr.PtrType() if isinstance(arg, builtin.MemRefType) else arg
for arg in op.function_type.outputs
]
op.function_type = func.FunctionType.from_lists(
new_input_types,
new_output_types,
)

if op.is_declaration:
return

insert_point = InsertPoint.at_start(op.body.blocks[0])

# rewrite arguments
for arg in op.args:
if not isinstance(arg_type := arg.type, memref.MemRefType):
continue

old_type = cast(memref.MemRefType[Attribute], arg_type)
arg.type = ptr.PtrType()

if not arg.uses:
continue

rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get([arg], [old_type]),
insert_point,
)
arg.replace_by_if(cast_op.results[0], lambda x: x.operation is not cast_op)


@dataclass
class LowerMemrefFuncReturnPattern(RewritePattern):
"""
Rewrites all `memref` arguments to `func.return` into `ptr.PtrType`
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.ReturnOp, rewriter: PatternRewriter, /):
if not any(isinstance(arg.type, memref.MemRefType) for arg in op.arguments):
return

insert_point = InsertPoint.before(op)
new_arguments: list[SSAValue] = []

# insert `memref -> ptr` casts for memref return values
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[argument], [ptr.PtrType()]
),
insert_point,
)
new_arguments.append(cast_op.results[0])
else:
new_arguments.append(argument)

rewriter.replace_matched_op(func.ReturnOp(*new_arguments))


@dataclass
class LowerMemrefFuncCallPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter, /):
if not any(
isinstance(arg.type, memref.MemRefType) for arg in op.arguments
) and not any(isinstance(type, memref.MemRefType) for type in op.result_types):
return

# rewrite arguments
insert_point = InsertPoint.before(op)
new_arguments: list[SSAValue] = []

# insert `memref -> ptr` casts for memref arguments values
for argument in op.arguments:
if isinstance(argument.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[argument], [ptr.PtrType()]
),
insert_point,
)
new_arguments.append(cast_op.results[0])
else:
new_arguments.append(argument)

insert_point = InsertPoint.after(op)
new_results: list[SSAValue] = []

# insert `ptr -> memref` casts for return values
for result in op.results:
if isinstance(result.type, memref.MemRefType):
rewriter.insert_op(
cast_op := builtin.UnrealizedConversionCastOp.get(
[result],
# TODO: annoying pyright warnings - Sasha, pls help
[result.type], # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
),
insert_point,
)
new_results.append(cast_op.results[0])
else:
new_results.append(result)

new_return_types = [
ptr.PtrType() if isinstance(type, memref.MemRefType) else type
for type in op.result_types
]

rewriter.replace_matched_op(
func.CallOp(op.callee, new_arguments, new_return_types)
)


class ReconcileUnrealizedPtrCasts(RewritePattern):
"""
Eliminates two variants of unrealized ptr casts:
- `ptr_xdsl.ptr -> memref.MemRef -> ptr_xdsl.ptr`;
- `ptr_xdsl.ptr -> memref.memref` where all uses are `ToPtrOp` operations.
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: builtin.UnrealizedConversionCastOp, rewriter: PatternRewriter, /
):
# preconditions
if (
len(op.inputs) != 1
or len(op.outputs) != 1
or not isinstance(op.inputs[0].type, ptr.PtrType)
or not isinstance(op.outputs[0].type, memref.MemRefType)
):
return

# erase ptr -> memref -> ptr cast pairs
uses = tuple(use for use in op.outputs[0].uses)
for use in uses:
if (
isinstance(use.operation, builtin.UnrealizedConversionCastOp)
and isinstance(use.operation.inputs[0].type, memref.MemRefType)
and isinstance(use.operation.outputs[0].type, ptr.PtrType)
):
use.operation.outputs[0].replace_by(op.inputs[0])
rewriter.erase_op(use.operation)

# erase this cast entirely if all remaining uses are by ToPtr operations
cast_ops = [use.operation for use in op.outputs[0].uses]
if not all(isinstance(op, ptr.ToPtrOp) for op in cast_ops):
return

for cast_op in cast_ops:
cast_op.results[0].replace_by(op.inputs[0])
rewriter.erase_op(cast_op)

rewriter.erase_op(op)


@dataclass(frozen=True)
class ConvertMemrefToPtr(ModulePass):
name = "convert-memref-to-ptr"

lower_func: bool = False

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
the_one_pass = PatternRewriteWalker(
PatternRewriteWalker(
GreedyRewritePatternApplier([ConvertStoreOp(), ConvertLoadOp()])
)
the_one_pass.rewrite_module(op)
).rewrite_module(op)

if self.lower_func:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
LowerMemrefFuncOpPattern(),
LowerMemrefFuncCallPattern(),
LowerMemrefFuncReturnPattern(),
ReconcileUnrealizedPtrCasts(),
]
)
).rewrite_module(op)

0 comments on commit dce19da

Please sign in to comment.