-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch convert to TfLite seems something is wrong, I don't why. #446
Comments
Hi @ZTFtrue, I was able to convert and get similar outputs with this cleaned up script. If you can let me know the number of classes (1000?) and where you got your weights, I can try something more similar to your case. import torch
import torchvision
import ai_edge_torch
import numpy as np
# Use resnet18 with pre-trained weights.
enet_b1 = torchvision.models.efficientnet_b1(torchvision.models.EfficientNet_B1_Weights.DEFAULT)
sample_inputs_NCHW = (torch.randn(1, 3, 224, 224),)
# Convert and serialize PyTorch model to a tflite flatbuffer. Note that we
# are setting the model to evaluation mode prior to conversion.
edge_model = ai_edge_torch.convert(enet_b1.eval(), sample_inputs_NCHW)
edge_model.export("enet_b1.tflite")
torch_output = enet_b1(*sample_inputs_NCHW)
edge_output = edge_model(*sample_inputs_NCHW)
if (np.allclose(
torch_output.detach().numpy(),
edge_output,
atol=1e-5,
rtol=1e-5,
)):
print("Inference result with Pytorch and TfLite was within tolerance")
else:
print("Something wrong with Pytorch --> TfLite")
nhwc_enet_b1 = ai_edge_torch.to_channel_last_io(enet_b1, args=[0])
sample_inputs_NHWC = (torch.randn(1, 224, 224, 3),)
nhwc_edge_model = ai_edge_torch.convert(nhwc_enet_b1.eval(), sample_inputs_NHWC)
nhwc_edge_model.export("nhwc_enet_b1.tflite")
torch_output = nhwc_enet_b1(*sample_inputs_NHWC)
edge_output = nhwc_edge_model(*sample_inputs_NHWC)
if (np.allclose(
torch_output.detach().numpy(),
edge_output,
atol=1e-5,
rtol=1e-5,
)):
print("Inference result with Pytorch and TfLite was within tolerance")
else:
print("Something wrong with Pytorch --> TfLite") I don't believe you can add metadata with this API that way. |
Thanks. This code produces the error: However, when I convert using your code, the errors in the TFLite file are much worse than those in the original model or the ONNX model I converted. Here is my model: Google Drive Link. The number of classes is 584. I also mean, can I add metadata using this tool? On my Android device, I encounter this error: Currently, I can only add this metadata using other tools. |
Here’s the translated text in English: The version I originally used was:
about this commit b183411, it resulted in an error:
When I removed the above line of code, the errors in the TFLite file still persisted. I used Here is the test image: Test Image - Google Drive. It should output Here is the model: Model - Google Drive. thank you. |
Hi @ZTFtrue, I was able to modify the test to more closely match your original case and I see some interesting results. I noticed your original test, the torch_output was produced from a different tensor (not a transposed version of the same tensor) so I fixed that. However it seems in this case, there is an issue with the original input tensor format but it seems to work with the NHWC case. So something is wrong at least: import torch
import ai_edge_torch
import numpy as np
from torchvision import models
from torchvision.models import EfficientNet_B1_Weights
import torch
import torch.nn as nn
import torch.optim as optim
#please ignore file name
model_path = "efficientnet_b1_model.pth"
weights=EfficientNet_B1_Weights.IMAGENET1K_V1
efficientnet = models.efficientnet_b1(weights=weights)
# "cuda:0" if torch.cuda.is_available() else
device = torch.device("cpu")
out_features = 224
num_classes = 584
fc_inputs = efficientnet.classifier[1].in_features
model = efficientnet.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
def load_model(model_path):
model = efficientnet
fc_inputs = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Linear(fc_inputs, out_features),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(out_features, num_classes),
nn.LogSoftmax(dim=1),
)
model.load_state_dict(torch.load(model_path, weights_only=False, map_location=torch.device('cpu')))
model = model.to(device)
return model
loaded_model = load_model(model_path)
sample_inputs_NCHW = (torch.randn(1, 3, 224, 224),)
torch_output = loaded_model(*sample_inputs_NCHW)
edge_model_NCHW = ai_edge_torch.convert(loaded_model.eval(), sample_inputs_NCHW)
edge_model_NCHW.export("efficientnet_b1_model_NCHW.tflite")
edge_output = edge_model_NCHW(*sample_inputs_NCHW)
if (np.allclose(
torch_output.detach().numpy(),
edge_output,
atol=1e-5,
rtol=1e-5,
)):
print("Inference result with Pytorch and TfLite was within tolerance")
else:
print("Something wrong with Pytorch --> TfLite")
sample_inputs_NHWC = (torch.randn(1, 224, 224, 3),)
# torch_output = loaded_model(*sample_inputs_NCHW)
# Transform the first input to NHWC.
loaded_model_NHWC = ai_edge_torch.to_channel_last_io(loaded_model, args=[0])
torch_output = loaded_model_NHWC(*sample_inputs_NHWC)
# Convert the transformed model with NHWC input(s).
edge_model_NHWC = ai_edge_torch.convert(loaded_model_NHWC.eval(), sample_inputs_NHWC)
edge_model_NHWC.export("efficientnet_b1_model_NHWC.tflite")
edge_output = edge_model_NHWC(*sample_inputs_NHWC)
if (np.allclose(
torch_output.detach().numpy(),
edge_output,
atol=1e-5,
rtol=1e-5,
)):
print("Inference result with Pytorch and TfLite was within tolerance")
else:
print("Something wrong with Pytorch --> TfLite") |
Sorry, I have a question.
Running this code will output
Something wrong with Pytorch --> TfLite
, and I don’t know why.I also tried it on my phone, and it outputs an error label.
Here is all code.
I have another question: Can I add metadata when exporting the model using this tool?
metadata
The text was updated successfully, but these errors were encountered: