Skip to content

Commit

Permalink
Add back join hook
Browse files Browse the repository at this point in the history
  • Loading branch information
thinkharderdev authored and Dandandan committed Oct 30, 2024
1 parent 1fd6116 commit 771d5fb
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 9 deletions.
123 changes: 115 additions & 8 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ use crate::{
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};
use std::ops::{Deref, DerefMut};
use std::task::Context;

use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array,
Expand All @@ -71,11 +73,58 @@ use datafusion_physical_expr::equivalence::{
use datafusion_physical_expr::PhysicalExprRef;

use ahash::RandomState;
use arrow_buffer::BooleanBuffer;
use datafusion_expr::Operator;
use datafusion_physical_expr_common::datum::compare_op_for_nested;
use futures::{ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;

pub struct SharedJoinState {
state_impl: Arc<dyn SharedJoinStateImpl>,
}

impl SharedJoinState {
pub fn new(state_impl: Arc<dyn SharedJoinStateImpl>) -> Self {
Self { state_impl }
}

fn num_task_partitions(&self) -> usize {
self.state_impl.num_task_partitions()
}

fn poll_probe_completed(
&self,
mask: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<SharedProbeState>> {
self.state_impl.poll_probe_completed(mask, cx)
}

fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize) {
self.state_impl.register_metrics(metrics, partition)
}
}

pub enum SharedProbeState {
// Probes are still running in other distributed tasks
Continue,
// Current task is last probe running so emit unmatched rows
// if required by join type
Ready(BooleanBuffer),
}

pub trait SharedJoinStateImpl: Send + Sync + 'static {
fn num_task_partitions(&self) -> usize;

fn poll_probe_completed(
&self,
visited_indices_bitmap: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<SharedProbeState>>;

fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize);
}

type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;

/// HashTable and input data for the left (build side) of a join
Expand All @@ -89,6 +138,7 @@ struct JoinLeftData {
/// Counter of running probe-threads, potentially
/// able to update `visited_indices_bitmap`
probe_threads_counter: AtomicUsize,
shared_state: Option<Arc<SharedJoinState>>,
/// Memory reservation that tracks memory used by `hash_map` hash table
/// `batch`. Cleared on drop.
#[allow(dead_code)]
Expand All @@ -103,12 +153,14 @@ impl JoinLeftData {
visited_indices_bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Self {
Self {
hash_map,
batch,
visited_indices_bitmap,
probe_threads_counter,
shared_state: distributed_state,
reservation,
}
}
Expand All @@ -127,14 +179,34 @@ impl JoinLeftData {
fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder {
&self.visited_indices_bitmap
}

/// Decrements the counter of running threads, and returns `true`
/// if caller is the last running thread
fn report_probe_completed(&self) -> bool {
self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
self.probe_threads_counter.load(Ordering::Relaxed) == 0
|| self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
}
}

fn merge_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> {
if m1.len() != m2.len() {
return Err(DataFusionError::Execution(format!(
"local and shared indices bitmaps have different lengths: {} and {}",
m1.len(),
m2.len()
)));
}

for (b1, b2) in m1
.as_slice_mut()
.iter_mut()
.zip(m2.inner().as_slice().iter().copied())
{
*b1 |= b2;
}

Ok(())
}

/// Join execution plan: Evaluates eqijoin predicates in parallel on multiple
/// partitions using a hash table and an optional filter list to apply post
/// join.
Expand Down Expand Up @@ -694,11 +766,25 @@ impl ExecutionPlan for HashJoinExec {
);
}

let distributed_state =
context.session_config().get_extension::<SharedJoinState>();

let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let left_fut = match self.mode {
PartitionMode::CollectLeft => self.left_fut.once(|| {
let reservation =
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());

let probe_threads = distributed_state
.as_ref()
.map(|s| {
s.register_metrics(&self.metrics, partition);
s.num_task_partitions()
})
.unwrap_or_else(|| {
self.right().output_partitioning().partition_count()
});

collect_left_input(
None,
self.random_state.clone(),
Expand All @@ -708,7 +794,8 @@ impl ExecutionPlan for HashJoinExec {
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
probe_threads,
distributed_state,
)
}),
PartitionMode::Partitioned => {
Expand All @@ -726,6 +813,7 @@ impl ExecutionPlan for HashJoinExec {
reservation,
need_produce_result_in_final(self.join_type),
1,
None,
))
}
PartitionMode::Auto => {
Expand Down Expand Up @@ -812,6 +900,7 @@ async fn collect_left_input(
reservation: MemoryReservation,
with_visited_indices_bitmap: bool,
probe_threads_count: usize,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Result<JoinLeftData> {
let schema = left.schema();

Expand Down Expand Up @@ -899,6 +988,7 @@ async fn collect_left_input(
Mutex::new(visited_indices_bitmap),
AtomicUsize::new(probe_threads_count),
reservation,
distributed_state,
);

Ok(data)
Expand Down Expand Up @@ -1286,7 +1376,7 @@ impl HashJoinStream {
handle_state!(self.process_probe_batch())
}
HashJoinStreamState::ExhaustedProbeSide => {
handle_state!(self.process_unmatched_build_batch())
handle_state!(ready!(self.process_unmatched_build_batch(cx)))
}
HashJoinStreamState::Completed => Poll::Ready(None),
};
Expand Down Expand Up @@ -1472,18 +1562,35 @@ impl HashJoinStream {
/// Updates state to `Completed`
fn process_unmatched_build_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let timer = self.join_metrics.join_time.timer();

if !need_produce_result_in_final(self.join_type) {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

let build_side = self.build_side.try_as_ready()?;
if !build_side.left_data.report_probe_completed() {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

if let Some(shared_state) = build_side.left_data.shared_state.as_ref() {
let mut guard = build_side.left_data.visited_indices_bitmap().lock();
match ready!(shared_state.poll_probe_completed(guard.deref(), cx)) {
Ok(SharedProbeState::Continue) => {
self.state = HashJoinStreamState::Completed;
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
Ok(SharedProbeState::Ready(shared_mask)) => {
if let Err(e) = merge_bitmap(guard.deref_mut(), shared_mask) {
return Poll::Ready(Err(e));
}
}
Err(err) => return Poll::Ready(Err(err)),
}
}

// use the global left bitmap to produce the left indices and right indices
Expand Down Expand Up @@ -1514,7 +1621,7 @@ impl HashJoinStream {

self.state = HashJoinStreamState::Completed;

Ok(StatefulStreamResult::Ready(Some(result?)))
Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result?))))
}
}

Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! DataFusion Join implementations

pub use cross_join::CrossJoinExec;
pub use hash_join::HashJoinExec;
pub use hash_join::{
HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState,
};
pub use nested_loop_join::NestedLoopJoinExec;
// Note: SortMergeJoin is not used in plans yet
pub use sort_merge_join::SortMergeJoinExec;
Expand Down

0 comments on commit 771d5fb

Please sign in to comment.