Skip to content

Commit

Permalink
chore: refactor agg_arrow (#549)
Browse files Browse the repository at this point in the history
* chore: refactor agg_arrow

* less magic indexing
  • Loading branch information
MarcoGorelli authored Jul 18, 2024
1 parent 5a496f6 commit 187b7c9
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def agg_arrow(

if all_simple_aggs:
# Mapping from output name to
# (input_column_name, function_name, pyarrow_output_name) # noqa: ERA001
simple_aggregations: dict[str, tuple[Any, str, str]] = {}
# (aggregation_args, pyarrow_output_name) # noqa: ERA001
simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {}
for expr in exprs:
if expr._depth == 0:
# e.g. agg(nw.len()) # noqa: ERA001
Expand All @@ -85,8 +85,7 @@ def agg_arrow(
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
simple_aggregations[expr._output_names[0]] = (
keys[0],
"count",
(keys[0], "count", pc.CountOptions(mode="all")),
f"{keys[0]}_count",
)
continue
Expand All @@ -102,19 +101,18 @@ def agg_arrow(
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (
root_name,
function_name,
(root_name, function_name),
f"{root_name}_{function_name}",
)

aggs: list[Any] = []
name_mapping = {}
for output_name, named_agg in simple_aggregations.items():
if named_agg[1] == "count":
aggs.append((named_agg[0], named_agg[1], pc.CountOptions(mode="all")))
else:
aggs.append((named_agg[0], named_agg[1]))
name_mapping[named_agg[2]] = output_name
for output_name, (
aggregation_args,
pyarrow_output_name,
) in simple_aggregations.items():
aggs.append(aggregation_args)
name_mapping[pyarrow_output_name] = output_name
result_simple = grouped.aggregate(aggs)
result_simple = result_simple.rename_columns(
[name_mapping.get(col, col) for col in result_simple.column_names]
Expand Down

0 comments on commit 187b7c9

Please sign in to comment.