This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2156dde546 Making stream joins extensible: A new Trait implementation 
for SHJ (#8234)
2156dde546 is described below

commit 2156dde54623d26635d4388d161d94ac79918cdc
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Mon Nov 20 13:40:25 2023 +0200

    Making stream joins extensible: A new Trait implementation for SHJ (#8234)
    
    * Upstream
    
    * Update utils.rs
    
    * Review
    
    * Name change and remove ignore on test
    
    * Comment revisions
    
    * Improve comments
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/physical-plan/src/joins/hash_join.rs    |   3 +-
 datafusion/physical-plan/src/joins/mod.rs          |   2 +-
 .../{hash_join_utils.rs => stream_join_utils.rs}   | 550 ++++++++++++++++-----
 .../physical-plan/src/joins/symmetric_hash_join.rs | 503 +++++++++++--------
 datafusion/physical-plan/src/joins/test_utils.rs   |  61 ++-
 datafusion/physical-plan/src/joins/utils.rs        | 131 ++++-
 datafusion/proto/proto/datafusion.proto            |  16 +
 datafusion/proto/src/generated/pbjson.rs           | 285 +++++++++++
 datafusion/proto/src/generated/prost.rs            |  48 +-
 datafusion/proto/src/physical_plan/mod.rs          | 168 ++++++-
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  46 +-
 11 files changed, 1463 insertions(+), 350 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 7a08b56a6e..4846d0a5e0 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -26,7 +26,7 @@ use std::{any::Any, usize, vec};
 use crate::joins::utils::{
     adjust_indices_by_join_type, apply_join_filter_to_indices, 
build_batch_from_indices,
     calculate_join_output_ordering, get_final_indices_from_bit_map,
-    need_produce_result_in_final,
+    need_produce_result_in_final, JoinHashMap, JoinHashMapType,
 };
 use crate::DisplayAs;
 use crate::{
@@ -35,7 +35,6 @@ use crate::{
     expressions::Column,
     expressions::PhysicalSortExpr,
     hash_utils::create_hashes,
-    joins::hash_join_utils::{JoinHashMap, JoinHashMapType},
     joins::utils::{
         adjust_right_output_partitioning, build_join_schema, 
check_join_is_valid,
         estimate_join_statistics, partitioned_join_output_partitioning,
diff --git a/datafusion/physical-plan/src/joins/mod.rs 
b/datafusion/physical-plan/src/joins/mod.rs
index 19f10d06e1..6ddf19c511 100644
--- a/datafusion/physical-plan/src/joins/mod.rs
+++ b/datafusion/physical-plan/src/joins/mod.rs
@@ -25,9 +25,9 @@ pub use sort_merge_join::SortMergeJoinExec;
 pub use symmetric_hash_join::SymmetricHashJoinExec;
 mod cross_join;
 mod hash_join;
-mod hash_join_utils;
 mod nested_loop_join;
 mod sort_merge_join;
+mod stream_join_utils;
 mod symmetric_hash_join;
 pub mod utils;
 
diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs 
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
similarity index 67%
rename from datafusion/physical-plan/src/joins/hash_join_utils.rs
rename to datafusion/physical-plan/src/joins/stream_join_utils.rs
index db65c8bf08..aa57a4f896 100644
--- a/datafusion/physical-plan/src/joins/hash_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -15,151 +15,34 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! This file contains common subroutines for regular and symmetric hash join
+//! This file contains common subroutines for symmetric hash join
 //! related functionality, used both in join calculations and optimization 
rules.
 
 use std::collections::{HashMap, VecDeque};
-use std::fmt::Debug;
-use std::ops::IndexMut;
 use std::sync::Arc;
-use std::{fmt, usize};
+use std::task::{Context, Poll};
+use std::usize;
 
-use crate::joins::utils::JoinFilter;
+use crate::handle_async_state;
+use crate::joins::utils::{JoinFilter, JoinHashMapType};
 
 use arrow::compute::concat_batches;
-use arrow::datatypes::{ArrowNativeType, SchemaRef};
-use arrow_array::builder::BooleanBufferBuilder;
 use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, 
RecordBatch};
+use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder};
+use arrow_schema::SchemaRef;
+use async_trait::async_trait;
 use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue};
+use datafusion_execution::SendableRecordBatchStream;
 use datafusion_physical_expr::expressions::Column;
 use datafusion_physical_expr::intervals::{Interval, IntervalBound};
 use datafusion_physical_expr::utils::collect_columns;
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
 
+use futures::{ready, FutureExt, StreamExt};
 use hashbrown::raw::RawTable;
 use hashbrown::HashSet;
 
-/// Maps a `u64` hash value based on the build side ["on" values] to a list of 
indices with this key's value.
-///
-/// By allocating a `HashMap` with capacity for *at least* the number of rows 
for entries at the build side,
-/// we make sure that we don't have to re-hash the hashmap, which needs access 
to the key (the hash in this case) value.
-///
-/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 
8 for hash value 1
-/// As the key is a hash value, we need to check possible hash collisions in 
the probe stage
-/// During this stage it might be the case that a row is contained the same 
hashmap value,
-/// but the values don't match. Those are checked in the 
[`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method.
-///
-/// The indices (values) are stored in a separate chained list stored in the 
`Vec<u64>`.
-///
-/// The first value (+1) is stored in the hashmap, whereas the next value is 
stored in array at the position value.
-///
-/// The chain can be followed until the value "0" has been reached, meaning 
the end of the list.
-/// Also see chapter 5.3 of [Balancing vectorized query execution with 
bandwidth-optimized 
storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487)
-///
-/// # Example
-///
-/// ``` text
-/// See the example below:
-///
-/// Insert (10,1)            <-- insert hash value 10 with row index 1
-/// map:
-/// ----------
-/// | 10 | 2 |
-/// ----------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 0 | 0 |
-/// ---------------------
-/// Insert (20,2)
-/// map:
-/// ----------
-/// | 10 | 2 |
-/// | 20 | 3 |
-/// ----------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 0 | 0 |
-/// ---------------------
-/// Insert (10,3)           <-- collision! row index 3 has a hash value of 10 
as well
-/// map:
-/// ----------
-/// | 10 | 4 |
-/// | 20 | 3 |
-/// ----------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 2 | 0 |  <--- hash value 10 maps to 4,2 (which means indices 
values 3,1)
-/// ---------------------
-/// Insert (10,4)          <-- another collision! row index 4 ALSO has a hash 
value of 10
-/// map:
-/// ---------
-/// | 10 | 5 |
-/// | 20 | 3 |
-/// ---------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means 
indices values 4,3,1)
-/// ---------------------
-/// ```
-pub struct JoinHashMap {
-    // Stores hash value to last row index
-    map: RawTable<(u64, u64)>,
-    // Stores indices in chained list data structure
-    next: Vec<u64>,
-}
-
-impl JoinHashMap {
-    #[cfg(test)]
-    pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec<u64>) -> Self {
-        Self { map, next }
-    }
-
-    pub(crate) fn with_capacity(capacity: usize) -> Self {
-        JoinHashMap {
-            map: RawTable::with_capacity(capacity),
-            next: vec![0; capacity],
-        }
-    }
-}
-
-/// Trait defining methods that must be implemented by a hash map type to be 
used for joins.
-pub trait JoinHashMapType {
-    /// The type of list used to store the next list
-    type NextType: IndexMut<usize, Output = u64>;
-    /// Extend with zero
-    fn extend_zero(&mut self, len: usize);
-    /// Returns mutable references to the hash map and the next.
-    fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType);
-    /// Returns a reference to the hash map.
-    fn get_map(&self) -> &RawTable<(u64, u64)>;
-    /// Returns a reference to the next.
-    fn get_list(&self) -> &Self::NextType;
-}
-
-/// Implementation of `JoinHashMapType` for `JoinHashMap`.
-impl JoinHashMapType for JoinHashMap {
-    type NextType = Vec<u64>;
-
-    // Void implementation
-    fn extend_zero(&mut self, _: usize) {}
-
-    /// Get mutable references to the hash map and the next.
-    fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) {
-        (&mut self.map, &mut self.next)
-    }
-
-    /// Get a reference to the hash map.
-    fn get_map(&self) -> &RawTable<(u64, u64)> {
-        &self.map
-    }
-
-    /// Get a reference to the next.
-    fn get_list(&self) -> &Self::NextType {
-        &self.next
-    }
-}
-
 /// Implementation of `JoinHashMapType` for `PruningJoinHashMap`.
 impl JoinHashMapType for PruningJoinHashMap {
     type NextType = VecDeque<u64>;
@@ -185,12 +68,6 @@ impl JoinHashMapType for PruningJoinHashMap {
     }
 }
 
-impl fmt::Debug for JoinHashMap {
-    fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
-        Ok(())
-    }
-}
-
 /// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with
 /// the capability of pruning elements in an efficient manner. This structure
 /// is particularly useful for cases where it's necessary to remove elements
@@ -322,7 +199,7 @@ impl PruningJoinHashMap {
     }
 }
 
