From 01a65f0a17892d5c7bc7d124b2c48925c5429c68 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 27 Jun 2024 14:11:12 +0200 Subject: [PATCH] fix: Fix literal slice in group by (#17242) --- crates/polars-expr/src/expressions/mod.rs | 8 +++++--- crates/polars-expr/src/expressions/slice.rs | 14 +++++++++++--- py-polars/tests/unit/operations/test_group_by.py | 9 +++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 5b4b5407a614..87a26a685331 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -378,9 +378,11 @@ impl<'a> AggregationContext<'a> { /// Update the group tuples pub(crate) fn with_groups(&mut self, groups: GroupsProxy) -> &mut Self { - // In case of new groups, a series always needs to be flattened - self.with_series(self.flat_naive().into_owned(), false, None) - .unwrap(); + if let AggState::AggregatedList(_) = self.agg_state() { + // In case of new groups, a series always needs to be flattened + self.with_series(self.flat_naive().into_owned(), false, None) + .unwrap(); + } self.groups = Cow::Owned(groups); // make sure that previous setting is not used self.update_groups = UpdateGroups::No; diff --git a/crates/polars-expr/src/expressions/slice.rs b/crates/polars-expr/src/expressions/slice.rs index 3b64a098073e..d2bc9137a7d3 100644 --- a/crates/polars-expr/src/expressions/slice.rs +++ b/crates/polars-expr/src/expressions/slice.rs @@ -1,5 +1,5 @@ use polars_core::prelude::*; -use polars_core::utils::{slice_offsets, CustomIterTools}; +use polars_core::utils::{slice_offsets, Container, CustomIterTools}; use polars_core::POOL; use rayon::prelude::*; use AnyValue::Null; @@ -106,13 +106,18 @@ impl PhysicalExpr for SliceExpr { let mut ac_length = results.pop().unwrap(); let mut ac_offset = results.pop().unwrap(); - let groups = ac.groups(); - use AggState::*; let groups = match (&ac_offset.state, &ac_length.state) { (Literal(offset), Literal(length)) => { let (offset, length) = extract_args(offset, length, &self.expr)?; + if let Literal(s) = ac.agg_state() { + let s1 = s.slice(offset, length); + ac.with_literal(s1); + return Ok(ac); + } + let groups = ac.groups(); + match groups.as_ref() { GroupsProxy::Idx(groups) => { let groups = groups @@ -134,6 +139,7 @@ impl PhysicalExpr for SliceExpr { } }, (Literal(offset), _) => { + let groups = ac.groups(); let offset = extract_offset(offset, &self.expr)?; let length = ac_length.aggregated(); check_argument(&length, groups, "length", &self.expr)?; @@ -168,6 +174,7 @@ impl PhysicalExpr for SliceExpr { } }, (_, Literal(length)) => { + let groups = ac.groups(); let length = extract_length(length, &self.expr)?; let offset = ac_offset.aggregated(); check_argument(&offset, groups, "offset", &self.expr)?; @@ -202,6 +209,7 @@ impl PhysicalExpr for SliceExpr { } }, _ => { + let groups = ac.groups(); let length = ac_length.aggregated(); let offset = ac_offset.aggregated(); check_argument(&length, groups, "length", &self.expr)?; diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 5b5494a17844..1c7ccd4d2396 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -1130,3 +1130,12 @@ def sub_col_min(column: str, min_column: str) -> pl.Expr: pl.List(pl.Float64), pl.List(pl.Float64), ] + + +def test_grouped_slice_literals() -> None: + assert pl.DataFrame({"idx": [1, 2, 3]}).group_by(True).agg( + x=pl.lit([1, 2]).slice( + -1, 1 + ), # slices a list of 1 element, so remains the same element + x2=pl.lit(pl.Series([1, 2])).slice(-1, 1), + ).to_dict(as_series=False) == {"literal": [True], "x": [[1, 2]], "x2": [2]}