Skip to content

Commit

Permalink
Merge pull request #6 from scipp/fix-ancestor-removal-logic
Browse files Browse the repository at this point in the history
Fix ancestor removal logic
  • Loading branch information
SimonHeybrock authored Jun 10, 2024
2 parents c903bf9 + 806f8e8 commit d9c159f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 51 deletions.
46 changes: 25 additions & 21 deletions src/cyclebane/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

from collections.abc import Generator, Hashable, Iterable
from dataclasses import dataclass
from typing import Any, Generator, Hashable, Iterable
from typing import Any
from uuid import uuid4

import networkx as nx
Expand All @@ -26,16 +27,18 @@ def _get_new_node_name(graph: nx.DiGraph) -> str:


def _remove_ancestors(graph: nx.DiGraph, node: Hashable) -> nx.DiGraph:
graph = graph.copy()
graph_without_node = graph.copy()
graph_without_node.remove_node(node)
ancestors = nx.ancestors(graph, node)
ancestors_successors = {
ancestor: graph.successors(ancestor) for ancestor in ancestors
}
to_remove = []
for ancestor, successors in ancestors_successors.items():
# If any successor does not have node as descendant we must keep the node
if all(nx.has_path(graph, successor, node) for successor in successors):
to_remove.append(ancestor)
# Considering the graph we obtain by removing `node`, we need to consider the
# descendants of each ancestor. If an ancestor has descendants that are not
# removal candidates, we should not remove the ancestor.
to_remove = [
ancestor
for ancestor in ancestors
if nx.descendants(graph_without_node, ancestor).issubset(ancestors)
]
graph = graph.copy()
graph.remove_nodes_from(to_remove)
graph.remove_edges_from(list(graph.in_edges(node)))
return graph
Expand All @@ -59,7 +62,7 @@ def from_tuple(t: tuple[tuple[IndexName, IndexValue], ...]) -> IndexValues:
return IndexValues(axes=names, values=values)

def to_tuple(self) -> tuple[tuple[IndexName, IndexValue], ...]:
return tuple(zip(self.axes, self.values))
return tuple(zip(self.axes, self.values, strict=True))

