-
Notifications
You must be signed in to change notification settings - Fork 0
/
Weighte_KNN.py
79 lines (62 loc) · 3.01 KB
/
Weighte_KNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#Cell type prediction using weigthed_knn function
from scipy import sparse
from sklearn.neighbors import KNeighborsTransformer
from collections import Counter
def remove_sparsity(adata):
if sparse.issparse(adata.X):
adata.X = adata.X.A
return adata
def weighted_knn(train_adata, valid_adata, label_key, n_neighbors=50, threshold=0.5,
pred_unknown=True, return_uncertainty=True):
"""
Taken from scnet:
https://github.com/theislab/scarches/blob/e84cfa5cf361bb22fd70865cb1f398af72248684/scnet/utils.py
"""
print(f'Weighted KNN with n_neighbors = {n_neighbors} and threshold = {threshold} ... ', end='')
k_neighbors_transformer = KNeighborsTransformer(n_neighbors=n_neighbors, mode='distance',
algorithm='brute', metric='euclidean',
n_jobs=-1)
train_adata = remove_sparsity(train_adata)
valid_adata = remove_sparsity(valid_adata)
k_neighbors_transformer.fit(train_adata.X)
y_train_labels = train_adata.obs[label_key].values
y_valid_labels = valid_adata.obs[label_key].values
top_k_distances, top_k_indices = k_neighbors_transformer.kneighbors(X=valid_adata.X)
stds = np.std(top_k_distances, axis=1)
stds = (2. / stds) ** 2
stds = stds.reshape(-1, 1)
top_k_distances_tilda = np.exp(-np.true_divide(top_k_distances, stds))
weights = top_k_distances_tilda / np.sum(top_k_distances_tilda, axis=1, keepdims=True)
uncertainties = []
pred_labels = []
for i in range(len(weights)):
# labels = y_train_labels[top_k_indices[i]]
most_common_label, _ = Counter(y_train_labels[top_k_indices[i]]).most_common(n=1)[0]
most_prob = weights[i, y_train_labels[top_k_indices[i]] == most_common_label].sum()
if pred_unknown:
if most_prob >= threshold:
pred_label = most_common_label
else:
pred_label = 'Unknown'
else:
pred_label = most_common_label
if pred_label == y_valid_labels[i]:
uncertainties.append(1 - most_prob)
else:
true_prob = weights[i, y_train_labels[top_k_indices[i]] == y_valid_labels[i]].sum()
uncertainties.append(1 - true_prob)
pred_labels.append(pred_label)
pred_labels = np.array(pred_labels).reshape(-1, 1)
uncertainties = np.array(uncertainties).reshape(-1, 1)
print('finished!')
if return_uncertainty:
return pred_labels, uncertainties
else:
return pred_labels
pred_labels, uncertainties = weighted_knn(train_adata=adata_latent,
valid_adata=adata_latentnew,
label_key='Subtype',
n_neighbors=n_neighbor,
threshold=threshold,
pred_unknown=True,
return_uncertainty=True)