From 7696b6e15d971a3dcc2230c4669d823754509f46 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Fri, 8 Nov 2024 10:50:58 +0100 Subject: [PATCH] perf: improve `ArrowGroupBy.__iter__` performances (#1334) --- narwhals/_arrow/group_by.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 991a96a51..030fcd113 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -8,6 +8,7 @@ from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs +from narwhals.utils import generate_temporary_column_name from narwhals.utils import remove_prefix if TYPE_CHECKING: @@ -79,16 +80,36 @@ def agg( ) def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]: - key_values = self._df.select(*self._keys).unique(subset=self._keys, keep="first") - nw_namespace = self._df.__narwhals_namespace__() + import pyarrow as pa # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import + + col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns) + null_token = "__null_token_value__" # noqa: S105 + + table = self._df._native_frame + key_values = pc.binary_join_element_wise( + *[pc.cast(table[key], pa.string()) for key in self._keys], + "", + null_handling="replace", + null_replacement=null_token, + ) + table = table.add_column(i=0, field_=col_token, column=key_values) + yield from ( ( - key_value, - self._df.filter( - *[nw_namespace.col(k) == v for k, v in zip(self._keys, key_value)] + next( + ( + t := self._df._from_native_frame( + table.filter(pc.equal(table[col_token], v)).drop([col_token]) + ) + ) + .select(*self._keys) + .head(1) + .iter_rows() ), + t, ) - for key_value in key_values.iter_rows() + for v in pc.unique(key_values) )