From e9ee21fc1d64590f941fb96bff717537c3f6b3a6 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 24 Feb 2022 16:10:42 +0800 Subject: [PATCH] support export hardsigmoid in torch<=1.8 (#169) * support export hardsigmoid in torch<=1.8 * fix lint --- mmdeploy/pytorch/ops/__init__.py | 4 +++- mmdeploy/pytorch/ops/hardsigmoid.py | 12 ++++++++++++ tests/test_pytorch/test_pytorch_ops.py | 7 +++++++ 3 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 mmdeploy/pytorch/ops/hardsigmoid.py diff --git a/mmdeploy/pytorch/ops/__init__.py b/mmdeploy/pytorch/ops/__init__.py index 48f9c9013..0608aadf7 100644 --- a/mmdeploy/pytorch/ops/__init__.py +++ b/mmdeploy/pytorch/ops/__init__.py @@ -3,6 +3,7 @@ adaptive_avg_pool2d__default, adaptive_avg_pool3d__default) from .grid_sampler import grid_sampler__default +from .hardsigmoid import hardsigmoid__default from .instance_norm import instance_norm__tensorrt from .lstm import generic_rnn__ncnn from .squeeze import squeeze__default @@ -10,5 +11,6 @@ __all__ = [ 'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default', 'adaptive_avg_pool3d__default', 'grid_sampler__default', - 'instance_norm__tensorrt', 'generic_rnn__ncnn', 'squeeze__default' + 'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn', + 'squeeze__default' ] diff --git a/mmdeploy/pytorch/ops/hardsigmoid.py b/mmdeploy/pytorch/ops/hardsigmoid.py new file mode 100644 index 000000000..a4d14173e --- /dev/null +++ b/mmdeploy/pytorch/ops/hardsigmoid.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# Modified from: +# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py +from mmdeploy.core import SYMBOLIC_REWRITER + + +@SYMBOLIC_REWRITER.register_symbolic( + 'hardsigmoid', is_pytorch=True, arg_descriptors=['v']) +def hardsigmoid__default(ctx, g, self): + """Support export hardsigmoid This rewrite enable export hardsigmoid in + torch<=1.8.2.""" + return g.op('HardSigmoid', self, alpha_f=1 / 6) diff --git a/tests/test_pytorch/test_pytorch_ops.py b/tests/test_pytorch/test_pytorch_ops.py index 69c9e12ed..9a0314881 100644 --- a/tests/test_pytorch/test_pytorch_ops.py +++ b/tests/test_pytorch/test_pytorch_ops.py @@ -116,3 +116,10 @@ def test_squeeze(self): nodes = get_model_onnx_nodes(model, x) assert nodes[0].attribute[0].ints == [0] assert nodes[0].op_type == 'Squeeze' + + +def test_hardsigmoid(): + x = torch.rand(1, 2, 3, 4) + model = torch.nn.Hardsigmoid().eval() + nodes = get_model_onnx_nodes(model, x) + assert nodes[0].op_type == 'HardSigmoid'