-
Notifications
You must be signed in to change notification settings - Fork 9
/
simple_version.py
81 lines (72 loc) · 2.95 KB
/
simple_version.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
## This code is stripped to the minimum required to implement the
## differentiable sort with no dependencies except numpy
import numpy as np
def bitonic_matrices(n):
"""Return the bitonic matrices to sort n elements, where n=2^k.
This is a list of quadruples of matrices l,r,l_inv,r_inv
l, r are n/2 x n and map the input to the left and right sides of a compare-and-swap
l_inv, r_inv are n/2 x n and map the result of a compare-and-swap back to a vector of
length n"""
# number of outer layers
layers = int(np.log2(n))
matrices = []
for layer in range(layers):
for s in range(layer + 1):
m = 1 << (layer - s)
l, r = np.zeros((n // 2, n)), np.zeros((n // 2, n))
map_l, map_r = np.zeros((n, n // 2)), np.zeros((n, n // 2))
out = 0
for i in range(0, n, m << 1):
for j in range(m):
ix = i + j
a, b = ix, ix + m
l[out, a], r[out,b] = 1, 1
if (ix >> (layer + 1)) & 1:
a, b = b, a
map_l[a, out], map_r[b, out] = 1, 1
out += 1
matrices.append((l, r, map_l, map_r))
return matrices
def diff_sort(matrices, x, alpha=1):
"""
Approximate differentiable sort. Takes a set of bitonic sort matrices generated by bitonic_matrices(n), sort
a sequence x of length n. Values may be distorted slightly but will be ordered.
"""
for l, r, map_l, map_r in matrices:
a, b = l @ x, r @ x
# smoothmax
mx = (a * np.exp(a * alpha) + b * np.exp(b * alpha)) / (
np.exp(a * alpha) + np.exp(b * alpha)
)
mn = a + b - mx
x = map_l @ mn + map_r @ mx
return x
def diff_argsort(matrices, x, sigma=0.1, alpha=1, transpose=False):
"""Return the smoothed, differentiable ranking of each element of x. Sigma
specifies the smoothing of the ranking. Note that this function is deceptively named,
and in the default setting returns the *ranking*, not the argsort.
If transpose is True, returns argsort (but note that ties are not broken in differentiable
argsort);
If False, returns ranking (likewise, ties are not broken).
"""
sortd = diff_sort(matrices, x, alpha)
diff = (x.reshape(-1, 1) - sortd.reshape(1, -1)) ** 2
rbf = np.exp(-(diff) / (2 * sigma ** 2))
order = (rbf.T / np.sum(rbf, axis=1)).T
if transpose:
order = order.T
return order @ np.arange(len(x))
if __name__ == "__main__":
# test data, length 8
x = np.array([5.0, -1.0, 9.5, 13.2, 16.2, 20.5, 42.0, 90.0])
print(x)
print("Sorted")
print(np.sort(x))
print("Diff. sorted")
matrices = bitonic_matrices(8)
print(diff_sort(matrices, x))
print("Ranked")
ixs = np.argsort(x)
print(np.arange(8)[ixs])
print("Diff. ranked")
print(diff_argsort(matrices, x))