forked from daibiaoxuwu/NeLoRa_Dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
137 lines (111 loc) · 5.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Main script for project."""
from __future__ import print_function
import config
import data_loader
import end2end
import os
import numpy as np
import torch
import sys
from model_components0 import maskCNNModel0, classificationHybridModel0
from utils import *
import collections
def load_checkpoint(opts, maskCNNModel, classificationHybridModel):
maskCNN_path = os.path.join(opts.load_checkpoint_dir, str(opts.load_iters) + '_maskCNN.pkl')
C_XtoY_path = os.path.join(opts.load_checkpoint_dir, str(opts.load_iters) + '_C_XtoY.pkl')
print('LOAD MODEL:', maskCNN_path)
maskCNN = maskCNNModel(opts)
if opts.load_maskcnn == 'True':
state_dict = torch.load(maskCNN_path, map_location=lambda storage, loc: storage)
for key in list(state_dict.keys()): state_dict[key.replace('module.', '')] = state_dict.pop(key)
#state_dict['conv2.1.weight']= torch.cat((state_dict['conv2.1.weight'], torch.zeros(64,258-130,5,5)),1)
#state_dict['fc1.weight']= state_dict['fc1.weight'][:, :4096]
#state_dict.pop('fc2.weight')
#state_dict.pop('fc2.bias')
maskCNN.load_state_dict(state_dict)#, strict=False)
if opts.cxtoy == 'True':
C_XtoY = classificationHybridModel(conv_dim_in=opts.x_image_channel, conv_dim_out=opts.n_classes, conv_dim_lstm=opts.conv_dim_lstm)
if opts.load_cxtoy == 'True' and os.path.exists(C_XtoY_path):
state_dict = torch.load( C_XtoY_path, map_location=lambda storage, loc: storage)
if type(state_dict)==collections.OrderedDict:
for key in list(state_dict.keys()): state_dict[key.replace('module.', '')] = state_dict.pop(key)
#state_dict['dense.weight']= state_dict['dense.weight'][:,:state_dict['dense.weight'].shape[1]//opts.stack_imgs ]
C_XtoY.load_state_dict(state_dict)#, strict=False)
else:
C_XtoY = torch.load(C_XtoY_path)
return [maskCNN, C_XtoY]
else: return [maskCNN, ]
def main(opts,models):
torch.cuda.empty_cache()
# Create train and test dataloaders for images from the two domains X and Y
training_dataloader, testing_dataloader = data_loader.lora_loader(opts)
# Create checkpoint directories
# Start training
set_gpu(opts.free_gpu_id)
# start training
models = end2end.training_loop(training_dataloader,testing_dataloader,models, opts)
return models
if __name__ == "__main__":
print('=' * 80)
print('Opts'.center(80))
print('-' * 80)
print('COMMAND: ', ' '.join(sys.argv))
parser = config.create_parser()
opts = parser.parse_args()
if opts.sf == -1:
opts.sf = int(opts.checkpoint_dir.split('-')[-1])
opts.data_dir='/data/djl/SF'+str(opts.sf)+'_125K'
opts.n_classes = 2 ** opts.sf
opts.stft_nfft = opts.n_classes * opts.fs // opts.bw
opts.stft_window = opts.n_classes // 2 * opts.stft_mod
opts.stft_overlap = opts.stft_window // 2 // opts.stft_mod
opts.conv_dim_lstm = opts.n_classes * opts.fs // opts.bw
print('opts.conv_dim_lstm ',opts.conv_dim_lstm )
opts.freq_size = opts.n_classes
create_dir(opts.checkpoint_dir)
if opts.lr == -1:
opts.lr = 0.001
if min(opts.snr) < -15: opts.lr *= 0.3
if min(opts.snr) < -20: opts.lr /= 1.5
if opts.w_image == -1:
opts.w_image = 1
if min(opts.snr) < -15: opts.w_image *= 4
if min(opts.snr) < -20: opts.w_image *= 4
#default checkpoint dir
if opts.load_checkpoint_dir == '/data/djl': opts.load_checkpoint_dir = opts.checkpoint_dir
maskCNNModel = maskCNNModel0
classificationHybridModel = classificationHybridModel0
if opts.load == 'yes':
if opts.load_iters == -1:
vals = [int(fname.split('_')[0]) for fname in os.listdir(opts.load_checkpoint_dir) if fname[-4:] == '.pkl']
if len(vals)==0 or max(vals) == 0:
opts.load = 'no'
print('--WARNING: CHECKPOINT_DIR NOT EXIST, SETTING OPTS.LOAD TO NO--')
else: opts.load_iters = max(vals)
codepath = os.path.join(opts.checkpoint_dir, 'code'+str(opts.load_iters))
create_dir(codepath)
os.system('cp '+ os.path.dirname(os.path.abspath(__file__))+r'/*.py '+codepath)
if opts.load == 'yes':
print('LOAD ITER: ',opts.load_iters)
models = load_checkpoint(opts, maskCNNModel, classificationHybridModel)
mask_CNN = models[0]
if opts.cxtoy == 'True': C_XtoY = models[1]
else:
mask_CNN = maskCNNModel(opts)
if opts.cxtoy == 'True': C_XtoY = classificationHybridModel(conv_dim_in=opts.x_image_channel, conv_dim_out=opts.n_classes, conv_dim_lstm= opts.conv_dim_lstm)
#mask_CNN = nn.DataParallel(mask_CNN)
mask_CNN.cuda()
models = [mask_CNN, ]
if opts.cxtoy == 'True':
#C_XtoY = nn.DataParallel(C_XtoY)
C_XtoY.cuda()
models.append(C_XtoY)
opts.logfile = os.path.join(opts.checkpoint_dir, 'logfile-djl-train.txt')
opts.logfile2 = os.path.join(opts.checkpoint_dir, 'logfile2-djl-train.txt')
strlist = print_opts(opts)
with open(opts.logfile,'a') as f: f.write('\n'+' '.join(sys.argv))
with open(opts.logfile,'a') as f: f.write('\n'+str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + ' ' +'\n'.join(strlist)+'\n')
with open(opts.logfile2,'a') as f: f.write(str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + ' snr ' +str(opts.snr)+' : ')
opts.init_train_iter = opts.load_iters
models = main(opts,models)
opts.init_train_iter += opts.train_iters