Skip to content

Commit

Permalink
add check for column_names-series combination and refactor logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Riik committed Nov 9, 2024
1 parent 0edf519 commit e4e43bb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
33 changes: 18 additions & 15 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,33 +314,36 @@ def maybe_set_index(
msg = "Either `column_names` or `index` should be provided"
raise ValueError(msg)

if is_pandas_like_dataframe(native_frame):
if column_names is not None:
if column_names is not None:
if is_pandas_like_dataframe(native_frame):
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(
native_frame.set_index(column_names)
)
)
elif is_pandas_like_series(native_frame):
msg = "Cannot set index using column names on a Series"
raise ValueError(msg)

if index is not None: # pragma: no cover
from narwhals.series import Series

if _is_iterable(index):
index = [
idx.to_native() if isinstance(idx, Series) else idx for idx in index
]
if isinstance(index, Series):
index = index.to_native()
if index is not None: # pragma: no cover
from narwhals.series import Series

if is_pandas_like_series(df_any):
native_frame.index = index
else:
native_frame = native_frame.set_index(index)
if _is_iterable(index):
index = [idx.to_native() if isinstance(idx, Series) else idx for idx in index]
if isinstance(index, Series):
index = index.to_native()

if is_pandas_like_dataframe(native_frame):
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(native_frame.set_index(index))
)

elif is_pandas_like_series(native_frame):
native_frame.index = index
return df_any._from_compliant_series( # type: ignore[no-any-return]
df_any._compliant_series._from_native_series(native_frame)
)

return df_any # type: ignore[no-any-return]


Expand Down
8 changes: 8 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def test_maybe_set_index_polars_direct_index(
assert result is df


def test_maybe_set_index_pandas_series_column_names() -> None:
df = nw.from_native(pd.Series([0, 1, 2]), allow_series=True)
with pytest.raises(
ValueError, match="Cannot set index using column names on a Series"
):
nw.maybe_set_index(df, column_names=["a"])


def test_maybe_set_index_pandas_either_index_or_column_names() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
column_names = ["a", "b"]
Expand Down

0 comments on commit e4e43bb

Please sign in to comment.