diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 7b498c707..6636d4606 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -839,10 +839,12 @@ def pivot( separator: str = "_", ) -> Self: if self._implementation is not Implementation.PANDAS or ( - self._backend_version < (1, 5) + self._backend_version < (1, 1) ): - msg = "pivot is only supported for pandas>=1.5" + msg = "pivot is only supported for pandas>=1.1" raise NotImplementedError(msg) + from itertools import product + frame = self._native_frame if isinstance(on, str): @@ -872,16 +874,24 @@ def pivot( values=values_, index=index, columns=on, - aggfunc="size" if aggregate_function == "len" else aggregate_function, + aggfunc=aggregate_function, margins=False, observed=True, - sort=False, ) + # Put columns in the right order + if sort_columns: + uniques = { + col: sorted(self._native_frame[col].unique().tolist()) for col in on + } + else: + uniques = {col: self._native_frame[col].unique().tolist() for col in on} + all_lists = [values_, *list(uniques.values())] + ordered_cols = list(product(*all_lists)) + result = result.loc[:, ordered_cols] columns = result.columns.tolist() n_on = len(on) - if n_on == 1: new_columns = [ separator.join(col).strip() if len(values_) > 1 else col[-1] @@ -895,22 +905,7 @@ def pivot( for col in columns ] result.columns = new_columns - - if sort_columns: - # The inner sorting creates a list of sorted lists of columns for each value - # which then needs to be unpacked into a list. - # This probably can be done more performantly as suffixes are always same?! - sorted_columns = [ - col_value - for col_values in [ - sorted([c for c in new_columns if c.startswith(v)]) for v in values_ - ] - for col_value in col_values - ] - - result = result.loc[:, sorted_columns] - - result.columns.names = [""] + result.columns.names = [""] # type: ignore[attr-defined] return self._from_native_frame(result.reset_index()) def to_arrow(self: Self) -> Any: diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index a3dc4003b..34482072d 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -2597,7 +2597,7 @@ def pivot( - None: no aggregation takes place, will raise error if multiple values are in group. - A predefined aggregate function string, one of - {'min', 'max', 'first', 'last', 'sum', 'mean', 'median', 'len'} + {'min', 'max', 'first', 'last', 'sum', 'mean', 'median', 'len'} maintain_order: Sort the grouped keys so that the output order is predictable. sort_columns: Sort the transposed columns by name. Default is by order of discovery. diff --git a/tests/frame/pivot_test.py b/tests/frame/pivot_test.py index bd6bab905..a454c9a99 100644 --- a/tests/frame/pivot_test.py +++ b/tests/frame/pivot_test.py @@ -125,7 +125,7 @@ def test_pivot( if "pyarrow_table" in str(constructor_eager): request.applymarker(pytest.mark.xfail) if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( - "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5) + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) ): # not implemented request.applymarker(pytest.mark.xfail) @@ -154,7 +154,7 @@ def test_pivot_no_agg( if "pyarrow_table" in str(constructor_eager): request.applymarker(pytest.mark.xfail) if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( - "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5) + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) ): # not implemented request.applymarker(pytest.mark.xfail) @@ -177,7 +177,7 @@ def test_pivot_sort_columns( if "pyarrow_table" in str(constructor_eager): request.applymarker(pytest.mark.xfail) if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( - "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5) + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) ): # not implemented request.applymarker(pytest.mark.xfail) @@ -227,7 +227,7 @@ def test_pivot_names_out( if "pyarrow_table" in str(constructor_eager): request.applymarker(pytest.mark.xfail) if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( - "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5) + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) ): # not implemented request.applymarker(pytest.mark.xfail)