forked from rmunro/pytorch_active_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiversity_sampling.py
314 lines (214 loc) · 11 KB
/
diversity_sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#!/usr/bin/env python
"""DIVERSITY SAMPLING
Diversity Sampling examples for Active Learning in PyTorch
This is an open source example to accompany Chapter 4 from the book:
"Human-in-the-Loop Machine Learning"
It contains four Active Learning strategies:
1. Model-based outlier sampling
2. Cluster-based sampling
3. Representative sampling
4. Adaptive Representative sampling
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import math
import datetime
import csv
import re
import os
import getopt, sys
from random import shuffle
from collections import defaultdict
# from numpy import rank
from uncertainty_sampling import UncertaintySampling
from pytorch_clusters import CosineClusters
from pytorch_clusters import Cluster
if sys.argv[0] == "diversity_sampling.py":
import active_learning
__author__ = "Robert Munro"
__license__ = "MIT"
__version__ = "1.0.1"
class DiversitySampling():
def __init__(self, verbose=False):
self.verbose = verbose
def get_cluster_samples(self, data, num_clusters=5, max_epochs=5, limit=5000):
"""Create clusters using cosine similarity
Keyword arguments:
data -- data to be clustered
num_clusters -- the number of clusters to create
max_epochs -- maximum number of epochs to create clusters
limit -- sample only this many items for faster clustering (-1 = no limit)
Creates clusters by the K-Means clustering algorithm,
using cosine similarity instead of more common euclidean distance
Creates clusters until converged or max_epochs passes over the data
"""
if limit > 0:
shuffle(data)
data = data[:limit]
cosine_clusters = CosineClusters(num_clusters)
cosine_clusters.add_random_training_items(data)
for i in range(0, max_epochs):
print("Epoch "+str(i))
added = cosine_clusters.add_items_to_best_cluster(data)
if added == 0:
break
centroids = cosine_clusters.get_centroids()
outliers = cosine_clusters.get_outliers()
randoms = cosine_clusters.get_randoms(3, self.verbose)
return centroids + outliers + randoms
def get_representative_samples(self, training_data, unlabeled_data, number=20, limit=10000):
"""Gets the most representative unlabeled items, compared to training data
Keyword arguments:
training_data -- data with a label, that the current model is trained on
unlabeled_data -- data that does not yet have a label
number -- number of items to sample
limit -- sample from only this many items for faster sampling (-1 = no limit)
Creates one cluster for each data set: training and unlabeled
"""
if limit > 0:
shuffle(training_data)
training_data = training_data[:limit]
shuffle(unlabeled_data)
unlabeled_data = unlabeled_data[:limit]
training_cluster = Cluster()
for item in training_data:
training_cluster.add_to_cluster(item)
unlabeled_cluster = Cluster()
for item in unlabeled_data:
unlabeled_cluster.add_to_cluster(item)
for item in unlabeled_data:
training_score = training_cluster.cosine_similary(item)
unlabeled_score = unlabeled_cluster.cosine_similary(item)
representativeness = unlabeled_score - training_score
item[3] = "representative"
item[4] = representativeness
unlabeled_data.sort(reverse=True, key=lambda x: x[4])
return unlabeled_data[:number:]
def get_adaptive_representative_samples(self, training_data, unlabeled_data, number=20, limit=5000):
"""Adaptively gets the most representative unlabeled items, compared to training data
Keyword arguments:
training_data -- data with a label, that the current model is trained on
unlabeled_data -- data that does not yet have a label
number -- number of items to sample
limit -- sample from only this many items for faster sampling (-1 = no limit)
Adaptive variant of get_representative_samples() where the training_data is updated
after each individual selection in order to increase diversity of samples
"""
samples = []
for i in range(0, number):
print("Epoch "+str(i))
representative_item = self.get_representative_samples(training_data, unlabeled_data, 1, limit)[0]
samples.append(representative_item)
unlabeled_data.remove(representative_item)
return samples
def get_validation_rankings(self, model, validation_data, feature_method):
"""Get model outliers from unlabeled data
Keyword arguments:
model -- current Machine Learning model for this task
unlabeled_data -- data that does not yet have a label
validation_data -- held out data drawn from the same distribution as the training data
feature_method -- the method to create features from the raw text
number -- number of items to sample
limit -- sample from only this many items for faster sampling (-1 = no limit)
An outlier is defined as
unlabeled_data with the lowest average from rank order of logits
where rank order is defined by validation data inference
"""
validation_rankings = [] # 2D array, every neuron by ordered list of output on validation data per neuron
# Get per-neuron scores from validation data
if self.verbose:
print("Getting neuron activation scores from validation data")
with torch.no_grad():
v=0
for item in validation_data:
textid = item[0]
text = item[1]
feature_vector = feature_method(text)
hidden, logits, log_probs = model(feature_vector, return_all_layers=True)
neuron_outputs = logits.data.tolist()[0] #logits
# initialize array if we haven't yet
if len(validation_rankings) == 0:
for output in neuron_outputs:
validation_rankings.append([0.0] * len(validation_data))
n=0
for output in neuron_outputs:
validation_rankings[n][v] = output
n += 1
v += 1
# Rank-order the validation scores
v=0
for validation in validation_rankings:
validation.sort()
validation_rankings[v] = validation
v += 1
return validation_rankings
def rt(str):
return str
def get_rank(self, value, rankings):
""" get the rank of the value in an ordered array as a percentage
Keyword arguments:
value -- the value for which we want to return the ranked value
rankings -- the ordered array in which to determine the value's ranking
returns linear distance between the indexes where value occurs, in the
case that there is not an exact match with the ranked values
"""
index = 0 # default: ranking = 0
for ranked_number in rankings:
if value < ranked_number:
break #NB: this O(N) loop could be optimized to O(log(N))
index += 1
if(index >= len(rankings)):
index = len(rankings) # maximum: ranking = 1
elif(index > 0):
# get linear interpolation between the two closest indexes
diff = rankings[index] - rankings[index - 1]
perc = value - rankings[index - 1]
linear = perc / diff
index = float(index - 1) + linear
absolute_ranking = index / len(rankings)
return(absolute_ranking)
def get_model_outliers(self, model, unlabeled_data, validation_data, feature_method, number=5, limit=10000):
"""Get model outliers from unlabeled data
Keyword arguments:
model -- current Machine Learning model for this task
unlabeled_data -- data that does not yet have a label
validation_data -- held out data drawn from the same distribution as the training data
feature_method -- the method to create features from the raw text
number -- number of items to sample
limit -- sample from only this many items for faster sampling (-1 = no limit)
An outlier is defined as
unlabeled_data with the lowest average from rank order of logits
where rank order is defined by validation data inference
"""
# Get per-neuron scores from validation data
validation_rankings = self.get_validation_rankings(model, validation_data, feature_method)
# Iterate over unlabeled items
if self.verbose:
print("Getting rankings for unlabeled data")
outliers = []
if limit == -1 and len(unlabeled_data) > 10000 and self.verbose: # we're drawing from *a lot* of data this will take a while
print("Get rankings for a large amount of unlabeled data: this might take a while")
else:
# only apply the model to a limited number of items
shuffle(unlabeled_data)
unlabeled_data = unlabeled_data[:limit]
with torch.no_grad():
for item in unlabeled_data:
text = item[1]
feature_vector = feature_method(text)
hidden, logits, log_probs = model(feature_vector, return_all_layers=True)
neuron_outputs = logits.data.tolist()[0] #logits
n=0
ranks = []
for output in neuron_outputs:
rank = self.get_rank(output, validation_rankings[n])
ranks.append(rank)
n += 1
item[3] = "logit_rank_outlier"
item[4] = 1 - (sum(ranks) / len(neuron_outputs)) # average rank
outliers.append(item)
outliers.sort(reverse=True, key=lambda x: x[4])
return outliers[:number:]