Skip to content

Commit

Permalink
Make DataFrame#loc return Series or DataFrame if a scalar is gi…
Browse files Browse the repository at this point in the history
…ven (#866)

Return Series or DataFrame if a scalar is given to DataFrame#loc

Fix #749.
  • Loading branch information
skatsuta authored Feb 12, 2024
1 parent fe28163 commit 176805d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
5 changes: 3 additions & 2 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class _iLocIndexerFrame(_iLocIndexer):
) -> None: ...

class _LocIndexerFrame(_LocIndexer):
@overload
def __getitem__(self, idx: Scalar) -> Series | DataFrame: ...
@overload
def __getitem__(
self,
Expand Down Expand Up @@ -184,8 +186,7 @@ class _LocIndexerFrame(_LocIndexer):
@overload
def __getitem__(
self,
idx: ScalarT
| Callable[[DataFrame], ScalarT]
idx: Callable[[DataFrame], ScalarT]
| tuple[
IndexType
| MaskType
Expand Down
31 changes: 20 additions & 11 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2208,16 +2208,16 @@ def test_frame_scalars_slice() -> None:

# Note: bool_ cannot be tested since the index is object and pandas does not
# support boolean access using loc except when the index is boolean
check(assert_type(df.loc[str_], pd.Series), pd.Series)
check(assert_type(df.loc[bytes_], pd.Series), pd.Series)
check(assert_type(df.loc[date], pd.Series), pd.Series)
check(assert_type(df.loc[datetime_], pd.Series), pd.Series)
check(assert_type(df.loc[timedelta], pd.Series), pd.Series)
check(assert_type(df.loc[int_], pd.Series), pd.Series)
check(assert_type(df.loc[float_], pd.Series), pd.Series)
check(assert_type(df.loc[complex_], pd.Series), pd.Series)
check(assert_type(df.loc[timestamp], pd.Series), pd.Series)
check(assert_type(df.loc[pd_timedelta], pd.Series), pd.Series)
check(assert_type(df.loc[str_], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[bytes_], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[date], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[datetime_], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[timedelta], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[int_], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[float_], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[complex_], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[timestamp], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[pd_timedelta], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[none], pd.Series), pd.Series)

check(assert_type(df.loc[:, str_], pd.Series), pd.Series)
Expand All @@ -2232,11 +2232,20 @@ def test_frame_scalars_slice() -> None:
check(assert_type(df.loc[:, pd_timedelta], pd.Series), pd.Series)
check(assert_type(df.loc[:, none], pd.Series), pd.Series)

# GH749

multi_idx = pd.MultiIndex.from_product([["a", "b"], [1, 2]], names=["alpha", "num"])
df2 = pd.DataFrame({"col1": range(4)}, index=multi_idx)
check(assert_type(df2.loc[str_], Union[pd.Series, pd.DataFrame]), pd.DataFrame)

df3 = pd.DataFrame({"x": range(2)}, index=pd.Index(["a", "b"]))
check(assert_type(df3.loc[str_], Union[pd.Series, pd.DataFrame]), pd.Series)


def test_boolean_loc() -> None:
# Booleans can only be used in loc when the index is boolean
df = pd.DataFrame([[0, 1], [1, 0]], columns=[True, False], index=[True, False])
check(assert_type(df.loc[True], pd.Series), pd.Series)
check(assert_type(df.loc[True], Union[pd.Series, pd.DataFrame]), pd.Series)
check(assert_type(df.loc[:, False], pd.Series), pd.Series)


Expand Down

0 comments on commit 176805d

Please sign in to comment.