-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdigits.py
131 lines (107 loc) · 4.17 KB
/
digits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import math
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.datasets import load_svmlight_file
import torch
from mmd_critic import Dataset, select_prototypes, select_criticisms
cwd = Path('.')
data_dir = cwd / 'data'
output_dir = cwd / 'output'
data_dir.mkdir(exist_ok=True)
output_dir.mkdir(exist_ok=True)
gamma = 0.026
num_prototypes = 32
num_criticisms = 10
kernel_type = 'local'
# kernel_type = 'global'
# regularizer = None
regularizer = 'logdet'
# regularizer = 'iterative'
make_plots = True
print('==============')
print(f'data_dir:{data_dir.absolute()}')
print(f'output_dir:{output_dir.absolute()}')
print(f'num_prototypes:{num_prototypes}')
print(f'num_criticisms:{num_criticisms}')
print(f'gamma:{gamma}')
print(f'kernel_type:{kernel_type}')
print(f'regularizer:{regularizer}')
print(f'make_plots:{make_plots}')
print('==============\n')
# torch.set_num_threads(64)
# Data setup
def load_data(fname):
X, y = load_svmlight_file(str(data_dir / fname))
X = torch.tensor(X.todense(), dtype=torch.float)
y = torch.tensor(y, dtype=torch.long)
sort_indices = y.argsort() # torch.argsort does not match np.argsort
# sort_indices = y.numpy().argsort() # argsort is not stable if quicksort is used
# print(sort_indices)
X = X[sort_indices, :]
y = y[sort_indices]
return X, y
print('Preparing data...', end='', flush=True)
X_train, y_train = load_data('usps')
X_test, y_test = load_data('usps.t')
y_train -= 1
y_test -= 1
d_train = Dataset(X_train, y_train)
if kernel_type == 'global':
d_train.compute_rbf_kernel(gamma)
elif kernel_type == 'local':
d_train.compute_local_rbf_kernel(gamma)
else:
raise KeyError('kernel_type must be either "global" or "local"')
print('Done.', flush=True)
# Prototypes
if num_prototypes > 0:
print('Computing prototypes...', end='', flush=True)
prototype_indices = select_prototypes(d_train.K, num_prototypes)
prototypes = d_train.X[prototype_indices]
prototype_labels = d_train.y[prototype_indices]
sorted_by_y_indices = prototype_labels.argsort()
prototypes_sorted = prototypes[sorted_by_y_indices]
prototype_labels = prototype_labels[sorted_by_y_indices]
print('Done.', flush=True)
print(prototype_indices.sort()[0].tolist())
# Visualize
if make_plots:
print('Plotting prototypes...', end='', flush=True)
num_cols = 8
num_rows = math.ceil(num_prototypes / num_cols)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(6, num_rows * 0.75))
for i, axis in enumerate(axes.ravel()):
if i >= num_prototypes:
axis.axis('off')
continue
axis.imshow(prototypes_sorted[i].view(16,16).numpy(), cmap='gray')
axis.axis('off')
fig.suptitle(f'{num_prototypes} Prototypes')
plt.savefig(output_dir / f'{num_prototypes}_prototypes_digits.svg')
print('Done.', flush=True)
# Criticisms
if num_criticisms > 0:
print('Computing criticisms...', end='', flush=True)
criticism_indices = select_criticisms(d_train.K, prototype_indices, num_criticisms, regularizer)
criticisms = d_train.X[criticism_indices]
criticism_labels = d_train.y[criticism_indices]
sorted_by_y_indices = criticism_labels.argsort()
criticisms_sorted = criticisms[sorted_by_y_indices]
criticism_labels = criticism_labels[sorted_by_y_indices]
print('Done.', flush=True)
print(criticism_indices.sort()[0].tolist())
# Visualize
if make_plots:
print('Plotting criticisms...', end='', flush=True)
num_cols = 8
num_rows = math.ceil(num_criticisms / num_cols)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(6, num_rows * 0.75))
for i, axis in enumerate(axes.ravel()):
if i >= num_criticisms:
axis.axis('off')
continue
axis.imshow(criticisms_sorted[i].view(16,16).numpy(), cmap='gray')
axis.axis('off')
fig.suptitle(f'{num_criticisms} Criticisms')
plt.savefig(output_dir / f'{num_criticisms}_criticisms_digits.svg')
print('Done.', flush=True)