forked from ingitom99/learn_sink
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
227 lines (201 loc) · 6.61 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Imports
import datetime
import os
import torch
from tqdm import tqdm
from src.geometry import get_cost
from src.sinkhorn import sink_vec, sink
from src.train import the_hunt
from src.nets import GenNet, PredNet
from src.loss import hilb_proj_loss, mse_loss
from src.data_funcs import preprocessor, test_set_sampler
# Create 'stamp' folder for saving results
current_time = datetime.datetime.now()
formatted_time = current_time.strftime('%m-%d_%H_%M_%S')
stamp_folder_path = './results/experiment_'+formatted_time
os.mkdir(stamp_folder_path)
# Problem hyperparameters
length_prior = 10
length = 28
dim_prior = length_prior**2
dim = length**2
dust_const = 1e-6
skip_const = 0.75
width_gen = 4 * dim
width_pred = 4 * dim
# Training experiment_info
n_loops = 10
n_mini_loops_gen = 1
n_mini_loops_pred = 1
n_batch = 200
layer_weights_normed = False
loss_gen_reg_coeff = 10.0
weight_decay_gen = 0.0
weight_decay_pred = 1e-4
lr_gen = 0.1
lr_pred = 0.1
lr_fact_gen = 1.0
lr_fact_pred = 1.0
learn_gen = True
bootstrapped = True
n_boot = 40
test_iter = 5
n_test = 3
plot_test_images = True
display_test_info = True
checkpoint_iter = n_loops
niter_warmstart = 200
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
# Initialization of cost matrix
cost = get_cost(length).double().to(device)
# Regularization parameter
eps = 1e-2
print(f'Entropic regularization param: {eps}')
# Loading, preprocessing, and sampling for the test sets dictionary
with torch.no_grad():
mnist = torch.load('./data/mnist.pt')
cifar = torch.load('./data/cifar.pt')
lfw = torch.load('./data/lfw.pt')
bear = torch.load('./data/bear.pt')
quickdraw = torch.load('./data/quickdraw.pt')
mnist = preprocessor(mnist, length, dust_const)
cifar = preprocessor(cifar, length, dust_const)
lfw = preprocessor(lfw, length, dust_const)
bear = preprocessor(bear, length, dust_const)
quickdraw = preprocessor(quickdraw, length, dust_const)
mnist = test_set_sampler(mnist, n_test).double().to(device)
cifar = test_set_sampler(cifar, n_test).double().to(device)
lfw = test_set_sampler(lfw, n_test).double().to(device)
bear = test_set_sampler(bear, n_test).double().to(device)
quickdraw = test_set_sampler(quickdraw, n_test).double().to(device)
test_sets = {'mnist' : mnist,'cifar' : cifar, 'bear' : bear,
'quickdraw' : quickdraw, 'lfw' : lfw}
# Creating a dictionary of test emds, and test targets for each test set
test_sinks = {}
test_T = {}
print('Computing test emds, sinks, and targets...')
for key in test_sets.keys():
print(f'{key}:')
with torch.no_grad():
X = test_sets[key]
V0 = torch.ones_like(X[:, :dim])
V = sink_vec(X[:, :dim], X[:, dim:], cost, eps, V0, 2000)[1]
V = torch.log(V)
T = V - torch.unsqueeze(V.mean(dim=1), 1).repeat(1, dim)
test_T[key] = T
sinks = []
for x in tqdm(X):
mu = x[:dim] / x[:dim].sum()
nu = x[dim:] / x[dim:].sum()
v0 = torch.ones_like(mu)
_, _, _, sink_dist = sink(mu, nu, cost, eps, v0, 1000)
sinks.append(sink_dist)
sinks = torch.tensor(sinks)
test_sinks[key] = sinks
# Initialization of loss function
loss_func = hilb_proj_loss
# Initialization of nets
deer = GenNet(dim_prior, dim, width_gen, dust_const,
skip_const).double().to(device)
puma = PredNet(dim, width_pred).double().to(device)
# no. layers in each net
n_layers_gen = len(deer.layers)
n_layers_pred = len(puma.layers)
# Load model state dict
#deer.load_state_dict(torch.load(f'{stamp_folder_path}/deer.pt'))
#puma.load_state_dict(torch.load(f'{stamp_folder_path}/puma.pt'))
# Training mode
deer.train()
puma.train()
# Get total number of trainable parameters
n_params_gen = sum(p.numel() for p in deer.parameters() if p.requires_grad)
n_params_pred = sum(p.numel() for p in puma.parameters() if p.requires_grad)
print(f'No. trainable parameters in gen net: {n_params_gen}')
print(f'No. trainable parameters in pred net: {n_params_pred}')
# Create txt file in stamp for experiment_info
current_date = datetime.datetime.now().strftime('%d.%m.%Y')
experiment_info = {
'date': current_date,
'prior distribution length': length_prior,
'data length': length,
'prior distribution dimension': dim_prior,
'data dimension': dim,
'regularization parameter': eps,
'dust constant': dust_const,
'skip connection constant': skip_const,
'no. layers gen': n_layers_gen,
'no. layers pred': n_layers_pred,
'hidden layer width gen': width_gen,
'hidden layer width pred': width_pred,
'total no. trainable parameters gen': n_params_gen,
'total no. trainable parameters pred': n_params_pred,
'device': device,
'gen net weight decay factor': weight_decay_gen,
'pred net weight decay factor': weight_decay_pred,
'gen net learning rate': lr_gen,
'pred net learning rate': lr_pred,
'learning rate scale factor gen': lr_fact_gen,
'learning rate scale factor pred': lr_fact_pred,
'no. unique data points gen': n_loops*n_mini_loops_gen*n_batch,
'no. unique data points pred': n_loops*n_mini_loops_pred*n_batch,
'no. loops' : n_loops,
'no. mini loops gen' : n_mini_loops_gen,
'no. mini loops pred' : n_mini_loops_pred,
'batch size per step gen' : n_batch,
'batch size per step pred' : n_batch,
'test_iter': test_iter,
'no. test samples': n_test,
'learn gen?': learn_gen,
'bootstrapped?': bootstrapped,
'no. bootstraps': n_boot,
'checkpoint': checkpoint_iter,
'warmstart iterations': niter_warmstart,
}
# Print experiment_info
for key, value in experiment_info.items():
print(f'{key}: {value}')
# Define the output file path
output_file = f'{stamp_folder_path}/experiment_info.txt'
# Save the experiment_info to the text file
with open(output_file, 'w', encoding='utf-8') as file:
for key, value in experiment_info.items():
file.write(f'{key}: {value}\n')
# Run the hunt
results = the_hunt(
deer,
puma,
loss_func,
loss_gen_reg_coeff,
layer_weights_normed,
cost,
eps,
dust_const,
dim_prior,
dim,
device,
test_sets,
test_sinks,
test_T,
n_loops,
n_mini_loops_gen,
n_mini_loops_pred,
n_batch,
weight_decay_gen,
weight_decay_pred,
lr_pred,
lr_gen,
lr_fact_gen,
lr_fact_pred,
learn_gen,
bootstrapped,
n_boot,
test_iter,
plot_test_images,
display_test_info,
stamp_folder_path,
checkpoint_iter,
niter_warmstart
)
print('The hunt is over. Time to rest.')