Skip to content

Commit

Permalink
revert using apply
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Feb 21, 2024
1 parent 9a06adf commit ae432ec
Showing 1 changed file with 9 additions and 23 deletions.
32 changes: 9 additions & 23 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import collections
import functools
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
Expand All @@ -13,7 +12,6 @@
from narwhals.pandas_like.utils import get_namespace
from narwhals.pandas_like.utils import horizontal_concat
from narwhals.pandas_like.utils import is_simple_aggregation
from narwhals.pandas_like.utils import item
from narwhals.pandas_like.utils import parse_into_exprs
from narwhals.spec import GroupBy as GroupByProtocol
from narwhals.spec import IntoExpr
Expand All @@ -22,7 +20,6 @@
if TYPE_CHECKING:
from narwhals.pandas_like.dataframe import DataFrame
from narwhals.pandas_like.dataframe import LazyFrame
from narwhals.pandas_like.expr import Expr


class GroupBy(GroupByProtocol):
Expand Down Expand Up @@ -86,28 +83,17 @@ def agg(
exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove]

out: dict[str, list[Any]] = collections.defaultdict(list)
for keys, _ in grouped:
for keys, df_keys in grouped:
for key, name in zip(keys, self._keys):
out[name].append(key)
for expr in exprs:
assert expr._output_names is not None
if len(expr._output_names) != 1:
msg = (
"Multi-output non-elementary aggregations are not supported in group_by.agg.\n"
"Instead of `pl.col('a', 'b').sum()`, please separate them, e.g. "
"`.agg(pl.col('a').sum(), pl.col('b').sum())`\n"
)
raise ValueError(msg)

def func(df: Any, expr: Expr) -> Any:
return item(expr._call(self._from_dataframe(df))[0]._series)

res = grouped.apply(functools.partial(func, expr=expr))
for key in self._keys:
if key in res.columns:
res = res.drop(columns=[key])
res.columns = expr._output_names
dfs.append(res)
for expr in exprs:
# q1 from TPC-H is about twice as fast iterating over groups
# manually, compared with using `.apply`.
# maybe reconsider once https://github.com/rapidsai/cudf/issues/15084
# is addressed?
results_keys = expr._call(self._from_dataframe(df_keys))
for result_keys in results_keys:
out[result_keys.name].append(result_keys.item())

results_keys = dataframe_from_dict(out, implementation=implementation)
results_keys = horizontal_concat(
Expand Down

0 comments on commit ae432ec

Please sign in to comment.