Skip to content

Commit

Permalink
Allow compatible values in __setitem__
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonHeybrock committed Jun 6, 2024
1 parent 806f8e8 commit c1ca804
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 14 deletions.
61 changes: 51 additions & 10 deletions src/cyclebane/node_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from types import ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

if TYPE_CHECKING:
import numpy
Expand All @@ -15,6 +16,8 @@
IndexName = Hashable
IndexValue = Hashable

T = TypeVar('T', bound='ValueArray')


class ValueArray(ABC):
"""
Expand Down Expand Up @@ -45,6 +48,17 @@ def from_array_like(values: Any, *, axis_zero: int = 0) -> ValueArray:
@abstractmethod
def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: ...

def __eq__(self, other: object) -> bool:
if type(self) != type(other):
return False
return self._equal(other)

def __ne__(self, other: object) -> bool:
return not self == other

@abstractmethod
def _equal(self: T, other: T) -> bool: ...

@abstractmethod
def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
"""Return data by selecting from index with given name and index value."""
Expand Down Expand Up @@ -94,6 +108,13 @@ def __init__(
def try_from(obj: Any, *, axis_zero: int = 0) -> SequenceAdapter | None:
return SequenceAdapter(obj, axis_zero=axis_zero)

def _equal(self, other: SequenceAdapter) -> bool:
return (
self._values == other._values
and self._index == other._index
and self._axis_zero == other._axis_zero
)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
if len(key) != 1:
raise ValueError('SequenceAdapter only supports single index')
Expand Down Expand Up @@ -133,6 +154,11 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> PandasSeriesAdapter | None:
if isinstance(obj, pandas.Series):
return PandasSeriesAdapter(obj, axis_zero=axis_zero)

def _equal(self, other: PandasSeriesAdapter) -> bool:
return (
self._series.equals(other._series) and self._axis_zero == other._axis_zero
)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
if len(key) != 1:
raise ValueError('PandasSeriesAdapter only supports single index')
Expand Down Expand Up @@ -188,6 +214,9 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> XarrayDataArrayAdapter | None:
except ModuleNotFoundError:
pass

def _equal(self, other: XarrayDataArrayAdapter) -> bool:
return self._data_array.identical(other._data_array)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
return self._data_array.sel(dict(key))

Expand All @@ -210,39 +239,39 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:


class ScippDataArrayAdapter(ValueArray):
def __init__(self, data_array: scipp.DataArray):
import scipp

def __init__(self, data_array: scipp.DataArray, scipp: ModuleType):
default_indices = {
dim: scipp.arange(dim, size, unit=None)
for dim, size in data_array.sizes.items()
if dim not in data_array.coords
}
self._data_array = data_array.assign_coords(default_indices)
self._scipp = scipp

@staticmethod
def try_from(obj: Any, *, axis_zero: int = 0) -> ScippDataArrayAdapter | None:
try:
import scipp

if isinstance(obj, scipp.Variable):
return ScippDataArrayAdapter(scipp.DataArray(obj))
return ScippDataArrayAdapter(scipp.DataArray(obj), scipp=scipp)
if isinstance(obj, scipp.DataArray):
return ScippDataArrayAdapter(obj)
return ScippDataArrayAdapter(obj, scipp=scipp)
except ModuleNotFoundError:
pass

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
import scipp
def _equal(self, other: ScippDataArrayAdapter) -> bool:
return self._scipp.identical(self._data_array, other._data_array)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
values = self._data_array
for dim, value in key:
# Reconstruct label, to use label-based indexing instead of positional
if isinstance(value, tuple):
value, unit = value
else:
unit = None
label = scipp.scalar(value, unit=unit)
label = self._scipp.scalar(value, unit=unit)
# Scipp indexing uses a comma to separate dimension label from the index,
# unlike Numpy and other libraries where it separates the indices for
# different axes.
Expand All @@ -253,7 +282,7 @@ def __getitem__(self, key: dict[IndexName, slice]) -> ScippDataArrayAdapter:
values = self._data_array
for dim, i in key:
values = values[dim, i]
return ScippDataArrayAdapter(values)
return ScippDataArrayAdapter(values, scipp=self._scipp)

@property
def shape(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -307,6 +336,13 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> NumpyArrayAdapter | None:
if isinstance(obj, numpy.ndarray):
return NumpyArrayAdapter(obj, axis_zero=axis_zero)

def _equal(self, other: NumpyArrayAdapter) -> bool:
return (
(self._array == other._array).all()
and self._indices == other._indices
and self._axis_zero == other._axis_zero
)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
index_tuple = tuple(self._indices[k].index(i) for k, i in key)
return self._array[index_tuple]
Expand Down Expand Up @@ -374,6 +410,11 @@ def from_mapping(
return NodeValues(value_arrays)

def merge(self, value_arrays: Mapping[Hashable, ValueArray]) -> NodeValues:
value_arrays = {
key: value
for key, value in value_arrays.items()
if self.get(key, None) != value
}
if value_arrays:
named = next(iter(value_arrays.values())).index_names
if any(name in self.indices for name in named):
Expand Down
60 changes: 56 additions & 4 deletions tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,15 +735,67 @@ def test_setitem_with_mapped_operands_raises_on_conflict() -> None:

def test_setitem_currently_does_not_allow_compatible_indices() -> None:
g = nx.DiGraph()
g.add_edge('a', 'c')
g.add_edge('b', 'c')
g.add_edge('a', 'b')
g.add_edge('c', 'd')

graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3], 'b': [11, 12, 13]}).reduce('c', name='d')
mapped1 = graph.map({'a': [1, 2, 3]})
mapped2 = graph['d'].map({'c': [11, 12, 13]}).reduce('d', name='e')
# Note: This is a limitation of the current implementation. We could check if the
# indices are identical and allow this. For simplicity we currently do not.
with pytest.raises(ValueError, match="Conflicting new index names"):
mapped['x'] = mapped['d']
mapped1['x'] = mapped2


@pytest.mark.parametrize(
'node_values',
[
{'a': [1, 2, 3], 'b': [11, 12, 13]},
{'a': np.array([1, 2, 3]), 'b': np.array([11, 12, 13])},
pd.DataFrame({'a': [1, 2, 3], 'b': [11, 12, 13]}),
{
'a': sc.array(dims=['x'], values=[1, 2, 3]),
'b': sc.array(dims=['x'], values=[11, 12, 13]),
},
{
'a': xr.DataArray(dims=('x',), data=[1, 2, 3]),
'b': xr.DataArray(dims=('x',), data=[11, 12, 13]),
},
],
)
def test_setitem_allows_compatible_node_values(node_values) -> None:
g = nx.DiGraph()
g.add_edge('a', 'c')
g.add_edge('b', 'c')

graph = cb.Graph(g)
mapped = graph.map(node_values).reduce('c', name='d')
mapped['x'] = mapped['d']
assert len(mapped.index_names) == 1


def test_setitem_raises_if_node_values_equivalent_but_of_different_type() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
graph = cb.Graph(g)
mapped1 = graph.map({'a': [1, 2]}).reduce('b', name='d')
mapped2 = graph.map({'a': np.array([1, 2])}).reduce('b', name='d')
# One could imagine treating this as equivalent, but we are strict in the
# comparison.
with pytest.raises(ValueError, match="Conflicting new index names"):
mapped1['x'] = mapped2['d']


def test_setitem_raises_if_node_values_incompatible() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
graph = cb.Graph(g)
mapped1 = graph.map({'a': [1, 2]}).reduce('b', name='d')
mapped2 = graph.map({'a': sc.array(dims=('x',), values=[1, 2])}).reduce(
'b', name='d'
)
with pytest.raises(ValueError, match="has already been mapped"):
mapped1['x'] = mapped2['d']


def test_setitem_does_currently_not_support_slice_assignment() -> None:
Expand Down

0 comments on commit c1ca804

Please sign in to comment.