-fn check_filter_expr_contains_sort_information(
+pub fn check_filter_expr_contains_sort_information(
     expr: &Arc<dyn PhysicalExpr>,
     reference: &Arc<dyn PhysicalExpr>,
 ) -> bool {
@@ -740,20 +617,423 @@ pub fn record_visited_indices<T: ArrowPrimitiveType>(
     }
 }
 
+/// The `handle_state` macro is designed to process the result of a 
state-changing
+/// operation, typically encountered in implementations of `EagerJoinStream`. 
It
+/// operates on a `StreamJoinStateResult` by matching its variants and 
executing
+/// corresponding actions. This macro is used to streamline code that deals 
with
+/// state transitions, reducing boilerplate and improving readability.
+///
+/// # Cases
+///
+/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the
+///   stream join operation should proceed to the next step.
+/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with 
the
+///   result, either yielding a value or indicating the stream is awaiting more
+///   data.
+/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
+///   during the stream join operation.
+///
+/// # Arguments
+///
+/// * `$match_case`: An expression that evaluates to a 
`Result<StreamJoinStateResult<_>>`.
+#[macro_export]
+macro_rules! handle_state {
+    ($match_case:expr) => {
+        match $match_case {
+            Ok(StreamJoinStateResult::Continue) => continue,
+            Ok(StreamJoinStateResult::Ready(result)) => {
+                Poll::Ready(Ok(result).transpose())
+            }
+            Err(e) => Poll::Ready(Some(Err(e))),
+        }
+    };
+}
+
+/// The `handle_async_state` macro adapts the `handle_state` macro for use in
+/// asynchronous operations, particularly when dealing with `Poll` results 
within
+/// async traits like `EagerJoinStream`. It polls the asynchronous 
state-changing
+/// function using `poll_unpin` and then passes the result to `handle_state` 
for
+/// further processing.
+///
+/// # Arguments
+///
+/// * `$state_func`: An async function or future that returns a
+///   `Result<StreamJoinStateResult<_>>`.
+/// * `$cx`: The context to be passed for polling, usually of type `&mut 
Context`.
+///
+#[macro_export]
+macro_rules! handle_async_state {
+    ($state_func:expr, $cx:expr) => {
+        $crate::handle_state!(ready!($state_func.poll_unpin($cx)))
+    };
+}
+
+/// Represents the result of a stateful operation on `EagerJoinStream`.
+///
+/// This enumueration indicates whether the state produced a result that is
+/// ready for use (`Ready`) or if the operation requires continuation 
(`Continue`).
+///
+/// Variants:
+/// - `Ready(T)`: Indicates that the operation is complete with a result of 
type `T`.
+/// - `Continue`: Indicates that the operation is not yet complete and 
requires further
+///   processing or more data. When this variant is returned, it typically 
means that the
+///   current invocation of the state did not produce a final result, and the 
operation
+///   should be invoked again later with more data and possibly with a 
different state.
+pub enum StreamJoinStateResult<T> {
+    Ready(T),
+    Continue,
+}
+
+/// Represents the various states of an eager join stream operation.
+///
+/// This enum is used to track the current state of streaming during a join
+/// operation. It provides indicators as to which side of the join needs to be
+/// pulled next or if one (or both) sides have been exhausted. This allows
+/// for efficient management of resources and optimal performance during the
+/// join process.
+#[derive(Clone, Debug)]
+pub enum EagerJoinStreamState {
+    /// Indicates that the next step should pull from the right side of the 
join.
+    PullRight,
+
+    /// Indicates that the next step should pull from the left side of the 
join.
+    PullLeft,
+
+    /// State representing that the right side of the join has been fully 
processed.
+    RightExhausted,
+
+    /// State representing that the left side of the join has been fully 
processed.
+    LeftExhausted,
+
+    /// Represents a state where both sides of the join are exhausted.
+    ///
+    /// The `final_result` field indicates whether the join operation has
+    /// produced a final result or not.
+    BothExhausted { final_result: bool },
+}
+
+/// `EagerJoinStream` is an asynchronous trait designed for managing 
incremental
+/// join operations between two streams, such as those used in 
`SymmetricHashJoinExec`
+/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to 
scan
+/// one side of the join fully before proceeding, `EagerJoinStream` facilitates
+/// more dynamic join operations by working with streams as they emit data. 
This
+/// approach allows for more efficient processing, particularly in scenarios
+/// where waiting for complete data materialization is not feasible or optimal.
+/// The trait provides a framework for handling various states of such a join
+/// process, ensuring that join logic is efficiently executed as data becomes
+/// available from either stream.
+///
+/// Implementors of this trait can perform eager joins of data from two 
different
+/// asynchronous streams, typically referred to as left and right streams. The
+/// trait provides a comprehensive set of methods to control and execute the 
join
+/// process, leveraging the states defined in `EagerJoinStreamState`. Methods 
are
+/// primarily focused on asynchronously fetching data batches from each stream,
+/// processing them, and managing transitions between various states of the 
join.
+///
+/// This trait's default implementations use a state machine approach to 
navigate
+/// different stages of the join operation, handling data from both streams and
+/// determining when the join completes.
+///
+/// State Transitions:
+/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
+///   - In `fetch_next_from_left_stream`, when fetching a batch from the left 
stream:
+///     - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
+///       processing the batch.
+///     - On error (`Some(Err(e))`), the error is returned, and the state 
remains
+///       unchanged.
+///     - On no data (`None`), state changes to `LeftExhausted`, returning 
`Continue`
+///       to proceed with the join process.
+/// - From `PullRight` to `PullLeft` or `RightExhausted`:
+///   - In `fetch_next_from_right_stream`, when fetching from the right stream:
+///     - If a batch is available, state changes to `PullLeft` for processing.
+///     - On error, the error is returned without changing the state.
+///     - If right stream is exhausted (`None`), state transitions to 
`RightExhausted`,
+///       with a `Continue` result.
+/// - Handling `RightExhausted` and `LeftExhausted`:
+///   - Methods `handle_right_stream_end` and `handle_left_stream_end` manage 
scenarios
+///     when streams are exhausted:
+///     - They attempt to continue processing with the other stream.
+///     - If both streams are exhausted, state changes to `BothExhausted { 
final_result: false }`.
+/// - Transition to `BothExhausted { final_result: true }`:
+///   - Occurs in `prepare_for_final_results_after_exhaustion` when both 
streams are
+///     exhausted, indicating completion of processing and availability of 
final results.
+#[async_trait]
+pub trait EagerJoinStream {
+    /// Implements the main polling logic for the join stream.
+    ///
+    /// This method continuously checks the state of the join stream and
+    /// acts accordingly by delegating the handling to appropriate sub-methods
+    /// depending on the current state.
+    ///
+    /// # Arguments
+    ///
+    /// * `cx` - A context that facilitates cooperative non-blocking execution 
within a task.
+    ///
+    /// # Returns
+    ///
+    /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a 
`RecordBatch` or None.
+    fn poll_next_impl(
+        &mut self,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>>
+    where
+        Self: Send,
+    {
+        loop {
+            return match self.state() {
+                EagerJoinStreamState::PullRight => {
+                    handle_async_state!(self.fetch_next_from_right_stream(), 
cx)
+                }
+                EagerJoinStreamState::PullLeft => {
+                    handle_async_state!(self.fetch_next_from_left_stream(), cx)
+                }
+                EagerJoinStreamState::RightExhausted => {
+                    handle_async_state!(self.handle_right_stream_end(), cx)
+                }
+                EagerJoinStreamState::LeftExhausted => {
+                    handle_async_state!(self.handle_left_stream_end(), cx)
+                }
+                EagerJoinStreamState::BothExhausted {
+                    final_result: false,
+                } => {
+                    
handle_state!(self.prepare_for_final_results_after_exhaustion())
+                }
+                EagerJoinStreamState::BothExhausted { final_result: true } => {
+                    Poll::Ready(None)
+                }
+            };
+        }
+    }
+    /// Asynchronously pulls the next batch from the right stream.
+    ///
+    /// This default implementation checks for the next value in the right 
stream.
+    /// If a batch is found, the state is switched to `PullLeft`, and the 
batch handling
+    /// is delegated to `process_batch_from_right`. If the stream ends, the 
state is set to `RightExhausted`.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after pulling the batch.
+    async fn fetch_next_from_right_stream(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        match self.right_stream().next().await {
+            Some(Ok(batch)) => {
+                self.set_state(EagerJoinStreamState::PullLeft);
+                self.process_batch_from_right(batch)
+            }
+            Some(Err(e)) => Err(e),
+            None => {
+                self.set_state(EagerJoinStreamState::RightExhausted);
+                Ok(StreamJoinStateResult::Continue)
+            }
+        }
+    }
+
+    /// Asynchronously pulls the next batch from the left stream.
+    ///
+    /// This default implementation checks for the next value in the left 
stream.
+    /// If a batch is found, the state is switched to `PullRight`, and the 
batch handling
+    /// is delegated to `process_batch_from_left`. If the stream ends, the 
state is set to `LeftExhausted`.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after pulling the batch.
+    async fn fetch_next_from_left_stream(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        match self.left_stream().next().await {
+            Some(Ok(batch)) => {
+                self.set_state(EagerJoinStreamState::PullRight);
+                self.process_batch_from_left(batch)
+            }
+            Some(Err(e)) => Err(e),
+            None => {
+                self.set_state(EagerJoinStreamState::LeftExhausted);
+                Ok(StreamJoinStateResult::Continue)
+            }
+        }
+    }
+
+    /// Asynchronously handles the scenario when the right stream is exhausted.
+    ///
+    /// In this default implementation, when the right stream is exhausted, it 
attempts
+    /// to pull from the left stream. If a batch is found in the left stream, 
it delegates
+    /// the handling to `process_batch_from_left`. If both streams are 
exhausted, the state is set
+    /// to indicate both streams are exhausted without final results yet.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after checking the exhaustion state.
+    async fn handle_right_stream_end(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        match self.left_stream().next().await {
+            Some(Ok(batch)) => self.process_batch_after_right_end(batch),
+            Some(Err(e)) => Err(e),
+            None => {
+                self.set_state(EagerJoinStreamState::BothExhausted {
+                    final_result: false,
+                });
+                Ok(StreamJoinStateResult::Continue)
+            }
+        }
+    }
+
+    /// Asynchronously handles the scenario when the left stream is exhausted.
+    ///
+    /// When the left stream is exhausted, this default
+    /// implementation tries to pull from the right stream and delegates the 
batch
+    /// handling to `process_batch_after_left_end`. If both streams are 
exhausted, the state
+    /// is updated to indicate so.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after checking the exhaustion state.
+    async fn handle_left_stream_end(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        match self.right_stream().next().await {
+            Some(Ok(batch)) => self.process_batch_after_left_end(batch),
+            Some(Err(e)) => Err(e),
+            None => {
+                self.set_state(EagerJoinStreamState::BothExhausted {
+                    final_result: false,
+                });
+                Ok(StreamJoinStateResult::Continue)
+            }
+        }
+    }
+
+    /// Handles the state when both streams are exhausted and final results 
are yet to be produced.
+    ///
+    /// This default implementation switches the state to indicate both 
streams are
+    /// exhausted with final results and then invokes the handling for this 
specific
+    /// scenario via `process_batches_before_finalization`.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after both streams are exhausted.
+    fn prepare_for_final_results_after_exhaustion(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        self.set_state(EagerJoinStreamState::BothExhausted { final_result: 
true });
+        self.process_batches_before_finalization()
+    }
+
+    /// Handles a pulled batch from the right stream.
+    ///
+    /// # Arguments
+    ///
+    /// * `batch` - The pulled `RecordBatch` from the right stream.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after processing the batch.
+    fn process_batch_from_right(
+        &mut self,
+        batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+    /// Handles a pulled batch from the left stream.
+    ///
+    /// # Arguments
+    ///
+    /// * `batch` - The pulled `RecordBatch` from the left stream.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after processing the batch.
+    fn process_batch_from_left(
+        &mut self,
+        batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+    /// Handles the situation when only the left stream is exhausted.
+    ///
+    /// # Arguments
+    ///
+    /// * `right_batch` - The `RecordBatch` from the right stream.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after the left stream is exhausted.
+    fn process_batch_after_left_end(
+        &mut self,
+        right_batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+    /// Handles the situation when only the right stream is exhausted.
+    ///
+    /// # Arguments
+    ///
+    /// * `left_batch` - The `RecordBatch` from the left stream.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state 
result after the right stream is exhausted.
+    fn process_batch_after_right_end(
+        &mut self,
+        left_batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+    /// Handles the final state after both streams are exhausted.
+    ///
+    /// # Returns
+    ///
+    /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The final 
state result after processing.
+    fn process_batches_before_finalization(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+    /// Provides mutable access to the right stream.
+    ///
+    /// # Returns
+    ///
+    /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to 
the right stream.
+    fn right_stream(&mut self) -> &mut SendableRecordBatchStream;
+
+    /// Provides mutable access to the left stream.
+    ///
+    /// # Returns
+    ///
+    /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to 
the left stream.
+    fn left_stream(&mut self) -> &mut SendableRecordBatchStream;
+
+    /// Sets the current state of the join stream.
+    ///
+    /// # Arguments
+    ///
+    /// * `state` - The new state to be set.
+    fn set_state(&mut self, state: EagerJoinStreamState);
+
+    /// Fetches the current state of the join stream.
+    ///
+    /// # Returns
+    ///
+    /// * `EagerJoinStreamState` - The current state of the join stream.
+    fn state(&mut self) -> EagerJoinStreamState;
+}
+
 #[cfg(test)]
 pub mod tests {
+    use std::sync::Arc;
+
     use super::*;
+    use crate::joins::stream_join_utils::{
+        build_filter_input_order, check_filter_expr_contains_sort_information,
+        convert_sort_expr_with_filter_schema, PruningJoinHashMap,
+    };
     use crate::{
         expressions::Column,
         expressions::PhysicalSortExpr,
         joins::utils::{ColumnIndex, JoinFilter},
     };
+
     use arrow::compute::SortOptions;
     use arrow::datatypes::{DataType, Field, Schema};
     use datafusion_common::ScalarValue;
     use datafusion_expr::Operator;
     use datafusion_physical_expr::expressions::{binary, cast, col, lit};
-    use std::sync::Arc;
 
     /// Filter expr for a + b > c + 10 AND a + b < c + 100
     pub(crate) fn complicated_filter(
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 51561f5dab..d653297abe 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -25,20 +25,19 @@
 //! This plan uses the [`OneSideHashJoiner`] object to facilitate join 
calculations
 //! for both its children.
 
-use std::fmt;
-use std::fmt::Debug;
+use std::any::Any;
+use std::fmt::{self, Debug};
 use std::sync::Arc;
 use std::task::Poll;
-use std::vec;
-use std::{any::Any, usize};
+use std::{usize, vec};
 
 use crate::common::SharedMemoryReservation;
 use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
-use crate::joins::hash_join_utils::{
+use crate::joins::stream_join_utils::{
     calculate_filter_expr_intervals, combine_two_batches,
     convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
-    get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap,
-    SortedFilterExpr,
+    get_pruning_semi_indices, record_visited_indices, EagerJoinStream,
+    EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, 
StreamJoinStateResult,
 };
 use crate::joins::utils::{
     build_batch_from_indices, build_join_schema, check_join_is_valid,
@@ -67,8 +66,7 @@ use 
datafusion_physical_expr::equivalence::join_equivalence_properties;
 use datafusion_physical_expr::intervals::ExprIntervalGraph;
 
 use ahash::RandomState;
-use futures::stream::{select, BoxStream};
-use futures::{Stream, StreamExt};
+use futures::Stream;
 use hashbrown::HashSet;
 use parking_lot::Mutex;
 
@@ -186,34 +184,34 @@ pub struct SymmetricHashJoinExec {
 }
 
 #[derive(Debug)]
-struct SymmetricHashJoinSideMetrics {
+pub struct StreamJoinSideMetrics {
     /// Number of batches consumed by this operator
-    input_batches: metrics::Count,
+    pub(crate) input_batches: metrics::Count,
     /// Number of rows consumed by this operator
-    input_rows: metrics::Count,
+    pub(crate) input_rows: metrics::Count,
 }
 
 /// Metrics for HashJoinExec
 #[derive(Debug)]
-struct SymmetricHashJoinMetrics {
+pub struct StreamJoinMetrics {
     /// Number of left batches/rows consumed by this operator
-    left: SymmetricHashJoinSideMetrics,
+    pub(crate) left: StreamJoinSideMetrics,
     /// Number of right batches/rows consumed by this operator
-    right: SymmetricHashJoinSideMetrics,
+    pub(crate) right: StreamJoinSideMetrics,
     /// Memory used by sides in bytes
     pub(crate) stream_memory_usage: metrics::Gauge,
     /// Number of batches produced by this operator
-    output_batches: metrics::Count,
+    pub(crate) output_batches: metrics::Count,
     /// Number of rows produced by this operator
-    output_rows: metrics::Count,
+    pub(crate) output_rows: metrics::Count,
 }
 
-impl SymmetricHashJoinMetrics {
+impl StreamJoinMetrics {
     pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
         let input_batches =
             MetricBuilder::new(metrics).counter("input_batches", partition);
         let input_rows = MetricBuilder::new(metrics).counter("input_rows", 
partition);
-        let left = SymmetricHashJoinSideMetrics {
+        let left = StreamJoinSideMetrics {
             input_batches,
             input_rows,
         };
@@ -221,7 +219,7 @@ impl SymmetricHashJoinMetrics {
         let input_batches =
             MetricBuilder::new(metrics).counter("input_batches", partition);
         let input_rows = MetricBuilder::new(metrics).counter("input_rows", 
partition);
-        let right = SymmetricHashJoinSideMetrics {
+        let right = StreamJoinSideMetrics {
             input_batches,
             input_rows,
         };
@@ -516,21 +514,9 @@ impl ExecutionPlan for SymmetricHashJoinExec {
         let right_side_joiner =
             OneSideHashJoiner::new(JoinSide::Right, on_right, 
self.right.schema());
 
-        let left_stream = self
-            .left
-            .execute(partition, context.clone())?
-            .map(|val| (JoinSide::Left, val));
-
-        let right_stream = self
-            .right
-            .execute(partition, context.clone())?
-            .map(|val| (JoinSide::Right, val));
-        // This function will attempt to pull items from both streams.
-        // Each stream will be polled in a round-robin fashion, and whenever a 
stream is
-        // ready to yield an item that item is yielded.
-        // After one of the two input streams completes, the remaining one 
will be polled exclusively.
-        // The returned stream completes when both input streams have 
completed.
-        let input_stream = select(left_stream, right_stream).boxed();
+        let left_stream = self.left.execute(partition, context.clone())?;
+
+        let right_stream = self.right.execute(partition, context.clone())?;
 
         let reservation = Arc::new(Mutex::new(
             
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
@@ -541,7 +527,8 @@ impl ExecutionPlan for SymmetricHashJoinExec {
         }
 
         Ok(Box::pin(SymmetricHashJoinStream {
-            input_stream,
+            left_stream,
+            right_stream,
             schema: self.schema(),
             filter: self.filter.clone(),
             join_type: self.join_type,
@@ -549,12 +536,12 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             left: left_side_joiner,
             right: right_side_joiner,
             column_indices: self.column_indices.clone(),
-            metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics),
+            metrics: StreamJoinMetrics::new(partition, &self.metrics),
             graph,
             left_sorted_filter_expr,
             right_sorted_filter_expr,
             null_equals_null: self.null_equals_null,
-            final_result: false,
+            state: EagerJoinStreamState::PullRight,
             reservation,
         }))
     }
@@ -562,8 +549,9 @@ impl ExecutionPlan for SymmetricHashJoinExec {
 
 /// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
 struct SymmetricHashJoinStream {
-    /// Input stream
-    input_stream: BoxStream<'static, (JoinSide, Result<RecordBatch>)>,
+    /// Input streams
+    left_stream: SendableRecordBatchStream,
+    right_stream: SendableRecordBatchStream,
     /// Input schema
     schema: Arc<Schema>,
     /// join filter
@@ -587,11 +575,11 @@ struct SymmetricHashJoinStream {
     /// If null_equals_null is true, null == null else null != null
     null_equals_null: bool,
     /// Metrics
-    metrics: SymmetricHashJoinMetrics,
+    metrics: StreamJoinMetrics,
     /// Memory reservation
     reservation: SharedMemoryReservation,
-    /// Flag indicating whether there is nothing to process anymore
-    final_result: bool,
+    /// State machine for input execution
+    state: EagerJoinStreamState,
 }
 
 impl RecordBatchStream for SymmetricHashJoinStream {
@@ -763,7 +751,9 @@ pub(crate) fn build_side_determined_results(
     column_indices: &[ColumnIndex],
 ) -> Result<Option<RecordBatch>> {
     // Check if we need to produce a result in the final output:
-    if need_to_produce_result_in_final(build_hash_joiner.build_side, 
join_type) {
+    if prune_length > 0
+        && need_to_produce_result_in_final(build_hash_joiner.build_side, 
join_type)
+    {
         // Calculate the indices for build and probe sides based on join type 
and build side:
         let (build_indices, probe_indices) = calculate_indices_by_join_type(
             build_hash_joiner.build_side,
@@ -1019,10 +1009,104 @@ impl OneSideHashJoiner {
     }
 }
 
+impl EagerJoinStream for SymmetricHashJoinStream {
+    fn process_batch_from_right(
+        &mut self,
+        batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        self.perform_join_for_given_side(batch, JoinSide::Right)
+            .map(|maybe_batch| {
+                if maybe_batch.is_some() {
+                    StreamJoinStateResult::Ready(maybe_batch)
+                } else {
+                    StreamJoinStateResult::Continue
+                }
+            })
+    }
+
+    fn process_batch_from_left(
+        &mut self,
+        batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        self.perform_join_for_given_side(batch, JoinSide::Left)
+            .map(|maybe_batch| {
+                if maybe_batch.is_some() {
+                    StreamJoinStateResult::Ready(maybe_batch)
+                } else {
+                    StreamJoinStateResult::Continue
+                }
+            })
+    }
+
+    fn process_batch_after_left_end(
+        &mut self,
+        right_batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        self.process_batch_from_right(right_batch)
+    }
+
+    fn process_batch_after_right_end(
+        &mut self,
+        left_batch: RecordBatch,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        self.process_batch_from_left(left_batch)
+    }
+
+    fn process_batches_before_finalization(
+        &mut self,
+    ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+        // Get the left side results:
+        let left_result = build_side_determined_results(
+            &self.left,
+            &self.schema,
+            self.left.input_buffer.num_rows(),
+            self.right.input_buffer.schema(),
+            self.join_type,
+            &self.column_indices,
+        )?;
+        // Get the right side results:
+        let right_result = build_side_determined_results(
+            &self.right,
+            &self.schema,
+            self.right.input_buffer.num_rows(),
+            self.left.input_buffer.schema(),
+            self.join_type,
+            &self.column_indices,
+        )?;
+
+        // Combine the left and right results:
+        let result = combine_two_batches(&self.schema, left_result, 
right_result)?;
+
+        // Update the metrics and return the result:
+        if let Some(batch) = &result {
+            // Update the metrics:
+            self.metrics.output_batches.add(1);
+            self.metrics.output_rows.add(batch.num_rows());
+            return Ok(StreamJoinStateResult::Ready(result));
+        }
+        Ok(StreamJoinStateResult::Continue)
+    }
+
+    fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
+        &mut self.right_stream
+    }
+
+    fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
+        &mut self.left_stream
+    }
+
+    fn set_state(&mut self, state: EagerJoinStreamState) {
+        self.state = state;
+    }
+
+    fn state(&mut self) -> EagerJoinStreamState {
+        self.state.clone()
+    }
+}
+
 impl SymmetricHashJoinStream {
     fn size(&self) -> usize {
         let mut size = 0;
-        size += std::mem::size_of_val(&self.input_stream);
         size += std::mem::size_of_val(&self.schema);
         size += std::mem::size_of_val(&self.filter);
         size += std::mem::size_of_val(&self.join_type);
@@ -1035,165 +1119,111 @@ impl SymmetricHashJoinStream {
         size += std::mem::size_of_val(&self.random_state);
         size += std::mem::size_of_val(&self.null_equals_null);
         size += std::mem::size_of_val(&self.metrics);
-        size += std::mem::size_of_val(&self.final_result);
         size
     }
-    /// Polls the next result of the join operation.
-    ///
-    /// If the result of the join is ready, it returns the next record batch.
-    /// If the join has completed and there are no more results, it returns
-    /// `Poll::Ready(None)`. If the join operation is not complete, but the
-    /// current stream is not ready yet, it returns `Poll::Pending`.
-    fn poll_next_impl(
+
+    /// Performs a join operation for the specified `probe_side` (either left 
or right).
+    /// This function:
+    /// 1. Determines which side is the probe and which is the build side.
+    /// 2. Updates metrics based on the batch that was polled.
+    /// 3. Executes the join with the given `probe_batch`.
+    /// 4. Optionally computes anti-join results if all conditions are met.
+    /// 5. Combines the results and returns a combined batch or `None` if no 
batch was produced.
+    fn perform_join_for_given_side(
         &mut self,
-        cx: &mut std::task::Context<'_>,
-    ) -> Poll<Option<Result<RecordBatch>>> {
-        loop {
-            // Poll the next batch from `input_stream`:
-            match self.input_stream.poll_next_unpin(cx) {
-                // Batch is available
-                Poll::Ready(Some((side, Ok(probe_batch)))) => {
-                    // Determine which stream should be polled next. The side 
the
-                    // RecordBatch comes from becomes the probe side.
-                    let (
-                        probe_hash_joiner,
-                        build_hash_joiner,
-                        probe_side_sorted_filter_expr,
-                        build_side_sorted_filter_expr,
-                        probe_side_metrics,
-                    ) = if side.eq(&JoinSide::Left) {
-                        (
-                            &mut self.left,
-                            &mut self.right,
-                            &mut self.left_sorted_filter_expr,
-                            &mut self.right_sorted_filter_expr,
-                            &mut self.metrics.left,
-                        )
-                    } else {
-                        (
-                            &mut self.right,
-                            &mut self.left,
-                            &mut self.right_sorted_filter_expr,
-                            &mut self.left_sorted_filter_expr,
-                            &mut self.metrics.right,
-                        )
-                    };
-                    // Update the metrics for the stream that was polled:
-                    probe_side_metrics.input_batches.add(1);
-                    probe_side_metrics.input_rows.add(probe_batch.num_rows());
-                    // Update the internal state of the hash joiner for the 
build side:
-                    probe_hash_joiner
-                        .update_internal_state(&probe_batch, 
&self.random_state)?;
-                    // Join the two sides:
-                    let equal_result = join_with_probe_batch(
-                        build_hash_joiner,
-                        probe_hash_joiner,
-                        &self.schema,
-                        self.join_type,
-                        self.filter.as_ref(),
-                        &probe_batch,
-                        &self.column_indices,
-                        &self.random_state,
-                        self.null_equals_null,
-                    )?;
-                    // Increment the offset for the probe hash joiner:
-                    probe_hash_joiner.offset += probe_batch.num_rows();
-
-                    let anti_result = if let (
-                        Some(build_side_sorted_filter_expr),
-                        Some(probe_side_sorted_filter_expr),
-                        Some(graph),
-                    ) = (
-                        build_side_sorted_filter_expr.as_mut(),
-                        probe_side_sorted_filter_expr.as_mut(),
-                        self.graph.as_mut(),
-                    ) {
-                        // Calculate filter intervals:
-                        calculate_filter_expr_intervals(
-                            &build_hash_joiner.input_buffer,
-                            build_side_sorted_filter_expr,
-                            &probe_batch,
-                            probe_side_sorted_filter_expr,
-                        )?;
-                        let prune_length = build_hash_joiner
-                            .calculate_prune_length_with_probe_batch(
-                                build_side_sorted_filter_expr,
-                                probe_side_sorted_filter_expr,
-                                graph,
-                            )?;
-
-                        if prune_length > 0 {
-                            let res = build_side_determined_results(
-                                build_hash_joiner,
-                                &self.schema,
-                                prune_length,
-                                probe_batch.schema(),
-                                self.join_type,
-                                &self.column_indices,
-                            )?;
-                            
build_hash_joiner.prune_internal_state(prune_length)?;
-                            res
-                        } else {
-                            None
-                        }
-                    } else {
-                        None
-                    };
-
-                    // Combine results:
-                    let result =
-                        combine_two_batches(&self.schema, equal_result, 
anti_result)?;
-                    let capacity = self.size();
-                    self.metrics.stream_memory_usage.set(capacity);
-                    self.reservation.lock().try_resize(capacity)?;
-                    // Update the metrics if we have a batch; otherwise, 
continue the loop.
-                    if let Some(batch) = &result {
-                        self.metrics.output_batches.add(1);
-                        self.metrics.output_rows.add(batch.num_rows());
-                        return Poll::Ready(Ok(result).transpose());
-                    }
-                }
-                Poll::Ready(Some((_, Err(e)))) => return 
Poll::Ready(Some(Err(e))),
-                Poll::Ready(None) => {
-                    // If the final result has already been obtained, return 
`Poll::Ready(None)`:
-                    if self.final_result {
-                        return Poll::Ready(None);
-                    }
-                    self.final_result = true;
-                    // Get the left side results:
-                    let left_result = build_side_determined_results(
-                        &self.left,
-                        &self.schema,
-                        self.left.input_buffer.num_rows(),
-                        self.right.input_buffer.schema(),
-                        self.join_type,
-                        &self.column_indices,
-                    )?;
-                    // Get the right side results:
-                    let right_result = build_side_determined_results(
-                        &self.right,
-                        &self.schema,
-                        self.right.input_buffer.num_rows(),
-                        self.left.input_buffer.schema(),
-                        self.join_type,
-                        &self.column_indices,
-                    )?;
-
-                    // Combine the left and right results:
-                    let result =
-                        combine_two_batches(&self.schema, left_result, 
right_result)?;
-
-                    // Update the metrics and return the result:
-                    if let Some(batch) = &result {
-                        // Update the metrics:
-                        self.metrics.output_batches.add(1);
-                        self.metrics.output_rows.add(batch.num_rows());
-                        return Poll::Ready(Ok(result).transpose());
-                    }
-                }
-                Poll::Pending => return Poll::Pending,
-            }
+        probe_batch: RecordBatch,
+        probe_side: JoinSide,
+    ) -> Result<Option<RecordBatch>> {
+        let (
+            probe_hash_joiner,
+            build_hash_joiner,
+            probe_side_sorted_filter_expr,
+            build_side_sorted_filter_expr,
+            probe_side_metrics,
+        ) = if probe_side.eq(&JoinSide::Left) {
+            (
+                &mut self.left,
+                &mut self.right,
+                &mut self.left_sorted_filter_expr,
+                &mut self.right_sorted_filter_expr,
+                &mut self.metrics.left,
+            )
+        } else {
+            (
+                &mut self.right,
+                &mut self.left,
+                &mut self.right_sorted_filter_expr,
+                &mut self.left_sorted_filter_expr,
+                &mut self.metrics.right,
+            )
+        };
+        // Update the metrics for the stream that was polled:
+        probe_side_metrics.input_batches.add(1);
+        probe_side_metrics.input_rows.add(probe_batch.num_rows());
+        // Update the internal state of the hash joiner for the build side:
+        probe_hash_joiner.update_internal_state(&probe_batch, 
&self.random_state)?;
+        // Join the two sides:
+        let equal_result = join_with_probe_batch(
+            build_hash_joiner,
+            probe_hash_joiner,
+            &self.schema,
+            self.join_type,
+            self.filter.as_ref(),
+            &probe_batch,
+            &self.column_indices,
+            &self.random_state,
+            self.null_equals_null,
+        )?;
+        // Increment the offset for the probe hash joiner:
+        probe_hash_joiner.offset += probe_batch.num_rows();
+
+        let anti_result = if let (
+            Some(build_side_sorted_filter_expr),
+            Some(probe_side_sorted_filter_expr),
+            Some(graph),
+        ) = (
+            build_side_sorted_filter_expr.as_mut(),
+            probe_side_sorted_filter_expr.as_mut(),
+            self.graph.as_mut(),
+        ) {
+            // Calculate filter intervals:
+            calculate_filter_expr_intervals(
+                &build_hash_joiner.input_buffer,
+                build_side_sorted_filter_expr,
+                &probe_batch,
+                probe_side_sorted_filter_expr,
+            )?;
+            let prune_length = build_hash_joiner
+                .calculate_prune_length_with_probe_batch(
+                    build_side_sorted_filter_expr,
+                    probe_side_sorted_filter_expr,
+                    graph,
+                )?;
+            let result = build_side_determined_results(
+                build_hash_joiner,
+                &self.schema,
+                prune_length,
+                probe_batch.schema(),
+                self.join_type,
+                &self.column_indices,
+            )?;
+            build_hash_joiner.prune_internal_state(prune_length)?;
+            result
+        } else {
+            None
+        };
+
+        // Combine results:
+        let result = combine_two_batches(&self.schema, equal_result, 
anti_result)?;
+        let capacity = self.size();
+        self.metrics.stream_memory_usage.set(capacity);
+        self.reservation.lock().try_resize(capacity)?;
+        // Update the metrics if we have a batch; otherwise, continue the loop.
+        if let Some(batch) = &result {
+            self.metrics.output_batches.add(1);
+            self.metrics.output_rows.add(batch.num_rows());
         }
+        Ok(result)
     }
 }
 
@@ -1203,10 +1233,9 @@ mod tests {
     use std::sync::Mutex;
 
     use super::*;
-    use crate::joins::hash_join_utils::tests::complicated_filter;
     use crate::joins::test_utils::{
-        build_sides_record_batches, compare_batches, create_memory_table,
-        join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
+        build_sides_record_batches, compare_batches, complicated_filter,
+        create_memory_table, join_expr_tests_fixture_f64, 
join_expr_tests_fixture_i32,
         join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
         partitioned_sym_join_with_filter, split_record_batches,
     };
@@ -1833,6 +1862,73 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test(flavor = "multi_thread")]
+    async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
+        let cardinality = (3, 4);
+        let join_type = JoinType::Full;
+
+        // a + b > c + 10 AND a + b < c + 100
+        let config = SessionConfig::new().with_repartition_joins(false);
+        // let session_ctx = SessionContext::with_config(config);
+        // let task_ctx = session_ctx.task_ctx();
+        let task_ctx = 
Arc::new(TaskContext::default().with_session_config(config));
+        let (left_partition, right_partition) = 
get_or_create_table(cardinality, 8)?;
+        let left_schema = &left_partition[0].schema();
+        let right_schema = &right_partition[0].schema();
+        let left_sorted = vec![
+            vec![PhysicalSortExpr {
+                expr: col("la1", left_schema)?,
+                options: SortOptions::default(),
+            }],
+            vec![PhysicalSortExpr {
+                expr: col("la2", left_schema)?,
+                options: SortOptions::default(),
+            }],
+        ];
+
+        let right_sorted = vec![PhysicalSortExpr {
+            expr: col("ra1", right_schema)?,
+            options: SortOptions::default(),
+        }];
+
+        let (left, right) = create_memory_table(
+            left_partition,
+            right_partition,
+            left_sorted,
+            vec![right_sorted],
+        )?;
+
+        let on = vec![(
+            Column::new_with_schema("lc1", left_schema)?,
+            Column::new_with_schema("rc1", right_schema)?,
+        )];
+
+        let intermediate_schema = Schema::new(vec![
+            Field::new("0", DataType::Int32, true),
+            Field::new("1", DataType::Int32, true),
+            Field::new("2", DataType::Int32, true),
+        ]);
+        let filter_expr = complicated_filter(&intermediate_schema)?;
+        let column_indices = vec![
+            ColumnIndex {
+                index: 0,
+                side: JoinSide::Left,
+            },
+            ColumnIndex {
+                index: 4,
+                side: JoinSide::Left,
+            },
+            ColumnIndex {
+                index: 0,
+                side: JoinSide::Right,
+            },
+        ];
+        let filter = JoinFilter::new(filter_expr, column_indices, 
intermediate_schema);
+
+        experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
+        Ok(())
+    }
+
     #[rstest]
     #[tokio::test(flavor = "multi_thread")]
     async fn testing_with_temporal_columns(
@@ -1917,6 +2013,7 @@ mod tests {
         experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
         Ok(())
     }
+
     #[rstest]
     #[tokio::test(flavor = "multi_thread")]
     async fn test_with_interval_columns(
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs 
b/datafusion/physical-plan/src/joins/test_utils.rs
index bb4a861991..6deaa9ba1b 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -17,6 +17,9 @@
 
 //! This file has test utils for hash joins
 
+use std::sync::Arc;
+use std::usize;
+
 use crate::joins::utils::{JoinFilter, JoinOn};
 use crate::joins::{
     HashJoinExec, PartitionMode, StreamJoinPartitionMode, 
SymmetricHashJoinExec,
@@ -24,24 +27,24 @@ use crate::joins::{
 use crate::memory::MemoryExec;
 use crate::repartition::RepartitionExec;
 use crate::{common, ExecutionPlan, Partitioning};
+
 use arrow::util::pretty::pretty_format_batches;
 use arrow_array::{
     ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch,
     TimestampMillisecondArray,
 };
-use arrow_schema::Schema;
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
+use arrow_schema::{DataType, Schema};
+use datafusion_common::{Result, ScalarValue};
 use datafusion_execution::TaskContext;
 use datafusion_expr::{JoinType, Operator};
+use datafusion_physical_expr::expressions::{binary, cast, col, lit};
 use datafusion_physical_expr::intervals::test_utils::{
     gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr,
 };
 use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
+
 use rand::prelude::StdRng;
 use rand::{Rng, SeedableRng};
-use std::sync::Arc;
-use std::usize;
 
 pub fn compare_batches(collected_1: &[RecordBatch], collected_2: 
&[RecordBatch]) {
     // compare
@@ -500,3 +503,51 @@ pub fn create_memory_table(
         .with_sort_information(right_sorted);
     Ok((Arc::new(left), Arc::new(right)))
 }
+
+/// Filter expr for a + b > c + 10 AND a + b < c + 100
+pub(crate) fn complicated_filter(
+    filter_schema: &Schema,
+) -> Result<Arc<dyn PhysicalExpr>> {
+    let left_expr = binary(
+        cast(
+            binary(
+                col("0", filter_schema)?,
+                Operator::Plus,
+                col("1", filter_schema)?,
+                filter_schema,
+            )?,
+            filter_schema,
+            DataType::Int64,
+        )?,
+        Operator::Gt,
+        binary(
+            cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
+            Operator::Plus,
+            lit(ScalarValue::Int64(Some(10))),
+            filter_schema,
+        )?,
+        filter_schema,
+    )?;
+
+    let right_expr = binary(
+        cast(
+            binary(
+                col("0", filter_schema)?,
+                Operator::Plus,
+                col("1", filter_schema)?,
+                filter_schema,
+            )?,
+            filter_schema,
+            DataType::Int64,
+        )?,
+        Operator::Lt,
+        binary(
+            cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
+            Operator::Plus,
+            lit(ScalarValue::Int64(Some(100))),
+            filter_schema,
+        )?,
+        filter_schema,
+    )?;
+    binary(left_expr, Operator::And, right_expr, filter_schema)
+}
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index f93f08255e..0729d365d6 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -18,12 +18,14 @@
 //! Join related functionality used both on logical and physical plans
 
 use std::collections::HashSet;
+use std::fmt::{self, Debug};
 use std::future::Future;
+use std::ops::IndexMut;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 use std::usize;
 
-use crate::joins::hash_join_utils::{build_filter_input_order, 
SortedFilterExpr};
+use crate::joins::stream_join_utils::{build_filter_input_order, 
SortedFilterExpr};
 use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
 use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics};
 
@@ -50,8 +52,135 @@ use datafusion_physical_expr::{
 
 use futures::future::{BoxFuture, Shared};
 use futures::{ready, FutureExt};
+use hashbrown::raw::RawTable;
 use parking_lot::Mutex;
 
+/// Maps a `u64` hash value based on the build side ["on" values] to a list of 
indices with this key's value.
+///
+/// By allocating a `HashMap` with capacity for *at least* the number of rows 
for entries at the build side,
+/// we make sure that we don't have to re-hash the hashmap, which needs access 
to the key (the hash in this case) value.
+///
+/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 
8 for hash value 1
+/// As the key is a hash value, we need to check possible hash collisions in 
the probe stage
+/// During this stage it might be the case that a row is contained the same 
hashmap value,
+/// but the values don't match. Those are checked in the 
[`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method.
+///
+/// The indices (values) are stored in a separate chained list stored in the 
`Vec<u64>`.
+///
+/// The first value (+1) is stored in the hashmap, whereas the next value is 
stored in array at the position value.
+///
+/// The chain can be followed until the value "0" has been reached, meaning 
the end of the list.
+/// Also see chapter 5.3 of [Balancing vectorized query execution with 
bandwidth-optimized 
storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487)
+///
+/// # Example
+///
+/// ``` text
+/// See the example below:
+///
+/// Insert (10,1)            <-- insert hash value 10 with row index 1
+/// map:
+/// ----------
+/// | 10 | 2 |
+/// ----------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 0 | 0 |
+/// ---------------------
+/// Insert (20,2)
+/// map:
+/// ----------
+/// | 10 | 2 |
+/// | 20 | 3 |
+/// ----------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 0 | 0 |
+/// ---------------------
+/// Insert (10,3)           <-- collision! row index 3 has a hash value of 10 
as well
+/// map:
+/// ----------
+/// | 10 | 4 |
+/// | 20 | 3 |
+/// ----------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 2 | 0 |  <--- hash value 10 maps to 4,2 (which means indices 
values 3,1)
+/// ---------------------
+/// Insert (10,4)          <-- another collision! row index 4 ALSO has a hash 
value of 10
+/// map:
+/// ---------
+/// | 10 | 5 |
+/// | 20 | 3 |
+/// ---------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means 
indices values 4,3,1)
+/// ---------------------
+/// ```
+pub struct JoinHashMap {
+    // Stores hash value to last row index
+    map: RawTable<(u64, u64)>,
+    // Stores indices in chained list data structure
+    next: Vec<u64>,
+}
+
+impl JoinHashMap {
+    #[cfg(test)]
+    pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec<u64>) -> Self {
+        Self { map, next }
+    }
+
+    pub(crate) fn with_capacity(capacity: usize) -> Self {
+        JoinHashMap {
+            map: RawTable::with_capacity(capacity),
+            next: vec![0; capacity],
+        }
+    }
+}
+
+// Trait defining methods that must be implemented by a hash map type to be 
used for joins.
+pub trait JoinHashMapType {
+    /// The type of list used to store the next list
+    type NextType: IndexMut<usize, Output = u64>;
+    /// Extend with zero
+    fn extend_zero(&mut self, len: usize);
+    /// Returns mutable references to the hash map and the next.
+    fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType);
+    /// Returns a reference to the hash map.
+    fn get_map(&self) -> &RawTable<(u64, u64)>;
+    /// Returns a reference to the next.
+    fn get_list(&self) -> &Self::NextType;
+}
+
+/// Implementation of `JoinHashMapType` for `JoinHashMap`.
+impl JoinHashMapType for JoinHashMap {
+    type NextType = Vec<u64>;
+
+    // Void implementation
+    fn extend_zero(&mut self, _: usize) {}
+
+    /// Get mutable references to the hash map and the next.
+    fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) {
+        (&mut self.map, &mut self.next)
+    }
+
+    /// Get a reference to the hash map.
+    fn get_map(&self) -> &RawTable<(u64, u64)> {
+        &self.map
+    }
+
+    /// Get a reference to the next.
+    fn get_list(&self) -> &Self::NextType {
+        &self.next
+    }
+}
+
+impl fmt::Debug for JoinHashMap {
+    fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
+        Ok(())
+    }
+}
+
 /// The on clause of the join, as vector of (left, right) columns.
 pub type JoinOn = Vec<(Column, Column)>;
 /// Reference for JoinOn.
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 9d508078c7..9197343d74 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1155,6 +1155,7 @@ message PhysicalPlanNode {
     NestedLoopJoinExecNode nested_loop_join = 22;
     AnalyzeExecNode analyze = 23;
     JsonSinkExecNode json_sink = 24;
+    SymmetricHashJoinExecNode symmetric_hash_join = 25;
   }
 }
 
