This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a71a76a996 refactor: `HashJoinStream` state machine (#8538)
a71a76a996 is described below

commit a71a76a996a32a0f068370940ebe475ec237b4ff
Author: Eduard Karacharov <[email protected]>
AuthorDate: Mon Dec 18 11:53:26 2023 +0200

    refactor: `HashJoinStream` state machine (#8538)
    
    * hash join state machine
    
    * StreamJoinStateResult to StatefulStreamResult
    
    * doc comments & naming & fmt
    
    * suggestions from code review
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * more review comments addressed
    
    * post-merge fixes
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/physical-plan/src/joins/hash_join.rs    | 431 ++++++++++++++-------
 .../physical-plan/src/joins/stream_join_utils.rs   | 127 ++----
 .../physical-plan/src/joins/symmetric_hash_join.rs |  25 +-
 datafusion/physical-plan/src/joins/utils.rs        |  83 ++++
 4 files changed, 420 insertions(+), 246 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 4846d0a5e0..13ac06ee30 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -28,7 +28,6 @@ use crate::joins::utils::{
     calculate_join_output_ordering, get_final_indices_from_bit_map,
     need_produce_result_in_final, JoinHashMap, JoinHashMapType,
 };
-use crate::DisplayAs;
 use crate::{
     coalesce_batches::concat_batches,
     coalesce_partitions::CoalescePartitionsExec,
@@ -38,12 +37,13 @@ use crate::{
     joins::utils::{
         adjust_right_output_partitioning, build_join_schema, 
check_join_is_valid,
         estimate_join_statistics, partitioned_join_output_partitioning,
-        BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn,
+        BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, 
StatefulStreamResult,
     },
     metrics::{ExecutionPlanMetricsSet, MetricsSet},
     DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
     RecordBatchStream, SendableRecordBatchStream, Statistics,
 };
+use crate::{handle_state, DisplayAs};
 
 use super::{
     utils::{OnceAsync, OnceFut},
@@ -618,15 +618,14 @@ impl ExecutionPlan for HashJoinExec {
             on_right,
             filter: self.filter.clone(),
             join_type: self.join_type,
-            left_fut,
-            visited_left_side: None,
             right: right_stream,
             column_indices: self.column_indices.clone(),
             random_state: self.random_state.clone(),
             join_metrics,
             null_equals_null: self.null_equals_null,
-            is_exhausted: false,
             reservation,
+            state: HashJoinStreamState::WaitBuildSide,
+            build_side: BuildSide::Initial(BuildSideInitialState { left_fut }),
         }))
     }
 
@@ -789,6 +788,104 @@ where
     Ok(())
 }
 
