-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtraining.py
144 lines (101 loc) · 3.57 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
Andrew Player
September 2022
Script for training a network for MSTAR SAR Target Detection.
"""
import os
from math import ceil
import numpy as np
from mstar_io import load_sample
from model import create_net
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
class DataGenerator(Sequence):
"""
Dataset Generator for sequencially passing files from storange into the model.
"""
def __init__(self, file_list, data_path, tile_size):
self.file_list = file_list
self.tile_size = tile_size
self.data_path = data_path
self.on_epoch_end()
def __len__(self):
"""
The amount of files in the dataset.
"""
return int(len(self.file_list))
def __getitem__(self, index):
"""
Returns the set of inputs and their corresponding truths.
"""
# Generate indexes of the batch
indexes = self.indexes[index:(index + 1)]
# single file
file_list_temp = [self.file_list[k] for k in indexes]
# Set of X_train and y_train
X, y = self.__data_generation(file_list_temp)
return X, y
def on_epoch_end(self):
self.indexes = np.arange(len(self.file_list))
def __data_generation(self, file_list_temp):
"""
Returns individual pairs of inputs and their corresponding truths.
"""
# Generate data
for ID in file_list_temp:
magnitude, label = load_sample(os.path.join(self.data_path, ID))
x = magnitude.reshape((1, self.tile_size, self.tile_size, 1))
y = label.reshape((1, 7))
return x, y
def train_model(
model_name: str,
dataset_dir: str,
tile_size: int,
num_epochs: int,
batch_size: int
):
train_path = dataset_dir + '/train'
val_path = dataset_dir + '/validation'
all_training_files = os.listdir(train_path)
all_validation_files = os.listdir(val_path)
training_partition = [item for item in all_training_files]
validation_partition = [item for item in all_validation_files]
training_generator = DataGenerator(training_partition, train_path, tile_size)
validation_generator = DataGenerator(validation_partition, val_path, tile_size)
model = create_net(
model_name = model_name,
tile_size = 128,
label_count = 7
)
model.summary()
early_stopping = EarlyStopping(
monitor = 'loss',
patience = 2,
verbose = 1
)
checkpoint = ModelCheckpoint(
filepath = 'models/checkpoints/' + model_name,
monitor = 'val_loss',
mode = 'min',
verbose = 1,
save_best_only = True
)
training_samples = len(training_partition)
validation_samples = len(validation_partition)
training_steps = ceil(training_samples / batch_size)
validation_steps = ceil(validation_samples / batch_size)
history = model.fit(
training_generator,
epochs = num_epochs,
validation_data = validation_generator,
batch_size = batch_size,
steps_per_epoch = training_steps,
validation_steps = validation_steps,
callbacks = [checkpoint, early_stopping]
)
model.save("models/" + model_name)
return history