From 49f5ef532cf80dc9109e6b35da422782cf46c3cd Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 26 Jun 2024 14:12:56 +0200 Subject: [PATCH] feat: add elementwise select and with_columns to new streaming engine (#17185) --- crates/polars-core/src/frame/from.rs | 21 ----- crates/polars-core/src/frame/mod.rs | 19 ++++ crates/polars-io/src/csv/read/read_impl.rs | 4 +- crates/polars-io/src/parquet/read/utils.rs | 2 +- .../sinks/group_by/generic/hash_table.rs | 2 +- .../src/executors/sinks/group_by/utils.rs | 2 +- .../src/executors/sinks/ordered.rs | 4 +- .../polars-pipe/src/executors/sinks/slice.rs | 4 +- crates/polars-plan/src/plans/ir/schema.rs | 3 + .../plans/optimizer/predicate_pushdown/mod.rs | 2 +- crates/polars-sql/src/context.rs | 2 +- crates/polars-stream/src/graph.rs | 1 + crates/polars-stream/src/lib.rs | 1 - .../polars-stream/src/nodes/in_memory_sink.rs | 28 ++++-- crates/polars-stream/src/nodes/mod.rs | 1 + crates/polars-stream/src/nodes/select.rs | 90 +++++++++++++++++++ .../src/physical_plan/lower_ir.rs | 48 +++++++++- crates/polars-stream/src/physical_plan/mod.rs | 11 ++- .../src/physical_plan/to_graph.rs | 33 ++++++- .../src/utils/in_memory_linearize.rs | 2 +- 20 files changed, 230 insertions(+), 50 deletions(-) create mode 100644 crates/polars-stream/src/nodes/select.rs diff --git a/crates/polars-core/src/frame/from.rs b/crates/polars-core/src/frame/from.rs index 72172ec7e736..607fab946857 100644 --- a/crates/polars-core/src/frame/from.rs +++ b/crates/polars-core/src/frame/from.rs @@ -28,24 +28,3 @@ impl TryFrom for DataFrame { DataFrame::new(columns) } } - -impl From<&Schema> for DataFrame { - fn from(schema: &Schema) -> Self { - let cols = schema - .iter() - .map(|(name, dtype)| Series::new_empty(name, dtype)) - .collect(); - unsafe { DataFrame::new_no_checks(cols) } - } -} - -impl From<&ArrowSchema> for DataFrame { - fn from(schema: &ArrowSchema) -> Self { - let cols = schema - .fields - .iter() - .map(|fld| Series::new_empty(fld.name.as_str(), &(fld.data_type().into()))) - .collect(); - unsafe { DataFrame::new_no_checks(cols) } - } -} diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 0a8f130dfe51..6a94cd089547 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -318,6 +318,25 @@ impl DataFrame { unsafe { DataFrame::new_no_checks(Vec::new()) } } + /// Create an empty `DataFrame` with empty columns as per the `schema`. + pub fn empty_with_schema(schema: &Schema) -> Self { + let cols = schema + .iter() + .map(|(name, dtype)| Series::new_empty(name, dtype)) + .collect(); + unsafe { DataFrame::new_no_checks(cols) } + } + + /// Create an empty `DataFrame` with empty columns as per the `schema`. + pub fn empty_with_arrow_schema(schema: &ArrowSchema) -> Self { + let cols = schema + .fields + .iter() + .map(|fld| Series::new_empty(fld.name.as_str(), &(fld.data_type().into()))) + .collect(); + unsafe { DataFrame::new_no_checks(cols) } + } + /// Removes the last `Series` from the `DataFrame` and returns it, or [`None`] if it is empty. /// /// # Example diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs index d0248941b8c8..50ba63dd668a 100644 --- a/crates/polars-io/src/csv/read/read_impl.rs +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -475,9 +475,9 @@ impl<'a> CoreReader<'a> { // An empty file with a schema should return an empty DataFrame with that schema if bytes.is_empty() { let mut df = if projection.len() == self.schema.len() { - DataFrame::from(self.schema.as_ref()) + DataFrame::empty_with_schema(self.schema.as_ref()) } else { - DataFrame::from( + DataFrame::empty_with_schema( &projection .iter() .map(|&i| self.schema.get_at_index(i).unwrap()) diff --git a/crates/polars-io/src/parquet/read/utils.rs b/crates/polars-io/src/parquet/read/utils.rs index f3b79f3cd756..60e8d9a29bba 100644 --- a/crates/polars-io/src/parquet/read/utils.rs +++ b/crates/polars-io/src/parquet/read/utils.rs @@ -17,7 +17,7 @@ pub fn materialize_empty_df( } else { Cow::Borrowed(reader_schema) }; - let mut df = DataFrame::from(schema.as_ref()); + let mut df = DataFrame::empty_with_arrow_schema(&schema); if let Some(row_index) = row_index { df.insert_column(0, Series::new_empty(&row_index.name, &IDX_DTYPE)) diff --git a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs index 43023493c769..3a1ca17a183a 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/generic/hash_table.rs @@ -215,7 +215,7 @@ impl AggHashTable { let (skip_len, take_len) = if let Some((offset, slice_len)) = slice { if *offset as usize >= local_len { *offset -= local_len as i64; - return DataFrame::from(self.output_schema.as_ref()); + return DataFrame::empty_with_schema(&self.output_schema); } else { let out = (*offset as usize, *slice_len); *offset = 0; diff --git a/crates/polars-pipe/src/executors/sinks/group_by/utils.rs b/crates/polars-pipe/src/executors/sinks/group_by/utils.rs index c95164fa6448..dd0d8ce4aea0 100644 --- a/crates/polars-pipe/src/executors/sinks/group_by/utils.rs +++ b/crates/polars-pipe/src/executors/sinks/group_by/utils.rs @@ -57,7 +57,7 @@ pub(super) fn finalize_group_by( ooc_payload: Option<(IOThread, Box)>, ) -> PolarsResult { let df = if dfs.is_empty() { - DataFrame::from(output_schema) + DataFrame::empty_with_schema(output_schema) } else { let mut df = accumulate_dataframes_vertical_unchecked(dfs); // re init to check duplicates diff --git a/crates/polars-pipe/src/executors/sinks/ordered.rs b/crates/polars-pipe/src/executors/sinks/ordered.rs index 083906125687..e76f38bdad30 100644 --- a/crates/polars-pipe/src/executors/sinks/ordered.rs +++ b/crates/polars-pipe/src/executors/sinks/ordered.rs @@ -48,8 +48,8 @@ impl Sink for OrderedSink { } fn finalize(&mut self, _context: &PExecutionContext) -> PolarsResult { if self.chunks.is_empty() { - return Ok(FinalizedSink::Finished(DataFrame::from( - self.schema.as_ref(), + return Ok(FinalizedSink::Finished(DataFrame::empty_with_schema( + &self.schema, ))); } self.sort(); diff --git a/crates/polars-pipe/src/executors/sinks/slice.rs b/crates/polars-pipe/src/executors/sinks/slice.rs index a5e1c0e24aca..8c85fc10f721 100644 --- a/crates/polars-pipe/src/executors/sinks/slice.rs +++ b/crates/polars-pipe/src/executors/sinks/slice.rs @@ -94,8 +94,8 @@ impl Sink for SliceSink { let mut chunks = chunks.lock().unwrap(); let chunks: Vec = std::mem::take(chunks.as_mut()); if chunks.is_empty() { - return Ok(FinalizedSink::Finished(DataFrame::from( - self.schema.as_ref(), + return Ok(FinalizedSink::Finished(DataFrame::empty_with_schema( + &self.schema, ))); } diff --git a/crates/polars-plan/src/plans/ir/schema.rs b/crates/polars-plan/src/plans/ir/schema.rs index 6047fe6d5943..46d793687345 100644 --- a/crates/polars-plan/src/plans/ir/schema.rs +++ b/crates/polars-plan/src/plans/ir/schema.rs @@ -1,3 +1,5 @@ +use recursive::recursive; + use super::*; impl IR { @@ -61,6 +63,7 @@ impl IR { } /// Get the schema of the logical plan node. + #[recursive] pub fn schema<'a>(&'a self, arena: &'a Arena) -> Cow<'a, SchemaRef> { use IR::*; let schema = match self { diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index 42a505a92b80..c9bdccee1e53 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -380,7 +380,7 @@ impl<'a> PredicatePushDown<'a> { } if new_paths.is_empty() { let schema = output_schema.as_ref().unwrap_or(&file_info.schema); - let df = DataFrame::from(schema.as_ref()); + let df = DataFrame::empty_with_schema(schema); return Ok(DataFrameScan { df: Arc::new(df), diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index ea86de06f4e5..b50008007612 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -478,7 +478,7 @@ impl SQLContext { None => { let tbl = table_name.to_string(); if let Some(lf) = self.table_map.get_mut(&tbl) { - *lf = DataFrame::from( + *lf = DataFrame::empty_with_schema( lf.schema_with_arenas(&mut self.lp_arena, &mut self.expr_arena) .unwrap() .as_ref(), diff --git a/crates/polars-stream/src/graph.rs b/crates/polars-stream/src/graph.rs index e3dbdef6b218..1445233b6025 100644 --- a/crates/polars-stream/src/graph.rs +++ b/crates/polars-stream/src/graph.rs @@ -71,6 +71,7 @@ pub struct GraphNode { } /// A pipe sends data between nodes. +#[allow(unused)] // TODO: remove. pub struct LogicalPipe { // Node that we send data to. sender: GraphNodeKey, diff --git a/crates/polars-stream/src/lib.rs b/crates/polars-stream/src/lib.rs index 263935da38ce..f3443e876ba3 100644 --- a/crates/polars-stream/src/lib.rs +++ b/crates/polars-stream/src/lib.rs @@ -1,6 +1,5 @@ #![allow(unused)] // TODO: remove. -#[allow(unused)] // TODO: remove. mod async_executor; #[allow(unused)] // TODO: remove. mod async_primitives; diff --git a/crates/polars-stream/src/nodes/in_memory_sink.rs b/crates/polars-stream/src/nodes/in_memory_sink.rs index 80919e01bbc5..edee107395de 100644 --- a/crates/polars-stream/src/nodes/in_memory_sink.rs +++ b/crates/polars-stream/src/nodes/in_memory_sink.rs @@ -1,15 +1,12 @@ -use std::cmp::Reverse; -use std::collections::{BinaryHeap, VecDeque}; +use std::sync::Arc; use parking_lot::Mutex; use polars_core::frame::DataFrame; +use polars_core::schema::Schema; +use polars_core::series::Series; use polars_core::utils::accumulate_dataframes_vertical_unchecked; -use polars_core::utils::rayon::iter::{IntoParallelIterator, ParallelIterator}; -use polars_core::POOL; use polars_error::PolarsResult; use polars_expr::state::ExecutionState; -use polars_utils::priority::Priority; -use polars_utils::sync::SyncPtr; use super::ComputeNode; use crate::async_executor::{JoinHandle, TaskScope}; @@ -17,9 +14,18 @@ use crate::async_primitives::pipe::{Receiver, Sender}; use crate::morsel::Morsel; use crate::utils::in_memory_linearize::linearize; -#[derive(Default)] pub struct InMemorySink { morsels_per_pipe: Mutex>>, + schema: Arc, +} + +impl InMemorySink { + pub fn new(schema: Arc) -> Self { + Self { + morsels_per_pipe: Mutex::default(), + schema, + } + } } impl ComputeNode for InMemorySink { @@ -51,8 +57,12 @@ impl ComputeNode for InMemorySink { } fn finalize(&mut self) -> PolarsResult> { - let mut morsels_per_pipe = core::mem::take(&mut *self.morsels_per_pipe.get_mut()); + let morsels_per_pipe = core::mem::take(&mut *self.morsels_per_pipe.get_mut()); let dataframes = linearize(morsels_per_pipe); - Ok(Some(accumulate_dataframes_vertical_unchecked(dataframes))) + if dataframes.is_empty() { + Ok(Some(DataFrame::empty_with_schema(&self.schema))) + } else { + Ok(Some(accumulate_dataframes_vertical_unchecked(dataframes))) + } } } diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index f2aaebd172d5..eb85181b5b3d 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -9,6 +9,7 @@ use crate::morsel::Morsel; pub mod filter; pub mod in_memory_sink; pub mod in_memory_source; +pub mod select; pub mod simple_projection; pub trait ComputeNode { diff --git a/crates/polars-stream/src/nodes/select.rs b/crates/polars-stream/src/nodes/select.rs new file mode 100644 index 000000000000..d108ccfce98a --- /dev/null +++ b/crates/polars-stream/src/nodes/select.rs @@ -0,0 +1,90 @@ +use std::sync::Arc; + +use polars_core::frame::DataFrame; +use polars_core::schema::Schema; +use polars_core::series::Series; +use polars_error::PolarsResult; +use polars_expr::prelude::PhysicalExpr; +use polars_expr::state::ExecutionState; + +use super::ComputeNode; +use crate::async_executor::{JoinHandle, TaskScope}; +use crate::async_primitives::pipe::{Receiver, Sender}; +use crate::morsel::Morsel; + +pub struct SelectNode { + selectors: Vec>, + schema: Arc, + extend_original: bool, +} + +impl SelectNode { + pub fn new( + selectors: Vec>, + schema: Arc, + extend_original: bool, + ) -> Self { + Self { + selectors, + schema, + extend_original, + } + } +} + +impl ComputeNode for SelectNode { + fn spawn<'env, 's>( + &'env self, + scope: &'s TaskScope<'s, 'env>, + _pipeline: usize, + recv: Vec>, + send: Vec>, + state: &'s ExecutionState, + ) -> JoinHandle> { + let [mut recv] = <[_; 1]>::try_from(recv).ok().unwrap(); + let [mut send] = <[_; 1]>::try_from(send).ok().unwrap(); + + scope.spawn_task(true, async move { + while let Ok(morsel) = recv.recv().await { + let morsel = morsel.try_map(|df| { + // Select columns. + let mut selected: Vec = self + .selectors + .iter() + .map(|s| s.evaluate(&df, state)) + .collect::>()?; + + // Extend or create new dataframe. + let ret = if self.extend_original { + let mut out = df.clone(); + out._add_columns(selected, &self.schema)?; + out + } else { + // Broadcast scalars. + let max_non_unit_length = selected + .iter() + .map(|s| s.len()) + .filter(|l| *l != 1) + .max() + .unwrap_or(1); + for s in &mut selected { + if s.len() != max_non_unit_length { + assert!(s.len() == 1, "got series of incompatible lengths"); + *s = s.new_from_index(0, max_non_unit_length); + } + } + unsafe { DataFrame::new_no_checks(selected) } + }; + + PolarsResult::Ok(ret) + })?; + + if send.send(morsel).await.is_err() { + break; + } + } + + Ok(()) + }) + } +} diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index fbad8573af9c..d522c87ca96c 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -17,7 +17,50 @@ pub fn lower_ir( expr_arena: &mut Arena, phys_sm: &mut SlotMap, ) -> PolarsResult { - match ir_arena.get(node) { + let ir_node = ir_arena.get(node); + match ir_node { + IR::SimpleProjection { input, columns } => { + let schema = columns.clone(); + let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; + Ok(phys_sm.insert(PhysNode::SimpleProjection { input, schema })) + }, + + // TODO: split partially streamable selections to avoid fallback as much as possible. + IR::Select { + input, + expr, + schema, + .. + } if expr.iter().all(|e| is_streamable(e.node(), expr_arena)) => { + let selectors = expr.clone(); + let schema = schema.clone(); + let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; + Ok(phys_sm.insert(PhysNode::Select { + input, + selectors, + schema, + extend_original: false, + })) + }, + + // TODO: split partially streamable selections to avoid fallback as much as possible. + IR::HStack { + input, + exprs, + schema, + .. + } if exprs.iter().all(|e| is_streamable(e.node(), expr_arena)) => { + let selectors = exprs.clone(); + let schema = schema.clone(); + let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; + Ok(phys_sm.insert(PhysNode::Select { + input, + selectors, + schema, + extend_original: true, + })) + }, + IR::Filter { input, predicate } if is_streamable(predicate.node(), expr_arena) => { let predicate = predicate.clone(); let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; @@ -57,8 +100,9 @@ pub fn lower_ir( IR::Sink { input, payload } => { if *payload == SinkType::Memory { + let schema = ir_node.schema(ir_arena).into_owned(); let input = lower_ir(*input, ir_arena, expr_arena, phys_sm)?; - return Ok(phys_sm.insert(PhysNode::InMemorySink { input })); + return Ok(phys_sm.insert(PhysNode::InMemorySink { input, schema })); } todo!() diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index dd34909c367a..e6a7d26155c4 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use polars_core::frame::DataFrame; -use polars_core::schema::SchemaRef; +use polars_core::schema::Schema; use polars_plan::prelude::expr_ir::ExprIR; use polars_utils::arena::Node; @@ -25,16 +25,23 @@ pub enum PhysNode { InMemorySource { df: Arc, }, + Select { + input: PhysNodeKey, + selectors: Vec, + extend_original: bool, + schema: Arc, + }, Filter { input: PhysNodeKey, predicate: ExprIR, }, SimpleProjection { input: PhysNodeKey, - schema: SchemaRef, + schema: Arc, }, InMemorySink { input: PhysNodeKey, + schema: Arc, }, // Fallback to the in-memory engine. Fallback(Node), diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index d88588b2d854..9acc930b914e 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -68,6 +68,31 @@ fn to_graph_rec<'a>( ) }, + Select { + selectors, + input, + schema, + extend_original, + } => { + let phys_selectors = selectors + .iter() + .map(|selector| { + create_physical_expr( + selector, + Context::Default, + ctx.expr_arena, + None, + &mut ctx.expr_conversion_state, + ) + }) + .collect::>()?; + let input_key = to_graph_rec(*input, ctx)?; + ctx.graph.add_node( + nodes::select::SelectNode::new(phys_selectors, schema.clone(), *extend_original), + [input_key], + ) + }, + SimpleProjection { schema, input } => { let input_key = to_graph_rec(*input, ctx)?; ctx.graph.add_node( @@ -76,10 +101,12 @@ fn to_graph_rec<'a>( ) }, - InMemorySink { input } => { + InMemorySink { input, schema } => { let input_key = to_graph_rec(*input, ctx)?; - ctx.graph - .add_node(nodes::in_memory_sink::InMemorySink::default(), [input_key]) + ctx.graph.add_node( + nodes::in_memory_sink::InMemorySink::new(schema.clone()), + [input_key], + ) }, // Fallback to the in-memory engine. diff --git a/crates/polars-stream/src/utils/in_memory_linearize.rs b/crates/polars-stream/src/utils/in_memory_linearize.rs index 24aa9f596a74..2cf1159cd286 100644 --- a/crates/polars-stream/src/utils/in_memory_linearize.rs +++ b/crates/polars-stream/src/utils/in_memory_linearize.rs @@ -34,7 +34,7 @@ pub fn linearize(mut morsels_per_pipe: Vec>) -> Vec { let morsels_per_p = &morsels_per_pipe; let mut dataframes: Vec = Vec::with_capacity(num_morsels); - let mut dataframes_ptr = unsafe { SyncPtr::new(dataframes.as_mut_ptr()) }; + let dataframes_ptr = unsafe { SyncPtr::new(dataframes.as_mut_ptr()) }; rayon::scope(|s| { let mut out_offset = 0; let mut stop_idx_per_pipe = vec![0; morsels_per_p.len()];