@@ -1432,6 +1433,21 @@ message HashJoinExecNode {
   JoinFilter filter = 8;
 }
 
+enum StreamPartitionMode {
+  SINGLE_PARTITION = 0;
+  PARTITIONED_EXEC = 1;
+}
+
+message SymmetricHashJoinExecNode {
+  PhysicalPlanNode left = 1;
+  PhysicalPlanNode right = 2;
+  repeated JoinOn on = 3;
+  JoinType join_type = 4;
+  StreamPartitionMode partition_mode = 6;
+  bool null_equals_null = 7;
+  JoinFilter filter = 8;
+}
+
 message UnionExecNode {
   repeated PhysicalPlanNode inputs = 1;
 }
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 0a8f415e20..8a63600237 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -17844,6 +17844,9 @@ impl serde::Serialize for PhysicalPlanNode {
                 physical_plan_node::PhysicalPlanType::JsonSink(v) => {
                     struct_ser.serialize_field("jsonSink", v)?;
                 }
+                physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => {
+                    struct_ser.serialize_field("symmetricHashJoin", v)?;
+                }
             }
         }
         struct_ser.end()
@@ -17890,6 +17893,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
             "analyze",
             "json_sink",
             "jsonSink",
+            "symmetric_hash_join",
+            "symmetricHashJoin",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -17917,6 +17922,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
             NestedLoopJoin,
             Analyze,
             JsonSink,
