From 83b4184cc05269320a8284bf5326b30079f348d0 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Sat, 28 Dec 2024 15:18:25 +0000 Subject: [PATCH 1/5] add converter --- .../transforms/tensorrt/trt_op_marker_pass.cc | 48 +++ python/paddle/tensorrt/impls/ops.py | 44 ++- test/tensorrt/test_converter_ops.py | 286 ++++++++++++++++++ 3 files changed, 377 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 061d2488d7bdb8..4778b9d3f61647 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -92,6 +92,20 @@ DEFINE_GENERAL_PATTERN(Flip, paddle::dialect::FlipOp) DEFINE_GENERAL_PATTERN(Mish, paddle::dialect::MishOp) DEFINE_GENERAL_PATTERN(AssignValue, paddle::dialect::AssignValueOp) DEFINE_GENERAL_PATTERN(AssignValue_, paddle::dialect::AssignValue_Op) +DEFINE_GENERAL_PATTERN(Exp, paddle::dialect::ExpOp) +DEFINE_GENERAL_PATTERN(Abs, paddle::dialect::AbsOp) +DEFINE_GENERAL_PATTERN(Abs_, paddle::dialect::Abs_Op) +DEFINE_GENERAL_PATTERN(Sin, paddle::dialect::SinOp) +DEFINE_GENERAL_PATTERN(Cos, paddle::dialect::CosOp) +DEFINE_GENERAL_PATTERN(Sinh, paddle::dialect::SinhOp) +DEFINE_GENERAL_PATTERN(Cosh, paddle::dialect::CoshOp) +DEFINE_GENERAL_PATTERN(Asinh, paddle::dialect::AsinhOp) +DEFINE_GENERAL_PATTERN(Acosh, paddle::dialect::AcoshOp) +DEFINE_GENERAL_PATTERN(Atanh, paddle::dialect::AtanhOp) +DEFINE_GENERAL_PATTERN(Ceil, paddle::dialect::CeilOp) +DEFINE_GENERAL_PATTERN(Rsqrt, paddle::dialect::RsqrtOp) +DEFINE_GENERAL_PATTERN(Reciprocal, paddle::dialect::ReciprocalOp) +DEFINE_GENERAL_PATTERN(Erf, paddle::dialect::ErfOp) #undef DEFINE_GENERAL_PATTERN // Add ReduceCommonOpPattern base class to simplify code @@ -552,6 +566,24 @@ class SignOpPattern : public pir::OpRewritePattern { } }; +class RoundOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::RoundOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } +#if IS_TRT_VERSION_LT(8200) + VLOG(3) << "round op is only supported by tensorrt8.2 above "; + return false; +#endif + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; + class GroupNormOpPattern : public pir::OpRewritePattern { public: @@ -2157,6 +2189,21 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ADD_PATTERN(Mish) ADD_PATTERN(AssignValue) ADD_PATTERN(AssignValue_) + ADD_PATTERN(Exp) + ADD_PATTERN(Abs) + ADD_PATTERN(Abs_) + ADD_PATTERN(Cos) + ADD_PATTERN(Sin) + ADD_PATTERN(Cos) + ADD_PATTERN(Sinh) + ADD_PATTERN(Cosh) + ADD_PATTERN(Asinh) + ADD_PATTERN(Acosh) + ADD_PATTERN(Atanh) + ADD_PATTERN(Ceil) + ADD_PATTERN(Rsqrt) + ADD_PATTERN(Reciprocal) + ADD_PATTERN(Erf) #if IS_TRT_VERSION_GE(8600) ADD_PATTERN(Layer_norm) #endif @@ -2167,6 +2214,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); diff --git a/python/paddle/tensorrt/impls/ops.py b/python/paddle/tensorrt/impls/ops.py index 6416cb96e6af38..f8ad23189126fe 100644 --- a/python/paddle/tensorrt/impls/ops.py +++ b/python/paddle/tensorrt/impls/ops.py @@ -19,13 +19,55 @@ "pd_op.sqrt": trt.UnaryOperation.SQRT, "pd_op.sqrt_": trt.UnaryOperation.SQRT, "pd_op.floor": trt.UnaryOperation.FLOOR, + "pd_op.exp": trt.UnaryOperation.EXP, + "pd_op.abs": trt.UnaryOperation.ABS, + "pd_op.abs_": trt.UnaryOperation.ABS, + "pd_op.sin": trt.UnaryOperation.SIN, + "pd_op.cos": trt.UnaryOperation.COS, + "pd_op.sinh": trt.UnaryOperation.SINH, + "pd_op.cosh": trt.UnaryOperation.COSH, + "pd_op.asinh": trt.UnaryOperation.ASINH, + "pd_op.acosh": trt.UnaryOperation.ACOSH, + "pd_op.atanh": trt.UnaryOperation.ATANH, + "pd_op.ceil": trt.UnaryOperation.CEIL, + "pd_op.reciprocal": trt.UnaryOperation.RECIP, + "pd_op.erf": trt.UnaryOperation.ERF, + "pd_op.sign": trt.UnaryOperation.SIGN, + "pd_op.round": trt.UnaryOperation.ROUND, } @converter_registry.register("pd_op.sqrt", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.sqrt_", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.floor", trt_version="8.x") -def sqrt_converter(network, paddle_op, inputs): +@converter_registry.register("pd_op.exp", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.abs", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.abs_", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.sin", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.cos", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.sinh", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.cosh", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.asinh", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.acosh", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.atanh", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.ceil", trt_version="trt_version_ge=8.0") +@converter_registry.register( + "pd_op.reciprocal", trt_version="trt_version_ge=8.0" +) +@converter_registry.register("pd_op.erf", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.sign", trt_version="trt_version_ge=8.2") +@converter_registry.register("pd_op.round", trt_version="trt_version_ge=8.2") +def Unary_Op_converter(network, paddle_op, inputs): input_tensor = inputs[0] layer = network.add_unary(input_tensor, ops_type_map[paddle_op.name()]) return layer.get_output(0) + + +@converter_registry.register("pd_op.rsqrt", trt_version="trt_version_ge=8.0") +def Rsqrt_Op_converter(network, paddle_op, inputs): + input_tensor = inputs[0] + sqrt_layer = network.add_unary(input_tensor, trt.UnaryOperation.SQRT) + rsqrt_layer = network.add_unary( + sqrt_layer.get_output(0), trt.UnaryOperation.RECIP + ) + return rsqrt_layer.get_output(0) diff --git a/test/tensorrt/test_converter_ops.py b/test/tensorrt/test_converter_ops.py index 8bc188e3e5514b..c7668a747b1dee 100644 --- a/test/tensorrt/test_converter_ops.py +++ b/test/tensorrt/test_converter_ops.py @@ -33,6 +33,9 @@ def setUp(self): def test_trt_result(self): self.check_trt_result() + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + class TestFloorFloatTRTPattern(TensorRTBaseTest): def setUp(self): @@ -47,6 +50,289 @@ def setUp(self): def test_trt_result(self): self.check_trt_result() + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestExpFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.exp + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestExpIntTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.exp + self.api_args = { + "x": np.random.randn(7, 3).astype("int64"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestAbsFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.abs + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestAbsIntTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.abs + self.api_args = { + "x": np.random.randn(7, 3).astype("int64"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestSinFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.sin + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestCosFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.cos + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestSinhFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.sinh + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestCoshFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.cosh + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestAsinhFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.asinh + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestAcoshFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.acosh + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestCeilFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.ceil + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestRsqrtFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.rsqrt + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestReciprocalFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.reciprocal + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestErfFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.erf + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestSignFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.sign + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestSignIntTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.sign + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + + +class TestRoundFloatTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.round + self.api_args = { + "x": np.random.randn(7, 3).astype("float32"), + } + self.program_config = {"feed_list": ["x"]} + self.min_shape = {"x": [3, 3]} + self.max_shape = {"x": [10, 3]} + + def test_trt_result(self): + self.check_trt_result() + + def test_trt_result_fp16(self): + self.check_trt_result(precision_mode="fp16") + if __name__ == '__main__': unittest.main() From aaa308c8265aeee656f23afa1000c03a14989a01 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Mon, 30 Dec 2024 04:33:40 +0000 Subject: [PATCH 2/5] fix --- python/paddle/tensorrt/converter_utils.py | 30 ++++++++++++++-- python/paddle/tensorrt/impls/ops.py | 43 ++++------------------- test/tensorrt/test_converter_ops.py | 14 -------- 3 files changed, 33 insertions(+), 54 deletions(-) diff --git a/python/paddle/tensorrt/converter_utils.py b/python/paddle/tensorrt/converter_utils.py index 19ac9eb8698733..14f5af400e1aad 100644 --- a/python/paddle/tensorrt/converter_utils.py +++ b/python/paddle/tensorrt/converter_utils.py @@ -686,6 +686,29 @@ def squeeze_trt(network, input_tensor, axes): def unary_op_converter(network, paddle_op, inputs): from paddle.tensorrt import PrecisionMode + ops_type_map = { + "pd_op.sqrt": [trt.UnaryOperation.SQRT], + "pd_op.sqrt_": [trt.UnaryOperation.SQRT], + "pd_op.floor": [trt.UnaryOperation.FLOOR], + "pd_op.exp": [trt.UnaryOperation.EXP], + "pd_op.abs": [trt.UnaryOperation.ABS], + "pd_op.abs_": [trt.UnaryOperation.ABS], + "pd_op.sin": [trt.UnaryOperation.SIN], + "pd_op.cos": [trt.UnaryOperation.COS], + "pd_op.sinh": [trt.UnaryOperation.SINH], + "pd_op.cosh": [trt.UnaryOperation.COSH], + "pd_op.asinh": [trt.UnaryOperation.ASINH], + "pd_op.acosh": [trt.UnaryOperation.ACOSH], + "pd_op.atanh": [trt.UnaryOperation.ATANH], + "pd_op.ceil": [trt.UnaryOperation.CEIL], + "pd_op.reciprocal": [trt.UnaryOperation.RECIP], + "pd_op.erf": [trt.UnaryOperation.ERF], + "pd_op.sign": [trt.UnaryOperation.SIGN], + "pd_op.round": [trt.UnaryOperation.ROUND], + "pd_op.logical_not": [trt.UnaryOperation.NOT], + "pd_op.rsqrt": [trt.UnaryOperation.SQRT, trt.UnaryOperation.RECIP], + } + input_tensor = inputs[0] layer = None org_type = input_tensor.dtype @@ -707,9 +730,10 @@ def unary_op_converter(network, paddle_op, inputs): identity_layer.set_output_type(0, trt.float16) input_tensor = identity_layer.get_output(0) - if paddle_op.name() in ["pd_op.logical_not", "pd_op.logical_not_"]: - layer = network.add_unary(input_tensor, trt.UnaryOperation.NOT) - input_tensor = layer.get_output(0) + if paddle_op.name() in ops_type_map: + for trt_op in ops_type_map[paddle_op.name()]: + layer = network.add_unary(input_tensor, trt_op) + input_tensor = layer.get_output(0) else: raise NotImplementedError( f"Unsupported unary operation: {paddle_op.name()}" diff --git a/python/paddle/tensorrt/impls/ops.py b/python/paddle/tensorrt/impls/ops.py index f8ad23189126fe..7370f10edc1eeb 100644 --- a/python/paddle/tensorrt/impls/ops.py +++ b/python/paddle/tensorrt/impls/ops.py @@ -11,35 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tensorrt as trt +from paddle.tensorrt.converter_utils import unary_op_converter from paddle.tensorrt.register import converter_registry -ops_type_map = { - "pd_op.sqrt": trt.UnaryOperation.SQRT, - "pd_op.sqrt_": trt.UnaryOperation.SQRT, - "pd_op.floor": trt.UnaryOperation.FLOOR, - "pd_op.exp": trt.UnaryOperation.EXP, - "pd_op.abs": trt.UnaryOperation.ABS, - "pd_op.abs_": trt.UnaryOperation.ABS, - "pd_op.sin": trt.UnaryOperation.SIN, - "pd_op.cos": trt.UnaryOperation.COS, - "pd_op.sinh": trt.UnaryOperation.SINH, - "pd_op.cosh": trt.UnaryOperation.COSH, - "pd_op.asinh": trt.UnaryOperation.ASINH, - "pd_op.acosh": trt.UnaryOperation.ACOSH, - "pd_op.atanh": trt.UnaryOperation.ATANH, - "pd_op.ceil": trt.UnaryOperation.CEIL, - "pd_op.reciprocal": trt.UnaryOperation.RECIP, - "pd_op.erf": trt.UnaryOperation.ERF, - "pd_op.sign": trt.UnaryOperation.SIGN, - "pd_op.round": trt.UnaryOperation.ROUND, -} - @converter_registry.register("pd_op.sqrt", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.sqrt_", trt_version="trt_version_ge=8.0") -@converter_registry.register("pd_op.floor", trt_version="8.x") +@converter_registry.register("pd_op.floor", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.exp", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.abs", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.abs_", trt_version="trt_version_ge=8.0") @@ -55,19 +34,9 @@ "pd_op.reciprocal", trt_version="trt_version_ge=8.0" ) @converter_registry.register("pd_op.erf", trt_version="trt_version_ge=8.0") +@converter_registry.register("pd_op.rsqrt", trt_version="trt_version_ge=8.0") @converter_registry.register("pd_op.sign", trt_version="trt_version_ge=8.2") @converter_registry.register("pd_op.round", trt_version="trt_version_ge=8.2") -def Unary_Op_converter(network, paddle_op, inputs): - input_tensor = inputs[0] - layer = network.add_unary(input_tensor, ops_type_map[paddle_op.name()]) - return layer.get_output(0) - - -@converter_registry.register("pd_op.rsqrt", trt_version="trt_version_ge=8.0") -def Rsqrt_Op_converter(network, paddle_op, inputs): - input_tensor = inputs[0] - sqrt_layer = network.add_unary(input_tensor, trt.UnaryOperation.SQRT) - rsqrt_layer = network.add_unary( - sqrt_layer.get_output(0), trt.UnaryOperation.RECIP - ) - return rsqrt_layer.get_output(0) +def UnaryOpConverter(network, paddle_op, inputs): + layer_output = unary_op_converter(network, paddle_op, inputs) + return layer_output diff --git a/test/tensorrt/test_converter_ops.py b/test/tensorrt/test_converter_ops.py index c7668a747b1dee..d6f194c34a9dbc 100644 --- a/test/tensorrt/test_converter_ops.py +++ b/test/tensorrt/test_converter_ops.py @@ -71,20 +71,6 @@ def test_trt_result_fp16(self): self.check_trt_result(precision_mode="fp16") -class TestExpIntTRTPattern(TensorRTBaseTest): - def setUp(self): - self.python_api = paddle.exp - self.api_args = { - "x": np.random.randn(7, 3).astype("int64"), - } - self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [3, 3]} - self.max_shape = {"x": [10, 3]} - - def test_trt_result(self): - self.check_trt_result() - - class TestAbsFloatTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.abs From 02d80515863438e6ff2287de250c843328f15629 Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Tue, 31 Dec 2024 06:37:05 +0000 Subject: [PATCH 3/5] add marker --- .../transforms/tensorrt/trt_op_marker_pass.cc | 68 ++++++++----------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 4778b9d3f61647..cf030197c91fea 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -106,6 +106,8 @@ DEFINE_GENERAL_PATTERN(Ceil, paddle::dialect::CeilOp) DEFINE_GENERAL_PATTERN(Rsqrt, paddle::dialect::RsqrtOp) DEFINE_GENERAL_PATTERN(Reciprocal, paddle::dialect::ReciprocalOp) DEFINE_GENERAL_PATTERN(Erf, paddle::dialect::ErfOp) +DEFINE_GENERAL_PATTERN(Sign, paddle::dialect::SignOp) +DEFINE_GENERAL_PATTERN(Round, paddle::dialect::RoundOp) #undef DEFINE_GENERAL_PATTERN // Add ReduceCommonOpPattern base class to simplify code @@ -277,8 +279,30 @@ class ActOpPattern : public pir::OpRewritePattern { }; using TanhOpPattern = ActOpPattern; using CeluOpPattern = ActOpPattern; -using LogicalNotOpPattern = ActOpPattern; -using LogicalNot_OpPattern = ActOpPattern; + +template +class Logical_NotOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(OpType op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->template attribute(kCanRunTrtAttr).data()) { + return false; + } + pir::Value x = op.operand_source(0); + auto x_dtype = pir::GetDataTypeFromValue(x); + if (!x_dtype.isa()) { + VLOG(3) << " logical_not op only support bool input in tensorrt."; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; +using LogicalNotOpPattern = Logical_NotOpPattern; +using LogicalNot_OpPattern = + Logical_NotOpPattern; class Pool2dOpPattern : public pir::OpRewritePattern { @@ -548,42 +572,6 @@ class ArangeOpPattern } }; -class SignOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::SignOp op, - pir::PatternRewriter &rewriter) const override { - if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { - return false; - } -#if IS_TRT_VERSION_LT(8200) - VLOG(3) << "sign op is only supported by tensorrt8.2 above "; - return false; -#endif - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); - return true; - } -}; - -class RoundOpPattern : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::RoundOp op, - pir::PatternRewriter &rewriter) const override { - if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { - return false; - } -#if IS_TRT_VERSION_LT(8200) - VLOG(3) << "round op is only supported by tensorrt8.2 above "; - return false; -#endif - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); - return true; - } -}; - class GroupNormOpPattern : public pir::OpRewritePattern { public: @@ -2204,6 +2192,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ADD_PATTERN(Rsqrt) ADD_PATTERN(Reciprocal) ADD_PATTERN(Erf) + ADD_PATTERN(Sign) + ADD_PATTERN(Round) #if IS_TRT_VERSION_GE(8600) ADD_PATTERN(Layer_norm) #endif @@ -2213,8 +2203,6 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); - ps.Add(std::make_unique(context)); - ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); From 097672256ebebdb98cfcd5ef8368851bf142c67a Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Thu, 2 Jan 2025 08:40:54 +0000 Subject: [PATCH 4/5] fix --- test/tensorrt/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensorrt/CMakeLists.txt b/test/tensorrt/CMakeLists.txt index 4735dc6def3345..201a1e02f2f3f7 100644 --- a/test/tensorrt/CMakeLists.txt +++ b/test/tensorrt/CMakeLists.txt @@ -14,7 +14,7 @@ if(NOT WIN32 AND TENSORRT_FOUND) set_tests_properties(test_converter_conv PROPERTIES TIMEOUT "300") set_tests_properties(test_export PROPERTIES TIMEOUT "500") set_tests_properties(test_converter_norm PROPERTIES TIMEOUT "300") - set_tests_properties(test_converter_ops PROPERTIES TIMEOUT "300") + set_tests_properties(test_converter_ops PROPERTIES TIMEOUT "500") set_tests_properties(test_converter_stat PROPERTIES TIMEOUT "300") set_tests_properties(test_converter_math PROPERTIES TIMEOUT "300") set_tests_properties(test_converter_activation PROPERTIES TIMEOUT "300") From 0121ca41986af287b6fdb9d44f4640ae8c2c077a Mon Sep 17 00:00:00 2001 From: fangfangssj <1135470306@qq.com> Date: Tue, 7 Jan 2025 13:22:05 +0000 Subject: [PATCH 5/5] fix --- test/tensorrt/test_converter_ops.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/tensorrt/test_converter_ops.py b/test/tensorrt/test_converter_ops.py index e2de2be2ad0384..155a93d2827a19 100644 --- a/test/tensorrt/test_converter_ops.py +++ b/test/tensorrt/test_converter_ops.py @@ -64,6 +64,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -81,6 +82,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -98,6 +100,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -112,6 +115,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -129,6 +133,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -146,6 +151,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -163,6 +169,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -180,6 +187,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -197,6 +205,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -214,6 +223,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -231,6 +241,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -248,6 +259,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -265,6 +277,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -282,6 +295,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -296,6 +310,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self): @@ -313,6 +328,7 @@ def setUp(self): } self.program_config = {"feed_list": ["x"]} self.min_shape = {"x": [3, 3]} + self.opt_shape = {"x": [7, 3]} self.max_shape = {"x": [10, 3]} def test_trt_result(self):