Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to Keras Mish implementation for TfLite compatibility #60

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions convert_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import click
import tensorflow as tf

from tf2_yolov4.anchors import YOLOV4_ANCHORS
from tf2_yolov4.model import YOLOv4

HEIGHT, WIDTH = (640, 960)

TFLITE_MODEL_PATH = "yolov4.tflite"


@click.command()
@click.option("--num_classes", default=80, help="Number of classes")
@click.option(
"--weights_path", default=None, help="Path to .h5 file with model weights"
)
def main(num_classes, weights_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Est-ce que tu peux :

  • ne pas appeler ca main()
  • ajouter dans le setup.py une commande comme c'est fait pour la conversion de poids ? Ca permet d'avoir une CLI quand tu installes la librairie

model = YOLOv4(
input_shape=(HEIGHT, WIDTH, 3),
anchors=YOLOV4_ANCHORS,
num_classes=num_classes,
training=False,
yolo_max_boxes=100,
yolo_iou_threshold=0.4,
yolo_score_threshold=0.1,
)

if weights_path:
model.load_weights(weights_path)

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]

converter.allow_custom_ops = True
tflite_model = converter.convert()

with tf.io.gfile.GFile(TFLITE_MODEL_PATH, "wb") as f:
f.write(tflite_model)


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions tf2_yolov4/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Activations layers"""

from .mish import Mish
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mets un path absolu plutot que relatif


__all__ = ["Mish"]
27 changes: 27 additions & 0 deletions tf2_yolov4/activations/mish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Tensorflow-Keras Implementation of Mish
Source: https://github.com/digantamisra98/Mish/blob/master/Mish/TFKeras/mish.py
"""
import tensorflow as tf
from tensorflow.keras.layers import Layer


class Mish(Layer):
"""
Mish Activation Function.
.. math::
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
Shape:
- Input: Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
- Output: Same shape as the input.
Examples:
>>> X = Mish()(X_input)
"""

def __init__(self, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pas besoin de définir l'init si tu ne fais rien de plus

super(Mish, self).__init__(**kwargs)

def call(self, inputs, **kwargs):
return inputs * tf.math.tanh(tf.math.softplus(inputs))
5 changes: 3 additions & 2 deletions tf2_yolov4/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common layer architecture such as Conv->BN->Mish or Conv->BN->LeakyReLU"""
import tensorflow as tf
import tensorflow_addons as tfa

from tf2_yolov4.activations import Mish


def conv_bn(
Expand Down Expand Up @@ -41,6 +42,6 @@ def conv_bn(
if activation == "leaky_relu":
x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
elif activation == "mish":
x = tfa.activations.mish(x)
x = Mish()(x)

return x