diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index a8812ea938f6..8e648a77f3d7 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -161,26 +161,6 @@ impl ExpressionConversionState { } } -pub fn create_physical_expr_streaming( - expr_ir: &ExprIR, - expr_arena: &Arena, - schema: Option<&SchemaRef>, - state: &mut ExpressionConversionState, -) -> PolarsResult> { - let phys_expr = - create_physical_expr_inner(expr_ir.node(), Context::Default, expr_arena, schema, state)?; - - if let Some(name) = expr_ir.get_alias() { - Ok(Arc::new(AliasExpr::new( - phys_expr, - name.clone(), - node_to_expr(expr_ir.node(), expr_arena), - ))) - } else { - Ok(phys_expr) - } -} - pub fn create_physical_expr( expr_ir: &ExprIR, ctxt: Context, diff --git a/crates/polars-stream/src/expression.rs b/crates/polars-stream/src/expression.rs new file mode 100644 index 000000000000..a6e41728d111 --- /dev/null +++ b/crates/polars-stream/src/expression.rs @@ -0,0 +1,41 @@ +use std::sync::Arc; + +use polars_core::frame::DataFrame; +use polars_core::prelude::Series; +use polars_error::PolarsResult; +use polars_expr::prelude::{ExecutionState, PhysicalExpr}; + +#[derive(Clone)] +pub(crate) struct StreamExpr { + inner: Arc, + // Whether the expression can be re-entering the engine (e.g. a function use the lazy api + // within that function) + reentrant: bool, +} + +impl StreamExpr { + pub(crate) fn new(phys_expr: Arc, reentrant: bool) -> Self { + Self { + inner: phys_expr, + reentrant, + } + } + + pub(crate) async fn evaluate( + &self, + df: &DataFrame, + state: &ExecutionState, + ) -> PolarsResult { + if self.reentrant { + let state = state.clone(); + let phys_expr = self.inner.clone(); + let df = df.clone(); + polars_io::pl_async::get_runtime() + .spawn_blocking(move || phys_expr.evaluate(&df, &state)) + .await + .unwrap() + } else { + self.inner.evaluate(df, state) + } + } +} diff --git a/crates/polars-stream/src/lib.rs b/crates/polars-stream/src/lib.rs index 00bc6e3cc11e..a31e20f22381 100644 --- a/crates/polars-stream/src/lib.rs +++ b/crates/polars-stream/src/lib.rs @@ -5,6 +5,7 @@ mod skeleton; pub use skeleton::run_query; mod execute; +pub(crate) mod expression; mod graph; mod morsel; mod nodes; diff --git a/crates/polars-stream/src/morsel.rs b/crates/polars-stream/src/morsel.rs index 2889bd5e400d..1e21c4802aa6 100644 --- a/crates/polars-stream/src/morsel.rs +++ b/crates/polars-stream/src/morsel.rs @@ -1,3 +1,4 @@ +use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, OnceLock}; @@ -130,6 +131,15 @@ impl Morsel { Ok(self) } + pub async fn async_try_map(mut self, f: M) -> Result + where + M: FnOnce(DataFrame) -> F, + F: Future>, + { + self.df = f(self.df).await?; + Ok(self) + } + pub fn set_consume_token(&mut self, token: WaitToken) { self.consume_token = Some(token); } diff --git a/crates/polars-stream/src/nodes/filter.rs b/crates/polars-stream/src/nodes/filter.rs index 0f263545c1e8..8a19b1a27986 100644 --- a/crates/polars-stream/src/nodes/filter.rs +++ b/crates/polars-stream/src/nodes/filter.rs @@ -1,16 +1,14 @@ -use std::sync::Arc; - use polars_error::polars_err; -use polars_expr::prelude::PhysicalExpr; use super::compute_node_prelude::*; +use crate::expression::StreamExpr; pub struct FilterNode { - predicate: Arc, + predicate: StreamExpr, } impl FilterNode { - pub fn new(predicate: Arc) -> Self { + pub fn new(predicate: StreamExpr) -> Self { Self { predicate } } } @@ -41,8 +39,9 @@ impl ComputeNode for FilterNode { let slf = &*self; join_handles.push(scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = recv.recv().await { - let morsel = morsel.try_map(|df| { - let mask = slf.predicate.evaluate(&df, state)?; + + let morsel = morsel.async_try_map(|df| async move { + let mask = slf.predicate.evaluate(&df, state).await?; let mask = mask.bool().map_err(|_| { polars_err!( ComputeError: "filter predicate must be of type `Boolean`, got `{}`", mask.dtype() @@ -51,7 +50,7 @@ impl ComputeNode for FilterNode { // We already parallelize, call the sequential filter. df._filter_seq(mask) - })?; + }).await?; if morsel.df().is_empty() { continue; diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs index 91209b7d4e33..4dc4d859ba62 100644 --- a/crates/polars-stream/src/nodes/reduce.rs +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -1,15 +1,15 @@ use std::sync::Arc; use polars_core::schema::Schema; -use polars_expr::prelude::PhysicalExpr; use polars_expr::reduce::Reduction; use super::compute_node_prelude::*; +use crate::expression::StreamExpr; use crate::morsel::SourceToken; enum ReduceState { Sink { - selectors: Vec>, + selectors: Vec, reductions: Vec>, }, Source(Option), @@ -23,7 +23,7 @@ pub struct ReduceNode { impl ReduceNode { pub fn new( - selectors: Vec>, + selectors: Vec, reductions: Vec>, output_schema: Arc, ) -> Self { @@ -37,7 +37,7 @@ impl ReduceNode { } fn spawn_sink<'env, 's>( - selectors: &'env [Arc], + selectors: &'env [StreamExpr], reductions: &'env mut [Box], scope: &'s TaskScope<'s, 'env>, recv: RecvPort<'_>, @@ -55,7 +55,7 @@ impl ReduceNode { while let Ok(morsel) = recv.recv().await { for (reduction, selector) in local_reductions.iter_mut().zip(selectors) { // TODO: don't convert to physical representation here. - let input = selector.evaluate(morsel.df(), state)?; + let input = selector.evaluate(morsel.df(), state).await?; reduction.update(&input.to_physical_repr())?; } } diff --git a/crates/polars-stream/src/nodes/select.rs b/crates/polars-stream/src/nodes/select.rs index e021a3a0f197..568351ee4f47 100644 --- a/crates/polars-stream/src/nodes/select.rs +++ b/crates/polars-stream/src/nodes/select.rs @@ -1,27 +1,20 @@ use std::sync::Arc; use polars_core::schema::Schema; -use polars_expr::prelude::PhysicalExpr; use super::compute_node_prelude::*; +use crate::expression::StreamExpr; pub struct SelectNode { - selectors: Vec>, - selector_reentrant: Vec, + selectors: Vec, schema: Arc, extend_original: bool, } impl SelectNode { - pub fn new( - selectors: Vec>, - selector_reentrant: Vec, - schema: Arc, - extend_original: bool, - ) -> Self { + pub fn new(selectors: Vec, schema: Arc, extend_original: bool) -> Self { Self { selectors, - selector_reentrant, schema, extend_original, } @@ -56,20 +49,8 @@ impl ComputeNode for SelectNode { while let Ok(morsel) = recv.recv().await { let (df, seq, source_token, consume_token) = morsel.into_inner(); let mut selected = Vec::new(); - for (selector, reentrant) in slf.selectors.iter().zip(&slf.selector_reentrant) { - // We need spawn_blocking because evaluate could contain Python UDFs which - // recursively call the executor again. - let s = if *reentrant { - let df = df.clone(); - let selector = selector.clone(); - let state = state.clone(); - polars_io::pl_async::get_runtime() - .spawn_blocking(move || selector.evaluate(&df, &state)) - .await - .unwrap()? - } else { - selector.evaluate(&df, state)? - }; + for selector in slf.selectors.iter() { + let s = selector.evaluate(&df, state).await?; selected.push(s); } diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index 0096c9518a54..08aa31fb05d4 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use polars_error::PolarsResult; use polars_expr::reduce::can_convert_into_reduction; use polars_plan::plans::{AExpr, Context, IR}; -use polars_plan::prelude::{ArenaExprIter, FunctionFlags, SinkType}; +use polars_plan::prelude::SinkType; use polars_utils::arena::{Arena, Node}; use slotmap::SlotMap; @@ -13,15 +13,6 @@ fn is_streamable(node: Node, arena: &Arena) -> bool { polars_plan::plans::is_streamable(node, arena, Context::Default) } -fn has_potential_recurring_entrance(node: Node, arena: &Arena) -> bool { - arena.iter(node).any(|(_n, ae)| match ae { - AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { - options.flags.contains(FunctionFlags::OPTIONAL_RE_ENTRANT) - }, - _ => false, - }) -} - #[recursive::recursive] pub fn lower_ir( node: Node, @@ -50,17 +41,12 @@ pub fn lower_ir( schema, .. } if expr.iter().all(|e| is_streamable(e.node(), expr_arena)) => { - let selector_reentrant = expr - .iter() - .map(|e| has_potential_recurring_entrance(e.node(), expr_arena)) - .collect(); let selectors = expr.clone(); let output_schema = schema.clone(); let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; Ok(phys_sm.insert(PhysNode::Select { input, selectors, - selector_reentrant, output_schema, extend_original: false, })) @@ -96,17 +82,12 @@ pub fn lower_ir( schema, .. } if exprs.iter().all(|e| is_streamable(e.node(), expr_arena)) => { - let selector_reentrant = exprs - .iter() - .map(|e| has_potential_recurring_entrance(e.node(), expr_arena)) - .collect(); let selectors = exprs.clone(); let output_schema = schema.clone(); let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; Ok(phys_sm.insert(PhysNode::Select { input, selectors, - selector_reentrant, output_schema, extend_original: true, })) diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 27147ed194f2..fdecbf0e48e7 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -30,7 +30,6 @@ pub enum PhysNode { Select { input: PhysNodeKey, selectors: Vec, - selector_reentrant: Vec, extend_original: bool, output_schema: Arc, }, diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index 3a604b05e7fd..f7cc5e321970 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -7,16 +7,42 @@ use polars_expr::reduce::into_reduction; use polars_expr::state::ExecutionState; use polars_mem_engine::create_physical_plan; use polars_plan::plans::expr_ir::ExprIR; -use polars_plan::plans::{AExpr, Context, IR}; -use polars_utils::arena::Arena; +use polars_plan::plans::{AExpr, ArenaExprIter, Context, IR}; +use polars_plan::prelude::FunctionFlags; +use polars_utils::arena::{Arena, Node}; use recursive::recursive; use slotmap::{SecondaryMap, SlotMap}; use super::{PhysNode, PhysNodeKey}; +use crate::expression::StreamExpr; use crate::graph::{Graph, GraphNodeKey}; use crate::nodes; use crate::utils::late_materialized_df::LateMaterializedDataFrame; +fn has_potential_recurring_entrance(node: Node, arena: &Arena) -> bool { + arena.iter(node).any(|(_n, ae)| match ae { + AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { + options.flags.contains(FunctionFlags::OPTIONAL_RE_ENTRANT) + }, + _ => false, + }) +} + +fn create_stream_expr( + expr_ir: &ExprIR, + ctx: &mut GraphConversionContext<'_>, +) -> PolarsResult { + let reentrant = has_potential_recurring_entrance(expr_ir.node(), ctx.expr_arena); + let phys = create_physical_expr( + expr_ir, + Context::Default, + ctx.expr_arena, + None, + &mut ctx.expr_conversion_state, + )?; + Ok(StreamExpr::new(phys, reentrant)) +} + struct GraphConversionContext<'a> { phys_sm: &'a SlotMap, expr_arena: &'a Arena, @@ -75,13 +101,7 @@ fn to_graph_rec<'a>( }, Filter { predicate, input } => { - let phys_predicate_expr = create_physical_expr( - predicate, - Context::Default, - ctx.expr_arena, - None, - &mut ctx.expr_conversion_state, - )?; + let phys_predicate_expr = create_stream_expr(predicate, ctx)?; let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( nodes::filter::FilterNode::new(phys_predicate_expr), @@ -91,28 +111,18 @@ fn to_graph_rec<'a>( Select { selectors, - selector_reentrant, input, output_schema, extend_original, } => { let phys_selectors = selectors .iter() - .map(|selector| { - create_physical_expr( - selector, - Context::Default, - ctx.expr_arena, - None, - &mut ctx.expr_conversion_state, - ) - }) + .map(|selector| create_stream_expr(selector, ctx)) .collect::>()?; let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( nodes::select::SelectNode::new( phys_selectors, - selector_reentrant.clone(), output_schema.clone(), *extend_original, ), @@ -136,13 +146,8 @@ fn to_graph_rec<'a>( .expect("invariant"); reductions.push(red); - let input_phys = create_physical_expr( - &ExprIR::from_node(input_node, ctx.expr_arena), - Context::Default, - ctx.expr_arena, - None, - &mut ctx.expr_conversion_state, - )?; + let input_phys = + create_stream_expr(&ExprIR::from_node(input_node, ctx.expr_arena), ctx)?; inputs.push(input_phys) }