+            SymmetricHashJoin,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -17961,6 +17967,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
                             "nestedLoopJoin" | "nested_loop_join" => 
Ok(GeneratedField::NestedLoopJoin),
                             "analyze" => Ok(GeneratedField::Analyze),
                             "jsonSink" | "json_sink" => 
Ok(GeneratedField::JsonSink),
+                            "symmetricHashJoin" | "symmetric_hash_join" => 
Ok(GeneratedField::SymmetricHashJoin),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -18142,6 +18149,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode 
{
                                 return 
Err(serde::de::Error::duplicate_field("jsonSink"));
                             }
                             physical_plan_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink)
+;
+                        }
+                        GeneratedField::SymmetricHashJoin => {
+                            if physical_plan_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("symmetricHashJoin"));
+                            }
+                            physical_plan_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin)
 ;
                         }
                     }
@@ -23648,6 +23662,77 @@ impl<'de> serde::Deserialize<'de> for Statistics {
         deserializer.deserialize_struct("datafusion.Statistics", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for StreamPartitionMode {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        let variant = match self {
+            Self::SinglePartition => "SINGLE_PARTITION",
+            Self::PartitionedExec => "PARTITIONED_EXEC",
+        };
+        serializer.serialize_str(variant)
+    }
+}
+impl<'de> serde::Deserialize<'de> for StreamPartitionMode {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "SINGLE_PARTITION",
+            "PARTITIONED_EXEC",
+        ];
+
+        struct GeneratedVisitor;
+
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = StreamPartitionMode;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                write!(formatter, "expected one of: {:?}", &FIELDS)
+            }
+
+            fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, 
E>
+            where
+                E: serde::de::Error,
+            {
+                i32::try_from(v)
+                    .ok()
+                    .and_then(|x| x.try_into().ok())
+                    .ok_or_else(|| {
+                        
serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self)
+                    })
+            }
+
+            fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value, 
E>
+            where
+                E: serde::de::Error,
+            {
+                i32::try_from(v)
+                    .ok()
+                    .and_then(|x| x.try_into().ok())
+                    .ok_or_else(|| {
+                        
serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self)
+                    })
+            }
+
+            fn visit_str<E>(self, value: &str) -> 
std::result::Result<Self::Value, E>
+            where
+                E: serde::de::Error,
+            {
+                match value {
+                    "SINGLE_PARTITION" => 
Ok(StreamPartitionMode::SinglePartition),
+                    "PARTITIONED_EXEC" => 
Ok(StreamPartitionMode::PartitionedExec),
+                    _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
+                }
+            }
+        }
+        deserializer.deserialize_any(GeneratedVisitor)
+    }
+}
 impl serde::Serialize for StringifiedPlan {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
@@ -24066,6 +24151,206 @@ impl<'de> serde::Deserialize<'de> for 
SubqueryAliasNode {
         deserializer.deserialize_struct("datafusion.SubqueryAliasNode", 
FIELDS, GeneratedVisitor)
     }
 }
