-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathkm_matcher.py
112 lines (89 loc) · 3.64 KB
/
km_matcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
'''
reference: https://www.topcoder.com/community/competitive-programming/tutorials/assignment-problem-and-hungarian-algorithm/
'''
import numpy as np
#max weight assignment
class KMMatcher:
## weights : nxm weight matrix (numpy , float), n <= m
def __init__(self, weights):
weights = np.array(weights).astype(np.float32)
self.weights = weights
self.n, self.m = weights.shape
assert self.n <= self.m
# init label
self.label_x = np.max(weights, axis=1)
self.label_y = np.zeros((self.m, ), dtype=np.float32)
self.max_match = 0
self.xy = -np.ones((self.n,), dtype=np.int)
self.yx = -np.ones((self.m,), dtype=np.int)
def do_augment(self, x, y):
self.max_match += 1
while x != -2:
self.yx[y] = x
ty = self.xy[x]
self.xy[x] = y
x, y = self.prev[x], ty
def find_augment_path(self):
self.S = np.zeros((self.n,), np.bool)
self.T = np.zeros((self.m,), np.bool)
self.slack = np.zeros((self.m,), dtype=np.float32)
self.slackyx = -np.ones((self.m,), dtype=np.int) # l[slackyx[y]] + l[y] - w[slackx[y], y] == slack[y]
self.prev = -np.ones((self.n,), np.int)
queue, st = [], 0
root = -1
for x in range(self.n):
if self.xy[x] == -1:
queue.append(x);
root = x
self.prev[x] = -2
self.S[x] = True
break
self.slack = self.label_y + self.label_x[root] - self.weights[root]
self.slackyx[:] = root
while True:
while st < len(queue):
x = queue[st]; st+= 1
is_in_graph = np.isclose(self.weights[x], self.label_x[x] + self.label_y)
nonzero_inds = np.nonzero(np.logical_and(is_in_graph, np.logical_not(self.T)))[0]
for y in nonzero_inds:
if self.yx[y] == -1:
return x, y
self.T[y] = True
queue.append(self.yx[y])
self.add_to_tree(self.yx[y], x)
self.update_labels()
queue, st = [], 0
is_in_graph = np.isclose(self.slack, 0)
nonzero_inds = np.nonzero(np.logical_and(is_in_graph, np.logical_not(self.T)))[0]
for y in nonzero_inds:
x = self.slackyx[y]
if self.yx[y] == -1:
return x, y
self.T[y] = True
if not self.S[self.yx[y]]:
queue.append(x)
self.add_to_tree(self.yx[y], x)
def solve(self, verbose = False):
while self.max_match < self.n:
x, y = self.find_augment_path()
self.do_augment(x, y)
sum = 0.
for x in range(self.n):
if verbose:
print('match {} to {}, weight {:.4f}'.format(x, self.xy[x], self.weights[x, self.xy[x]]))
sum += self.weights[x, self.xy[x]]
self.best = sum
if verbose:
print('ans: {:.4f}'.format(sum))
return sum
def add_to_tree(self, x, prevx):
self.S[x] = True
self.prev[x] = prevx
better_slack_idx = self.label_x[x] + self.label_y - self.weights[x] < self.slack
self.slack[better_slack_idx] = self.label_x[x] + self.label_y[better_slack_idx] - self.weights[x, better_slack_idx]
self.slackyx[better_slack_idx] = x
def update_labels(self):
delta = self.slack[np.logical_not(self.T)].min()
self.label_x[self.S] -= delta
self.label_y[self.T] += delta
self.slack[np.logical_not(self.T)] -= delta