Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Jan 29, 2025
1 parent 52632d1 commit c6049f1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 51 deletions.
2 changes: 2 additions & 0 deletions crates/polars-stream/src/nodes/in_memory_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ impl ComputeNode for InMemorySourceNode {
break;
}

dbg!(&df);

let morsel_seq = MorselSeq::new(seq).offset_by(slf.seq_offset);
let mut morsel = Morsel::new(df, morsel_seq, source_token.clone());
morsel.set_consume_token(wait_group.token());
Expand Down
141 changes: 90 additions & 51 deletions crates/polars-stream/src/nodes/merge_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use polars_core::schema::Schema;
use polars_ops::frame::_merge_sorted_dfs;
use polars_utils::pl_str::PlSmallStr;

use crate::async_primitives::wait_group::WaitGroup;
use crate::morsel::SourceToken;
use crate::nodes::compute_node_prelude::*;

Expand All @@ -14,17 +15,8 @@ enum Side {
Right,
}

#[derive(Debug)]
enum State {
/// Merging values from buffered or ports.
Merging,
/// Passing values along from one of the ports.
Passing,
}

pub struct MergeSortedNode {
key_column_idx: usize,
state: State,

seq: MorselSeq,

Expand All @@ -44,7 +36,6 @@ impl MergeSortedNode {

Self {
key_column_idx,
state: State::Merging,

seq: MorselSeq::default(),

Expand All @@ -68,21 +59,20 @@ impl ComputeNode for MergeSortedNode {
if send[0] == PortState::Done {
recv[0] = PortState::Done;
recv[1] = PortState::Done;
return Ok(());
}

self.state = match (recv[0], recv[1]) {
_ if !self.merged.is_empty() || !self.unmerged.is_empty() => State::Merging,

// If one of the ports is closed and the buffers are empty, we can just start passing
// the morsels along the port.
(PortState::Done, PortState::Done) => {
send[0] = PortState::Done;
State::Passing
},
(PortState::Done, _) | (_, PortState::Done) => State::Passing,
if recv[0] == PortState::Done && recv[1] == PortState::Done {
send[0] = PortState::Done;
return Ok(())
}

_ => State::Merging,
};
if recv[0] != PortState::Done {
recv[0] = send[0];
}
if recv[1] != PortState::Done {
recv[1] = send[0];
}

Ok(())
}
Expand All @@ -104,17 +94,27 @@ impl ComputeNode for MergeSortedNode {
let unmerged = &mut self.unmerged;
let unmerged_side = &mut self.unmerged_side;

match self.state {
dbg!(&state);
dbg!(&merged);
dbg!(&unmerged);

match state {
State::Passing => {
assert!(unmerged.is_empty());
assert!(merged.is_empty());

let mut send = send_ports[0].take().unwrap().serial();

dbg!(recv_ports[0].is_some(), recv_ports[1].is_some());

match (recv_ports[0].take(), recv_ports[1].take()) {
(None, None) => {},
(Some(port), None) | (None, Some(port)) => {
// @TODO: Turn into parallel passing.
let mut recv = port.serial();
join_handles.push(scope.spawn_task(TaskPriority::High, async move {
while let Ok(morsel) = recv.recv().await {
dbg!(morsel.df());
match send.send(morsel).await {
Ok(_) => *seq = seq.successor(),
Err(m) => {
Expand All @@ -130,22 +130,70 @@ impl ComputeNode for MergeSortedNode {
(Some(_), Some(_)) => unreachable!(),
}
},
State::EmptyBuffered => {
let wait_group = WaitGroup::default();
let source_token = SourceToken::new();
let mut send = send_ports[0].take().unwrap().serial();

join_handles.push(scope.spawn_task(TaskPriority::High, async move {
wait_group.wait().await;
if source_token.stop_requested() {
return Ok(());
}

if !merged.is_empty() {
// @TODO: Break unmerged up in smaller chunks as it might be very
// big.
eprintln!("Sending merged: ({})", merged.height());
let morsel = Morsel::new(merged.clone(), *seq, source_token.clone());
if send.send(morsel).await.is_err() {
return Ok(());
}

*merged = merged.clear();
*seq = seq.successor();
}

if !unmerged.is_empty() {
eprintln!("Sending unmerged: ({})", unmerged.height());

// @TODO: Break unmerged up in smaller chunks as it might be very
// big.
if send
.send(Morsel::new(unmerged.clone(), *seq, source_token.clone()))
.await
.is_err()
{
return Ok(());
};

*unmerged = unmerged.clear();
*seq = seq.successor();
}

*state = State::Passing;
Ok(())
}));
},
State::Merging => {
let wait_group = WaitGroup::default();
let source_token = SourceToken::new();
let mut send = send_ports[0].take().unwrap().serial();

let mut left = recv_ports[0].take().map(|p| p.serial());
let mut right = recv_ports[1].take().map(|p| p.serial());
let mut left = recv_ports[0].take().unwrap().serial();
let mut right = recv_ports[1].take().unwrap().serial();

join_handles.push(scope.spawn_task(TaskPriority::Low, async move {
loop {
wait_group.wait().await;
if source_token.stop_requested() {
return Ok(());
}

if !merged.is_empty() {
// @TODO: Break unmerged up in smaller chunks as it might be very
// big.
eprintln!("Sending merged: ({})", merged.height());
let morsel = Morsel::new(merged.clone(), *seq, source_token.clone());
if send.send(morsel).await.is_err() {
return Ok(());
Expand All @@ -156,36 +204,27 @@ impl ComputeNode for MergeSortedNode {
}

let opposite_port = match *unmerged_side {
Side::Left => right.as_mut(),
Side::Right => left.as_mut(),
};

let other_unmerged = match opposite_port {
None => {
if unmerged.is_empty() {
return Ok(());
}

// @TODO: Break unmerged up in smaller chunks as it might be very
// big.
if send
.send(Morsel::new(unmerged.clone(), *seq, source_token.clone()))
.await
.is_err()
{
return Ok(());
};

*unmerged = unmerged.clear();
*seq = seq.successor();
return Ok(());
},
Some(port) => port.recv().await,
Side::Left => &mut right,
Side::Right => &mut left,
};

let Ok(other_unmerged) = other_unmerged else {
let Ok(other_unmerged) = opposite_port.recv().await else {
match *unmerged_side {
Side::Left => {
if matches!(*right_status, Status::Done) {
*right_status = Status::Flushed;
}
},
Side::Right => {
if matches!(*left_status, Status::Done) {
*left_status = Status::Flushed;
}
},
}
return Ok(());
};

dbg!(other_unmerged.df());
let other_unmerged = other_unmerged.into_df();
let taken_unmerged = std::mem::take(unmerged);
let (left_unmerged, right_unmerged) = match *unmerged_side {
Expand Down

0 comments on commit c6049f1

Please sign in to comment.