forked from daibiaoxuwu/NeLoRa_Dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathend2end.py
230 lines (196 loc) · 10.9 KB
/
end2end.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# end2end.py
from __future__ import division
import os
from scipy.signal import chirp
import warnings
import torch
import numpy as np
import cv2
from utils import *
import time
import math
warnings.filterwarnings("ignore")
def checkpoint(iteration, models, opts):
mask_CNN_path = os.path.join(opts.checkpoint_dir, str(iteration) + '_maskCNN.pkl')
create_dir(opts.checkpoint_dir)
torch.save(models[0].state_dict(), mask_CNN_path)
if opts.cxtoy == 'True':
C_XtoY_path = os.path.join(opts.checkpoint_dir, str(iteration) + '_C_XtoY.pkl')
torch.save(models[1].state_dict(), C_XtoY_path)
print('CKPT: ', mask_CNN_path)
def merge_images(sources, targets, Y, opts):
"""Creates a grid consisting of pairs of columns, where the first column in
each pair contains images source images and the second column in each pair
contains images generated by the CycleGAN from the corresponding images in
the first column.
"""
_, h, w = sources[0].shape
row = int(np.sqrt(opts.batch_size))
column = math.ceil(opts.batch_size / row)
merged = np.zeros([2, row * h , column * w * 3])
for idx, (s, t, y) in enumerate(zip(sources, targets, Y )):
i = idx // column
j = idx % column
merged[:, i * h:(i + 1) * h, (j * 3 + 0) * w:(j * 3 + 1) * w,] = s
merged[:, i * h:(i + 1) * h, (j * 3 + 1) * w:(j * 3 + 2) * w,] = t
merged[:, i * h:(i + 1) * h, (j * 3 + 2) * w:(j * 3 + 3) * w,] = y
merged = merged.transpose(1, 2, 0)
return merged
def save_samples(iteration, fixed_Y, fixed_X, fake_Y, name, opts):
"""Saves samples from both generators X->Y and Y->X.
"""
fixed_X = [to_data(i) for i in fixed_X]
Y, fake_Y = [to_data(i) for i in fixed_Y], [to_data(i) for i in fake_Y]
mergeda = merge_images(fixed_X, fake_Y, Y, opts)
path = os.path.join(opts.checkpoint_dir, 'sample-{:06d}-snr{:.1f}-Y{:s}.png'.format(iteration,opts.snr,name))
merged = np.abs(mergeda[:, :, 0]+1j*mergeda[:, :, 1])
merged = (merged - np.min(merged)) / (np.max(merged) - np.min(merged)) * 255
merged = cv2.flip(merged, 0)
#print(np.max(merged),np.min(merged),np.mean(merged))
cv2.imwrite(path, merged)
print('SAVED TEST SAMPLE: {}'.format(path))
def work2(fake_Y_spectrum, images_Y_spectrum,labels_X, C_XtoY, opts):
if opts.comp_channel == 2:
g_y_pix_loss = opts.loss_spec(fake_Y_spectrum, images_Y_spectrum)
else:
g_y_pix_loss = opts.loss_spec( torch.abs(fake_Y_spectrum[:,0]+1j*fake_Y_spectrum[:,1]), torch.abs(images_Y_spectrum[:,0]+1j*images_Y_spectrum[:,1]))
G_Image_loss = opts.w_image * g_y_pix_loss
if opts.cxtoy == 'True':
if opts.cxtoy_pretrain == 'True':
if opts.x_image_channel == 1:
labels_X_estimated = C_XtoY(torch.unsqueeze(torch.abs(images_Y_spectrum[:,0]+1j*images_Y_spectrum[:,1]),1))
else:
labels_X_estimated = C_XtoY(images_Y_spectrum)
else:
if opts.x_image_channel == 1:
labels_X_estimated = C_XtoY(torch.unsqueeze(torch.abs(images_Y_spectrum[:,0]+1j*images_Y_spectrum[:,1]),1))
else:
labels_X_estimated = C_XtoY(fake_Y_spectrum)
else:
fake_Y_spectrum = spec_to_network_input(torch.squeeze(fake_Y_spectrum[:,0]+1j*fake_Y_spectrum[:,1]),opts) # ???
assert(fake_Y_spectrum.dtype==torch.cfloat)
labels_X_estimated = torch.nn.functional.softmax(torch.abs(fake_Y_spectrum).sum(-1),dim=1).squeeze()
g_y_class_loss = opts.loss_class(labels_X_estimated, labels_X)
#if (opts.iteration - opts.init_train_iter) % opts.test_step == 1: print(torch.max(labels_X_estimated,1)[1], labels_X)
#print(labels_X_estimated[0], labels_X[0])
G_Class_loss = g_y_class_loss
_, labels_X_estimated = torch.max(labels_X_estimated, 1)
#print(labels_X_estimated, labels_X)
return G_Image_loss, G_Class_loss, labels_X_estimated
def work(images_X, labels_X, images_Y, opts, downchirp, mask_CNN, C_XtoY):
images_X = to_var(images_X)
images_Y = to_var(images_Y)
if opts.dechirp == 'True': images_X = images_X * downchirp
if opts.dechirp == 'True': images_Y = images_Y * downchirp
images_X_spectrum_raw = torch.stft(input=images_X, n_fft=opts.stft_nfft,
hop_length=opts.stft_overlap , win_length=opts.stft_window ,
pad_mode='constant',return_complex=True)
images_X_spectrum = spec_to_network_input2( spec_to_network_input(images_X_spectrum_raw, opts), opts )
images_Y_spectrum_raw = torch.stft(input=images_Y, n_fft=opts.stft_nfft,
hop_length=opts.stft_overlap , win_length=opts.stft_window ,
pad_mode='constant',return_complex=True)
images_Y_spectrum_raw = spec_to_network_input(images_Y_spectrum_raw, opts)
images_Y_spectrum = spec_to_network_input2(images_Y_spectrum_raw, opts)
fake_Y_spectrums = mask_CNN(images_X_spectrum)
if (opts.iteration - opts.init_train_iter) % opts.test_step == 1:
save_samples(opts.iteration, images_Y_spectrum, images_X_spectrum, fake_Y_spectrums, 'val', opts)
G_Y_loss = 0
G_Image_loss = 0
G_Class_loss = 0
G_Acc = 0
G_Image_loss_img, G_Class_loss_img, labels_X_estimated = work2(fake_Y_spectrums, images_Y_spectrum,labels_X,C_XtoY, opts)
G_Image_loss += G_Image_loss_img
G_Class_loss += G_Class_loss_img
G_Acc += torch.sum(labels_X_estimated == labels_X) / opts.batch_size
return G_Image_loss, G_Class_loss, G_Acc
def training_loop(training_dataloader,testing_dataloader, models, opts):
"""Runs the training loop.
* Saves checkpoint every opts.checkpoint_every iterations
"""
mask_CNN = models[0]
if opts.cxtoy == 'True': C_XtoY = models[1]
else: C_XtoY = None
opts.loss_spec = torch.nn.MSELoss(reduction='mean')
opts.loss_class = torch.nn.CrossEntropyLoss()
opts.iteration = 0
nsamp = int(opts.fs * opts.n_classes / opts.bw)
t = np.linspace(0, nsamp / opts.fs, nsamp)
chirpI1 = chirp(t, f0=opts.bw/2, f1=-opts.bw/2, t1=2** opts.sf / opts.bw , method='linear', phi=90)
chirpQ1 = chirp(t, f0=opts.bw/2, f1=-opts.bw/2, t1=2** opts.sf / opts.bw, method='linear', phi=0)
dechirp = chirpI1+1j*chirpQ1
downchirp1 = torch.tensor(dechirp, dtype=torch.cfloat).cuda()
downchirp = torch.stack([ downchirp1 for i in range(opts.batch_size)])
g_params = list(mask_CNN.parameters())
if opts.cxtoy == 'True': g_params += list(C_XtoY.parameters())
g_optimizer = torch.optim.Adam(g_params, opts.lr, [opts.beta1, opts.beta2])
G_Y_loss_avg = 0
G_Image_loss_avg = 0
G_Class_loss_avg = 0
G_Acc_avg = 0
iteration = opts.init_train_iter
oldtime = time.time()
scoreboards = [0, 0,]
trim_size = opts.freq_size // 2
print(' CURRENT TIME ITER YLOSS ILOSS CLOSS ACC TIME ----TRAINING',opts.lr,'----')
while iteration<=opts.init_train_iter+opts.train_iters:
iteration+=1
mask_CNN.train()
if opts.cxtoy == 'True':C_XtoY.train()
images_X, labels_X, images_Y = next(training_dataloader.__iter__())
g_optimizer.zero_grad()
G_Image_loss, G_Class_loss, G_Acc = work(images_X, labels_X, images_Y, opts, downchirp, mask_CNN, C_XtoY)
G_Y_loss = G_Image_loss + G_Class_loss
G_Y_loss.backward()
g_optimizer.step()
G_Y_loss_avg += G_Y_loss.item()
G_Image_loss_avg += G_Image_loss.item()
G_Class_loss_avg += G_Class_loss.item()
G_Acc_avg += G_Acc
if iteration % opts.log_step == 0:
output_lst = [ "{:6d}".format(iteration),
"{:6.3f}".format(G_Y_loss_avg / opts.log_step ),
"{:6.3f}".format(G_Image_loss_avg / opts.log_step),
"{:6.3f}".format(G_Class_loss_avg / opts.log_step),
"{:6.3f}".format(G_Acc_avg / opts.log_step),
"{:6.3f}".format(time.time() - oldtime)]
output_str = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + ' ' + ' '.join(output_lst)
oldtime = time.time()
print(output_str)
with open(opts.logfile,'a') as f: f.write('\n'+output_str)
G_Y_loss_avg = 0
G_Image_loss_avg = 0
G_Class_loss_avg = 0
G_Acc_avg = 0
if (iteration) % opts.checkpoint_every == 0: checkpoint(iteration, models, opts)
if (iteration - opts.init_train_iter) % opts.test_step == 1:# or iteration == opts.init_train_iter + opts.train_iters:
mask_CNN.eval()
if opts.cxtoy == 'True':C_XtoY.eval()
with torch.no_grad():
#print('start testing..')
error_matrix = 0
error_matrix_count = 0
iteration2 = 0
G_Image_loss_avg_test = 0
G_Class_loss_avg_test = 0
while iteration2 < opts.max_test_iters:
opts.iteration = iteration + iteration2
iteration2 += 1
images_X_test, labels_X_test, images_Y_test= next(testing_dataloader.__iter__())
G_Image_loss, G_Class_loss, G_Acc = work(images_X_test, labels_X_test, images_Y_test, opts, downchirp, mask_CNN, C_XtoY)
G_Image_loss_avg_test += G_Image_loss
G_Class_loss_avg_test += G_Class_loss
error_matrix += G_Acc
error_matrix_count += 1
error_matrix2 = (error_matrix / error_matrix_count).item()
print('TEST: ACC:' ,error_matrix2, '['+str(error_matrix.item())+'/'+str(error_matrix_count)+']','ILOSS:',"{:6.3f}".format(G_Image_loss_avg_test/error_matrix_count) ,
'CLOSS:',"{:6.3f}".format(G_Class_loss_avg_test/error_matrix_count))
with open(opts.logfile2,'a') as f:
f.write('\n'+str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + ' , ' + "{:6d}".format(iteration) + ' , ' + "{:6.3f}".format(error_matrix2))
with open(opts.logfile,'a') as f:
f.write(' , ' + "{:6d}".format(iteration) + ' , ' + "{:6.3f}".format(error_matrix2))
if(error_matrix2>=opts.terminate_acc):
print('REACHED',opts.terminate_acc,'ACC, TERMINATINg...')
iteration = opts.init_train_iter + opts.train_iters + 1
break
print(' CURRENT TIME ITER YLOSS ILOSS CLOSS ACC TIME ----TRAINING',opts.lr,'----')
return [mask_CNN, C_XtoY]