From 771d5fb59cda1fe2211925002b78f5b86d3a8944 Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Fri, 20 Sep 2024 07:06:48 -0400 Subject: [PATCH] Add back join hook --- .../physical-plan/src/joins/hash_join.rs | 123 ++++++++++++++++-- datafusion/physical-plan/src/joins/mod.rs | 4 +- 2 files changed, 118 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 3b730c01291c..1dfc543ebb69 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -47,6 +47,8 @@ use crate::{ Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use std::ops::{Deref, DerefMut}; +use std::task::Context; use arrow::array::{ Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array, @@ -71,11 +73,58 @@ use datafusion_physical_expr::equivalence::{ use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; +use arrow_buffer::BooleanBuffer; use datafusion_expr::Operator; use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; +pub struct SharedJoinState { + state_impl: Arc, +} + +impl SharedJoinState { + pub fn new(state_impl: Arc) -> Self { + Self { state_impl } + } + + fn num_task_partitions(&self) -> usize { + self.state_impl.num_task_partitions() + } + + fn poll_probe_completed( + &self, + mask: &BooleanBufferBuilder, + cx: &mut Context<'_>, + ) -> Poll> { + self.state_impl.poll_probe_completed(mask, cx) + } + + fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize) { + self.state_impl.register_metrics(metrics, partition) + } +} + +pub enum SharedProbeState { + // Probes are still running in other distributed tasks + Continue, + // Current task is last probe running so emit unmatched rows + // if required by join type + Ready(BooleanBuffer), +} + +pub trait SharedJoinStateImpl: Send + Sync + 'static { + fn num_task_partitions(&self) -> usize; + + fn poll_probe_completed( + &self, + visited_indices_bitmap: &BooleanBufferBuilder, + cx: &mut Context<'_>, + ) -> Poll>; + + fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize); +} + type SharedBitmapBuilder = Mutex; /// HashTable and input data for the left (build side) of a join @@ -89,6 +138,7 @@ struct JoinLeftData { /// Counter of running probe-threads, potentially /// able to update `visited_indices_bitmap` probe_threads_counter: AtomicUsize, + shared_state: Option>, /// Memory reservation that tracks memory used by `hash_map` hash table /// `batch`. Cleared on drop. #[allow(dead_code)] @@ -103,12 +153,14 @@ impl JoinLeftData { visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, + distributed_state: Option>, ) -> Self { Self { hash_map, batch, visited_indices_bitmap, probe_threads_counter, + shared_state: distributed_state, reservation, } } @@ -127,14 +179,34 @@ impl JoinLeftData { fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { &self.visited_indices_bitmap } - /// Decrements the counter of running threads, and returns `true` /// if caller is the last running thread fn report_probe_completed(&self) -> bool { - self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 + self.probe_threads_counter.load(Ordering::Relaxed) == 0 + || self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 } } +fn merge_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> { + if m1.len() != m2.len() { + return Err(DataFusionError::Execution(format!( + "local and shared indices bitmaps have different lengths: {} and {}", + m1.len(), + m2.len() + ))); + } + + for (b1, b2) in m1 + .as_slice_mut() + .iter_mut() + .zip(m2.inner().as_slice().iter().copied()) + { + *b1 |= b2; + } + + Ok(()) +} + /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post /// join. @@ -694,11 +766,25 @@ impl ExecutionPlan for HashJoinExec { ); } + let distributed_state = + context.session_config().get_extension::(); + let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { PartitionMode::CollectLeft => self.left_fut.once(|| { let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); + + let probe_threads = distributed_state + .as_ref() + .map(|s| { + s.register_metrics(&self.metrics, partition); + s.num_task_partitions() + }) + .unwrap_or_else(|| { + self.right().output_partitioning().partition_count() + }); + collect_left_input( None, self.random_state.clone(), @@ -708,7 +794,8 @@ impl ExecutionPlan for HashJoinExec { join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), - self.right().output_partitioning().partition_count(), + probe_threads, + distributed_state, ) }), PartitionMode::Partitioned => { @@ -726,6 +813,7 @@ impl ExecutionPlan for HashJoinExec { reservation, need_produce_result_in_final(self.join_type), 1, + None, )) } PartitionMode::Auto => { @@ -812,6 +900,7 @@ async fn collect_left_input( reservation: MemoryReservation, with_visited_indices_bitmap: bool, probe_threads_count: usize, + distributed_state: Option>, ) -> Result { let schema = left.schema(); @@ -899,6 +988,7 @@ async fn collect_left_input( Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), reservation, + distributed_state, ); Ok(data) @@ -1286,7 +1376,7 @@ impl HashJoinStream { handle_state!(self.process_probe_batch()) } HashJoinStreamState::ExhaustedProbeSide => { - handle_state!(self.process_unmatched_build_batch()) + handle_state!(ready!(self.process_unmatched_build_batch(cx))) } HashJoinStreamState::Completed => Poll::Ready(None), }; @@ -1472,18 +1562,35 @@ impl HashJoinStream { /// Updates state to `Completed` fn process_unmatched_build_batch( &mut self, - ) -> Result>> { + cx: &mut Context<'_>, + ) -> Poll>>> { let timer = self.join_metrics.join_time.timer(); if !need_produce_result_in_final(self.join_type) { self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); } let build_side = self.build_side.try_as_ready()?; if !build_side.left_data.report_probe_completed() { self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if let Some(shared_state) = build_side.left_data.shared_state.as_ref() { + let mut guard = build_side.left_data.visited_indices_bitmap().lock(); + match ready!(shared_state.poll_probe_completed(guard.deref(), cx)) { + Ok(SharedProbeState::Continue) => { + self.state = HashJoinStreamState::Completed; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Ok(SharedProbeState::Ready(shared_mask)) => { + if let Err(e) = merge_bitmap(guard.deref_mut(), shared_mask) { + return Poll::Ready(Err(e)); + } + } + Err(err) => return Poll::Ready(Err(err)), + } } // use the global left bitmap to produce the left indices and right indices @@ -1514,7 +1621,7 @@ impl HashJoinStream { self.state = HashJoinStreamState::Completed; - Ok(StatefulStreamResult::Ready(Some(result?))) + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result?)))) } } diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 6ddf19c51193..221f664f0e34 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -18,7 +18,9 @@ //! DataFusion Join implementations pub use cross_join::CrossJoinExec; -pub use hash_join::HashJoinExec; +pub use hash_join::{ + HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState, +}; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec;