Skip to content

Commit

Permalink
fix: remove tflite_runtime requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
LynnL4 committed Feb 5, 2025
1 parent be71bae commit ea0ded7
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions sscma/deploy/backend/tflite_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
import torch
import torchvision.transforms as transforms

from .base_infer import BaseInfer

from sscma.utils import lazy_import

import tensorflow as tf

from .base_infer import BaseInfer
class TFliteInfer(BaseInfer):

@lazy_import("tflite_runtime", install_only=True)
def __init__(self, weights="sscma.tflite", device=torch.device("cpu")):
super().__init__(weights=weights, device=device)
self.interpreter = None
Expand Down Expand Up @@ -54,9 +51,8 @@ def infer(self, input_data):
return results

def load_weights(self):
from tflite_runtime.interpreter import Interpreter

self.interpreter = Interpreter(model_path=self.weights) # load TFLite model
self.interpreter = tf.lite.Interpreter(model_path=self.weights) # load TFLite model

self.interpreter.allocate_tensors() # allocate
self.input_details = self.interpreter.get_input_details() # inputs
Expand Down

0 comments on commit ea0ded7

Please sign in to comment.