-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF 2.11.0 MaxPool2D padding incorrect dim calculation bug prevents models from being loaded #258
Comments
I can confirm that this also occurs on Tensorflow 2.11 (which is the Android dependency version for |
I have narrowed it down to a problem with how MaxPool2D gets prepared:
Testing with Tensorflow 2.11.0 this fails:
With:
Which means any module that uses MaxPool2D with padding is destined to fail to run on flutter-tflite? I'll try to see if I can come up with a workaround. |
The following workaround seems to work - the produced TFLite module reads just fine in Tensorflow 2.11.0: import torch
from torch import nn
import ai_edge_torch
class MaxPool2dWorkaround(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x)
x1 = x1[:, :, 1:-1, 1:-1]
x1 = self.m2(x1)
return torch.cat([x, x1], dim=1)
m = MaxPool2dWorkaround()
x = torch.zeros((1, 3, 10, 10))
print(m(x).shape)
torch._dynamo.config.verbose = True
edge_model = ai_edge_torch.convert(m.eval(), sample_args=(x, ))
edge_model.export('test.tflite') Just for the record, other attempts were not fruitful: class PadAndCropWorks(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
return torch.cat([x, x1[:, :, 2:-2, 2:-2]], dim=1)
class MaxPool2DWithoutPaddingWorks(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x)
return torch.cat([x, x1], dim=1)
class MaxPool2DWithoutCatWorks(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
def forward(self, x):
x1 = self.m1(x)
return x1
class JustPadWorksButBadOutput(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
return x1
class JustPadWith3WorksButBadOutput(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
return x1
class ReconstructMaxPool2dFails(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x)
x1 = self.m2(x1)
return torch.cat([x, x1], dim=1)
class ReconstructMaxPool2dWithMulFails(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((2,)*4), value=-float('inf'))
self.m2 = nn.MaxPool2d(kernel_size=5, stride=1, padding=0)
def forward(self, x):
x1 = self.m1(x) * (1 - 1e-5)
x1 = self.m2(x1)
return torch.cat([x, x1], dim=1)
class ExtraPadAndCropWorksGoodOutput(nn.Module):
def __init__(self):
super().__init__()
self.m1 = nn.ConstantPad2d(padding=((3,)*4), value=-float('inf'))
def forward(self, x):
x1 = self.m1(x)
x1 = x1[:, :, 1:-1, 1:-1]
return x1 I'll try it on YOLOv6 and report... |
Using the aforementioned workaround, the model works with TF 2.11 and thus with flutter-tflite. So I would suggest creating an issue for upgrading the TF library version. On another note, the quantized model created with TF 2.17 is not backwards compatible with TF Lite 2.11, as it requires the hybrid transpose-conv (introduced at v2.17), but this is another issue. |
Here is a complete example for seeing the failure and workaround on Colab. Cell 1
Cell 2
Output:
|
I am having a problem with a specific TFLite model: my model works on other platforms (TFLite on Linux) but fails using the Flutter TFLite package (
tflite_flutter: ^0.11.0
) in my Android simulator. The model in question is YOLOv6-seg.Using the recipe for PyTorch to TFLite, I export a
.tflite
model for YOLOv6-seg (see notebook).When running with Python TFLite (using the
ai_edge_litert
package) everything seems to be working fine.However, when using the the Flutter package
tflite_flutter: ^0.11.0
, running on Android, the call toawait Interpreter.fromAsset(...)
fails with the following log message:This happens both for the quantized and normal model. The node number is 63 in the quantized model, and 40 in the normal model. The node corresponds to the
CSPSPPFModule
class, which is a rather simple combination of concat, add, conv2d, BN, maxpool and ReLU.For the model, I use the
yolov6-seg
branch. There is something slightly weird about the YOLOv6 code - the original authors added code to silence warnings:This in turn broke
torch.export.export()
. So for the export to work, I had to remove the warning filter. It just so happens that this issue resides in the aforementioned classCSPSPPFModule
. I do not have an in-depth understanding of the model, but perhaps the warning the original authors wanted to silence has something to do with this failure. However, I am not getting any actual warning in PyTorch or in TFLite on Python, and inference using the exported model works perfectly using TFLite on Python. Again, the exact same model fails on Flutter Android.Here is how it looks like in the model explorer:
I assume this is a bug in the underlying TFLite implementation. However, since this is obviously working on other platforms (namely on Python on Colab), I am assuming it's something that's been fixed in subsequent implementations and thus filing it here in the Flutter TFLite package.
I can provide the quantized model, or you can run the notebook linked to above.
The text was updated successfully, but these errors were encountered: