Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory consumption increase for anndata objects #363

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
70 changes: 64 additions & 6 deletions anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Union, Optional, Type, ClassVar, TypeVar # Special types
from typing import Iterator, Mapping, Sequence # ABCs
from typing import Tuple, List, Dict # Generic base types
import weakref
import warnings

import numpy as np
Expand Down Expand Up @@ -65,8 +66,8 @@ def _validate_value(self, val: V, key: str) -> V:
message="Support for Awkward Arrays is currently experimental.*",
)
for i, axis in enumerate(self.axes):
if self.parent.shape[axis] != dim_len(val, i):
right_shape = tuple(self.parent.shape[a] for a in self.axes)
if self.parent_shape[axis] != dim_len(val, i):
right_shape = tuple(self.parent_shape[a] for a in self.axes)
actual_shape = tuple(dim_len(val, a) for a, _ in enumerate(self.axes))
if actual_shape[i] is None and isinstance(val, AwkArray):
raise ValueError(
Expand Down Expand Up @@ -107,6 +108,12 @@ def is_view(self) -> bool:
def parent(self) -> Union["anndata.AnnData", "raw.Raw"]:
return self._parent

@property
def parent_shape(self) -> Tuple[int, int]:
if hasattr(self, "_parent_shape"):
return self._parent_shape
return self._parent.shape

def copy(self):
d = self._actual_class(self.parent, self._axis)
for k, v in self.items():
Expand Down Expand Up @@ -259,7 +266,7 @@ def _validate_value(self, val: V, key: str) -> V:

@property
def dim_names(self) -> pd.Index:
return (self.parent.obs_names, self.parent.var_names)[self._axis]
return (self._parent.obs_names, self._parent.var_names)[self._axis]


class AxisArrays(AlignedActualMixin, AxisArraysBase):
Expand All @@ -269,14 +276,39 @@ def __init__(
axis: int,
vals: Union[Mapping, AxisArraysBase, None] = None,
):
self._parent = parent
if isinstance(parent, anndata.AnnData):
self._parent_ref = weakref.ref(parent)
self._is_weak = True
else:
self._parent_ref = parent
self._is_weak = False
if axis not in (0, 1):
raise ValueError()
self._axis = axis

self._parent_shape = parent.shape
# self.dim_names = (parent.obs_names, parent.var_names)[self._axis]
self._data = dict()
if vals is not None:
self.update(vals)

@property
def _parent(self) -> Union["anndata.AnnData", "raw.Raw"]:
if self._is_weak:
return self._parent_ref()
return self._parent_ref

def __getstate__(self):
state = self.__dict__.copy()
if self._is_weak:
state["_parent_ref"] = state["_parent_ref"]()
return state

def __setstate__(self, state):
self.__dict__ = state.copy()
if self._is_weak:
self.__dict__["_parent_ref"] = weakref.ref(state["_parent_ref"])


class AxisArraysView(AlignedViewMixin, AxisArraysBase):
def __init__(
Expand Down Expand Up @@ -315,11 +347,24 @@ def copy(self) -> "Layers":

class Layers(AlignedActualMixin, LayersBase):
def __init__(self, parent: "anndata.AnnData", vals: Optional[Mapping] = None):
self._parent = parent
self._parent_ref = weakref.ref(parent)
self._data = dict()
if vals is not None:
self.update(vals)

@property
def _parent(self):
return self._parent_ref()

def __getstate__(self):
state = self.__dict__.copy()
state["_parent_ref"] = state["_parent_ref"]()
return state

def __setstate__(self, state):
self.__dict__ = state.copy()
self.__dict__["_parent_ref"] = weakref.ref(state["_parent_ref"])


class LayersView(AlignedViewMixin, LayersBase):
def __init__(
Expand Down Expand Up @@ -368,14 +413,27 @@ def __init__(
axis: int,
vals: Optional[Mapping] = None,
):
self._parent = parent
self._parent_ref = weakref.ref(parent)
if axis not in (0, 1):
raise ValueError()
self._axis = axis
self._data = dict()
if vals is not None:
self.update(vals)

@property
def _parent(self):
return self._parent_ref()

def __getstate__(self):
state = self.__dict__.copy()
state["_parent_ref"] = state["_parent_ref"]()
return state

def __setstate__(self, state):
self.__dict__ = state.copy()
self.__dict__["_parent_ref"] = weakref.ref(state["_parent_ref"])


class PairwiseArraysView(AlignedViewMixin, PairwiseArraysBase):
def __init__(
Expand Down
16 changes: 15 additions & 1 deletion anndata/_core/file_backing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Optional, Union, Iterator, Literal
from collections.abc import Mapping
import weakref

import h5py

Expand All @@ -20,13 +21,26 @@ def __init__(
filename: Optional[PathLike] = None,
filemode: Optional[Literal["r", "r+"]] = None,
):
self._adata = adata
self._adata_ref = weakref.ref(adata)
self.filename = filename
self._filemode = filemode
self._file = None
if filename:
self.open()

def __getstate__(self):
state = self.__dict__.copy()
state["_adata_ref"] = state["_adata_ref"]()
return state

def __setstate__(self, state):
self.__dict__ = state.copy()
self.__dict__["_adata_ref"] = weakref.ref(state["_adata_ref"])

@property
def _adata(self):
return self._adata_ref()

def __repr__(self) -> str:
if self.filename is None:
return "Backing file manager: no file is set."
Expand Down
47 changes: 47 additions & 0 deletions anndata/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import product
import tracemalloc
import warnings

import numpy as np
Expand Down Expand Up @@ -601,6 +602,52 @@ def assert_eq_not_id(a, b):
assert_eq_not_id(map_sprs[key], map_copy[key])


def test_memory_usage():
N, M = 100, 200
RUNS = 10
obs_df = pd.DataFrame(
dict(
cat=pd.Categorical(np.arange(N, dtype=int)),
int=np.arange(N, dtype=int),
float=np.arange(N, dtype=float),
obj=[str(i) for i in np.arange(N, dtype=int)],
),
index=[f"cell{i}" for i in np.arange(N, dtype=int)],
)
var_df = pd.DataFrame(
dict(
cat=pd.Categorical(np.arange(M, dtype=int)),
int=np.arange(M, dtype=int),
float=np.arange(M, dtype=float),
obj=[str(i) for i in np.arange(M, dtype=int)],
),
index=[f"gene{i}" for i in np.arange(M, dtype=int)],
)

def get_memory(snapshot, key_type="lineno"):
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<unknown>"),
)
)
total = sum(stat.size for stat in snapshot.statistics(key_type))
return total

total = np.zeros(RUNS)
# Instantiate the anndata object first before memory calculation to
# only look at memory changes due to deletion of such a object.
adata = AnnData(X=np.random.random((N, M)), obs=obs_df, var=var_df)
adata.X[0, 0] = 1.0 # Disable Codacy issue
tracemalloc.start()
for i in range(RUNS):
adata = AnnData(X=np.random.random((N, M)), obs=obs_df, var=var_df)
total[i] = get_memory(tracemalloc.take_snapshot())
tracemalloc.stop()
relative_increase = total[:-1] / total[1:]
np.testing.assert_allclose(relative_increase, 1.0, atol=0.2)


def test_to_memory_no_copy():
adata = gen_adata((3, 5))
mem = adata.to_memory()
Expand Down
Loading