From ae432ecbca3fdd59d5fea7a2e2ee98dbee743659 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 21 Feb 2024 18:26:22 +0000 Subject: [PATCH] revert using apply --- narwhals/pandas_like/group_by.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index 81761c0d9..8244bb7db 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -1,7 +1,6 @@ from __future__ import annotations import collections -import functools from copy import copy from typing import TYPE_CHECKING from typing import Any @@ -13,7 +12,6 @@ from narwhals.pandas_like.utils import get_namespace from narwhals.pandas_like.utils import horizontal_concat from narwhals.pandas_like.utils import is_simple_aggregation -from narwhals.pandas_like.utils import item from narwhals.pandas_like.utils import parse_into_exprs from narwhals.spec import GroupBy as GroupByProtocol from narwhals.spec import IntoExpr @@ -22,7 +20,6 @@ if TYPE_CHECKING: from narwhals.pandas_like.dataframe import DataFrame from narwhals.pandas_like.dataframe import LazyFrame - from narwhals.pandas_like.expr import Expr class GroupBy(GroupByProtocol): @@ -86,28 +83,17 @@ def agg( exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove] out: dict[str, list[Any]] = collections.defaultdict(list) - for keys, _ in grouped: + for keys, df_keys in grouped: for key, name in zip(keys, self._keys): out[name].append(key) - for expr in exprs: - assert expr._output_names is not None - if len(expr._output_names) != 1: - msg = ( - "Multi-output non-elementary aggregations are not supported in group_by.agg.\n" - "Instead of `pl.col('a', 'b').sum()`, please separate them, e.g. " - "`.agg(pl.col('a').sum(), pl.col('b').sum())`\n" - ) - raise ValueError(msg) - - def func(df: Any, expr: Expr) -> Any: - return item(expr._call(self._from_dataframe(df))[0]._series) - - res = grouped.apply(functools.partial(func, expr=expr)) - for key in self._keys: - if key in res.columns: - res = res.drop(columns=[key]) - res.columns = expr._output_names - dfs.append(res) + for expr in exprs: + # q1 from TPC-H is about twice as fast iterating over groups + # manually, compared with using `.apply`. + # maybe reconsider once https://github.com/rapidsai/cudf/issues/15084 + # is addressed? + results_keys = expr._call(self._from_dataframe(df_keys)) + for result_keys in results_keys: + out[result_keys.name].append(result_keys.item()) results_keys = dataframe_from_dict(out, implementation=implementation) results_keys = horizontal_concat(