diff --git a/crates/polars-lazy/src/physical_plan/expressions/binary.rs b/crates/polars-lazy/src/physical_plan/expressions/binary.rs index ef20e7abdc33..4a2252c260d9 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/binary.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/binary.rs @@ -102,6 +102,24 @@ impl BinaryExpr { Ok(ac_l) } + fn apply_all_literal<'a>( + &self, + mut ac_l: AggregationContext<'a>, + mut ac_r: AggregationContext<'a>, + ) -> PolarsResult> { + let name = ac_l.series().name().to_string(); + ac_l.groups(); + ac_r.groups(); + polars_ensure!(ac_l.groups.len() == ac_r.groups.len(), ComputeError: "lhs and rhs should have same group length"); + let left_s = ac_l.series().rechunk(); + let right_s = ac_r.series().rechunk(); + let res_s = apply_operator(&left_s, &right_s, self.op)?; + let ca = ListChunked::full(&name, &res_s, ac_l.groups.len()); + ac_l.with_update_groups(UpdateGroups::WithSeriesLen); + ac_l.with_series(ca.into_series(), true, Some(&self.expr))?; + Ok(ac_l) + } + fn apply_group_aware<'a>( &self, mut ac_l: AggregationContext<'a>, @@ -202,10 +220,15 @@ impl PhysicalExpr for BinaryExpr { let ac_r = result_b?; match (ac_l.agg_state(), ac_r.agg_state()) { - ( - AggState::Literal(_) | AggState::NotAggregated(_), - AggState::Literal(_) | AggState::NotAggregated(_), - ) => self.apply_elementwise(ac_l, ac_r, false), + (AggState::Literal(s), AggState::NotAggregated(_)) + | (AggState::NotAggregated(_), AggState::Literal(s)) => match s.len() { + 1 => self.apply_elementwise(ac_l, ac_r, false), + _ => self.apply_group_aware(ac_l, ac_r), + }, + (AggState::Literal(_), AggState::Literal(_)) => self.apply_all_literal(ac_l, ac_r), + (AggState::NotAggregated(_), AggState::NotAggregated(_)) => { + self.apply_elementwise(ac_l, ac_r, false) + }, ( AggState::AggregatedScalar(_) | AggState::Literal(_), AggState::AggregatedScalar(_) | AggState::Literal(_), diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 448bddac5ef3..1b1e98e3f3f6 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -443,6 +443,29 @@ def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None: } +def test_group_by_binary_agg_with_literal() -> None: + df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]}) + + out = df.group_by("id", maintain_order=True).agg( + pl.col("value") + pl.Series([1, 3]) + ) + assert out.to_dict(False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]} + + out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1)) + assert out.to_dict(False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]} + + out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2)) + assert out.to_dict(False) == {"id": ["a", "b"], "literal": [3, 3]} + + out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3])) + assert out.to_dict(False) == {"id": ["a", "b"], "literal": [[3, 4], [3, 4]]} + + out = df.group_by("id", maintain_order=True).agg( + value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4])) + ) + assert out.to_dict(False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]} + + @pytest.mark.slow() @pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32]) def test_overflow_mean_partitioned_group_by_5194(dtype: pl.PolarsDataType) -> None: