-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
87 lines (72 loc) · 3.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import yaml
import os
import shutil
import numpy as np
import random
import torch
from train import pretrain, adapt
from evaluation import test
from util import feat_list
from regularizer import LifeLongAgent
from preprocess import OnlinePreprocessor
from model import LSTM, IRM, Residual
from asteroid.losses.sdr import SingleSrcNegSDR
def main():
parser = argparse.ArgumentParser(
description='Argument Parser for SERIL.')
parser.add_argument('--logdir', default='log',
help='Name of current experiment.')
parser.add_argument('--n_jobs', default=2, type=int)
parser.add_argument(
'--do', choices=['train', 'test'], default='train', type=str)
parser.add_argument(
'--mode', choices=['seril', 'finetune'], default='seril', type=str)
parser.add_argument(
'--model', choices=['LSTM', 'Residual', 'IRM'], default='LSTM', type=str)
# Options
parser.add_argument(
'--config', default='config/config.yaml', required=False)
parser.add_argument('--seed', default=1126, type=int,
help='Random seed for reproducable results.', required=False)
parser.add_argument('--gpu', default='2', type=int,
help='Assigning GPU id.')
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# build log directory
os.makedirs(args.logdir, exist_ok=True)
# load configure
config = yaml.load(open(args.config, 'r'), Loader=yaml.FullLoader)
if config['train']['loss'] == 'sisdr':
loss_func = SingleSrcNegSDR("sisdr", zero_mean=False,
reduction='mean')
if args.do == 'train':
torch.cuda.set_device(args.gpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
assert len(config['dataset']['train']['clean']) == len(
config['dataset']['train']['noisy']) and len(config['dataset']['train']['clean']) >= 1
model_path = f'{args.logdir}/pretrain/{args.model}_model_T0.pth'
lifelong_agent_path = f'{args.logdir}/pretrain/{args.model}_synapses_T0.pth'
if os.path.exists(model_path) and os.path.exists(lifelong_agent_path):
print(f'[Runner] - pretrain model has already existed!')
model = torch.load(model_path).to(device)
lifelong_agent = torch.load(lifelong_agent_path).to(device)
lifelong_agent.load_config(**config['train']['strategies'])
else:
print(f'[Runner] - run pretrain process!')
preprocessor = OnlinePreprocessor(feat_list=feat_list).to(device)
model = eval(f'{args.model}')(loss_func, preprocessor, **config['model']).to(device)
lifelong_agent = LifeLongAgent(model, **config['train']['strategies'])
pretrain(args, config, model, lifelong_agent)
print(f'[Runner] - run adaptation process!')
args.logdir = f'{args.logdir}/{args.mode}'
if args.mode == 'seril':
adapt(args, config, model, lifelong_agent)
elif args.mode == 'finetune':
adapt(args, config, model)
if args.do == 'test':
test(args, config)
if __name__ == "__main__":
main()