From e4e43bbfbc57b591ad090c9969fe62e3067b2db7 Mon Sep 17 00:00:00 2001 From: Rik van der Vlist Date: Sat, 9 Nov 2024 22:37:24 +0100 Subject: [PATCH] add check for column_names-series combination and refactor logic --- narwhals/utils.py | 33 ++++++++++++++++++--------------- tests/utils_test.py | 8 ++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/narwhals/utils.py b/narwhals/utils.py index 0a6f58bf8..a2cda86d8 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -314,33 +314,36 @@ def maybe_set_index( msg = "Either `column_names` or `index` should be provided" raise ValueError(msg) - if is_pandas_like_dataframe(native_frame): - if column_names is not None: + if column_names is not None: + if is_pandas_like_dataframe(native_frame): return df_any._from_compliant_dataframe( # type: ignore[no-any-return] df_any._compliant_frame._from_native_frame( native_frame.set_index(column_names) ) ) + elif is_pandas_like_series(native_frame): + msg = "Cannot set index using column names on a Series" + raise ValueError(msg) - if index is not None: # pragma: no cover - from narwhals.series import Series - - if _is_iterable(index): - index = [ - idx.to_native() if isinstance(idx, Series) else idx for idx in index - ] - if isinstance(index, Series): - index = index.to_native() + if index is not None: # pragma: no cover + from narwhals.series import Series - if is_pandas_like_series(df_any): - native_frame.index = index - else: - native_frame = native_frame.set_index(index) + if _is_iterable(index): + index = [idx.to_native() if isinstance(idx, Series) else idx for idx in index] + if isinstance(index, Series): + index = index.to_native() + if is_pandas_like_dataframe(native_frame): return df_any._from_compliant_dataframe( # type: ignore[no-any-return] df_any._compliant_frame._from_native_frame(native_frame.set_index(index)) ) + elif is_pandas_like_series(native_frame): + native_frame.index = index + return df_any._from_compliant_series( # type: ignore[no-any-return] + df_any._compliant_series._from_native_series(native_frame) + ) + return df_any # type: ignore[no-any-return] diff --git a/tests/utils_test.py b/tests/utils_test.py index e2b8e4523..dc2415c8d 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -143,6 +143,14 @@ def test_maybe_set_index_polars_direct_index( assert result is df +def test_maybe_set_index_pandas_series_column_names() -> None: + df = nw.from_native(pd.Series([0, 1, 2]), allow_series=True) + with pytest.raises( + ValueError, match="Cannot set index using column names on a Series" + ): + nw.maybe_set_index(df, column_names=["a"]) + + def test_maybe_set_index_pandas_either_index_or_column_names() -> None: df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) column_names = ["a", "b"]