Skip to content

Commit

Permalink
Hook for doing distributed CollectLeft joins (#269)
Browse files Browse the repository at this point in the history
* Add state for distributed joins

* Make DistributedJoinState public

* Implement shared bitmap hook

* make trait public

* remove unused method

* make enum public

* Get probe threads from distributed state

* Avoid issue with probe threads counter wrapping

* Use naming conventions from upstream PR

* Add hook to register metrics

* Add partition to register_metrics
  • Loading branch information
thinkharderdev authored Sep 20, 2024
1 parent f490e9e commit 534b4ac
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 15 deletions.
133 changes: 119 additions & 14 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@

//! [`HashJoinExec`] Partitioned Hash Join Operator
use std::fmt;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
use std::{any::Any, vec};

use super::{
utils::{OnceAsync, OnceFut},
PartitionMode,
Expand All @@ -46,6 +40,12 @@ use crate::{
Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream,
Statistics,
};
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::{any::Any, vec};

use arrow::array::{
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
Expand All @@ -72,9 +72,56 @@ use datafusion_physical_expr::expressions::UnKnownColumn;
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use ahash::RandomState;
use arrow_buffer::BooleanBuffer;
use futures::{ready, Stream, StreamExt, TryStreamExt};
use parking_lot::Mutex;

pub struct SharedJoinState {
state_impl: Arc<dyn SharedJoinStateImpl>,
}

impl SharedJoinState {
pub fn new(state_impl: Arc<dyn SharedJoinStateImpl>) -> Self {
Self { state_impl }
}

fn num_task_partitions(&self) -> usize {
self.state_impl.num_task_partitions()
}

fn poll_probe_completed(
&self,
mask: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<SharedProbeState>> {
self.state_impl.poll_probe_completed(mask, cx)
}

fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize) {
self.state_impl.register_metrics(metrics, partition)
}
}

pub enum SharedProbeState {
// Probes are still running in other distributed tasks
Continue,
// Current task is last probe running so emit unmatched rows
// if required by join type
Ready(BooleanBuffer),
}

pub trait SharedJoinStateImpl: Send + Sync + 'static {
fn num_task_partitions(&self) -> usize;

fn poll_probe_completed(
&self,
visited_indices_bitmap: &BooleanBufferBuilder,
cx: &mut Context<'_>,
) -> Poll<Result<SharedProbeState>>;

fn register_metrics(&self, metrics: &ExecutionPlanMetricsSet, partition: usize);
}

type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;

/// HashTable and input data for the left (build side) of a join
Expand All @@ -88,6 +135,7 @@ struct JoinLeftData {
/// Counter of running probe-threads, potentially
/// able to update `visited_indices_bitmap`
probe_threads_counter: AtomicUsize,
shared_state: Option<Arc<SharedJoinState>>,
/// Memory reservation that tracks memory used by `hash_map` hash table
/// `batch`. Cleared on drop.
#[allow(dead_code)]
Expand All @@ -102,12 +150,14 @@ impl JoinLeftData {
visited_indices_bitmap: SharedBitmapBuilder,
probe_threads_counter: AtomicUsize,
reservation: MemoryReservation,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Self {
Self {
hash_map,
batch,
visited_indices_bitmap,
probe_threads_counter,
shared_state: distributed_state,
reservation,
}
}
Expand All @@ -126,14 +176,34 @@ impl JoinLeftData {
fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder {
&self.visited_indices_bitmap
}

/// Decrements the counter of running threads, and returns `true`
/// if caller is the last running thread
fn report_probe_completed(&self) -> bool {
self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
self.probe_threads_counter.load(Ordering::Relaxed) == 0
|| self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1
}
}

fn merge_bitmap(m1: &mut BooleanBufferBuilder, m2: BooleanBuffer) -> Result<()> {
if m1.len() != m2.len() {
return Err(DataFusionError::Execution(format!(
"local and shared indices bitmaps have different lengths: {} and {}",
m1.len(),
m2.len()
)));
}

for (b1, b2) in m1
.as_slice_mut()
.iter_mut()
.zip(m2.inner().as_slice().iter().copied())
{
*b1 |= b2;
}

Ok(())
}

/// Join execution plan: Evaluates eqijoin predicates in parallel on multiple
/// partitions using a hash table and an optional filter list to apply post
/// join.
Expand Down Expand Up @@ -721,11 +791,25 @@ impl ExecutionPlan for HashJoinExec {
);
}

let distributed_state =
context.session_config().get_extension::<SharedJoinState>();

let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let left_fut = match self.mode {
PartitionMode::CollectLeft => self.left_fut.once(|| {
let reservation =
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());

let probe_threads = distributed_state
.as_ref()
.map(|s| {
s.register_metrics(&self.metrics, partition);
s.num_task_partitions()
})
.unwrap_or_else(|| {
self.right().output_partitioning().partition_count()
});

collect_left_input(
None,
self.random_state.clone(),
Expand All @@ -735,7 +819,8 @@ impl ExecutionPlan for HashJoinExec {
join_metrics.clone(),
reservation,
need_produce_result_in_final(self.join_type),
self.right().output_partitioning().partition_count(),
probe_threads,
distributed_state,
)
}),
PartitionMode::Partitioned => {
Expand All @@ -753,6 +838,7 @@ impl ExecutionPlan for HashJoinExec {
reservation,
need_produce_result_in_final(self.join_type),
1,
None,
))
}
PartitionMode::Auto => {
Expand Down Expand Up @@ -838,6 +924,7 @@ async fn collect_left_input(
reservation: MemoryReservation,
with_visited_indices_bitmap: bool,
probe_threads_count: usize,
distributed_state: Option<Arc<SharedJoinState>>,
) -> Result<JoinLeftData> {
let schema = left.schema();

Expand Down Expand Up @@ -925,6 +1012,7 @@ async fn collect_left_input(
Mutex::new(visited_indices_bitmap),
AtomicUsize::new(probe_threads_count),
reservation,
distributed_state,
);

Ok(data)
Expand Down Expand Up @@ -1301,7 +1389,7 @@ impl HashJoinStream {
handle_state!(self.process_probe_batch())
}
HashJoinStreamState::ExhaustedProbeSide => {
handle_state!(self.process_unmatched_build_batch())
handle_state!(ready!(self.process_unmatched_build_batch(cx)))
}
HashJoinStreamState::Completed => Poll::Ready(None),
};
Expand Down Expand Up @@ -1486,18 +1574,35 @@ impl HashJoinStream {
/// Updates state to `Completed`
fn process_unmatched_build_batch(
&mut self,
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
cx: &mut Context<'_>,
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let timer = self.join_metrics.join_time.timer();

if !need_produce_result_in_final(self.join_type) {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

let build_side = self.build_side.try_as_ready()?;
if !build_side.left_data.report_probe_completed() {
self.state = HashJoinStreamState::Completed;
return Ok(StatefulStreamResult::Continue);
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}

if let Some(shared_state) = build_side.left_data.shared_state.as_ref() {
let mut guard = build_side.left_data.visited_indices_bitmap().lock();
match ready!(shared_state.poll_probe_completed(guard.deref(), cx)) {
Ok(SharedProbeState::Continue) => {
self.state = HashJoinStreamState::Completed;
return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
Ok(SharedProbeState::Ready(shared_mask)) => {
if let Err(e) = merge_bitmap(guard.deref_mut(), shared_mask) {
return Poll::Ready(Err(e));
}
}
Err(err) => return Poll::Ready(Err(err)),
}
}

// use the global left bitmap to produce the left indices and right indices
Expand Down Expand Up @@ -1528,7 +1633,7 @@ impl HashJoinStream {

self.state = HashJoinStreamState::Completed;

Ok(StatefulStreamResult::Ready(Some(result?)))
Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result?))))
}
}

Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
//! DataFusion Join implementations
pub use cross_join::CrossJoinExec;
pub use hash_join::HashJoinExec;
pub use hash_join::{
HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState,
};
pub use nested_loop_join::NestedLoopJoinExec;
// Note: SortMergeJoin is not used in plans yet
pub use sort_merge_join::SortMergeJoinExec;
Expand Down

0 comments on commit 534b4ac

Please sign in to comment.