Skip to content

Commit

Permalink
[tf-frontend] add new pattern to remove tf.ReshapeOp
Browse files Browse the repository at this point in the history
  • Loading branch information
heromapwrd committed Jul 18, 2024
1 parent cb89ed5 commit ed3b8fe
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
9 changes: 8 additions & 1 deletion frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,11 @@ func.func @replace_where_V2_2D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24xf1
// CHECK-NEXT: %15 = "tf.GatherV2"(%arg1, %12, %[[CST_8]]) <{batch_dims = 0 : i64}> : (tensor<256x24xf16>, tensor<?xi64>, tensor<i32>) -> tensor<?x24xf16>
// CHECK-NEXT: %16 = "tf.Mul"(%15, %14) : (tensor<?x24xf16>, tensor<?x24xf16>) -> tensor<?x24xf16>
// CHECK-NEXT: %17 = "tf.Sum"(%16, %[[CST]]) <{keep_dims = false}> : (tensor<?x24xf16>, tensor<1xi64>) -> tensor<?xf16>
// CHECK-NEXT: return %17 : tensor<?xf16>
// CHECK-NEXT: return %17 : tensor<?xf16>

func.func @replace_tf_reshape(%arg0: tensor<?x96xf16>, %arg1: tensor<2xi32>) -> tensor<?x96xf16> {
%0 = "tf.Reshape"(%arg0, %arg1) {device = ""} : (tensor<?x96xf16>, tensor<2xi32>) -> tensor<?x96xf16>
return %0 : tensor<?x96xf16>
}
// CHECK-LABEL: func.func @replace_tf_reshape(%arg0: tensor<?x96xf16>, %arg1: tensor<2xi32>) -> tensor<?x96xf16> {
// CHECK-NEXT: return %arg0 : tensor<?x96xf16>
39 changes: 39 additions & 0 deletions frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,44 @@ Value replaceWhereDynamic(PatternRewriter &rewriter, Location loc, Value input,
return replaceWhereStatic(rewriter, loc, inputGather, oneHotOutput);
}

struct RemoveTFReshapeOp : public OpRewritePattern<TF::ReshapeOp> {
using OpRewritePattern<TF::ReshapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TF::ReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto input = reshapeOp.getTensor();
auto shape = reshapeOp.getShape();
auto output = reshapeOp.getOutput();
auto inputType = input.getType().dyn_cast<RankedTensorType>();
auto outputType = output.getType().dyn_cast<RankedTensorType>();
if (!inputType || !outputType) {
return failure();
}
if (inputType.hasStaticShape() || outputType.hasStaticShape()) {
return failure();
}
if (inputType.getRank() != outputType.getRank()) {
return failure();
}
if (!llvm::all_of(
llvm::zip(inputType.getShape(), outputType.getShape()),
[](auto it) { return std::get<0>(it) == std::get<1>(it); })) {
return failure();
}
int64_t dynamicDimNumber = 0;
for (auto dim : inputType.getShape()) {
if (dim == ShapedType::kDynamic) {
dynamicDimNumber++;
}
}
if (dynamicDimNumber > 1) {
return failure();
}
rewriter.replaceAllUsesWith(output, input);
return success();
}
};

#include "tf_mlir_ext/transforms/fuse_tf_ops.inc"

struct FuseTFOpsPass : public FuseTFOpsBase<FuseTFOpsPass> {
Expand All @@ -244,6 +282,7 @@ struct FuseTFOpsPass : public FuseTFOpsBase<FuseTFOpsPass> {

patterns.add(std::make_unique<FuseDilatedConv3DPattern>(ctx));
patterns.add(std::make_unique<FuseSigmoid>(ctx));
patterns.add(std::make_unique<RemoveTFReshapeOp>(ctx));
if (replaceWhereToStatic) {
patterns.add(std::make_unique<ReplaceWhereStatic>(ctx));
patterns.add(std::make_unique<ReplaceWhereStaticV2>(ctx));
Expand Down

0 comments on commit ed3b8fe

Please sign in to comment.