Skip to content

Commit

Permalink
refactor(rust): Deal with re-entrant expressions locally (#17885)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 26, 2024
1 parent 55c15d7 commit ca8e445
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 105 deletions.
20 changes: 0 additions & 20 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,6 @@ impl ExpressionConversionState {
}
}

pub fn create_physical_expr_streaming(
expr_ir: &ExprIR,
expr_arena: &Arena<AExpr>,
schema: Option<&SchemaRef>,
state: &mut ExpressionConversionState,
) -> PolarsResult<Arc<dyn PhysicalExpr>> {
let phys_expr =
create_physical_expr_inner(expr_ir.node(), Context::Default, expr_arena, schema, state)?;

if let Some(name) = expr_ir.get_alias() {
Ok(Arc::new(AliasExpr::new(
phys_expr,
name.clone(),
node_to_expr(expr_ir.node(), expr_arena),
)))
} else {
Ok(phys_expr)
}
}

pub fn create_physical_expr(
expr_ir: &ExprIR,
ctxt: Context,
Expand Down
41 changes: 41 additions & 0 deletions crates/polars-stream/src/expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::sync::Arc;

use polars_core::frame::DataFrame;
use polars_core::prelude::Series;
use polars_error::PolarsResult;
use polars_expr::prelude::{ExecutionState, PhysicalExpr};

#[derive(Clone)]
pub(crate) struct StreamExpr {
inner: Arc<dyn PhysicalExpr>,
// Whether the expression can be re-entering the engine (e.g. a function use the lazy api
// within that function)
reentrant: bool,
}

impl StreamExpr {
pub(crate) fn new(phys_expr: Arc<dyn PhysicalExpr>, reentrant: bool) -> Self {
Self {
inner: phys_expr,
reentrant,
}
}

pub(crate) async fn evaluate(
&self,
df: &DataFrame,
state: &ExecutionState,
) -> PolarsResult<Series> {
if self.reentrant {
let state = state.clone();
let phys_expr = self.inner.clone();
let df = df.clone();
polars_io::pl_async::get_runtime()
.spawn_blocking(move || phys_expr.evaluate(&df, &state))
.await
.unwrap()
} else {
self.inner.evaluate(df, state)
}
}
}
1 change: 1 addition & 0 deletions crates/polars-stream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod skeleton;
pub use skeleton::run_query;

mod execute;
pub(crate) mod expression;
mod graph;
mod morsel;
mod nodes;
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-stream/src/morsel.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::future::Future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};

Expand Down Expand Up @@ -130,6 +131,15 @@ impl Morsel {
Ok(self)
}

pub async fn async_try_map<E, M, F>(mut self, f: M) -> Result<Self, E>
where
M: FnOnce(DataFrame) -> F,
F: Future<Output = Result<DataFrame, E>>,
{
self.df = f(self.df).await?;
Ok(self)
}

pub fn set_consume_token(&mut self, token: WaitToken) {
self.consume_token = Some(token);
}
Expand Down
15 changes: 7 additions & 8 deletions crates/polars-stream/src/nodes/filter.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use std::sync::Arc;

use polars_error::polars_err;
use polars_expr::prelude::PhysicalExpr;

use super::compute_node_prelude::*;
use crate::expression::StreamExpr;

pub struct FilterNode {
predicate: Arc<dyn PhysicalExpr>,
predicate: StreamExpr,
}

