From 0ffd2bcbe3c5c95e17a11bea46d69f788a020093 Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Mon, 30 Sep 2024 19:54:34 +0100 Subject: [PATCH] feat: add `maybe_reset_index` for pandas-like dataframe or series (#1095) --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marco Edward Gorelli --- docs/api-reference/narwhals.md | 1 + narwhals/__init__.py | 2 ++ narwhals/stable/v1/__init__.py | 31 +++++++++++++++++++++++++ narwhals/utils.py | 41 ++++++++++++++++++++++++++++++++++ tests/utils_test.py | 24 ++++++++++++++++++++ 5 files changed, 99 insertions(+) diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index d678b8732..2700c48c7 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -22,6 +22,7 @@ Here are the top-level functions available in Narwhals. - maybe_align_index - maybe_convert_dtypes - maybe_get_index + - maybe_reset_index - maybe_set_index - mean - mean_horizontal diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 2b571f0e2..e5ee71a18 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -55,6 +55,7 @@ from narwhals.utils import maybe_align_index from narwhals.utils import maybe_convert_dtypes from narwhals.utils import maybe_get_index +from narwhals.utils import maybe_reset_index from narwhals.utils import maybe_set_index __version__ = "1.8.4" @@ -72,6 +73,7 @@ "maybe_align_index", "maybe_convert_dtypes", "maybe_get_index", + "maybe_reset_index", "maybe_set_index", "get_native_namespace", "all", diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index b542b90fa..b0cefc3e6 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -55,6 +55,7 @@ from narwhals.utils import maybe_align_index as nw_maybe_align_index from narwhals.utils import maybe_convert_dtypes as nw_maybe_convert_dtypes from narwhals.utils import maybe_get_index as nw_maybe_get_index +from narwhals.utils import maybe_reset_index as nw_maybe_reset_index from narwhals.utils import maybe_set_index as nw_maybe_set_index if TYPE_CHECKING: @@ -1802,6 +1803,35 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T: return nw_maybe_set_index(df, column_names) +def maybe_reset_index(obj: T) -> T: + """ + Reset the index to the default integer index of a DataFrame or a Series, if it's pandas-like. + + Notes: + This is only really intended for backwards-compatibility purposes, + for example if your library already resets the index for users. + If you're designing a new library, we highly encourage you to not + rely on the Index. + For non-pandas-like inputs, this is a no-op. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals.stable.v1 as nw + >>> df_pd = pd.DataFrame({"a": [1, 2], "b": [4, 5]}, index=([6, 7])) + >>> df = nw.from_native(df_pd) + >>> nw.to_native(nw.maybe_reset_index(df)) + a b + 0 1 4 + 1 2 5 + >>> series_pd = pd.Series([1, 2]) + >>> series = nw.from_native(series_pd, series_only=True) + >>> nw.maybe_get_index(series) + RangeIndex(start=0, stop=2, step=1) + """ + return nw_maybe_reset_index(obj) + + def get_native_namespace(obj: Any) -> Any: """ Get native namespace from object. @@ -2032,6 +2062,7 @@ def from_dict( "maybe_align_index", "maybe_convert_dtypes", "maybe_get_index", + "maybe_reset_index", "maybe_set_index", "get_native_namespace", "get_level", diff --git a/narwhals/utils.py b/narwhals/utils.py index 62ae7730b..37cce17d3 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -306,6 +306,47 @@ def maybe_set_index(df: T, column_names: str | list[str]) -> T: return df_any # type: ignore[no-any-return] +def maybe_reset_index(obj: T) -> T: + """ + Reset the index to the default integer index of a DataFrame or a Series, if it's pandas-like. + + Notes: + This is only really intended for backwards-compatibility purposes, + for example if your library already resets the index for users. + If you're designing a new library, we highly encourage you to not + rely on the Index. + For non-pandas-like inputs, this is a no-op. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import narwhals as nw + >>> df_pd = pd.DataFrame({"a": [1, 2], "b": [4, 5]}, index=([6, 7])) + >>> df = nw.from_native(df_pd) + >>> nw.to_native(nw.maybe_reset_index(df)) + a b + 0 1 4 + 1 2 5 + >>> series_pd = pd.Series([1, 2]) + >>> series = nw.from_native(series_pd, series_only=True) + >>> nw.maybe_get_index(series) + RangeIndex(start=0, stop=2, step=1) + """ + obj_any = cast(Any, obj) + native_obj = to_native(obj_any) + if is_pandas_like_dataframe(native_obj): + return obj_any._from_compliant_dataframe( # type: ignore[no-any-return] + obj_any._compliant_frame._from_native_frame(native_obj.reset_index(drop=True)) + ) + if is_pandas_like_series(native_obj): + return obj_any._from_compliant_series( # type: ignore[no-any-return] + obj_any._compliant_series._from_native_series( + native_obj.reset_index(drop=True) + ) + ) + return obj_any # type: ignore[no-any-return] + + def maybe_convert_dtypes(obj: T, *args: bool, **kwargs: bool | str) -> T: """ Convert columns or series to the best possible dtypes using dtypes supporting ``pd.NA``, if df is pandas-like. diff --git a/tests/utils_test.py b/tests/utils_test.py index f51c28eab..cea458bc9 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -84,6 +84,30 @@ def test_maybe_get_index_polars() -> None: assert result is None +def test_maybe_reset_index_pandas() -> None: + pandas_df = nw.from_native( + pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[7, 8, 9]) + ) + result = nw.maybe_reset_index(pandas_df) + expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}, index=[0, 1, 2]) + assert_frame_equal(nw.to_native(result), expected) + pandas_series = nw.from_native( + pd.Series([1, 2, 3], index=[7, 8, 9]), series_only=True + ) + result_s = nw.maybe_reset_index(pandas_series) + expected_s = pd.Series([1, 2, 3], index=[0, 1, 2]) + assert_series_equal(nw.to_native(result_s), expected_s) + + +def test_maybe_reset_index_polars() -> None: + df = nw.from_native(pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + result = nw.maybe_reset_index(df) + assert result is df + series = nw.from_native(pl.Series([1, 2, 3]), series_only=True) + result_s = nw.maybe_reset_index(series) + assert result_s is series + + @pytest.mark.skipif( parse_version(pd.__version__) < parse_version("1.0.0"), reason="too old for convert_dtypes",