+impl serde::Serialize for SymmetricHashJoinExecNode {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.left.is_some() {
+            len += 1;
+        }
+        if self.right.is_some() {
+            len += 1;
+        }
+        if !self.on.is_empty() {
+            len += 1;
+        }
+        if self.join_type != 0 {
+            len += 1;
+        }
+        if self.partition_mode != 0 {
+            len += 1;
+        }
+        if self.null_equals_null {
+            len += 1;
+        }
+        if self.filter.is_some() {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?;
+        if let Some(v) = self.left.as_ref() {
+            struct_ser.serialize_field("left", v)?;
+        }
+        if let Some(v) = self.right.as_ref() {
+            struct_ser.serialize_field("right", v)?;
+        }
+        if !self.on.is_empty() {
+            struct_ser.serialize_field("on", &self.on)?;
+        }
+        if self.join_type != 0 {
+            let v = JoinType::try_from(self.join_type)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.join_type)))?;
+            struct_ser.serialize_field("joinType", &v)?;
+        }
+        if self.partition_mode != 0 {
+            let v = StreamPartitionMode::try_from(self.partition_mode)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.partition_mode)))?;
+            struct_ser.serialize_field("partitionMode", &v)?;
+        }
+        if self.null_equals_null {
+            struct_ser.serialize_field("nullEqualsNull", 
&self.null_equals_null)?;
+        }
+        if let Some(v) = self.filter.as_ref() {
+            struct_ser.serialize_field("filter", v)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "left",
+            "right",
+            "on",
+            "join_type",
+            "joinType",
+            "partition_mode",
+            "partitionMode",
+            "null_equals_null",
+            "nullEqualsNull",
+            "filter",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Left,
+            Right,
+            On,
+            JoinType,
+            PartitionMode,
+            NullEqualsNull,
+            Filter,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "left" => Ok(GeneratedField::Left),
+                            "right" => Ok(GeneratedField::Right),
+                            "on" => Ok(GeneratedField::On),
+                            "joinType" | "join_type" => 
Ok(GeneratedField::JoinType),
+                            "partitionMode" | "partition_mode" => 
Ok(GeneratedField::PartitionMode),
+                            "nullEqualsNull" | "null_equals_null" => 
Ok(GeneratedField::NullEqualsNull),
+                            "filter" => Ok(GeneratedField::Filter),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = SymmetricHashJoinExecNode;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct 
datafusion.SymmetricHashJoinExecNode")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<SymmetricHashJoinExecNode, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut left__ = None;
+                let mut right__ = None;
+                let mut on__ = None;
+                let mut join_type__ = None;
+                let mut partition_mode__ = None;
+                let mut null_equals_null__ = None;
+                let mut filter__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::Left => {
+                            if left__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("left"));
+                            }
+                            left__ = map_.next_value()?;
+                        }
+                        GeneratedField::Right => {
+                            if right__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("right"));
+                            }
+                            right__ = map_.next_value()?;
+                        }
+                        GeneratedField::On => {
+                            if on__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("on"));
+                            }
+                            on__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::JoinType => {
+                            if join_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("joinType"));
+                            }
+                            join_type__ = Some(map_.next_value::<JoinType>()? 
as i32);
+                        }
+                        GeneratedField::PartitionMode => {
+                            if partition_mode__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("partitionMode"));
+                            }
+                            partition_mode__ = 
Some(map_.next_value::<StreamPartitionMode>()? as i32);
+                        }
+                        GeneratedField::NullEqualsNull => {
+                            if null_equals_null__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullEqualsNull"));
+                            }
+                            null_equals_null__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::Filter => {
+                            if filter__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("filter"));
+                            }
+                            filter__ = map_.next_value()?;
+                        }
+                    }
+                }
+                Ok(SymmetricHashJoinExecNode {
+                    left: left__,
+                    right: right__,
+                    on: on__.unwrap_or_default(),
+                    join_type: join_type__.unwrap_or_default(),
+                    partition_mode: partition_mode__.unwrap_or_default(),
+                    null_equals_null: null_equals_null__.unwrap_or_default(),
+                    filter: filter__,
+                })
+            }
+        }
+        
deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, 
GeneratedVisitor)
+    }
+}
 impl serde::Serialize for TimeUnit {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 84fb84b948..4fb8e1599e 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1514,7 +1514,7 @@ pub mod owned_table_reference {
 pub struct PhysicalPlanNode {
     #[prost(
         oneof = "physical_plan_node::PhysicalPlanType",
-        tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
19, 20, 21, 22, 23, 24"
+        tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
19, 20, 21, 22, 23, 24, 25"
     )]
     pub physical_plan_type: 
::core::option::Option<physical_plan_node::PhysicalPlanType>,
 }
