From f8c4c4eea8203db581dece729550d3ab0aedbbba Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 31 Jul 2023 13:28:07 +0200 Subject: [PATCH] fix(rust, python): fix cse windows (#10197) --- .../src/logical_plan/optimizer/cse_expr.rs | 43 +++++++++++++------ py-polars/tests/unit/test_cse.py | 22 ++++++++++ 2 files changed, 52 insertions(+), 13 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs index 0fb9d86c09c5..e07def33df53 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/cse_expr.rs @@ -132,31 +132,42 @@ impl ExprIdentifierVisitor<'_> { unreachable!() } - fn accept_node(&self, ae: &AExpr) -> bool { + /// return `None` -> node is accepted + /// return `Some(_)` node is not accepted and apply the given recursion operation + fn accept_node(&self, ae: &AExpr) -> Option { match ae { + // window expressions should `evaluate_on_groups`, not `evaluate` + // so we shouldn't cache the children as they are evaluated incorrectly + AExpr::Window { .. } => Some(VisitRecursion::Skip), // skip window functions for now until we properly implemented the physical side - AExpr::Column(_) - | AExpr::Count - | AExpr::Literal(_) - | AExpr::Window { .. } - | AExpr::Alias(_, _) => false, + AExpr::Column(_) | AExpr::Count | AExpr::Literal(_) | AExpr::Alias(_, _) => { + Some(VisitRecursion::Continue) + } #[cfg(feature = "random")] AExpr::Function { function: FunctionExpr::Random { .. }, .. - } => false, + } => Some(VisitRecursion::Continue), _ => { // during aggregation we only store elementwise operation in the state // other operations we cannot add to the state as they have the output size of the // groups, not the original dataframe if self.is_groupby { match ae { - AExpr::Agg(_) | AExpr::AnonymousFunction { .. } => false, - AExpr::Function { options, .. } => !options.is_groups_sensitive(), - _ => true, + AExpr::Agg(_) | AExpr::AnonymousFunction { .. } => { + Some(VisitRecursion::Continue) + } + AExpr::Function { options, .. } => { + if options.is_groups_sensitive() { + Some(VisitRecursion::Continue) + } else { + None + } + } + _ => None, } } else { - true + None } } } @@ -186,11 +197,11 @@ impl Visitor for ExprIdentifierVisitor<'_> { // if we don't store this node // we only push the visit_stack, so the parents know the trail - if !self.accept_node(ae) { + if let Some(recurse) = self.accept_node(ae) { self.identifier_array[pre_visit_idx + self.id_array_offset].0 = self.post_visit_idx; self.visit_stack .push(VisitRecord::SubExprId(Rc::from(format!("{:E}", ae)))); - return Ok(VisitRecursion::Continue); + return Ok(recurse); } // create the id of this node @@ -290,6 +301,12 @@ impl RewritingVisitor for CommonSubExprRewriter<'_> { return Ok(RewriteRecursion::Stop); } + // check if we can accept node + // we don't traverse those children + if matches!(ae_node.to_aexpr(), AExpr::Window { .. }) { + return Ok(RewriteRecursion::Stop); + } + let id = &self.identifier_array[self.visited_idx + self.id_array_offset].1; // placeholder not overwritten, so we can skip this sub-expression diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index c10b22b32ebe..eb8363aa4059 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -247,3 +247,25 @@ def test_cse_expr_groupby() -> None: for streaming in [True, False]: out = q.collect(comm_subexpr_elim=True, streaming=streaming) assert_frame_equal(out, expected) + + +def test_windows_cse_excluded() -> None: + lf = pl.LazyFrame( + data=[ + ("a", "aaa", 1), + ("a", "bbb", 3), + ("a", "ccc", 1), + ("c", "xxx", 2), + ("c", "yyy", 3), + ("c", "zzz", 4), + ("b", "qqq", 0), + ], + schema=["a", "b", "c"], + ) + assert lf.select( + c_diff=pl.col("c").diff(1), + c_diff_by_a=pl.col("c").diff(1).over("a"), + ).collect(comm_subexpr_elim=True).to_dict(False) == { + "c_diff": [None, 2, -2, 1, 1, 1, -4], + "c_diff_by_a": [None, 2, -2, None, 1, 1, None], + }