-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathsingle_length_train.py
104 lines (79 loc) · 5.29 KB
/
single_length_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
from tensorflow.python.training.saver import latest_checkpoint
from config import *
from language_helpers import generate_argmax_samples_and_gt_samples, inf_train_gen, decode_indices_to_string
from objective import get_optimization_ops, define_objective
from summaries import define_summaries, \
log_samples
sys.path.append(os.getcwd())
from model import *
import model_and_data_serialization
# Download Google Billion Word at http://www.statmt.org/lm-benchmark/ and
# fill in the path to the extracted files here!
def run(iterations, seq_length, is_first, charmap, inv_charmap, prev_seq_length):
if len(DATA_DIR) == 0:
raise Exception('Please specify path to data directory in single_length_train.py!')
lines, _, _ = model_and_data_serialization.load_dataset(seq_length=seq_length, b_charmap=False, b_inv_charmap=False,
n_examples=FLAGS.MAX_N_EXAMPLES)
real_inputs_discrete = tf.placeholder(tf.int32, shape=[BATCH_SIZE, seq_length])
global_step = tf.Variable(0, trainable=False)
disc_cost, gen_cost, fake_inputs, disc_fake, disc_real, disc_on_inference, inference_op = define_objective(charmap,
real_inputs_discrete,
seq_length)
merged, train_writer = define_summaries(disc_cost, gen_cost, seq_length)
disc_train_op, gen_train_op = get_optimization_ops(disc_cost, gen_cost, global_step)
saver = tf.train.Saver(tf.trainable_variables())
with tf.Session() as session:
session.run(tf.initialize_all_variables())
if not is_first:
print("Loading previous checkpoint...")
internal_checkpoint_dir = model_and_data_serialization.get_internal_checkpoint_dir(prev_seq_length)
model_and_data_serialization.optimistic_restore(session,
latest_checkpoint(internal_checkpoint_dir, "checkpoint"))
restore_config.set_restore_dir(
load_from_curr_session=True) # global param, always load from curr session after finishing the first seq
gen = inf_train_gen(lines, charmap)
for iteration in range(iterations):
start_time = time.time()
# Train critic
for i in range(CRITIC_ITERS):
_data = next(gen)
_disc_cost, _, real_scores = session.run(
[disc_cost, disc_train_op, disc_real],
feed_dict={real_inputs_discrete: _data}
)
# Train G
for i in range(GEN_ITERS):
_data = next(gen)
_ = session.run(gen_train_op, feed_dict={real_inputs_discrete: _data})
print("iteration %s/%s"%(iteration, iterations))
print("disc cost %f"%_disc_cost)
# Summaries
if iteration % 100 == 99:
_data = next(gen)
summary_str = session.run(
merged,
feed_dict={real_inputs_discrete: _data}
)
train_writer.add_summary(summary_str, global_step=iteration)
fake_samples, samples_real_probabilites, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
fake_inputs,
disc_fake,
gen,
real_inputs_discrete,
feed_gt=True)
log_samples(fake_samples, fake_scores, iteration, seq_length, "train")
log_samples(decode_indices_to_string(_data, inv_charmap), real_scores, iteration, seq_length,
"gt")
test_samples, _, fake_scores = generate_argmax_samples_and_gt_samples(session, inv_charmap,
inference_op,
disc_on_inference,
gen,
real_inputs_discrete,
feed_gt=False)
# disc_on_inference, inference_op
log_samples(test_samples, fake_scores, iteration, seq_length, "test")
if iteration % FLAGS.SAVE_CHECKPOINTS_EVERY == FLAGS.SAVE_CHECKPOINTS_EVERY-1:
saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
saver.save(session, model_and_data_serialization.get_internal_checkpoint_dir(seq_length) + "/ckp")
session.close()