@@ -1571,6 +1571,8 @@ pub mod physical_plan_node {
         Analyze(::prost::alloc::boxed::Box<super::AnalyzeExecNode>),
         #[prost(message, tag = "24")]
         JsonSink(::prost::alloc::boxed::Box<super::JsonSinkExecNode>),
+        #[prost(message, tag = "25")]
+        
SymmetricHashJoin(::prost::alloc::boxed::Box<super::SymmetricHashJoinExecNode>),
     }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
@@ -2009,6 +2011,24 @@ pub struct HashJoinExecNode {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct SymmetricHashJoinExecNode {
+    #[prost(message, optional, boxed, tag = "1")]
+    pub left: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+    #[prost(message, optional, boxed, tag = "2")]
+    pub right: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+    #[prost(message, repeated, tag = "3")]
+    pub on: ::prost::alloc::vec::Vec<JoinOn>,
+    #[prost(enumeration = "JoinType", tag = "4")]
+    pub join_type: i32,
+    #[prost(enumeration = "StreamPartitionMode", tag = "6")]
+    pub partition_mode: i32,
+    #[prost(bool, tag = "7")]
+    pub null_equals_null: bool,
+    #[prost(message, optional, tag = "8")]
+    pub filter: ::core::option::Option<JoinFilter>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct UnionExecNode {
     #[prost(message, repeated, tag = "1")]
     pub inputs: ::prost::alloc::vec::Vec<PhysicalPlanNode>,
@@ -3265,6 +3285,32 @@ impl PartitionMode {
 }
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
 #[repr(i32)]
+pub enum StreamPartitionMode {
+    SinglePartition = 0,
+    PartitionedExec = 1,
+}
+impl StreamPartitionMode {
+    /// String value of the enum field names used in the ProtoBuf definition.
+    ///
+    /// The values are not transformed in any way and thus are considered 
stable
+    /// (if the ProtoBuf definition does not change) and safe for programmatic 
use.
+    pub fn as_str_name(&self) -> &'static str {
+        match self {
+            StreamPartitionMode::SinglePartition => "SINGLE_PARTITION",
+            StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC",
+        }
+    }
+    /// Creates an enum from field names used in the ProtoBuf definition.
+    pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
+        match value {
+            "SINGLE_PARTITION" => Some(Self::SinglePartition),
+            "PARTITIONED_EXEC" => Some(Self::PartitionedExec),
+            _ => None,
+        }
+    }
+}
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
+#[repr(i32)]
 pub enum AggregateMode {
     Partial = 0,
     Final = 1,
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 1eedbe987e..6714c35dc6 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -39,7 +39,9 @@ use datafusion::physical_plan::expressions::{Column, 
PhysicalSortExpr};
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::insert::FileSinkExec;
 use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
-use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec};
+use datafusion::physical_plan::joins::{
+    CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, 
SymmetricHashJoinExec,
+};
 use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
 use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
 use datafusion::physical_plan::projection::ProjectionExec;
@@ -583,6 +585,97 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     hashjoin.null_equals_null,
                 )?))
             }
