From 715c3bfbd8526fc2376cfd56ed2934ebf3b9aaed Mon Sep 17 00:00:00 2001 From: John Welsh Date: Tue, 29 Nov 2022 18:51:16 +0000 Subject: [PATCH] added flip converter --- torch2trt/converters/__init__.py | 1 + torch2trt/converters/flip.py | 30 ++++++++++++++++ torch2trt/tests/test_flip.py | 62 ++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+) create mode 100644 torch2trt/converters/flip.py create mode 100644 torch2trt/tests/test_flip.py diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 19eb3eb6..c4bf5a7d 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -34,6 +34,7 @@ from .expand import * from .example_plugin import * from .flatten import * +from .flip import * from .floordiv import * from .gelu import * from .getitem import * diff --git a/torch2trt/converters/flip.py b/torch2trt/converters/flip.py new file mode 100644 index 00000000..61fec07d --- /dev/null +++ b/torch2trt/converters/flip.py @@ -0,0 +1,30 @@ +from torch2trt import torch2trt, tensorrt_converter, get_arg, trt, make_size_wrapper + + +@tensorrt_converter("torch.Tensor.flip") +@tensorrt_converter("torch.flip") +def convert_flip(ctx): + + input = get_arg(ctx, 'input', 0, None) + dims = get_arg(ctx, 'dims', 1, None) + output = ctx.method_return + + input_shape_trt = ctx.network.add_shape(input._trt).get_output(0) + + offset = [0 for i in range(input.ndim)] + stride = [1 for i in range(input.ndim)] + shape = tuple(input.size()) + for d in dims: + offset[d] = -1 + stride[d] = -1 + + layer = ctx.network.add_slice( + input._trt, + offset, + shape, + stride + ) + layer.set_input(2, input_shape_trt) + layer.mode = trt.SliceMode.WRAP + + output._trt = layer.get_output(0) diff --git a/torch2trt/tests/test_flip.py b/torch2trt/tests/test_flip.py new file mode 100644 index 00000000..fac8935c --- /dev/null +++ b/torch2trt/tests/test_flip.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn + +from torch2trt import torch2trt + + +class FlipModule(nn.Module): + + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return torch.flip(x, self.dims) + + +class FlipTensorModule(nn.Module): + + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return x.flip(self.dims) + + + +def test_torch_flip(): + + x = torch.randn(1, 2, 3).cuda() + + model = FlipModule(dims=(1,)).cuda().eval() + model_trt = torch2trt(model, [x]) + + out = model(x) + out_trt = model_trt(x) + + assert torch.allclose(out, out_trt, rtol=1e-4, atol=1e-4) + +def test_torch_flip_multidim(): + + x = torch.randn(1, 2, 3).cuda() + + model = FlipTensorModule(dims=(1, 2)).cuda().eval() + model_trt = torch2trt(model, [x]) + + out = model(x) + out_trt = model_trt(x) + + assert torch.allclose(out, out_trt, rtol=1e-4, atol=1e-4) + +def test_torch_flip_tensor(): + + x = torch.randn(1, 2, 3).cuda() + + model = FlipTensorModule(dims=(1,)).cuda().eval() + model_trt = torch2trt(model, [x]) + + out = model(x) + out_trt = model_trt(x) + + assert torch.allclose(out, out_trt, rtol=1e-4, atol=1e-4)