-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathtraining_loop.py
424 lines (357 loc) · 19.5 KB
/
training_loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# Training loop:
# 1. Set up the environment and data
# 2. Build the generator (g) and discriminator (d) networks
# 3. Manage the training process
# 4. Run periodic evaluations on specified metrics
# 5. Produce sample images over the course of training
# It supports training over data in TF records as produced by prepare_data.py
# Labels can optionally be provided although are not essential
# If provided, image will be generated conditioned on the chosen label
import glob
import numpy as np
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib.autosummary import autosummary
import pretrained_networks
from training import dataset as data
from training import misc
from training import visualize
from metrics import metric_base
# Data processing
# ----------------------------------------------------------------------------
# Just-in-time input image processing before feeding them to the networks
def process_reals(x, drange_data, drange_net, mirror_augment):
with tf.name_scope("DynamicRange"):
x = tf.cast(x, tf.float32)
x.set_shape([None, 3, None, None])
x = misc.adjust_dynamic_range(x, drange_data, drange_net)
if mirror_augment:
with tf.name_scope("MirrorAugment"):
x = tf.where(tf.random_uniform([tf.shape(x)[0]]) < 0.5, x, tf.reverse(x, [3]))
return x
def read_data(data, name, shape, batch_gpu_in):
var = tf.Variable(name = name, trainable = False, initial_value = tf.zeros(shape))
data_write = tf.concat([data, var[batch_gpu_in:]], axis = 0)
data_fetch_op = tf.assign(var, data_write)
data_read = var[:batch_gpu_in]
return data_read, data_fetch_op
# Scheduling and optimization
# ----------------------------------------------------------------------------
# Evaluate time-varying training parameters
def training_schedule(
sched_args,
cur_nimg, # The training length, measured in number of generated images
dataset, # The dataset object for accessing the data
lrate_rampup_kimg = 0, # Duration of learning rate ramp-up
tick_kimg = 8): # Default interval of progress snapshots
# Initialize scheduling dictionary
s = dnnlib.EasyDict()
# Set parameters
s.kimg = cur_nimg / 1000.0
s.tick_kimg = tick_kimg
s.resolution = 2 ** dataset.resolution_log2
for arg in ["G_lrate", "D_lrate", "batch_size", "batch_gpu"]:
s[arg] = sched_args[arg]
# Learning rate optional rampup
if lrate_rampup_kimg > 0:
rampup = min(s.kimg / lrate_rampup_kimg, 1.0)
s.G_lrate *= rampup
s.D_lrate *= rampup
return s
# Build two optimizers a network cN for the loss and regularization terms
def set_optimizer(cN, lrate_in, batch_multiplier, lazy_regularization = True, clip = None):
args = dict(cN.opt_args)
args["batch_multiplier"] = batch_multiplier
args["learning_rate"] = lrate_in
if lazy_regularization:
mb_ratio = cN.reg_interval / (cN.reg_interval + 1)
args["learning_rate"] *= mb_ratio
if "beta1" in args: args["beta1"] **= mb_ratio
if "beta2" in args: args["beta2"] **= mb_ratio
cN.opt = tflib.Optimizer(name = f"Loss{cN.name}", clip = clip, **args)
cN.reg_opt = tflib.Optimizer(name = f"Reg{cN.name}", share = cN.opt, clip = clip, **args)
# Create optimization operations for computing and optimizing loss, gradient norm and regularization terms
def set_optimizer_ops(cN, lazy_regularization, no_op):
cN.reg_norm = tf.constant(0.0)
cN.trainables = cN.gpu.trainables
if cN.reg is not None:
if lazy_regularization:
cN.reg_opt.register_gradients(tf.reduce_mean(cN.reg * cN.reg_interval), cN.trainables)
cN.reg_norm = cN.reg_opt.norm
else:
cN.loss += cN.reg
cN.opt.register_gradients(tf.reduce_mean(cN.loss), cN.trainables)
cN.norm = cN.opt.norm
cN.loss_op = tf.reduce_mean(cN.loss) if cN.loss is not None else no_op
cN.regval_op = tf.reduce_mean(cN.reg) if cN.reg is not None else no_op
cN.ops = {"loss": cN.loss_op, "reg": cN.regval_op, "norm": cN.norm}
# Loading and logging
# ----------------------------------------------------------------------------
# Tracks exponential moving average: average, value -> new average
def emaAvg(avg, value, alpha = 0.995):
if value is None:
return avg
if avg is None:
return value
return avg * alpha + value * (1 - alpha)
# Load networks from snapshot
def load_nets(resume_pkl, lG, lD, lGs, recompile):
misc.log("Loading networks from %s..." % resume_pkl, "white")
rG, rD, rGs = pretrained_networks.load_networks(resume_pkl)
if recompile:
misc.log("Copying nets...")
lG.copy_vars_from(rG); lD.copy_vars_from(rD); lGs.copy_vars_from(rGs)
else:
lG, lD, lGs = rG, rD, rGs
return lG, lD, lGs
# Training Loop
# ----------------------------------------------------------------------------
# 1. Set up the environment and data
# 2. Build the generator (g) and discriminator (d) networks
# 3. Manage the training process
# 4. Run periodic evaluations on specified metrics
# 5. Produce sample images over the course of training
def training_loop(
# Configurations
cG = {}, cD = {}, # Generator and Discriminator command-line arguments
dataset_args = {}, # dataset.load_dataset() options
sched_args = {}, # train.TrainingSchedule options
vis_args = {}, # visualize.vis options
grid_args = {}, # train.setup_snapshot_img_grid() options
metric_arg_list = [], # MetricGroup Options
tf_config = {}, # tflib.init_tf() options
train = False, # Training mode
eval = False, # Evaluation mode
vis = False, # Visualization mode
# Data
data_dir = None, # Directory to load datasets from
total_kimg = 25000, # Total length of the training, measured in thousands of real images
mirror_augment = False, # Enable mirror augmentation?
drange_net = [-1,1], # Dynamic range used when feeding image data to the networks
# Optimization
batch_repeats = 4, # Number of batches to run before adjusting training parameters
lazy_regularization = True, # Perform regularization as a separate training step?
smoothing_kimg = 10.0, # Half-life of the running average of generator weights
clip = None, # Clip gradients threshold
# Resumption
resume_pkl = None, # Network pickle to resume training from, None = train from scratch.
resume_kimg = 0.0, # Assumed training progress at the beginning
# Affects reporting and training schedule
resume_time = 0.0, # Assumed wallclock time at the beginning, affects reporting
recompile = False, # Recompile network from source code (otherwise loads from snapshot)
# Logging
summarize = True, # Create TensorBoard summaries
save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file?
save_weight_histograms = False, # Include weight histograms in the tfevents file?
img_snapshot_ticks = 3, # How often to save image snapshots? None = disable
network_snapshot_ticks = 3, # How often to save network snapshots? None = only save networks-final.pkl
last_snapshots = 10, # Maximal number of prior snapshots to save
eval_images_num = 50000, # Sample size for the metrics
printname = ""): # Experiment name for logging
# Initialize dnnlib and TensorFlow
tflib.init_tf(tf_config)
num_gpus = dnnlib.submit_config.num_gpus
cG.name, cD.name = "g", "d"
# Load dataset, configure training scheduler and metrics object
dataset = data.load_dataset(data_dir = dnnlib.convert_path(data_dir), verbose = True, **dataset_args)
sched = training_schedule(sched_args, cur_nimg = total_kimg * 1000, dataset = dataset)
metrics = metric_base.MetricGroup(metric_arg_list)
# Construct or load networks
with tf.device("/gpu:0"):
no_op = tf.no_op()
G, D, Gs = None, None, None
if resume_pkl is None or recompile:
misc.log("Constructing networks...", "white")
G = tflib.Network("G", num_channels = dataset.shape[0], resolution = dataset.shape[1],
label_size = dataset.label_size, **cG.args)
D = tflib.Network("D", num_channels = dataset.shape[0], resolution = dataset.shape[1],
label_size = dataset.label_size, **cD.args)
Gs = G.clone("Gs")
if resume_pkl is not None:
G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile)
G.print_layers()
D.print_layers()
# Train/Evaluate/Visualize
# Labels are optional but not essential
grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid(dataset, **grid_args)
misc.save_img_grid(grid_reals, dnnlib.make_run_dir_path("reals.png"), drange = dataset.dynamic_range, grid_size = grid_size)
grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
if eval:
# Save a snapshot of the current network to evaluate
pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" % resume_kimg)
misc.save_pkl((G, D, Gs), pkl)
misc.log("Run evaluation...")
metric = metrics.run(pkl, num_imgs = eval_images_num, run_dir = dnnlib.make_run_dir_path(),
data_dir = dnnlib.convert_path(data_dir), num_gpus = num_gpus, ratio = dataset.ratio,
tf_config = tf_config, eval_mod = True, mirror_augment = mirror_augment)
if vis:
misc.log("Produce visualizations...")
visualize.vis(Gs, dataset, batch_size = sched.batch_gpu,
drange_net = drange_net, ratio = dataset.ratio, **vis_args)
if not train:
dataset.close()
exit()
# Setup training inputs
misc.log("Building TensorFlow graph...", "white")
with tf.name_scope("Inputs"), tf.device("/cpu:0"):
lrate_in_g = tf.placeholder(tf.float32, name = "lrate_in_g", shape = [])
lrate_in_d = tf.placeholder(tf.float32, name = "lrate_in_d", shape = [])
step = tf.placeholder(tf.int32, name = "step", shape = [])
batch_size_in = tf.placeholder(tf.int32, name = "batch_size_in", shape=[])
batch_gpu_in = tf.placeholder(tf.int32, name = "batch_gpu_in", shape=[])
batch_multiplier = batch_size_in // (batch_gpu_in * num_gpus)
beta = 0.5 ** tf.div(tf.cast(batch_size_in, tf.float32),
smoothing_kimg * 1000.0) if smoothing_kimg > 0.0 else 0.0
# Set optimizers
for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]:
set_optimizer(cN, lr, batch_multiplier, lazy_regularization, clip)
# Build training graph for each GPU
data_fetch_ops = []
for gpu in range(num_gpus):
with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):
# Create GPU-specific shadow copies of G and D
for cN, N in [(cG, G), (cD, D)]:
cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow")
Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow")
# Fetch training data via temporary variables
with tf.name_scope("DataFetch"):
reals, labels = dataset.get_batch_tf()
reals = process_reals(reals, dataset.dynamic_range, drange_net, mirror_augment)
reals, reals_fetch = read_data(reals, "reals",
[sched.batch_gpu] + dataset.shape, batch_gpu_in)
labels, labels_fetch = read_data(labels, "labels",
[sched.batch_gpu, dataset.label_size], batch_gpu_in)
data_fetch_ops += [reals_fetch, labels_fetch]
# Evaluate loss functions
with tf.name_scope("G_loss"):
cG.loss, cG.reg = dnnlib.util.call_func_by_name(G = cG.gpu, D = cD.gpu, dataset = dataset,
reals = reals, batch_size = batch_gpu_in, **cG.loss_args)
with tf.name_scope("D_loss"):
cD.loss, cD.reg = dnnlib.util.call_func_by_name(G = cG.gpu, D = cD.gpu, dataset = dataset,
reals = reals, labels = labels, batch_size = batch_gpu_in, **cD.loss_args)
for cN in [cG, cD]:
set_optimizer_ops(cN, lazy_regularization, no_op)
# Setup training ops
data_fetch_op = tf.group(*data_fetch_ops)
for cN in [cG, cD]:
cN.train_op = cN.opt.apply_updates()
cN.reg_op = cN.reg_opt.apply_updates(allow_no_op = True)
Gs_update_op = Gs.setup_as_moving_average_of(G, beta = beta)
# Finalize graph
with tf.device("/gpu:0"):
try:
peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
except tf.errors.NotFoundError:
peak_gpu_mem_op = tf.constant(0)
tflib.init_uninitialized_vars()
# Tensorboard summaries
if summarize:
misc.log("Initializing logs...", "white")
summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
if save_tf_graph:
summary_log.add_graph(tf.get_default_graph())
if save_weight_histograms:
G.setup_weight_histograms(); D.setup_weight_histograms()
# Initialize training
misc.log("Training for %d kimg..." % total_kimg, "white")
dnnlib.RunContext.get().update("", cur_epoch = resume_kimg, max_epoch = total_kimg)
maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
cur_tick, running_mb_counter = -1, 0
cur_nimg = int(resume_kimg * 1000)
tick_start_nimg = cur_nimg
for cN in [cG, cD]:
cN.lossvals_agg = {k: None for k in ["loss", "reg", "norm", "reg_norm"]}
cN.opt.reset_optimizer_state()
# Training loop
while cur_nimg < total_kimg * 1000:
if dnnlib.RunContext.get().should_stop():
break
# Choose training parameters and configure training ops
sched = training_schedule(sched_args, cur_nimg = cur_nimg, dataset = dataset)
assert sched.batch_size % (sched.batch_gpu * num_gpus) == 0
dataset.configure(sched.batch_gpu)
# Run training ops
feed_dict = {
lrate_in_g: sched.G_lrate,
lrate_in_d: sched.D_lrate,
batch_size_in: sched.batch_size,
batch_gpu_in: sched.batch_gpu,
step: sched.kimg
}
# Several iterations before updating training parameters
for _repeat in range(batch_repeats):
rounds = range(0, sched.batch_size, sched.batch_gpu * num_gpus)
for cN in [cG, cD]:
cN.run_reg = lazy_regularization and (running_mb_counter % cN.reg_interval == 0)
cur_nimg += sched.batch_size
running_mb_counter += 1
for cN in [cG, cD]:
cN.lossvals = {k: None for k in ["loss", "reg", "norm", "reg_norm"]}
# Gradient accumulation
for _round in rounds:
cG.lossvals.update(tflib.run([cG.train_op, cG.ops], feed_dict)[1])
if cG.run_reg:
_, cG.lossvals["reg_norm"] = tflib.run([cG.reg_op, cG.reg_norm], feed_dict)
tflib.run(data_fetch_op, feed_dict)
cD.lossvals.update(tflib.run([cD.train_op, cD.ops], feed_dict)[1])
if cD.run_reg:
_, cD.lossvals["reg_norm"] = tflib.run([cD.reg_op, cD.reg_norm], feed_dict)
tflib.run([Gs_update_op], feed_dict)
# Track loss statistics
for cN in [cG, cD]:
for k in cN.lossvals_agg:
cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k], cN.lossvals[k])
# Perform maintenance tasks once per tick
done = (cur_nimg >= total_kimg * 1000)
if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
cur_tick += 1
tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
tick_start_nimg = cur_nimg
tick_time = dnnlib.RunContext.get().get_time_since_last_update()
total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time
# Report progress
print(("tick %s kimg %s loss/reg: G (%s %s) D (%s %s) grad norms: G (%s %s) D (%s %s) " +
"time %s sec/kimg %s maxGPU %sGB %s") % (
misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)),
misc.bcolored(f"{autosummary('Progress/kimg', cur_nimg / 1000.0):>8.1f}", "red"),
misc.bcolored(f"{(cG.lossvals_agg['loss'] or 0):>6.3f}", "blue"),
misc.bold(f"{(cG.lossvals_agg['reg'] or 0):>6.3f}"),
misc.bcolored(f"{(cD.lossvals_agg['loss'] or 0):>6.3f}", "blue"),
misc.bold(f"{(cD.lossvals_agg['reg'] or 0):>6.3f}"),
misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"),
misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"),
misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"),
misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"),
misc.bold("%-10s" % dnnlib.util.format_time(autosummary("Timing/total_sec", total_time))),
f"{autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):>7.2f}",
f"{autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):>4.1f}",
misc.bold(printname)))
autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))
# Save snapshots
if img_snapshot_ticks is not None and (cur_tick % img_snapshot_ticks == 0 or done):
visualize.vis(Gs, dataset, batch_size = sched.batch_gpu, training = True,
step = cur_nimg // 1000, grid_size = grid_size, latents = grid_latents,
labels = grid_labels, drange_net = drange_net, ratio = dataset.ratio, **vis_args)
if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" % (cur_nimg // 1000))
misc.save_pkl((G, D, Gs), pkl)
if cur_tick % network_snapshot_ticks == 0 or done:
metric = metrics.run(pkl, num_imgs = eval_images_num, run_dir = dnnlib.make_run_dir_path(),
data_dir = dnnlib.convert_path(data_dir), num_gpus = num_gpus, ratio = dataset.ratio,
tf_config = tf_config, mirror_augment = mirror_augment)
if last_snapshots > 0:
misc.rm(sorted(glob.glob(dnnlib.make_run_dir_path("network*.pkl")))[:-last_snapshots])
# Update summaries and RunContext
if summarize:
metrics.update_autosummaries()
tflib.autosummary.save_summaries(summary_log, cur_nimg)
dnnlib.RunContext.get().update(None, cur_epoch = cur_nimg // 1000, max_epoch = total_kimg)
maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time
# Save final snapshot
misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl"))
# All done
if summarize:
summary_log.close()
dataset.close()