Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch convert to TfLite seems something is wrong, I don't why. #446

Open
ZTFtrue opened this issue Jan 4, 2025 · 6 comments
Open

Pytorch convert to TfLite seems something is wrong, I don't why. #446

ZTFtrue opened this issue Jan 4, 2025 · 6 comments

Comments

@ZTFtrue
Copy link

ZTFtrue commented Jan 4, 2025

Sorry, I have a question.

Running this code will output Something wrong with Pytorch --> TfLite, and I don’t know why.

I also tried it on my phone, and it outputs an error label.

Name: ai-edge-quantizer-nightly
Version: 0.0.1.dev20240718
Summary: A quantizer for advanced developers to quantize converted AI Edge models.
Home-page: https://github.com/google-ai-edge/ai-edge-quantizer
Author: 
Author-email: 
License: 
Location: /home/ztftrue/.conda/envs/tensorflow/lib/python3.11/site-packages
Requires: immutabledict, numpy, tf-nightly
Required-by: ai-edge-torch, ai-edge-torch-nightly
----
Name: tf_nightly
Version: 2.19.0.dev20241227
Summary: TensorFlow is an open source machine learning framework for everyone.
Home-page: https://www.tensorflow.org/
Author: Google Inc.
Author-email: [email protected]
License: Apache 2.0
Location: /home/ztftrue/.conda/envs/tensorflow/lib/python3.11/site-packages
Requires: absl-py, astunparse, flatbuffers, gast, google-pasta, grpcio, h5py, keras-nightly, libclang, ml-dtypes, numpy, opt-einsum, packaging, protobuf, requests, setuptools, six, tb-nightly, tensorflow-io-gcs-filesystem, termcolor, typing-extensions, wrapt
Required-by: ai-edge-quantizer, ai-edge-quantizer-nightly, ai-edge-torch, ai-edge-torch-nightly

Here is all code.

import torch
import ai_edge_torch
import torchvision
import numpy
from torchvision import transforms, models
from torchvision.models import EfficientNet_B7_Weights,EfficientNet_B4_Weights,EfficientNet_B1_Weights
import torch
from torchvision.models import efficientnet_b7
from torchvision import transforms, models
from torchvision.models import EfficientNet_B7_Weights,EfficientNet_B4_Weights,EfficientNet_B1_Weights
import torch.nn as nn
import torch.optim as optim
import os
#please ignore file name
model_path = "efficientnet_b7_model.pth"
weights=EfficientNet_B1_Weights.IMAGENET1K_V1

efficientnet = models.efficientnet_b1(weights=weights)
# "cuda:0" if torch.cuda.is_available() else
device = torch.device( "cpu")
out_features=224
image_folder = '/home/ztftrue/Documents/PyTorch/ArTest2/train'
num_classes = len(os.listdir(image_folder)) 
fc_inputs = efficientnet.classifier[1].in_features
 

model = efficientnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
 
def load_model(model_path):
    model = efficientnet
    fc_inputs = model.classifier[1].in_features
    model.classifier = nn.Sequential(
      nn.Linear(fc_inputs, out_features), 
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(out_features, num_classes),
        nn.LogSoftmax(dim=1)
    )
    model.load_state_dict(torch.load(model_path,weights_only=False))
    model = model.to(device)
    return model


loaded_model = load_model(model_path)
sample_inputs_NCHW = (torch.randn(1, 3, 224, 224),)
sample_inputs_NHWC = (torch.randn(1, 224, 224, 3),)

torch_output = loaded_model(*sample_inputs_NCHW)
# Transform the first input to NHWC.

loaded_model = ai_edge_torch.to_channel_last_io(loaded_model,  args=[0])

# Convert the transformed model with NHWC input(s).
edge_model = ai_edge_torch.convert(loaded_model.eval(), sample_inputs_NHWC)
edge_model.export("efficientnet_b7_model.tflite")

edge_output = edge_model(*sample_inputs_NHWC)