+            PhysicalPlanType::SymmetricHashJoin(sym_join) => {
+                let left = into_physical_plan(
+                    &sym_join.left,
+                    registry,
+                    runtime,
+                    extension_codec,
+                )?;
+                let right = into_physical_plan(
+                    &sym_join.right,
+                    registry,
+                    runtime,
+                    extension_codec,
+                )?;
+                let on = sym_join
+                    .on
+                    .iter()
+                    .map(|col| {
+                        let left = into_required!(col.left)?;
+                        let right = into_required!(col.right)?;
+                        Ok((left, right))
+                    })
+                    .collect::<Result<_>>()?;
+                let join_type = 
protobuf::JoinType::try_from(sym_join.join_type)
+                    .map_err(|_| {
+                        proto_error(format!(
+                            "Received a SymmetricHashJoin message with unknown 
JoinType {}",
+                            sym_join.join_type
+                        ))
+                    })?;
+                let filter = sym_join
+                    .filter
+                    .as_ref()
+                    .map(|f| {
+                        let schema = f
+                            .schema
+                            .as_ref()
+                            .ok_or_else(|| proto_error("Missing JoinFilter 
schema"))?
+                            .try_into()?;
+
+                        let expression = parse_physical_expr(
+                            f.expression.as_ref().ok_or_else(|| {
+                                proto_error("Unexpected empty filter 
expression")
+                            })?,
+                            registry, &schema
+                        )?;
+                        let column_indices = f.column_indices
+                            .iter()
+                            .map(|i| {
+                                let side = protobuf::JoinSide::try_from(i.side)
+                                    .map_err(|_| proto_error(format!(
+                                        "Received a HashJoinNode message with 
JoinSide in Filter {}",
+                                        i.side))
+                                    )?;
+
+                                Ok(ColumnIndex{
+                                    index: i.index as usize,
+                                    side: side.into(),
+                                })
+                            })
+                            .collect::<Result<_>>()?;
+
+                        Ok(JoinFilter::new(expression, column_indices, schema))
+                    })
+                    .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+
+                let partition_mode =
+                    
protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| {
+                        proto_error(format!(
+                            "Received a SymmetricHashJoin message with unknown 
PartitionMode {}",
+                            sym_join.partition_mode
+                        ))
+                    })?;
+                let partition_mode = match partition_mode {
+                    protobuf::StreamPartitionMode::SinglePartition => {
+                        StreamJoinPartitionMode::SinglePartition
+                    }
+                    protobuf::StreamPartitionMode::PartitionedExec => {
+                        StreamJoinPartitionMode::Partitioned
+                    }
+                };
+                SymmetricHashJoinExec::try_new(
+                    left,
+                    right,
+                    on,
+                    filter,
+                    &join_type.into(),
+                    sym_join.null_equals_null,
+                    partition_mode,
+                )
+                .map(|e| Arc::new(e) as _)
+            }
             PhysicalPlanType::Union(union) => {
                 let mut inputs: Vec<Arc<dyn ExecutionPlan>> = vec![];
                 for input in &union.inputs {
@@ -1008,6 +1101,79 @@ impl AsExecutionPlan for PhysicalPlanNode {
             });
         }
 
+        if let Some(exec) = plan.downcast_ref::<SymmetricHashJoinExec>() {
+            let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
+                exec.left().to_owned(),
+                extension_codec,
+            )?;
+            let right = protobuf::PhysicalPlanNode::try_from_physical_plan(
+                exec.right().to_owned(),
+                extension_codec,
+            )?;
+            let on = exec
+                .on()
+                .iter()
+                .map(|tuple| protobuf::JoinOn {
+                    left: Some(protobuf::PhysicalColumn {
+                        name: tuple.0.name().to_string(),
+                        index: tuple.0.index() as u32,
+                    }),
+                    right: Some(protobuf::PhysicalColumn {
+                        name: tuple.1.name().to_string(),
+                        index: tuple.1.index() as u32,
+                    }),
+                })
+                .collect();
+            let join_type: protobuf::JoinType = 
exec.join_type().to_owned().into();
+            let filter = exec
+                .filter()
+                .as_ref()
+                .map(|f| {
+                    let expression = f.expression().to_owned().try_into()?;
+                    let column_indices = f
+                        .column_indices()
+                        .iter()
+                        .map(|i| {
+                            let side: protobuf::JoinSide = 
i.side.to_owned().into();
+                            protobuf::ColumnIndex {
+                                index: i.index as u32,
+                                side: side.into(),
+                            }
+                        })
+                        .collect();
+                    let schema = f.schema().try_into()?;
+                    Ok(protobuf::JoinFilter {
+                        expression: Some(expression),
+                        column_indices,
+                        schema: Some(schema),
+                    })
+                })
+                .map_or(Ok(None), |v: Result<protobuf::JoinFilter>| 
v.map(Some))?;
+
+            let partition_mode = match exec.partition_mode() {
+                StreamJoinPartitionMode::SinglePartition => {
+                    protobuf::StreamPartitionMode::SinglePartition
+                }
+                StreamJoinPartitionMode::Partitioned => {
+                    protobuf::StreamPartitionMode::PartitionedExec
+                }
+            };
+
+            return Ok(protobuf::PhysicalPlanNode {
+                physical_plan_type: 
Some(PhysicalPlanType::SymmetricHashJoin(Box::new(
+                    protobuf::SymmetricHashJoinExecNode {
+                        left: Some(Box::new(left)),
+                        right: Some(Box::new(right)),
+                        on,
+                        join_type: join_type.into(),
+                        partition_mode: partition_mode.into(),
+                        null_equals_null: exec.null_equals_null(),
+                        filter,
+                    },
+                ))),
+            });
+        }
+
         if let Some(exec) = plan.downcast_ref::<CrossJoinExec>() {
             let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
                 exec.left().to_owned(),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 23b0ea43c7..d7d762d470 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -45,7 +45,9 @@ use datafusion::physical_plan::expressions::{
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::functions::make_scalar_function;
 use datafusion::physical_plan::insert::FileSinkExec;
-use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, 
PartitionMode};
+use datafusion::physical_plan::joins::{
+    HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
+};
 use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
 use datafusion::physical_plan::projection::ProjectionExec;
 use datafusion::physical_plan::sorts::sort::SortExec;
@@ -754,3 +756,45 @@ fn roundtrip_json_sink() -> Result<()> {
         Some(sort_order),
     )))
 }
