Skip to content

Commit

Permalink
fix(rust, python): rollback cse in groupby: python 0.18.15 (#10491)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 15, 2023
1 parent 11b7583 commit 0357177
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 42 deletions.
82 changes: 41 additions & 41 deletions crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,47 +701,47 @@ impl<'a> RewritingVisitor for CommonSubExprOptimizer<'a> {
lp_arena.replace(arena_idx, lp);
}
},
ALogicalPlan::Aggregate {
input,
keys,
aggs,
options,
maintain_order,
apply,
schema,
} => {
let input_schema = lp_arena.get(*input).schema(lp_arena);
if let Some(aggs) = self.find_cse(
aggs,
&mut expr_arena,
&mut id_array_offsets,
true,
input_schema.as_ref().as_ref(),
)? {
let keys = keys.clone();
let options = options.clone();
let schema = schema.clone();
let apply = apply.clone();
let maintain_order = *maintain_order;
let input = *input;

let lp = ALogicalPlanBuilder::new(input, &mut expr_arena, lp_arena)
.with_columns(aggs.cse_exprs().to_vec(), Default::default())
.build();
let input = lp_arena.add(lp);

let lp = ALogicalPlan::Aggregate {
input,
keys,
aggs: aggs.default_exprs().to_vec(),
options,
schema,
maintain_order,
apply,
};
lp_arena.replace(arena_idx, lp);
}
},
// ALogicalPlan::Aggregate {
// input,
// keys,
// aggs,
// options,
// maintain_order,
// apply,
// schema,
// } => {
// let input_schema = lp_arena.get(*input).schema(lp_arena);
// if let Some(aggs) = self.find_cse(
// aggs,
// &mut expr_arena,
// &mut id_array_offsets,
// true,
// input_schema.as_ref().as_ref(),
// )? {
// let keys = keys.clone();
// let options = options.clone();
// let schema = schema.clone();
// let apply = apply.clone();
// let maintain_order = *maintain_order;
// let input = *input;
//
// let lp = ALogicalPlanBuilder::new(input, &mut expr_arena, lp_arena)
// .with_columns(aggs.cse_exprs().to_vec(), Default::default())
// .build();
// let input = lp_arena.add(lp);
//
// let lp = ALogicalPlan::Aggregate {
// input,
// keys,
// aggs: aggs.default_exprs().to_vec(),
// options,
// schema,
// maintain_order,
// apply,
// };
// lp_arena.replace(arena_idx, lp);
// }
// },
_ => {},
}
PolarsResult::Ok(())
Expand Down
2 changes: 1 addition & 1 deletion py-polars/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-polars"
version = "0.18.14"
version = "0.18.15"
edition = "2021"

[lib]
Expand Down
41 changes: 41 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_windows_cse_excluded() -> None:
}


@pytest.mark.skip()
def test_cse_groupby_10215() -> None:
q = (
pl.DataFrame(
Expand Down Expand Up @@ -292,3 +293,43 @@ def test_cse_10452() -> None:
)
assert "__POLARS_CSE" in q.explain(comm_subexpr_elim=True)
assert q.collect(comm_subexpr_elim=True).to_dict(False) == {"b": [13, 14, 15]}


def test_cse_groupby_ternary_10490() -> None:
df = pl.DataFrame(
{
"a": [1, 1, 2, 2],
"b": [1, 2, 3, 4],
"c": [2, 3, 4, 5],
}
)

assert (
df.lazy()
.groupby("a")
.agg(
[
pl.when(pl.col(col).is_null().all()).then(None).otherwise(1).alias(col)
for col in ["b", "c"]
]
+ [
(pl.col("a").sum() * pl.col("a").sum()).alias("x"),
(pl.col("b").sum() * pl.col("b").sum()).alias("y"),
(pl.col("a").sum() * pl.col("a").sum()).alias("x2"),
((pl.col("a") + 2).sum() * pl.col("a").sum()).alias("x3"),
((pl.col("a") + 2).sum() * pl.col("b").sum()).alias("x4"),
]
)
.collect(comm_subexpr_elim=True)
.sort("a")
.to_dict(False)
) == {
"a": [1, 2],
"b": [1, 1],
"c": [1, 1],
"x": [4, 16],
"y": [9, 49],
"x2": [4, 16],
"x3": [12, 32],
"x4": [18, 56],
}

0 comments on commit 0357177

Please sign in to comment.