Skip to content

Commit

Permalink
refactor and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
m-romanenko committed Jul 7, 2020
1 parent d1ae5de commit 164e77e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 9 deletions.
19 changes: 19 additions & 0 deletions tests/tools/test_convert_tflite.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,21 @@
from tf2_yolov4.anchors import YOLOV4_ANCHORS
from tf2_yolov4.model import YOLOv4
from tf2_yolov4.tools.convert_tflite import create_tflite_model


def test_import_convert_tflite_script_does_not_fail():
from tf2_yolov4.tools.convert_tflite import convert_tflite


def test_create_tflite_model_returns_correct_type():
model = YOLOv4(
input_shape=(640, 960, 3),
anchors=YOLOV4_ANCHORS,
num_classes=80,
training=False,
yolo_max_boxes=100,
yolo_iou_threshold=0.4,
yolo_score_threshold=0.1,
)
tflite_model = create_tflite_model(model)
assert isinstance(tflite_model, bytes)
30 changes: 21 additions & 9 deletions tf2_yolov4/tools/convert_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@
TFLITE_MODEL_PATH = "yolov4.tflite"


def create_tflite_model(model):
"""Converts a YOLOv4 model to a TfLite model
Args:
model (tensorflow.python.keras.engine.training.Model): YOLOv4 model
Returns:
(bytes): a binary TfLite model
"""
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]

converter.allow_custom_ops = True
return converter.convert()


@click.command()
@click.option("--num_classes", default=80, help="Number of classes")
@click.option(
Expand All @@ -38,15 +58,7 @@ def convert_tflite(num_classes, weights_path):
if weights_path:
model.load_weights(weights_path)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]

converter.allow_custom_ops = True
tflite_model = converter.convert()
tflite_model = create_tflite_model(model)

with tf.io.gfile.GFile(TFLITE_MODEL_PATH, "wb") as file:
file.write(tflite_model)
Expand Down

0 comments on commit 164e77e

Please sign in to comment.