-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
188 lines (159 loc) · 6.89 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from itertools import chain
from time import time
import logging
import sys
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import mean_squared_error
from src.enums.trainingtype import TrainingType
from src.discriminator.type1 import Discriminator as Discriminator1
from src.discriminator.type2 import Discriminator as Discriminator2
from src.generator.generator import Generator
from src.data import get_real_samples
from src.settings import g_settings as gs
from src.settings import d_settings as ds
from src.settings import t_settings as ts
from src.settings import data_settings as das
from src.settings import settings_init
'''
The plan:
1 train the discriminator to do what I propose above
(starting with the easiest thing to train, i.e., feeding quantum state
with amplitudes given by square root of probs of Gaussian or quantum state
produced by QCBM and training discriminator to tell the difference).
2 train the QCBM to fool the generator
(i.e., minimize the loss function: discriminator says fake,
i.e., QCBM, data is fake).
3 repeat and alternate the above two steps.
4 if it works, try harder to train encodings of the real/fake data.
'''
class QGAN(object):
def __init__(self):
self.d = Discriminator1() if ds.type == 1 else Discriminator2()
self.g = Generator()
# Returns dataset of [[point0, point1..., point7], ...], and list of 0 (fake) and 1 (real)
def generate_dataset(self):
f, r = self.g.gen_synthetics(das.synthetic_size), get_real_samples(das.real_size)
labels = list(chain((0 for x in range(das.synthetic_size)), (1 for x in range(das.real_size))))
return list(chain(f, r)), labels
def _train_discriminator(self):
dataset, labels = self.generate_dataset()
logging.info(f'New trainingset generated')
if ts.print_accuracy:
logging.info(f'Discriminator mean squared error pre: {self.d.test(dataset, labels)}')
d_start_time = time()
self.d.train(dataset, labels)
d_end_time = time()
logging.info(f'Discriminator training completed in {round(d_end_time-d_start_time, 2)} seconds')
if ts.print_accuracy:
logging.info(f'Discriminator mean squared error post: {self.d.test(dataset, labels)}')
def _train_generator(self):
if ts.print_accuracy:
logging.info(f'Generator mean squared error pre: {self.g.test(self.d)}')
g_start_time = time()
self.g.train(self.d)
g_end_time = time()
logging.info(f'Generator training completed in {round(g_end_time-g_start_time, 2)} seconds')
if ts.print_accuracy:
diff = 0.0
for x in range(20):
diff += mean_squared_error(next(self.g.gen_synthetics(1)), next(get_real_samples(1)))
diff /= 20.0
logging.info(f'Average diff between generator output and true distribution: {diff}')
logging.info(f'Generator mean squared error post: {self.g.test(self.d)}')
def train(self):
generator_samples = []
total_start_time = time()
for idx, x in enumerate(range(ts.repeats)):
logging.info(f'Starting training iteration {idx}')
it_start_time = time()
self._train_discriminator()
if not (ts.dend and idx+1 == ts.repeats): #Not final round and discriminator ends
self._train_generator()
it_end_time = time()
logging.info(f'COMPLETED in {round(it_end_time-it_start_time, 2)} seconds')
# Code below generates intermediate figures
generator_samples.append(next(self.g.gen_synthetics(1)))
plt.plot(next(get_real_samples(1)))
plt.plot(generator_samples[-1])
plt.legend(['Data', f'gen({idx})'])
plt.savefig(f'iter_{idx}.pdf')
plt.clf()
total_end_time = time()
logging.info(f'FINISHED in {round(total_end_time-total_start_time, 2)} seconds')
plt.plot(next(get_real_samples(1)))
legend = ['Data']
for idx, sample in enumerate(generator_samples):
plt.plot(sample)
legend.append(f'gen({idx})')
plt.legend(legend)
if ts.show_figs:
plt.show()
plt.savefig(f'iter_final.pdf')
plt.clf()
def _test_generator(self):
print(f'Final generator squared error (can be high if discriminator is trained well): {self.g.test(self.d)}')
logging.getLogger('matplotlib').setLevel(logging.WARNING) #shut up matplotlib
plt.plot(next(get_real_samples(1)))
for dist in self.g.gen_synthetics(4):
plt.plot(dist)
legend = ['Data']
legend.extend(list(f'gen{x}' for x in range(4)))
plt.legend(legend)
if ts.show_figs:
plt.show()
plt.savefig('gen.pdf')
plt.clf()
def _test_discriminator(self):
dataset, labels = self.generate_dataset()
accuracy, details = self.d.test2(dataset, labels)
logging.info(f'Final discriminator mean squared error (on test): {accuracy}')
TP, FP, TN, FN = details
logging.info(f'TP: {TP}')
logging.info(f'FP: {FP}')
logging.info(f'TN: {TN}')
logging.info(f'FN: {FN}')
logging.info(f'Total: {das.synthetic_size+das.real_size}')
def test(self):
self._test_generator()
self._test_discriminator()
def parameter_prelude():
logging.info(f'''
Training network for {ts.repeats} repeats, using
Generator:
initial param type {gs.paramtype.name}
training type {gs.trainingtype.name}
using {gs.num_qubits} qubits
depth {gs.depth} ({2*gs.depth*gs.num_qubits} params to optimize)
shots {gs.n_shots} (used to estimate probabilities on hardware if > 0)
maximal {gs.max_iter} iterations per repeat
learning step rate {gs.step_rate} (used only if training type is ADAM)
Discriminator:
initial param type {ds.paramtype.name}
training type {ds.trainingtype.name}
type {ds.type} (using {ds.num_qubits} qubits)
depth {ds.depth} ({2*ds.depth*ds.num_qubits} params to optimize)
shots {ds.n_shots} (used to estimate probabilities on hardware if > 0)
maximal {ds.max_iter} iterations per repeat
Distribution:
mu {das.mu}
sigma {das.sigma}
batch size {das.batch_size} (higher means better log-normal estimation)
Data:
discriminator trainingset size {das.synthetic_size+das.real_size} (with {das.synthetic_size} synthetic and {das.real_size} real)
generator trainingset size {das.gen_size}
Training:
repeats {ts.repeats}
printing accuracy {ts.print_accuracy} (more info during training for a small slowdown)
showing figures is set to {ts.show_figs}
Last one to train is {"discriminator" if ts.dend else "generator"}
''')
def main():
parameter_prelude()
qgan = QGAN()
qgan.train()
qgan.test()
if __name__ == '__main__':
default_excited_state = False
settings_init()
main()