-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnode.py
141 lines (115 loc) · 4.4 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from abc import ABC
import networkx as nx
import numpy as np
class Node(ABC):
def __init__(self, graph: nx.Graph, name: str, init=None):
self.graph = graph
self.name = name
if init is None: # assume that this is a binary variable
init = np.array([1., 1.])
self._init = init
self.belief = init
self.log_belief = np.log(init)
def reset_belief(self):
self.belief = self._init
def get_neighbors(self, exclude=None):
neighbors = self.graph.neighbors(self)
if exclude is not None:
neighbors = [x for x in neighbors if x != exclude]
return neighbors
def message(self, node):
pass
def sum_product(self, node):
pass
def max_product(self, node):
pass
def max_sum(self, node):
pass
class VarNode(Node):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.type = "variable"
self.messages_in = {}
self.messages_out = {}
def clear_messages(self):
self.messages_in = {}
self.messages_out = {}
def update_belief(self):
msg = np.ones_like(self.belief)
for factor in self.messages_in:
msg *= self.messages_in[factor]
self.belief = msg / np.linalg.norm(msg, 1)
def update_log_belief(self):
msg = np.zeros_like(self.log_belief)
for factor in self.messages_in:
msg += self.messages_in[factor]
self.belief = np.exp(msg) / sum(np.exp(msg))
self.log_belief = np.log(self.belief)
def sum_product(self, node: Node, normalize=False):
neighbors = self.get_neighbors(exclude=node)
msg = np.ones_like(self.belief)
for nbd in neighbors:
msg *= self.messages_in[nbd]
if normalize:
msg = msg / (np.sum(msg) + 1e-10)
self.messages_out[node] = msg
def max_product(self, node: Node, normalize=False):
self.sum_product(node, normalize=normalize)
def max_sum(self, node: Node, normalize=False):
neighbors = self.get_neighbors(exclude=node)
msg = np.zeros_like(self.log_belief)
for nbd in neighbors:
msg += self.messages_in[nbd]
if normalize:
msg = msg - np.log(np.sum(np.exp(msg)))
self.messages_out[node] = msg
def loopy_sum_product(self, node: Node):
self.sum_product(node)
msg = self.messages_out[node]
self.messages_out[node] = msg / np.linalg.norm(msg, 1)
class FactorNode(Node):
def __init__(self, cpd, ordered_variables, **kwargs):
super().__init__(**kwargs)
self.cpd = cpd
self.log_cpd = np.log(cpd)
self.order_neighbors = {var: i for (i, var) in enumerate(ordered_variables)}
self.type = "factor"
def _get_init_msg(self, node, neighbors=None):
if node not in self.order_neighbors:
raise IndexError("Node not the neighbor of this factor node")
if neighbors is None:
neighbors = self.get_neighbors(exclude=node)
msg = self.cpd
pos = self.order_neighbors[node]
order = [self.order_neighbors[x] for x in neighbors]
order = tuple([pos]+order)
msg = msg.transpose(order)
return msg
def sum_product(self, node: Node, normalize=False):
neighbors = self.get_neighbors(exclude=node)
msg = self._get_init_msg(node, neighbors=neighbors)
for nbd in reversed(neighbors):
msg = np.dot(msg, nbd.messages_out[self])
if normalize:
msg = msg / (np.sum(msg) + 1e-10)
node.messages_in[self] = msg
return msg
def max_product(self, node: Node, normalize=False):
neighbors = self.get_neighbors(exclude=node)
msg = self._get_init_msg(node, neighbors=neighbors)
for nbd in reversed(neighbors):
msg = np.multiply(msg, nbd.messages_out[self]).max(-1)
if normalize:
msg = msg / (np.sum(msg) + 1e-10)
node.messages_in[self] = msg
return msg
def max_sum(self, node: Node, normalize=False):
neighbors = self.get_neighbors(exclude=node)
msg = self._get_init_msg(node, neighbors=neighbors)
msg = np.log(msg)
for nbd in reversed(neighbors):
msg = (msg + nbd.messages_out[self]).max(-1)
if normalize:
msg = msg - np.log(np.sum(np.exp(msg)))
node.messages_in[self] = msg
return msg