Skip to content

Commit

Permalink
fix(rust, python): fix cse windows (#10197)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 31, 2023
1 parent 47b91ab commit f8c4c4e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
43 changes: 30 additions & 13 deletions crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,31 +132,42 @@ impl ExprIdentifierVisitor<'_> {
unreachable!()
}

fn accept_node(&self, ae: &AExpr) -> bool {
/// return `None` -> node is accepted
/// return `Some(_)` node is not accepted and apply the given recursion operation
fn accept_node(&self, ae: &AExpr) -> Option<VisitRecursion> {
match ae {
// window expressions should `evaluate_on_groups`, not `evaluate`
// so we shouldn't cache the children as they are evaluated incorrectly
AExpr::Window { .. } => Some(VisitRecursion::Skip),
// skip window functions for now until we properly implemented the physical side
AExpr::Column(_)
| AExpr::Count
| AExpr::Literal(_)
| AExpr::Window { .. }
| AExpr::Alias(_, _) => false,
AExpr::Column(_) | AExpr::Count | AExpr::Literal(_) | AExpr::Alias(_, _) => {
Some(VisitRecursion::Continue)
}
#[cfg(feature = "random")]
AExpr::Function {
function: FunctionExpr::Random { .. },
..
} => false,
} => Some(VisitRecursion::Continue),
_ => {
// during aggregation we only store elementwise operation in the state
// other operations we cannot add to the state as they have the output size of the
// groups, not the original dataframe
if self.is_groupby {
match ae {
AExpr::Agg(_) | AExpr::AnonymousFunction { .. } => false,
AExpr::Function { options, .. } => !options.is_groups_sensitive(),
_ => true,
AExpr::Agg(_) | AExpr::AnonymousFunction { .. } => {
Some(VisitRecursion::Continue)
}
AExpr::Function { options, .. } => {
if options.is_groups_sensitive() {
Some(VisitRecursion::Continue)
} else {
None
}
}
_ => None,
}
} else {
true
None
}
}
}
Expand Down Expand Up @@ -186,11 +197,11 @@ impl Visitor for ExprIdentifierVisitor<'_> {

// if we don't store this node
// we only push the visit_stack, so the parents know the trail
if !self.accept_node(ae) {
if let Some(recurse) = self.accept_node(ae) {
self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx;
self.visit_stack
.push(VisitRecord::SubExprId(Rc::from(format!("{:E}", ae))));
return Ok(VisitRecursion::Continue);
return Ok(recurse);
}

// create the id of this node
Expand Down Expand Up @@ -290,6 +301,12 @@ impl RewritingVisitor for CommonSubExprRewriter<'_> {
return Ok(RewriteRecursion::Stop);
}

// check if we can accept node
// we don't traverse those children
if matches!(ae_node.to_aexpr(), AExpr::Window { .. }) {
return Ok(RewriteRecursion::Stop);
}

let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1;

// placeholder not overwritten, so we can skip this sub-expression
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,25 @@ def test_cse_expr_groupby() -> None:
for streaming in [True, False]:
out = q.collect(comm_subexpr_elim=True, streaming=streaming)
assert_frame_equal(out, expected)


def test_windows_cse_excluded() -> None:
lf = pl.LazyFrame(
data=[
("a", "aaa", 1),
("a", "bbb", 3),
("a", "ccc", 1),
("c", "xxx", 2),
("c", "yyy", 3),
("c", "zzz", 4),
("b", "qqq", 0),
],
schema=["a", "b", "c"],
)
assert lf.select(
c_diff=pl.col("c").diff(1),
c_diff_by_a=pl.col("c").diff(1).over("a"),
).collect(comm_subexpr_elim=True).to_dict(False) == {
"c_diff": [None, 2, -2, 1, 1, 1, -4],
"c_diff_by_a": [None, 2, -2, None, 1, 1, None],
}

0 comments on commit f8c4c4e

Please sign in to comment.