Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix): disallow using dataframes with multi index columns #1589

Merged
merged 13 commits into from
Aug 26, 2024
1 change: 1 addition & 0 deletions docs/release-notes/0.10.9.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Upper bound {mod}`numpy` for `gpu` installation on account of https://github.com/cupy/cupy/issues/8391 {pr}`1540` {user}`ilan-gold`
* Fix writing large number of columns for `h5` files {pr}`1147` {user}`ilan-gold` {user}`selmanozleyen`
* Upper bound dask on account of https://github.com/scverse/anndata/issues/1579 {pr}`1580` {user}`ilan-gold`
* Disallow using {class}`~pandas.DataFrame`s with mutli-index columns {pr}`1589` {use}`ilan-gold`

#### Documentation

Expand Down
30 changes: 19 additions & 11 deletions src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from .._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning
from ..compat import AwkArray
from ..utils import axis_len, convert_to_dict, deprecated, warn_once
from ..utils import (
axis_len,
convert_to_dict,
deprecated,
raise_value_error_if_multiindex_columns,
warn_once,
)
from .access import ElementRef
from .index import _subset
from .storage import coerce_array
Expand Down Expand Up @@ -258,16 +264,18 @@
return df

def _validate_value(self, val: Value, key: str) -> Value:
if isinstance(val, pd.DataFrame) and not val.index.equals(self.dim_names):
# Could probably also re-order index if it’s contained
try:
pd.testing.assert_index_equal(val.index, self.dim_names)
except AssertionError as e:
msg = f"value.index does not match parent’s {self.dim} names:\n{e}"
raise ValueError(msg) from None
else:
msg = "Index.equals and pd.testing.assert_index_equal disagree"
raise AssertionError(msg)
if isinstance(val, pd.DataFrame):
raise_value_error_if_multiindex_columns(val)
if not val.index.equals(self.dim_names):
# Could probably also re-order index if it’s contained
try:
pd.testing.assert_index_equal(val.index, self.dim_names)
except AssertionError as e:
msg = f"value.index does not match parent’s {self.dim} names:\n{e}"
raise ValueError(msg) from None
else:
msg = "Index.equals and pd.testing.assert_index_equal disagree"
raise AssertionError(msg)

Check warning on line 278 in src/anndata/_core/aligned_mapping.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_core/aligned_mapping.py#L277-L278

Added lines #L277 - L278 were not covered by tests
return super()._validate_value(val, key)

