Skip to content

Commit

Permalink
Merge pull request #574 from jaybdub/roll_converter
Browse files Browse the repository at this point in the history
added converter for torch.roll
  • Loading branch information
jaybdub authored Jun 16, 2021
2 parents 9457b8c + b604881 commit 0963e49
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## [Master]

- Added converter for ``torch.roll``
- Added converter for ``torch.nn.functional.layer_norm``
- Added converter for ``torch.nn.functional.gelu``
- Added converter for ``torch.nn.functional.linear``
Expand Down
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from .prod import *
from .relu import *
from .relu6 import *
from .roll import *
from .sigmoid import *
from .silu import *
from .softmax import *
Expand Down
74 changes: 74 additions & 0 deletions torch2trt/converters/roll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.roll')
@tensorrt_converter('torch.Tensor.roll')
def convert_roll(ctx):
input = get_arg(ctx, 'input', 0, None)
shifts = get_arg(ctx, 'shifts', 1, None)
dims = get_arg(ctx, 'dims', 2, None)
output = ctx.method_return

assert dims is not None, "roll converter only supports roll when dims is specified"

ndim = input.ndim

input_trt = add_missing_trt_tensors(ctx.network, [input])[0]

try:
iter(shifts)
except:
shifts = (shifts,)
dims = (dims,)

start = [0] * ndim
shape = tuple([int(d) for d in input.shape])
stride = [1] * ndim

for s, d in zip(shifts, dims):
start[d] = (-s) % shape[d]

start = tuple(start[1:])
shape = tuple(shape[1:])
stride = tuple(stride[1:])


layer = ctx.network.add_slice(
input_trt,
start, # [1:] to exclude batch
shape,
stride
)
layer.mode = trt.SliceMode.WRAP

output._trt = layer.get_output(0)


class Roll(torch.nn.Module):

def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwargs = kwargs

def forward(self, x):
return torch.roll(x, *self.args, **self.kwargs)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 4)])
@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 5)])
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_roll_int():
return Roll(1, 1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 4, 5)])
@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_roll_int_dim():
return Roll(1, -2)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 3, 4, 5)])
def test_roll_tuple():
return Roll((2, 3), (1, 3))

0 comments on commit 0963e49

Please sign in to comment.