+/// Represents build-side of hash join.
+enum BuildSide {
+    /// Indicates that build-side not collected yet
+    Initial(BuildSideInitialState),
+    /// Indicates that build-side data has been collected
+    Ready(BuildSideReadyState),
+}
+
+/// Container for BuildSide::Initial related data
+struct BuildSideInitialState {
+    /// Future for building hash table from build-side input
+    left_fut: OnceFut<JoinLeftData>,
+}
+
+/// Container for BuildSide::Ready related data
+struct BuildSideReadyState {
+    /// Collected build-side data
+    left_data: Arc<JoinLeftData>,
+    /// Which build-side rows have been matched while creating output.
+    /// For some OUTER joins, we need to know which rows have not been matched
+    /// to produce the correct output.
+    visited_left_side: BooleanBufferBuilder,
+}
+
+impl BuildSide {
+    /// Tries to extract BuildSideInitialState from BuildSide enum.
+    /// Returns an error if state is not Initial.
+    fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> {
+        match self {
+            BuildSide::Initial(state) => Ok(state),
+            _ => internal_err!("Expected build side in initial state"),
+        }
+    }
+
+    /// Tries to extract BuildSideReadyState from BuildSide enum.
+    /// Returns an error if state is not Ready.
+    fn try_as_ready(&self) -> Result<&BuildSideReadyState> {
+        match self {
+            BuildSide::Ready(state) => Ok(state),
+            _ => internal_err!("Expected build side in ready state"),
+        }
+    }
+
+    /// Tries to extract BuildSideReadyState from BuildSide enum.
+    /// Returns an error if state is not Ready.
+    fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> {
+        match self {
+            BuildSide::Ready(state) => Ok(state),
+            _ => internal_err!("Expected build side in ready state"),
+        }
+    }
+}
+
+/// Represents state of HashJoinStream
+///
+/// Expected state transitions performed by HashJoinStream are:
+///
+/// ```text
+///
+///       WaitBuildSide
+///             │
+///             ▼
+///  ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed
+///  │          │
+///  │          ▼
+///  └─ ProcessProbeBatch
+///
+/// ```
+enum HashJoinStreamState {
+    /// Initial state for HashJoinStream indicating that build-side data not 
collected yet
+    WaitBuildSide,
+    /// Indicates that build-side has been collected, and stream is ready for 
fetching probe-side
+    FetchProbeBatch,
+    /// Indicates that non-empty batch has been fetched from probe-side, and 
is ready to be processed
+    ProcessProbeBatch(ProcessProbeBatchState),
+    /// Indicates that probe-side has been fully processed
+    ExhaustedProbeSide,
+    /// Indicates that HashJoinStream execution is completed
+    Completed,
+}
+
+/// Container for HashJoinStreamState::ProcessProbeBatch related data
+struct ProcessProbeBatchState {
+    /// Current probe-side batch
+    batch: RecordBatch,
+}
+
+impl HashJoinStreamState {
+    /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum.
+    /// Returns an error if state is not ProcessProbeBatchState.
+    fn try_as_process_probe_batch(&self) -> Result<&ProcessProbeBatchState> {
+        match self {
+            HashJoinStreamState::ProcessProbeBatch(state) => Ok(state),
+            _ => internal_err!("Expected hash join stream in ProcessProbeBatch 
state"),
+        }
+    }
+}
+
 /// [`Stream`] for [`HashJoinExec`] that does the actual join.
 ///
 /// This stream:
@@ -808,20 +905,10 @@ struct HashJoinStream {
     filter: Option<JoinFilter>,
     /// type of the join (left, right, semi, etc)
     join_type: JoinType,
-    /// future which builds hash table from left side
-    left_fut: OnceFut<JoinLeftData>,
-    /// Which left (probe) side rows have been matches while creating output.
-    /// For some OUTER joins, we need to know which rows have not been matched
-    /// to produce the correct output.
-    visited_left_side: Option<BooleanBufferBuilder>,
     /// right (probe) input
     right: SendableRecordBatchStream,
     /// Random state used for hashing initialization
     random_state: RandomState,
-    /// The join output is complete. For outer joins, this is used to
-    /// distinguish when the input stream is exhausted and when any unmatched
-    /// rows are output.
-    is_exhausted: bool,
     /// Metrics
     join_metrics: BuildProbeJoinMetrics,
     /// Information of index and left / right placement of columns
@@ -830,6 +917,10 @@ struct HashJoinStream {
     null_equals_null: bool,
     /// Memory reservation
     reservation: MemoryReservation,
+    /// State of the stream
+    state: HashJoinStreamState,
+    /// Build side
+    build_side: BuildSide,
 }
 
 impl RecordBatchStream for HashJoinStream {
@@ -1069,19 +1160,44 @@ impl HashJoinStream {
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            return match self.state {
+                HashJoinStreamState::WaitBuildSide => {
+                    handle_state!(ready!(self.collect_build_side(cx)))
+                }
+                HashJoinStreamState::FetchProbeBatch => {
+                    handle_state!(ready!(self.fetch_probe_batch(cx)))
+                }
+                HashJoinStreamState::ProcessProbeBatch(_) => {
+                    handle_state!(self.process_probe_batch())
+                }
+                HashJoinStreamState::ExhaustedProbeSide => {
+                    handle_state!(self.process_unmatched_build_batch())
+                }
+                HashJoinStreamState::Completed => Poll::Ready(None),
+            };
+        }
+    }
+
+    /// Collects build-side data by polling `OnceFut` future from initialized 
build-side
+    ///
+    /// Updates build-side to `Ready`, and state to `FetchProbeSide`
+    fn collect_build_side(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         let build_timer = self.join_metrics.build_time.timer();
         // build hash table from left (build) side, if not yet done
-        let left_data = match ready!(self.left_fut.get(cx)) {
-            Ok(left_data) => left_data,
-            Err(e) => return Poll::Ready(Some(Err(e))),
-        };
+        let left_data = ready!(self
+            .build_side
+            .try_as_initial_mut()?
+            .left_fut
+            .get_shared(cx))?;
         build_timer.done();
 
         // Reserving memory for visited_left_side bitmap in case it hasn't 
been initialized yet
         // and join_type requires to store it
-        if self.visited_left_side.is_none()
-            && need_produce_result_in_final(self.join_type)
-        {
+        if need_produce_result_in_final(self.join_type) {
             // TODO: Replace `ceil` wrapper with stable `div_cell` after
             // https://github.com/rust-lang/rust/issues/88581
             let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8);
@@ -1089,124 +1205,167 @@ impl HashJoinStream {
             self.join_metrics.build_mem_used.add(visited_bitmap_size);
         }
 
-        let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
+        let visited_left_side = if 
need_produce_result_in_final(self.join_type) {
             let num_rows = left_data.num_rows();
-            if need_produce_result_in_final(self.join_type) {
-                // Some join types need to track which row has be matched or 
unmatched:
-                // `left semi` join:  need to use the bitmap to produce the 
matched row in the left side
-                // `left` join:       need to use the bitmap to produce the 
unmatched row in the left side with null
-                // `left anti` join:  need to use the bitmap to produce the 
unmatched row in the left side
-                // `full` join:       need to use the bitmap to produce the 
unmatched row in the left side with null
-                let mut buffer = BooleanBufferBuilder::new(num_rows);
-                buffer.append_n(num_rows, false);
-                buffer
-            } else {
-                BooleanBufferBuilder::new(0)
-            }
+            // Some join types need to track which row has be matched or 
unmatched:
+            // `left semi` join:  need to use the bitmap to produce the 
matched row in the left side
+            // `left` join:       need to use the bitmap to produce the 
unmatched row in the left side with null
+            // `left anti` join:  need to use the bitmap to produce the 
unmatched row in the left side
+            // `full` join:       need to use the bitmap to produce the 
unmatched row in the left side with null
+            let mut buffer = BooleanBufferBuilder::new(num_rows);
+            buffer.append_n(num_rows, false);
+            buffer
+        } else {
+            BooleanBufferBuilder::new(0)
+        };
+
+        self.state = HashJoinStreamState::FetchProbeBatch;
+        self.build_side = BuildSide::Ready(BuildSideReadyState {
+            left_data,
+            visited_left_side,
         });
+
+        Poll::Ready(Ok(StatefulStreamResult::Continue))
+    }
+
+    /// Fetches next batch from probe-side
+    ///
+    /// If non-empty batch has been fetched, updates state to 
`ProcessProbeBatchState`,
+    /// otherwise updates state to `ExhaustedProbeSide`
+    fn fetch_probe_batch(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        match ready!(self.right.poll_next_unpin(cx)) {
+            None => {
+                self.state = HashJoinStreamState::ExhaustedProbeSide;
+            }
+            Some(Ok(batch)) => {
+                self.state =
+                    
HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState {
+                        batch,
+                    });
+            }
+            Some(Err(err)) => return Poll::Ready(Err(err)),
+        };
+
+        Poll::Ready(Ok(StatefulStreamResult::Continue))
+    }
+
+    /// Joins current probe batch with build-side data and produces batch with 
matched output
+    ///
+    /// Updates state to `FetchProbeBatch`
+    fn process_probe_batch(
+        &mut self,
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+        let state = self.state.try_as_process_probe_batch()?;
+        let build_side = self.build_side.try_as_ready_mut()?;
+
+        self.join_metrics.input_batches.add(1);
+        self.join_metrics.input_rows.add(state.batch.num_rows());
+        let timer = self.join_metrics.join_time.timer();
+
         let mut hashes_buffer = vec![];
-        // get next right (probe) input batch
-        self.right
-            .poll_next_unpin(cx)
-            .map(|maybe_batch| match maybe_batch {
-                // one right batch in the join loop
-                Some(Ok(batch)) => {
-                    self.join_metrics.input_batches.add(1);
-                    self.join_metrics.input_rows.add(batch.num_rows());
-                    let timer = self.join_metrics.join_time.timer();
-
-                    // get the matched two indices for the on condition
-                    let left_right_indices = 
build_equal_condition_join_indices(
-                        left_data.hash_map(),
-                        left_data.batch(),
-                        &batch,
-                        &self.on_left,
-                        &self.on_right,
-                        &self.random_state,
-                        self.null_equals_null,
-                        &mut hashes_buffer,
-                        self.filter.as_ref(),
-                        JoinSide::Left,
-                        None,
-                    );
-
-                    let result = match left_right_indices {
-                        Ok((left_side, right_side)) => {
-                            // set the left bitmap
-                            // and only left, full, left semi, left anti need 
the left bitmap
-                            if need_produce_result_in_final(self.join_type) {
-                                left_side.iter().flatten().for_each(|x| {
-                                    visited_left_side.set_bit(x as usize, 
true);
-                                });
-                            }
-
-                            // adjust the two side indices base on the join 
type
-                            let (left_side, right_side) = 
adjust_indices_by_join_type(
-                                left_side,
-                                right_side,
-                                batch.num_rows(),
-                                self.join_type,
-                            );
-
-                            let result = build_batch_from_indices(
-                                &self.schema,
-                                left_data.batch(),
-                                &batch,
-                                &left_side,
-                                &right_side,
-                                &self.column_indices,
-                                JoinSide::Left,
-                            );
-                            self.join_metrics.output_batches.add(1);
-                            
self.join_metrics.output_rows.add(batch.num_rows());
-                            Some(result)
-                        }
-                        Err(err) => Some(exec_err!(
-                            "Fail to build join indices in HashJoinExec, 
error:{err}"
-                        )),
-                    };
-                    timer.done();
-                    result
-                }
-                None => {
-                    let timer = self.join_metrics.join_time.timer();
-                    if need_produce_result_in_final(self.join_type) && 
!self.is_exhausted
-                    {
-                        // use the global left bitmap to produce the left 
indices and right indices
-                        let (left_side, right_side) = 
get_final_indices_from_bit_map(
-                            visited_left_side,
-                            self.join_type,
-                        );
-                        let empty_right_batch =
-                            RecordBatch::new_empty(self.right.schema());
-                        // use the left and right indices to produce the batch 
result
-                        let result = build_batch_from_indices(
-                            &self.schema,
-                            left_data.batch(),
-                            &empty_right_batch,
-                            &left_side,
-                            &right_side,
-                            &self.column_indices,
-                            JoinSide::Left,
-                        );
-
-                        if let Ok(ref batch) = result {
-                            self.join_metrics.input_batches.add(1);
-                            self.join_metrics.input_rows.add(batch.num_rows());
-
-                            self.join_metrics.output_batches.add(1);
-                            
self.join_metrics.output_rows.add(batch.num_rows());
-                        }
-                        timer.done();
-                        self.is_exhausted = true;
-                        Some(result)
-                    } else {
-                        // end of the join loop
-                        None
-                    }
+        // get the matched two indices for the on condition
+        let left_right_indices = build_equal_condition_join_indices(
+            build_side.left_data.hash_map(),
+            build_side.left_data.batch(),
+            &state.batch,
+            &self.on_left,
+            &self.on_right,
+            &self.random_state,
+            self.null_equals_null,
+            &mut hashes_buffer,
+            self.filter.as_ref(),
+            JoinSide::Left,
+            None,
+        );
+
+        let result = match left_right_indices {
+            Ok((left_side, right_side)) => {
+                // set the left bitmap
+                // and only left, full, left semi, left anti need the left 
bitmap
+                if need_produce_result_in_final(self.join_type) {
+                    left_side.iter().flatten().for_each(|x| {
+                        build_side.visited_left_side.set_bit(x as usize, true);
+                    });
                 }
-                Some(err) => Some(err),
-            })
+
+                // adjust the two side indices base on the join type
+                let (left_side, right_side) = adjust_indices_by_join_type(
+                    left_side,
+                    right_side,
+                    state.batch.num_rows(),
+                    self.join_type,
+                );
+
+                let result = build_batch_from_indices(
+                    &self.schema,
+                    build_side.left_data.batch(),
+                    &state.batch,
+                    &left_side,
+                    &right_side,
+                    &self.column_indices,
+                    JoinSide::Left,
+                );
+                self.join_metrics.output_batches.add(1);
+                self.join_metrics.output_rows.add(state.batch.num_rows());
+                result
+            }
+            Err(err) => {
+                exec_err!("Fail to build join indices in HashJoinExec, 
error:{err}")
+            }
+        };
+        timer.done();
+
+        self.state = HashJoinStreamState::FetchProbeBatch;
+
+        Ok(StatefulStreamResult::Ready(Some(result?)))
+    }
+
+    /// Processes unmatched build-side rows for certain join types and 
produces output batch
+    ///
+    /// Updates state to `Completed`
+    fn process_unmatched_build_batch(
+        &mut self,
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+        let timer = self.join_metrics.join_time.timer();
+
+        if !need_produce_result_in_final(self.join_type) {
+            self.state = HashJoinStreamState::Completed;
+
+            return Ok(StatefulStreamResult::Continue);
+        }
+
+        let build_side = self.build_side.try_as_ready()?;
+
+        // use the global left bitmap to produce the left indices and right 
indices
+        let (left_side, right_side) =
+            get_final_indices_from_bit_map(&build_side.visited_left_side, 
self.join_type);
+        let empty_right_batch = RecordBatch::new_empty(self.right.schema());
+        // use the left and right indices to produce the batch result
+        let result = build_batch_from_indices(
+            &self.schema,
+            build_side.left_data.batch(),
+            &empty_right_batch,
+            &left_side,
+            &right_side,
+            &self.column_indices,
+            JoinSide::Left,
+        );
+
+        if let Ok(ref batch) = result {
+            self.join_metrics.input_batches.add(1);
+            self.join_metrics.input_rows.add(batch.num_rows());
+
+            self.join_metrics.output_batches.add(1);
+            self.join_metrics.output_rows.add(batch.num_rows());
+        }
+        timer.done();
+
+        self.state = HashJoinStreamState::Completed;
+
+        Ok(StatefulStreamResult::Ready(Some(result?)))
     }
 }
 
diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs 
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
index 2f74bd1c4b..64a976a1e3 100644
--- a/datafusion/physical-plan/src/joins/stream_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -23,9 +23,9 @@ use std::sync::Arc;
 use std::task::{Context, Poll};
 use std::usize;
 
-use crate::joins::utils::{JoinFilter, JoinHashMapType};
+use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult};
 use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder};
-use crate::{handle_async_state, metrics};
+use crate::{handle_async_state, handle_state, metrics};
 
 use arrow::compute::concat_batches;
 use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, 
RecordBatch};
@@ -624,73 +624,6 @@ pub fn record_visited_indices<T: ArrowPrimitiveType>(
     }
 }
 
