Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix view behavior for AwkwardArrays #1070

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
24 changes: 20 additions & 4 deletions anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,27 @@ class AlignedViewMixin:
parent_mapping: Mapping[str, V]
"""The object this is a view of."""

copied_objects: Mapping[str, V]
"""Objects that are copied at view creation, because they natively support copy on write"""

is_view = True

def __getitem__(self, key: str) -> V:
return as_view(
_subset(self.parent_mapping[key], self.subset_idx),
ElementRef(self.parent, self.attrname, (key,)),
)
try:
return self.copied_objects[key]
# key might not exist, or copied_objects might not yet be initialized
except (KeyError, AttributeError):
return as_view(
_subset(self.parent_mapping[key], self.subset_idx),
ElementRef(self.parent, self.attrname, (key,)),
)

def _copy_objects(self):
"""For some objects (Awkward arrays) we want to store a copy of the slice at view creation."""
self.copied_objects = {}
for key, value in self.parent_mapping.items():
if isinstance(value, AwkArray):
self.copied_objects[key] = AwkArray(self[key])

def __setitem__(self, key: str, value: V):
value = self._validate_value(value, key) # Validate before mutating
Expand Down Expand Up @@ -286,9 +300,11 @@ def __init__(
subset_idx: OneDIdx,
):
self.parent_mapping = parent_mapping
self.copied_objects = {}
self._parent = parent_view
self.subset_idx = subset_idx
self._axis = parent_mapping._axis
self._copy_objects()


AxisArraysBase._view_class = AxisArraysView
Expand Down
72 changes: 9 additions & 63 deletions anndata/_core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,15 @@ def as_view_zappy(z, view_args):
return z


@as_view.register(AwkArray)
def as_view_awkarray(array, view_args):
# We don't need any specific view behavior for awkward arrays. A slice of an awkward array is always a
# shallow copy of the original. This implies that setting a record field on a slice never modifies the original.
# Other fields than records are entirely immutable anyway.
# See also https://github.com/scverse/anndata/issues/1035#issuecomment-1687619270.
return AwkArray(array)


@as_view.register(CupyArray)
def as_view_cupy(array, view_args):
return CupyArrayView(array, view_args=view_args)
Expand All @@ -315,69 +324,6 @@ def as_view_cupy_csc(mtx, view_args):
return CupySparseCSCView(mtx, view_args=view_args)


try:
from ..compat import awkward as ak
import weakref

# Registry to store weak references from AwkwardArrayViews to their parent AnnData container
_registry = weakref.WeakValueDictionary()
_PARAM_NAME = "_view_args"

class AwkwardArrayView(_ViewMixin, AwkArray):
@property
def _view_args(self):
"""Override _view_args to retrieve the values from awkward arrays parameters.

Awkward arrays cannot be subclassed like other python objects. Instead subclasses need
to be attached as "behavior". These "behaviors" cannot take any additional parameters (as we do
for other data types to store `_view_args`). Therefore, we need to store `_view_args` using awkward's
parameter mechanism. These parameters need to be json-serializable, which is why we can't store
ElementRef directly, but need to replace the reference to the parent AnnDataView container with a weak
reference.
"""
parent_key, attrname, keys = self.layout.parameter(_PARAM_NAME)
parent = _registry[parent_key]
return ElementRef(parent, attrname, keys)

def __copy__(self) -> AwkArray:
"""
Turn the AwkwardArrayView into an actual AwkwardArray with no special behavior.

Need to override __copy__ instead of `.copy()` as awkward arrays don't implement `.copy()`
and are copied using python's standard copy mechanism in `aligned_mapping.py`.
"""
array = self
# makes a shallow copy and removes the reference to the original AnnData object
array = ak.with_parameter(self, _PARAM_NAME, None)
array = ak.with_parameter(array, "__list__", None)
return array

@as_view.register(AwkArray)
def as_view_awkarray(array, view_args):
parent, attrname, keys = view_args
parent_key = f"target-{id(parent)}"
_registry[parent_key] = parent
# TODO: See https://github.com/scverse/anndata/pull/647#discussion_r963494798_ for more details and
# possible strategies to stack behaviors.
# A better solution might be based on xarray-style "attrs", once this is implemented
# https://github.com/scikit-hep/awkward/issues/1391#issuecomment-1412297114
if type(array).__name__ != "Array":
raise NotImplementedError(
"Cannot create a view of an awkward array with __array__ parameter. "
"Please open an issue in the AnnData repo and describe your use-case."
)
array = ak.with_parameter(array, _PARAM_NAME, (parent_key, attrname, keys))
array = ak.with_parameter(array, "__list__", "AwkwardArrayView")
return array

ak.behavior["AwkwardArrayView"] = AwkwardArrayView

except ImportError:

class AwkwardArrayView:
pass


def _resolve_idxs(old, new, adata):
t = tuple(_resolve_idx(old[i], new[i], adata.shape[i]) for i in (0, 1))
return t
Expand Down
6 changes: 0 additions & 6 deletions anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,6 @@ def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None)))

@_REGISTRY.register_write(H5Group, AwkArray, IOSpec("awkward-array", "0.1.0"))
@_REGISTRY.register_write(ZarrGroup, AwkArray, IOSpec("awkward-array", "0.1.0"))
@_REGISTRY.register_write(
H5Group, views.AwkwardArrayView, IOSpec("awkward-array", "0.1.0")
)
@_REGISTRY.register_write(
ZarrGroup, views.AwkwardArrayView, IOSpec("awkward-array", "0.1.0")
)
def write_awkward(f, k, v, _writer, dataset_kwargs=MappingProxyType({})):
from anndata.compat import awkward as ak

