Skip to content

Commit

Permalink
fix: binary agg should group aware if literal not a scalar (#12043)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Oct 26, 2023
1 parent 929952a commit e7d3f04
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
31 changes: 27 additions & 4 deletions crates/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ impl BinaryExpr {
Ok(ac_l)
}

fn apply_all_literal<'a>(
&self,
mut ac_l: AggregationContext<'a>,
mut ac_r: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
let name = ac_l.series().name().to_string();
ac_l.groups();
ac_r.groups();
polars_ensure!(ac_l.groups.len() == ac_r.groups.len(), ComputeError: "lhs and rhs should have same group length");
let left_s = ac_l.series().rechunk();
let right_s = ac_r.series().rechunk();
let res_s = apply_operator(&left_s, &right_s, self.op)?;
let ca = ListChunked::full(&name, &res_s, ac_l.groups.len());
ac_l.with_update_groups(UpdateGroups::WithSeriesLen);
ac_l.with_series(ca.into_series(), true, Some(&self.expr))?;
Ok(ac_l)
}

fn apply_group_aware<'a>(
&self,
mut ac_l: AggregationContext<'a>,
Expand Down Expand Up @@ -202,10 +220,15 @@ impl PhysicalExpr for BinaryExpr {
let ac_r = result_b?;

match (ac_l.agg_state(), ac_r.agg_state()) {
(
AggState::Literal(_) | AggState::NotAggregated(_),
AggState::Literal(_) | AggState::NotAggregated(_),
) => self.apply_elementwise(ac_l, ac_r, false),
(AggState::Literal(s), AggState::NotAggregated(_))
| (AggState::NotAggregated(_), AggState::Literal(s)) => match s.len() {
1 => self.apply_elementwise(ac_l, ac_r, false),
_ => self.apply_group_aware(ac_l, ac_r),
},
(AggState::Literal(_), AggState::Literal(_)) => self.apply_all_literal(ac_l, ac_r),
(AggState::NotAggregated(_), AggState::NotAggregated(_)) => {
self.apply_elementwise(ac_l, ac_r, false)
},
(
AggState::AggregatedScalar(_) | AggState::Literal(_),
AggState::AggregatedScalar(_) | AggState::Literal(_),
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,29 @@ def test_group_by_when_then_with_binary_and_agg_in_pred_6202() -> None:
}


def test_group_by_binary_agg_with_literal() -> None:
df = pl.DataFrame({"id": ["a", "a", "b", "b"], "value": [1, 2, 3, 4]})

out = df.group_by("id", maintain_order=True).agg(
pl.col("value") + pl.Series([1, 3])
)
assert out.to_dict(False) == {"id": ["a", "b"], "value": [[2, 5], [4, 7]]}

out = df.group_by("id", maintain_order=True).agg(pl.col("value") + pl.lit(1))
assert out.to_dict(False) == {"id": ["a", "b"], "value": [[2, 3], [4, 5]]}

out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.lit(2))
assert out.to_dict(False) == {"id": ["a", "b"], "literal": [3, 3]}

out = df.group_by("id", maintain_order=True).agg(pl.lit(1) + pl.Series([2, 3]))
assert out.to_dict(False) == {"id": ["a", "b"], "literal": [[3, 4], [3, 4]]}

out = df.group_by("id", maintain_order=True).agg(
value=pl.lit(pl.Series([1, 2])) + pl.lit(pl.Series([3, 4]))
)
assert out.to_dict(False) == {"id": ["a", "b"], "value": [[4, 6], [4, 6]]}


@pytest.mark.slow()
@pytest.mark.parametrize("dtype", [pl.Int32, pl.UInt32])
def test_overflow_mean_partitioned_group_by_5194(dtype: pl.PolarsDataType) -> None:
Expand Down

0 comments on commit e7d3f04

Please sign in to comment.