Skip to content

Commit

Permalink
Merge pull request #2 from scipp/fix-conceptual-naming-issues
Browse files Browse the repository at this point in the history
Fix conceptual naming issues and more
  • Loading branch information
SimonHeybrock authored May 14, 2024
2 parents 03db6c2 + 298bedd commit 56a3ccb
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 44 deletions.
121 changes: 82 additions & 39 deletions src/cyclebane/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _get_new_node_name(graph: nx.DiGraph) -> str:
return name


def _remove_with_ancestors(graph: nx.DiGraph, node: Hashable) -> nx.DiGraph:
def _remove_ancestors(graph: nx.DiGraph, node: Hashable) -> nx.DiGraph:
graph = graph.copy()
ancestors = nx.ancestors(graph, node)
ancestors_successors = {
Expand All @@ -46,28 +46,9 @@ def _remove_with_ancestors(graph: nx.DiGraph, node: Hashable) -> nx.DiGraph:
to_remove.append(ancestor)
graph.remove_nodes_from(to_remove)
graph.remove_edges_from(list(graph.in_edges(node)))
graph.remove_node(node)
return graph


def _check_for_conflicts(graph: nx.DiGraph, ancestor_graph):
for node in ancestor_graph.nodes:
if node in graph:
if graph.nodes[node] != ancestor_graph.nodes[node]:
raise ValueError(
f"Node '{node}' has different attributes in ancestor_graph"
)
if list(graph.in_edges(node)) != list(ancestor_graph.in_edges(node)):
raise ValueError(
f"Node '{node}' has different incoming edges in ancestor_graph"
)
# TODO The composite graph may add more edges, so this check is bad
if list(graph.out_edges(node)) != list(ancestor_graph.out_edges(node)):
raise ValueError(
f"Node '{node}' has different outgoing edges in ancestor_graph"
)


@dataclass(frozen=True)
class IndexValues:
axes: tuple[IndexName]
Expand Down Expand Up @@ -113,7 +94,7 @@ def __len__(self):

@dataclass(frozen=True)
class NodeName:
name: str
name: Hashable
index: IndexValues

def merge_index(self, other: IndexValues) -> NodeName:
Expand All @@ -123,6 +104,24 @@ def __str__(self):
return f'{self.name}({self.index})'


@dataclass(frozen=True)
class MappedNode:
name: Hashable
indices: tuple[int, ...]


def node_with_indices(node: Hashable, indices: tuple[int, ...]) -> MappedNode:
if isinstance(node, MappedNode):
return MappedNode(name=node.name, indices=indices + node.indices)
return MappedNode(name=node, indices=indices)


def node_indices(node: Hashable) -> tuple[int, ...] | None:
if isinstance(node, MappedNode):
return node.indices
return ()


def _find_successors(
graph: nx.DiGraph, *, root_nodes: tuple[Hashable]
) -> set[Hashable]:
Expand Down Expand Up @@ -297,6 +296,12 @@ def __init__(self, graph: nx.DiGraph, *, value_attr: str = 'value'):
self._value_attr = value_attr
self._node_values: dict[tuple[IndexName, ...], MappingToArrayLike] = {}

def copy(self) -> Graph:
graph = Graph(self.graph.copy(), value_attr=self._value_attr)
graph.indices = dict(self.indices)
graph._node_values = dict(self._node_values)
return graph

@property
def value_attr(self) -> str:
return self._value_attr
Expand Down Expand Up @@ -330,9 +335,11 @@ def map(self, node_values: MappingToArrayLike) -> Graph:
f'Conflicting new index names {named} with existing {self.index_names}'
)
successors = _find_successors(self.graph, root_nodes=root_nodes)
graph = self.graph.copy()
name_mapping: dict[Hashable, MappedNode] = {}
for node in successors:
graph.nodes[node]['indices'] = named + graph.nodes[node].get('indices', ())
name_mapping[node] = node_with_indices(node, named)
graph = nx.relabel_nodes(self.graph, name_mapping)

