-
Notifications
You must be signed in to change notification settings - Fork 84
/
train_cls.py
108 lines (84 loc) · 3.29 KB
/
train_cls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import json
import pandas as pd
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint
def generate(batch, shape, ptrain, pval):
"""Data generation and augmentation
# Arguments
batch: Integer, batch size.
size: Integer, image size.
ptrain: train dir.
pval: eval dir.
# Returns
train_generator: train set generator
validation_generator: validation set generator
count1: Integer, number of train set.
count2: Integer, number of test set.
"""
# Using the data Augmentation in traning data
datagen1 = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
rotation_range=90,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
datagen2 = ImageDataGenerator(rescale=1. / 255)
train_generator = datagen1.flow_from_directory(
ptrain,
target_size=shape,
batch_size=batch,
class_mode='categorical')
validation_generator = datagen2.flow_from_directory(
pval,
target_size=shape,
batch_size=batch,
class_mode='categorical')
count1 = 0
for root, dirs, files in os.walk(ptrain):
for each in files:
count1 += 1
count2 = 0
for root, dirs, files in os.walk(pval):
for each in files:
count2 += 1
return train_generator, validation_generator, count1, count2
def train():
with open('config/config.json', 'r') as f:
cfg = json.load(f)
save_dir = cfg['save_dir']
shape = (int(cfg['height']), int(cfg['width']), 3)
n_class = int(cfg['class_number'])
batch = int(cfg['batch'])
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if cfg['model'] == 'large':
from model.mobilenet_v3_large import MobileNetV3_Large
model = MobileNetV3_Large(shape, n_class).build()
if cfg['model'] == 'small':
from model.mobilenet_v3_small import MobileNetV3_Small
model = MobileNetV3_Small(shape, n_class).build()
pre_weights = cfg['weights']
if pre_weights and os.path.exists(pre_weights):
model.load_weights(pre_weights, by_name=True)
opt = Adam(lr=float(cfg['learning_rate']))
earlystop = EarlyStopping(monitor='val_acc', patience=5, verbose=0, mode='auto')
checkpoint = ModelCheckpoint(filepath=os.path.join(save_dir, '{}_weights.h5'.format(cfg['model'])),
monitor='val_acc', save_best_only=True, save_weights_only=True)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
train_generator, validation_generator, count1, count2 = generate(batch, shape[:2], cfg['train_dir'], cfg['eval_dir'])
hist = model.fit_generator(
train_generator,
validation_data=validation_generator,
steps_per_epoch=count1 // batch,
validation_steps=count2 // batch,
epochs=cfg['epochs'],
callbacks=[earlystop,checkpoint])
df = pd.DataFrame.from_dict(hist.history)
df.to_csv(os.path.join(save_dir, 'hist.csv'), encoding='utf-8', index=False)
#model.save_weights(os.path.join(save_dir, '{}_weights.h5'.format(cfg['model'])))
if __name__ == '__main__':
train()