diff --git a/test/tensorrt/test_converter_activation.py b/test/tensorrt/test_converter_activation.py index 3e385bd57e993..f07a1a304f26f 100644 --- a/test/tensorrt/test_converter_activation.py +++ b/test/tensorrt/test_converter_activation.py @@ -20,7 +20,7 @@ import paddle -class TestEluTRTPatternCase1(TensorRTBaseTest): +class TestEluTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.nn.functional.elu self.api_args = { @@ -35,21 +35,8 @@ def setUp(self): def test_trt_result(self): self.check_trt_result() - -class TestEluTRTPatternCase2(TensorRTBaseTest): - def setUp(self): - self.python_api = paddle.nn.functional.elu - self.api_args = { - "x": np.random.randn(3).astype("float16"), - "alpha": 1.0, - } - self.program_config = {"feed_list": ["x"]} - self.min_shape = {"x": [1]} - self.opt_shape = {"x": [1]} - self.max_shape = {"x": [5]} - - def test_trt_result(self): - self.check_trt_result() + def test_trt_result_fp16(self): + self.check_trt_result(rtol=1e-3, atol=1e-3, precision_mode="fp16") class TestHardSigmoidTRTPattern(TensorRTBaseTest):