out = Graph(graph)
# TODO order?
out.indices = {**indices, **self.indices}
Expand Down Expand Up @@ -375,9 +382,8 @@ def reduce(
attrs = attrs or {}
if index is not None and axis is not None:
raise ValueError('Only one of index and axis can be given')
if key not in self.graph:
raise KeyError(f"Node '{key}' does not exist in the graph.")
indices: tuple[IndexName] = self.graph.nodes[key].get('indices', ())
key = self._from_orig_key(key)
indices: tuple[IndexName] = node_indices(key)
if index is not None and index not in indices:
raise ValueError(f"Node '{key}' does not have index '{index}'.")
# TODO We can support indexing from the back in the future.
Expand All @@ -390,19 +396,36 @@ def reduce(
new_index = tuple(value for i, value in enumerate(indices) if i != axis)
else:
new_index = None
indices_attr = {} if new_index is None else {'indices': new_index}
if name in self.graph:
raise ValueError(f'Node {name} already exists in the graph.')

graph = self.graph.copy()
graph.add_node(name, **attrs, **indices_attr)
name = MappedNode(name=name, indices=new_index) if new_index else name
graph.add_node(name, **attrs)
graph.add_edge(key, name)

out = Graph(graph)
out.indices = dict(self.indices)
out._node_values = dict(self._node_values)
return out

def _from_orig_key(self, key: Hashable) -> Hashable:
# Graph.map relabels nodes to include index names, which can be inconvenient
# for the user. Is this convenience of finding the node by its original name
# worth the complexity and a good idea?
if key not in self.graph:
matches = [
node
for node in self.graph.nodes
if isinstance(node, MappedNode) and node.name == key
]
if len(matches) == 0:
raise KeyError(f"Node '{key}' does not exist in the graph.")
if len(matches) > 1:
raise KeyError(f"Node '{key}' is ambiguous. Found {matches}.")
return matches[0]
return key

def by_position(self, index_name: IndexName) -> PositionalIndexer:
return PositionalIndexer(self, index_name)

Expand All @@ -411,10 +434,11 @@ def to_networkx(self) -> nx.DiGraph:
for index_name, index in reversed(self.indices.items()):
# Find all nodes with this index
nodes = []
for node, data in graph.nodes(data=True):
if (node_indices := data.get('indices', None)) is not None:
if index_name in node_indices:
nodes.append(node)
for node in graph.nodes():
if index_name in node_indices(
node.name if isinstance(node, NodeName) else node
):
nodes.append(node)
# Make a copy for each index value
graphs = [
_rename_successors(
Expand All @@ -423,6 +447,13 @@ def to_networkx(self) -> nx.DiGraph:
for index in _yield_index([(index_name, index)])
]
graph = nx.compose_all(graphs)
# Replace all MappingNodes with their name
new_names = {
node: NodeName(node.name.name, node.index)
for node in graph
if isinstance(node, NodeName)
}
graph = nx.relabel_nodes(graph, new_names)

# Get values using previously stored index values
for values in self._node_values.values():
Expand All @@ -446,6 +477,7 @@ def __getitem__(self, key: Hashable | slice) -> Graph:
"""
if isinstance(key, slice):
raise NotImplementedError('Only single nodes are supported ')
key = self._from_orig_key(key)
ancestors = nx.ancestors(self.graph, key)
ancestors.add(key)
out = Graph(self.graph.subgraph(ancestors))
Expand All @@ -467,16 +499,27 @@ def __setitem__(self, branch: Hashable | slice, other: Graph) -> None:
raise TypeError(f'Expected {Graph}, got {type(other)}')
new_branch = other.graph
sink = _get_unique_sink(new_branch)
graph = _remove_with_ancestors(self.graph, branch)
new_branch = nx.relabel_nodes(new_branch, {sink: branch})
if branch in self.graph:
graph = _remove_ancestors(self.graph, branch)
graph.nodes[branch].clear()
else:
graph = self.graph

# TODO Checks seem complicated, maybe we should just make it the user's
# responsibility to ensure the graphs are compatible?
# _check_for_conflicts(graph, ancestor_graph)
intersection_nodes = set(graph.nodes) & set(new_branch.nodes) - {branch}

for node in intersection_nodes:
if graph.pred[node] != new_branch.pred[node]:
raise ValueError(
f"Node inputs differ for node '{node}':\n"
f" {graph.pred[node]}\n"
f" {new_branch.pred[node]}\n"
)
if graph.nodes[node] != new_branch.nodes[node]:
raise ValueError(f"Node data differs for node '{node}'")

graph = nx.compose(graph, new_branch)
for child in self.graph.successors(branch):
edge_data = self.graph.get_edge_data(branch, child)
graph.add_edge(sink, child, **edge_data)

# Delay setting graph until we know no step fails
self.graph = graph
self.indices.update(other.indices)
Expand Down
59 changes: 54 additions & 5 deletions tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ def test_map_reduce() -> None:

graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3]}).map({'x': [4, 5]})
reduced = mapped.reduce('c', name='func', axis=1)
reduced = mapped.reduce(name='func', axis=1)
# Axis 0 reduces 'x', so there are 2 reduce nodes.
assert len(reduced.to_networkx().nodes) == 19
# Axis 1 reduces 'a', so there are 3 reduce nodes.
reduced = mapped.reduce('c', name='func', axis=0)
reduced = mapped.reduce(name='func', axis=0)
assert len(reduced.to_networkx().nodes) == 20

a_data = [
Expand All @@ -326,7 +326,7 @@ def test_reduce_all_axes() -> None:

graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3]}).map({'b': [4, 5]})
reduced = mapped.reduce('c', name='sum', attrs={'func': 'sum'})
reduced = mapped.reduce(name='sum', attrs={'func': 'sum'})
# No axis or index given, all axes are reduced, so the new node has no index part.
assert 'sum' in reduced.graph
assert reduced.graph.nodes['sum'] == {'func': 'sum'}
Expand All @@ -349,7 +349,7 @@ def test_reduce_raises_if_new_node_name_exists() -> None:
graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3]})
with pytest.raises(ValueError):
mapped.reduce('c', name='other')
mapped.reduce(name='other')


@pytest.mark.parametrize('indexer', [{'axis': 1}, {'index': 'y'}])
Expand All @@ -361,7 +361,7 @@ def test_reduce_raises_if_axis_or_does_not_exist(indexer) -> None:
graph = cb.Graph(g)
mapped = graph.map({'a': sc.arange('x', 3)})
with pytest.raises(ValueError):
mapped.reduce('c', name='combine', **indexer)
mapped.reduce(name='combine', **indexer)


@pytest.mark.parametrize('indexer', [{'axis': 1}, {'index': 'y'}])
Expand Down Expand Up @@ -484,3 +484,52 @@ def test_setitem_raises_TypeError_if_given_networkx_graph() -> None:
graph = cb.Graph(g)
with pytest.raises(TypeError):
graph['a'] = nx.DiGraph()


def test_setitem_with_other_graph_keeps_nodename_of_key_but_replaces_node_data() -> (
None
):
g1 = nx.DiGraph()
g1.add_edge('b', 'a')
g1.nodes['b']['attr'] = 1
g2 = nx.DiGraph()
g2.add_edge('d', 'c')
g2.nodes['c']['attr'] = 2

graph = cb.Graph(g1)
graph['b'] = cb.Graph(g2)
assert 'b' in graph.to_networkx()
nx_graph = graph.to_networkx()
assert set(nx_graph.nodes) == {'a', 'b', 'd'}
assert len(nx_graph.edges) == 2
assert nx_graph.has_edge('d', 'b')
assert nx_graph.has_edge('b', 'a')
assert nx_graph.nodes['b'] == {'attr': 2}


def test_setitem_raises_on_conflicting_ancestor_node_data() -> None:
g1 = nx.DiGraph()
g1.add_edge('a', 'b')
g1.nodes['a']['attr'] = 1
g1.add_edge('x', 'b')
g2 = nx.DiGraph()
g2.add_edge('a', 'x')
g2.nodes['a']['attr'] = 2

graph = cb.Graph(g1)
with pytest.raises(ValueError, match="Node data differs for node 'a'"):
graph['x'] = cb.Graph(g2)


def test_setitem_raises_on_conflicting_input_nodes_in_ancestor() -> None:
g1 = nx.DiGraph()
g1.add_edge('a1', 'b')
g1.add_edge('b', 'c')
g1.add_edge('x', 'c')
g2 = nx.DiGraph()
g2.add_edge('a2', 'b')
g2.add_edge('b', 'x')

graph = cb.Graph(g1)
with pytest.raises(ValueError, match="Node inputs differ for node 'b'"):
graph['x'] = cb.Graph(g2)

0 comments on commit 56a3ccb

Please sign in to comment.