Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ancestor removal logic #6

Merged
merged 4 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions src/cyclebane/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

from collections.abc import Generator, Hashable, Iterable
from dataclasses import dataclass
from typing import Any, Generator, Hashable, Iterable
from typing import Any
from uuid import uuid4

import networkx as nx
Expand All @@ -26,16 +27,20 @@ def _get_new_node_name(graph: nx.DiGraph) -> str:


def _remove_ancestors(graph: nx.DiGraph, node: Hashable) -> nx.DiGraph:
graph = graph.copy()
graph_without_node = graph.copy()
graph_without_node.remove_node(node)
ancestors = nx.ancestors(graph, node)
ancestors_successors = {
ancestor: graph.successors(ancestor) for ancestor in ancestors
# Considering the graph we obtain by removing `node`, we need to consider the
# descendants of each ancestor. If an ancestor has descendants that are not
# removal candidates, we should not remove the ancestor.
ancestors_descendants = {
ancestor: nx.descendants(graph_without_node, ancestor) for ancestor in ancestors
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use

Suggested change
ancestors_descendants = {
ancestor: nx.descendants(graph_without_node, ancestor) for ancestor in ancestors
}
ancestors_descendants = [
(ancestor, nx.descendants(graph_without_node, ancestor)) for ancestor in ancestors
]

to avoid computing hashes. And technically, you could also express the whole construction of to_remove as a list comprehension. (I'm surprised ruff didn't suggest the latter.)

to_remove = []
for ancestor, successors in ancestors_successors.items():
# If any successor does not have node as descendant we must keep the node
if all(nx.has_path(graph, successor, node) for successor in successors):
for ancestor, descendants in ancestors_descendants.items():
if descendants.issubset(ancestors):
to_remove.append(ancestor)
graph = graph.copy()
graph.remove_nodes_from(to_remove)
graph.remove_edges_from(list(graph.in_edges(node)))
return graph
Expand All @@ -59,7 +64,7 @@ def from_tuple(t: tuple[tuple[IndexName, IndexValue], ...]) -> IndexValues:
return IndexValues(axes=names, values=values)

def to_tuple(self) -> tuple[tuple[IndexName, IndexValue], ...]:
return tuple(zip(self.axes, self.values))
return tuple(zip(self.axes, self.values, strict=True))

def merge_index(self, other: IndexValues) -> IndexValues:
return IndexValues(
Expand All @@ -68,7 +73,8 @@ def merge_index(self, other: IndexValues) -> IndexValues:

def __str__(self) -> str:
return ', '.join(
f'{name}={value}' for name, value in zip(self.axes, self.values)
f'{name}={value}'
for name, value in zip(self.axes, self.values, strict=True)
)

def __len__(self) -> int:
Expand Down Expand Up @@ -139,7 +145,7 @@ def _rename_successors(


def _yield_index(
indices: list[tuple[IndexName, Iterable[IndexValue]]]
indices: list[tuple[IndexName, Iterable[IndexValue]]],
) -> Generator[tuple[tuple[IndexName, IndexValue], ...], None, None]:
"""Given a multi-dimensional index, yield all possible combinations."""
name, index = indices[0]
Expand All @@ -148,7 +154,7 @@ def _yield_index(
yield ((name, index_value),)
else:
for rest in _yield_index(indices[1:]):
yield ((name, index_value),) + rest
yield ((name, index_value), *rest)


class PositionalIndexer:
Expand Down Expand Up @@ -364,12 +370,12 @@ def to_networkx(self, value_attr: str = 'value') -> nx.DiGraph:
graph = self.graph
for index_name, index in reversed(self.indices.items()):
# Find all nodes with this index
nodes = []
for node in graph.nodes():
if index_name in _node_indices(
node.name if isinstance(node, NodeName) else node
):
nodes.append(node)
nodes = [
node
for node in graph.nodes()
if index_name
in _node_indices(node.name if isinstance(node, NodeName) else node)
]
# Make a copy for each index value
graphs = [
_rename_successors(
Expand Down Expand Up @@ -409,7 +415,7 @@ def __getitem__(self, key: Hashable | slice) -> Graph:
ancestors = nx.ancestors(self.graph, key)
ancestors.add(key)
# Drop all node values that are not in the branch
mapped = set(a.name for a in ancestors if isinstance(a, MappedNode))
mapped = {a.name for a in ancestors if isinstance(a, MappedNode)}
keep_values = [key for key in self._node_values.keys() if key in mapped]
return Graph(
self.graph.subgraph(ancestors),
Expand Down
25 changes: 12 additions & 13 deletions src/cyclebane/node_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections import abc
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Mapping, Sequence
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar

if TYPE_CHECKING:
import numpy
Expand All @@ -26,7 +26,7 @@ class ValueArray(ABC):
simple Python iterables.
"""

_registry = []
_registry: ClassVar = []

def __init_subclass__(cls) -> None:
super().__init_subclass__()
Expand All @@ -43,8 +43,7 @@ def from_array_like(values: Any, *, axis_zero: int = 0) -> ValueArray:

@staticmethod
@abstractmethod
def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None:
...
def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: ...

@abstractmethod
def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
Expand Down Expand Up @@ -121,7 +120,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:


class PandasSeriesAdapter(ValueArray):
def __init__(self, series: 'pandas.Series', *, axis_zero: int = 0):
def __init__(self, series: pandas.Series, *, axis_zero: int = 0):
self._series = series
self._axis_zero = axis_zero

Expand Down Expand Up @@ -170,7 +169,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:
class XarrayDataArrayAdapter(ValueArray):
def __init__(
self,
data_array: 'xarray.DataArray',
data_array: xarray.DataArray,
):
default_indices = {
dim: range(size)
Expand Down Expand Up @@ -211,7 +210,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:


class ScippDataArrayAdapter(ValueArray):
def __init__(self, data_array: 'scipp.DataArray'):
def __init__(self, data_array: scipp.DataArray):
import scipp

default_indices = {
Expand Down Expand Up @@ -264,7 +263,7 @@ def shape(self) -> tuple[int, ...]:
def index_names(self) -> tuple[IndexName, ...]:
return tuple(self._data_array.dims)

def _index_for_dim(self, dim: str) -> list[tuple[Any, 'scipp.Unit']]:
def _index_for_dim(self, dim: str) -> list[tuple[Any, scipp.Unit]]:
# Work around some NetworkX errors. Probably scipp.Variable lacks functionality.
# For now we return a list of tuples, where the first element is the value and
# the second is the unit.
Expand All @@ -283,7 +282,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:
class NumpyArrayAdapter(ValueArray):
def __init__(
self,
array: 'numpy.ndarray',
array: numpy.ndarray,
*,
indices: dict[IndexName, Iterable[IndexValue]] | None = None,
axis_zero: int = 0,
Expand Down Expand Up @@ -335,7 +334,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:
return self._indices


class NodeValues(abc.Mapping[Hashable, ValueArray]):
class NodeValues(Mapping[Hashable, ValueArray]):
"""
A collection of pandas.DataFrame-like objects with distinct indices.

Expand All @@ -349,7 +348,7 @@ def __len__(self) -> int:
"""Return the number of columns."""
return len(self._values)

def __iter__(self) -> Iterable[Hashable]:
def __iter__(self) -> Iterator[Hashable]:
"""Iterate over the column names."""
return iter(self._values)

Expand Down Expand Up @@ -377,7 +376,7 @@ def from_mapping(
def merge(self, value_arrays: Mapping[Hashable, ValueArray]) -> NodeValues:
if value_arrays:
named = next(iter(value_arrays.values())).index_names
if any([name in self.indices for name in named]):
if any(name in self.indices for name in named):
raise ValueError(
f'Conflicting new index names {named} with existing '
f'{tuple(self.indices)}'
Expand Down
56 changes: 39 additions & 17 deletions tests/graph_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from typing import Hashable, Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_map_with_previously_mapped_index_name_raises() -> None:
values = sc.arange('x', 3)

mapped = graph.map({'a': values})
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="Conflicting new index names"):
mapped.map({'b': values})


Expand Down Expand Up @@ -441,7 +441,7 @@ def test_reduce_raises_if_axis_or_does_not_exist(indexer) -> None:

graph = cb.Graph(g)
mapped = graph.map({'a': sc.arange('x', 3)})
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="does not have"):
mapped.reduce(name='combine', **indexer)


Expand Down Expand Up @@ -498,8 +498,8 @@ def test_can_reduce_same_node_multiple_times() -> None:
reduced = mapped.reduce('b', name='c1', axis=0).reduce('b', name='c2', axis=0)
result = reduced.to_networkx()
assert len(result.nodes) == 3 + 3 + 1 + 1
c1_parents = [n for n in result.predecessors('c1')]
c2_parents = [n for n in result.predecessors('c2')]
c1_parents = list(result.predecessors('c1'))
c2_parents = list(result.predecessors('c2'))
assert c1_parents == c2_parents


Expand All @@ -525,19 +525,19 @@ def test_can_reduce_different_axes_or_indices_of_same_node() -> None:
reduced = reduced.to_networkx()

assert len(reduced.nodes) == 9 + 9 + 4 * 3
c0s = [n for n in helper.predecessors('d0')]
c1s = [n for n in helper.predecessors('d1')]
cxs = [n for n in helper.predecessors('dx')]
cys = [n for n in helper.predecessors('dy')]

for c0, cx in zip(c0s, cxs):
c0_parents = [n for n in reduced.predecessors(c0)]
cx_parents = [n for n in reduced.predecessors(cx)]
c0s = list(helper.predecessors('d0'))
c1s = list(helper.predecessors('d1'))
cxs = list(helper.predecessors('dx'))
cys = list(helper.predecessors('dy'))

for c0, cx in zip(c0s, cxs, strict=True):
c0_parents = list(reduced.predecessors(c0))
cx_parents = list(reduced.predecessors(cx))
assert c0_parents == cx_parents

for c1, cy in zip(c1s, cys):
c1_parents = [n for n in reduced.predecessors(c1)]
cy_parents = [n for n in reduced.predecessors(cy)]
for c1, cy in zip(c1s, cys, strict=True):
c1_parents = list(reduced.predecessors(c1))
cy_parents = list(reduced.predecessors(cy))
assert c1_parents == cy_parents


Expand Down Expand Up @@ -616,6 +616,28 @@ def test_setitem_raises_on_conflicting_input_nodes_in_ancestor() -> None:
graph['x'] = cb.Graph(g2)


def test_setitem_replaces_nodes_that_are_not_ancestors_of_unrelated_node() -> None:
g1 = nx.DiGraph()
g1.add_edge('a', 'b')
g1.add_edge('b', 'c')
g1.add_edge('c', 'd')
graph = cb.Graph(g1)
g2 = nx.DiGraph()
g2.add_edge('b', 'c')
graph['c'] = cb.Graph(g2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no assert here.


def test_setitem_preserves_nodes_that_are_ancestors_of_unrelated_node() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
g.add_edge('b', 'c')
g.add_edge('b', 'd')
g.add_edge('c', 'd')
graph = cb.Graph(g)
graph['c'] = graph['c']
assert 'a' in graph.to_networkx()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
assert 'a' in graph.to_networkx()
assert graph.to_networkx() == g

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be nx.utils.graphs_equal(graph.to_networkx(), g) for equality.



def test_getitem_returns_graph_containing_only_key_and_ancestors() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
Expand Down Expand Up @@ -668,7 +690,7 @@ def test_getitem_keeps_only_relevant_node_values() -> None:
graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3]})
# This fails due to existing mapping...
with pytest.raises(ValueError):
with pytest.raises(ValueError, match='has already been mapped'):
mapped.map({'a': [1, 2]})
# ... but getitem drops the 'a' mapping, so we can map 'a' again:
mapped['b'].map({'a': [1, 2]})
Expand Down