Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow compatible mapped values in __setitem__ #7

Merged
merged 3 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions src/cyclebane/node_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar
from types import ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

if TYPE_CHECKING:
import numpy
Expand All @@ -15,6 +16,8 @@
IndexName = Hashable
IndexValue = Hashable

T = TypeVar('T', bound='ValueArray')


class ValueArray(ABC):
"""
Expand Down Expand Up @@ -45,6 +48,17 @@ def from_array_like(values: Any, *, axis_zero: int = 0) -> ValueArray:
@abstractmethod
def try_from(obj: Any, *, axis_zero: int = 0) -> ValueArray | None: ...

def __eq__(self, other: object) -> bool:
if type(self) is not type(other):
return NotImplemented
return self._equal(other)

def __ne__(self, other: object) -> bool:
return not self == other

@abstractmethod
def _equal(self: T, other: T) -> bool: ...

@abstractmethod
def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
"""Return data by selecting from index with given name and index value."""
Expand Down Expand Up @@ -94,6 +108,13 @@ def __init__(
def try_from(obj: Any, *, axis_zero: int = 0) -> SequenceAdapter | None:
return SequenceAdapter(obj, axis_zero=axis_zero)

def _equal(self, other: SequenceAdapter) -> bool:
return (
self._values == other._values
and self._index == other._index
and self._axis_zero == other._axis_zero
)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
if len(key) != 1:
raise ValueError('SequenceAdapter only supports single index')
Expand Down Expand Up @@ -133,6 +154,11 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> PandasSeriesAdapter | None:
if isinstance(obj, pandas.Series):
return PandasSeriesAdapter(obj, axis_zero=axis_zero)

def _equal(self, other: PandasSeriesAdapter) -> bool:
return (
self._series.equals(other._series) and self._axis_zero == other._axis_zero
)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
if len(key) != 1:
raise ValueError('PandasSeriesAdapter only supports single index')
Expand Down Expand Up @@ -188,6 +214,9 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> XarrayDataArrayAdapter | None:
except ModuleNotFoundError:
pass

def _equal(self, other: XarrayDataArrayAdapter) -> bool:
return self._data_array.identical(other._data_array)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
return self._data_array.sel(dict(key))

Expand All @@ -210,39 +239,39 @@ def indices(self) -> dict[IndexName, Iterable[IndexValue]]:


class ScippDataArrayAdapter(ValueArray):
def __init__(self, data_array: scipp.DataArray):
import scipp

def __init__(self, data_array: scipp.DataArray, scipp: ModuleType):
default_indices = {
dim: scipp.arange(dim, size, unit=None)
for dim, size in data_array.sizes.items()
if dim not in data_array.coords
}
self._data_array = data_array.assign_coords(default_indices)
self._scipp = scipp

@staticmethod
def try_from(obj: Any, *, axis_zero: int = 0) -> ScippDataArrayAdapter | None:
try:
import scipp

if isinstance(obj, scipp.Variable):
return ScippDataArrayAdapter(scipp.DataArray(obj))
return ScippDataArrayAdapter(scipp.DataArray(obj), scipp=scipp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit awkward that we have to pass the module as an arg everywhere...
Was it to avoid importing scipp in multiple places?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, making the code cleaner.

if isinstance(obj, scipp.DataArray):
return ScippDataArrayAdapter(obj)
return ScippDataArrayAdapter(obj, scipp=scipp)
except ModuleNotFoundError:
pass

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
import scipp
def _equal(self, other: ScippDataArrayAdapter) -> bool:
return self._scipp.identical(self._data_array, other._data_array)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
values = self._data_array
for dim, value in key:
# Reconstruct label, to use label-based indexing instead of positional
if isinstance(value, tuple):
value, unit = value
else:
unit = None
label = scipp.scalar(value, unit=unit)
label = self._scipp.scalar(value, unit=unit)
# Scipp indexing uses a comma to separate dimension label from the index,
# unlike Numpy and other libraries where it separates the indices for
# different axes.
Expand All @@ -253,7 +282,7 @@ def __getitem__(self, key: dict[IndexName, slice]) -> ScippDataArrayAdapter:
values = self._data_array
for dim, i in key:
values = values[dim, i]
return ScippDataArrayAdapter(values)
return ScippDataArrayAdapter(values, scipp=self._scipp)

@property
def shape(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -307,6 +336,13 @@ def try_from(obj: Any, *, axis_zero: int = 0) -> NumpyArrayAdapter | None:
if isinstance(obj, numpy.ndarray):
return NumpyArrayAdapter(obj, axis_zero=axis_zero)

def _equal(self, other: NumpyArrayAdapter) -> bool:
return (
(self._array == other._array).all()
and self._indices == other._indices
and self._axis_zero == other._axis_zero
)

def sel(self, key: tuple[tuple[IndexName, IndexValue], ...]) -> Any:
index_tuple = tuple(self._indices[k].index(i) for k, i in key)
return self._array[index_tuple]
Expand Down Expand Up @@ -374,6 +410,11 @@ def from_mapping(
return NodeValues(value_arrays)

def merge(self, value_arrays: Mapping[Hashable, ValueArray]) -> NodeValues:
value_arrays = {
key: value
for key, value in value_arrays.items()
if self.get(key, None) != value
}
if value_arrays:
named = next(iter(value_arrays.values())).index_names
if any(name in self.indices for name in named):
Expand Down
65 changes: 61 additions & 4 deletions tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,15 +735,72 @@ def test_setitem_with_mapped_operands_raises_on_conflict() -> None:

def test_setitem_currently_does_not_allow_compatible_indices() -> None:
g = nx.DiGraph()
g.add_edge('a', 'c')
g.add_edge('b', 'c')
g.add_edge('a', 'b')
g.add_edge('c', 'd')

graph = cb.Graph(g)
mapped = graph.map({'a': [1, 2, 3], 'b': [11, 12, 13]}).reduce('c', name='d')
mapped1 = graph.map({'a': [1, 2, 3]})
mapped2 = graph['d'].map({'c': [11, 12, 13]}).reduce('d', name='e')
# Note: This is a limitation of the current implementation. We could check if the
# indices are identical and allow this. For simplicity we currently do not.
with pytest.raises(ValueError, match="Conflicting new index names"):
mapped['x'] = mapped['d']
mapped1['x'] = mapped2


@pytest.mark.parametrize(
'node_values',
[
{'a': [1, 2, 3]},
{'a': [1, 2, 3], 'b': [11, 12, 13]},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that {'a': [1, 2, 3]} is sufficient to make the test fail with the code on main. So there seems to be no need to have multiple mapped nodes to test this behaviour. Can you add a test that only maps one node?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

{'a': np.array([1, 2, 3])},
{'a': np.array([1, 2, 3]), 'b': np.array([11, 12, 13])},
pd.DataFrame({'a': [1, 2, 3]}),
pd.DataFrame({'a': [1, 2, 3], 'b': [11, 12, 13]}),
{'a': sc.array(dims=['x'], values=[1, 2, 3])},
{
'a': sc.array(dims=['x'], values=[1, 2, 3]),
'b': sc.array(dims=['x'], values=[11, 12, 13]),
},
Comment on lines +760 to +763
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm a little confused because the name above is a ScippDataArrayAdapter, so I thought it would be a sc.DataArray, like the pandas DataFrame above.
But I guess you just need a mapping of key to ArrayLike? Does this mean that you never use the .data in the DataArray?
I assume you need a structure that can be sliced/indexed. Could it be a DataGroup instead of a DataArray internally?

I realize this is besides the point of the current PR...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScippDataArrayAdapter also handles scipp.Variable, interpreting the latter as an data array without coords.

I assume you need a structure that can be sliced/indexed. Could it be a DataGroup instead of a DataArray internally?

Using a dictcurrently, could in principle add support for Dataset and DataGroup. But not "instead of": The DataArray holds the values for a single node, the dict (or DataGroup, ...) maps to multiple nodes, just like the columns of pandas.DataFrame vs. its rows.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, I see now

{'a': xr.DataArray(dims=('x',), data=[1, 2, 3])},
{
'a': xr.DataArray(dims=('x',), data=[1, 2, 3]),
'b': xr.DataArray(dims=('x',), data=[11, 12, 13]),
},
],
)
def test_setitem_allows_compatible_node_values(node_values) -> None:
g = nx.DiGraph()
g.add_edge('a', 'c')
g.add_edge('b', 'c')

graph = cb.Graph(g)
mapped = graph.map(node_values).reduce('c', name='d')
mapped['x'] = mapped['d']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the 'x' here is unrelated to the 'x' dimension in

        {
            'a': sc.array(dims=['x'], values=[1, 2, 3]),
            'b': sc.array(dims=['x'], values=[11, 12, 13]),
        }

above.

If so, can we use a different name here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was asking because at first it confused me, as I thought they were related.

assert len(mapped.index_names) == 1


def test_setitem_raises_if_node_values_equivalent_but_of_different_type() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
graph = cb.Graph(g)
mapped1 = graph.map({'a': [1, 2]}).reduce('b', name='d')
mapped2 = graph.map({'a': np.array([1, 2])}).reduce('b', name='d')
# One could imagine treating this as equivalent, but we are strict in the
# comparison.
with pytest.raises(ValueError, match="Conflicting new index names"):
mapped1['x'] = mapped2['d']


def test_setitem_raises_if_node_values_incompatible() -> None:
g = nx.DiGraph()
g.add_edge('a', 'b')
graph = cb.Graph(g)
mapped1 = graph.map({'a': [1, 2]}).reduce('b', name='d')
mapped2 = graph.map({'a': sc.array(dims=('x',), values=[1, 2])}).reduce(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand why this is different from the above? In mapped1 you have [1, 2] and in mapped2 a Scipp Variable. Isn't it the same as with the numpy array, that the types are different?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it actually raises for different reasons, this test is to ensure that both code paths work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my question was why is it raising for a different reason, I would have expected to raise with the same reason.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an artifact of the usual problem of having two checks in a particular order, so depending on the exact properties of the input you get one exception or the other.

'b', name='d'
)
with pytest.raises(ValueError, match="has already been mapped"):
mapped1['x'] = mapped2['d']


def test_setitem_does_currently_not_support_slice_assignment() -> None:
Expand Down