Skip to content

Commit

Permalink
Backport PR #1589: (fix): disallow using dataframes with multi index …
Browse files Browse the repository at this point in the history
…columns (#1609)
  • Loading branch information
ilan-gold authored Aug 26, 2024
1 parent 35055ec commit cf0208b
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/release-notes/0.10.9.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Upper bound {mod}`numpy` for `gpu` installation on account of {issue}`cupy/cupy#8391` {pr}`1540` {user}`ilan-gold`
* Fix writing large number of columns for `h5` files {pr}`1147` {user}`ilan-gold` {user}`selmanozleyen`
* Upper bound dask on account of {issue}`1579` {pr}`1580` {user}`ilan-gold`
* Disallow using {class}`~pandas.DataFrame`s with multi-index columns {pr}`1589` {user}`ilan-gold`
* Ensure setting {attr}`pandas.DataFrame.index` on a view of a {class}`~anndata.AnnData` instantiates the {class}`~pandas.DataFrame` from the view {pr}`1586` {user}`ilan-gold`

#### Documentation
Expand Down
30 changes: 19 additions & 11 deletions src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning
from ..compat import AwkArray
from ..utils import convert_to_dict, deprecated, dim_len, warn_once
from ..utils import (
convert_to_dict,
deprecated,
dim_len,
raise_value_error_if_multiindex_columns,
warn_once,
)
from .access import ElementRef
from .index import _subset
from .storage import coerce_array
Expand Down Expand Up @@ -258,16 +264,18 @@ def to_df(self) -> pd.DataFrame:
return df

def _validate_value(self, val: Value, key: str) -> Value:
if isinstance(val, pd.DataFrame) and not val.index.equals(self.dim_names):
# Could probably also re-order index if it’s contained
try:
pd.testing.assert_index_equal(val.index, self.dim_names)
except AssertionError as e:
msg = f"value.index does not match parent’s {self.dim} names:\n{e}"
raise ValueError(msg) from None
else:
msg = "Index.equals and pd.testing.assert_index_equal disagree"
raise AssertionError(msg)
if isinstance(val, pd.DataFrame):
raise_value_error_if_multiindex_columns(val, f"{self.attrname}[{key!r}]")
if not val.index.equals(self.dim_names):
# Could probably also re-order index if it’s contained
try:
pd.testing.assert_index_equal(val.index, self.dim_names)
except AssertionError as e:
msg = f"value.index does not match parent’s {self.dim} names:\n{e}"
raise ValueError(msg) from None
else:
msg = "Index.equals and pd.testing.assert_index_equal disagree"
raise AssertionError(msg)
return super()._validate_value(val, key)

@property
Expand Down
18 changes: 14 additions & 4 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
_move_adj_mtx,
)
from ..logging import anndata_logger as logger
from ..utils import deprecated, dim_len, ensure_df_homogeneous
from ..utils import (
deprecated,
dim_len,
ensure_df_homogeneous,
raise_value_error_if_multiindex_columns,
)
from .access import ElementRef
from .aligned_df import _gen_dataframe
from .aligned_mapping import AlignedMappingProperty, AxisArrays, Layers, PairwiseArrays
Expand Down Expand Up @@ -234,9 +239,13 @@ def __init__(
*,
obsp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
oidx: Index1D = None,
vidx: Index1D = None,
oidx: Index1D | None = None,
vidx: Index1D | None = None,
):
# check for any multi-indices that aren’t later checked in coerce_array
for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]:
if isinstance(attr, pd.DataFrame):
raise_value_error_if_multiindex_columns(attr, key)
if asview:
if not isinstance(X, AnnData):
raise ValueError("`X` has to be an AnnData object.")
Expand Down Expand Up @@ -731,9 +740,10 @@ def n_vars(self) -> int:
"""Number of variables/features."""
return len(self.var_names)