-/// The `handle_state` macro is designed to process the result of a 
state-changing
-/// operation, typically encountered in implementations of `EagerJoinStream`. 
It
-/// operates on a `StreamJoinStateResult` by matching its variants and 
executing
-/// corresponding actions. This macro is used to streamline code that deals 
with
-/// state transitions, reducing boilerplate and improving readability.
-///
-/// # Cases
-///
-/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the
-///   stream join operation should proceed to the next step.
-/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with 
the
-///   result, either yielding a value or indicating the stream is awaiting more
-///   data.
-/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
-///   during the stream join operation.
-///
-/// # Arguments
-///
-/// * `$match_case`: An expression that evaluates to a 
`Result<StreamJoinStateResult<_>>`.
-#[macro_export]
-macro_rules! handle_state {
-    ($match_case:expr) => {
-        match $match_case {
-            Ok(StreamJoinStateResult::Continue) => continue,
-            Ok(StreamJoinStateResult::Ready(result)) => {
-                Poll::Ready(Ok(result).transpose())
-            }
-            Err(e) => Poll::Ready(Some(Err(e))),
-        }
-    };
-}
-
-/// The `handle_async_state` macro adapts the `handle_state` macro for use in
-/// asynchronous operations, particularly when dealing with `Poll` results 
within
-/// async traits like `EagerJoinStream`. It polls the asynchronous 
state-changing
-/// function using `poll_unpin` and then passes the result to `handle_state` 
for
-/// further processing.
-///
-/// # Arguments
-///
-/// * `$state_func`: An async function or future that returns a
-///   `Result<StreamJoinStateResult<_>>`.
-/// * `$cx`: The context to be passed for polling, usually of type `&mut 
Context`.
-///
-#[macro_export]
-macro_rules! handle_async_state {
-    ($state_func:expr, $cx:expr) => {
-        $crate::handle_state!(ready!($state_func.poll_unpin($cx)))
-    };
-}
-
-/// Represents the result of a stateful operation on `EagerJoinStream`.
-///
-/// This enumueration indicates whether the state produced a result that is
-/// ready for use (`Ready`) or if the operation requires continuation 
(`Continue`).
-///
-/// Variants:
-/// - `Ready(T)`: Indicates that the operation is complete with a result of 
type `T`.
-/// - `Continue`: Indicates that the operation is not yet complete and 
requires further
-///   processing or more data. When this variant is returned, it typically 
means that the
-///   current invocation of the state did not produce a final result, and the 
operation
-///   should be invoked again later with more data and possibly with a 
different state.
-pub enum StreamJoinStateResult<T> {
-    Ready(T),
-    Continue,
-}
-
 /// Represents the various states of an eager join stream operation.
 ///
 /// This enum is used to track the current state of streaming during a join
@@ -819,14 +752,14 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after pulling the batch.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after pulling the batch.
     async fn fetch_next_from_right_stream(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         match self.right_stream().next().await {
             Some(Ok(batch)) => {
                 if batch.num_rows() == 0 {
-                    return Ok(StreamJoinStateResult::Continue);
+                    return Ok(StatefulStreamResult::Continue);
                 }
 
                 self.set_state(EagerJoinStreamState::PullLeft);
@@ -835,7 +768,7 @@ pub trait EagerJoinStream {
             Some(Err(e)) => Err(e),
             None => {
                 self.set_state(EagerJoinStreamState::RightExhausted);
-                Ok(StreamJoinStateResult::Continue)
+                Ok(StatefulStreamResult::Continue)
             }
         }
     }
@@ -848,14 +781,14 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after pulling the batch.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after pulling the batch.
     async fn fetch_next_from_left_stream(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         match self.left_stream().next().await {
             Some(Ok(batch)) => {
                 if batch.num_rows() == 0 {
-                    return Ok(StreamJoinStateResult::Continue);
+                    return Ok(StatefulStreamResult::Continue);
                 }
                 self.set_state(EagerJoinStreamState::PullRight);
                 self.process_batch_from_left(batch)
@@ -863,7 +796,7 @@ pub trait EagerJoinStream {
             Some(Err(e)) => Err(e),
             None => {
                 self.set_state(EagerJoinStreamState::LeftExhausted);
-                Ok(StreamJoinStateResult::Continue)
+                Ok(StatefulStreamResult::Continue)
             }
         }
     }
@@ -877,14 +810,14 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after checking the exhaustion state.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after checking the exhaustion state.
     async fn handle_right_stream_end(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         match self.left_stream().next().await {
             Some(Ok(batch)) => {
                 if batch.num_rows() == 0 {
-                    return Ok(StreamJoinStateResult::Continue);
+                    return Ok(StatefulStreamResult::Continue);
                 }
                 self.process_batch_after_right_end(batch)
             }
@@ -893,7 +826,7 @@ pub trait EagerJoinStream {
                 self.set_state(EagerJoinStreamState::BothExhausted {
                     final_result: false,
                 });
-                Ok(StreamJoinStateResult::Continue)
+                Ok(StatefulStreamResult::Continue)
             }
         }
     }
@@ -907,14 +840,14 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after checking the exhaustion state.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after checking the exhaustion state.
     async fn handle_left_stream_end(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         match self.right_stream().next().await {
             Some(Ok(batch)) => {
                 if batch.num_rows() == 0 {
-                    return Ok(StreamJoinStateResult::Continue);
+                    return Ok(StatefulStreamResult::Continue);
                 }
                 self.process_batch_after_left_end(batch)
             }
@@ -923,7 +856,7 @@ pub trait EagerJoinStream {
                 self.set_state(EagerJoinStreamState::BothExhausted {
                     final_result: false,
                 });
-                Ok(StreamJoinStateResult::Continue)
+                Ok(StatefulStreamResult::Continue)
             }
         }
     }
@@ -936,10 +869,10 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after both streams are exhausted.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after both streams are exhausted.
     fn prepare_for_final_results_after_exhaustion(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         self.set_state(EagerJoinStreamState::BothExhausted { final_result: 
true });
         self.process_batches_before_finalization()
     }
@@ -952,11 +885,11 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after processing the batch.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after processing the batch.
     fn process_batch_from_right(
         &mut self,
         batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
 
     /// Handles a pulled batch from the left stream.
     ///
@@ -966,11 +899,11 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after processing the batch.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after processing the batch.
     fn process_batch_from_left(
         &mut self,
         batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
 
     /// Handles the situation when only the left stream is exhausted.
     ///
@@ -980,11 +913,11 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after the left stream is exhausted.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after the left stream is exhausted.
     fn process_batch_after_left_end(
         &mut self,
         right_batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
 
     /// Handles the situation when only the right stream is exhausted.
     ///
@@ -994,20 +927,20 @@ pub trait EagerJoinStream {
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after the right stream is exhausted.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The state 
result after the right stream is exhausted.
     fn process_batch_after_right_end(
         &mut self,
         left_batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
 
     /// Handles the final state after both streams are exhausted.
     ///
     /// # Returns
     ///
-    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The final 
state result after processing.
+    /// * `Result<StatefulStreamResult<Option<RecordBatch>>>` - The final 
state result after processing.
     fn process_batches_before_finalization(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>>;
 
     /// Provides mutable access to the right stream.
     ///
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 00a7f23eba..b9101b57c3 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -38,12 +38,11 @@ use crate::joins::stream_join_utils::{
     convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
     get_pruning_semi_indices, record_visited_indices, EagerJoinStream,
     EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, 
StreamJoinMetrics,
-    StreamJoinStateResult,
 };
 use crate::joins::utils::{
     build_batch_from_indices, build_join_schema, check_join_is_valid,
     partitioned_join_output_partitioning, prepare_sorted_exprs, ColumnIndex, 
JoinFilter,
-    JoinOn,
+    JoinOn, StatefulStreamResult,
 };
 use crate::{
     expressions::{Column, PhysicalSortExpr},
@@ -956,13 +955,13 @@ impl EagerJoinStream for SymmetricHashJoinStream {
     fn process_batch_from_right(
         &mut self,
         batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         self.perform_join_for_given_side(batch, JoinSide::Right)
             .map(|maybe_batch| {
                 if maybe_batch.is_some() {
-                    StreamJoinStateResult::Ready(maybe_batch)
+                    StatefulStreamResult::Ready(maybe_batch)
                 } else {
-                    StreamJoinStateResult::Continue
+                    StatefulStreamResult::Continue
                 }
             })
     }
@@ -970,13 +969,13 @@ impl EagerJoinStream for SymmetricHashJoinStream {
     fn process_batch_from_left(
         &mut self,
         batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         self.perform_join_for_given_side(batch, JoinSide::Left)
             .map(|maybe_batch| {
                 if maybe_batch.is_some() {
-                    StreamJoinStateResult::Ready(maybe_batch)
+                    StatefulStreamResult::Ready(maybe_batch)
                 } else {
-                    StreamJoinStateResult::Continue
+                    StatefulStreamResult::Continue
                 }
             })
     }
@@ -984,20 +983,20 @@ impl EagerJoinStream for SymmetricHashJoinStream {
     fn process_batch_after_left_end(
         &mut self,
         right_batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         self.process_batch_from_right(right_batch)
     }
 
     fn process_batch_after_right_end(
         &mut self,
         left_batch: RecordBatch,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         self.process_batch_from_left(left_batch)
     }
 
     fn process_batches_before_finalization(
         &mut self,
-    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
         // Get the left side results:
         let left_result = build_side_determined_results(
             &self.left,
@@ -1025,9 +1024,9 @@ impl EagerJoinStream for SymmetricHashJoinStream {
             // Update the metrics:
             self.metrics.output_batches.add(1);
             self.metrics.output_rows.add(batch.num_rows());
-            return Ok(StreamJoinStateResult::Ready(result));
+            return Ok(StatefulStreamResult::Ready(result));
         }
-        Ok(StreamJoinStateResult::Continue)
+        Ok(StatefulStreamResult::Continue)
     }
 
     fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index 5e01ca227c..eae65ce9c2 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -849,6 +849,22 @@ impl<T: 'static> OnceFut<T> {
             ),
         }
     }
+
+    /// Get shared reference to the result of the computation if it is ready, 
without consuming it
+    pub(crate) fn get_shared(&mut self, cx: &mut Context<'_>) -> 
Poll<Result<Arc<T>>> {
+        if let OnceFutState::Pending(fut) = &mut self.state {
+            let r = ready!(fut.poll_unpin(cx));
+            self.state = OnceFutState::Ready(r);
+        }
+
+        match &self.state {
+            OnceFutState::Pending(_) => unreachable!(),
+            OnceFutState::Ready(r) => Poll::Ready(
+                r.clone()
+                    .map_err(|e| DataFusionError::External(Box::new(e))),
+            ),
+        }
+    }
 }
 
 /// Some type `join_type` of join need to maintain the matched indices bit map 
for the left side, and
@@ -1277,6 +1293,73 @@ pub fn prepare_sorted_exprs(
     Ok((left_sorted_filter_expr, right_sorted_filter_expr, graph))
 }
 
+/// The `handle_state` macro is designed to process the result of a 
state-changing
+/// operation, encountered e.g. in implementations of `EagerJoinStream`. It
+/// operates on a `StatefulStreamResult` by matching its variants and executing
+/// corresponding actions. This macro is used to streamline code that deals 
with
+/// state transitions, reducing boilerplate and improving readability.
+///
+/// # Cases
+///
+/// - `Ok(StatefulStreamResult::Continue)`: Continues the loop, indicating the
+///   stream join operation should proceed to the next step.
+/// - `Ok(StatefulStreamResult::Ready(result))`: Returns a `Poll::Ready` with 
the
+///   result, either yielding a value or indicating the stream is awaiting more
+///   data.
+/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
+///   during the stream join operation.
+///
+/// # Arguments
+///
+/// * `$match_case`: An expression that evaluates to a 
`Result<StatefulStreamResult<_>>`.
+#[macro_export]
+macro_rules! handle_state {
+    ($match_case:expr) => {
+        match $match_case {
+            Ok(StatefulStreamResult::Continue) => continue,
+            Ok(StatefulStreamResult::Ready(result)) => {
+                Poll::Ready(Ok(result).transpose())
+            }
+            Err(e) => Poll::Ready(Some(Err(e))),
+        }
+    };
+}
+
+/// The `handle_async_state` macro adapts the `handle_state` macro for use in
+/// asynchronous operations, particularly when dealing with `Poll` results 
within
+/// async traits like `EagerJoinStream`. It polls the asynchronous 
state-changing
+/// function using `poll_unpin` and then passes the result to `handle_state` 
for
+/// further processing.
+///
+/// # Arguments
+///
+/// * `$state_func`: An async function or future that returns a
+///   `Result<StatefulStreamResult<_>>`.
+/// * `$cx`: The context to be passed for polling, usually of type `&mut 
Context`.
+///
+#[macro_export]
+macro_rules! handle_async_state {
+    ($state_func:expr, $cx:expr) => {
+        $crate::handle_state!(ready!($state_func.poll_unpin($cx)))
+    };
+}
+
+/// Represents the result of an operation on stateful join stream.
+///
+/// This enumueration indicates whether the state produced a result that is
+/// ready for use (`Ready`) or if the operation requires continuation 
(`Continue`).
+///
+/// Variants:
+/// - `Ready(T)`: Indicates that the operation is complete with a result of 
type `T`.
+/// - `Continue`: Indicates that the operation is not yet complete and 
requires further
+///   processing or more data. When this variant is returned, it typically 
means that the
+///   current invocation of the state did not produce a final result, and the 
operation
+///   should be invoked again later with more data and possibly with a 
different state.
+pub enum StatefulStreamResult<T> {
+    Ready(T),
+    Continue,
+}
+
 #[cfg(test)]
 mod tests {
     use std::pin::Pin;


Reply via email to