From 7aa514c7c59d8402280ed66a475489bd895f8ab0 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Mon, 27 May 2024 15:37:42 +0200 Subject: [PATCH] fix(python): Fix `DataFrame.__getitem__` for empty list input - `df[[]]` (#16520) --- py-polars/polars/_utils/getitem.py | 13 +++++-------- py-polars/tests/unit/dataframe/test_getitem.py | 10 +++++++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/py-polars/polars/_utils/getitem.py b/py-polars/polars/_utils/getitem.py index da774be84c99..743b457b58d7 100644 --- a/py-polars/polars/_utils/getitem.py +++ b/py-polars/polars/_utils/getitem.py @@ -145,20 +145,17 @@ def get_df_item_by_key( else: return _select_rows(selection, row_key) - # Single input, e.g. df[1] - elif isinstance(key, str): + # Single string input, e.g. df["a"] + if isinstance(key, str): # This case is required because empty strings are otherwise treated # as an empty Sequence in `_select_rows` return df.get_column(key) - elif isinstance(key, Sequence) and len(key) == 0: - # df[[]] - # TODO: This removes all columns, but it should remove all rows. - # https://github.com/pola-rs/polars/issues/4924 - return df.__class__() + + # Single input - df[1] - or multiple inputs - df["a", "b", "c"] try: return _select_rows(df, key) # type: ignore[arg-type] except TypeError: - return _select_columns(df, key) # type: ignore[arg-type] + return _select_columns(df, key) # `str` overlaps with `Sequence[str]` diff --git a/py-polars/tests/unit/dataframe/test_getitem.py b/py-polars/tests/unit/dataframe/test_getitem.py index 30b1dfa6274b..0583526fce7a 100644 --- a/py-polars/tests/unit/dataframe/test_getitem.py +++ b/py-polars/tests/unit/dataframe/test_getitem.py @@ -234,6 +234,13 @@ def test_df_getitem_row_range_single_input() -> None: assert_frame_equal(result, expected) +def test_df_getitem_row_empty_list_single_input() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [5.0, 6.0]}) + result = df[[]] + expected = df.clear() + assert_frame_equal(result, expected) + + def test_df_getitem() -> None: """Test all the methods to use [] on a dataframe.""" df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]}) @@ -287,9 +294,6 @@ def test_df_getitem() -> None: # empty list with column selector drops rows but keeps columns assert_frame_equal(df[empty, :], df[:0]) - # empty list without column select return empty frame - assert_frame_equal(df[empty], pl.DataFrame({})) - # numpy array: assumed to be row indices if integers, or columns if strings # numpy array: positive idxs and empty idx