-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgan_model.py
361 lines (297 loc) · 14 KB
/
gan_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# -*- coding: utf-8 -*-
"""GAN-model
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1tb7cRDKNj0au2L225WDAHRPK_wgvPZP_
"""
import numpy as np
import matplotlib.pyplot as plt
from os import listdir
from numpy import vstack
from numpy import savez_compressed
from PIL import Image
from numpy import load
from numpy.random import randint
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras.layers import Dropout
from keras.layers import BatchNormalization
# Load all images
class ImageLoader():
def __init__(self, path, img_size=(256,512)):
self.path = path
self.img_size = img_size
# Load images in two variables
[self.source_img, self.target_img] = self.load_images()
# Load the save compressed numpy array
self.filename = self.save_images()
def load_images(self):
source_img = list()
target_img = list()
# Load images in a two list
for img in listdir(self.path):
images = load_img (self.path + img, target_size = self.img_size)
# Convert to numpy array
images = img_to_array(images)
# Create and separate two images inclusing satellite and map
satellite_img = images[:, :256]
map_img = images[:, 256:]
# Add images in two separate list
source_img.append(satellite_img)
target_img.append(map_img)
# Convert to numpy array and scale from [0,255] to [-1,1]
source_img = np.array(source_img)
target_img = np.array(target_img)
return [source_img, target_img]
# save as compressed numpy array
def save_images(self):
filename = '/content/drive/MyDrive/Google_maps.npz'
savez_compressed(filename, self.source_img, self.target_img)
return filename
# Load original data after saving
def load_data(self):
# load compressed dataset (numpyarray)
ds = load(self.filename)
# unpack arrays
src_img, tar_img = ds['arr_0'], ds['arr_1']
# Convert to numpy array and scale from [0,255] to [-1,1]
source_img = (src_img - 127.5) / 127.5
target_img = (tar_img - 127.5) / 127.5
return [source_img, target_img]
def plot_images(self):
# plot source images
num_samples = 4
for i in range(num_samples):
plt.subplot(2, num_samples, 1 + num_samples + i)
plt.axis("off")
plt.imshow(self.source_img[i].astype('uint8'))
# plot target image
for i in range(num_samples):
plt.subplot(2, num_samples, 1 + num_samples + i)
plt.axis("off")
plt.imshow(self.target_img[i].astype('uint8'))
plt.show()
# To Develop and Train a Pix2Pix Model (CGANS)
class CGAN():
def __init__(self):
# Image shape
self.IMAGE_WIDTH = 256
self.IMAGE_HEIGHT = 256
self.IMAGE_CHANNELS = 3
# Load dataset
self.path = '/content/drive/MyDrive/maps/maps/train/'
self.image = ImageLoader(self.path)
self.image_loader = self.image.load_data()
print("Load images: ", self.image_loader[0].shape, self.image_loader[1].shape)
self.image_shape = self.image_loader[0].shape[1:]
# Plot images
self.drwingImages = self.image.plot_images()
print(self.drwingImages)
# Build discriminator
self.discriminator_model = self.build_discriminator()
# Build the generator
self.generator_model = self.build_generator()
# Build GAN model
self.gan_model = self.build_GAN()
# determine the output square shape of the discriminator
num_patch = self.discriminator_model.output_shape[1]
# define the generator model being a encoder-decoder block
def build_generator(self):
# The generator is an encoder-decoder model using a U-Net architecture:
"""U-Net Generator"""
# define an encoder block
def define_encoder(input_layer, num_filters, batchnorm=True):
# Add downsampling layer
encoder = Conv2D (num_filters, (4,4),
strides=(2,2),
padding="same",
kernel_initializer = RandomNormal(stddev=0.02))(input_layer)
# conditionally add batch normalization
if batchnorm:
encoder = BatchNormalization()(encoder, training=True)
# Add Leaky ReLU activation
encoder = LeakyReLU(alpha=0.2)(encoder)
return encoder
# define a decoder block
def define_decoder(input_layer, skip_input, num_filters, dropout=True):
# Add upsampling layer
decoder = Conv2DTranspose(num_filters, (4,4),
strides=(2,2),
padding="same",
kernel_initializer = RandomNormal(stddev=0.02))(input_layer)
# add batch normalization
decoder = BatchNormalization()(decoder, training=True)
# conditionally add dropout layer
if dropout:
decoder = Dropout(0.5)(decoder, training=True)
# merge with skip connection
decoder = Concatenate()([decoder, skip_input])
# relu activation
decoder = Activation("relu")(decoder)
return decoder
# define the generator model
# Image input
img_in = Input(shape = self.image_shape)
# encoder model (downsampling)
d1 = define_encoder(img_in, 64, batchnorm=False)
d2 = define_encoder(d1, 128)
d3 = define_encoder(d2, 256)
d4 = define_encoder(d3, 512)
d5 = define_encoder(d4, 512)
d6 = define_encoder(d5, 512)
d7 = define_encoder(d6, 512)
d = Conv2D(512, (4,4), strides=(2,2),
padding = "same",
kernel_initializer = RandomNormal(stddev=0.02))(d7)
d = Activation("relu")(d)
# decoder model (Upsampling)
u1 = define_decoder(d, d7, 512)
u2 = define_decoder(u1, d6, 512)
u3 = define_decoder(u2, d5, 512)
u4 = define_decoder(u3, d4, 512, dropout=False)
u5 = define_decoder(u4, d3, 256, dropout=False)
u6 = define_decoder(u5, d2, 128, dropout=False)
u7 = define_decoder(u6, d1, 64, dropout=False)
# Output layer
lr = Conv2DTranspose(3, (4,4), strides=(2,2),
padding = "same",
kernel_initializer = RandomNormal(stddev=0.02))(u7)
img_out = Activation("tanh")(lr)
# define model
model = Model(img_in, img_out)
model.summary()
return model
# define the discriminator model
def build_discriminator(self):
# Source image input
source_img_in = Input(shape = self.image_shape)
# Target image input
target_img_in = Input(shape = self.image_shape)
# concatenate images
merged = Concatenate()([source_img_in, target_img_in])
# C64
model = Conv2D(64, (4,4), strides=(2,2), padding = "same", kernel_initializer = RandomNormal(stddev=0.02))(merged)
model = LeakyReLU(alpha=0.2)(model)
# C128
model = Conv2D(128, (4,4), strides=(2,2), padding = "same", kernel_initializer = RandomNormal(stddev=0.02))(model)
model = BatchNormalization()(model)
model = LeakyReLU(alpha=0.2)(model)
# C256
model = Conv2D(256, (4,4), strides=(2,2), padding = "same", kernel_initializer = RandomNormal(stddev=0.02))(model)
model = BatchNormalization()(model)
model = LeakyReLU(alpha=0.2)(model)
# C512
model = Conv2D(512, (4,4), strides=(2,2), padding = "same", kernel_initializer = RandomNormal(stddev=0.02))(model)
model = BatchNormalization()(model)
model = LeakyReLU(alpha=0.2)(model)
# Last layer
model = Conv2D(256, (4,4), padding = "same", kernel_initializer = RandomNormal(stddev=0.02))(model)
model = BatchNormalization()(model)
model = LeakyReLU(alpha=0.2)(model)
# Patch output
model = Conv2D(1, (4,4), padding = "same", kernel_initializer = RandomNormal(stddev=0.02))(model)
patch_out = Activation("sigmoid")(model)
# Define model
model = Model([source_img_in, target_img_in], patch_out)
# Compile model
opt = Adam(lr=0.0002, beta_2=0.5)
model.compile(loss = "binary_crossentropy", optimizer=opt, loss_weights=[0.5])
return model
# define the combined generator and discriminator model
def build_GAN(self):
# Here, we need to make weights not trainable in the discriminator model
self.discriminator_model.trainable = False
# Build the source image
source_input = Input(self.image_shape)
# Connect source_input to the generator input model
generator_output = self.generator_model(source_input)
# Connect the source_input and generator_output to the discriminator input model
discriminator_output = self.discriminator_model([source_input, generator_output])
# Define the model that source image consider as input model, generated image and
# discriminator_output (classification) are output
gan_model = Model(source_input, [discriminator_output, generator_output])
gan_model.compile(loss=['binary_crossentropy', 'mae'], optimizer = Adam(0.0002, 0.5), loss_weights=[1,100])
return gan_model
# train pix2pix model
def train_model(self, num_sample, epochs, batch_size):
# determine the output square shape of the discriminator
patch_size = self.discriminator_model.output_shape[1]
# Load dataset
self.source_img, self.target_img = self.image_loader
# Create fake and real sample :
# choose random instances for generating real samples
batch_indexes = randint(0, self.source_img.shape[0], num_sample)
# retrieve selected images
self.X_real_A, self.X_real_B = self.source_img[batch_indexes], self.target_img[batch_indexes]
# generate fake instance
self.X_fake_B = self.generator_model.predict(self.X_real_A)
# generate 'real' class labels (1)
real = np.ones((num_sample, patch_size, patch_size, 1))
# create 'fake' class labels (0)
fake = np.zeros((len(self.X_fake_B), patch_size, patch_size, 1))
# determine the number of batches per training epoch
batch_train = int(len(self.source_img) / batch_size)
# determine the number of training iterations
num_steps = batch_train * epochs
for e in range(num_steps):
#---------------------
# Train Discriminator
#---------------------
# To update the discriminator for original or real images
disc_loss_real = self.discriminator_model.train_on_batch([self.X_real_A, self.X_real_B], real)
# To update the discriminator for generated or fake images
disc_loss_fake = self.discriminator_model.train_on_batch([self.X_real_A, self.X_fake_B], fake)
#---------------------
# Train Generator
#---------------------
# To update the generator model
g_loss = self.gan_model.train_on_batch(self.X_real_A, [real, self.X_real_B])
print("Epochs: {}, D_loss_real: {}, D_loss_fake: {}, G_loss: {}".format(e+1,
disc_loss_real,
disc_loss_fake,
g_loss))
if (e+1) % (batch_train*10) == 0:
self.sample_image(e, num_sample)
# Select sample of images from the training dataset and save them as a plot and then save the model
def sample_image(self, step, num_sample):
# Rescale images from [-1,1] to [0,1]
self.X_real_A = (self.X_real_A + 1) / 2.0
self.X_real_B = (self.X_real_B + 1) / 2.0
self.X_fake_B = (self.X_fake_B + 1) / 2.0
# plot real source images
for i in range(num_sample):
plt.subplot(3, num_sample, 1 + i)
plt.axis('off')
plt.imshow(self.X_real_A[i])
# plot generated target image
for i in range(num_sample):
plt.subplot(3, num_sample, 1 + num_sample + i)
plt.axis('off')
plt.imshow(self.X_fake_B[i])
# plot real target image
for i in range(num_sample):
plt.subplot(3, num_sample, 1 + num_sample*2 + i)
plt.axis('off')
plt.imshow(self.X_real_B[i])
# save plot to file
file1 = 'plot_%06d.png' % (step+1)
plt.savefig(file1)
plt.close()
# save the generator model
file2 = 'model_%06d.h5' % (step+1)
self.generator_model.save(file2)
print('>Saved: %s and %s' % (file1, file2))
plt.close()
if __name__ == "__main__":
gan = CGAN()
# Train the model
gan.train_model(num_sample = 3, epochs = 25, batch_size = 1)