diff --git a/anndata/_core/aligned_mapping.py b/anndata/_core/aligned_mapping.py index 156146518..14c1face2 100644 --- a/anndata/_core/aligned_mapping.py +++ b/anndata/_core/aligned_mapping.py @@ -4,6 +4,7 @@ from typing import Union, Optional, Type, ClassVar, TypeVar # Special types from typing import Iterator, Mapping, Sequence # ABCs from typing import Tuple, List, Dict # Generic base types +import weakref import warnings import numpy as np @@ -65,8 +66,8 @@ def _validate_value(self, val: V, key: str) -> V: message="Support for Awkward Arrays is currently experimental.*", ) for i, axis in enumerate(self.axes): - if self.parent.shape[axis] != dim_len(val, i): - right_shape = tuple(self.parent.shape[a] for a in self.axes) + if self.parent_shape[axis] != dim_len(val, i): + right_shape = tuple(self.parent_shape[a] for a in self.axes) actual_shape = tuple(dim_len(val, a) for a, _ in enumerate(self.axes)) if actual_shape[i] is None and isinstance(val, AwkArray): raise ValueError( @@ -107,6 +108,12 @@ def is_view(self) -> bool: def parent(self) -> Union["anndata.AnnData", "raw.Raw"]: return self._parent + @property + def parent_shape(self) -> Tuple[int, int]: + if hasattr(self, "_parent_shape"): + return self._parent_shape + return self._parent.shape + def copy(self): d = self._actual_class(self.parent, self._axis) for k, v in self.items(): @@ -259,7 +266,7 @@ def _validate_value(self, val: V, key: str) -> V: @property def dim_names(self) -> pd.Index: - return (self.parent.obs_names, self.parent.var_names)[self._axis] + return (self._parent.obs_names, self._parent.var_names)[self._axis] class AxisArrays(AlignedActualMixin, AxisArraysBase): @@ -269,14 +276,39 @@ def __init__( axis: int, vals: Union[Mapping, AxisArraysBase, None] = None, ): - self._parent = parent + if isinstance(parent, anndata.AnnData): + self._parent_ref = weakref.ref(parent) + self._is_weak = True + else: + self._parent_ref = parent + self._is_weak = False if axis not in (0, 1): raise ValueError() self._axis = axis + + self._parent_shape = parent.shape + # self.dim_names = (parent.obs_names, parent.var_names)[self._axis] self._data = dict() if vals is not None: self.update(vals) + @property + def _parent(self) -> Union["anndata.AnnData", "raw.Raw"]: + if self._is_weak: + return self._parent_ref() + return self._parent_ref + + def __getstate__(self): + state = self.__dict__.copy() + if self._is_weak: + state["_parent_ref"] = state["_parent_ref"]() + return state + + def __setstate__(self, state): + self.__dict__ = state.copy() + if self._is_weak: + self.__dict__["_parent_ref"] = weakref.ref(state["_parent_ref"]) + class AxisArraysView(AlignedViewMixin, AxisArraysBase): def __init__( @@ -315,11 +347,24 @@ def copy(self) -> "Layers": class Layers(AlignedActualMixin, LayersBase): def __init__(self, parent: "anndata.AnnData", vals: Optional[Mapping] = None): - self._parent = parent + self._parent_ref = weakref.ref(parent) self._data = dict() if vals is not None: self.update(vals) + @property + def _parent(self): + return self._parent_ref() + + def __getstate__(self): + state = self.__dict__.copy() + state["_parent_ref"] = state["_parent_ref"]() + return state + + def __setstate__(self, state): + self.__dict__ = state.copy() + self.__dict__["_parent_ref"] = weakref.ref(state["_parent_ref"]) + class LayersView(AlignedViewMixin, LayersBase): def __init__( @@ -368,7 +413,7 @@ def __init__( axis: int, vals: Optional[Mapping] = None, ): - self._parent = parent + self._parent_ref = weakref.ref(parent) if axis not in (0, 1): raise ValueError() self._axis = axis @@ -376,6 +421,19 @@ def __init__( if vals is not None: self.update(vals) + @property + def _parent(self): + return self._parent_ref() + + def __getstate__(self): + state = self.__dict__.copy() + state["_parent_ref"] = state["_parent_ref"]() + return state + + def __setstate__(self, state): + self.__dict__ = state.copy() + self.__dict__["_parent_ref"] = weakref.ref(state["_parent_ref"]) + class PairwiseArraysView(AlignedViewMixin, PairwiseArraysBase): def __init__( diff --git a/anndata/_core/file_backing.py b/anndata/_core/file_backing.py index 02401873c..03e98bfea 100644 --- a/anndata/_core/file_backing.py +++ b/anndata/_core/file_backing.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Optional, Union, Iterator, Literal from collections.abc import Mapping +import weakref import h5py @@ -20,13 +21,26 @@ def __init__( filename: Optional[PathLike] = None, filemode: Optional[Literal["r", "r+"]] = None, ): - self._adata = adata + self._adata_ref = weakref.ref(adata) self.filename = filename self._filemode = filemode self._file = None if filename: self.open() + def __getstate__(self): + state = self.__dict__.copy() + state["_adata_ref"] = state["_adata_ref"]() + return state + + def __setstate__(self, state): + self.__dict__ = state.copy() + self.__dict__["_adata_ref"] = weakref.ref(state["_adata_ref"]) + + @property + def _adata(self): + return self._adata_ref() + def __repr__(self) -> str: if self.filename is None: return "Backing file manager: no file is set." diff --git a/anndata/tests/test_base.py b/anndata/tests/test_base.py index d045bc5ca..7dfa3a2a0 100644 --- a/anndata/tests/test_base.py +++ b/anndata/tests/test_base.py @@ -1,4 +1,5 @@ from itertools import product +import tracemalloc import warnings import numpy as np @@ -601,6 +602,52 @@ def assert_eq_not_id(a, b): assert_eq_not_id(map_sprs[key], map_copy[key]) +def test_memory_usage(): + N, M = 100, 200 + RUNS = 10 + obs_df = pd.DataFrame( + dict( + cat=pd.Categorical(np.arange(N, dtype=int)), + int=np.arange(N, dtype=int), + float=np.arange(N, dtype=float), + obj=[str(i) for i in np.arange(N, dtype=int)], + ), + index=[f"cell{i}" for i in np.arange(N, dtype=int)], + ) + var_df = pd.DataFrame( + dict( + cat=pd.Categorical(np.arange(M, dtype=int)), + int=np.arange(M, dtype=int), + float=np.arange(M, dtype=float), + obj=[str(i) for i in np.arange(M, dtype=int)], + ), + index=[f"gene{i}" for i in np.arange(M, dtype=int)], + ) + + def get_memory(snapshot, key_type="lineno"): + snapshot = snapshot.filter_traces( + ( + tracemalloc.Filter(False, ""), + tracemalloc.Filter(False, ""), + ) + ) + total = sum(stat.size for stat in snapshot.statistics(key_type)) + return total + + total = np.zeros(RUNS) + # Instantiate the anndata object first before memory calculation to + # only look at memory changes due to deletion of such a object. + adata = AnnData(X=np.random.random((N, M)), obs=obs_df, var=var_df) + adata.X[0, 0] = 1.0 # Disable Codacy issue + tracemalloc.start() + for i in range(RUNS): + adata = AnnData(X=np.random.random((N, M)), obs=obs_df, var=var_df) + total[i] = get_memory(tracemalloc.take_snapshot()) + tracemalloc.stop() + relative_increase = total[:-1] / total[1:] + np.testing.assert_allclose(relative_increase, 1.0, atol=0.2) + + def test_to_memory_no_copy(): adata = gen_adata((3, 5)) mem = adata.to_memory()