diff --git a/compiler/include/byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h b/compiler/include/byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h index 852d41cdb..1988fa4f1 100644 --- a/compiler/include/byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h +++ b/compiler/include/byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h @@ -60,7 +60,6 @@ inline void registerAllMhloInferReturnTypeComponents() { //===----------------------------------------------------------------------===// void registerDynamicPartitionInferBoundedReturnTypeComponents(); void registerNonZeroInferBoundedReturnTypeComponents(); -void registerWhereInferBoundedReturnTypeComponents(); void registerScatterNdInferBoundedReturnTypeComponents(); void registerStridedSliceInferBoundedReturnTypeComponents(); void registerRepeatInferBoundedReturnTypeComponents(); @@ -68,7 +67,6 @@ void registerRepeatInferBoundedReturnTypeComponents(); inline void registerAllMhloInferBoundedReturnTypeComponents() { registerDynamicPartitionInferBoundedReturnTypeComponents(); registerNonZeroInferBoundedReturnTypeComponents(); - registerWhereInferBoundedReturnTypeComponents(); registerScatterNdInferBoundedReturnTypeComponents(); registerStridedSliceInferBoundedReturnTypeComponents(); registerRepeatInferBoundedReturnTypeComponents(); diff --git a/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h b/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h index d1aea96f8..7ecaf32d0 100644 --- a/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h +++ b/compiler/include/byteir/Dialect/mhlo/Util/CustomCallUtil.h @@ -121,7 +121,6 @@ constexpr llvm::StringRef getDynamicMaskStitchName() { return TF_NAME_PREFIX "DynamicMaskStitch"; } -constexpr llvm::StringRef getWhereName() { return TF_NAME_PREFIX "Where"; } constexpr llvm::StringRef getScatterNdName() { return TF_NAME_PREFIX "ScatterNd"; } diff --git a/compiler/lib/Dialect/mhlo/CMakeLists.txt b/compiler/lib/Dialect/mhlo/CMakeLists.txt index 651937aee..5b7a3b0b3 100644 --- a/compiler/lib/Dialect/mhlo/CMakeLists.txt +++ b/compiler/lib/Dialect/mhlo/CMakeLists.txt @@ -35,7 +35,6 @@ add_mlir_dialect_library(ByteIRMhloDynamicShapeOpRegister DynamicShapeOpRegister/Softmax.cpp DynamicShapeOpRegister/AddN.cpp DynamicShapeOpRegister/TorchIndexSelect.cpp - DynamicShapeOpRegister/Where.cpp DynamicShapeOpRegister/ScatterNd.cpp DynamicShapeOpRegister/StridedSlice.cpp DynamicShapeOpRegister/BatchMatMul.cpp diff --git a/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/NonZero.cpp b/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/NonZero.cpp index 630893f4b..a0c6f2e84 100644 --- a/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/NonZero.cpp +++ b/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/NonZero.cpp @@ -37,9 +37,10 @@ void mlir::registerNonZeroInferBoundedReturnTypeComponents() { if (!inputShape || !inputShape.hasStaticShape()) return failure(); - Type type = RankedTensorType::get({inputShape.getNumElements()}, - IntegerType::get(context, 64)); + Type type = RankedTensorType::get( + {inputShape.getNumElements(), inputShape.getRank()}, + IntegerType::get(context, 64)); inferredReturnTypes.push_back(cast(type)); return success(); }); -} \ No newline at end of file +} diff --git a/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/Where.cpp b/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/Where.cpp deleted file mode 100644 index 1f6ca5cae..000000000 --- a/compiler/lib/Dialect/mhlo/DynamicShapeOpRegister/Where.cpp +++ /dev/null @@ -1,45 +0,0 @@ -//===- Where.cpp ----------------------------------------------*--- C++ -*-===// -// -// Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -//===----------------------------------------------------------------------===// - -#include "byteir/Dialect/mhlo/DynamicShapeOpRegister/Register.h" -#include "byteir/Dialect/mhlo/Util/CustomCallUtil.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" - -#define DEBUG_TYPE "dynamic-shape-op-register" - -using namespace mlir; - -/// See Where's signature on https://www.tensorflow.org/api_docs/python/tf/where -/// Bounded shape infer is the same as nonzero -void mlir::registerWhereInferBoundedReturnTypeComponents() { - static InferBoundedReturnTypeComponentsRegistration shapeRegister( - getWhereName(), - [](MLIRContext *context, std::optional, - ValueShapeRange operands, DictionaryAttr, RegionRange, - SmallVectorImpl &inferredReturnTypes) { - Value input = operands[0]; - ShapedType inputShape = dyn_cast(input.getType()); - if (!inputShape || !inputShape.hasStaticShape()) - return failure(); - Type type = RankedTensorType::get( - {inputShape.getNumElements(), inputShape.getRank()}, - IntegerType::get(context, 64)); - inferredReturnTypes.push_back(cast(type)); - return success(); - }); -} diff --git a/compiler/test/Transforms/boundedShapeInference.mlir b/compiler/test/Transforms/boundedShapeInference.mlir index 062221ab0..03cafd40f 100644 --- a/compiler/test/Transforms/boundedShapeInference.mlir +++ b/compiler/test/Transforms/boundedShapeInference.mlir @@ -30,21 +30,13 @@ func.func @several_ops(%arg0: tensor {byteir.bounded_shape = [8, 4]}, % //CHECK-NEXT: %3 = mhlo.add %0, %2 : tensor //CHECK-NEXT: return %3 : tensor -func.func @registered_shape_infer(%arg0 : tensor {byteir.bounded_shape = [8, 4]}) -> tensor { - %0 = "mhlo.custom_call"(%arg0) {call_target_name = "byteir.non_zero"} : (tensor) -> tensor - return %0 : tensor +func.func @registered_shape_infer(%arg0 : tensor {byteir.bounded_shape = [8, 4]}) -> tensor { + %0 = "mhlo.custom_call"(%arg0) {call_target_name = "byteir.non_zero"} : (tensor) -> tensor + return %0 : tensor } -//CHECK-LABEL: func.func @registered_shape_infer(%arg0: tensor {byteir.bounded_shape = [8, 4]}) -> tensor { -//CHECK-NEXT: %0 = mhlo.custom_call @byteir.non_zero(%arg0) : (tensor) -> tensor -//CHECK-NEXT: return %0 : tensor - -func.func @tf_where(%arg0 : tensor<1xi1>) -> tensor { - %0 = "mhlo.custom_call"(%arg0) { call_target_name = "tf.Where" } : (tensor<1xi1>) -> tensor - return %0 : tensor -} -//CHECK-LABEL: func.func @tf_where(%arg0: tensor<1xi1>) -> tensor { -//CHECK-NEXT: %0 = mhlo.custom_call @tf.Where(%arg0) : (tensor<1xi1>) -> tensor -//CHECK-NEXT: return %0 : tensor +//CHECK-LABEL: func.func @registered_shape_infer(%arg0: tensor {byteir.bounded_shape = [8, 4]}) -> tensor { +//CHECK-NEXT: %0 = mhlo.custom_call @byteir.non_zero(%arg0) : (tensor) -> tensor +//CHECK-NEXT: return %0 : tensor func.func @main_sub_0(%arg0: tensor {byteir.bounded_shape = [4, 4]}) -> tensor { %0 = mhlo.constant dense<-0.000000e+00> : tensor diff --git a/frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir b/frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir index 8daac43b0..9233184c4 100644 --- a/frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir +++ b/frontends/tf-frontend/tf_mlir_ext/tests/fuse_tf_ops.mlir @@ -147,3 +147,65 @@ func.func @replace_where_3D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24x8xf16 // CHECK-NEXT: %15 = "tf.Mul"(%13, %14) : (tensor, tensor) -> tensor // CHECK-NEXT: %16 = "tf.Sum"(%15, %cst_1) <{keep_dims = false}> : (tensor, tensor<1xi64>) -> tensor // CHECK-NEXT: return %16 : tensor + +func.func @replace_where_V2_2D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24xf16>) -> tensor { + %cst = "tf.Const"() <{value = dense<28800> : tensor}> : () -> tensor + %cst_1 = "tf.Const"() <{value = dense<86400> : tensor}> : () -> tensor + %cst_2 = "tf.Const"() <{value = dense<1.156330e-05> : tensor}> : () -> tensor + %cst_3 = "tf.Const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor + %cst_4 = "tf.Const"() <{value = dense<2.400000e+01> : tensor}> : () -> tensor + %cst_5 = "tf.Const"() <{value = dense<24> : tensor}> : () -> tensor + %cst_6 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor + %cst_7 = "tf.Const"() <{value = dense<-1> : tensor<1xi32>}> : () -> tensor<1xi32> + %cst_8 = "tf.Const"() <{value = dense<6144> : tensor<1xi32>}> : () -> tensor<1xi32> + %cst_9 = "tf.Const"() <{value = dense<[6144, 8]> : tensor<2xi32>}> : () -> tensor<2xi32> + %cst_10 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst) {device = ""} : (tensor<256x1xi64>, tensor) -> tensor<256x1xi64> + %1 = "tf.FloorMod"(%0, %cst_1) {device = ""} : (tensor<256x1xi64>, tensor) -> tensor<256x1xi64> + %2 = "tf.Cast"(%1) <{Truncate = false}> {device = ""} : (tensor<256x1xi64>) -> tensor<256x1xf16> + %3 = "tf.Cast"(%2) <{Truncate = false}> {device = ""} : (tensor<256x1xf16>) -> tensor<256x1xf32> + %4 = "tf.Mul"(%3, %cst_2) {device = ""} : (tensor<256x1xf32>, tensor) -> tensor<256x1xf32> + %5 = "tf.Cast"(%4) <{Truncate = false}> {device = ""} : (tensor<256x1xf32>) -> tensor<256x1xf16> + %6 = "tf.FloorMod"(%5, %cst_3) {device = ""} : (tensor<256x1xf16>, tensor) -> tensor<256x1xf16> + %7 = "tf.Mul"(%6, %cst_4) {device = ""} : (tensor<256x1xf16>, tensor) -> tensor<256x1xf16> + %8 = "tf.Cast"(%7) <{Truncate = false}> {device = ""} : (tensor<256x1xf16>) -> tensor<256x1xi64> + %9 = "tf.Squeeze"(%8) <{squeeze_dims = [1]}> {device = ""} : (tensor<256x1xi64>) -> tensor<256xi64> + %10 = "tf.OneHot"(%9, %cst_5, %cst_3, %cst_6) <{axis = -1 : i64}> {device = ""} : (tensor<256xi64>, tensor, tensor, tensor) -> tensor<256x24xf16> + %11 = "tf.Reshape"(%10, %cst_7) {device = ""} : (tensor<256x24xf16>, tensor<1xi32>) -> tensor<6144xf16> + %12 = "tf.Cast"(%11) <{Truncate = false}> {device = ""} : (tensor<6144xf16>) -> tensor<6144xf32> + %13 = "tf.Where"(%12) {device = ""} : (tensor<6144xf32>) -> tensor + %14 = "tf.Squeeze"(%13) <{squeeze_dims = [1]}> {device = ""} : (tensor) -> tensor + %15 = "tf.Reshape"(%arg1, %cst_8) {device = ""} : (tensor<256x24xf16>, tensor<1xi32>) -> tensor<6144xf16> + %16 = "tf.GatherV2"(%15, %14, %cst_10) <{batch_dims = 0 : i64}> {device = ""} : (tensor<6144xf16>, tensor, tensor) -> tensor + return %16 : tensor +} +// CHECK-LABEL: func.func @replace_where_V2_2D(%arg0: tensor<256x1xi64>, %arg1: tensor<256x24xf16>) -> tensor { +// CHECK-NEXT: %cst = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK-NEXT: %cst_0 = "tf.Const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK-NEXT: %cst_1 = "tf.Const"() <{value = dense<28800> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_2 = "tf.Const"() <{value = dense<86400> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_3 = "tf.Const"() <{value = dense<1.156330e-05> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_4 = "tf.Const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_5 = "tf.Const"() <{value = dense<2.400000e+01> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_6 = "tf.Const"() <{value = dense<24> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_7 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK-NEXT: %cst_8 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK-NEXT: %0 = "tf.AddV2"(%arg0, %cst_1) {device = ""} : (tensor<256x1xi64>, tensor) -> tensor<256x1xi64> +// CHECK-NEXT: %1 = "tf.FloorMod"(%0, %cst_2) {device = ""} : (tensor<256x1xi64>, tensor) -> tensor<256x1xi64> +// CHECK-NEXT: %2 = "tf.Cast"(%1) <{Truncate = false}> {device = ""} : (tensor<256x1xi64>) -> tensor<256x1xf16> +// CHECK-NEXT: %3 = "tf.Cast"(%2) <{Truncate = false}> {device = ""} : (tensor<256x1xf16>) -> tensor<256x1xf32> +// CHECK-NEXT: %4 = "tf.Mul"(%3, %cst_3) {device = ""} : (tensor<256x1xf32>, tensor) -> tensor<256x1xf32> +// CHECK-NEXT: %5 = "tf.Cast"(%4) <{Truncate = false}> {device = ""} : (tensor<256x1xf32>) -> tensor<256x1xf16> +// CHECK-NEXT: %6 = "tf.FloorMod"(%5, %cst_4) {device = ""} : (tensor<256x1xf16>, tensor) -> tensor<256x1xf16> +// CHECK-NEXT: %7 = "tf.Mul"(%6, %cst_5) {device = ""} : (tensor<256x1xf16>, tensor) -> tensor<256x1xf16> +// CHECK-NEXT: %8 = "tf.Cast"(%7) <{Truncate = false}> {device = ""} : (tensor<256x1xf16>) -> tensor<256x1xi64> +// CHECK-NEXT: %9 = "tf.Squeeze"(%8) <{squeeze_dims = [1]}> {device = ""} : (tensor<256x1xi64>) -> tensor<256xi64> +// CHECK-NEXT: %10 = "tf.GreaterEqual"(%9, %cst) : (tensor<256xi64>, tensor<1xi64>) -> tensor<256xi1> +// CHECK-NEXT: %11 = "tf.Where"(%10) : (tensor<256xi1>) -> tensor +// CHECK-NEXT: %12 = "tf.Squeeze"(%11) <{squeeze_dims = [1]}> : (tensor) -> tensor +// CHECK-NEXT: %13 = "tf.GatherV2"(%9, %12, %cst_8) <{batch_dims = 0 : i64}> : (tensor<256xi64>, tensor, tensor) -> tensor +// CHECK-NEXT: %14 = "tf.OneHot"(%13, %cst_6, %cst_4, %cst_7) <{axis = -1 : i64}> : (tensor, tensor, tensor, tensor) -> tensor +// CHECK-NEXT: %15 = "tf.GatherV2"(%arg1, %12, %cst_8) <{batch_dims = 0 : i64}> : (tensor<256x24xf16>, tensor, tensor) -> tensor +// CHECK-NEXT: %16 = "tf.Mul"(%15, %14) : (tensor, tensor) -> tensor +// CHECK-NEXT: %17 = "tf.Sum"(%16, %cst_0) <{keep_dims = false}> : (tensor, tensor<1xi64>) -> tensor +// CHECK-NEXT: return %17 : tensor diff --git a/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir b/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir index 9db3b23ef..161ddaa3c 100644 --- a/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir +++ b/frontends/tf-frontend/tf_mlir_ext/tests/rewrite_to_custom_call.mlir @@ -183,6 +183,14 @@ func.func @addn_case0(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>, %arg2: // CHECK: mhlo.custom_call // CHECK-SAME: @byteir.addn +func.func @where_case0(%arg0: tensor<6144xf32>) -> tensor { + %0 = "tf.Where"(%arg0) : (tensor<6144xf32>) -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @where_case0 +// CHECK: mhlo.custom_call +// CHECK-SAME: @byteir.non_zero + func.func @layer_norm(%arg0: tensor<1x32x3xf32>) -> tensor<1x32x3xf32> { %cst = "tf.Const"() {value = dense<9.99999997E-7> : tensor} : () -> tensor %cst_0 = "tf.Const"() {value = dense<[0.0401659757, -0.11370486, 0.432680517]> : tensor<3xf32>} : () -> tensor<3xf32> diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc b/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc index 1600841df..095e6cfce 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.cc @@ -246,8 +246,10 @@ struct FuseTFOpsPass : public FuseTFOpsBase { patterns.add(std::make_unique(ctx)); if (replaceWhereToStatic) { patterns.add(std::make_unique(ctx)); + patterns.add(std::make_unique(ctx)); } else { patterns.add(std::make_unique(ctx)); + patterns.add(std::make_unique(ctx)); } if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.td b/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.td index df6d94ffe..cbc8c3ece 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.td +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/fuse_tf_ops.td @@ -188,4 +188,148 @@ def ReplaceWhereStatic : Pat< (TwoRank $after_where), (OneRank $after_squeeze1), (SameEleType $input1, $after_onehot)]>; +def ReplaceWhereDynamicV2 : Pat< + (TF_GatherV2Op + (TF_ReshapeOp:$after_reshape1 + $input1, + $input_shape + ), + (TF_SqueezeOp:$after_squeeze1 + (TF_WhereOp:$after_where + (TF_CastOp:$after_cast2 + (TF_ReshapeOp:$before_cast2 + (TF_OneHotOp:$after_onehot + (TF_SqueezeOp:$after_squeeze + (TF_CastOp:$after_cast1 + (TF_MulOp:$before_cast1 + (TF_FloorModOp + (TF_CastOp + (TF_MulOp + (TF_CastOp + (TF_CastOp:$after_cast + (TF_FloorModOp:$before_cast + (TF_AddV2Op + $input, + (TF_ConstOp:$addconst $addconst_attr) + ), + (TF_ConstOp:$floorconst $floorconst_attr) + ), + $truncate + ), + $truncate1 + ), + (TF_ConstOp:$mulconst $mulconst_attr) + ), + $truncate2 + ), + (TF_ConstOp:$floorconst1 $floorconst1_attr) + ), + (TF_ConstOp:$mulconst1 $mulconst1_attr) + ), + $truncate3 + ), + $squeeze_dims + ), + (TF_ConstOp:$depth $depth_attr), + (TF_ConstOp:$onvalue $onvalue_attr), + (TF_ConstOp:$offvalue $offvalue_attr), + $onehot_axis + ), + (TF_ConstOp:$shape $shape_attr) + ), + $truncate4 + ) + ), + $squeeze_dims1 + ), + (TF_ConstOp:$gatheraxis $gatheraxis_attr), + $gather_batch_dims + ), + (NativeCodeCall<"replaceWhereDynamic($_builder, $_loc, $0, $1, $2, $3, $4, $5, $6)"> $input1, $after_squeeze, $depth, $onvalue, $offvalue, $gatheraxis, $onehot_axis), + [(WhereValue2 $addconst_attr), (WhereValue3 $floorconst_attr), + (WhereValue0 $mulconst_attr), (FpSplatValueOne $floorconst1_attr), + (WhereValue4 $mulconst1_attr), (WhereValue1 $depth_attr), + (FpSplatValueOne $onvalue_attr), (FpSplatValueZero $offvalue_attr), + (IntSplatValueNegOne $shape_attr), (IntSplatValueZero $gatheraxis_attr), + (AxisAttrNegOne $onehot_axis), (AxisAttrZero $gather_batch_dims), + (IntegerEleType $before_cast), (FloatEleType $after_cast), + (FloatEleType $before_cast1), (IntegerEleType $after_cast1), + (FloatEleType $before_cast2), (FloatEleType $after_cast2), + (TwoRank $after_cast1), (OneRank $after_squeeze), + (TwoRank $after_onehot), (OneRank $before_cast2), + (TwoRank $after_where), (OneRank $after_squeeze1), + (SameEleType $input1, $after_onehot)]>; + +def ReplaceWhereStaticV2 : Pat< + (TF_GatherV2Op + (TF_ReshapeOp:$after_reshape1 + $input1, + $input_shape + ), + (TF_SqueezeOp:$after_squeeze1 + (TF_WhereOp:$after_where + (TF_CastOp:$after_cast2 + (TF_ReshapeOp:$before_cast2 + (TF_OneHotOp:$after_onehot + (TF_SqueezeOp:$after_squeeze + (TF_CastOp:$after_cast1 + (TF_MulOp:$before_cast1 + (TF_FloorModOp + (TF_CastOp + (TF_MulOp + (TF_CastOp + (TF_CastOp:$after_cast + (TF_FloorModOp:$before_cast + (TF_AddV2Op + $input, + (TF_ConstOp:$addconst $addconst_attr) + ), + (TF_ConstOp:$floorconst $floorconst_attr) + ), + $truncate + ), + $truncate1 + ), + (TF_ConstOp:$mulconst $mulconst_attr) + ), + $truncate2 + ), + (TF_ConstOp:$floorconst1 $floorconst1_attr) + ), + (TF_ConstOp:$mulconst1 $mulconst1_attr) + ), + $truncate3 + ), + $squeeze_dims + ), + (TF_ConstOp:$depth $depth_attr), + (TF_ConstOp:$onvalue $onvalue_attr), + (TF_ConstOp:$offvalue $offvalue_attr), + $onehot_axis + ), + (TF_ConstOp:$shape $shape_attr) + ), + $truncate4 + ) + ), + $squeeze_dims1 + ), + (TF_ConstOp:$gatheraxis $gatheraxis_attr), + $gather_batch_dims + ), + (NativeCodeCall<"replaceWhereStatic($_builder, $_loc, $0, $1)"> $input1, $after_onehot), + [(WhereValue2 $addconst_attr), (WhereValue3 $floorconst_attr), + (WhereValue0 $mulconst_attr), (FpSplatValueOne $floorconst1_attr), + (WhereValue4 $mulconst1_attr), (WhereValue1 $depth_attr), + (FpSplatValueOne $onvalue_attr), (FpSplatValueZero $offvalue_attr), + (IntSplatValueNegOne $shape_attr), (IntSplatValueZero $gatheraxis_attr), + (AxisAttrNegOne $onehot_axis), (AxisAttrZero $gather_batch_dims), + (IntegerEleType $before_cast), (FloatEleType $after_cast), + (FloatEleType $before_cast1), (IntegerEleType $after_cast1), + (FloatEleType $before_cast2), (FloatEleType $after_cast2), + (TwoRank $after_cast1), (OneRank $after_squeeze), + (TwoRank $after_onehot), (OneRank $before_cast2), + (TwoRank $after_where), (OneRank $after_squeeze1), + (SameEleType $input1, $after_onehot)]>; + #endif // FUSE_TF_OPS_PATTERN diff --git a/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc b/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc index 16528895f..e033a50b4 100644 --- a/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc +++ b/frontends/tf-frontend/tf_mlir_ext/transforms/rewrite_to_custom_call.cc @@ -62,6 +62,7 @@ namespace { cb(addn, AddN, CALL_TARGET_NAME_PREFIX) \ cb(one_hot, OneHot, CALL_TARGET_NAME_PREFIX) \ cb(repeat, Repeat, CALL_TARGET_NAME_PREFIX) \ + cb(non_zero, Where, CALL_TARGET_NAME_PREFIX) \ cb(DynamicMaskStitch, DynamicMaskStitch, CALL_TF_TARGET_NAME_PREFIX) \ cb(DynamicPartition, DynamicPartition, CALL_TF_TARGET_NAME_PREFIX) \ cb(DynamicStitch, DynamicStitch, CALL_TF_TARGET_NAME_PREFIX) @@ -621,6 +622,35 @@ struct RewriteRepeat : public RewritePattern { } }; +//===----------------------------------------------------------------------===// +// Where Pattern +//===----------------------------------------------------------------------===// +struct RewriteWhere : public RewritePattern { + RewriteWhere(MLIRContext *context, PatternBenefit benefits = 1) + : RewritePattern("tf.Where", benefits, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // llvm::outs() << op->getName().getStringRef(); + assert(op->getName().getStringRef() == "tf.Where"); + RankedTensorType outType = + dyn_cast(op->getResult(0).getType()); + if (!outType) + return failure(); + llvm::SmallVector outTypes{outType}; + mhlo::CustomCallOp customCallOp = rewriter.create( + op->getLoc(), outTypes, op->getOperands(), getWhereNameWithPrefix(), + false, rewriter.getStringAttr(""), + mhlo::CustomCallApiVersion{ + mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL}, + rewriter.getArrayAttr(ArrayRef{}), + mhlo::CustomCallSchedule{mhlo::CustomCallSchedule::NONE}, nullptr, + nullptr, rewriter.getArrayAttr(ArrayRef{})); + customCallOp->setAttr(getByteIRAttrs(), getCleanAttr(op)); + rewriter.replaceOp(op, customCallOp->getResults()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // SimpleReplace Pattern //===----------------------------------------------------------------------===// @@ -877,6 +907,8 @@ struct RewriteToCustomCallOpsPass std::make_unique(context, 1)); validCustomCallOpSet[getRepeatName()].emplace_back( std::make_unique(context, 1)); + validCustomCallOpSet[getWhereName()].emplace_back( + std::make_unique(context, 1)); RewritePatternSet patterns(context); for (auto op : opsSet) { diff --git a/runtime/lib/backends/cpu/providers/default/cpu_provider.cc b/runtime/lib/backends/cpu/providers/default/cpu_provider.cc index d5d9f49cd..dd7e17e05 100644 --- a/runtime/lib/backends/cpu/providers/default/cpu_provider.cc +++ b/runtime/lib/backends/cpu/providers/default/cpu_provider.cc @@ -17,11 +17,11 @@ #include "brt/backends/cpu/providers/default/cpu_provider.h" +#include "./custom_call/non_zero.h" #include "./custom_call/repeat.h" #include "./custom_call/tf_equal.h" #include "./custom_call/tf_select.h" #include "./custom_call/tf_string_to_number.h" -#include "./custom_call/tf_where.h" #include "./custom_call/topk.h" #include "./llvm/jit.h" #include "./math/elementwise_ops.h" @@ -77,11 +77,6 @@ BRT_STATIC_KERNEL_REGISTRATION( new cpu::Typecvt(info)); return kernel; }); - registry->Register( - "tf.Where", - [](const brt::OpKernelInfo &info) -> std::shared_ptr { - return std::make_shared(info); - }); }); // statcially register all CPU OpKernels @@ -128,6 +123,11 @@ BRT_STATIC_KERNEL_REGISTRATION( [](const brt::OpKernelInfo &info) -> std::shared_ptr { return std::make_shared(info); }); + registry->Register( + "byteir.non_zero", + [](const brt::OpKernelInfo &info) -> std::shared_ptr { + return std::make_shared(info); + }); RegisterCommonBuiltinOps(registry); }); diff --git a/runtime/lib/backends/cpu/providers/default/custom_call/tf_where.cc b/runtime/lib/backends/cpu/providers/default/custom_call/non_zero.cc similarity index 86% rename from runtime/lib/backends/cpu/providers/default/custom_call/tf_where.cc rename to runtime/lib/backends/cpu/providers/default/custom_call/non_zero.cc index c6bc2204f..ad718b096 100644 --- a/runtime/lib/backends/cpu/providers/default/custom_call/tf_where.cc +++ b/runtime/lib/backends/cpu/providers/default/custom_call/non_zero.cc @@ -1,4 +1,4 @@ -//===- tf_where.cc ---------------------------------------*--- C++ -*-===// +//===- non_zero.cc --------------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "./tf_where.h" +#include "./non_zero.h" #include "brt/backends/cpu/device/llvm/jit.h" #include "brt/core/framework/op_accessor.h" #include "brt/core/ir/engine_util.h" @@ -27,7 +27,7 @@ namespace brt { namespace cpu { template -void TFWhereImpl(const OpAccessor &accessor, WorkQueue *work_queue, int op_id, +void NonZeroImpl(const OpAccessor &accessor, WorkQueue *work_queue, int op_id, const std::vector &dependency) { const auto &shape = accessor.GetArgShape(0); const int64_t num_elements = accessor.GetNumElementsOfShape(shape); @@ -48,7 +48,7 @@ void TFWhereImpl(const OpAccessor &accessor, WorkQueue *work_queue, int op_id, }); } -common::Status TFWhere::RunImpl(const ExecutionContext &ctx) { +common::Status NonZero::RunImpl(const ExecutionContext &ctx) { OpAccessor accessor(info_, ctx.exec_frame); // output dtype is constraint to int64 in tf_generated_ops.td by // let results = (outs @@ -56,12 +56,12 @@ common::Status TFWhere::RunImpl(const ExecutionContext &ctx) { // ); if (accessor.GetArgDTypeEnum(1) != DTypeEnum::Int64) return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL, - "tf.Where output tensor not int64 dtype"); + "byteir.non_zero output tensor not int64 dtype"); auto data_dtype = accessor.GetArgDTypeEnum(0); #define HANDLE_DTYPE(DType) \ if (data_dtype == DType) { \ - TFWhereImpl::type_t>( \ + NonZeroImpl::type_t>( \ accessor, ctx.work_queue, info_.GetOpId(), info_.GetDependency()); \ return common::Status::OK(); \ } @@ -77,7 +77,7 @@ common::Status TFWhere::RunImpl(const ExecutionContext &ctx) { #undef HANDLE_DTYPE return common::Status(common::StatusCategory::BRT, common::StatusCode::FAIL, - "tf.Where unsupported data type"); + "byteir.non_zero unsupported data type"); } // instantiate diff --git a/runtime/lib/backends/cpu/providers/default/custom_call/tf_where.h b/runtime/lib/backends/cpu/providers/default/custom_call/non_zero.h similarity index 85% rename from runtime/lib/backends/cpu/providers/default/custom_call/tf_where.h rename to runtime/lib/backends/cpu/providers/default/custom_call/non_zero.h index 19d963d9a..c7ccd8f25 100644 --- a/runtime/lib/backends/cpu/providers/default/custom_call/tf_where.h +++ b/runtime/lib/backends/cpu/providers/default/custom_call/non_zero.h @@ -1,4 +1,4 @@ -//===- tf_where.h ---------------------------------------------*--- C++ -*-===// +//===- non_zero.h ---------------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,9 +22,9 @@ namespace brt { namespace cpu { -class TFWhere final : public OpKernel { +class NonZero final : public OpKernel { public: - explicit TFWhere(const OpKernelInfo &info) : OpKernel(info) {} + explicit NonZero(const OpKernelInfo &info) : OpKernel(info) {} common::Status RunImpl(const ExecutionContext &ctx) override; }; diff --git a/runtime/test/backends/cpu/providers/default/kernel/tf_where_test.cc b/runtime/test/backends/cpu/providers/default/kernel/non_zero_test.cc similarity index 87% rename from runtime/test/backends/cpu/providers/default/kernel/tf_where_test.cc rename to runtime/test/backends/cpu/providers/default/kernel/non_zero_test.cc index dac83155b..09fc34327 100644 --- a/runtime/test/backends/cpu/providers/default/kernel/tf_where_test.cc +++ b/runtime/test/backends/cpu/providers/default/kernel/non_zero_test.cc @@ -1,4 +1,4 @@ -//===- tf_where_test.cc ---------------------------------------*--- C++ -*-===// +//===- non_zero_test.cc ---------------------------------------*--- C++ -*-===// // // Copyright 2022 ByteDance Ltd. and/or its affiliates. All rights reserved. // Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,7 +40,7 @@ using namespace std; namespace { template > -void CheckTFWhereSingle(const std::vector &shape, +void CheckNonZeroSingle(const std::vector &shape, const ContainerT &data, const std::vector &expect_result) { ByREBuilder byre_builder; @@ -51,7 +51,7 @@ void CheckTFWhereSingle(const std::vector &shape, BRT_TEST_CHECK_STATUS(status_cpu); auto status_load = session.LoadFromMemory( - CreateTFWhereOp(byre_builder, dtype_enum_v, shape), "byre"); + CreateNonZeroOp(byre_builder, dtype_enum_v, shape), "byre"); BRT_TEST_CHECK_STATUS(status_load); std::unique_ptr request; @@ -90,14 +90,14 @@ void CheckTFWhereSingle(const std::vector &shape, } } // namespace -TEST(CPUOpKerenlTest, TFWhereBasic) { +TEST(CPUOpKerenlTest, NonZeroBasic) { using half_float::half; - CheckTFWhereSingle({3}, {1.1f, 0.0f, 0.1f}, {0, 2}); - CheckTFWhereSingle({2, 2}, + CheckNonZeroSingle({3}, {1.1f, 0.0f, 0.1f}, {0, 2}); + CheckNonZeroSingle({2, 2}, {half(5.5f), half(0.1f), half(0.0f), half(2.4f)}, {0, 0, 0, 1, 1, 1}); - CheckTFWhereSingle({2, 1, 3}, {5ll, 0ll, 9ll, 0ll, 0ll, 2ll}, + CheckNonZeroSingle({2, 1, 3}, {5ll, 0ll, 9ll, 0ll, 0ll, 2ll}, {0, 0, 0, 0, 0, 2, 1, 0, 2}); - CheckTFWhereSingle>({4}, {true, false, false, true}, + CheckNonZeroSingle>({4}, {true, false, false, true}, {0, 3}); -} \ No newline at end of file +} diff --git a/runtime/test/common/models.cc b/runtime/test/common/models.cc index 4c4302296..eaeccc20a 100644 --- a/runtime/test/common/models.cc +++ b/runtime/test/common/models.cc @@ -1185,7 +1185,7 @@ const void *CreatePTXAddOp(brt::ir::ByREBuilder &byre_builder) { return m.getAsOpaquePointer(); } -const void *CreateTFWhereOp(brt::ir::ByREBuilder &byre_builder, +const void *CreateNonZeroOp(brt::ir::ByREBuilder &byre_builder, DTypeEnum input_dtype, const std::vector &shape) { @@ -1208,7 +1208,7 @@ const void *CreateTFWhereOp(brt::ir::ByREBuilder &byre_builder, // add entry function body mlir::Block *entry_block = func_op.addEntryBlock(); op_builder.setInsertionPointToStart(entry_block); - op_builder.create(UnknownLoc::get(ctx), "tf.Where", + op_builder.create(UnknownLoc::get(ctx), "byteir.non_zero", ValueRange{entry_block->getArgument(0)}, ValueRange{entry_block->getArgument(1)}); op_builder.create(UnknownLoc::get(ctx)); diff --git a/runtime/test/include/brt/test/common/models.h b/runtime/test/include/brt/test/common/models.h index c4e80cb99..e006ac5e8 100644 --- a/runtime/test/include/brt/test/common/models.h +++ b/runtime/test/include/brt/test/common/models.h @@ -136,7 +136,7 @@ const void *CreateRepeat(brt::ir::ByREBuilder &byre_builder, DTypeEnum dataType, // always cuda const void *CreatePTXAddOp(brt::ir::ByREBuilder &byre_builder); -const void *CreateTFWhereOp(brt::ir::ByREBuilder &byre_builder, +const void *CreateNonZeroOp(brt::ir::ByREBuilder &byre_builder, DTypeEnum input_dtype, const std::vector &shape);