diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs index e07def33df53..d88c47274991 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs @@ -526,44 +526,45 @@ impl<'a> RewritingVisitor for CommonSubExprOptimizer<'a> { node.replace(lp); } } - ALogicalPlan::Aggregate { - input, - keys, - aggs, - options, - maintain_order, - apply, - schema, - } => { - if let Some(aggs) = - self.find_cse(aggs, &mut expr_arena, &mut id_array_offsets, true)? - { - let keys = keys.clone(); - let options = options.clone(); - let schema = schema.clone(); - let apply = apply.clone(); - let maintain_order = *maintain_order; - let input = *input; - - let input = node.with_arena_mut(|lp_arena| { - let lp = ALogicalPlanBuilder::new(input, &mut expr_arena, lp_arena) - .with_columns(aggs.cse_exprs().to_vec()) - .build(); - lp_arena.add(lp) - }); - - let lp = ALogicalPlan::Aggregate { - input, - keys, - aggs: aggs.default_exprs().to_vec(), - options, - schema, - maintain_order, - apply, - }; - node.replace(lp); - } - } + // TODO! activate once fixed + // ALogicalPlan::Aggregate { + // input, + // keys, + // aggs, + // options, + // maintain_order, + // apply, + // schema, + // } => { + // if let Some(aggs) = + // self.find_cse(aggs, &mut expr_arena, &mut id_array_offsets, true)? + // { + // let keys = keys.clone(); + // let options = options.clone(); + // let schema = schema.clone(); + // let apply = apply.clone(); + // let maintain_order = *maintain_order; + // let input = *input; + // + // let input = node.with_arena_mut(|lp_arena| { + // let lp = ALogicalPlanBuilder::new(input, &mut expr_arena, lp_arena) + // .with_columns(aggs.cse_exprs().to_vec()) + // .build(); + // lp_arena.add(lp) + // }); + // + // let lp = ALogicalPlan::Aggregate { + // input, + // keys, + // aggs: aggs.default_exprs().to_vec(), + // options, + // schema, + // maintain_order, + // apply, + // }; + // node.replace(lp); + // } + // } _ => {} }; std::mem::swap(self.expr_arena, &mut expr_arena); diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index eb8363aa4059..6d0dfcd465e6 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -214,6 +214,7 @@ def test_cse_expr_selection_streaming(monkeypatch: Any, capfd: Any) -> None: assert "df -> hstack[cse] -> ordered_sink" in err +@pytest.mark.skip(reason="activate once fixed") def test_cse_expr_groupby() -> None: q = pl.LazyFrame( { @@ -269,3 +270,24 @@ def test_windows_cse_excluded() -> None: "c_diff": [None, 2, -2, 1, 1, 1, -4], "c_diff_by_a": [None, 2, -2, None, 1, 1, None], } + + +def test_cse_groupby_10215() -> None: + assert ( + pl.DataFrame( + { + "a": [1, 2, 3], + "b": [1, 1, 1], + } + ) + .lazy() + .groupby( + "a", + ) + .agg( + (pl.col("a").sum() * pl.col("a").sum()).alias("x"), + (pl.col("b").sum() * pl.col("b").sum()).alias("y"), + ) + .collect() + .sort("a") + ).to_dict(False) == {"a": [1, 2, 3], "x": [1, 4, 9], "y": [1, 1, 1]}