@property
Expand Down
16 changes: 15 additions & 1 deletion src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from .._settings import settings
from ..compat import DaskArray, SpArray, ZarrArray, _move_adj_mtx
from ..logging import anndata_logger as logger
from ..utils import axis_len, deprecated, ensure_df_homogeneous
from ..utils import (
axis_len,
deprecated,
ensure_df_homogeneous,
raise_value_error_if_multiindex_columns,
)
from .access import ElementRef
from .aligned_df import _gen_dataframe
from .aligned_mapping import AlignedMappingProperty, AxisArrays, Layers, PairwiseArrays
Expand Down Expand Up @@ -234,6 +239,14 @@ def __init__(
oidx: Index1D = None,
vidx: Index1D = None,
):
# check for any multi-indices
df_elems = [obs, var, X]
for xxxm in [obsm, varm]:
if xxxm is not None and not isinstance(xxxm, np.ndarray):
df_elems += [v for v in xxxm.values() if isinstance(v, pd.DataFrame)]
for attr in df_elems:
if isinstance(attr, pd.DataFrame):
raise_value_error_if_multiindex_columns(attr)
if asview:
if not isinstance(X, AnnData):
raise ValueError("`X` has to be an AnnData object.")
Expand Down Expand Up @@ -736,6 +749,7 @@ def n_vars(self) -> int:
def _set_dim_df(self, value: pd.DataFrame, attr: str):
if not isinstance(value, pd.DataFrame):
raise ValueError(f"Can only assign pd.DataFrame to {attr}.")
raise_value_error_if_multiindex_columns(value)
value_idx = self._prep_dim_index(value.index, attr)
if self.is_view:
self._init_as_actual(self.copy())
Expand Down
8 changes: 7 additions & 1 deletion src/anndata/_core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
ZappyArray,
ZarrArray,
)
from ..utils import ensure_df_homogeneous, join_english
from ..utils import (
ensure_df_homogeneous,
join_english,
raise_value_error_if_multiindex_columns,
)
from .sparse_dataset import BaseCompressedSparseDataset

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,6 +88,8 @@ def coerce_array(
value = value.A
return value
if isinstance(value, pd.DataFrame):
if allow_df:
raise_value_error_if_multiindex_columns(value)
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
return value if allow_df else ensure_df_homogeneous(value, name)
# if value is an array-like object, try to convert it
e = None
Expand Down
11 changes: 11 additions & 0 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
import random
import re
import warnings
Expand Down Expand Up @@ -1000,3 +1001,13 @@ def __init__(self, *_args, **_kwargs) -> None:
raise ImportError(
"zarr must be imported to create an `AccessTrackingStore` instance."
)


def get_multiindex_columns_df(shape):
return pd.DataFrame(
np.random.rand(shape[0], shape[1]),
columns=pd.MultiIndex.from_tuples(
list(itertools.product(["a"], range(shape[1] - (shape[1] // 2))))
+ list(itertools.product(["b"], range(shape[1] // 2)))
),
)
8 changes: 8 additions & 0 deletions src/anndata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,11 @@ def is_hidden(attr) -> bool:
for item in type.__dir__(cls)
if not is_hidden(getattr(cls, item, None))
]


def raise_value_error_if_multiindex_columns(df: pd.DataFrame):
if isinstance(df.columns, pd.MultiIndex):
raise ValueError(
"MultiIndex columns are not supported in AnnData. "
"Please use a single-level index."
ilan-gold marked this conversation as resolved.
Show resolved Hide resolved
)
8 changes: 8 additions & 0 deletions tests/test_annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from natsort import natsorted

import anndata as ad
from anndata.tests.helpers import get_multiindex_columns_df


@pytest.mark.parametrize("dtype", [object, "string"])
Expand Down Expand Up @@ -63,3 +64,10 @@ def test_non_str_to_not_categorical():
result_non_transformed = adata.obs.drop(columns=["str_with_nan"])

pd.testing.assert_frame_equal(expected_non_transformed, result_non_transformed)


def test_error_multiindex():
adata = ad.AnnData(np.random.rand(100, 10))
df = get_multiindex_columns_df((adata.shape[0], 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
adata.obs = df
20 changes: 19 additions & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from anndata import AnnData, ImplicitModificationWarning
from anndata._settings import settings
from anndata.compat import CAN_USE_SPARSE_ARRAY
from anndata.tests.helpers import assert_equal, gen_adata
from anndata.tests.helpers import assert_equal, gen_adata, get_multiindex_columns_df

# some test objects that we use below
adata_dense = AnnData(np.array([[1, 2], [3, 4]]))
Expand Down Expand Up @@ -117,6 +117,24 @@ def test_create_from_df():
assert df.index.tolist() == ad.obs_names.tolist()


def test_error_create_from_multiindex_df():
df = get_multiindex_columns_df((100, 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
AnnData(df)


def test_error_with_obs_multiindex_df():
df = get_multiindex_columns_df((100, 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
AnnData(X=np.random.rand(100, 10), obs=df)


def test_error_with_obsm_multiindex_df():
df = get_multiindex_columns_df((100, 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
AnnData(X=np.random.rand(100, 10), obsm={"df": df})


def test_create_from_sparse_df():
s = sp.random(20, 30, density=0.2)
obs_names = [f"obs{i}" for i in range(20)]
Expand Down
7 changes: 7 additions & 0 deletions tests/test_obsmvarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scipy import sparse

from anndata import AnnData
from anndata.tests.helpers import get_multiindex_columns_df

M, N = (100, 100)

Expand Down Expand Up @@ -137,3 +138,9 @@ def test_shape_error(adata: AnnData):
),
):
adata.obsm["b"] = np.zeros((adata.shape[0] + 1, adata.shape[0]))


def test_error_set_multiindex_df(adata: AnnData):
df = get_multiindex_columns_df((adata.shape[0], 20))
with pytest.raises(ValueError, match=r"MultiIndex columns are not supported"):
adata.obsm["df"] = df
Loading