-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_interface.py
122 lines (102 loc) · 4.31 KB
/
train_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import time
import torch
from Baseline import Baseline
from Baseline_TLS import Baseline_TLS
from Baseline_CTFA import Baseline_CTFA
from NUNet_TLS import NUNet_TLS
from trainer import train, valid, joint_train, joint_valid
from dataloader import create_dataloader
import tools
import config as cfg
import warnings
warnings.filterwarnings("ignore")
#######################################################################
# Set a job and a log folder #
#######################################################################
dir2sav = cfg.job_dir
dir2log = cfg.logs_dir
# make the folder
if not os.path.exists(dir2sav):
os.mkdir(dir2sav)
if not os.path.exists(dir2log):
os.mkdir(dir2log)
#######################################################################
# Model init #
#######################################################################
DEVICE = torch.device(cfg.DEVICE)
# define model
if cfg.model_mode == 'Baseline':
model = Baseline().to(DEVICE)
elif cfg.model_mode == 'Baseline+TLS':
model = Baseline_TLS().to(DEVICE)
elif cfg.model_mode == 'Baseline+CTFA':
model = Baseline_CTFA().to(DEVICE)
elif cfg.model_mode == 'NUNet-TLS':
model = NUNet_TLS().to(DEVICE)
# define train mode
if cfg.joint_loss:
trainer = joint_train
validator = joint_valid
else:
trainer = train
validator = valid
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)
total_params = tools.cal_total_params(model)
# load the params if there is pretrained model
epoch_start_idx = 1
if os.path.exists(cfg.pretrained_addr):
print('Load the pretrained model...')
chkpt = torch.load(cfg.pretrained_addr + '/chkpt_{}.pt'.format(cfg.chkpt_num))
model.load_state_dict(chkpt['model'])
optimizer.load_state_dict(chkpt['optimizer'])
epoch_start_idx = chkpt['epoch'] + 1
dir2sav = cfg.pretrained_addr
#######################################################################
# Create Dataloader #
#######################################################################
train_loader = create_dataloader(mode='train')
valid_loader = create_dataloader(mode='valid')
#######################################################################
# Confirm model intormation #
#######################################################################
print('%d-%d-%d %d:%d:%d\n' %
(time.localtime().tm_year, time.localtime().tm_mon,
time.localtime().tm_mday, time.localtime().tm_hour,
time.localtime().tm_min, time.localtime().tm_sec))
print('total params : %d (%.2f M, %.2f MBytes)\n' %
(total_params,
total_params / 1000000.0,
total_params * 4.0 / 1000000.0))
# save the status information
tools.write_status(dir2sav)
#######################################################################
#######################################################################
# Main #
#######################################################################
#######################################################################
writer = tools.Writer(dir2log)
train_log_fp = open(dir2sav + '/train_log.txt', 'a')
print('Main program start...')
for epoch in range(epoch_start_idx, cfg.max_epoch + 1):
st_time = time.time()
# train
train_loss = trainer(model, train_loader, optimizer, writer, epoch, DEVICE)
# save checkpoint file to resume training
save_path = str(dir2sav + '/chkpt_%d.pt' % epoch)
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}, save_path)
# validate
valid_loss, pesq, stoi = validator(model, valid_loader, writer, epoch, DEVICE)
print('EPOCH[{}] T {:.6f} | V {:.6f} takes {:.3f} seconds'
.format(epoch, train_loss, valid_loss, time.time() - st_time))
print('PESQ {:.6f} | STOI {:.6f}'.format(pesq, stoi))
# write train log
train_log_fp.write('EPOCH[{}] T {:.6f} | V {:.6f} takes {:.3f} seconds\n'
.format(epoch, train_loss, valid_loss, time.time() - st_time))
train_log_fp.write('PESQ {:.6f} | STOI {:.6f}\n'.format(pesq, stoi))
print('Training has been finished.')
train_log_fp.close()