diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 14c25ec8ab..066d0b2588 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1714,6 +1714,23 @@ def aten_ops_isnan( ) +@dynamo_tensorrt_converter(torch.ops.aten._local_scalar_dense.default) +def aten_ops_local_scalar_dense( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.local_scalar_dense( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(operator.add, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index beb13fca9b..dbdbb332c5 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -571,3 +571,26 @@ def isnan( ) return nan_values_mask + + +def local_scalar_dense( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + start = [0] * len(input.shape) + shape = [1] * len(input.shape) # Get one element from each dimension + stride = [1] * len(input.shape) # Step through each dimension by 1 + + layer = ctx.net.add_slice(input=input, start=start, shape=shape, stride=stride) + set_layer_name(layer, target, f"{name}_slice", source_ir) + + reshape_layer = ctx.net.add_shuffle(layer.get_output(0)) + reshape_layer.reshape_dims = [ + 1, + ] # Reshape to a single-element tensor + set_layer_name(reshape_layer, target, f"{name}_reshape", source_ir) + + return reshape_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py new file mode 100644 index 0000000000..7817fc0ab7 --- /dev/null +++ b/tests/py/dynamo/conversion/test_local_scalar_dense_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestLocalScalarDenseConverter(DispatchTestCase): + @parameterized.expand( + [ + (torch.randn((5, 10, 5), dtype=torch.float32),), + (torch.randint(-10, 10, (5, 1, 15), dtype=torch.int32),), + (torch.randn((1), dtype=torch.float32),), + ((torch.tensor([-2.4])),), + ((torch.tensor([5.5, 3.5, 3.6])),), + ((torch.tensor([True])),), + ( + torch.tensor( + [ + float("nan"), + 1.23, + float("inf"), + ] + ), + ), + ( + torch.tensor( + [ + float("-inf"), + 1.23, + float("nan"), + ] + ), + ), + ((torch.tensor([float("inf")])),), + ] + ) + def test_local_scalar_dense(self, data): + class local_scalar_dense(nn.Module): + def forward(self, input): + return torch.ops.aten._local_scalar_dense.default(input) + + inputs = [data] + self.run_test( + local_scalar_dense(), + inputs, + ) + + +if __name__ == "__main__": + run_tests()