-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathday25.py
123 lines (91 loc) · 2.84 KB
/
day25.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import random
import sys
import networkx as nx
visualize = False
if "--visualize" in sys.argv:
sys.argv.remove("--visualize")
visualize = True
fast = False
if "--fast" in sys.argv:
sys.argv.remove("--fast")
fast = True
if len(sys.argv) != 2:
print("Missing input file.")
exit(1)
filename = sys.argv[1]
is_sample = filename == "sample.txt"
def parse():
G = nx.DiGraph()
for line in open(filename).readlines():
source, targets = line.strip().split(": ")
for target in targets.split():
G.add_edge(source, target, capacity=1)
G.add_edge(target, source, capacity=1)
return G
def split_graph(G):
for edge in edges_with_max_flow(G, 3):
source, target = edge
G.remove_edge(source, target)
G.remove_edge(target, source)
(S, T) = get_connected_components(G)
return S, T
def edges_with_max_flow(G, flow):
seen = set()
edges = []
for source in G:
for target in G[source]:
# skip reverse edges (G is directed, but our wires are not)
if (target, source) in seen:
continue
seen.add((source, target))
if nx.maximum_flow_value(G, source, target) == flow:
edges.append((source, target))
if len(edges) == 3:
return edges
assert False, f"need 3 edges with flow {flow}, only found {len(edges)}"
def get_connected_components(G):
components = list(nx.connected_components(G.to_undirected()))
assert len(components) == 2
return components
def fast_minimum_cut(G):
while True:
u = v = random.choice(list(G.nodes))
while v == u:
v = random.choice(list(G.nodes))
cut, partition = nx.minimum_cut(G, u, v)
if cut == 3:
(S, T) = partition
return S, T
def plot(G, S):
import matplotlib.pyplot as plt
options = {
"font_size": 12,
"node_size": 1000,
"with_labels": True,
"font_color": "#505050",
"node_color": ["#FFB6C1" if node in S else "#ADD8E6" for node in G],
}
pos = nx.spring_layout(G, seed=13) # tried seeds until one was nice
nx.draw(G, pos=pos, **options)
plt.show()
def check(part, actual, expected=None):
print(f"Part {part}{' (sample)' if is_sample else ''}: {actual} ", end="")
if expected is None:
print("❔")
else:
if actual != expected:
print(f"≠ {expected} ❌")
exit(1)
print("✅")
if __name__ == "__main__":
G = parse()
if fast:
# use nx's minimum_cut (0.3 s)
S, T = fast_minimum_cut(G)
else:
# my original idea (30 s)
S, T = split_graph(G)
part1 = len(S) * len(T)
check(1, part1, 54 if is_sample else 589036)
if visualize:
plot(parse().to_undirected(), S)