This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a71a76a996 refactor: `HashJoinStream` state machine (#8538)
a71a76a996 is described below
commit a71a76a996a32a0f068370940ebe475ec237b4ff
Author: Eduard Karacharov <[email protected]>
AuthorDate: Mon Dec 18 11:53:26 2023 +0200
refactor: `HashJoinStream` state machine (#8538)
* hash join state machine
* StreamJoinStateResult to StatefulStreamResult
* doc comments & naming & fmt
* suggestions from code review
Co-authored-by: Andrew Lamb <[email protected]>
* more review comments addressed
* post-merge fixes
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/physical-plan/src/joins/hash_join.rs | 431 ++++++++++++++-------
.../physical-plan/src/joins/stream_join_utils.rs | 127 ++----
.../physical-plan/src/joins/symmetric_hash_join.rs | 25 +-
datafusion/physical-plan/src/joins/utils.rs | 83 ++++
4 files changed, 420 insertions(+), 246 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs
b/datafusion/physical-plan/src/joins/hash_join.rs
index 4846d0a5e0..13ac06ee30 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -28,7 +28,6 @@ use crate::joins::utils::{
calculate_join_output_ordering, get_final_indices_from_bit_map,
need_produce_result_in_final, JoinHashMap, JoinHashMapType,
};
-use crate::DisplayAs;
use crate::{
coalesce_batches::concat_batches,
coalesce_partitions::CoalescePartitionsExec,
@@ -38,12 +37,13 @@ use crate::{
joins::utils::{
adjust_right_output_partitioning, build_join_schema,
check_join_is_valid,
estimate_join_statistics, partitioned_join_output_partitioning,
- BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn,
+ BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn,
StatefulStreamResult,
},
metrics::{ExecutionPlanMetricsSet, MetricsSet},
DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
+use crate::{handle_state, DisplayAs};
use super::{
utils::{OnceAsync, OnceFut},
@@ -618,15 +618,14 @@ impl ExecutionPlan for HashJoinExec {
on_right,
filter: self.filter.clone(),
join_type: self.join_type,
- left_fut,
- visited_left_side: None,
right: right_stream,
column_indices: self.column_indices.clone(),
random_state: self.random_state.clone(),
join_metrics,
null_equals_null: self.null_equals_null,
- is_exhausted: false,
reservation,
+ state: HashJoinStreamState::WaitBuildSide,
+ build_side: BuildSide::Initial(BuildSideInitialState { left_fut }),
}))
}
@@ -789,6 +788,104 @@ where
Ok(())
}
+/// Represents build-side of hash join.
+enum BuildSide {
+ /// Indicates that build-side not collected yet
+ Initial(BuildSideInitialState),
+ /// Indicates that build-side data has been collected
+ Ready(BuildSideReadyState),
+}
+
+/// Container for BuildSide::Initial related data
+struct BuildSideInitialState {
+ /// Future for building hash table from build-side input
+ left_fut: OnceFut<JoinLeftData>,
+}
+
+/// Container for BuildSide::Ready related data
+struct BuildSideReadyState {
+ /// Collected build-side data
+ left_data: Arc<JoinLeftData>,
+ /// Which build-side rows have been matched while creating output.
+ /// For some OUTER joins, we need to know which rows have not been matched
+ /// to produce the correct output.
+ visited_left_side: BooleanBufferBuilder,
+}
+
+impl BuildSide {
+ /// Tries to extract BuildSideInitialState from BuildSide enum.
+ /// Returns an error if state is not Initial.
+ fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> {
+ match self {
+ BuildSide::Initial(state) => Ok(state),
+ _ => internal_err!("Expected build side in initial state"),
+ }
+ }
+
+ /// Tries to extract BuildSideReadyState from BuildSide enum.
+ /// Returns an error if state is not Ready.
+ fn try_as_ready(&self) -> Result<&BuildSideReadyState> {
+ match self {
+ BuildSide::Ready(state) => Ok(state),
+ _ => internal_err!("Expected build side in ready state"),
+ }
+ }
+
+ /// Tries to extract BuildSideReadyState from BuildSide enum.
+ /// Returns an error if state is not Ready.
+ fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> {
+ match self {
+ BuildSide::Ready(state) => Ok(state),
+ _ => internal_err!("Expected build side in ready state"),
+ }
+ }
+}
+
+/// Represents state of HashJoinStream
+///
+/// Expected state transitions performed by HashJoinStream are:
+///
+/// ```text
+///
+/// WaitBuildSide
+/// │
+/// ▼
+/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed
+/// │ │
+/// │ ▼
+/// └─ ProcessProbeBatch
+///
+/// ```
+enum HashJoinStreamState {
+ /// Initial state for HashJoinStream indicating that build-side data not
collected yet
+ WaitBuildSide,
+ /// Indicates that build-side has been collected, and stream is ready for
fetching probe-side
+ FetchProbeBatch,
+ /// Indicates that non-empty batch has been fetched from probe-side, and
is ready to be processed
+ ProcessProbeBatch(ProcessProbeBatchState),
+ /// Indicates that probe-side has been fully processed
+ ExhaustedProbeSide,
+ /// Indicates that HashJoinStream execution is completed
+ Completed,
+}
+
+/// Container for HashJoinStreamState::ProcessProbeBatch related data
+struct ProcessProbeBatchState {
+ /// Current probe-side batch
+ batch: RecordBatch,
+}
+
+impl HashJoinStreamState {
+ /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum.
+ /// Returns an error if state is not ProcessProbeBatchState.
+ fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> {
+ match self {
+ HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
+ _ => internal_err!("Expected hash join stream in ProcessProbeBatch
state"),
+ }
+ }
+}
+
/// [`Stream`] for [`HashJoinExec`] that does the actual join.
///
/// This stream:
@@ -808,20 +905,10 @@ struct HashJoinStream {
filter: Option<JoinFilter>,
/// type of the join (left, right, semi, etc)
join_type: JoinType,
- /// future which builds hash table from left side
- left_fut: OnceFut<JoinLeftData>,
- /// Which left (probe) side rows have been matches while creating output.
- /// For some OUTER joins, we need to know which rows have not been matched
- /// to produce the correct output.
- visited_left_side: Option<BooleanBufferBuilder>,
/// right (probe) input
right: SendableRecordBatchStream,
/// Random state used for hashing initialization
random_state: RandomState,
- /// The join output is complete. For outer joins, this is used to
- /// distinguish when the input stream is exhausted and when any unmatched
- /// rows are output.
- is_exhausted: bool,
/// Metrics
join_metrics: BuildProbeJoinMetrics,
/// Information of index and left / right placement of columns
@@ -830,6 +917,10 @@ struct HashJoinStream {
null_equals_null: bool,
/// Memory reservation
reservation: MemoryReservation,
+ /// State of the stream
+ state: HashJoinStreamState,
+ /// Build side
+ build_side: BuildSide,
}
impl RecordBatchStream for HashJoinStream {
@@ -1069,19 +1160,44 @@ impl HashJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ HashJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ HashJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ HashJoinStreamState::ProcessProbeBatch(_) => {
+ handle_state!(self.process_probe_batch())
+ }
+ HashJoinStreamState::ExhaustedProbeSide => {
+ handle_state!(self.process_unmatched_build_batch())
+ }
+ HashJoinStreamState::Completed => Poll::Ready(None),
+ };
+ }
+ }
+
+ /// Collects build-side data by polling `OnceFut` future from initialized
build-side
+ ///
+ /// Updates build-side to `Ready`, and state to `FetchProbeSide`
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
// build hash table from left (build) side, if not yet done
- let left_data = match ready!(self.left_fut.get(cx)) {
- Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
- };
+ let left_data = ready!(self
+ .build_side
+ .try_as_initial_mut()?
+ .left_fut
+ .get_shared(cx))?;
build_timer.done();
// Reserving memory for visited_left_side bitmap in case it hasn't
been initialized yet
// and join_type requires to store it
- if self.visited_left_side.is_none()
- && need_produce_result_in_final(self.join_type)
- {
+ if need_produce_result_in_final(self.join_type) {
// TODO: Replace `ceil` wrapper with stable `div_cell` after
// https://github.com/rust-lang/rust/issues/88581
let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8);
@@ -1089,124 +1205,167 @@ impl HashJoinStream {
self.join_metrics.build_mem_used.add(visited_bitmap_size);
}
- let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
+ let visited_left_side = if
need_produce_result_in_final(self.join_type) {
let num_rows = left_data.num_rows();
- if need_produce_result_in_final(self.join_type) {
- // Some join types need to track which row has be matched or
unmatched:
- // `left semi` join: need to use the bitmap to produce the
matched row in the left side
- // `left` join: need to use the bitmap to produce the
unmatched row in the left side with null
- // `left anti` join: need to use the bitmap to produce the
unmatched row in the left side
- // `full` join: need to use the bitmap to produce the
unmatched row in the left side with null
- let mut buffer = BooleanBufferBuilder::new(num_rows);
- buffer.append_n(num_rows, false);
- buffer
- } else {
- BooleanBufferBuilder::new(0)
- }
+ // Some join types need to track which row has be matched or
unmatched:
+ // `left semi` join: need to use the bitmap to produce the
matched row in the left side
+ // `left` join: need to use the bitmap to produce the
unmatched row in the left side with null
+ // `left anti` join: need to use the bitmap to produce the
unmatched row in the left side
+ // `full` join: need to use the bitmap to produce the
unmatched row in the left side with null
+ let mut buffer = BooleanBufferBuilder::new(num_rows);
+ buffer.append_n(num_rows, false);
+ buffer
+ } else {
+ BooleanBufferBuilder::new(0)
+ };
+
+ self.state = HashJoinStreamState::FetchProbeBatch;
+ self.build_side = BuildSide::Ready(BuildSideReadyState {
+ left_data,
+ visited_left_side,
});
+
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+
+ /// Fetches next batch from probe-side
+ ///
+ /// If non-empty batch has been fetched, updates state to
`ProcessProbeBatchState`,
+ /// otherwise updates state to `ExhaustedProbeSide`
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ match ready!(self.right.poll_next_unpin(cx)) {
+ None => {
+ self.state = HashJoinStreamState::ExhaustedProbeSide;
+ }
+ Some(Ok(batch)) => {
+ self.state =
+
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
+ batch,
+ });
+ }
+ Some(Err(err)) => return Poll::Ready(Err(err)),
+ };
+
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+
+ /// Joins current probe batch with build-side data and produces batch with
matched output
+ ///
+ /// Updates state to `FetchProbeBatch`
+ fn process_probe_batch(
+ &mut self,
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let state = self.state.try_as_process_probe_batch()?;
+ let build_side = self.build_side.try_as_ready_mut()?;
+
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(state.batch.num_rows());
+ let timer = self.join_metrics.join_time.timer();
+
let mut hashes_buffer = vec![];
- // get next right (probe) input batch
- self.right
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- // one right batch in the join loop
- Some(Ok(batch)) => {
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
- let timer = self.join_metrics.join_time.timer();
-
- // get the matched two indices for the on condition
- let left_right_indices =
build_equal_condition_join_indices(
- left_data.hash_map(),
- left_data.batch(),
- &batch,
- &self.on_left,
- &self.on_right,
- &self.random_state,
- self.null_equals_null,
- &mut hashes_buffer,
- self.filter.as_ref(),
- JoinSide::Left,
- None,
- );
-
- let result = match left_right_indices {
- Ok((left_side, right_side)) => {
- // set the left bitmap
- // and only left, full, left semi, left anti need
the left bitmap
- if need_produce_result_in_final(self.join_type) {
- left_side.iter().flatten().for_each(|x| {
- visited_left_side.set_bit(x as usize,
true);
- });
- }
-
- // adjust the two side indices base on the join
type
- let (left_side, right_side) =
adjust_indices_by_join_type(
- left_side,
- right_side,
- batch.num_rows(),
- self.join_type,
- );
-
- let result = build_batch_from_indices(
- &self.schema,
- left_data.batch(),
- &batch,
- &left_side,
- &right_side,
- &self.column_indices,
- JoinSide::Left,
- );
- self.join_metrics.output_batches.add(1);
-
self.join_metrics.output_rows.add(batch.num_rows());
- Some(result)
- }
- Err(err) => Some(exec_err!(
- "Fail to build join indices in HashJoinExec,
error:{err}"
- )),
- };
- timer.done();
- result
- }
- None => {
- let timer = self.join_metrics.join_time.timer();
- if need_produce_result_in_final(self.join_type) &&
!self.is_exhausted
- {
- // use the global left bitmap to produce the left
indices and right indices
- let (left_side, right_side) =
get_final_indices_from_bit_map(
- visited_left_side,
- self.join_type,
- );
- let empty_right_batch =
- RecordBatch::new_empty(self.right.schema());
- // use the left and right indices to produce the batch
result
- let result = build_batch_from_indices(
- &self.schema,
- left_data.batch(),
- &empty_right_batch,
- &left_side,
- &right_side,
- &self.column_indices,
- JoinSide::Left,
- );
-
- if let Ok(ref batch) = result {
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
-
- self.join_metrics.output_batches.add(1);
-
self.join_metrics.output_rows.add(batch.num_rows());
- }
- timer.done();
- self.is_exhausted = true;
- Some(result)
- } else {
- // end of the join loop
- None
- }
+ // get the matched two indices for the on condition
+ let left_right_indices = build_equal_condition_join_indices(
+ build_side.left_data.hash_map(),
+ build_side.left_data.batch(),
+ &state.batch,
+ &self.on_left,
+ &self.on_right,
+ &self.random_state,
+ self.null_equals_null,
+ &mut hashes_buffer,
+ self.filter.as_ref(),
+ JoinSide::Left,
+ None,
+ );
+
+ let result = match left_right_indices {
+ Ok((left_side, right_side)) => {
+ // set the left bitmap
+ // and only left, full, left semi, left anti need the left
bitmap
+ if need_produce_result_in_final(self.join_type) {
+ left_side.iter().flatten().for_each(|x| {
+ build_side.visited_left_side.set_bit(x as usize, true);
+ });
}
- Some(err) => Some(err),
- })
+
+ // adjust the two side indices base on the join type
+ let (left_side, right_side) = adjust_indices_by_join_type(
+ left_side,
+ right_side,
+ state.batch.num_rows(),
+ self.join_type,
+ );
+
+ let result = build_batch_from_indices(
+ &self.schema,
+ build_side.left_data.batch(),
+ &state.batch,
+ &left_side,
+ &right_side,
+ &self.column_indices,
+ JoinSide::Left,
+ );
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(state.batch.num_rows());
+ result
+ }
+ Err(err) => {
+ exec_err!("Fail to build join indices in HashJoinExec,
error:{err}")
+ }
+ };
+ timer.done();
+
+ self.state = HashJoinStreamState::FetchProbeBatch;
+
+ Ok(StatefulStreamResult::Ready(Some(result?)))
+ }
+
+ /// Processes unmatched build-side rows for certain join types and
produces output batch
+ ///
+ /// Updates state to `Completed`
+ fn process_unmatched_build_batch(
+ &mut self,
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let timer = self.join_metrics.join_time.timer();
+
+ if !need_produce_result_in_final(self.join_type) {
+ self.state = HashJoinStreamState::Completed;
+
+ return Ok(StatefulStreamResult::Continue);
+ }
+
+ let build_side = self.build_side.try_as_ready()?;
+
+ // use the global left bitmap to produce the left indices and right
indices
+ let (left_side, right_side) =
+ get_final_indices_from_bit_map(&build_side.visited_left_side,
self.join_type);
+ let empty_right_batch = RecordBatch::new_empty(self.right.schema());
+ // use the left and right indices to produce the batch result
+ let result = build_batch_from_indices(
+ &self.schema,
+ build_side.left_data.batch(),
+ &empty_right_batch,
+ &left_side,
+ &right_side,
+ &self.column_indices,
+ JoinSide::Left,
+ );
+
+ if let Ok(ref batch) = result {
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(batch.num_rows());
+
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ }
+ timer.done();
+
+ self.state = HashJoinStreamState::Completed;
+
+ Ok(StatefulStreamResult::Ready(Some(result?)))
}
}
diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
index 2f74bd1c4b..64a976a1e3 100644
--- a/datafusion/physical-plan/src/joins/stream_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -23,9 +23,9 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::usize;
-use crate::joins::utils::{JoinFilter, JoinHashMapType};
+use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult};
use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder};
-use crate::{handle_async_state, metrics};
+use crate::{handle_async_state, handle_state, metrics};
use arrow::compute::concat_batches;
use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray,
RecordBatch};
@@ -624,73 +624,6 @@ pub fn record_visited_indices<T: ArrowPrimitiveType>(
}
}
-/// The `handle_state` macro is designed to process the result of a
state-changing
-/// operation, typically encountered in implementations of `EagerJoinStream`.
It
-/// operates on a `StreamJoinStateResult` by matching its variants and
executing
-/// corresponding actions. This macro is used to streamline code that deals
with
-/// state transitions, reducing boilerplate and improving readability.
-///
-/// # Cases
-///
-/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the
-/// stream join operation should proceed to the next step.
-/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with
the
-/// result, either yielding a value or indicating the stream is awaiting more
-/// data.
-/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
-/// during the stream join operation.
-///
-/// # Arguments
-///
-/// * `$match_case`: An expression that evaluates to a
`Result<StreamJoinStateResult<_>>`.
-#[macro_export]
-macro_rules! handle_state {
- ($match_case:expr) => {
- match $match_case {
- Ok(StreamJoinStateResult::Continue) => continue,
- Ok(StreamJoinStateResult::Ready(result)) => {
- Poll::Ready(Ok(result).transpose())
- }
- Err(e) => Poll::Ready(Some(Err(e))),
- }
- };
-}
-
-/// The `handle_async_state` macro adapts the `handle_state` macro for use in
-/// asynchronous operations, particularly when dealing with `Poll` results
within
-/// async traits like `EagerJoinStream`. It polls the asynchronous
state-changing
-/// function using `poll_unpin` and then passes the result to `handle_state`
for
-/// further processing.
-///
-/// # Arguments
-///
-/// * `$state_func`: An async function or future that returns a
-/// `Result<StreamJoinStateResult<_>>`.
-/// * `$cx`: The context to be passed for polling, usually of type `&mut
Context`.
-///
-#[macro_export]
-macro_rules! handle_async_state {
- ($state_func:expr, $cx:expr) => {
- $crate::handle_state!(ready!($state_func.poll_unpin($cx)))
- };
-}
-
-/// Represents the result of a stateful operation on `EagerJoinStream`.
-///
-/// This enumueration indicates whether the state produced a result that is
-/// ready for use (`Ready`) or if the operation requires continuation
(`Continue`).
-///
-/// Variants:
-/// - `Ready(T)`: Indicates that the operation is complete with a result of
type `T`.
-/// - `Continue`: Indicates that the operation is not yet complete and
requires further
-/// processing or more data. When this variant is returned, it typically
means that the
-/// current invocation of the state did not produce a final result, and the
operation
-/// should be invoked again later with more data and possibly with a
different state.
-pub enum StreamJoinStateResult<T> {
- Ready(T),
- Continue,
-}
-
/// Represents the various states of an eager join stream operation.
///
/// This enum is used to track the current state of streaming during a join
@@ -819,14 +752,14 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after pulling the batch.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after pulling the batch.
async fn fetch_next_from_right_stream(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.right_stream().next().await {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
- return Ok(StreamJoinStateResult::Continue);
+ return Ok(StatefulStreamResult::Continue);
}
self.set_state(EagerJoinStreamState::PullLeft);
@@ -835,7 +768,7 @@ pub trait EagerJoinStream {
Some(Err(e)) => Err(e),
None => {
self.set_state(EagerJoinStreamState::RightExhausted);
- Ok(StreamJoinStateResult::Continue)
+ Ok(StatefulStreamResult::Continue)
}
}
}
@@ -848,14 +781,14 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after pulling the batch.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after pulling the batch.
async fn fetch_next_from_left_stream(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.left_stream().next().await {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
- return Ok(StreamJoinStateResult::Continue);
+ return Ok(StatefulStreamResult::Continue);
}
self.set_state(EagerJoinStreamState::PullRight);
self.process_batch_from_left(batch)
@@ -863,7 +796,7 @@ pub trait EagerJoinStream {
Some(Err(e)) => Err(e),
None => {
self.set_state(EagerJoinStreamState::LeftExhausted);
- Ok(StreamJoinStateResult::Continue)
+ Ok(StatefulStreamResult::Continue)
}
}
}
@@ -877,14 +810,14 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after checking the exhaustion state.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after checking the exhaustion state.
async fn handle_right_stream_end(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.left_stream().next().await {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
- return Ok(StreamJoinStateResult::Continue);
+ return Ok(StatefulStreamResult::Continue);
}
self.process_batch_after_right_end(batch)
}
@@ -893,7 +826,7 @@ pub trait EagerJoinStream {
self.set_state(EagerJoinStreamState::BothExhausted {
final_result: false,
});
- Ok(StreamJoinStateResult::Continue)
+ Ok(StatefulStreamResult::Continue)
}
}
}
@@ -907,14 +840,14 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after checking the exhaustion state.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after checking the exhaustion state.
async fn handle_left_stream_end(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
match self.right_stream().next().await {
Some(Ok(batch)) => {
if batch.num_rows() == 0 {
- return Ok(StreamJoinStateResult::Continue);
+ return Ok(StatefulStreamResult::Continue);
}
self.process_batch_after_left_end(batch)
}
@@ -923,7 +856,7 @@ pub trait EagerJoinStream {
self.set_state(EagerJoinStreamState::BothExhausted {
final_result: false,
});
- Ok(StreamJoinStateResult::Continue)
+ Ok(StatefulStreamResult::Continue)
}
}
}
@@ -936,10 +869,10 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after both streams are exhausted.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after both streams are exhausted.
fn prepare_for_final_results_after_exhaustion(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.set_state(EagerJoinStreamState::BothExhausted { final_result:
true });
self.process_batches_before_finalization()
}
@@ -952,11 +885,11 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after processing the batch.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after processing the batch.
fn process_batch_from_right(
&mut self,
batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
/// Handles a pulled batch from the left stream.
///
@@ -966,11 +899,11 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after processing the batch.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after processing the batch.
fn process_batch_from_left(
&mut self,
batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
/// Handles the situation when only the left stream is exhausted.
///
@@ -980,11 +913,11 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after the left stream is exhausted.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after the left stream is exhausted.
fn process_batch_after_left_end(
&mut self,
right_batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
/// Handles the situation when only the right stream is exhausted.
///
@@ -994,20 +927,20 @@ pub trait EagerJoinStream {
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after the right stream is exhausted.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state
result after the right stream is exhausted.
fn process_batch_after_right_end(
&mut self,
left_batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
/// Handles the final state after both streams are exhausted.
///
/// # Returns
///
- /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The final
state result after processing.
+ /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The final
state result after processing.
fn process_batches_before_finalization(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
/// Provides mutable access to the right stream.
///
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 00a7f23eba..b9101b57c3 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -38,12 +38,11 @@ use crate::joins::stream_join_utils::{
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
get_pruning_semi_indices, record_visited_indices, EagerJoinStream,
EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr,
StreamJoinMetrics,
- StreamJoinStateResult,
};
use crate::joins::utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex,
JoinFilter,
- JoinOn,
+ JoinOn, StatefulStreamResult,
};
use crate::{
expressions::{Column, PhysicalSortExpr},
@@ -956,13 +955,13 @@ impl EagerJoinStream for SymmetricHashJoinStream {
fn process_batch_from_right(
&mut self,
batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Right)
.map(|maybe_batch| {
if maybe_batch.is_some() {
- StreamJoinStateResult::Ready(maybe_batch)
+ StatefulStreamResult::Ready(maybe_batch)
} else {
- StreamJoinStateResult::Continue
+ StatefulStreamResult::Continue
}
})
}
@@ -970,13 +969,13 @@ impl EagerJoinStream for SymmetricHashJoinStream {
fn process_batch_from_left(
&mut self,
batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.perform_join_for_given_side(batch, JoinSide::Left)
.map(|maybe_batch| {
if maybe_batch.is_some() {
- StreamJoinStateResult::Ready(maybe_batch)
+ StatefulStreamResult::Ready(maybe_batch)
} else {
- StreamJoinStateResult::Continue
+ StatefulStreamResult::Continue
}
})
}
@@ -984,20 +983,20 @@ impl EagerJoinStream for SymmetricHashJoinStream {
fn process_batch_after_left_end(
&mut self,
right_batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_right(right_batch)
}
fn process_batch_after_right_end(
&mut self,
left_batch: RecordBatch,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
self.process_batch_from_left(left_batch)
}
fn process_batches_before_finalization(
&mut self,
- ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
// Get the left side results:
let left_result = build_side_determined_results(
&self.left,
@@ -1025,9 +1024,9 @@ impl EagerJoinStream for SymmetricHashJoinStream {
// Update the metrics:
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
- return Ok(StreamJoinStateResult::Ready(result));
+ return Ok(StatefulStreamResult::Ready(result));
}
- Ok(StreamJoinStateResult::Continue)
+ Ok(StatefulStreamResult::Continue)
}
fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index 5e01ca227c..eae65ce9c2 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -849,6 +849,22 @@ impl<T: 'static> OnceFut<T> {
),
}
}
+
+ /// Get shared reference to the result of the computation if it is ready,
without consuming it
+ pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) ->
Poll<Result<Arc<T>>> {
+ if let OnceFutState::Pending(fut) = &mut self.state {
+ let r = ready!(fut.poll_unpin(cx));
+ self.state = OnceFutState::Ready(r);
+ }
+
+ match &self.state {
+ OnceFutState::Pending(_) => unreachable!(),
+ OnceFutState::Ready(r) => Poll::Ready(
+ r.clone()
+ .map_err(|e| DataFusionError::External(Box::new(e))),
+ ),
+ }
+ }
}
/// Some type `join_type` of join need to maintain the matched indices bit map
for the left side, and
@@ -1277,6 +1293,73 @@ pub fn prepare_sorted_exprs(
Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
}
+/// The `handle_state` macro is designed to process the result of a
state-changing
+/// operation, encountered e.g. in implementations of `EagerJoinStream`. It
+/// operates on a `StatefulStreamResult` by matching its variants and executing
+/// corresponding actions. This macro is used to streamline code that deals
with
+/// state transitions, reducing boilerplate and improving readability.
+///
+/// # Cases
+///
+/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the
+/// stream join operation should proceed to the next step.
+/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with
the
+/// result, either yielding a value or indicating the stream is awaiting more
+/// data.
+/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
+/// during the stream join operation.
+///
+/// # Arguments
+///
+/// * `$match_case`: An expression that evaluates to a
`Result<StatefulStreamResult<_>>`.
+#[macro_export]
+macro_rules! handle_state {
+ ($match_case:expr) => {
+ match $match_case {
+ Ok(StatefulStreamResult::Continue) => continue,
+ Ok(StatefulStreamResult::Ready(result)) => {
+ Poll::Ready(Ok(result).transpose())
+ }
+ Err(e) => Poll::Ready(Some(Err(e))),
+ }
+ };
+}
+
+/// The `handle_async_state` macro adapts the `handle_state` macro for use in
+/// asynchronous operations, particularly when dealing with `Poll` results
within
+/// async traits like `EagerJoinStream`. It polls the asynchronous
state-changing
+/// function using `poll_unpin` and then passes the result to `handle_state`
for
+/// further processing.
+///
+/// # Arguments
+///
+/// * `$state_func`: An async function or future that returns a
+/// `Result<StatefulStreamResult<_>>`.
+/// * `$cx`: The context to be passed for polling, usually of type `&mut
Context`.
+///
+#[macro_export]
+macro_rules! handle_async_state {
+ ($state_func:expr, $cx:expr) => {
+ $crate::handle_state!(ready!($state_func.poll_unpin($cx)))
+ };
+}
+
+/// Represents the result of an operation on stateful join stream.
+///
+/// This enumueration indicates whether the state produced a result that is
+/// ready for use (`Ready`) or if the operation requires continuation
(`Continue`).
+///
+/// Variants:
+/// - `Ready(T)`: Indicates that the operation is complete with a result of
type `T`.
+/// - `Continue`: Indicates that the operation is not yet complete and
requires further
+/// processing or more data. When this variant is returned, it typically
means that the
+/// current invocation of the state did not produce a final result, and the
operation
+/// should be invoked again later with more data and possibly with a
different state.
+pub enum StatefulStreamResult<T> {
+ Ready(T),
+ Continue,
+}
+
#[cfg(test)]
mod tests {
use std::pin::Pin;