def merge_index(self, other: IndexValues) -> IndexValues:
return IndexValues(
Expand All @@ -68,7 +71,8 @@ def merge_index(self, other: IndexValues) -> IndexValues:

def __str__(self) -> str:
return ', '.join(
f'{name}={value}' for name, value in zip(self.axes, self.values)
f'{name}={value}'
for name, value in zip(self.axes, self.values, strict=True)
)

def __len__(self) -> int:
Expand Down Expand Up @@ -139,7 +143,7 @@ def _rename_successors(


def _yield_index(
indices: list[tuple[IndexName, Iterable[IndexValue]]]
indices: list[tuple[IndexName, Iterable[IndexValue]]],
) -> Generator[tuple[tuple[IndexName, IndexValue], ...], None, None]:
"""Given a multi-dimensional index, yield all possible combinations."""
name, index = indices[0]
Expand All @@ -148,7 +152,7 @@ def _yield_index(
yield ((name, index_value),)
else:
for rest in _yield_index(indices[1:]):
yield ((name, index_value),) + rest
yield ((name, index_value), *rest)


class PositionalIndexer:
Expand Down Expand Up @@ -364,12 +368,12 @@ def to_networkx(self, value_attr: str = 'value') -> nx.DiGraph:
graph = self.graph
for index_name, index in reversed(self.indices.items()):
# Find all nodes with this index
nodes = []
for node in graph.nodes():
if index_name in _node_indices(
node.name if isinstance(node, NodeName) else node
):
nodes.append(node)
nodes = [
node
for node in graph.nodes()
if index_name
in _node_indices(node.name if isinstance(node, NodeName) else node)
]
# Make a copy for each index value
graphs = [
_rename_successors(
Expand Down Expand Up @@ -409,7 +413,7 @@ def __getitem__(self, key: Hashable | slice) -> Graph:
ancestors = nx.ancestors(self.graph, key)
ancestors.add(key)
# Drop all node values that are not in the branch
mapped = set(a.name for a in ancestors if isinstance(a, MappedNode))
mapped = {a.name for a in ancestors if isinstance(a, MappedNode)}
keep_values = [key for key in self._node_values.keys() if key in mapped]
return Graph(
self.graph.subgraph(ancestors),
Expand Down
25 changes: 12 additions & 13 deletions src/cyclebane/node_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections import abc
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Mapping, Sequence
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar

if TYPE_CHECKING:
import numpy
Expand All @@ -26,7 +26,7 @@ class ValueArray(ABC):
simple Python iterables.
"""

_registry = []
_registry: ClassVar = []

def __init_subclass__(cls) -> None:
super().__init_subclass__()
Expand All @@ -43,8 +43,7 @@ def from_array_like(values: Any, *, axis_zero: int = 0) -> ValueArray:

@staticmethod
@abstractmethod
def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None:
...
def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: ...

@abstractmethod
def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
Expand Down Expand Up @@ -121,7 +120,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:


class PandasSeriesAdapter(ValueArray):
def __init__(self, series: 'pandas.Series', *, axis_zero: int = 0):
def __init__(self, series: pandas.Series, *, axis_zero: int = 0):
self._series = series
self._axis_zero = axis_zero

Expand Down Expand Up @@ -170,7 +169,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:
class XarrayDataArrayAdapter(ValueArray):
def __init__(
self,
data_array: 'xarray.DataArray',
data_array: xarray.DataArray,
):
default_indices = {
dim: range(size)
Expand Down Expand Up @@ -211,7 +210,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:


class ScippDataArrayAdapter(ValueArray):
def __init__(self, data_array: 'scipp.DataArray'):
def __init__(self, data_array: scipp.DataArray):
import scipp

default_indices = {
Expand Down Expand Up @@ -264,7 +263,7 @@ def shape(self) -> tuple[int, ...]:
def index_names(self) -> tuple[IndexName, ...]:
return tuple(self._data_array.dims)

def _index_for_dim(self, dim: str) -> list[tuple[Any, 'scipp.Unit']]:
def _index_for_dim(self, dim: str) -> list[tuple[Any, scipp.Unit]]:
# Work around some NetworkX errors. Probably scipp.Variable lacks functionality.
# For now we return a list of tuples, where the first element is the value and
# the second is the unit.
Expand All @@ -283,7 +282,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:
class NumpyArrayAdapter(ValueArray):
def __init__(
self,
array: 'numpy.ndarray',
array: numpy.ndarray,
*,
indices: dict[IndexName, Iterable[IndexValue]] | None = None,
axis_zero: int = 0,
Expand Down Expand Up @@ -335,7 +334,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:
return self._indices


class NodeValues(abc.Mapping[Hashable, ValueArray]):
class NodeValues(Mapping[Hashable, ValueArray]):
"""
A collection of pandas.DataFrame-like objects with distinct indices.
Expand All @@ -349,7 +348,7 @@ def __len__(self) -> int:
"""Return the number of columns."""
return len(self._values)

def __iter__(self) -> Iterable[Hashable]:
def __iter__(self) -> Iterator[Hashable]:
"""Iterate over the column names."""
return iter(self._values)

Expand Down Expand Up @@ -377,7 +376,7 @@ def from_mapping(
def merge(self, value_arrays: Mapping[Hashable, ValueArray]) -> NodeValues:
if value_arrays:
named = next(iter(value_arrays.values())).index_names
if any([name in self.indices for name in named]):
if any(name in self.indices for name in named):
raise ValueError(
f'Conflicting new index names {named} with existing '
f'{tuple(self.indices)}'
Expand Down
57 changes: 40 additions & 17 deletions tests/graph_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Hashable, Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_map_with_previously_mapped_index_name_raises() -> None:
values = sc.arange('x', 3)

mapped = graph.map({'a': values})
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Conflicting new index names"):
mapped.map({'b': values})


Expand Down Expand Up @@ -441,7 +441,7 @@ def test_reduce_raises_if_axis_or_does_not_exist(indexer) -> None:

graph = cb.Graph(g)
mapped = graph.map({'a': sc.arange('x', 3)})
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="does not have"):
mapped.reduce(name='combine', **indexer)


Expand Down Expand Up @@ -498,8 +498,8 @@ def test_can_reduce_same_node_multiple_times() -> None:
reduced = mapped.reduce('b', name='c1', axis=0).reduce('b', name='c2', axis=0)
result = reduced.to_networkx()
assert len(result.nodes) == 3 + 3 + 1 + 1
c1_parents = [n for n in result.predecessors('c1')]
c2_parents = [n for n in result.predecessors('c2')]
c1_parents = list(result.predecessors('c1'))
c2_parents = list(result.predecessors('c2'))
assert c1_parents == c2_parents


Expand All @@ -525,19 +525,19 @@ def test_can_reduce_different_axes_or_indices_of_same_node() -> None:
reduced = reduced.to_networkx()

assert len(reduced.nodes) == 9 + 9 + 4 * 3
c0s = [n for n in helper.predecessors('d0')]
c1s = [n for n in helper.predecessors('d1')]
cxs = [n for n in helper.predecessors('dx')]
cys = [n for n in helper.predecessors('dy')]

for c0, cx in zip(c0s, cxs):
c0_parents = [n for n in reduced.predecessors(c0)]
cx_parents = [n for n in reduced.predecessors(cx)]
c0s = list(helper.predecessors('d0'))
c1s = list(helper.predecessors('d1'))
cxs = list(helper.predecessors('dx'))
cys = list(helper.predecessors('dy'))

for c0, cx in zip(c0s, cxs, strict=True):
c0_parents = list(reduced.predecessors(c0))
cx_parents = list(reduced.predecessors(cx))
assert c0_parents == cx_parents

for c1, cy in zip(c1s, cys):
c1_parents = [n for n in reduced.predecessors(c1)]
cy_parents = [n for n in reduced.predecessors(cy)]
for c1, cy in zip(c1s, cys, strict=True):
c1_parents = list(reduced.predecessors(c1))
cy_parents = list(reduced.predecessors(cy))
assert c1_parents == cy_parents


Expand Down Expand Up @@ -616,6 +616,29 @@ def test_setitem_raises_on_conflicting_input_nodes_in_ancestor() -> None:
graph['x'] = cb.Graph(g2)


def test_setitem_replaces_nodes_that_are_not_ancestors_of_unrelated_node() -> None:
g1 = nx.DiGraph()
g1.add_edge('a', 'b')
g1.add_edge('b', 'c')
g1.add_edge('c', 'd')
graph = cb.Graph(g1)
g2 = nx.DiGraph()
g2.add_edge('b', 'c')
graph['c'] = cb.Graph(g2)
assert 'a' not in graph.to_networkx()


def test_setitem_preserves_nodes_that_are_ancestors_of_unrelated_node() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
g.add_edge('b', 'c')
g.add_edge('b', 'd')
g.add_edge('c', 'd')
graph = cb.Graph(g)
graph['c'] = graph['c']
nx.utils.graphs_equal(graph.to_networkx(), g)


def test_getitem_returns_graph_containing_only_key_and_ancestors() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
Expand Down Expand Up @@ -668,7 +691,7 @@ def test_getitem_keeps_only_relevant_node_values() -> None:
graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3]})
# This fails due to existing mapping...
with pytest.raises(ValueError):
with pytest.raises(ValueError, match='has already been mapped'):
mapped.map({'a': [1, 2]})
# ... but getitem drops the 'a' mapping, so we can map 'a' again:
mapped['b'].map({'a': [1, 2]})
Expand Down

0 comments on commit d9c159f

Please sign in to comment.