From 187b7c9752ee7a95621f1a0cc3332ff982a6cf5a Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 18 Jul 2024 19:13:29 +0100 Subject: [PATCH] chore: refactor agg_arrow (#549) * chore: refactor agg_arrow * less magic indexing --- narwhals/_arrow/group_by.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 279d32b7c..59090c4d4 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -74,8 +74,8 @@ def agg_arrow( if all_simple_aggs: # Mapping from output name to - # (input_column_name, function_name, pyarrow_output_name) # noqa: ERA001 - simple_aggregations: dict[str, tuple[Any, str, str]] = {} + # (aggregation_args, pyarrow_output_name) # noqa: ERA001 + simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {} for expr in exprs: if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 @@ -85,8 +85,7 @@ def agg_arrow( msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" raise AssertionError(msg) simple_aggregations[expr._output_names[0]] = ( - keys[0], - "count", + (keys[0], "count", pc.CountOptions(mode="all")), f"{keys[0]}_count", ) continue @@ -102,19 +101,18 @@ def agg_arrow( function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name) for root_name, output_name in zip(expr._root_names, expr._output_names): simple_aggregations[output_name] = ( - root_name, - function_name, + (root_name, function_name), f"{root_name}_{function_name}", ) aggs: list[Any] = [] name_mapping = {} - for output_name, named_agg in simple_aggregations.items(): - if named_agg[1] == "count": - aggs.append((named_agg[0], named_agg[1], pc.CountOptions(mode="all"))) - else: - aggs.append((named_agg[0], named_agg[1])) - name_mapping[named_agg[2]] = output_name + for output_name, ( + aggregation_args, + pyarrow_output_name, + ) in simple_aggregations.items(): + aggs.append(aggregation_args) + name_mapping[pyarrow_output_name] = output_name result_simple = grouped.aggregate(aggs) result_simple = result_simple.rename_columns( [name_mapping.get(col, col) for col in result_simple.column_names]