diff --git a/src/cyclebane/graph.py b/src/cyclebane/graph.py index a91f1bd..6326e77 100644 --- a/src/cyclebane/graph.py +++ b/src/cyclebane/graph.py @@ -2,8 +2,9 @@ # Copyright (c) 2024 Scipp contributors (https://github.com/scipp) from __future__ import annotations +from collections.abc import Generator, Hashable, Iterable from dataclasses import dataclass -from typing import Any, Generator, Hashable, Iterable +from typing import Any from uuid import uuid4 import networkx as nx @@ -26,16 +27,18 @@ def _get_new_node_name(graph: nx.DiGraph) -> str: def _remove_ancestors(graph: nx.DiGraph, node: Hashable) -> nx.DiGraph: - graph = graph.copy() + graph_without_node = graph.copy() + graph_without_node.remove_node(node) ancestors = nx.ancestors(graph, node) - ancestors_successors = { - ancestor: graph.successors(ancestor) for ancestor in ancestors - } - to_remove = [] - for ancestor, successors in ancestors_successors.items(): - # If any successor does not have node as descendant we must keep the node - if all(nx.has_path(graph, successor, node) for successor in successors): - to_remove.append(ancestor) + # Considering the graph we obtain by removing `node`, we need to consider the + # descendants of each ancestor. If an ancestor has descendants that are not + # removal candidates, we should not remove the ancestor. + to_remove = [ + ancestor + for ancestor in ancestors + if nx.descendants(graph_without_node, ancestor).issubset(ancestors) + ] + graph = graph.copy() graph.remove_nodes_from(to_remove) graph.remove_edges_from(list(graph.in_edges(node))) return graph @@ -59,7 +62,7 @@ def from_tuple(t: tuple[tuple[IndexName, IndexValue], ...]) -> IndexValues: return IndexValues(axes=names, values=values) def to_tuple(self) -> tuple[tuple[IndexName, IndexValue], ...]: - return tuple(zip(self.axes, self.values)) + return tuple(zip(self.axes, self.values, strict=True)) def merge_index(self, other: IndexValues) -> IndexValues: return IndexValues( @@ -68,7 +71,8 @@ def merge_index(self, other: IndexValues) -> IndexValues: def __str__(self) -> str: return ', '.join( - f'{name}={value}' for name, value in zip(self.axes, self.values) + f'{name}={value}' + for name, value in zip(self.axes, self.values, strict=True) ) def __len__(self) -> int: @@ -139,7 +143,7 @@ def _rename_successors( def _yield_index( - indices: list[tuple[IndexName, Iterable[IndexValue]]] + indices: list[tuple[IndexName, Iterable[IndexValue]]], ) -> Generator[tuple[tuple[IndexName, IndexValue], ...], None, None]: """Given a multi-dimensional index, yield all possible combinations.""" name, index = indices[0] @@ -148,7 +152,7 @@ def _yield_index( yield ((name, index_value),) else: for rest in _yield_index(indices[1:]): - yield ((name, index_value),) + rest + yield ((name, index_value), *rest) class PositionalIndexer: @@ -364,12 +368,12 @@ def to_networkx(self, value_attr: str = 'value') -> nx.DiGraph: graph = self.graph for index_name, index in reversed(self.indices.items()): # Find all nodes with this index - nodes = [] - for node in graph.nodes(): - if index_name in _node_indices( - node.name if isinstance(node, NodeName) else node - ): - nodes.append(node) + nodes = [ + node + for node in graph.nodes() + if index_name + in _node_indices(node.name if isinstance(node, NodeName) else node) + ] # Make a copy for each index value graphs = [ _rename_successors( @@ -409,7 +413,7 @@ def __getitem__(self, key: Hashable | slice) -> Graph: ancestors = nx.ancestors(self.graph, key) ancestors.add(key) # Drop all node values that are not in the branch - mapped = set(a.name for a in ancestors if isinstance(a, MappedNode)) + mapped = {a.name for a in ancestors if isinstance(a, MappedNode)} keep_values = [key for key in self._node_values.keys() if key in mapped] return Graph( self.graph.subgraph(ancestors), diff --git a/src/cyclebane/node_values.py b/src/cyclebane/node_values.py index 4ade437..96b534f 100644 --- a/src/cyclebane/node_values.py +++ b/src/cyclebane/node_values.py @@ -3,8 +3,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections import abc -from typing import TYPE_CHECKING, Any, Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, ClassVar if TYPE_CHECKING: import numpy @@ -26,7 +26,7 @@ class ValueArray(ABC): simple Python iterables. """ - _registry = [] + _registry: ClassVar = [] def __init_subclass__(cls) -> None: super().__init_subclass__() @@ -43,8 +43,7 @@ def from_array_like(values: Any, *, axis_zero: int = 0) -> ValueArray: @staticmethod @abstractmethod - def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: - ... + def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: ... @abstractmethod def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: @@ -121,7 +120,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]: class PandasSeriesAdapter(ValueArray): - def __init__(self, series: 'pandas.Series', *, axis_zero: int = 0): + def __init__(self, series: pandas.Series, *, axis_zero: int = 0): self._series = series self._axis_zero = axis_zero @@ -170,7 +169,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]: class XarrayDataArrayAdapter(ValueArray): def __init__( self, - data_array: 'xarray.DataArray', + data_array: xarray.DataArray, ): default_indices = { dim: range(size) @@ -211,7 +210,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]: class ScippDataArrayAdapter(ValueArray): - def __init__(self, data_array: 'scipp.DataArray'): + def __init__(self, data_array: scipp.DataArray): import scipp default_indices = { @@ -264,7 +263,7 @@ def shape(self) -> tuple[int, ...]: def index_names(self) -> tuple[IndexName, ...]: return tuple(self._data_array.dims) - def _index_for_dim(self, dim: str) -> list[tuple[Any, 'scipp.Unit']]: + def _index_for_dim(self, dim: str) -> list[tuple[Any, scipp.Unit]]: # Work around some NetworkX errors. Probably scipp.Variable lacks functionality. # For now we return a list of tuples, where the first element is the value and # the second is the unit. @@ -283,7 +282,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]: class NumpyArrayAdapter(ValueArray): def __init__( self, - array: 'numpy.ndarray', + array: numpy.ndarray, *, indices: dict[IndexName, Iterable[IndexValue]] | None = None, axis_zero: int = 0, @@ -335,7 +334,7 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]: return self._indices -class NodeValues(abc.Mapping[Hashable, ValueArray]): +class NodeValues(Mapping[Hashable, ValueArray]): """ A collection of pandas.DataFrame-like objects with distinct indices. @@ -349,7 +348,7 @@ def __len__(self) -> int: """Return the number of columns.""" return len(self._values) - def __iter__(self) -> Iterable[Hashable]: + def __iter__(self) -> Iterator[Hashable]: """Iterate over the column names.""" return iter(self._values) @@ -377,7 +376,7 @@ def from_mapping( def merge(self, value_arrays: Mapping[Hashable, ValueArray]) -> NodeValues: if value_arrays: named = next(iter(value_arrays.values())).index_names - if any([name in self.indices for name in named]): + if any(name in self.indices for name in named): raise ValueError( f'Conflicting new index names {named} with existing ' f'{tuple(self.indices)}' diff --git a/tests/graph_test.py b/tests/graph_test.py index e4cee95..7e657fa 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import Hashable, Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence import networkx as nx import numpy as np @@ -362,7 +362,7 @@ def test_map_with_previously_mapped_index_name_raises() -> None: values = sc.arange('x', 3) mapped = graph.map({'a': values}) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Conflicting new index names"): mapped.map({'b': values}) @@ -441,7 +441,7 @@ def test_reduce_raises_if_axis_or_does_not_exist(indexer) -> None: graph = cb.Graph(g) mapped = graph.map({'a': sc.arange('x', 3)}) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="does not have"): mapped.reduce(name='combine', **indexer) @@ -498,8 +498,8 @@ def test_can_reduce_same_node_multiple_times() -> None: reduced = mapped.reduce('b', name='c1', axis=0).reduce('b', name='c2', axis=0) result = reduced.to_networkx() assert len(result.nodes) == 3 + 3 + 1 + 1 - c1_parents = [n for n in result.predecessors('c1')] - c2_parents = [n for n in result.predecessors('c2')] + c1_parents = list(result.predecessors('c1')) + c2_parents = list(result.predecessors('c2')) assert c1_parents == c2_parents @@ -525,19 +525,19 @@ def test_can_reduce_different_axes_or_indices_of_same_node() -> None: reduced = reduced.to_networkx() assert len(reduced.nodes) == 9 + 9 + 4 * 3 - c0s = [n for n in helper.predecessors('d0')] - c1s = [n for n in helper.predecessors('d1')] - cxs = [n for n in helper.predecessors('dx')] - cys = [n for n in helper.predecessors('dy')] - - for c0, cx in zip(c0s, cxs): - c0_parents = [n for n in reduced.predecessors(c0)] - cx_parents = [n for n in reduced.predecessors(cx)] + c0s = list(helper.predecessors('d0')) + c1s = list(helper.predecessors('d1')) + cxs = list(helper.predecessors('dx')) + cys = list(helper.predecessors('dy')) + + for c0, cx in zip(c0s, cxs, strict=True): + c0_parents = list(reduced.predecessors(c0)) + cx_parents = list(reduced.predecessors(cx)) assert c0_parents == cx_parents - for c1, cy in zip(c1s, cys): - c1_parents = [n for n in reduced.predecessors(c1)] - cy_parents = [n for n in reduced.predecessors(cy)] + for c1, cy in zip(c1s, cys, strict=True): + c1_parents = list(reduced.predecessors(c1)) + cy_parents = list(reduced.predecessors(cy)) assert c1_parents == cy_parents @@ -616,6 +616,29 @@ def test_setitem_raises_on_conflicting_input_nodes_in_ancestor() -> None: graph['x'] = cb.Graph(g2) +def test_setitem_replaces_nodes_that_are_not_ancestors_of_unrelated_node() -> None: + g1 = nx.DiGraph() + g1.add_edge('a', 'b') + g1.add_edge('b', 'c') + g1.add_edge('c', 'd') + graph = cb.Graph(g1) + g2 = nx.DiGraph() + g2.add_edge('b', 'c') + graph['c'] = cb.Graph(g2) + assert 'a' not in graph.to_networkx() + + +def test_setitem_preserves_nodes_that_are_ancestors_of_unrelated_node() -> None: + g = nx.DiGraph() + g.add_edge('a', 'b') + g.add_edge('b', 'c') + g.add_edge('b', 'd') + g.add_edge('c', 'd') + graph = cb.Graph(g) + graph['c'] = graph['c'] + nx.utils.graphs_equal(graph.to_networkx(), g) + + def test_getitem_returns_graph_containing_only_key_and_ancestors() -> None: g = nx.DiGraph() g.add_edge('a', 'b') @@ -668,7 +691,7 @@ def test_getitem_keeps_only_relevant_node_values() -> None: graph = cb.Graph(g) mapped = graph.map({'a': [1, 2, 3]}) # This fails due to existing mapping... - with pytest.raises(ValueError): + with pytest.raises(ValueError, match='has already been mapped'): mapped.map({'a': [1, 2]}) # ... but getitem drops the 'a' mapping, so we can map 'a' again: mapped['b'].map({'a': [1, 2]})