Expand Down
113 changes: 87 additions & 26 deletions anndata/tests/test_awkward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from anndata.tests.helpers import assert_equal, gen_adata, gen_awkward
from anndata.compat import awkward as ak
from anndata import ImplicitModificationWarning
from anndata.utils import dim_len
from anndata import AnnData, read_h5ad
import anndata
Expand Down Expand Up @@ -123,31 +122,92 @@ def test_copy(key):
getattr(adata, key)["awk"]["d"]


@pytest.mark.parametrize("key", ["obsm", "varm"])
def test_view(key):
"""Check that modifying a view does not modify the original"""
@pytest.mark.parametrize(
"array,setter_slice,setter_value,expected",
[
# Non-records are immutable and setting on them results in a TypeError
pytest.param(
[[1], [2, 3], [4, 5, 6]],
(slice(None), 2),
42,
TypeError,
id="immutable_ragged_list",
),
pytest.param(
np.zeros((3, 3)),
(slice(None), 1),
42,
TypeError,
id="immutable_regular_type",
),
pytest.param(
[{"a": 1}, {"a": 2}, {"a": 3}],
"a",
[42, 43, 44],
[{"a": 42}, {"a": 43}, {"a": 44}],
id="updating_record",
),
pytest.param(
[{"a": 1}, {"a": 2}, {"a": 3}],
"b",
[42, 43, 44],
[{"a": 1, "b": 42}, {"a": 2, "b": 43}, {"a": 3, "b": 44}],
id="adding_record",
),
pytest.param(
[{"outer": {"a": 1}}, {"outer": {"a": 2}}, {"outer": {"a": 3}}],
("outer", "a"),
[42, 43, 44],
[{"outer": {"a": 42}}, {"outer": {"a": 43}}, {"outer": {"a": 44}}],
id="updating_nested_record",
),
],
)
@pytest.mark.parametrize(
"key",
[
"obsm",
"varm",
# "uns",
],
)
def test_view(key, array, setter_slice, setter_value, expected):
"""Check that modifying a view does not modify the original.

Parameters
----------
key
key in anndata, obsm, varm, or uns
view_func
a function that returns a view of an AnnData object
array
The array that is assigned to adata[key]["awk"] for testing.
setter_slice
The slice used for setting a value on the awkward array with `arr[slice] = ...`
setter_value
The value assigned to the array with `arr[slice] = setter_value`
expected
The expected array after setting the value. Can be an exception if setting the value is supposed
to result in an error.
"""
adata = gen_adata((3, 3), varm_types=(), obsm_types=(), layers_types=())
getattr(adata, key)["awk"] = ak.Array([{"a": [1], "b": [2], "c": [3]}] * 3)
adata_view = adata[:2, :2]

with pytest.warns(ImplicitModificationWarning, match="initializing view as actual"):
getattr(adata_view, key)["awk"]["c"] = np.full((2, 1), 4)
getattr(adata_view, key)["awk"]["d"] = np.full((2, 1), 5)
getattr(adata, key)["awk"] = ak.Array(array)
adata_view = adata[:]
get_awk_view = lambda *_: getattr(adata_view, key)["awk"]

# values in view were correctly set
npt.assert_equal(getattr(adata_view, key)["awk"]["c"], np.full((2, 1), 4))
npt.assert_equal(getattr(adata_view, key)["awk"]["d"], np.full((2, 1), 5))

# values in original were not updated
npt.assert_equal(getattr(adata, key)["awk"]["c"], np.full((3, 1), 3))
with pytest.raises(IndexError):
getattr(adata, key)["awk"]["d"]
if isinstance(expected, type):
with pytest.raises(expected):
get_awk_view()[setter_slice] = setter_value
else:
get_awk_view()[setter_slice] = setter_value
# values in view are correctly set
assert ak.to_list(get_awk_view()) == expected
# values in original were not modified
assert ak.to_list(getattr(adata, key)["awk"]) == array


def test_view_of_awkward_array_with_custom_behavior():
"""Currently can't create view of arrays with custom __name__ (in this case "string")
See https://github.com/scverse/anndata/pull/647#discussion_r963494798_"""

"""Ensure that a custom behavior persists when creating a view."""
from uuid import uuid4

BEHAVIOUR_ID = str(uuid4())
Expand All @@ -157,14 +217,15 @@ def reversed(self):
return self[..., ::-1]

ak.behavior[BEHAVIOUR_ID] = ReversibleArray
ak.behavior["*", BEHAVIOUR_ID] = ReversibleArray

adata = gen_adata((3, 3), varm_types=(), obsm_types=(), layers_types=())
adata.obsm["awk_string"] = ak.with_parameter(
ak.Array(["AAA", "BBB", "CCC"]), "__list__", BEHAVIOUR_ID
adata.obsm["awk_string"] = ak.with_name(
ak.Array([{"a": "AAA"}, {"a": "BBB"}, {"a": "CCC"}]), BEHAVIOUR_ID
)
adata_view = adata[:2]
ak_view = adata[1:]

with pytest.raises(NotImplementedError):
adata_view.obsm["awk_string"]
assert ak.to_list(ak_view.obsm["awk_string"].reversed()["a"]) == ["CCC", "BBB"]


@pytest.mark.parametrize(
Expand Down
Loading