Skip to content

Commit

Permalink
[converter] add aten::frobenius_norm support (#234)
Browse files Browse the repository at this point in the history
* [converter] add aten::frobenius_norm support

* reuse code
  • Loading branch information
Juelianqvq authored Jul 11, 2023
1 parent 21dd01c commit 4be57e2
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/op_matrix.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Operators that are implemented in Python
| `aten::flip` | |
| `aten::floor` | |
| `aten::floor_divide` | floor_divide for floats is not supported<br>floor_divide for negative numbers is not supported |
| `aten::frobenius_norm` | |
| `aten::gather` | |
| `aten::ge` | |
| `aten::gelu` | |
Expand Down
50 changes: 50 additions & 0 deletions tests/converter_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5231,6 +5231,56 @@ def model(x):
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output)

@unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported")
def test_frobenius_norm(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return torch.norm(x)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input).view(1)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
print(tfl_output.shape)
print(dummy_output, tfl_output)
assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0)

@unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported")
def test_frobenius_norm_with_dim(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return torch.norm(x, dim=0)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0)

@unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported")
def test_frobenius_norm_with_dim_keepdim(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)

def model(x):
return torch.norm(x, dim=0, keepdim=True)

model_path = get_model_path()

converter = TFLiteConverter(model, dummy_input, model_path, nchw_transpose=False)
converter.convert()

dummy_output = model(dummy_input)
tfl_output = tfl_run_model(model_path, dummy_input, dummy_output)
assert_close(dummy_output, tfl_output, atol=256.0, rtol=256.0)

@unittest.skipIf(not hasattr(torch, 'norm'), "Norm is not supported")
def test_norm_p1(self):
dummy_input = torch.randn(10, 10, dtype=torch.float32)
Expand Down
1 change: 1 addition & 0 deletions tinynn/converter/operators/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@
"aten::roll": ATenRollOperator,
"aten::round": ATenRoundOperator,
"aten::norm": ATenNormOperator,
"aten::frobenius_norm": ATenFrobeniusNormOperator,
"aten::scatter_": ATenScatterOperator,
"aten::abs": ATenAbsOperator,
"aten::im2col": ATenIm2colOperator,
Expand Down
14 changes: 14 additions & 0 deletions tinynn/converter/operators/torch/aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3770,6 +3770,20 @@ def parse(self, node, attrs, args, graph_converter):
self.parse_common(node, attrs, args, graph_converter)


class ATenFrobeniusNormOperator(ATenFrobeniusNormSchema):
def parse_common(self, node, attrs, args, graph_converter):

assert 'p' not in args
self.input_tensors.insert(1, 2)
ATenNormOperator.parse_common(self, node, attrs, args, graph_converter)

def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)

self.run(node)
self.parse_common(node, attrs, args, graph_converter)


class ATenLinalgVectorNormOperator(ATenLinalgVectorNormSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
Expand Down

0 comments on commit 4be57e2

Please sign in to comment.