forked from jason9693/MusicTransformer-tensorflow2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
140 lines (115 loc) · 5.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from model import MusicTransformerDecoder
from custom.layers import *
from custom import callback
import params as par
from tensorflow.python.keras.optimizer_v2.adam import Adam
from data import Data
import utils
import argparse
import datetime
import sys
from scripts.utils import write_csv
import timeit
tf.executing_eagerly()
parser = argparse.ArgumentParser()
parser.add_argument('--l_r', default=None, help='학습률', type=float)
parser.add_argument('--batch_size', default=2, help='batch size', type=int)
parser.add_argument('--pickle_dir', default='music', help='데이터셋 경로')
parser.add_argument('--max_seq', default=2048, help='최대 길이', type=int)
parser.add_argument('--epochs', default=100, help='에폭 수', type=int)
parser.add_argument('--load_path', default=None, help='모델 로드 경로', type=str)
parser.add_argument('--save_path', default="result/dec0722", help='모델 저장 경로')
parser.add_argument('--is_reuse', default=False)
parser.add_argument('--multi_gpu', default=True)
parser.add_argument('--num_layers', default=6, type=int)
args = parser.parse_args()
# set arguments
l_r = args.l_r
batch_size = args.batch_size
pickle_dir = args.pickle_dir
max_seq = args.max_seq
epochs = args.epochs
is_reuse = args.is_reuse
load_path = args.load_path
save_path = args.save_path
multi_gpu = args.multi_gpu
num_layer = args.num_layers
# load data
dataset = Data(pickle_dir)
print(dataset)
# load model
learning_rate = callback.CustomSchedule(par.embedding_dim) if l_r is None else l_r
opt = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
start_time = timeit.default_timer()
skipped_time = 0
# define model
mt = MusicTransformerDecoder(
embedding_dim=256,
vocab_size=par.vocab_size,
num_layer=num_layer,
max_seq=max_seq,
dropout=0.2,
debug=False, loader_path=load_path)
mt.compile(optimizer=opt, loss=callback.transformer_dist_train_loss)
# define tensorboard writer
current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
train_log_dir = 'logs/mt_decoder/'+current_time+'/train'
eval_log_dir = 'logs/mt_decoder/'+current_time+'/eval'
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
eval_summary_writer = tf.summary.create_file_writer(eval_log_dir)
total_loss = 0
loss_count = 0
total_accuracy = 0
accuracy_count = 0
# Train Start
idx = 0
for e in range(epochs):
mt.reset_metrics()
for b in range(len(dataset.files) // batch_size):
try:
batch_x, batch_y = dataset.slide_seq2seq_batch(batch_size, max_seq)
except:
continue
result_metrics = mt.train_on_batch(batch_x, batch_y)
if b % 100 == 0:
eval_x, eval_y = dataset.slide_seq2seq_batch(batch_size, max_seq, 'eval')
eval_result_metrics, weights = mt.evaluate(eval_x, eval_y)
print_time = timeit.default_timer()
mt.save(save_path)
skipped_time += timeit.default_timer() - print_time
with train_summary_writer.as_default():
if b == 0:
tf.summary.histogram("target_analysis", batch_y, step=e)
tf.summary.histogram("source_analysis", batch_x, step=e)
tf.summary.scalar('loss', result_metrics[0], step=idx)
total_loss += result_metrics[0]
loss_count += 1
tf.summary.scalar('accuracy', result_metrics[1], step=idx)
total_accuracy += result_metrics[1]
accuracy_count += 1
with eval_summary_writer.as_default():
if b == 0:
mt.sanity_check(eval_x, eval_y, step=e)
tf.summary.scalar('loss', eval_result_metrics[0], step=idx)
tf.summary.scalar('accuracy', eval_result_metrics[1], step=idx)
for i, weight in enumerate(weights):
with tf.name_scope("layer_%d" % i):
with tf.name_scope("w"):
utils.attention_image_summary(weight, step=idx)
# for i, weight in enumerate(weights):
# with tf.name_scope("layer_%d" % i):
# with tf.name_scope("_w0"):
# utils.attention_image_summary(weight[0])
# with tf.name_scope("_w1"):
# utils.attention_image_summary(weight[1])
idx += 1
print_time = timeit.default_timer()
print('\n====================================================')
print('Epoch/Batch: {}/{}'.format(e, b))
print('Train >>>> Loss: {:6.6}, Accuracy: {}'.format(result_metrics[0], result_metrics[1]))
print('Eval >>>> Loss: {:6.6}, Accuracy: {}'.format(eval_result_metrics[0], eval_result_metrics[1]))
skipped_time += timeit.default_timer() - print_time
time = timeit.default_timer() - start_time - skipped_time
avg_loss = float(total_loss) / float(loss_count)
avg_accuracy = float(total_accuracy) / float(accuracy_count)
write_csv(__file__, epochs=epochs, accuracy=float(avg_accuracy), loss=float(avg_loss), time=time)