Skip to content

Commit

Permalink
change to separate and parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Riik committed Nov 9, 2024
1 parent ec888de commit aaa1780
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 30 deletions.
45 changes: 34 additions & 11 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,15 @@ def maybe_get_index(obj: T) -> Any | None:
return None


def maybe_set_index(df: T, keys: str | Series | list[Series | str]) -> T:
def maybe_set_index(
df: T,
column_names: str | list[str] | None = None,
*,
index: Series | list[Series] | None = None,
) -> T:
"""
Set columns `keys` to be the index of `df`, if `df` is pandas-like. 'keys' should be
a name of an existing column, a Series, or a list of column names and/or Series.
Set the index of `df`, if `df` is pandas-like. The index can either be specified as
a existing column name or list of column names with `column_names`, or set directly with a Series or list of Series with `index`.
Notes:
This is only really intended for backwards-compatibility purposes,
for example if your library already aligns indices for users.
Expand All @@ -301,17 +306,35 @@ def maybe_set_index(df: T, keys: str | Series | list[Series | str]) -> T:
df_any = cast(Any, df)
native_frame = to_native(df_any)

if column_names is not None and index is not None:
msg = "Only one of `column_names` or `keys` should be provided"
raise ValueError(msg)

if not column_names and not index:
msg = "Either `column_names` or `keys` should be provided"
raise ValueError(msg)

if is_pandas_like_dataframe(native_frame):
from narwhals.series import Series
if column_names is not None:
return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(
native_frame.set_index(column_names)
)
)

if _is_iterable(keys):
keys = [key.to_native() if isinstance(key, Series) else key for key in keys]
if isinstance(keys, Series):
keys = keys.to_native()
if index is not None:
from narwhals.series import Series

return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(native_frame.set_index(keys))
)
if _is_iterable(index):
index = [
key.to_native() if isinstance(key, Series) else key for key in index
]
if isinstance(index, Series):
index = index.to_native()

return df_any._from_compliant_dataframe( # type: ignore[no-any-return]
df_any._compliant_frame._from_native_frame(native_frame.set_index(index))
)

return df_any # type: ignore[no-any-return]

Expand Down
80 changes: 61 additions & 19 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,42 +63,84 @@ def test_maybe_align_index_polars() -> None:


@pytest.mark.parametrize(
("pandas_keys", "narwhals_keys"),
"column_names",
["b", ["a", "b"]],
)
def test_maybe_set_index_pandas_column_names(
column_names: str | list[str] | None,
) -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, column_names)
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(column_names)
assert_frame_equal(nw.to_native(result), expected)


@pytest.mark.parametrize(
"column_names",
[
("b", "b"),
(pd.Series([1, 2, 0]), nw.from_native(pd.Series([1, 2, 0]), series_only=True)),
(["a", "b"], ["a", "b"]),
(
[pd.Series([0, 1, 2]), "b"],
[nw.from_native(pd.Series([0, 1, 2]), series_only=True), "b"],
),
"b",
["a", "b"],
],
)
def test_maybe_set_index_pandas(
pandas_keys: str | Series | list[Series | str],
narwhals_keys: str | Series | list[Series | str],
def test_maybe_set_index_polars_column_names(
column_names: str | list[str] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, column_names)
assert result is df


@pytest.mark.parametrize(
"index",
[
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
[
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
],
],
)
def test_maybe_set_index_pandas_direct_index(
index: Series | list[Series] | None,
) -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, narwhals_keys)
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(pandas_keys)
result = nw.maybe_set_index(df, index=index)
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).set_index(index)
assert_frame_equal(nw.to_native(result), expected)


@pytest.mark.parametrize(
"narwhals_keys",
"index",
[
"b",
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
["a", "b"],
[nw.from_native(pd.Series([0, 1, 2]), series_only=True), "b"],
[
nw.from_native(pd.Series([0, 1, 2]), series_only=True),
nw.from_native(pd.Series([1, 2, 0]), series_only=True),
],
],
)
def test_maybe_set_index_polars(narwhals_keys: str | Series | list[Series | str]) -> None:
def test_maybe_set_index_polars_direct_index(
index: Series | list[Series] | None,
) -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_set_index(df, narwhals_keys)
result = nw.maybe_set_index(df, index=index)
assert result is df


def test_maybe_set_index_pandas_either_index_or_column_names() -> None:
df = nw.from_native(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
column_names = ["a", "b"]
index = nw.from_native(pd.Series([0, 1, 2]), series_only=True)
with pytest.raises(
ValueError, match="Only one of `column_names` or `keys` should be provided"
):
nw.maybe_set_index(df, column_names=column_names, index=index)
with pytest.raises(
ValueError, match="Either `column_names` or `keys` should be provided"
):
nw.maybe_set_index(df)


def test_maybe_get_index_pandas() -> None:
pandas_df = pd.DataFrame({"a": [1, 2, 3]}, index=[1, 2, 0])
result = nw.maybe_get_index(nw.from_native(pandas_df))
Expand Down

0 comments on commit aaa1780

Please sign in to comment.