Skip to content

Commit

Permalink
more robust solution for old pandas
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Nov 12, 2024
1 parent b2b3684 commit 706ff43
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 26 deletions.
37 changes: 16 additions & 21 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,10 +839,12 @@ def pivot(
separator: str = "_",
) -> Self:
if self._implementation is not Implementation.PANDAS or (
self._backend_version < (1, 5)
self._backend_version < (1, 1)
):
msg = "pivot is only supported for pandas>=1.5"
msg = "pivot is only supported for pandas>=1.1"
raise NotImplementedError(msg)
from itertools import product

frame = self._native_frame

if isinstance(on, str):
Expand Down Expand Up @@ -872,16 +874,24 @@ def pivot(
values=values_,
index=index,
columns=on,
aggfunc="size" if aggregate_function == "len" else aggregate_function,
aggfunc=aggregate_function,
margins=False,
observed=True,
sort=False,
)

# Put columns in the right order
if sort_columns:
uniques = {
col: sorted(self._native_frame[col].unique().tolist()) for col in on
}
else:
uniques = {col: self._native_frame[col].unique().tolist() for col in on}
all_lists = [values_, *list(uniques.values())]
ordered_cols = list(product(*all_lists))
result = result.loc[:, ordered_cols]
columns = result.columns.tolist()

n_on = len(on)

if n_on == 1:
new_columns = [
separator.join(col).strip() if len(values_) > 1 else col[-1]
Expand All @@ -895,22 +905,7 @@ def pivot(
for col in columns
]
result.columns = new_columns

if sort_columns:
# The inner sorting creates a list of sorted lists of columns for each value
# which then needs to be unpacked into a list.
# This probably can be done more performantly as suffixes are always same?!
sorted_columns = [
col_value
for col_values in [
sorted([c for c in new_columns if c.startswith(v)]) for v in values_
]
for col_value in col_values
]

result = result.loc[:, sorted_columns]

result.columns.names = [""]
result.columns.names = [""] # type: ignore[attr-defined]
return self._from_native_frame(result.reset_index())

def to_arrow(self: Self) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,7 +2597,7 @@ def pivot(
- None: no aggregation takes place, will raise error if multiple values
are in group.
- A predefined aggregate function string, one of
{'min', 'max', 'first', 'last', 'sum', 'mean', 'median', 'len'}
{'min', 'max', 'first', 'last', 'sum', 'mean', 'median', 'len'}
maintain_order: Sort the grouped keys so that the output order is predictable.
sort_columns: Sort the transposed columns by name. Default is by order of
discovery.
Expand Down
8 changes: 4 additions & 4 deletions tests/frame/pivot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_pivot(
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5)
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
):
# not implemented
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -154,7 +154,7 @@ def test_pivot_no_agg(
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5)
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
):
# not implemented
request.applymarker(pytest.mark.xfail)
Expand All @@ -177,7 +177,7 @@ def test_pivot_sort_columns(
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5)
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
):
# not implemented
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_pivot_names_out(
if "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)
if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or (
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 5)
"pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1)
):
# not implemented
request.applymarker(pytest.mark.xfail)
Expand Down

0 comments on commit 706ff43

Please sign in to comment.