diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index 983c9f5da..c795fa183 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -2,6 +2,7 @@ import collections import os +import warnings from copy import copy from typing import TYPE_CHECKING from typing import Any @@ -101,12 +102,21 @@ def agg_pandas( simple_aggs = [] complex_aggs = [] for expr in exprs: - if is_simple_aggregation(expr): + if is_simple_aggregation(expr, implementation="pandas"): simple_aggs.append(expr) else: complex_aggs.append(expr) simple_aggregations = {} for expr in simple_aggs: + if expr._depth == 0: + # e.g. agg(pl.len()) + assert expr._output_names is not None + for output_name in expr._output_names: + simple_aggregations[output_name] = pd.NamedAgg( + column=keys[0], aggfunc=expr._function_name.replace("len", "size") + ) + continue + assert expr._root_names is not None assert expr._output_names is not None for root_name, output_name in zip(expr._root_names, expr._output_names): @@ -124,17 +134,31 @@ def func(df: Any) -> Any: out_names.append(result_keys.name) return pd.Series(out_group, index=out_names) - if parse(pd.__version__) < parse("2.2.0"): - result_complex = grouped.apply(func) - else: - result_complex = grouped.apply(func, include_groups=False) + if complex_aggs: + warnings.warn( + "Found complex group-by expression, which can't be expressed efficiently with the " + "pandas API. If you can, please rewrite your query such that group-by aggregations " + "are simple (e.g. mean, std, min, max, ...).", + UserWarning, + stacklevel=2, + ) + if parse(pd.__version__) < parse("2.2.0"): + result_complex = grouped.apply(func) + else: + result_complex = grouped.apply(func, include_groups=False) - if result_simple is not None: + if result_simple is not None and not complex_aggs: + result = result_simple + elif result_simple is not None and complex_aggs: result = pd.concat( - [result_simple, result_complex.drop(columns=keys)], axis=1, copy=False + [result_simple, result_complex.drop(columns=keys)], + axis=1, + copy=False, ) - else: + elif complex_aggs: result = result_complex + else: + raise AssertionError("At least one aggregation should have been passed") return from_dataframe(result.loc[:, output_names]) @@ -149,7 +173,7 @@ def agg_generic( # noqa: PLR0913 dfs: list[Any] = [] to_remove: list[int] = [] for i, expr in enumerate(exprs): - if is_simple_aggregation(expr): + if is_simple_aggregation(expr, implementation): dfs.append(evaluate_simple_aggregation(expr, grouped)) to_remove.append(i) exprs = [expr for i, expr in enumerate(exprs) if i not in to_remove] diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 8093e5227..7f0351003 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from copy import copy from typing import TYPE_CHECKING from typing import Any @@ -195,13 +196,20 @@ def item(s: Any) -> Any: return s.iloc[0] -def is_simple_aggregation(expr: PandasExpr) -> bool: +def is_simple_aggregation(expr: PandasExpr, implementation: str) -> bool: return ( expr._function_name is not None and expr._depth is not None and expr._depth < 2 # todo: avoid this one? - and expr._root_names is not None + and ( + expr._root_names is not None + or ( + expr._depth == 0 + and implementation == "pandas" + and not os.environ.get("NARWHALS_FORCE_GENERIC") + ) + ) ) diff --git a/tests/tpch_q1_test.py b/tests/tpch_q1_test.py index c699f5797..e07dbea2a 100644 --- a/tests/tpch_q1_test.py +++ b/tests/tpch_q1_test.py @@ -24,24 +24,24 @@ def test_q1(df_raw: Any) -> None: df = nw.LazyFrame(df_raw) query_result = ( df.filter(nw.col("l_shipdate") <= var_1) + .with_columns( + disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")), + charge=( + nw.col("l_extendedprice") + * (1.0 - nw.col("l_discount")) + * (1.0 + nw.col("l_tax")) + ), + ) .group_by(["l_returnflag", "l_linestatus"]) .agg( [ - nw.col("l_quantity").sum().alias("sum_qty"), - nw.col("l_extendedprice").sum().alias("sum_base_price"), - (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) - .sum() - .alias("sum_disc_price"), - ( - nw.col("l_extendedprice") - * (1.0 - nw.col("l_discount")) - * (1.0 + nw.col("l_tax")) - ) - .sum() - .alias("sum_charge"), - nw.col("l_quantity").mean().alias("avg_qty"), - nw.col("l_extendedprice").mean().alias("avg_price"), - nw.col("l_discount").mean().alias("avg_disc"), + nw.sum("l_quantity").alias("sum_qty"), + nw.sum("l_extendedprice").alias("sum_base_price"), + nw.sum("disc_price").alias("sum_disc_price"), + nw.col("charge").sum().alias("sum_charge"), + nw.mean("l_quantity").alias("avg_qty"), + nw.mean("l_extendedprice").alias("avg_price"), + nw.mean("l_discount").alias("avg_disc"), nw.len().alias("count_order"), ], ) @@ -85,21 +85,21 @@ def test_q1_w_pandas_agg_generic_path(df_raw: Any) -> None: df = nw.LazyFrame(df_raw) query_result = ( df.filter(nw.col("l_shipdate") <= var_1) + .with_columns( + disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")), + charge=( + nw.col("l_extendedprice") + * (1.0 - nw.col("l_discount")) + * (1.0 + nw.col("l_tax")) + ), + ) .group_by(["l_returnflag", "l_linestatus"]) .agg( [ nw.sum("l_quantity").alias("sum_qty"), nw.sum("l_extendedprice").alias("sum_base_price"), - (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) - .sum() - .alias("sum_disc_price"), - ( - nw.col("l_extendedprice") - * (1.0 - nw.col("l_discount")) - * (1.0 + nw.col("l_tax")) - ) - .sum() - .alias("sum_charge"), + nw.sum("disc_price").alias("sum_disc_price"), + nw.col("charge").sum().alias("sum_charge"), nw.mean("l_quantity").alias("avg_qty"), nw.mean("l_extendedprice").alias("avg_price"), nw.mean("l_discount").alias("avg_disc"), diff --git a/tpch/q1.py b/tpch/q1.py index 4f179bc58..965069548 100644 --- a/tpch/q1.py +++ b/tpch/q1.py @@ -13,21 +13,21 @@ def q1(df_raw: Any) -> Any: df = nw.LazyFrame(df_raw) result = ( df.filter(nw.col("l_shipdate") <= var_1) + .with_columns( + disc_price=nw.col("l_extendedprice") * (1 - nw.col("l_discount")), + charge=( + nw.col("l_extendedprice") + * (1.0 - nw.col("l_discount")) + * (1.0 + nw.col("l_tax")) + ), + ) .group_by(["l_returnflag", "l_linestatus"]) .agg( [ nw.sum("l_quantity").alias("sum_qty"), nw.sum("l_extendedprice").alias("sum_base_price"), - (nw.col("l_extendedprice") * (1 - nw.col("l_discount"))) - .sum() - .alias("sum_disc_price"), - ( - nw.col("l_extendedprice") - * (1.0 - nw.col("l_discount")) - * (1.0 + nw.col("l_tax")) - ) - .sum() - .alias("sum_charge"), + nw.sum("disc_price").alias("sum_disc_price"), + nw.col("charge").sum().alias("sum_charge"), nw.mean("l_quantity").alias("avg_qty"), nw.mean("l_extendedprice").alias("avg_price"), nw.mean("l_discount").alias("avg_disc"),