impl FilterNode {
pub fn new(predicate: Arc<dyn PhysicalExpr>) -> Self {
pub fn new(predicate: StreamExpr) -> Self {
Self { predicate }
}
}
Expand Down Expand Up @@ -41,8 +39,9 @@ impl ComputeNode for FilterNode {
let slf = &*self;
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
while let Ok(morsel) = recv.recv().await {
let morsel = morsel.try_map(|df| {
let mask = slf.predicate.evaluate(&df, state)?;

let morsel = morsel.async_try_map(|df| async move {
let mask = slf.predicate.evaluate(&df, state).await?;
let mask = mask.bool().map_err(|_| {
polars_err!(
ComputeError: "filter predicate must be of type `Boolean`, got `{}`", mask.dtype()
Expand All @@ -51,7 +50,7 @@ impl ComputeNode for FilterNode {

// We already parallelize, call the sequential filter.
df._filter_seq(mask)
})?;
}).await?;

if morsel.df().is_empty() {
continue;
Expand Down
10 changes: 5 additions & 5 deletions crates/polars-stream/src/nodes/reduce.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::sync::Arc;

use polars_core::schema::Schema;
use polars_expr::prelude::PhysicalExpr;
use polars_expr::reduce::Reduction;

use super::compute_node_prelude::*;
use crate::expression::StreamExpr;
use crate::morsel::SourceToken;

enum ReduceState {
Sink {
selectors: Vec<Arc<dyn PhysicalExpr>>,
selectors: Vec<StreamExpr>,
reductions: Vec<Box<dyn Reduction>>,
},
Source(Option<DataFrame>),
Expand All @@ -23,7 +23,7 @@ pub struct ReduceNode {

impl ReduceNode {
pub fn new(
selectors: Vec<Arc<dyn PhysicalExpr>>,
selectors: Vec<StreamExpr>,
reductions: Vec<Box<dyn Reduction>>,
output_schema: Arc<Schema>,
) -> Self {
Expand All @@ -37,7 +37,7 @@ impl ReduceNode {
}

fn spawn_sink<'env, 's>(
selectors: &'env [Arc<dyn PhysicalExpr>],
selectors: &'env [StreamExpr],
reductions: &'env mut [Box<dyn Reduction>],
scope: &'s TaskScope<'s, 'env>,
recv: RecvPort<'_>,
Expand All @@ -55,7 +55,7 @@ impl ReduceNode {
while let Ok(morsel) = recv.recv().await {
for (reduction, selector) in local_reductions.iter_mut().zip(selectors) {
// TODO: don't convert to physical representation here.
let input = selector.evaluate(morsel.df(), state)?;
let input = selector.evaluate(morsel.df(), state).await?;
reduction.update(&input.to_physical_repr())?;
}
}
Expand Down
29 changes: 5 additions & 24 deletions crates/polars-stream/src/nodes/select.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
use std::sync::Arc;

use polars_core::schema::Schema;
use polars_expr::prelude::PhysicalExpr;

use super::compute_node_prelude::*;
use crate::expression::StreamExpr;

pub struct SelectNode {
selectors: Vec<Arc<dyn PhysicalExpr>>,
selector_reentrant: Vec<bool>,
selectors: Vec<StreamExpr>,
schema: Arc<Schema>,
extend_original: bool,
}

impl SelectNode {
pub fn new(
selectors: Vec<Arc<dyn PhysicalExpr>>,
selector_reentrant: Vec<bool>,
schema: Arc<Schema>,
extend_original: bool,
) -> Self {
pub fn new(selectors: Vec<StreamExpr>, schema: Arc<Schema>, extend_original: bool) -> Self {
Self {
selectors,
selector_reentrant,
schema,
extend_original,
}
Expand Down Expand Up @@ -56,20 +49,8 @@ impl ComputeNode for SelectNode {
while let Ok(morsel) = recv.recv().await {
let (df, seq, source_token, consume_token) = morsel.into_inner();
let mut selected = Vec::new();
for (selector, reentrant) in slf.selectors.iter().zip(&slf.selector_reentrant) {
// We need spawn_blocking because evaluate could contain Python UDFs which
// recursively call the executor again.
let s = if *reentrant {
let df = df.clone();
let selector = selector.clone();
let state = state.clone();
polars_io::pl_async::get_runtime()
.spawn_blocking(move || selector.evaluate(&df, &state))
.await
.unwrap()?
} else {
selector.evaluate(&df, state)?
};
for selector in slf.selectors.iter() {
let s = selector.evaluate(&df, state).await?;
selected.push(s);
}

Expand Down
21 changes: 1 addition & 20 deletions crates/polars-stream/src/physical_plan/lower_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use polars_error::PolarsResult;
use polars_expr::reduce::can_convert_into_reduction;
use polars_plan::plans::{AExpr, Context, IR};
use polars_plan::prelude::{ArenaExprIter, FunctionFlags, SinkType};
use polars_plan::prelude::SinkType;
use polars_utils::arena::{Arena, Node};
use slotmap::SlotMap;

Expand All @@ -13,15 +13,6 @@ fn is_streamable(node: Node, arena: &Arena<AExpr>) -> bool {
polars_plan::plans::is_streamable(node, arena, Context::Default)
}

fn has_potential_recurring_entrance(node: Node, arena: &Arena<AExpr>) -> bool {
arena.iter(node).any(|(_n, ae)| match ae {
AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => {
options.flags.contains(FunctionFlags::OPTIONAL_RE_ENTRANT)
},
_ => false,
})
}

#[recursive::recursive]
pub fn lower_ir(
node: Node,
Expand Down Expand Up @@ -50,17 +41,12 @@ pub fn lower_ir(
schema,
..
} if expr.iter().all(|e| is_streamable(e.node(), expr_arena)) => {
let selector_reentrant = expr
.iter()
.map(|e| has_potential_recurring_entrance(e.node(), expr_arena))
.collect();
let selectors = expr.clone();
let output_schema = schema.clone();
let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?;
Ok(phys_sm.insert(PhysNode::Select {
input,
selectors,
selector_reentrant,
output_schema,
extend_original: false,
}))
Expand Down Expand Up @@ -96,17 +82,12 @@ pub fn lower_ir(
schema,
..
} if exprs.iter().all(|e| is_streamable(e.node(), expr_arena)) => {
let selector_reentrant = exprs
.iter()
.map(|e| has_potential_recurring_entrance(e.node(), expr_arena))
.collect();
let selectors = exprs.clone();
let output_schema = schema.clone();
let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?;
Ok(phys_sm.insert(PhysNode::Select {
input,
selectors,
selector_reentrant,
output_schema,
extend_original: true,
}))
Expand Down
1 change: 0 additions & 1 deletion crates/polars-stream/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub enum PhysNode {
Select {
input: PhysNodeKey,
selectors: Vec<ExprIR>,
selector_reentrant: Vec<bool>,
extend_original: bool,
output_schema: Arc<Schema>,
},
Expand Down
Loading

0 comments on commit ca8e445

Please sign in to comment.