diff --git a/docs/op_matrix.md b/docs/op_matrix.md index 72ace2a9..13ffd160 100644 --- a/docs/op_matrix.md +++ b/docs/op_matrix.md @@ -76,6 +76,7 @@ Operators that are implemented in Python | `aten::flip` | | | `aten::floor` | | | `aten::floor_divide` | floor_divide for floats is not supported
floor_divide for negative numbers is not supported | +| `aten::frobenius_norm` | | | `aten::gather` | | | `aten::ge` | | | `aten::gelu` | | diff --git a/tests/converter_op_test.py b/tests/converter_op_test.py index 856233fc..5c706c98 100644 --- a/tests/converter_op_test.py +++ b/tests/converter_op_test.py @@ -5231,6 +5231,56 @@ def model(x): tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) assert_close(dummy_output, tfl_output) + @unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported") + def test_frobenius_norm(self): + dummy_input = torch.randn(10, 10, dtype=torch.float32) + + def model(x): + return torch.norm(x) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input).view(1) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + print(tfl_output.shape) + print(dummy_output, tfl_output) + assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0) + + @unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported") + def test_frobenius_norm_with_dim(self): + dummy_input = torch.randn(10, 10, dtype=torch.float32) + + def model(x): + return torch.norm(x, dim=0) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0) + + @unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported") + def test_frobenius_norm_with_dim_keepdim(self): + dummy_input = torch.randn(10, 10, dtype=torch.float32) + + def model(x): + return torch.norm(x, dim=0, keepdim=True) + + model_path = get_model_path() + + converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False) + converter.convert() + + dummy_output = model(dummy_input) + tfl_output = tfl_run_model(model_path, dummy_input, dummy_output) + assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0) + @unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported") def test_norm_p1(self): dummy_input = torch.randn(10, 10, dtype=torch.float32) diff --git a/tinynn/converter/operators/torch/__init__.py b/tinynn/converter/operators/torch/__init__.py index cc7d781a..e20edaf0 100644 --- a/tinynn/converter/operators/torch/__init__.py +++ b/tinynn/converter/operators/torch/__init__.py @@ -168,6 +168,7 @@ "aten::roll": ATenRollOperator, "aten::round": ATenRoundOperator, "aten::norm": ATenNormOperator, + "aten::frobenius_norm": ATenFrobeniusNormOperator, "aten::scatter_": ATenScatterOperator, "aten::abs": ATenAbsOperator, "aten::im2col": ATenIm2colOperator, diff --git a/tinynn/converter/operators/torch/aten.py b/tinynn/converter/operators/torch/aten.py index 4fcbd90d..695a9b2f 100644 --- a/tinynn/converter/operators/torch/aten.py +++ b/tinynn/converter/operators/torch/aten.py @@ -3770,6 +3770,20 @@ def parse(self, node, attrs, args, graph_converter): self.parse_common(node, attrs, args, graph_converter) +class ATenFrobeniusNormOperator(ATenFrobeniusNormSchema): + def parse_common(self, node, attrs, args, graph_converter): + + assert 'p' not in args + self.input_tensors.insert(1, 2) + ATenNormOperator.parse_common(self, node, attrs, args, graph_converter) + + def parse(self, node, attrs, args, graph_converter): + super().parse(node, attrs, args, graph_converter) + + self.run(node) + self.parse_common(node, attrs, args, graph_converter) + + class ATenLinalgVectorNormOperator(ATenLinalgVectorNormSchema): def parse(self, node, attrs, args, graph_converter): super().parse(node, attrs, args, graph_converter)