-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_with_tflite_model_maker.py
67 lines (54 loc) · 2.06 KB
/
train_with_tflite_model_maker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import os
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
from pprint import pprint #Pretty printing for output
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
LABELMAP_FILENAME = 'images/all/labelmap.txt'
text_file = open(LABELMAP_FILENAME, "r")
classes = text_file.read().splitlines()
print("\nClasses to be used:")
print(classes)
train_data = object_detector.DataLoader.from_pascal_voc(
'images/train',
'images/train',
classes
)
validation_data = object_detector.DataLoader.from_pascal_voc(
'images/validation',
'images/validation',
classes
)
test_data = object_detector.DataLoader.from_pascal_voc(
'images/test',
'images/test',
classes
)
print("\nUsing an EfficientDet-Lite0 model for training with 320x320 image resolution.")
spec = object_detector.EfficientDetLite0Spec()
print("\nTraining starts......")
model = object_detector.create(train_data=train_data,
model_spec=spec,
validation_data=validation_data,
epochs=50,
batch_size=4,
train_whole_model=True)
print("\nEvaluating created model")
print("Evaluation result:")
result = model.evaluate(test_data)
pprint(result, width=10)
TFLITE_FILENAME = 'smrc_model.tflite'
LABELS_FILENAME = 'labels.txt'
print("\nExport model to tflite-format")
model.export(export_dir='.', tflite_filename=TFLITE_FILENAME, label_filename=LABELS_FILENAME,
export_format=[ExportFormat.TFLITE, ExportFormat.LABEL])
print("\n\nEvaluating tflite-model")
print("Evaluation result:")
result = model.evaluate_tflite(TFLITE_FILENAME, test_data)
pprint(result, width=10)