Skip to content

Commit

Permalink
feat: add maybe_reset_index for pandas-like dataframe or series (#1095
Browse files Browse the repository at this point in the history
)



---------

Co-authored-by: Francesco Bruzzesi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Marco Edward Gorelli <[email protected]>
  • Loading branch information
4 people authored Sep 30, 2024
1 parent 9094d5c commit 0ffd2bc
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Here are the top-level functions available in Narwhals.
- maybe_align_index
- maybe_convert_dtypes
- maybe_get_index
- maybe_reset_index
- maybe_set_index
- mean
- mean_horizontal
Expand Down
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from narwhals.utils import maybe_align_index
from narwhals.utils import maybe_convert_dtypes
from narwhals.utils import maybe_get_index
from narwhals.utils import maybe_reset_index
from narwhals.utils import maybe_set_index

__version__ = "1.8.4"
Expand All @@ -72,6 +73,7 @@
"maybe_align_index",
"maybe_convert_dtypes",
"maybe_get_index",
"maybe_reset_index",
"maybe_set_index",
"get_native_namespace",
"all",
Expand Down
31 changes: 31 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from narwhals.utils import maybe_align_index as nw_maybe_align_index
from narwhals.utils import maybe_convert_dtypes as nw_maybe_convert_dtypes
from narwhals.utils import maybe_get_index as nw_maybe_get_index
from narwhals.utils import maybe_reset_index as nw_maybe_reset_index
from narwhals.utils import maybe_set_index as nw_maybe_set_index

if TYPE_CHECKING:
Expand Down Expand Up @@ -1802,6 +1803,35 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:
return nw_maybe_set_index(df, column_names)


def maybe_reset_index(obj: T) -> T:
"""
Reset the index to the default integer index of a DataFrame or a Series, if it's pandas-like.
Notes:
This is only really intended for backwards-compatibility purposes,
for example if your library already resets the index for users.
If you're designing a new library, we highly encourage you to not
rely on the Index.
For non-pandas-like inputs, this is a no-op.
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals.stable.v1 as nw
>>> df_pd = pd.DataFrame({"a": [1, 2], "b": [4, 5]}, index=([6, 7]))
>>> df = nw.from_native(df_pd)
>>> nw.to_native(nw.maybe_reset_index(df))
a b
0 1 4
1 2 5
>>> series_pd = pd.Series([1, 2])
>>> series = nw.from_native(series_pd, series_only=True)
>>> nw.maybe_get_index(series)
RangeIndex(start=0, stop=2, step=1)
"""
return nw_maybe_reset_index(obj)


def get_native_namespace(obj: Any) -> Any:
"""
Get native namespace from object.
Expand Down Expand Up @@ -2032,6 +2062,7 @@ def from_dict(
"maybe_align_index",
"maybe_convert_dtypes",
"maybe_get_index",
"maybe_reset_index",
"maybe_set_index",
"get_native_namespace",
"get_level",
Expand Down
41 changes: 41 additions & 0 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,47 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T:
return df_any # type: ignore[no-any-return]


def maybe_reset_index(obj: T) -> T:
"""
Reset the index to the default integer index of a DataFrame or a Series, if it's pandas-like.
Notes:
This is only really intended for backwards-compatibility purposes,
for example if your library already resets the index for users.
If you're designing a new library, we highly encourage you to not
rely on the Index.
For non-pandas-like inputs, this is a no-op.
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> df_pd = pd.DataFrame({"a": [1, 2], "b": [4, 5]}, index=([6, 7]))
>>> df = nw.from_native(df_pd)
>>> nw.to_native(nw.maybe_reset_index(df))
a b
0 1 4
1 2 5
>>> series_pd = pd.Series([1, 2])
>>> series = nw.from_native(series_pd, series_only=True)
>>> nw.maybe_get_index(series)
RangeIndex(start=0, stop=2, step=1)
"""
obj_any = cast(Any, obj)
native_obj = to_native(obj_any)
if is_pandas_like_dataframe(native_obj):
return obj_any._from_compliant_dataframe( # type: ignore[no-any-return]
obj_any._compliant_frame._from_native_frame(native_obj.reset_index(drop=True))
)
if is_pandas_like_series(native_obj):
return obj_any._from_compliant_series( # type: ignore[no-any-return]
obj_any._compliant_series._from_native_series(
native_obj.reset_index(drop=True)
)
)
return obj_any # type: ignore[no-any-return]


def maybe_convert_dtypes(obj: T, *args: bool, **kwargs: bool | str) -> T:
"""
Convert columns or series to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like.
Expand Down
24 changes: 24 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,30 @@ def test_maybe_get_index_polars() -> None:
assert result is None


def test_maybe_reset_index_pandas() -> None:
pandas_df = nw.from_native(
pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[7, 8, 9])
)
result = nw.maybe_reset_index(pandas_df)
expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 2])
assert_frame_equal(nw.to_native(result), expected)
pandas_series = nw.from_native(
pd.Series([1, 2, 3], index=[7, 8, 9]), series_only=True
)
result_s = nw.maybe_reset_index(pandas_series)
expected_s = pd.Series([1, 2, 3], index=[0, 1, 2])
assert_series_equal(nw.to_native(result_s), expected_s)


def test_maybe_reset_index_polars() -> None:
df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
result = nw.maybe_reset_index(df)
assert result is df
series = nw.from_native(pl.Series([1, 2, 3]), series_only=True)
result_s = nw.maybe_reset_index(series)
assert result_s is series


@pytest.mark.skipif(
parse_version(pd.__version__) < parse_version("1.0.0"),
reason="too old for convert_dtypes",
Expand Down

0 comments on commit 0ffd2bc

Please sign in to comment.