-
Notifications
You must be signed in to change notification settings - Fork 1
/
classifier_inferencer.py
153 lines (140 loc) · 8.42 KB
/
classifier_inferencer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
import models
import datasets
import json
import torch
import numpy as np
import sys
import types
from pca_layers import change_all_pca_layer_thresholds, change_all_pca_layer_thresholds_and_inject_random_directions
from pca_layers import change_all_pca_layer_centering
import pandas as pd
from thop import profile
from trainer import Trainer
from time import time
parser = argparse.ArgumentParser(description='Train a network on a dataset')
parser.add_argument('-n', '--network', dest='model_name', action='store', default='vgg11')
parser.add_argument('-d', '--dataset', dest='dataset_name', action='store', default='Cifar10')
parser.add_argument('-b', '--batch-size', dest='batch_size', action='store', default=32)
parser.add_argument('-o', '--output', dest='output', action='store', default='logs')
parser.add_argument('-c', '--compute-device', dest='device', action='store', default='cpu')
parser.add_argument('-r', '--run_id', dest='run_id', action='store', default=0)
parser.add_argument('-cf', '--config', dest='json_file', action='store', default=None)
parser.add_argument('-cs', '--saturation-device', dest='sat_device', type=str, default=None, action='store')
def parse_model(model_name, shape, num_classes):
try:
model = models.__dict__[model_name](input_size=shape, num_classes=num_classes)
except KeyError:
raise NameError("%s doesn't exist." % model_name)
return model
def parse_dataset(dataset_name, batch_size):
batch_size = int(batch_size)
try:
train_loader, test_loader, shape, num_classes = datasets.__dict__[dataset_name](batch_size=batch_size)
except KeyError:
raise NameError("%s doesn't exist." % dataset_name)
return train_loader, test_loader, shape, num_classes
if __name__ == '__main__':
args = parser.parse_args()
model_names = []
accs = []
losses = []
inference_thresholds = []
dims = []
fdims = []
sats_l = []
sat_avg = []
datasets_csv = []
downsamplings = []
if args.json_file is None:
print('Starting manual run')
train_loader, test_loader, shape, num_classes = parse_dataset(args.dataset_name, args.batch_size)
model = parse_model(args.model_name, shape, num_classes)
trainer = Trainer(model, train_loader, test_loader, logs_dir=args.output, device=args.device, run_id=args.run_id)
trainer.train()
else:
print('Automatized experiment schedule enabled using', args.json_file)
config_dict = json.load(open(args.json_file, 'r'))
thresholds = [.99] if not 'threshs' in config_dict else config_dict['threshs']
dss = config_dict['dataset'] if isinstance(config_dict['dataset'], list) else [config_dict['dataset']]
downsampling = [None] if not 'downsampling' in config_dict else config_dict['downsampling'] + [None]
optimizer = config_dict['optimizer']
run_num = 0
print(thresholds)
for dataset in dss:
for thresh in thresholds:
for batch_size in config_dict['batch_sizes']:
for dwnsmpl in downsampling:
for model in config_dict['models']:
run_num += 1
print('Running Experiment', run_num, 'of', len(config_dict['batch_sizes'])*len(config_dict['models']*len(thresholds))*len(dss)*len(downsampling))
train_loader, test_loader, shape, num_classes = parse_dataset(dataset, 500)
model = parse_model(model, shape, num_classes)
change_all_pca_layer_centering(centering=config_dict['centering'], network=model, verbose=False, downsampling=dwnsmpl)
conv_method = 'channelwise' if 'conv_method' not in config_dict else config_dict['conv_method']
trainer = Trainer(model,
train_loader,
test_loader,
logs_dir=args.output,
device=args.device,
run_id=args.run_id,
epochs=config_dict['epochs'],
batch_size=batch_size,
optimizer=optimizer,
plot=True,
compute_top_k=True if dataset == 'ImageNet' else False,
data_prallel=False if torch.cuda.device_count() > 1 and dataset == 'ImageNet' else False,
saturation_device=args.sat_device,
conv_method=conv_method,
thresh=thresh, downsampling=dwnsmpl)
# try:
trainer.stats.stop()
#model.load_state_dict(torch.load(trainer.savepath.replace('.csv', '.pt'))['model_state_dict'])
print('Loading model from', trainer.savepath)
#except:
# print('Loading model failed, proceeding')
# continue
print('Model loaded')
#for eval_thresh in reversed([0.9, 0.91, 0.92, 0.93, 0.94,
# 0.95, 0.96, 0.97, 0.98, 0.99,
# 0.992, 0.994, 0.996, 0.998, 0.999,
# 3.0]):
for eval_thresh in reversed([0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 0.992, 0.994, 0.996, 0.998, 0.999, 0.999, 0.9991, 0.9992, 0.9993, 0.9994, 0.9995, 0.9996, 0.9997, 0.9998, 0.9999, 3.0]):
#for eval_thresh in reversed([0.9991, 0.9992, 0.9993, 0.9994, 0.9995, 0.9996, 0.9997, 0.9998, 0.9999, 3.0]):
#change_all_pca_layer_thresholds_and_inject_random_directions(eval_thresh, model, verbose=False)
sat, indims, fsdims, lnames = change_all_pca_layer_thresholds(eval_thresh, network=model)
start = time()
print('Changed model threshold to', eval_thresh)
#model = model.to(trainer.device)
#trainer.model = model
acc, loss = trainer.test(False)
print('InDims:', sum(indims), 'Acc:', acc, 'Loss:', loss, 'for', model.name, 'at threshold:', eval_thresh)
model_names.append(model.name)
accs.append(acc)
losses.append(loss)
dims.append(sum(indims))
inference_thresholds.append(eval_thresh)
fdims.append(sum(fsdims))
sats_l =({name: [lsat] for name, lsat in zip(lnames, sat)})
avg = np.mean(sat)
sat_avg.append(avg)
datasets_csv.append(dataset)
downsamplings.append(dwnsmpl)
sats_l['avg_sat'] = avg
sats_l['loss'] = loss
pd.DataFrame.from_dict(
sats_l
).to_csv(f'{trainer.savepath.replace(".csv", "_if{}.csv".format(eval_thresh))}', sep=';')
end = time()
print('Took:', end - start)
pd.DataFrame.from_dict({
'dataset': datasets_csv,
'loss': losses,
'model': model_names,
'accs': accs,
'thresh': inference_thresholds,
'intrinsic_dimensions': dims,
'featurespace_dimension': fdims,
'sat_avg': sat_avg,
'downsampling': downsamplings
}).to_csv('resner18_result.csv', sep=';')