Skip to content

Commit

Permalink
feat: add DataFrame.pivot for pandas like and Polars backend (#546)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Marco Gorelli <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent e3de5b2 commit a52374a
Show file tree
Hide file tree
Showing 7 changed files with 441 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api-reference/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- lazy
- null_count
- pipe
- pivot
- rename
- row
- rows
Expand Down
84 changes: 84 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,90 @@ def clone(self: Self) -> Self:
def gather_every(self: Self, n: int, offset: int = 0) -> Self:
return self._from_native_frame(self._native_frame.iloc[offset::n])

def pivot(
self: Self,
on: str | list[str],
*,
index: str | list[str] | None,
values: str | list[str] | None,
aggregate_function: Any | None,
maintain_order: bool,
sort_columns: bool,
separator: str = "_",
) -> Self:
if self._implementation is Implementation.PANDAS and (
self._backend_version < (1, 1)
): # pragma: no cover
msg = "pivot is only supported for pandas>=1.1"
raise NotImplementedError(msg)
if self._implementation is Implementation.MODIN:
msg = "pivot is not supported for Modin backend due to https://github.com/modin-project/modin/issues/7409."
raise NotImplementedError(msg)
from itertools import product

frame = self._native_frame

if isinstance(on, str):
on = [on]
if isinstance(index, str):
index = [index]

if values is None:
values_ = [c for c in self.columns if c not in {*on, *index}] # type: ignore[misc]
elif isinstance(values, str): # pragma: no cover
values_ = [values]
else:
values_ = values

if aggregate_function is None:
result = frame.pivot(columns=on, index=index, values=values_)

elif aggregate_function == "len":
result = (
frame.groupby([*on, *index]) # type: ignore[misc]
.agg({v: "size" for v in values_})
.reset_index()
.pivot(columns=on, index=index, values=values_)
)
else:
result = frame.pivot_table(
values=values_,
index=index,
columns=on,
aggfunc=aggregate_function,
margins=False,
observed=True,
)

# Put columns in the right order
if sort_columns:
uniques = {
col: sorted(self._native_frame[col].unique().tolist()) for col in on
}
else:
uniques = {col: self._native_frame[col].unique().tolist() for col in on}
all_lists = [values_, *list(uniques.values())]
ordered_cols = list(product(*all_lists))
result = result.loc[:, ordered_cols]
columns = result.columns.tolist()

n_on = len(on)
if n_on == 1:
new_columns = [
separator.join(col).strip() if len(values_) > 1 else col[-1]
for col in columns
]
else:
new_columns = [
separator.join([col[0], '{"' + '","'.join(col[-n_on:]) + '"}'])
if len(values_) > 1
else '{"' + '","'.join(col[-n_on:]) + '"}'
for col in columns
]
result.columns = new_columns
result.columns.names = [""] # type: ignore[attr-defined]
return self._from_native_frame(result.reset_index())

def to_arrow(self: Self) -> Any:
if self._implementation is Implementation.CUDF:
return self._native_frame.to_arrow(preserve_index=False)
Expand Down
29 changes: 29 additions & 0 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import Sequence

from narwhals._polars.namespace import PolarsNamespace
Expand Down Expand Up @@ -247,6 +248,34 @@ def unpivot(
)
)

def pivot(
self: Self,
on: str | list[str],
*,
index: str | list[str] | None = None,
values: str | list[str] | None = None,
aggregate_function: Literal[
"min", "max", "first", "last", "sum", "mean", "median", "len"
]
| None = None,
maintain_order: bool = True,
sort_columns: bool = False,
separator: str = "_",
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
msg = "`pivot` is only supported for Polars>=1.0.0"
raise NotImplementedError(msg)
result = self._native_frame.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
maintain_order=maintain_order,
sort_columns=sort_columns,
separator=separator,
)
return self._from_native_object(result) # type: ignore[no-any-return]


class PolarsLazyFrame:
def __init__(
Expand Down
85 changes: 85 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,91 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self:
"""
return super().gather_every(n=n, offset=offset)

def pivot(
self: Self,
on: str | list[str],
*,
index: str | list[str] | None = None,
values: str | list[str] | None = None,
aggregate_function: Literal[
"min", "max", "first", "last", "sum", "mean", "median", "len"
]
| None = None,
maintain_order: bool = True,
sort_columns: bool = False,
separator: str = "_",
) -> Self:
r"""
Create a spreadsheet-style pivot table as a DataFrame.
Arguments:
on: Name of the column(s) whose values will be used as the header of the
output DataFrame.
index: One or multiple keys to group by. If None, all remaining columns not
specified on `on` and `values` will be used. At least one of `index` and
`values` must be specified.
values: One or multiple keys to group by. If None, all remaining columns not
specified on `on` and `index` will be used. At least one of `index` and
`values` must be specified.
aggregate_function: Choose from:
- None: no aggregation takes place, will raise error if multiple values
are in group.
- A predefined aggregate function string, one of
{'min', 'max', 'first', 'last', 'sum', 'mean', 'median', 'len'}
maintain_order: Sort the grouped keys so that the output order is predictable.
sort_columns: Sort the transposed columns by name. Default is by order of
discovery.
separator: Used as separator/delimiter in generated column names in case of
multiple `values` columns.
Examples:
>>> import narwhals as nw
>>> import pandas as pd
>>> import polars as pl
>>> data = {
... "ix": [1, 1, 2, 2, 1, 2],
... "col": ["a", "a", "a", "a", "b", "b"],
... "foo": [0, 1, 2, 2, 7, 1],
... "bar": [0, 2, 0, 0, 9, 4],
... }
>>> df_pd = pd.DataFrame(data)
>>> df_pl = pl.DataFrame(data)
Let's define a dataframe-agnostic function:
>>> @nw.narwhalify
... def func(df):
... return df.pivot("col", index="ix", aggregate_function="sum")
We can then pass either pandas or Polars to `func`:
>>> func(df_pd)
ix foo_a foo_b bar_a bar_b
0 1 1 7 2 9
1 2 4 1 0 4
>>> func(df_pl)
shape: (2, 5)
┌─────┬───────┬───────┬───────┬───────┐
│ ix ┆ foo_a ┆ foo_b ┆ bar_a ┆ bar_b │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞═════╪═══════╪═══════╪═══════╪═══════╡
│ 1 ┆ 1 ┆ 7 ┆ 2 ┆ 9 │
│ 2 ┆ 4 ┆ 1 ┆ 0 ┆ 4 │
└─────┴───────┴───────┴───────┴───────┘
"""
return self._from_compliant_dataframe(
self._compliant_frame.pivot(
on=on,
index=index,
values=values,
aggregate_function=aggregate_function,
maintain_order=maintain_order,
sort_columns=sort_columns,
separator=separator,
)
)

def to_arrow(self: Self) -> pa.Table:
r"""
Convert to arrow table.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ lint.ignore = [
"ISC001",
"NPY002",
"PD901", # This is a auxiliary library so dataframe variables have no concrete business meaning
"PD010",
"PLR0911",
"PLR0912",
"PLR0913",
Expand Down
Loading

0 comments on commit a52374a

Please sign in to comment.