From 54b7fb861d87352fd902fa49f615c3660eb31ba5 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Thu, 25 Jul 2024 17:19:57 +0200 Subject: [PATCH] refactor(rust): Add zip node to streaming engine (#17866) --- crates/polars-stream/src/morsel.rs | 49 ++++- .../src/nodes/in_memory_source.rs | 15 +- crates/polars-stream/src/nodes/mod.rs | 1 + crates/polars-stream/src/nodes/reduce.rs | 3 +- crates/polars-stream/src/nodes/select.rs | 4 +- .../src/nodes/streaming_slice.rs | 7 + crates/polars-stream/src/nodes/zip.rs | 178 ++++++++++++++++++ .../src/physical_plan/lower_ir.rs | 13 ++ crates/polars-stream/src/physical_plan/mod.rs | 4 + .../src/physical_plan/to_graph.rs | 8 + 10 files changed, 270 insertions(+), 12 deletions(-) create mode 100644 crates/polars-stream/src/nodes/zip.rs diff --git a/crates/polars-stream/src/morsel.rs b/crates/polars-stream/src/morsel.rs index 58a236211dc5..2889bd5e400d 100644 --- a/crates/polars-stream/src/morsel.rs +++ b/crates/polars-stream/src/morsel.rs @@ -1,4 +1,5 @@ -use std::sync::OnceLock; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, OnceLock}; use polars_core::frame::DataFrame; @@ -41,6 +42,30 @@ impl MorselSeq { } } +/// A token indicating which source this morsel originated from, and a way to +/// pass information/signals to it. Currently it's only used to request a source +/// to stop with passing new morsels this execution phase. +#[derive(Clone)] +pub struct SourceToken { + stop: Arc, +} + +impl SourceToken { + pub fn new() -> Self { + Self { + stop: Arc::new(AtomicBool::new(false)), + } + } + + pub fn stop(&self) { + self.stop.store(true, Ordering::Relaxed); + } + + pub fn stop_requested(&self) -> bool { + self.stop.load(Ordering::Relaxed) + } +} + pub struct Morsel { /// The data contained in this morsel. df: DataFrame, @@ -49,28 +74,40 @@ pub struct Morsel { /// within a pipeline. seq: MorselSeq, + /// A token that indicates which source this morsel originates from. + source_token: SourceToken, + /// Used to notify someone when this morsel is consumed, to provide backpressure. consume_token: Option, } impl Morsel { - pub fn new(df: DataFrame, seq: MorselSeq) -> Self { + pub fn new(df: DataFrame, seq: MorselSeq, source_token: SourceToken) -> Self { Self { df, seq, + source_token, consume_token: None, } } #[allow(unused)] - pub fn into_inner(self) -> (DataFrame, MorselSeq, Option) { - (self.df, self.seq, self.consume_token) + pub fn into_inner(self) -> (DataFrame, MorselSeq, SourceToken, Option) { + (self.df, self.seq, self.source_token, self.consume_token) + } + + pub fn into_df(self) -> DataFrame { + self.df } pub fn df(&self) -> &DataFrame { &self.df } + pub fn df_mut(&mut self) -> &mut DataFrame { + &mut self.df + } + pub fn seq(&self) -> MorselSeq { self.seq } @@ -100,4 +137,8 @@ impl Morsel { pub fn take_consume_token(&mut self) -> Option { self.consume_token.take() } + + pub fn source_token(&self) -> &SourceToken { + &self.source_token + } } diff --git a/crates/polars-stream/src/nodes/in_memory_source.rs b/crates/polars-stream/src/nodes/in_memory_source.rs index 1c6124fcb52f..826f9e5e5c83 100644 --- a/crates/polars-stream/src/nodes/in_memory_source.rs +++ b/crates/polars-stream/src/nodes/in_memory_source.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use super::compute_node_prelude::*; use crate::async_primitives::wait_group::WaitGroup; -use crate::morsel::{get_ideal_morsel_size, MorselSeq}; +use crate::morsel::{get_ideal_morsel_size, MorselSeq, SourceToken}; pub struct InMemorySourceNode { source: Option>, @@ -28,9 +28,9 @@ impl ComputeNode for InMemorySourceNode { fn initialize(&mut self, num_pipelines: usize) { let len = self.source.as_ref().unwrap().height(); - let ideal_block_count = (len / get_ideal_morsel_size()).max(1); - let block_count = ideal_block_count.next_multiple_of(num_pipelines); - self.morsel_size = len.div_ceil(block_count).max(1); + let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1); + let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines); + self.morsel_size = len.div_ceil(morsel_count).max(1); self.seq = AtomicU64::new(0); } @@ -71,6 +71,7 @@ impl ComputeNode for InMemorySourceNode { let slf = &*self; join_handles.push(scope.spawn_task(TaskPriority::Low, async move { let wait_group = WaitGroup::default(); + let source_token = SourceToken::new(); loop { let seq = slf.seq.fetch_add(1, Ordering::Relaxed); let offset = (seq as usize * slf.morsel_size) as i64; @@ -79,12 +80,16 @@ impl ComputeNode for InMemorySourceNode { break; } - let mut morsel = Morsel::new(df, MorselSeq::new(seq)); + let mut morsel = Morsel::new(df, MorselSeq::new(seq), source_token.clone()); morsel.set_consume_token(wait_group.token()); if send.send(morsel).await.is_err() { break; } + wait_group.wait().await; + if source_token.stop_requested() { + break; + } } Ok(()) diff --git a/crates/polars-stream/src/nodes/mod.rs b/crates/polars-stream/src/nodes/mod.rs index 31e8c710129d..fecc1b8e5abe 100644 --- a/crates/polars-stream/src/nodes/mod.rs +++ b/crates/polars-stream/src/nodes/mod.rs @@ -8,6 +8,7 @@ pub mod reduce; pub mod select; pub mod simple_projection; pub mod streaming_slice; +pub mod zip; /// The imports you'll always need for implementing a ComputeNode. mod compute_node_prelude { diff --git a/crates/polars-stream/src/nodes/reduce.rs b/crates/polars-stream/src/nodes/reduce.rs index 4b59f4f22515..91209b7d4e33 100644 --- a/crates/polars-stream/src/nodes/reduce.rs +++ b/crates/polars-stream/src/nodes/reduce.rs @@ -5,6 +5,7 @@ use polars_expr::prelude::PhysicalExpr; use polars_expr::reduce::Reduction; use super::compute_node_prelude::*; +use crate::morsel::SourceToken; enum ReduceState { Sink { @@ -84,7 +85,7 @@ impl ReduceNode { ) { let mut send = send.serial(); join_handles.push(scope.spawn_task(TaskPriority::High, async move { - let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0)); + let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0), SourceToken::new()); let _ = send.send(morsel).await; Ok(()) })); diff --git a/crates/polars-stream/src/nodes/select.rs b/crates/polars-stream/src/nodes/select.rs index 4f831f8ebaa4..e021a3a0f197 100644 --- a/crates/polars-stream/src/nodes/select.rs +++ b/crates/polars-stream/src/nodes/select.rs @@ -54,7 +54,7 @@ impl ComputeNode for SelectNode { let slf = &*self; join_handles.push(scope.spawn_task(TaskPriority::High, async move { while let Ok(morsel) = recv.recv().await { - let (df, seq, consume_token) = morsel.into_inner(); + let (df, seq, source_token, consume_token) = morsel.into_inner(); let mut selected = Vec::new(); for (selector, reentrant) in slf.selectors.iter().zip(&slf.selector_reentrant) { // We need spawn_blocking because evaluate could contain Python UDFs which @@ -94,7 +94,7 @@ impl ComputeNode for SelectNode { unsafe { DataFrame::new_no_checks(selected) } }; - let mut morsel = Morsel::new(ret, seq); + let mut morsel = Morsel::new(ret, seq, source_token); if let Some(token) = consume_token { morsel.set_consume_token(token); } diff --git a/crates/polars-stream/src/nodes/streaming_slice.rs b/crates/polars-stream/src/nodes/streaming_slice.rs index a6d2688538d8..b46693bac808 100644 --- a/crates/polars-stream/src/nodes/streaming_slice.rs +++ b/crates/polars-stream/src/nodes/streaming_slice.rs @@ -74,6 +74,13 @@ impl ComputeNode for StreamingSliceNode { } }); + // Technically not necessary, but it's nice to already tell the + // source to stop producing more morsels as we won't be + // interested in the results anyway. + if self.stream_offset >= stop_offset { + morsel.source_token().stop(); + } + if !morsel.df().is_empty() && send.send(morsel).await.is_err() { break; } diff --git a/crates/polars-stream/src/nodes/zip.rs b/crates/polars-stream/src/nodes/zip.rs new file mode 100644 index 000000000000..8816bfbb0640 --- /dev/null +++ b/crates/polars-stream/src/nodes/zip.rs @@ -0,0 +1,178 @@ +use std::collections::VecDeque; + +use polars_core::functions::concat_df_horizontal; + +use super::compute_node_prelude::*; +use crate::morsel::SourceToken; + +pub struct ZipNode { + out_seq: MorselSeq, + input_heads: Vec>, +} + +impl ZipNode { + pub fn new() -> Self { + Self { + out_seq: MorselSeq::new(0), + input_heads: Vec::new(), + } + } +} + +impl ComputeNode for ZipNode { + fn name(&self) -> &str { + "zip" + } + + fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) { + assert!(send.len() == 1); + assert!(!recv.is_empty()); + + let any_input_blocked = recv.iter().any(|s| *s == PortState::Blocked); + + let mut all_done = true; + let mut at_least_one_done = false; + let mut at_least_one_nonempty = false; + for (recv_idx, recv_state) in recv.iter().enumerate() { + let is_empty = self + .input_heads + .get(recv_idx) + .map(|h| h.is_empty()) + .unwrap_or(true); + at_least_one_nonempty |= !is_empty; + if *recv_state == PortState::Done { + all_done &= is_empty; + at_least_one_done |= is_empty; + } else { + all_done = false; + } + } + + assert!( + !(at_least_one_done && at_least_one_nonempty), + "zip received non-equal length inputs" + ); + + let new_recv_state = if send[0] == PortState::Done || all_done { + self.input_heads.clear(); + send[0] = PortState::Done; + PortState::Done + } else if send[0] == PortState::Blocked || any_input_blocked { + send[0] = if any_input_blocked { + PortState::Blocked + } else { + PortState::Ready + }; + PortState::Blocked + } else { + send[0] = PortState::Ready; + PortState::Ready + }; + + for r in recv { + *r = new_recv_state; + } + } + + fn spawn<'env, 's>( + &'env mut self, + scope: &'s TaskScope<'s, 'env>, + recv: &mut [Option>], + send: &mut [Option>], + _state: &'s ExecutionState, + join_handles: &mut Vec>>, + ) { + assert!(send.len() == 1); + assert!(!recv.is_empty()); + let mut sender = send[0].take().unwrap().serial(); + let mut receivers: Vec<_> = recv.iter_mut().map(|r| Some(r.take()?.serial())).collect(); + + self.input_heads.resize_with(receivers.len(), VecDeque::new); + + join_handles.push(scope.spawn_task(TaskPriority::High, async move { + let mut out = Vec::new(); + let source_token = SourceToken::new(); + loop { + if source_token.stop_requested() { + break; + } + + // Fill input heads with non-empty morsels. + for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() { + if let Some(recv) = opt_recv { + if self.input_heads[recv_idx].is_empty() { + while let Ok(morsel) = recv.recv().await { + if morsel.df().height() > 0 { + self.input_heads[recv_idx].push_back(morsel); + break; + } + } + } + } + } + + // TODO: recombine morsels to make sure the concatenation is + // close to the ideal morsel size. + + // Compute common size and send a combined morsel. + let common_size = self + .input_heads + .iter() + .map(|h| h.front().map(|m| m.df().height()).unwrap_or(0)) + .min() + .unwrap(); + if common_size == 0 { + // One or more of the input heads is exhausted (this phase). + break; + } + + for input_head in &mut self.input_heads { + if input_head[0].df().height() == common_size { + out.push(input_head.pop_front().unwrap().into_df()); + } else { + let (head, tail) = input_head[0].df().split_at(common_size as i64); + *input_head[0].df_mut() = tail; + out.push(head); + } + } + + let out_df = concat_df_horizontal(&out)?; + out.clear(); + + let morsel = Morsel::new(out_df, self.out_seq, source_token.clone()); + self.out_seq = self.out_seq.successor(); + if sender.send(morsel).await.is_err() { + // Our receiver is no longer interested in any data, no + // need store the rest of the incoming stream, can directly + // return. + return Ok(()); + } + } + + // We can't continue because one or more input heads is empty. We + // must tell everyone to stop, unblock all pipes by consuming + // all ConsumeTokens, and then store all data that was still flowing + // through the pipelines into input_heads for the next phase. + for input_head in &mut self.input_heads { + for morsel in input_head { + morsel.source_token().stop(); + drop(morsel.take_consume_token()); + } + } + + for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() { + if let Some(recv) = opt_recv { + while let Ok(mut morsel) = recv.recv().await { + morsel.source_token().stop(); + drop(morsel.take_consume_token()); + if morsel.df().height() > 0 { + self.input_heads[recv_idx].push_back(morsel); + } + } + } + } + + Ok(()) + })); + } +} diff --git a/crates/polars-stream/src/physical_plan/lower_ir.rs b/crates/polars-stream/src/physical_plan/lower_ir.rs index e22d06fd869d..0096c9518a54 100644 --- a/crates/polars-stream/src/physical_plan/lower_ir.rs +++ b/crates/polars-stream/src/physical_plan/lower_ir.rs @@ -226,6 +226,19 @@ pub fn lower_ir( Ok(phys_sm.insert(PhysNode::OrderedUnion { inputs })) }, + IR::HConcat { + inputs, + schema: _, + options: _, + } => { + let inputs = inputs + .clone() // Needed to borrow ir_arena mutably. + .into_iter() + .map(|input| lower_ir(input, ir_arena, expr_arena, phys_sm)) + .collect::>()?; + Ok(phys_sm.insert(PhysNode::Zip { inputs })) + }, + _ => todo!(), } } diff --git a/crates/polars-stream/src/physical_plan/mod.rs b/crates/polars-stream/src/physical_plan/mod.rs index 015f4cdfec94..27147ed194f2 100644 --- a/crates/polars-stream/src/physical_plan/mod.rs +++ b/crates/polars-stream/src/physical_plan/mod.rs @@ -86,4 +86,8 @@ pub enum PhysNode { OrderedUnion { inputs: Vec, }, + + Zip { + inputs: Vec, + }, } diff --git a/crates/polars-stream/src/physical_plan/to_graph.rs b/crates/polars-stream/src/physical_plan/to_graph.rs index b30a7a6aab73..3a604b05e7fd 100644 --- a/crates/polars-stream/src/physical_plan/to_graph.rs +++ b/crates/polars-stream/src/physical_plan/to_graph.rs @@ -237,6 +237,14 @@ fn to_graph_rec<'a>( ctx.graph .add_node(nodes::ordered_union::OrderedUnionNode::new(), input_keys) }, + + Zip { inputs } => { + let input_keys = inputs + .iter() + .map(|i| to_graph_rec(*i, ctx)) + .collect::, _>>()?; + ctx.graph.add_node(nodes::zip::ZipNode::new(), input_keys) + }, }; ctx.phys_to_graph.insert(phys_node_key, graph_key);