if (numpy.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

I have another question: Can I add metadata when exporting the model using this tool?

metadata

    ModelSpecificInfo(
            name="MobileNetV1 image classifier",
            version="v1",
            image_width=240,
            image_height=240,
            image_min=0,
            image_max=255,
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            num_classes=num_classes,
            author="TensorFlow")
@pkgoogle
Copy link
Contributor

pkgoogle commented Jan 6, 2025

Hi @ZTFtrue, I was able to convert and get similar outputs with this cleaned up script. If you can let me know the number of classes (1000?) and where you got your weights, I can try something more similar to your case.

import torch
import torchvision
import ai_edge_torch
import numpy as np

# Use resnet18 with pre-trained weights.
enet_b1 = torchvision.models.efficientnet_b1(torchvision.models.EfficientNet_B1_Weights.DEFAULT)
sample_inputs_NCHW = (torch.randn(1, 3, 224, 224),)

# Convert and serialize PyTorch model to a tflite flatbuffer. Note that we
# are setting the model to evaluation mode prior to conversion.
edge_model = ai_edge_torch.convert(enet_b1.eval(), sample_inputs_NCHW)
edge_model.export("enet_b1.tflite")

torch_output = enet_b1(*sample_inputs_NCHW)
edge_output = edge_model(*sample_inputs_NCHW)

if (np.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

nhwc_enet_b1 = ai_edge_torch.to_channel_last_io(enet_b1, args=[0])
sample_inputs_NHWC = (torch.randn(1, 224, 224, 3),)

nhwc_edge_model = ai_edge_torch.convert(nhwc_enet_b1.eval(), sample_inputs_NHWC)
nhwc_edge_model.export("nhwc_enet_b1.tflite")

torch_output = nhwc_enet_b1(*sample_inputs_NHWC)
edge_output = nhwc_edge_model(*sample_inputs_NHWC)

if (np.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

I don't believe you can add metadata with this API that way.

@pkgoogle pkgoogle added status:awaiting user response When awaiting user response status:more data needed This label needs to be added to stale issues and PRs. type:support For use-related issues labels Jan 6, 2025
@ZTFtrue
Copy link
Author

ZTFtrue commented Jan 7, 2025

Thanks. This code produces the error:
torch_output = loaded_model(*sample_inputs_NCHW).

However, when I convert using your code, the errors in the TFLite file are much worse than those in the original model or the ONNX model I converted.

Here is my model: Google Drive Link. The number of classes is 584.


I also mean, can I add metadata using this tool?

On my Android device, I encounter this error:
Input tensor is of type kTfLiteFloat32: it requires NormalizationOptions metadata to preprocess the input image.

Currently, I can only add this metadata using other tools.

@pkgoogle
Copy link
Contributor

pkgoogle commented Jan 7, 2025

Hi @ZTFtrue, I was able to run my script as above w/o issue. I used this commit: b183411 and pip install -e ..

Can you try with that version?

I need permission to access that drive link. I do not currently see a way to add metadata using this library, which tools are you using?

@ZTFtrue
Copy link
Author

ZTFtrue commented Jan 8, 2025

Here’s the translated text in English:


The version I originally used was:

Name: ai-edge-torch  
Version: 0.2.1  

about this commit b183411, it resulted in an error:

in <module>
    edge_output = edge_model(*sample_inputs_NHWC)
SystemError: <built-in method CreateWrapperFromBuffer of PyCapsule object at 0x79e6d69f75d0> returned a result with an exception set

When I removed the above line of code, the errors in the TFLite file still persisted.

I used tflite_support to add metadata.

Here is the test image: Test Image - Google Drive. It should output 1.

Here is the model: Model - Google Drive.

thank you.

@ZTFtrue
Copy link
Author

ZTFtrue commented Jan 8, 2025

image = Image.open(image_path).convert("RGB")
input_tensor = preprocess(image).unsqueeze(0)
onnx_input = input_tensor.numpy()
output = ort_session.run(None, {"input": onnx_input})[0]
predicted_class = np.argmax(output, axis=1)
print(f"Predicted class: {predicted_class[0]}")

2025-01-08_20-03

@pkgoogle
Copy link
Contributor

pkgoogle commented Jan 8, 2025

Hi @ZTFtrue, I was able to modify the test to more closely match your original case and I see some interesting results. I noticed your original test, the torch_output was produced from a different tensor (not a transposed version of the same tensor) so I fixed that. However it seems in this case, there is an issue with the original input tensor format but it seems to work with the NHWC case. So something is wrong at least:

import torch
import ai_edge_torch
import numpy as np
from torchvision import models
from torchvision.models import EfficientNet_B1_Weights
import torch
import torch.nn as nn
import torch.optim as optim


#please ignore file name
model_path = "efficientnet_b1_model.pth"
weights=EfficientNet_B1_Weights.IMAGENET1K_V1

efficientnet = models.efficientnet_b1(weights=weights)
# "cuda:0" if torch.cuda.is_available() else
device = torch.device("cpu")
out_features = 224
num_classes = 584
fc_inputs = efficientnet.classifier[1].in_features

model = efficientnet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
 
def load_model(model_path):
    model = efficientnet
    fc_inputs = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Linear(fc_inputs, out_features), 
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(out_features, num_classes),
        nn.LogSoftmax(dim=1),
    )
    model.load_state_dict(torch.load(model_path, weights_only=False, map_location=torch.device('cpu')))
    model = model.to(device)
    return model


loaded_model = load_model(model_path)
sample_inputs_NCHW = (torch.randn(1, 3, 224, 224),)

torch_output = loaded_model(*sample_inputs_NCHW)
edge_model_NCHW = ai_edge_torch.convert(loaded_model.eval(), sample_inputs_NCHW)
edge_model_NCHW.export("efficientnet_b1_model_NCHW.tflite")
edge_output = edge_model_NCHW(*sample_inputs_NCHW)


if (np.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")


sample_inputs_NHWC = (torch.randn(1, 224, 224, 3),)

# torch_output = loaded_model(*sample_inputs_NCHW)
# Transform the first input to NHWC.

loaded_model_NHWC = ai_edge_torch.to_channel_last_io(loaded_model, args=[0])
torch_output = loaded_model_NHWC(*sample_inputs_NHWC)

# Convert the transformed model with NHWC input(s).
edge_model_NHWC = ai_edge_torch.convert(loaded_model_NHWC.eval(), sample_inputs_NHWC)
edge_model_NHWC.export("efficientnet_b1_model_NHWC.tflite")

edge_output = edge_model_NHWC(*sample_inputs_NHWC)

if (np.allclose(
    torch_output.detach().numpy(),
    edge_output,
    atol=1e-5,
    rtol=1e-5,
)):
    print("Inference result with Pytorch and TfLite was within tolerance")
else:
    print("Something wrong with Pytorch --> TfLite")

@pkgoogle pkgoogle added status:awaiting ai-edge-developer type:precision/accuracy For issues where the precision/accuracy appear incorrect type:bug Bug and removed status:awaiting user response When awaiting user response status:more data needed This label needs to be added to stale issues and PRs. type:support For use-related issues type:precision/accuracy For issues where the precision/accuracy appear incorrect labels Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants