-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
103 lines (82 loc) · 3.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import platform
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# 256x256
image_size = (256, 256)
batch_size = 32
device = tf.config.list_physical_devices('GPU')
if platform.processor() == "aarch64" and device:
tf.config.experimental.set_memory_growth(device[0], True)
tf.config.experimental.set_virtual_device_configuration(device[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
# load images
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"data/train_val/",
validation_split=0.2,
# color_mode="grayscale",
subset="training",
seed=1337,
image_size=image_size,
batch_size=batch_size,
label_mode="categorical",
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"data/train_val/",
validation_split=0.2,
# color_mode="grayscale",
subset="validation",
seed=1337,
image_size=image_size,
batch_size=batch_size,
label_mode="categorical",
)
# make loading images more efficient
train_ds = train_ds.prefetch(buffer_size=32)
val_ds = val_ds.prefetch(buffer_size=32)
num_classes = 5 # we have 5 categories of images
def make_model(input_shape):
inputs = keras.Input(shape=input_shape)
x = layers.Rescaling(1.0 / 255)(inputs)
x = layers.Conv2D(32, 3, strides=2, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2D(64, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x
for size in [128, 256, 512]: # this amount of layers is a good balance between accuracy and generalization
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual])
previous_block_activation = x
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.GlobalAveragePooling2D()(x)
activation = "softmax" # best for multi-class classification
units = num_classes
x = layers.Dropout(0.5)(x) # prevent overfitting
outputs = layers.Dense(units, activation=activation)(x)
return keras.Model(inputs, outputs)
model = make_model(input_shape=(*image_size, 3))
epochs = 50 # usually training will stop before this number of epochs is reached but ehh it doesn't hurt to set a limit
callbacks = [
keras.callbacks.EarlyStopping(monitor='val_loss', patience=3), # stop automatically if overfitting is detected
keras.callbacks.ModelCheckpoint("checkpoints/save_at_{epoch}.h5"), # save model after every epoch so we can cherry-pick the best iteration
]
model.compile(
optimizer=keras.optimizers.Adam(1e-3),
loss="categorical_crossentropy", # multi-class classification
metrics=["accuracy"],
)
model.fit(
train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
)