def _set_dim_df(self, value: pd.DataFrame, attr: str):
def _set_dim_df(self, value: pd.DataFrame, attr: Literal["obs", "var"]):
if not isinstance(value, pd.DataFrame):
raise ValueError(f"Can only assign pd.DataFrame to {attr}.")
raise_value_error_if_multiindex_columns(value, attr)
value_idx = self._prep_dim_index(value.index, attr)
if self.is_view:
self._init_as_actual(self.copy())
Expand Down
8 changes: 7 additions & 1 deletion src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
ZappyArray,
ZarrArray,
)
from ..utils import ensure_df_homogeneous, join_english
from ..utils import (
ensure_df_homogeneous,
join_english,
raise_value_error_if_multiindex_columns,
)
from .sparse_dataset import BaseCompressedSparseDataset

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,6 +86,8 @@ def coerce_array(
value = value.A
return value
if isinstance(value, pd.DataFrame):
if allow_df:
raise_value_error_if_multiindex_columns(value, name)
return value if allow_df else ensure_df_homogeneous(value, name)
# if value is an array-like object, try to convert it
e = None
Expand Down
11 changes: 11 additions & 0 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
import random
import re
import warnings
Expand Down Expand Up @@ -785,3 +786,13 @@ def __init__(self, *_args, **_kwargs) -> None:
raise ImportError(
"zarr must be imported to create an `AccessTrackingStore` instance."
)


def get_multiindex_columns_df(shape):
return pd.DataFrame(
np.random.rand(shape[0], shape[1]),
columns=pd.MultiIndex.from_tuples(
list(itertools.product(["a"], range(shape[1] - (shape[1] // 2))))
+ list(itertools.product(["b"], range(shape[1] // 2)))
),
)
9 changes: 9 additions & 0 deletions src/anndata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,12 @@ def is_hidden(attr) -> bool:
for item in type.__dir__(cls)
if not is_hidden(getattr(cls, item, None))
]


def raise_value_error_if_multiindex_columns(df: pd.DataFrame, attr: str):
if isinstance(df.columns, pd.MultiIndex):
msg = (
"MultiIndex columns are not supported in AnnData. "
f"Please use a single-level index for {attr}."
)
raise ValueError(msg)
8 changes: 8 additions & 0 deletions tests/test_annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from natsort import natsorted

import anndata as ad
from anndata.tests.helpers import get_multiindex_columns_df


@pytest.mark.parametrize("dtype", [object, "string"])
Expand Down Expand Up @@ -63,3 +64,10 @@ def test_non_str_to_not_categorical():
result_non_transformed = adata.obs.drop(columns=["str_with_nan"])

pd.testing.assert_frame_equal(expected_non_transformed, result_non_transformed)


def test_error_multiindex():
adata = ad.AnnData(np.random.rand(100, 10))
df = get_multiindex_columns_df((adata.shape[0], 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
adata.obs = df
10 changes: 9 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.sparse import csr_matrix, issparse

from anndata import AnnData, ImplicitModificationWarning
from anndata.tests.helpers import assert_equal, gen_adata
from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df

# some test objects that we use below
adata_dense = AnnData(np.array([[1, 2], [3, 4]]))
Expand Down Expand Up @@ -113,6 +113,14 @@ def test_create_from_df():
assert df.index.tolist() == ad.obs_names.tolist()


@pytest.mark.parametrize("attr", ["X", "obs", "obsm"])
def test_error_create_from_multiindex_df(attr):
df = get_multiindex_columns_df((100, 20))
val = df if attr != "obsm" else {"df": df}
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
AnnData(**{attr: val}, shape=(100, 10))


def test_create_from_sparse_df():
s = sp.random(20, 30, density=0.2)
obs_names = [f"obs{i}" for i in range(20)]
Expand Down
7 changes: 7 additions & 0 deletions tests/test_obsmvarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scipy import sparse

from anndata import AnnData
from anndata.tests.helpers import get_multiindex_columns_df

M, N = (100, 100)

Expand Down Expand Up @@ -137,3 +138,9 @@ def test_shape_error(adata: AnnData):
),
):
adata.obsm["b"] = np.zeros((adata.shape[0] + 1, adata.shape[0]))


def test_error_set_multiindex_df(adata: AnnData):
df = get_multiindex_columns_df((adata.shape[0], 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
adata.obsm["df"] = df

0 comments on commit cf0208b

Please sign in to comment.