Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support passing index object directly into maybe_set_index #1319

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,15 @@ def maybe_get_index(obj: T) -> Any | None:
return None


def maybe_set_index(df: T, column_names: str | list[str]) -> T:
def maybe_set_index(
df: T,
column_names: str | list[str] | None = None,
*,
index: Series | list[Series] | None = None,
) -> T:
"""
Set columns `columns` to be the index of `df`, if `df` is pandas-like.

Set the index of `df`, if `df` is pandas-like. The index can either be specified as
a existing column name or list of column names with `column_names`, or set directly with a Series or list of Series with `index`.
Notes:
This is only really intended for backwards-compatibility purposes,
for example if your library already aligns indices for users.
Expand All @@ -297,14 +302,48 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:
4 1
5 2
"""

df_any = cast(Any, df)
native_frame = to_native(df_any)
if is_pandas_like_dataframe(native_frame):
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(
native_frame.set_index(column_names)

if column_names is not None and index is not None:
msg = "Only one of `column_names` or `index` should be provided"
raise ValueError(msg)

if not column_names and not index:
msg = "Either `column_names` or `index` should be provided"
raise ValueError(msg)

if column_names is not None:
if is_pandas_like_dataframe(native_frame):
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(
native_frame.set_index(column_names)
)
)
)
elif is_pandas_like_series(native_frame):
msg = "Cannot set index using column names on a Series"
raise ValueError(msg)

if index is not None: # pragma: no cover
from narwhals.series import Series

if _is_iterable(index):
index = [idx.to_native() if isinstance(idx, Series) else idx for idx in index]
if isinstance(index, Series):
index = index.to_native()

if is_pandas_like_dataframe(native_frame):
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(native_frame.set_index(index))
)

elif is_pandas_like_series(native_frame):
native_frame.index = index
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
return df_any._from_compliant_series( # type: ignore[no-any-return]
df_any._compliant_series._from_native_series(native_frame)
)

return df_any # type: ignore[no-any-return]


Expand Down
108 changes: 100 additions & 8 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import string
from typing import TYPE_CHECKING

import hypothesis.strategies as st
import pandas as pd
Expand All @@ -15,6 +16,9 @@
from tests.utils import PANDAS_VERSION
from tests.utils import get_module_version_as_tuple

if TYPE_CHECKING:
from narwhals.series import Series


def test_maybe_align_index_pandas() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0]))
Expand Down Expand Up @@ -58,21 +62,109 @@ def test_maybe_align_index_polars() -> None:
nw.maybe_align_index(df, s[1:])


def test_maybe_set_index_pandas() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]))
result = nw.maybe_set_index(df, "b")
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[1, 2, 0]).set_index(
"b"
)
@pytest.mark.parametrize(
"column_names",
["b", ["a", "b"]],
)
def test_maybe_set_index_pandas_column_names(
column_names: str | list[str] | None,
) -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, column_names)
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(column_names)
assert_frame_equal(nw.to_native(result), expected)


def test_maybe_set_index_polars() -> None:
@pytest.mark.parametrize(
"column_names",
[
"b",
["a", "b"],
],
)
def test_maybe_set_index_polars_column_names(
column_names: str | list[str] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, column_names)
assert result is df


@pytest.mark.parametrize(
"native_df_or_series",
[pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}), pd.Series([0, 1, 2])],
)
@pytest.mark.parametrize(
("narwhals_index", "pandas_index"),
[
(nw.from_native(pd.Series([1, 2, 0]), series_only=True), pd.Series([1, 2, 0])),
(
[
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
],
[
pd.Series([0, 1, 2]),
pd.Series([1, 2, 0]),
],
),
],
)
def test_maybe_set_index_pandas_direct_index(
narwhals_index: Series | list[Series] | None,
pandas_index: pd.Series | list[pd.Series] | None,
native_df_or_series: pd.DataFrame | pd.Series,
) -> None:
df = nw.from_native(native_df_or_series, allow_series=True)
result = nw.maybe_set_index(df, index=narwhals_index)
if isinstance(native_df_or_series, pd.Series):
native_df_or_series.index = pandas_index
assert_series_equal(nw.to_native(result), native_df_or_series)
else:
expected = native_df_or_series.set_index(pandas_index)
assert_frame_equal(nw.to_native(result), expected)


@pytest.mark.parametrize(
"index",
[
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
[
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
],
],
)
def test_maybe_set_index_polars_direct_index(
index: Series | list[Series] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, "b")
result = nw.maybe_set_index(df, index=index)
assert result is df


def test_maybe_set_index_pandas_series_column_names() -> None:
df = nw.from_native(pd.Series([0, 1, 2]), allow_series=True)
with pytest.raises(
ValueError, match="Cannot set index using column names on a Series"
):
nw.maybe_set_index(df, column_names=["a"])


def test_maybe_set_index_pandas_either_index_or_column_names() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
column_names = ["a", "b"]
index = nw.from_native(pd.Series([0, 1, 2]), series_only=True)
with pytest.raises(
ValueError, match="Only one of `column_names` or `index` should be provided"
):
nw.maybe_set_index(df, column_names=column_names, index=index)
with pytest.raises(
ValueError, match="Either `column_names` or `index` should be provided"
):
nw.maybe_set_index(df)


def test_maybe_get_index_pandas() -> None:
pandas_df = pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0])
result = nw.maybe_get_index(nw.from_native(pandas_df))
Expand Down
Loading