+
+#[test]
+fn roundtrip_sym_hash_join() -> Result<()> {
+    let field_a = Field::new("col", DataType::Int64, false);
+    let schema_left = Schema::new(vec![field_a.clone()]);
+    let schema_right = Schema::new(vec![field_a]);
+    let on = vec![(
+        Column::new("col", schema_left.index_of("col")?),
+        Column::new("col", schema_right.index_of("col")?),
+    )];
+
+    let schema_left = Arc::new(schema_left);
+    let schema_right = Arc::new(schema_right);
+    for join_type in &[
+        JoinType::Inner,
+        JoinType::Left,
+        JoinType::Right,
+        JoinType::Full,
+        JoinType::LeftAnti,
+        JoinType::RightAnti,
+        JoinType::LeftSemi,
+        JoinType::RightSemi,
+    ] {
+        for partition_mode in &[
+            StreamJoinPartitionMode::Partitioned,
+            StreamJoinPartitionMode::SinglePartition,
+        ] {
+            roundtrip_test(Arc::new(
+                
datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new(
+                    Arc::new(EmptyExec::new(false, schema_left.clone())),
+                    Arc::new(EmptyExec::new(false, schema_right.clone())),
+                    on.clone(),
+                    None,
+                    join_type,
+                    false,
+                    *partition_mode,
+                )?,
+            ))?;
+        }
+    }
+    Ok(())
+}


Reply via email to