From 9be8f1bc3bc85505c2398755e0435859cc5a5cb9 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Thu, 4 Jan 2024 18:56:07 +0100 Subject: [PATCH 1/3] Add convolution for torch script --- .../converters/mil/frontend/torch/ops.py | 32 ++++++-- .../mil/frontend/torch/test/test_torch_ops.py | 82 +++++++++---------- 2 files changed, 67 insertions(+), 47 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 0755d4d9d..7bb90ab10 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -217,7 +217,10 @@ def get_bindings(alist) -> List[Any]: for i in alist: if isinstance(i, str): - results.append(context[i]) + try: + results.append(context[i]) + except ValueError: + results.append(None) elif isinstance(i, (list, tuple)) and all(isinstance(j, int) for j in i): results.append(mb.const(val=i)) elif isinstance(i, (list, tuple)): @@ -962,7 +965,7 @@ def linear(context, node): context.add(res, torch_name=node.name) -@register_torch_op(torch_alias=["conv2d", "convolution"]) +@register_torch_op(torch_alias=["convolution", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"]) def _convolution(context, node): inputs = _get_inputs(context, node) @@ -980,11 +983,25 @@ def _convolution(context, node): # we require a (2 * n)-tuple, where n is the number of spatial dimensions, start and end for each spatial dimension pad = inputs[4].val - if len(weight.shape) in (3, 4): - # 1D and 2D: Need to explicitly state L-R, T-B pad + if type(pad) == str: + if pad == "same": + pad = 1 + elif pad == "valid": + pad = 0 + else: + raise ValueError(f"Unkown padding string value: '{pad}'") + + if len(weight.shape) == 3: + # 1D padding: needs explicitly state L-R for x dim pad = _np.repeat(pad, 2) + elif len(weight.shape) == 4: + # 2D padding: needs explicitly state L-R for x,y dims + if type(pad) == int: + pad = _np.repeat(pad, 4) + elif len(pad) == 2: + pad = _np.repeat(pad, 2) elif len(weight.shape) == 5: - # 3D: Need to explicitly state F-Bk, L-R, T-B pad + # 3D padding: needs explicitly state L-R for x,y,z dims if type(pad) == int: pad = _np.repeat(pad, 6) elif len(pad) == 3: @@ -1000,6 +1017,11 @@ def _convolution(context, node): transposed = inputs[6].val out_pad = inputs[7].val group = inputs[8] + elif len(inputs) == 8: + transposed = True + out_pad = inputs[5].val + dilations = inputs[7] + group = inputs[6] elif len(inputs) == 7: transposed = False group = inputs[6] diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index 242ec8740..b28c5d9c9 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -226,29 +226,6 @@ def forward(self, x): use_scripting=True, ) - @pytest.mark.parametrize("compute_unit, backend", itertools.product(compute_units, backends)) - def test_conv(self, compute_unit, backend): - pytest.xfail( - "rdar://88194776 ([Converter] coremltools is not working with scripted torch convolution model)" - ) - model = torch.nn.Conv2d( - in_channels=2, - out_channels=3, - kernel_size=1, - padding="same", - stride=1, - dilation=1, - groups=1, - bias=False, - ) - self.run_compare_torch( - (1, 2, 4, 5), - model, - backend=backend, - compute_unit=compute_unit, - use_scripting=True, - ) - class TestMean(TorchBaseTest): @pytest.mark.parametrize( @@ -1456,6 +1433,7 @@ class TestConv(TorchBaseTest): [ "compute_unit", "backend", + "scripting", "padding", "stride", "length", @@ -1467,10 +1445,11 @@ class TestConv(TorchBaseTest): ] ), [ - (compute_unit, backend, padding, stride, *param) - for compute_unit, backend, padding, stride, param in itertools.product( + (compute_unit, backend, scripting, padding, stride, *param) + for compute_unit, backend, scripting, padding, stride, param in itertools.product( [ct.ComputeUnit.CPU_ONLY], backends, + [True, False], ["same", "valid", 0, 1], [1, 2, 3], [ @@ -1490,6 +1469,7 @@ def test_convolution1d( self, compute_unit, backend, + scripting, padding, stride, length, @@ -1503,6 +1483,7 @@ def test_convolution1d( if padding == "same" and stride != 1: # configuration not supported return + model = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, @@ -1511,12 +1492,14 @@ def test_convolution1d( padding=padding, dilation=dilation, bias=bias, + groups=groups, ) self.run_compare_torch( (1, in_channels, length), model, backend=backend, compute_unit=compute_unit, + use_scripting=scripting, ) @pytest.mark.parametrize( @@ -1524,6 +1507,7 @@ def test_convolution1d( [ "compute_unit", "backend", + "scripting", "padding", "stride", "height", @@ -1536,10 +1520,11 @@ def test_convolution1d( ] ), [ - (compute_unit, backend, padding, stride, *param) - for compute_unit, backend, padding, stride, param in itertools.product( + (compute_unit, backend, scripting, padding, stride, *param) + for compute_unit, backend, scripting, padding, stride, param in itertools.product( [ct.ComputeUnit.CPU_ONLY], backends, + [True, False], ["same", "valid", 1, 0], [1, 2, 3], [ @@ -1559,6 +1544,7 @@ def test_convolution2d( self, compute_unit, backend, + scripting, padding, stride, height, @@ -1571,7 +1557,9 @@ def test_convolution2d( groups=1, ): if padding == "same" and stride != 1: + # configuration not supported return + model = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, @@ -1580,12 +1568,14 @@ def test_convolution2d( padding=padding, dilation=dilation, bias=bias, + groups=groups, ) self.run_compare_torch( (1, in_channels, height, width), model, backend=backend, compute_unit=compute_unit, + use_scripting=scripting, ) @pytest.mark.parametrize( @@ -1593,6 +1583,7 @@ def test_convolution2d( [ "compute_unit", "backend", + "scripting", "padding", "stride", "depth", @@ -1606,10 +1597,11 @@ def test_convolution2d( ] ), [ - (compute_unit, backend, padding, stride, *param) - for compute_unit, backend, padding, stride, param in itertools.product( + (compute_unit, backend, scripting, padding, stride, *param) + for compute_unit, backend, scripting, padding, stride, param in itertools.product( [ct.ComputeUnit.CPU_ONLY], backends, + [True, False], ["same", "valid", 1, 0], [1, 2, 3], [ @@ -1629,6 +1621,7 @@ def test_convolution3d( self, compute_unit, backend, + scripting, padding, stride, depth, @@ -1642,52 +1635,57 @@ def test_convolution3d( groups=1, ): if padding == "same" and stride != 1: + # configuration not supported return + model = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + bias=bias, stride=stride, padding=padding, dilation=dilation, - bias=bias, + groups=groups, ) self.run_compare_torch( (1, in_channels, depth, height, width), model, backend=backend, compute_unit=compute_unit, + use_scripting=scripting, ) -class TestDynamicConv(TorchBaseTest): +class TestFunctionalConv(TorchBaseTest): @pytest.mark.parametrize( ",".join( [ "compute_unit", "backend", + "padding", "width", "in_channels", "out_channels", "kernel_size", "stride", - "padding", ] ), [ - (compute_unit, backend, *param) - for compute_unit, backend, param in itertools.product( + (compute_unit, backend, padding, *param) + for compute_unit, backend, padding, param in itertools.product( compute_units, backends, + ["same", "valid", 1, 0], [ - (5, 1, 1, 1, 2, 1), - (3, 1, 1, 1, 2, 3), - (4, 3, 3, 1, 2, 1), - (7, 3, 3, 1, 3, 1), - (5, 3, 3, 2, 2, 1), - (3, 3, 3, 1, 3, 1), - (3, 3, 3, 1, 3, 3), - (7, 3, 3, 3, 1, 3), + (5, 1, 1, 1, 2), + (3, 1, 1, 1, 2), + (4, 3, 3, 1, 2), + (7, 3, 3, 1, 3), + (5, 3, 3, 2, 2), + (3, 3, 3, 1, 3), + (3, 3, 3, 1, 3), + (7, 3, 3, 3, 1), ], ) ], From f917759bce5e1004bb198f393acc3bb11b78fe75 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Thu, 4 Jan 2024 19:01:42 +0100 Subject: [PATCH 2/3] Improve Functional convolution tests --- .../mil/frontend/torch/test/test_torch_ops.py | 132 +++++++++++++++--- 1 file changed, 112 insertions(+), 20 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index b28c5d9c9..09b207480 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -1694,21 +1694,25 @@ def test_convolution1d( self, compute_unit, backend, + padding, width, in_channels, out_channels, kernel_size, stride, - padding, groups=1, ): - class DynamicConv(nn.Module): + if padding == "same" and stride != 1: + # configuration not supported + return + + class FunctionalConv1D(nn.Module): def forward(self, input_data, weights): return nn.functional.conv1d( - input_data, weights, stride=stride, padding=padding + input_data, weights, stride=stride, padding=padding, groups=groups ) - model = DynamicConv() + model = FunctionalConv1D() input_shape = [ (1, in_channels, width), (out_channels, int(in_channels / groups), kernel_size), @@ -1725,29 +1729,30 @@ def forward(self, input_data, weights): [ "compute_unit", "backend", + "padding", "height", "width", "in_channels", "out_channels", "kernel_size", "stride", - "padding", ] ), [ - (compute_unit, backend, *param) - for compute_unit, backend, param in itertools.product( + (compute_unit, backend, padding, *param) + for compute_unit, backend, padding, param in itertools.product( compute_units, backends, + ["same", "valid", 1, 0], [ - (5, 3, 1, 1, 1, 2, 0), - (3, 3, 1, 1, 1, 2, 1), - (4, 3, 3, 3, 1, 2, 0), - (7, 3, 3, 3, 1, 3, 0), - (5, 5, 3, 3, 2, 1, 0), - (3, 5, 3, 3, 1, 3, 0), - (3, 5, 3, 3, 1, 3, 1), - (7, 5, 3, 3, 2, 3, 1), + (5, 3, 1, 1, 1, 2), + (3, 3, 1, 1, 1, 2), + (4, 3, 3, 3, 1, 2), + (7, 3, 3, 3, 1, 3), + (5, 5, 3, 3, 2, 1), + (3, 5, 3, 3, 1, 3), + (3, 5, 3, 3, 1, 3), + (7, 5, 3, 3, 2, 3), ], ) ], @@ -1756,31 +1761,118 @@ def test_convolution2d( self, compute_unit, backend, + padding, height, width, in_channels, out_channels, kernel_size, stride, - padding, groups=1, ): - class DynamicConv(nn.Module): + if padding == "same" and stride != 1: + # configuration not supported + return + + class FunctionalConv2D(nn.Module): def forward(self, input_data, weights): return nn.functional.conv2d( - input_data, weights, stride=stride, padding=padding + input_data, weights, stride=stride, padding=padding, groups=groups ) - model = DynamicConv() + model = FunctionalConv2D() input_shape = [ (1, in_channels, height, width), (out_channels, int(in_channels / groups), kernel_size, kernel_size), ] self.run_compare_torch( - input_shape, model, backend=backend, compute_unit=compute_unit + input_shape, + model, + backend=backend, + compute_unit=compute_unit, ) + @pytest.mark.parametrize( + ",".join( + [ + "compute_unit", + "backend", + "padding", + "depth", + "height", + "width", + "in_channels", + "out_channels", + "kernel_size", + "stride", + ] + ), + [ + (compute_unit, backend, padding, *param) + for compute_unit, backend, padding, param in itertools.product( + compute_units, + backends, + ["same", "valid", 1, 0], + [ + (5, 3, 2, 1, 1, 1, 2), + (3, 3, 1, 1, 1, 1, 2), + (4, 3, 3, 3, 3, 1, 2), + (7, 3, 4, 3, 3, 1, 3), + (5, 5, 3, 3, 3, 2, 1), + (3, 5, 1, 3, 3, 1, 3), + (3, 5, 4, 3, 3, 1, 3), + (7, 5, 6, 3, 3, 2, 3), + ], + ) + ], + ) + def test_convolution3d( + self, + compute_unit, + backend, + padding, + depth, + height, + width, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + ): + if padding == "same" and stride != 1: + # configuration not supported + return + + class FunctionalConv3D(nn.Module): + def forward(self, input_data, weights): + return nn.functional.conv3d( + input_data, weights, stride=stride, padding=padding, groups=groups + ) + + model = FunctionalConv3D() + input_shape = [ + (1, in_channels, depth, height, width), + (out_channels, int(in_channels / groups), kernel_size, kernel_size, kernel_size), + ] + + if "neuralnetwork" in backend: + with pytest.raises(ValueError, match="3D Convolution doesn't support dynamic weights."): + self.run_compare_torch( + input_shape, + model, + backend=backend, + compute_unit=compute_unit, + ) + else: + self.run_compare_torch( + input_shape, + model, + backend=backend, + compute_unit=compute_unit, + ) + class TestConvTranspose(TorchBaseTest): @pytest.mark.parametrize( From 79fd9705c7de2b098f25d415c2b5f8e284ff831d Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Tue, 9 Jan 2024 18:52:52 +0100 Subject: [PATCH 3/3] Add comment about conv_transpose --- coremltools/converters/mil/frontend/torch/ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 7bb90ab10..1ab8de2a0 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -965,7 +965,8 @@ def linear(context, node): context.add(res, torch_name=node.name) -@register_torch_op(torch_alias=["convolution", "conv1d", "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d"]) +# NOTE: This function is also an alias of: ["conv_transpose1d", "conv_transpose2d", "conv_transpose3d"] but we lack tests for those +@register_torch_op(torch_alias=["convolution", "conv1d", "conv2d", "conv3d"]) def _convolution(context, node): inputs = _get_inputs(context, node)