From a52374ad92b298667492e7ab48ed47df16ed0812 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:21:46 +0100 Subject: [PATCH] feat: add `DataFrame.pivot` for pandas like and Polars backend (#546) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> --- docs/api-reference/dataframe.md | 1 + narwhals/_pandas_like/dataframe.py | 84 +++++++++ narwhals/_polars/dataframe.py | 29 +++ narwhals/dataframe.py | 85 +++++++++ pyproject.toml | 1 + tests/frame/pivot_test.py | 240 +++++++++++++++++++++++++ utils/generate_backend_completeness.py | 2 +- 7 files changed, 441 insertions(+), 1 deletion(-) create mode 100644 tests/frame/pivot_test.py diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 447acbc15..00ff2122e 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -26,6 +26,7 @@ - lazy - null_count - pipe + - pivot - rename - row - rows diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 3bf60f845..25a5c236f5 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -827,6 +827,90 @@ def clone(self: Self) -> Self: def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self._from_native_frame(self._native_frame.iloc[offset::n]) + def pivot( + self: Self, + on: str | list[str], + *, + index: str | list[str] | None, + values: str | list[str] | None, + aggregate_function: Any | None, + maintain_order: bool, + sort_columns: bool, + separator: str = "_", + ) -> Self: + if self._implementation is Implementation.PANDAS and ( + self._backend_version < (1, 1) + ): # pragma: no cover + msg = "pivot is only supported for pandas>=1.1" + raise NotImplementedError(msg) + if self._implementation is Implementation.MODIN: + msg = "pivot is not supported for Modin backend due to https://github.com/modin-project/modin/issues/7409." + raise NotImplementedError(msg) + from itertools import product + + frame = self._native_frame + + if isinstance(on, str): + on = [on] + if isinstance(index, str): + index = [index] + + if values is None: + values_ = [c for c in self.columns if c not in {*on, *index}] # type: ignore[misc] + elif isinstance(values, str): # pragma: no cover + values_ = [values] + else: + values_ = values + + if aggregate_function is None: + result = frame.pivot(columns=on, index=index, values=values_) + + elif aggregate_function == "len": + result = ( + frame.groupby([*on, *index]) # type: ignore[misc] + .agg({v: "size" for v in values_}) + .reset_index() + .pivot(columns=on, index=index, values=values_) + ) + else: + result = frame.pivot_table( + values=values_, + index=index, + columns=on, + aggfunc=aggregate_function, + margins=False, + observed=True, + ) + + # Put columns in the right order + if sort_columns: + uniques = { + col: sorted(self._native_frame[col].unique().tolist()) for col in on + } + else: + uniques = {col: self._native_frame[col].unique().tolist() for col in on} + all_lists = [values_, *list(uniques.values())] + ordered_cols = list(product(*all_lists)) + result = result.loc[:, ordered_cols] + columns = result.columns.tolist() + + n_on = len(on) + if n_on == 1: + new_columns = [ + separator.join(col).strip() if len(values_) > 1 else col[-1] + for col in columns + ] + else: + new_columns = [ + separator.join([col[0], '{"' + '","'.join(col[-n_on:]) + '"}']) + if len(values_) > 1 + else '{"' + '","'.join(col[-n_on:]) + '"}' + for col in columns + ] + result.columns = new_columns + result.columns.names = [""] # type: ignore[attr-defined] + return self._from_native_frame(result.reset_index()) + def to_arrow(self: Self) -> Any: if self._implementation is Implementation.CUDF: return self._native_frame.to_arrow(preserve_index=False) diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index c136cc675..fffe5092d 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Literal from typing import Sequence from narwhals._polars.namespace import PolarsNamespace @@ -247,6 +248,34 @@ def unpivot( ) ) + def pivot( + self: Self, + on: str | list[str], + *, + index: str | list[str] | None = None, + values: str | list[str] | None = None, + aggregate_function: Literal[ + "min", "max", "first", "last", "sum", "mean", "median", "len" + ] + | None = None, + maintain_order: bool = True, + sort_columns: bool = False, + separator: str = "_", + ) -> Self: + if self._backend_version < (1, 0, 0): # pragma: no cover + msg = "`pivot` is only supported for Polars>=1.0.0" + raise NotImplementedError(msg) + result = self._native_frame.pivot( + on, + index=index, + values=values, + aggregate_function=aggregate_function, + maintain_order=maintain_order, + sort_columns=sort_columns, + separator=separator, + ) + return self._from_native_object(result) # type: ignore[no-any-return] + class PolarsLazyFrame: def __init__( diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index bb163b28d..34482072d 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -2567,6 +2567,91 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: """ return super().gather_every(n=n, offset=offset) + def pivot( + self: Self, + on: str | list[str], + *, + index: str | list[str] | None = None, + values: str | list[str] | None = None, + aggregate_function: Literal[ + "min", "max", "first", "last", "sum", "mean", "median", "len" + ] + | None = None, + maintain_order: bool = True, + sort_columns: bool = False, + separator: str = "_", + ) -> Self: + r""" + Create a spreadsheet-style pivot table as a DataFrame. + + Arguments: + on: Name of the column(s) whose values will be used as the header of the + output DataFrame. + index: One or multiple keys to group by. If None, all remaining columns not + specified on `on` and `values` will be used. At least one of `index` and + `values` must be specified. + values: One or multiple keys to group by. If None, all remaining columns not + specified on `on` and `index` will be used. At least one of `index` and + `values` must be specified. + aggregate_function: Choose from: + - None: no aggregation takes place, will raise error if multiple values + are in group. + - A predefined aggregate function string, one of + {'min', 'max', 'first', 'last', 'sum', 'mean', 'median', 'len'} + maintain_order: Sort the grouped keys so that the output order is predictable. + sort_columns: Sort the transposed columns by name. Default is by order of + discovery. + separator: Used as separator/delimiter in generated column names in case of + multiple `values` columns. + + Examples: + >>> import narwhals as nw + >>> import pandas as pd + >>> import polars as pl + >>> data = { + ... "ix": [1, 1, 2, 2, 1, 2], + ... "col": ["a", "a", "a", "a", "b", "b"], + ... "foo": [0, 1, 2, 2, 7, 1], + ... "bar": [0, 2, 0, 0, 9, 4], + ... } + >>> df_pd = pd.DataFrame(data) + >>> df_pl = pl.DataFrame(data) + + Let's define a dataframe-agnostic function: + + >>> @nw.narwhalify + ... def func(df): + ... return df.pivot("col", index="ix", aggregate_function="sum") + + We can then pass either pandas or Polars to `func`: + + >>> func(df_pd) + ix foo_a foo_b bar_a bar_b + 0 1 1 7 2 9 + 1 2 4 1 0 4 + >>> func(df_pl) + shape: (2, 5) + ┌─────┬───────┬───────┬───────┬───────┐ + │ ix ┆ foo_a ┆ foo_b ┆ bar_a ┆ bar_b │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ + ╞═════╪═══════╪═══════╪═══════╪═══════╡ + │ 1 ┆ 1 ┆ 7 ┆ 2 ┆ 9 │ + │ 2 ┆ 4 ┆ 1 ┆ 0 ┆ 4 │ + └─────┴───────┴───────┴───────┴───────┘ + """ + return self._from_compliant_dataframe( + self._compliant_frame.pivot( + on=on, + index=index, + values=values, + aggregate_function=aggregate_function, + maintain_order=maintain_order, + sort_columns=sort_columns, + separator=separator, + ) + ) + def to_arrow(self: Self) -> pa.Table: r""" Convert to arrow table. diff --git a/pyproject.toml b/pyproject.toml index 653221002..9f20bcf7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ lint.ignore = [ "ISC001", "NPY002", "PD901", # This is a auxiliary library so dataframe variables have no concrete business meaning + "PD010", "PLR0911", "PLR0912", "PLR0913", diff --git a/tests/frame/pivot_test.py b/tests/frame/pivot_test.py new file mode 100644 index 000000000..76217361e --- /dev/null +++ b/tests/frame/pivot_test.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from contextlib import nullcontext as does_not_raise +from typing import Any + +import polars as pl +import pytest + +import narwhals.stable.v1 as nw +from tests.utils import PANDAS_VERSION +from tests.utils import POLARS_VERSION +from tests.utils import assert_equal_data + +data = { + "ix": [1, 2, 1, 1, 2, 2], + "iy": [1, 2, 2, 1, 2, 1], + "col": ["b", "b", "a", "a", "a", "a"], + "col_b": ["x", "y", "x", "y", "x", "y"], + "foo": [7, 1, 0, 1, 2, 2], + "bar": [9, 4, 0, 2, 0, 0], +} + +data_no_dups = { + "ix": [1, 1, 2, 2], + "col": ["a", "b", "a", "b"], + "foo": [1, 2, 3, 4], + "bar": ["x", "y", "z", "w"], +} + + +@pytest.mark.parametrize( + ("agg_func", "expected"), + [ + ( + "min", + { + "ix": [1, 2], + "foo_a": [0, 2], + "foo_b": [7, 1], + "bar_a": [0, 0], + "bar_b": [9, 4], + }, + ), + ( + "max", + { + "ix": [1, 2], + "foo_a": [1, 2], + "foo_b": [7, 1], + "bar_a": [2, 0], + "bar_b": [9, 4], + }, + ), + ( + "first", + { + "ix": [1, 2], + "foo_a": [0, 2], + "foo_b": [7, 1], + "bar_a": [0, 0], + "bar_b": [9, 4], + }, + ), + ( + "last", + { + "ix": [1, 2], + "foo_a": [1, 2], + "foo_b": [7, 1], + "bar_a": [2, 0], + "bar_b": [9, 4], + }, + ), + ( + "sum", + { + "ix": [1, 2], + "foo_a": [1, 4], + "foo_b": [7, 1], + "bar_a": [2, 0], + "bar_b": [9, 4], + }, + ), + ( + "mean", + { + "ix": [1, 2], + "foo_a": [0.5, 2.0], + "foo_b": [7.0, 1.0], + "bar_a": [1.0, 0.0], + "bar_b": [9.0, 4.0], + }, + ), + ( + "median", + { + "ix": [1, 2], + "foo_a": [0.5, 2.0], + "foo_b": [7.0, 1.0], + "bar_a": [1.0, 0.0], + "bar_b": [9.0, 4.0], + }, + ), + ( + "len", + { + "ix": [1, 2], + "foo_a": [2, 2], + "foo_b": [1, 1], + "bar_a": [2, 2], + "bar_b": [1, 1], + }, + ), + ], +) +@pytest.mark.parametrize(("on", "index"), [("col", "ix"), (["col"], ["ix"])]) +def test_pivot( + constructor_eager: Any, + agg_func: str, + expected: dict[str, list[Any]], + on: str | list[str], + index: str | list[str], + request: pytest.FixtureRequest, +) -> None: + if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")): + request.applymarker(pytest.mark.xfail) + if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) + ): + # not implemented + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df.pivot( + on=on, + index=index, + values=["foo", "bar"], + aggregate_function=agg_func, # type: ignore[arg-type] + ) + + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("data_", "context"), + [ + (data_no_dups, does_not_raise()), + (data, pytest.raises((ValueError, pl.exceptions.ComputeError))), + ], +) +def test_pivot_no_agg( + request: Any, constructor_eager: Any, data_: Any, context: Any +) -> None: + if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")): + request.applymarker(pytest.mark.xfail) + if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) + ): + # not implemented + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor_eager(data_), eager_only=True) + with context: + df.pivot("col", index="ix", aggregate_function=None) + + +@pytest.mark.parametrize( + ("sort_columns", "expected"), + [ + (True, ["ix", "foo_a", "foo_b", "bar_a", "bar_b"]), + (False, ["ix", "foo_b", "foo_a", "bar_b", "bar_a"]), + ], +) +def test_pivot_sort_columns( + request: Any, constructor_eager: Any, sort_columns: Any, expected: list[str] +) -> None: + if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")): + request.applymarker(pytest.mark.xfail) + if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) + ): + # not implemented + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df.pivot( + on="col", + index="ix", + values=["foo", "bar"], + aggregate_function="sum", + sort_columns=sort_columns, + ) + assert result.columns == expected + + +@pytest.mark.parametrize( + ("kwargs", "expected"), + [ + ({"on": ["col"], "values": ["foo"]}, ["ix", "b", "a"]), + ( + {"on": ["col"], "values": ["foo", "bar"]}, + ["ix", "foo_b", "foo_a", "bar_b", "bar_a"], + ), + ( + {"on": ["col", "col_b"], "values": ["foo"]}, + ["ix", '{"b","x"}', '{"b","y"}', '{"a","x"}', '{"a","y"}'], + ), + ( + {"on": ["col", "col_b"], "values": ["foo", "bar"]}, + [ + "ix", + 'foo_{"b","x"}', + 'foo_{"b","y"}', + 'foo_{"a","x"}', + 'foo_{"a","y"}', + 'bar_{"b","x"}', + 'bar_{"b","y"}', + 'bar_{"a","x"}', + 'bar_{"a","y"}', + ], + ), + ], +) +def test_pivot_names_out( + request: Any, constructor_eager: Any, kwargs: Any, expected: list[str] +) -> None: + if any(x in str(constructor_eager) for x in ("pyarrow_table", "modin")): + request.applymarker(pytest.mark.xfail) + if ("polars" in str(constructor_eager) and POLARS_VERSION < (1, 0)) or ( + "pandas" in str(constructor_eager) and PANDAS_VERSION < (1, 1) + ): + # not implemented + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor_eager(data), eager_only=True) + + result = ( + df.pivot(aggregate_function="min", index="ix", **kwargs).collect_schema().names() + ) + assert result == expected diff --git a/utils/generate_backend_completeness.py b/utils/generate_backend_completeness.py index d7d05daa2..2aafc5671 100644 --- a/utils/generate_backend_completeness.py +++ b/utils/generate_backend_completeness.py @@ -68,7 +68,7 @@ def render_table_and_write_to_output( results: list[pl.DataFrame], title: str, output_filename: str ) -> None: results = ( - pl.concat(results) # noqa: PD010 + pl.concat(results) .with_columns(supported=pl.lit(":white_check_mark:")) .pivot(on="Backend", values="supported", index=["Class", "Method"]) .filter(pl.col("narwhals").is_not_null())