Skip to content

Commit

Permalink
perf: improve ArrowGroupBy.__iter__ performances (#1334)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Nov 8, 2024
1 parent 1f3f460 commit 7696b6e
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,16 +80,36 @@ def agg(
)

def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
key_values = self._df.select(*self._keys).unique(subset=self._keys, keep="first")
nw_namespace = self._df.__narwhals_namespace__()
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns)
null_token = "__null_token_value__" # noqa: S105

table = self._df._native_frame
key_values = pc.binary_join_element_wise(
*[pc.cast(table[key], pa.string()) for key in self._keys],
"",
null_handling="replace",
null_replacement=null_token,
)
table = table.add_column(i=0, field_=col_token, column=key_values)

yield from (
(
key_value,
self._df.filter(
*[nw_namespace.col(k) == v for k, v in zip(self._keys, key_value)]
next(
(
t := self._df._from_native_frame(
table.filter(pc.equal(table[col_token], v)).drop([col_token])
)
)
.select(*self._keys)
.head(1)
.iter_rows()
),
t,
)
for key_value in key_values.iter_rows()
for v in pc.unique(key_values)
)


Expand Down

0 comments on commit 7696b6e

Please sign in to comment.