Skip to content

Commit

Permalink
refactor(rust): Add zip node to streaming engine (#17866)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Jul 25, 2024
1 parent 5a108d4 commit 54b7fb8
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 12 deletions.
49 changes: 45 additions & 4 deletions crates/polars-stream/src/morsel.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};

use polars_core::frame::DataFrame;

Expand Down Expand Up @@ -41,6 +42,30 @@ impl MorselSeq {
}
}

/// A token indicating which source this morsel originated from, and a way to
/// pass information/signals to it. Currently it's only used to request a source
/// to stop with passing new morsels this execution phase.
#[derive(Clone)]
pub struct SourceToken {
stop: Arc<AtomicBool>,
}

impl SourceToken {
pub fn new() -> Self {
Self {
stop: Arc::new(AtomicBool::new(false)),
}
}

pub fn stop(&self) {
self.stop.store(true, Ordering::Relaxed);
}

pub fn stop_requested(&self) -> bool {
self.stop.load(Ordering::Relaxed)
}
}

pub struct Morsel {
/// The data contained in this morsel.
df: DataFrame,
Expand All @@ -49,28 +74,40 @@ pub struct Morsel {
/// within a pipeline.
seq: MorselSeq,

/// A token that indicates which source this morsel originates from.
source_token: SourceToken,

/// Used to notify someone when this morsel is consumed, to provide backpressure.
consume_token: Option<WaitToken>,
}

impl Morsel {
pub fn new(df: DataFrame, seq: MorselSeq) -> Self {
pub fn new(df: DataFrame, seq: MorselSeq, source_token: SourceToken) -> Self {
Self {
df,
seq,
source_token,
consume_token: None,
}
}

#[allow(unused)]
pub fn into_inner(self) -> (DataFrame, MorselSeq, Option<WaitToken>) {
(self.df, self.seq, self.consume_token)
pub fn into_inner(self) -> (DataFrame, MorselSeq, SourceToken, Option<WaitToken>) {
(self.df, self.seq, self.source_token, self.consume_token)
}

pub fn into_df(self) -> DataFrame {
self.df
}

pub fn df(&self) -> &DataFrame {
&self.df
}

pub fn df_mut(&mut self) -> &mut DataFrame {
&mut self.df
}

pub fn seq(&self) -> MorselSeq {
self.seq
}
Expand Down Expand Up @@ -100,4 +137,8 @@ impl Morsel {
pub fn take_consume_token(&mut self) -> Option<WaitToken> {
self.consume_token.take()
}

pub fn source_token(&self) -> &SourceToken {
&self.source_token
}
}
15 changes: 10 additions & 5 deletions crates/polars-stream/src/nodes/in_memory_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;

use super::compute_node_prelude::*;
use crate::async_primitives::wait_group::WaitGroup;
use crate::morsel::{get_ideal_morsel_size, MorselSeq};
use crate::morsel::{get_ideal_morsel_size, MorselSeq, SourceToken};

pub struct InMemorySourceNode {
source: Option<Arc<DataFrame>>,
Expand All @@ -28,9 +28,9 @@ impl ComputeNode for InMemorySourceNode {

fn initialize(&mut self, num_pipelines: usize) {
let len = self.source.as_ref().unwrap().height();
let ideal_block_count = (len / get_ideal_morsel_size()).max(1);
let block_count = ideal_block_count.next_multiple_of(num_pipelines);
self.morsel_size = len.div_ceil(block_count).max(1);
let ideal_morsel_count = (len / get_ideal_morsel_size()).max(1);
let morsel_count = ideal_morsel_count.next_multiple_of(num_pipelines);
self.morsel_size = len.div_ceil(morsel_count).max(1);
self.seq = AtomicU64::new(0);
}

Expand Down Expand Up @@ -71,6 +71,7 @@ impl ComputeNode for InMemorySourceNode {
let slf = &*self;
join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
let wait_group = WaitGroup::default();
let source_token = SourceToken::new();
loop {
let seq = slf.seq.fetch_add(1, Ordering::Relaxed);
let offset = (seq as usize * slf.morsel_size) as i64;
Expand All @@ -79,12 +80,16 @@ impl ComputeNode for InMemorySourceNode {
break;
}

let mut morsel = Morsel::new(df, MorselSeq::new(seq));
let mut morsel = Morsel::new(df, MorselSeq::new(seq), source_token.clone());
morsel.set_consume_token(wait_group.token());
if send.send(morsel).await.is_err() {
break;
}

wait_group.wait().await;
if source_token.stop_requested() {
break;
}
}

Ok(())
Expand Down
1 change: 1 addition & 0 deletions crates/polars-stream/src/nodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod reduce;
pub mod select;
pub mod simple_projection;
pub mod streaming_slice;
pub mod zip;

/// The imports you'll always need for implementing a ComputeNode.
mod compute_node_prelude {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-stream/src/nodes/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_expr::prelude::PhysicalExpr;
use polars_expr::reduce::Reduction;

use super::compute_node_prelude::*;
use crate::morsel::SourceToken;

enum ReduceState {
Sink {
Expand Down Expand Up @@ -84,7 +85,7 @@ impl ReduceNode {
) {
let mut send = send.serial();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0));
let morsel = Morsel::new(df.take().unwrap(), MorselSeq::new(0), SourceToken::new());
let _ = send.send(morsel).await;
Ok(())
}));
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-stream/src/nodes/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl ComputeNode for SelectNode {
let slf = &*self;
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
while let Ok(morsel) = recv.recv().await {
let (df, seq, consume_token) = morsel.into_inner();
let (df, seq, source_token, consume_token) = morsel.into_inner();
let mut selected = Vec::new();
for (selector, reentrant) in slf.selectors.iter().zip(&slf.selector_reentrant) {
// We need spawn_blocking because evaluate could contain Python UDFs which
Expand Down Expand Up @@ -94,7 +94,7 @@ impl ComputeNode for SelectNode {
unsafe { DataFrame::new_no_checks(selected) }
};

let mut morsel = Morsel::new(ret, seq);
let mut morsel = Morsel::new(ret, seq, source_token);
if let Some(token) = consume_token {
morsel.set_consume_token(token);
}
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-stream/src/nodes/streaming_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ impl ComputeNode for StreamingSliceNode {
}
});

// Technically not necessary, but it's nice to already tell the
// source to stop producing more morsels as we won't be
// interested in the results anyway.
if self.stream_offset >= stop_offset {
morsel.source_token().stop();
}

if !morsel.df().is_empty() && send.send(morsel).await.is_err() {
break;
}
Expand Down
178 changes: 178 additions & 0 deletions crates/polars-stream/src/nodes/zip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use std::collections::VecDeque;

use polars_core::functions::concat_df_horizontal;

use super::compute_node_prelude::*;
use crate::morsel::SourceToken;

pub struct ZipNode {
out_seq: MorselSeq,
input_heads: Vec<VecDeque<Morsel>>,
}

impl ZipNode {
pub fn new() -> Self {
Self {
out_seq: MorselSeq::new(0),
input_heads: Vec::new(),
}
}
}

impl ComputeNode for ZipNode {
fn name(&self) -> &str {
"zip"
}

fn update_state(&mut self, recv: &mut [PortState], send: &mut [PortState]) {
assert!(send.len() == 1);
assert!(!recv.is_empty());

let any_input_blocked = recv.iter().any(|s| *s == PortState::Blocked);

let mut all_done = true;
let mut at_least_one_done = false;
let mut at_least_one_nonempty = false;
for (recv_idx, recv_state) in recv.iter().enumerate() {
let is_empty = self
.input_heads
.get(recv_idx)
.map(|h| h.is_empty())
.unwrap_or(true);
at_least_one_nonempty |= !is_empty;
if *recv_state == PortState::Done {
all_done &= is_empty;
at_least_one_done |= is_empty;
} else {
all_done = false;
}
}

assert!(
!(at_least_one_done && at_least_one_nonempty),
"zip received non-equal length inputs"
);

let new_recv_state = if send[0] == PortState::Done || all_done {
self.input_heads.clear();
send[0] = PortState::Done;
PortState::Done
} else if send[0] == PortState::Blocked || any_input_blocked {
send[0] = if any_input_blocked {
PortState::Blocked
} else {
PortState::Ready
};
PortState::Blocked
} else {
send[0] = PortState::Ready;
PortState::Ready
};

for r in recv {
*r = new_recv_state;
}
}

fn spawn<'env, 's>(
&'env mut self,
scope: &'s TaskScope<'s, 'env>,
recv: &mut [Option<RecvPort<'_>>],
send: &mut [Option<SendPort<'_>>],
_state: &'s ExecutionState,
join_handles: &mut Vec<JoinHandle<PolarsResult<()>>>,
) {
assert!(send.len() == 1);
assert!(!recv.is_empty());
let mut sender = send[0].take().unwrap().serial();
let mut receivers: Vec<_> = recv.iter_mut().map(|r| Some(r.take()?.serial())).collect();

self.input_heads.resize_with(receivers.len(), VecDeque::new);

join_handles.push(scope.spawn_task(TaskPriority::High, async move {
let mut out = Vec::new();
let source_token = SourceToken::new();
loop {
if source_token.stop_requested() {
break;
}

// Fill input heads with non-empty morsels.
for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() {
if let Some(recv) = opt_recv {
if self.input_heads[recv_idx].is_empty() {
while let Ok(morsel) = recv.recv().await {
if morsel.df().height() > 0 {
self.input_heads[recv_idx].push_back(morsel);
break;
}
}
}
}
}

// TODO: recombine morsels to make sure the concatenation is
// close to the ideal morsel size.

// Compute common size and send a combined morsel.
let common_size = self
.input_heads
.iter()
.map(|h| h.front().map(|m| m.df().height()).unwrap_or(0))
.min()
.unwrap();
if common_size == 0 {
// One or more of the input heads is exhausted (this phase).
break;
}

for input_head in &mut self.input_heads {
if input_head[0].df().height() == common_size {
out.push(input_head.pop_front().unwrap().into_df());
} else {
let (head, tail) = input_head[0].df().split_at(common_size as i64);
*input_head[0].df_mut() = tail;
out.push(head);
}
}

let out_df = concat_df_horizontal(&out)?;
out.clear();

let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
self.out_seq = self.out_seq.successor();
if sender.send(morsel).await.is_err() {
// Our receiver is no longer interested in any data, no
// need store the rest of the incoming stream, can directly
// return.
return Ok(());
}
}

// We can't continue because one or more input heads is empty. We
// must tell everyone to stop, unblock all pipes by consuming
// all ConsumeTokens, and then store all data that was still flowing
// through the pipelines into input_heads for the next phase.
for input_head in &mut self.input_heads {
for morsel in input_head {
morsel.source_token().stop();
drop(morsel.take_consume_token());
}
}

for (recv_idx, opt_recv) in receivers.iter_mut().enumerate() {
if let Some(recv) = opt_recv {
while let Ok(mut morsel) = recv.recv().await {
morsel.source_token().stop();
drop(morsel.take_consume_token());
if morsel.df().height() > 0 {
self.input_heads[recv_idx].push_back(morsel);
}
}
}
}

Ok(())
}));
}
}
13 changes: 13 additions & 0 deletions crates/polars-stream/src/physical_plan/lower_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ pub fn lower_ir(
Ok(phys_sm.insert(PhysNode::OrderedUnion { inputs }))
},

IR::HConcat {
inputs,
schema: _,
options: _,
} => {
let inputs = inputs
.clone() // Needed to borrow ir_arena mutably.
.into_iter()
.map(|input| lower_ir(input, ir_arena, expr_arena, phys_sm))
.collect::<Result<_, _>>()?;
Ok(phys_sm.insert(PhysNode::Zip { inputs }))
},

_ => todo!(),
}
}
Loading

0 comments on commit 54b7fb8

Please sign in to comment.