From c1ca8044f40c0a6e6c9d6228c746a1fa0485caf0 Mon Sep 17 00:00:00 2001 From: Simon Heybrock Date: Thu, 6 Jun 2024 05:49:07 +0200 Subject: [PATCH] Allow compatible values in __setitem__ --- src/cyclebane/node_values.py | 61 ++++++++++++++++++++++++++++++------ tests/graph_test.py | 60 ++++++++++++++++++++++++++++++++--- 2 files changed, 107 insertions(+), 14 deletions(-) diff --git a/src/cyclebane/node_values.py b/src/cyclebane/node_values.py index 96b534f..ffc9a4b 100644 --- a/src/cyclebane/node_values.py +++ b/src/cyclebane/node_values.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar +from types import ModuleType +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar if TYPE_CHECKING: import numpy @@ -15,6 +16,8 @@ IndexName = Hashable IndexValue = Hashable +T = TypeVar('T', bound='ValueArray') + class ValueArray(ABC): """ @@ -45,6 +48,17 @@ def from_array_like(values: Any, *, axis_zero: int = 0) -> ValueArray: @abstractmethod def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: ... + def __eq__(self, other: object) -> bool: + if type(self) != type(other): + return False + return self._equal(other) + + def __ne__(self, other: object) -> bool: + return not self == other + + @abstractmethod + def _equal(self: T, other: T) -> bool: ... + @abstractmethod def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: """Return data by selecting from index with given name and index value.""" @@ -94,6 +108,13 @@ def __init__( def try_from(obj: Any, *, axis_zero: int = 0) -> SequenceAdapter | None: return SequenceAdapter(obj, axis_zero=axis_zero) + def _equal(self, other: SequenceAdapter) -> bool: + return ( + self._values == other._values + and self._index == other._index + and self._axis_zero == other._axis_zero + ) + def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: if len(key) != 1: raise ValueError('SequenceAdapter only supports single index') @@ -133,6 +154,11 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> PandasSeriesAdapter | None: if isinstance(obj, pandas.Series): return PandasSeriesAdapter(obj, axis_zero=axis_zero) + def _equal(self, other: PandasSeriesAdapter) -> bool: + return ( + self._series.equals(other._series) and self._axis_zero == other._axis_zero + ) + def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: if len(key) != 1: raise ValueError('PandasSeriesAdapter only supports single index') @@ -188,6 +214,9 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> XarrayDataArrayAdapter | None: except ModuleNotFoundError: pass + def _equal(self, other: XarrayDataArrayAdapter) -> bool: + return self._data_array.identical(other._data_array) + def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: return self._data_array.sel(dict(key)) @@ -210,15 +239,14 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]: class ScippDataArrayAdapter(ValueArray): - def __init__(self, data_array: scipp.DataArray): - import scipp - + def __init__(self, data_array: scipp.DataArray, scipp: ModuleType): default_indices = { dim: scipp.arange(dim, size, unit=None) for dim, size in data_array.sizes.items() if dim not in data_array.coords } self._data_array = data_array.assign_coords(default_indices) + self._scipp = scipp @staticmethod def try_from(obj: Any, *, axis_zero: int = 0) -> ScippDataArrayAdapter | None: @@ -226,15 +254,16 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> ScippDataArrayAdapter | None: import scipp if isinstance(obj, scipp.Variable): - return ScippDataArrayAdapter(scipp.DataArray(obj)) + return ScippDataArrayAdapter(scipp.DataArray(obj), scipp=scipp) if isinstance(obj, scipp.DataArray): - return ScippDataArrayAdapter(obj) + return ScippDataArrayAdapter(obj, scipp=scipp) except ModuleNotFoundError: pass - def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: - import scipp + def _equal(self, other: ScippDataArrayAdapter) -> bool: + return self._scipp.identical(self._data_array, other._data_array) + def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: values = self._data_array for dim, value in key: # Reconstruct label, to use label-based indexing instead of positional @@ -242,7 +271,7 @@ def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: value, unit = value else: unit = None - label = scipp.scalar(value, unit=unit) + label = self._scipp.scalar(value, unit=unit) # Scipp indexing uses a comma to separate dimension label from the index, # unlike Numpy and other libraries where it separates the indices for # different axes. @@ -253,7 +282,7 @@ def __getitem__(self, key: dict[IndexName, slice]) -> ScippDataArrayAdapter: values = self._data_array for dim, i in key: values = values[dim, i] - return ScippDataArrayAdapter(values) + return ScippDataArrayAdapter(values, scipp=self._scipp) @property def shape(self) -> tuple[int, ...]: @@ -307,6 +336,13 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> NumpyArrayAdapter | None: if isinstance(obj, numpy.ndarray): return NumpyArrayAdapter(obj, axis_zero=axis_zero) + def _equal(self, other: NumpyArrayAdapter) -> bool: + return ( + (self._array == other._array).all() + and self._indices == other._indices + and self._axis_zero == other._axis_zero + ) + def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any: index_tuple = tuple(self._indices[k].index(i) for k, i in key) return self._array[index_tuple] @@ -374,6 +410,11 @@ def from_mapping( return NodeValues(value_arrays) def merge(self, value_arrays: Mapping[Hashable, ValueArray]) -> NodeValues: + value_arrays = { + key: value + for key, value in value_arrays.items() + if self.get(key, None) != value + } if value_arrays: named = next(iter(value_arrays.values())).index_names if any(name in self.indices for name in named): diff --git a/tests/graph_test.py b/tests/graph_test.py index 7e657fa..fe80de8 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -735,15 +735,67 @@ def test_setitem_with_mapped_operands_raises_on_conflict() -> None: def test_setitem_currently_does_not_allow_compatible_indices() -> None: g = nx.DiGraph() - g.add_edge('a', 'c') - g.add_edge('b', 'c') + g.add_edge('a', 'b') + g.add_edge('c', 'd') graph = cb.Graph(g) - mapped = graph.map({'a': [1, 2, 3], 'b': [11, 12, 13]}).reduce('c', name='d') + mapped1 = graph.map({'a': [1, 2, 3]}) + mapped2 = graph['d'].map({'c': [11, 12, 13]}).reduce('d', name='e') # Note: This is a limitation of the current implementation. We could check if the # indices are identical and allow this. For simplicity we currently do not. with pytest.raises(ValueError, match="Conflicting new index names"): - mapped['x'] = mapped['d'] + mapped1['x'] = mapped2 + + +@pytest.mark.parametrize( + 'node_values', + [ + {'a': [1, 2, 3], 'b': [11, 12, 13]}, + {'a': np.array([1, 2, 3]), 'b': np.array([11, 12, 13])}, + pd.DataFrame({'a': [1, 2, 3], 'b': [11, 12, 13]}), + { + 'a': sc.array(dims=['x'], values=[1, 2, 3]), + 'b': sc.array(dims=['x'], values=[11, 12, 13]), + }, + { + 'a': xr.DataArray(dims=('x',), data=[1, 2, 3]), + 'b': xr.DataArray(dims=('x',), data=[11, 12, 13]), + }, + ], +) +def test_setitem_allows_compatible_node_values(node_values) -> None: + g = nx.DiGraph() + g.add_edge('a', 'c') + g.add_edge('b', 'c') + + graph = cb.Graph(g) + mapped = graph.map(node_values).reduce('c', name='d') + mapped['x'] = mapped['d'] + assert len(mapped.index_names) == 1 + + +def test_setitem_raises_if_node_values_equivalent_but_of_different_type() -> None: + g = nx.DiGraph() + g.add_edge('a', 'b') + graph = cb.Graph(g) + mapped1 = graph.map({'a': [1, 2]}).reduce('b', name='d') + mapped2 = graph.map({'a': np.array([1, 2])}).reduce('b', name='d') + # One could imagine treating this as equivalent, but we are strict in the + # comparison. + with pytest.raises(ValueError, match="Conflicting new index names"): + mapped1['x'] = mapped2['d'] + + +def test_setitem_raises_if_node_values_incompatible() -> None: + g = nx.DiGraph() + g.add_edge('a', 'b') + graph = cb.Graph(g) + mapped1 = graph.map({'a': [1, 2]}).reduce('b', name='d') + mapped2 = graph.map({'a': sc.array(dims=('x',), values=[1, 2])}).reduce( + 'b', name='d' + ) + with pytest.raises(ValueError, match="has already been mapped"): + mapped1['x'] = mapped2['d'] def test_setitem_does_currently_not_support_slice_assignment() -> None: