This is an automated email from the ASF dual-hosted git repository.
ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2156dde546 Making stream joins extensible: A new Trait implementation
for SHJ (#8234)
2156dde546 is described below
commit 2156dde54623d26635d4388d161d94ac79918cdc
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Mon Nov 20 13:40:25 2023 +0200
Making stream joins extensible: A new Trait implementation for SHJ (#8234)
* Upstream
* Update utils.rs
* Review
* Name change and remove ignore on test
* Comment revisions
* Improve comments
---------
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion/physical-plan/src/joins/hash_join.rs | 3 +-
datafusion/physical-plan/src/joins/mod.rs | 2 +-
.../{hash_join_utils.rs => stream_join_utils.rs} | 550 ++++++++++++++++-----
.../physical-plan/src/joins/symmetric_hash_join.rs | 503 +++++++++++--------
datafusion/physical-plan/src/joins/test_utils.rs | 61 ++-
datafusion/physical-plan/src/joins/utils.rs | 131 ++++-
datafusion/proto/proto/datafusion.proto | 16 +
datafusion/proto/src/generated/pbjson.rs | 285 +++++++++++
datafusion/proto/src/generated/prost.rs | 48 +-
datafusion/proto/src/physical_plan/mod.rs | 168 ++++++-
.../proto/tests/cases/roundtrip_physical_plan.rs | 46 +-
11 files changed, 1463 insertions(+), 350 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs
b/datafusion/physical-plan/src/joins/hash_join.rs
index 7a08b56a6e..4846d0a5e0 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -26,7 +26,7 @@ use std::{any::Any, usize, vec};
use crate::joins::utils::{
adjust_indices_by_join_type, apply_join_filter_to_indices,
build_batch_from_indices,
calculate_join_output_ordering, get_final_indices_from_bit_map,
- need_produce_result_in_final,
+ need_produce_result_in_final, JoinHashMap, JoinHashMapType,
};
use crate::DisplayAs;
use crate::{
@@ -35,7 +35,6 @@ use crate::{
expressions::Column,
expressions::PhysicalSortExpr,
hash_utils::create_hashes,
- joins::hash_join_utils::{JoinHashMap, JoinHashMapType},
joins::utils::{
adjust_right_output_partitioning, build_join_schema,
check_join_is_valid,
estimate_join_statistics, partitioned_join_output_partitioning,
diff --git a/datafusion/physical-plan/src/joins/mod.rs
b/datafusion/physical-plan/src/joins/mod.rs
index 19f10d06e1..6ddf19c511 100644
--- a/datafusion/physical-plan/src/joins/mod.rs
+++ b/datafusion/physical-plan/src/joins/mod.rs
@@ -25,9 +25,9 @@ pub use sort_merge_join::SortMergeJoinExec;
pub use symmetric_hash_join::SymmetricHashJoinExec;
mod cross_join;
mod hash_join;
-mod hash_join_utils;
mod nested_loop_join;
mod sort_merge_join;
+mod stream_join_utils;
mod symmetric_hash_join;
pub mod utils;
diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
similarity index 67%
rename from datafusion/physical-plan/src/joins/hash_join_utils.rs
rename to datafusion/physical-plan/src/joins/stream_join_utils.rs
index db65c8bf08..aa57a4f896 100644
--- a/datafusion/physical-plan/src/joins/hash_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -15,151 +15,34 @@
// specific language governing permissions and limitations
// under the License.
-//! This file contains common subroutines for regular and symmetric hash join
+//! This file contains common subroutines for symmetric hash join
//! related functionality, used both in join calculations and optimization
rules.
use std::collections::{HashMap, VecDeque};
-use std::fmt::Debug;
-use std::ops::IndexMut;
use std::sync::Arc;
-use std::{fmt, usize};
+use std::task::{Context, Poll};
+use std::usize;
-use crate::joins::utils::JoinFilter;
+use crate::handle_async_state;
+use crate::joins::utils::{JoinFilter, JoinHashMapType};
use arrow::compute::concat_batches;
-use arrow::datatypes::{ArrowNativeType, SchemaRef};
-use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray,
RecordBatch};
+use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder};
+use arrow_schema::SchemaRef;
+use async_trait::async_trait;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DataFusionError, JoinSide, Result, ScalarValue};
+use datafusion_execution::SendableRecordBatchStream;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::intervals::{Interval, IntervalBound};
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use futures::{ready, FutureExt, StreamExt};
use hashbrown::raw::RawTable;
use hashbrown::HashSet;
-/// Maps a `u64` hash value based on the build side ["on" values] to a list of
indices with this key's value.
-///
-/// By allocating a `HashMap` with capacity for *at least* the number of rows
for entries at the build side,
-/// we make sure that we don't have to re-hash the hashmap, which needs access
to the key (the hash in this case) value.
-///
-/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and
8 for hash value 1
-/// As the key is a hash value, we need to check possible hash collisions in
the probe stage
-/// During this stage it might be the case that a row is contained the same
hashmap value,
-/// but the values don't match. Those are checked in the
[`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method.
-///
-/// The indices (values) are stored in a separate chained list stored in the
`Vec<u64>`.
-///
-/// The first value (+1) is stored in the hashmap, whereas the next value is
stored in array at the position value.
-///
-/// The chain can be followed until the value "0" has been reached, meaning
the end of the list.
-/// Also see chapter 5.3 of [Balancing vectorized query execution with
bandwidth-optimized
storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487)
-///
-/// # Example
-///
-/// ``` text
-/// See the example below:
-///
-/// Insert (10,1) <-- insert hash value 10 with row index 1
-/// map:
-/// ----------
-/// | 10 | 2 |
-/// ----------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 0 | 0 |
-/// ---------------------
-/// Insert (20,2)
-/// map:
-/// ----------
-/// | 10 | 2 |
-/// | 20 | 3 |
-/// ----------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 0 | 0 |
-/// ---------------------
-/// Insert (10,3) <-- collision! row index 3 has a hash value of 10
as well
-/// map:
-/// ----------
-/// | 10 | 4 |
-/// | 20 | 3 |
-/// ----------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices
values 3,1)
-/// ---------------------
-/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash
value of 10
-/// map:
-/// ---------
-/// | 10 | 5 |
-/// | 20 | 3 |
-/// ---------
-/// next:
-/// ---------------------
-/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means
indices values 4,3,1)
-/// ---------------------
-/// ```
-pub struct JoinHashMap {
- // Stores hash value to last row index
- map: RawTable<(u64, u64)>,
- // Stores indices in chained list data structure
- next: Vec<u64>,
-}
-
-impl JoinHashMap {
- #[cfg(test)]
- pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec<u64>) -> Self {
- Self { map, next }
- }
-
- pub(crate) fn with_capacity(capacity: usize) -> Self {
- JoinHashMap {
- map: RawTable::with_capacity(capacity),
- next: vec![0; capacity],
- }
- }
-}
-
-/// Trait defining methods that must be implemented by a hash map type to be
used for joins.
-pub trait JoinHashMapType {
- /// The type of list used to store the next list
- type NextType: IndexMut<usize, Output = u64>;
- /// Extend with zero
- fn extend_zero(&mut self, len: usize);
- /// Returns mutable references to the hash map and the next.
- fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType);
- /// Returns a reference to the hash map.
- fn get_map(&self) -> &RawTable<(u64, u64)>;
- /// Returns a reference to the next.
- fn get_list(&self) -> &Self::NextType;
-}
-
-/// Implementation of `JoinHashMapType` for `JoinHashMap`.
-impl JoinHashMapType for JoinHashMap {
- type NextType = Vec<u64>;
-
- // Void implementation
- fn extend_zero(&mut self, _: usize) {}
-
- /// Get mutable references to the hash map and the next.
- fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) {
- (&mut self.map, &mut self.next)
- }
-
- /// Get a reference to the hash map.
- fn get_map(&self) -> &RawTable<(u64, u64)> {
- &self.map
- }
-
- /// Get a reference to the next.
- fn get_list(&self) -> &Self::NextType {
- &self.next
- }
-}
-
/// Implementation of `JoinHashMapType` for `PruningJoinHashMap`.
impl JoinHashMapType for PruningJoinHashMap {
type NextType = VecDeque<u64>;
@@ -185,12 +68,6 @@ impl JoinHashMapType for PruningJoinHashMap {
}
}
-impl fmt::Debug for JoinHashMap {
- fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
- Ok(())
- }
-}
-
/// The `PruningJoinHashMap` is similar to a regular `JoinHashMap`, but with
/// the capability of pruning elements in an efficient manner. This structure
/// is particularly useful for cases where it's necessary to remove elements
@@ -322,7 +199,7 @@ impl PruningJoinHashMap {
}
}
-fn check_filter_expr_contains_sort_information(
+pub fn check_filter_expr_contains_sort_information(
expr: &Arc<dyn PhysicalExpr>,
reference: &Arc<dyn PhysicalExpr>,
) -> bool {
@@ -740,20 +617,423 @@ pub fn record_visited_indices<T: ArrowPrimitiveType>(
}
}
+/// The `handle_state` macro is designed to process the result of a
state-changing
+/// operation, typically encountered in implementations of `EagerJoinStream`.
It
+/// operates on a `StreamJoinStateResult` by matching its variants and
executing
+/// corresponding actions. This macro is used to streamline code that deals
with
+/// state transitions, reducing boilerplate and improving readability.
+///
+/// # Cases
+///
+/// - `Ok(StreamJoinStateResult::Continue)`: Continues the loop, indicating the
+/// stream join operation should proceed to the next step.
+/// - `Ok(StreamJoinStateResult::Ready(result))`: Returns a `Poll::Ready` with
the
+/// result, either yielding a value or indicating the stream is awaiting more
+/// data.
+/// - `Err(e)`: Returns a `Poll::Ready` containing an error, signaling an issue
+/// during the stream join operation.
+///
+/// # Arguments
+///
+/// * `$match_case`: An expression that evaluates to a
`Result<StreamJoinStateResult<_>>`.
+#[macro_export]
+macro_rules! handle_state {
+ ($match_case:expr) => {
+ match $match_case {
+ Ok(StreamJoinStateResult::Continue) => continue,
+ Ok(StreamJoinStateResult::Ready(result)) => {
+ Poll::Ready(Ok(result).transpose())
+ }
+ Err(e) => Poll::Ready(Some(Err(e))),
+ }
+ };
+}
+
+/// The `handle_async_state` macro adapts the `handle_state` macro for use in
+/// asynchronous operations, particularly when dealing with `Poll` results
within
+/// async traits like `EagerJoinStream`. It polls the asynchronous
state-changing
+/// function using `poll_unpin` and then passes the result to `handle_state`
for
+/// further processing.
+///
+/// # Arguments
+///
+/// * `$state_func`: An async function or future that returns a
+/// `Result<StreamJoinStateResult<_>>`.
+/// * `$cx`: The context to be passed for polling, usually of type `&mut
Context`.
+///
+#[macro_export]
+macro_rules! handle_async_state {
+ ($state_func:expr, $cx:expr) => {
+ $crate::handle_state!(ready!($state_func.poll_unpin($cx)))
+ };
+}
+
+/// Represents the result of a stateful operation on `EagerJoinStream`.
+///
+/// This enumueration indicates whether the state produced a result that is
+/// ready for use (`Ready`) or if the operation requires continuation
(`Continue`).
+///
+/// Variants:
+/// - `Ready(T)`: Indicates that the operation is complete with a result of
type `T`.
+/// - `Continue`: Indicates that the operation is not yet complete and
requires further
+/// processing or more data. When this variant is returned, it typically
means that the
+/// current invocation of the state did not produce a final result, and the
operation
+/// should be invoked again later with more data and possibly with a
different state.
+pub enum StreamJoinStateResult<T> {
+ Ready(T),
+ Continue,
+}
+
+/// Represents the various states of an eager join stream operation.
+///
+/// This enum is used to track the current state of streaming during a join
+/// operation. It provides indicators as to which side of the join needs to be
+/// pulled next or if one (or both) sides have been exhausted. This allows
+/// for efficient management of resources and optimal performance during the
+/// join process.
+#[derive(Clone, Debug)]
+pub enum EagerJoinStreamState {
+ /// Indicates that the next step should pull from the right side of the
join.
+ PullRight,
+
+ /// Indicates that the next step should pull from the left side of the
join.
+ PullLeft,
+
+ /// State representing that the right side of the join has been fully
processed.
+ RightExhausted,
+
+ /// State representing that the left side of the join has been fully
processed.
+ LeftExhausted,
+
+ /// Represents a state where both sides of the join are exhausted.
+ ///
+ /// The `final_result` field indicates whether the join operation has
+ /// produced a final result or not.
+ BothExhausted { final_result: bool },
+}
+
+/// `EagerJoinStream` is an asynchronous trait designed for managing
incremental
+/// join operations between two streams, such as those used in
`SymmetricHashJoinExec`
+/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to
scan
+/// one side of the join fully before proceeding, `EagerJoinStream` facilitates
+/// more dynamic join operations by working with streams as they emit data.
This
+/// approach allows for more efficient processing, particularly in scenarios
+/// where waiting for complete data materialization is not feasible or optimal.
+/// The trait provides a framework for handling various states of such a join
+/// process, ensuring that join logic is efficiently executed as data becomes
+/// available from either stream.
+///
+/// Implementors of this trait can perform eager joins of data from two
different
+/// asynchronous streams, typically referred to as left and right streams. The
+/// trait provides a comprehensive set of methods to control and execute the
join
+/// process, leveraging the states defined in `EagerJoinStreamState`. Methods
are
+/// primarily focused on asynchronously fetching data batches from each stream,
+/// processing them, and managing transitions between various states of the
join.
+///
+/// This trait's default implementations use a state machine approach to
navigate
+/// different stages of the join operation, handling data from both streams and
+/// determining when the join completes.
+///
+/// State Transitions:
+/// - From `PullLeft` to `PullRight` or `LeftExhausted`:
+/// - In `fetch_next_from_left_stream`, when fetching a batch from the left
stream:
+/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for
+/// processing the batch.
+/// - On error (`Some(Err(e))`), the error is returned, and the state
remains
+/// unchanged.
+/// - On no data (`None`), state changes to `LeftExhausted`, returning
`Continue`
+/// to proceed with the join process.
+/// - From `PullRight` to `PullLeft` or `RightExhausted`:
+/// - In `fetch_next_from_right_stream`, when fetching from the right stream:
+/// - If a batch is available, state changes to `PullLeft` for processing.
+/// - On error, the error is returned without changing the state.
+/// - If right stream is exhausted (`None`), state transitions to
`RightExhausted`,
+/// with a `Continue` result.
+/// - Handling `RightExhausted` and `LeftExhausted`:
+/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage
scenarios
+/// when streams are exhausted:
+/// - They attempt to continue processing with the other stream.
+/// - If both streams are exhausted, state changes to `BothExhausted {
final_result: false }`.
+/// - Transition to `BothExhausted { final_result: true }`:
+/// - Occurs in `prepare_for_final_results_after_exhaustion` when both
streams are
+/// exhausted, indicating completion of processing and availability of
final results.
+#[async_trait]
+pub trait EagerJoinStream {
+ /// Implements the main polling logic for the join stream.
+ ///
+ /// This method continuously checks the state of the join stream and
+ /// acts accordingly by delegating the handling to appropriate sub-methods
+ /// depending on the current state.
+ ///
+ /// # Arguments
+ ///
+ /// * `cx` - A context that facilitates cooperative non-blocking execution
within a task.
+ ///
+ /// # Returns
+ ///
+ /// * `Poll<Option<Result<RecordBatch>>>` - A polled result, either a
`RecordBatch` or None.
+ fn poll_next_impl(
+ &mut self,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>>
+ where
+ Self: Send,
+ {
+ loop {
+ return match self.state() {
+ EagerJoinStreamState::PullRight => {
+ handle_async_state!(self.fetch_next_from_right_stream(),
cx)
+ }
+ EagerJoinStreamState::PullLeft => {
+ handle_async_state!(self.fetch_next_from_left_stream(), cx)
+ }
+ EagerJoinStreamState::RightExhausted => {
+ handle_async_state!(self.handle_right_stream_end(), cx)
+ }
+ EagerJoinStreamState::LeftExhausted => {
+ handle_async_state!(self.handle_left_stream_end(), cx)
+ }
+ EagerJoinStreamState::BothExhausted {
+ final_result: false,
+ } => {
+
handle_state!(self.prepare_for_final_results_after_exhaustion())
+ }
+ EagerJoinStreamState::BothExhausted { final_result: true } => {
+ Poll::Ready(None)
+ }
+ };
+ }
+ }
+ /// Asynchronously pulls the next batch from the right stream.
+ ///
+ /// This default implementation checks for the next value in the right
stream.
+ /// If a batch is found, the state is switched to `PullLeft`, and the
batch handling
+ /// is delegated to `process_batch_from_right`. If the stream ends, the
state is set to `RightExhausted`.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after pulling the batch.
+ async fn fetch_next_from_right_stream(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ match self.right_stream().next().await {
+ Some(Ok(batch)) => {
+ self.set_state(EagerJoinStreamState::PullLeft);
+ self.process_batch_from_right(batch)
+ }
+ Some(Err(e)) => Err(e),
+ None => {
+ self.set_state(EagerJoinStreamState::RightExhausted);
+ Ok(StreamJoinStateResult::Continue)
+ }
+ }
+ }
+
+ /// Asynchronously pulls the next batch from the left stream.
+ ///
+ /// This default implementation checks for the next value in the left
stream.
+ /// If a batch is found, the state is switched to `PullRight`, and the
batch handling
+ /// is delegated to `process_batch_from_left`. If the stream ends, the
state is set to `LeftExhausted`.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after pulling the batch.
+ async fn fetch_next_from_left_stream(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ match self.left_stream().next().await {
+ Some(Ok(batch)) => {
+ self.set_state(EagerJoinStreamState::PullRight);
+ self.process_batch_from_left(batch)
+ }
+ Some(Err(e)) => Err(e),
+ None => {
+ self.set_state(EagerJoinStreamState::LeftExhausted);
+ Ok(StreamJoinStateResult::Continue)
+ }
+ }
+ }
+
+ /// Asynchronously handles the scenario when the right stream is exhausted.
+ ///
+ /// In this default implementation, when the right stream is exhausted, it
attempts
+ /// to pull from the left stream. If a batch is found in the left stream,
it delegates
+ /// the handling to `process_batch_from_left`. If both streams are
exhausted, the state is set
+ /// to indicate both streams are exhausted without final results yet.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after checking the exhaustion state.
+ async fn handle_right_stream_end(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ match self.left_stream().next().await {
+ Some(Ok(batch)) => self.process_batch_after_right_end(batch),
+ Some(Err(e)) => Err(e),
+ None => {
+ self.set_state(EagerJoinStreamState::BothExhausted {
+ final_result: false,
+ });
+ Ok(StreamJoinStateResult::Continue)
+ }
+ }
+ }
+
+ /// Asynchronously handles the scenario when the left stream is exhausted.
+ ///
+ /// When the left stream is exhausted, this default
+ /// implementation tries to pull from the right stream and delegates the
batch
+ /// handling to `process_batch_after_left_end`. If both streams are
exhausted, the state
+ /// is updated to indicate so.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after checking the exhaustion state.
+ async fn handle_left_stream_end(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ match self.right_stream().next().await {
+ Some(Ok(batch)) => self.process_batch_after_left_end(batch),
+ Some(Err(e)) => Err(e),
+ None => {
+ self.set_state(EagerJoinStreamState::BothExhausted {
+ final_result: false,
+ });
+ Ok(StreamJoinStateResult::Continue)
+ }
+ }
+ }
+
+ /// Handles the state when both streams are exhausted and final results
are yet to be produced.
+ ///
+ /// This default implementation switches the state to indicate both
streams are
+ /// exhausted with final results and then invokes the handling for this
specific
+ /// scenario via `process_batches_before_finalization`.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after both streams are exhausted.
+ fn prepare_for_final_results_after_exhaustion(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ self.set_state(EagerJoinStreamState::BothExhausted { final_result:
true });
+ self.process_batches_before_finalization()
+ }
+
+ /// Handles a pulled batch from the right stream.
+ ///
+ /// # Arguments
+ ///
+ /// * `batch` - The pulled `RecordBatch` from the right stream.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after processing the batch.
+ fn process_batch_from_right(
+ &mut self,
+ batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+ /// Handles a pulled batch from the left stream.
+ ///
+ /// # Arguments
+ ///
+ /// * `batch` - The pulled `RecordBatch` from the left stream.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after processing the batch.
+ fn process_batch_from_left(
+ &mut self,
+ batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+ /// Handles the situation when only the left stream is exhausted.
+ ///
+ /// # Arguments
+ ///
+ /// * `right_batch` - The `RecordBatch` from the right stream.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after the left stream is exhausted.
+ fn process_batch_after_left_end(
+ &mut self,
+ right_batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+ /// Handles the situation when only the right stream is exhausted.
+ ///
+ /// # Arguments
+ ///
+ /// * `left_batch` - The `RecordBatch` from the left stream.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The state
result after the right stream is exhausted.
+ fn process_batch_after_right_end(
+ &mut self,
+ left_batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+ /// Handles the final state after both streams are exhausted.
+ ///
+ /// # Returns
+ ///
+ /// * `Result<StreamJoinStateResult<Option<RecordBatch>>>` - The final
state result after processing.
+ fn process_batches_before_finalization(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>>;
+
+ /// Provides mutable access to the right stream.
+ ///
+ /// # Returns
+ ///
+ /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to
the right stream.
+ fn right_stream(&mut self) -> &mut SendableRecordBatchStream;
+
+ /// Provides mutable access to the left stream.
+ ///
+ /// # Returns
+ ///
+ /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to
the left stream.
+ fn left_stream(&mut self) -> &mut SendableRecordBatchStream;
+
+ /// Sets the current state of the join stream.
+ ///
+ /// # Arguments
+ ///
+ /// * `state` - The new state to be set.
+ fn set_state(&mut self, state: EagerJoinStreamState);
+
+ /// Fetches the current state of the join stream.
+ ///
+ /// # Returns
+ ///
+ /// * `EagerJoinStreamState` - The current state of the join stream.
+ fn state(&mut self) -> EagerJoinStreamState;
+}
+
#[cfg(test)]
pub mod tests {
+ use std::sync::Arc;
+
use super::*;
+ use crate::joins::stream_join_utils::{
+ build_filter_input_order, check_filter_expr_contains_sort_information,
+ convert_sort_expr_with_filter_schema, PruningJoinHashMap,
+ };
use crate::{
expressions::Column,
expressions::PhysicalSortExpr,
joins::utils::{ColumnIndex, JoinFilter},
};
+
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, cast, col, lit};
- use std::sync::Arc;
/// Filter expr for a + b > c + 10 AND a + b < c + 100
pub(crate) fn complicated_filter(
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 51561f5dab..d653297abe 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -25,20 +25,19 @@
//! This plan uses the [`OneSideHashJoiner`] object to facilitate join
calculations
//! for both its children.
-use std::fmt;
-use std::fmt::Debug;
+use std::any::Any;
+use std::fmt::{self, Debug};
use std::sync::Arc;
use std::task::Poll;
-use std::vec;
-use std::{any::Any, usize};
+use std::{usize, vec};
use crate::common::SharedMemoryReservation;
use crate::joins::hash_join::{build_equal_condition_join_indices, update_hash};
-use crate::joins::hash_join_utils::{
+use crate::joins::stream_join_utils::{
calculate_filter_expr_intervals, combine_two_batches,
convert_sort_expr_with_filter_schema, get_pruning_anti_indices,
- get_pruning_semi_indices, record_visited_indices, PruningJoinHashMap,
- SortedFilterExpr,
+ get_pruning_semi_indices, record_visited_indices, EagerJoinStream,
+ EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr,
StreamJoinStateResult,
};
use crate::joins::utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
@@ -67,8 +66,7 @@ use
datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::intervals::ExprIntervalGraph;
use ahash::RandomState;
-use futures::stream::{select, BoxStream};
-use futures::{Stream, StreamExt};
+use futures::Stream;
use hashbrown::HashSet;
use parking_lot::Mutex;
@@ -186,34 +184,34 @@ pub struct SymmetricHashJoinExec {
}
#[derive(Debug)]
-struct SymmetricHashJoinSideMetrics {
+pub struct StreamJoinSideMetrics {
/// Number of batches consumed by this operator
- input_batches: metrics::Count,
+ pub(crate) input_batches: metrics::Count,
/// Number of rows consumed by this operator
- input_rows: metrics::Count,
+ pub(crate) input_rows: metrics::Count,
}
/// Metrics for HashJoinExec
#[derive(Debug)]
-struct SymmetricHashJoinMetrics {
+pub struct StreamJoinMetrics {
/// Number of left batches/rows consumed by this operator
- left: SymmetricHashJoinSideMetrics,
+ pub(crate) left: StreamJoinSideMetrics,
/// Number of right batches/rows consumed by this operator
- right: SymmetricHashJoinSideMetrics,
+ pub(crate) right: StreamJoinSideMetrics,
/// Memory used by sides in bytes
pub(crate) stream_memory_usage: metrics::Gauge,
/// Number of batches produced by this operator
- output_batches: metrics::Count,
+ pub(crate) output_batches: metrics::Count,
/// Number of rows produced by this operator
- output_rows: metrics::Count,
+ pub(crate) output_rows: metrics::Count,
}
-impl SymmetricHashJoinMetrics {
+impl StreamJoinMetrics {
pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self {
let input_batches =
MetricBuilder::new(metrics).counter("input_batches", partition);
let input_rows = MetricBuilder::new(metrics).counter("input_rows",
partition);
- let left = SymmetricHashJoinSideMetrics {
+ let left = StreamJoinSideMetrics {
input_batches,
input_rows,
};
@@ -221,7 +219,7 @@ impl SymmetricHashJoinMetrics {
let input_batches =
MetricBuilder::new(metrics).counter("input_batches", partition);
let input_rows = MetricBuilder::new(metrics).counter("input_rows",
partition);
- let right = SymmetricHashJoinSideMetrics {
+ let right = StreamJoinSideMetrics {
input_batches,
input_rows,
};
@@ -516,21 +514,9 @@ impl ExecutionPlan for SymmetricHashJoinExec {
let right_side_joiner =
OneSideHashJoiner::new(JoinSide::Right, on_right,
self.right.schema());
- let left_stream = self
- .left
- .execute(partition, context.clone())?
- .map(|val| (JoinSide::Left, val));
-
- let right_stream = self
- .right
- .execute(partition, context.clone())?
- .map(|val| (JoinSide::Right, val));
- // This function will attempt to pull items from both streams.
- // Each stream will be polled in a round-robin fashion, and whenever a
stream is
- // ready to yield an item that item is yielded.
- // After one of the two input streams completes, the remaining one
will be polled exclusively.
- // The returned stream completes when both input streams have
completed.
- let input_stream = select(left_stream, right_stream).boxed();
+ let left_stream = self.left.execute(partition, context.clone())?;
+
+ let right_stream = self.right.execute(partition, context.clone())?;
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
@@ -541,7 +527,8 @@ impl ExecutionPlan for SymmetricHashJoinExec {
}
Ok(Box::pin(SymmetricHashJoinStream {
- input_stream,
+ left_stream,
+ right_stream,
schema: self.schema(),
filter: self.filter.clone(),
join_type: self.join_type,
@@ -549,12 +536,12 @@ impl ExecutionPlan for SymmetricHashJoinExec {
left: left_side_joiner,
right: right_side_joiner,
column_indices: self.column_indices.clone(),
- metrics: SymmetricHashJoinMetrics::new(partition, &self.metrics),
+ metrics: StreamJoinMetrics::new(partition, &self.metrics),
graph,
left_sorted_filter_expr,
right_sorted_filter_expr,
null_equals_null: self.null_equals_null,
- final_result: false,
+ state: EagerJoinStreamState::PullRight,
reservation,
}))
}
@@ -562,8 +549,9 @@ impl ExecutionPlan for SymmetricHashJoinExec {
/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
struct SymmetricHashJoinStream {
- /// Input stream
- input_stream: BoxStream<'static, (JoinSide, Result<RecordBatch>)>,
+ /// Input streams
+ left_stream: SendableRecordBatchStream,
+ right_stream: SendableRecordBatchStream,
/// Input schema
schema: Arc<Schema>,
/// join filter
@@ -587,11 +575,11 @@ struct SymmetricHashJoinStream {
/// If null_equals_null is true, null == null else null != null
null_equals_null: bool,
/// Metrics
- metrics: SymmetricHashJoinMetrics,
+ metrics: StreamJoinMetrics,
/// Memory reservation
reservation: SharedMemoryReservation,
- /// Flag indicating whether there is nothing to process anymore
- final_result: bool,
+ /// State machine for input execution
+ state: EagerJoinStreamState,
}
impl RecordBatchStream for SymmetricHashJoinStream {
@@ -763,7 +751,9 @@ pub(crate) fn build_side_determined_results(
column_indices: &[ColumnIndex],
) -> Result<Option<RecordBatch>> {
// Check if we need to produce a result in the final output:
- if need_to_produce_result_in_final(build_hash_joiner.build_side,
join_type) {
+ if prune_length > 0
+ && need_to_produce_result_in_final(build_hash_joiner.build_side,
join_type)
+ {
// Calculate the indices for build and probe sides based on join type
and build side:
let (build_indices, probe_indices) = calculate_indices_by_join_type(
build_hash_joiner.build_side,
@@ -1019,10 +1009,104 @@ impl OneSideHashJoiner {
}
}
+impl EagerJoinStream for SymmetricHashJoinStream {
+ fn process_batch_from_right(
+ &mut self,
+ batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ self.perform_join_for_given_side(batch, JoinSide::Right)
+ .map(|maybe_batch| {
+ if maybe_batch.is_some() {
+ StreamJoinStateResult::Ready(maybe_batch)
+ } else {
+ StreamJoinStateResult::Continue
+ }
+ })
+ }
+
+ fn process_batch_from_left(
+ &mut self,
+ batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ self.perform_join_for_given_side(batch, JoinSide::Left)
+ .map(|maybe_batch| {
+ if maybe_batch.is_some() {
+ StreamJoinStateResult::Ready(maybe_batch)
+ } else {
+ StreamJoinStateResult::Continue
+ }
+ })
+ }
+
+ fn process_batch_after_left_end(
+ &mut self,
+ right_batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ self.process_batch_from_right(right_batch)
+ }
+
+ fn process_batch_after_right_end(
+ &mut self,
+ left_batch: RecordBatch,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ self.process_batch_from_left(left_batch)
+ }
+
+ fn process_batches_before_finalization(
+ &mut self,
+ ) -> Result<StreamJoinStateResult<Option<RecordBatch>>> {
+ // Get the left side results:
+ let left_result = build_side_determined_results(
+ &self.left,
+ &self.schema,
+ self.left.input_buffer.num_rows(),
+ self.right.input_buffer.schema(),
+ self.join_type,
+ &self.column_indices,
+ )?;
+ // Get the right side results:
+ let right_result = build_side_determined_results(
+ &self.right,
+ &self.schema,
+ self.right.input_buffer.num_rows(),
+ self.left.input_buffer.schema(),
+ self.join_type,
+ &self.column_indices,
+ )?;
+
+ // Combine the left and right results:
+ let result = combine_two_batches(&self.schema, left_result,
right_result)?;
+
+ // Update the metrics and return the result:
+ if let Some(batch) = &result {
+ // Update the metrics:
+ self.metrics.output_batches.add(1);
+ self.metrics.output_rows.add(batch.num_rows());
+ return Ok(StreamJoinStateResult::Ready(result));
+ }
+ Ok(StreamJoinStateResult::Continue)
+ }
+
+ fn right_stream(&mut self) -> &mut SendableRecordBatchStream {
+ &mut self.right_stream
+ }
+
+ fn left_stream(&mut self) -> &mut SendableRecordBatchStream {
+ &mut self.left_stream
+ }
+
+ fn set_state(&mut self, state: EagerJoinStreamState) {
+ self.state = state;
+ }
+
+ fn state(&mut self) -> EagerJoinStreamState {
+ self.state.clone()
+ }
+}
+
impl SymmetricHashJoinStream {
fn size(&self) -> usize {
let mut size = 0;
- size += std::mem::size_of_val(&self.input_stream);
size += std::mem::size_of_val(&self.schema);
size += std::mem::size_of_val(&self.filter);
size += std::mem::size_of_val(&self.join_type);
@@ -1035,165 +1119,111 @@ impl SymmetricHashJoinStream {
size += std::mem::size_of_val(&self.random_state);
size += std::mem::size_of_val(&self.null_equals_null);
size += std::mem::size_of_val(&self.metrics);
- size += std::mem::size_of_val(&self.final_result);
size
}
- /// Polls the next result of the join operation.
- ///
- /// If the result of the join is ready, it returns the next record batch.
- /// If the join has completed and there are no more results, it returns
- /// `Poll::Ready(None)`. If the join operation is not complete, but the
- /// current stream is not ready yet, it returns `Poll::Pending`.
- fn poll_next_impl(
+
+ /// Performs a join operation for the specified `probe_side` (either left
or right).
+ /// This function:
+ /// 1. Determines which side is the probe and which is the build side.
+ /// 2. Updates metrics based on the batch that was polled.
+ /// 3. Executes the join with the given `probe_batch`.
+ /// 4. Optionally computes anti-join results if all conditions are met.
+ /// 5. Combines the results and returns a combined batch or `None` if no
batch was produced.
+ fn perform_join_for_given_side(
&mut self,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Result<RecordBatch>>> {
- loop {
- // Poll the next batch from `input_stream`:
- match self.input_stream.poll_next_unpin(cx) {
- // Batch is available
- Poll::Ready(Some((side, Ok(probe_batch)))) => {
- // Determine which stream should be polled next. The side
the
- // RecordBatch comes from becomes the probe side.
- let (
- probe_hash_joiner,
- build_hash_joiner,
- probe_side_sorted_filter_expr,
- build_side_sorted_filter_expr,
- probe_side_metrics,
- ) = if side.eq(&JoinSide::Left) {
- (
- &mut self.left,
- &mut self.right,
- &mut self.left_sorted_filter_expr,
- &mut self.right_sorted_filter_expr,
- &mut self.metrics.left,
- )
- } else {
- (
- &mut self.right,
- &mut self.left,
- &mut self.right_sorted_filter_expr,
- &mut self.left_sorted_filter_expr,
- &mut self.metrics.right,
- )
- };
- // Update the metrics for the stream that was polled:
- probe_side_metrics.input_batches.add(1);
- probe_side_metrics.input_rows.add(probe_batch.num_rows());
- // Update the internal state of the hash joiner for the
build side:
- probe_hash_joiner
- .update_internal_state(&probe_batch,
&self.random_state)?;
- // Join the two sides:
- let equal_result = join_with_probe_batch(
- build_hash_joiner,
- probe_hash_joiner,
- &self.schema,
- self.join_type,
- self.filter.as_ref(),
- &probe_batch,
- &self.column_indices,
- &self.random_state,
- self.null_equals_null,
- )?;
- // Increment the offset for the probe hash joiner:
- probe_hash_joiner.offset += probe_batch.num_rows();
-
- let anti_result = if let (
- Some(build_side_sorted_filter_expr),
- Some(probe_side_sorted_filter_expr),
- Some(graph),
- ) = (
- build_side_sorted_filter_expr.as_mut(),
- probe_side_sorted_filter_expr.as_mut(),
- self.graph.as_mut(),
- ) {
- // Calculate filter intervals:
- calculate_filter_expr_intervals(
- &build_hash_joiner.input_buffer,
- build_side_sorted_filter_expr,
- &probe_batch,
- probe_side_sorted_filter_expr,
- )?;
- let prune_length = build_hash_joiner
- .calculate_prune_length_with_probe_batch(
- build_side_sorted_filter_expr,
- probe_side_sorted_filter_expr,
- graph,
- )?;
-
- if prune_length > 0 {
- let res = build_side_determined_results(
- build_hash_joiner,
- &self.schema,
- prune_length,
- probe_batch.schema(),
- self.join_type,
- &self.column_indices,
- )?;
-
build_hash_joiner.prune_internal_state(prune_length)?;
- res
- } else {
- None
- }
- } else {
- None
- };
-
- // Combine results:
- let result =
- combine_two_batches(&self.schema, equal_result,
anti_result)?;
- let capacity = self.size();
- self.metrics.stream_memory_usage.set(capacity);
- self.reservation.lock().try_resize(capacity)?;
- // Update the metrics if we have a batch; otherwise,
continue the loop.
- if let Some(batch) = &result {
- self.metrics.output_batches.add(1);
- self.metrics.output_rows.add(batch.num_rows());
- return Poll::Ready(Ok(result).transpose());
- }
- }
- Poll::Ready(Some((_, Err(e)))) => return
Poll::Ready(Some(Err(e))),
- Poll::Ready(None) => {
- // If the final result has already been obtained, return
`Poll::Ready(None)`:
- if self.final_result {
- return Poll::Ready(None);
- }
- self.final_result = true;
- // Get the left side results:
- let left_result = build_side_determined_results(
- &self.left,
- &self.schema,
- self.left.input_buffer.num_rows(),
- self.right.input_buffer.schema(),
- self.join_type,
- &self.column_indices,
- )?;
- // Get the right side results:
- let right_result = build_side_determined_results(
- &self.right,
- &self.schema,
- self.right.input_buffer.num_rows(),
- self.left.input_buffer.schema(),
- self.join_type,
- &self.column_indices,
- )?;
-
- // Combine the left and right results:
- let result =
- combine_two_batches(&self.schema, left_result,
right_result)?;
-
- // Update the metrics and return the result:
- if let Some(batch) = &result {
- // Update the metrics:
- self.metrics.output_batches.add(1);
- self.metrics.output_rows.add(batch.num_rows());
- return Poll::Ready(Ok(result).transpose());
- }
- }
- Poll::Pending => return Poll::Pending,
- }
+ probe_batch: RecordBatch,
+ probe_side: JoinSide,
+ ) -> Result<Option<RecordBatch>> {
+ let (
+ probe_hash_joiner,
+ build_hash_joiner,
+ probe_side_sorted_filter_expr,
+ build_side_sorted_filter_expr,
+ probe_side_metrics,
+ ) = if probe_side.eq(&JoinSide::Left) {
+ (
+ &mut self.left,
+ &mut self.right,
+ &mut self.left_sorted_filter_expr,
+ &mut self.right_sorted_filter_expr,
+ &mut self.metrics.left,
+ )
+ } else {
+ (
+ &mut self.right,
+ &mut self.left,
+ &mut self.right_sorted_filter_expr,
+ &mut self.left_sorted_filter_expr,
+ &mut self.metrics.right,
+ )
+ };
+ // Update the metrics for the stream that was polled:
+ probe_side_metrics.input_batches.add(1);
+ probe_side_metrics.input_rows.add(probe_batch.num_rows());
+ // Update the internal state of the hash joiner for the build side:
+ probe_hash_joiner.update_internal_state(&probe_batch,
&self.random_state)?;
+ // Join the two sides:
+ let equal_result = join_with_probe_batch(
+ build_hash_joiner,
+ probe_hash_joiner,
+ &self.schema,
+ self.join_type,
+ self.filter.as_ref(),
+ &probe_batch,
+ &self.column_indices,
+ &self.random_state,
+ self.null_equals_null,
+ )?;
+ // Increment the offset for the probe hash joiner:
+ probe_hash_joiner.offset += probe_batch.num_rows();
+
+ let anti_result = if let (
+ Some(build_side_sorted_filter_expr),
+ Some(probe_side_sorted_filter_expr),
+ Some(graph),
+ ) = (
+ build_side_sorted_filter_expr.as_mut(),
+ probe_side_sorted_filter_expr.as_mut(),
+ self.graph.as_mut(),
+ ) {
+ // Calculate filter intervals:
+ calculate_filter_expr_intervals(
+ &build_hash_joiner.input_buffer,
+ build_side_sorted_filter_expr,
+ &probe_batch,
+ probe_side_sorted_filter_expr,
+ )?;
+ let prune_length = build_hash_joiner
+ .calculate_prune_length_with_probe_batch(
+ build_side_sorted_filter_expr,
+ probe_side_sorted_filter_expr,
+ graph,
+ )?;
+ let result = build_side_determined_results(
+ build_hash_joiner,
+ &self.schema,
+ prune_length,
+ probe_batch.schema(),
+ self.join_type,
+ &self.column_indices,
+ )?;
+ build_hash_joiner.prune_internal_state(prune_length)?;
+ result
+ } else {
+ None
+ };
+
+ // Combine results:
+ let result = combine_two_batches(&self.schema, equal_result,
anti_result)?;
+ let capacity = self.size();
+ self.metrics.stream_memory_usage.set(capacity);
+ self.reservation.lock().try_resize(capacity)?;
+ // Update the metrics if we have a batch; otherwise, continue the loop.
+ if let Some(batch) = &result {
+ self.metrics.output_batches.add(1);
+ self.metrics.output_rows.add(batch.num_rows());
}
+ Ok(result)
}
}
@@ -1203,10 +1233,9 @@ mod tests {
use std::sync::Mutex;
use super::*;
- use crate::joins::hash_join_utils::tests::complicated_filter;
use crate::joins::test_utils::{
- build_sides_record_batches, compare_batches, create_memory_table,
- join_expr_tests_fixture_f64, join_expr_tests_fixture_i32,
+ build_sides_record_batches, compare_batches, complicated_filter,
+ create_memory_table, join_expr_tests_fixture_f64,
join_expr_tests_fixture_i32,
join_expr_tests_fixture_temporal, partitioned_hash_join_with_filter,
partitioned_sym_join_with_filter, split_record_batches,
};
@@ -1833,6 +1862,73 @@ mod tests {
Ok(())
}
+ #[tokio::test(flavor = "multi_thread")]
+ async fn complex_join_all_one_ascending_equivalence() -> Result<()> {
+ let cardinality = (3, 4);
+ let join_type = JoinType::Full;
+
+ // a + b > c + 10 AND a + b < c + 100
+ let config = SessionConfig::new().with_repartition_joins(false);
+ // let session_ctx = SessionContext::with_config(config);
+ // let task_ctx = session_ctx.task_ctx();
+ let task_ctx =
Arc::new(TaskContext::default().with_session_config(config));
+ let (left_partition, right_partition) =
get_or_create_table(cardinality, 8)?;
+ let left_schema = &left_partition[0].schema();
+ let right_schema = &right_partition[0].schema();
+ let left_sorted = vec![
+ vec![PhysicalSortExpr {
+ expr: col("la1", left_schema)?,
+ options: SortOptions::default(),
+ }],
+ vec![PhysicalSortExpr {
+ expr: col("la2", left_schema)?,
+ options: SortOptions::default(),
+ }],
+ ];
+
+ let right_sorted = vec![PhysicalSortExpr {
+ expr: col("ra1", right_schema)?,
+ options: SortOptions::default(),
+ }];
+
+ let (left, right) = create_memory_table(
+ left_partition,
+ right_partition,
+ left_sorted,
+ vec![right_sorted],
+ )?;
+
+ let on = vec![(
+ Column::new_with_schema("lc1", left_schema)?,
+ Column::new_with_schema("rc1", right_schema)?,
+ )];
+
+ let intermediate_schema = Schema::new(vec![
+ Field::new("0", DataType::Int32, true),
+ Field::new("1", DataType::Int32, true),
+ Field::new("2", DataType::Int32, true),
+ ]);
+ let filter_expr = complicated_filter(&intermediate_schema)?;
+ let column_indices = vec![
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 4,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Right,
+ },
+ ];
+ let filter = JoinFilter::new(filter_expr, column_indices,
intermediate_schema);
+
+ experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
+ Ok(())
+ }
+
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn testing_with_temporal_columns(
@@ -1917,6 +2013,7 @@ mod tests {
experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
Ok(())
}
+
#[rstest]
#[tokio::test(flavor = "multi_thread")]
async fn test_with_interval_columns(
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs
b/datafusion/physical-plan/src/joins/test_utils.rs
index bb4a861991..6deaa9ba1b 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -17,6 +17,9 @@
//! This file has test utils for hash joins
+use std::sync::Arc;
+use std::usize;
+
use crate::joins::utils::{JoinFilter, JoinOn};
use crate::joins::{
HashJoinExec, PartitionMode, StreamJoinPartitionMode,
SymmetricHashJoinExec,
@@ -24,24 +27,24 @@ use crate::joins::{
use crate::memory::MemoryExec;
use crate::repartition::RepartitionExec;
use crate::{common, ExecutionPlan, Partitioning};
+
use arrow::util::pretty::pretty_format_batches;
use arrow_array::{
ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch,
TimestampMillisecondArray,
};
-use arrow_schema::Schema;
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
+use arrow_schema::{DataType, Schema};
+use datafusion_common::{Result, ScalarValue};
use datafusion_execution::TaskContext;
use datafusion_expr::{JoinType, Operator};
+use datafusion_physical_expr::expressions::{binary, cast, col, lit};
use datafusion_physical_expr::intervals::test_utils::{
gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr,
};
use datafusion_physical_expr::{LexOrdering, PhysicalExpr};
+
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
-use std::sync::Arc;
-use std::usize;
pub fn compare_batches(collected_1: &[RecordBatch], collected_2:
&[RecordBatch]) {
// compare
@@ -500,3 +503,51 @@ pub fn create_memory_table(
.with_sort_information(right_sorted);
Ok((Arc::new(left), Arc::new(right)))
}
+
+/// Filter expr for a + b > c + 10 AND a + b < c + 100
+pub(crate) fn complicated_filter(
+ filter_schema: &Schema,
+) -> Result<Arc<dyn PhysicalExpr>> {
+ let left_expr = binary(
+ cast(
+ binary(
+ col("0", filter_schema)?,
+ Operator::Plus,
+ col("1", filter_schema)?,
+ filter_schema,
+ )?,
+ filter_schema,
+ DataType::Int64,
+ )?,
+ Operator::Gt,
+ binary(
+ cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
+ Operator::Plus,
+ lit(ScalarValue::Int64(Some(10))),
+ filter_schema,
+ )?,
+ filter_schema,
+ )?;
+
+ let right_expr = binary(
+ cast(
+ binary(
+ col("0", filter_schema)?,
+ Operator::Plus,
+ col("1", filter_schema)?,
+ filter_schema,
+ )?,
+ filter_schema,
+ DataType::Int64,
+ )?,
+ Operator::Lt,
+ binary(
+ cast(col("2", filter_schema)?, filter_schema, DataType::Int64)?,
+ Operator::Plus,
+ lit(ScalarValue::Int64(Some(100))),
+ filter_schema,
+ )?,
+ filter_schema,
+ )?;
+ binary(left_expr, Operator::And, right_expr, filter_schema)
+}
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index f93f08255e..0729d365d6 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -18,12 +18,14 @@
//! Join related functionality used both on logical and physical plans
use std::collections::HashSet;
+use std::fmt::{self, Debug};
use std::future::Future;
+use std::ops::IndexMut;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::usize;
-use crate::joins::hash_join_utils::{build_filter_input_order,
SortedFilterExpr};
+use crate::joins::stream_join_utils::{build_filter_input_order,
SortedFilterExpr};
use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder};
use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics};
@@ -50,8 +52,135 @@ use datafusion_physical_expr::{
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
+use hashbrown::raw::RawTable;
use parking_lot::Mutex;
+/// Maps a `u64` hash value based on the build side ["on" values] to a list of
indices with this key's value.
+///
+/// By allocating a `HashMap` with capacity for *at least* the number of rows
for entries at the build side,
+/// we make sure that we don't have to re-hash the hashmap, which needs access
to the key (the hash in this case) value.
+///
+/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and
8 for hash value 1
+/// As the key is a hash value, we need to check possible hash collisions in
the probe stage
+/// During this stage it might be the case that a row is contained the same
hashmap value,
+/// but the values don't match. Those are checked in the
[`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method.
+///
+/// The indices (values) are stored in a separate chained list stored in the
`Vec<u64>`.
+///
+/// The first value (+1) is stored in the hashmap, whereas the next value is
stored in array at the position value.
+///
+/// The chain can be followed until the value "0" has been reached, meaning
the end of the list.
+/// Also see chapter 5.3 of [Balancing vectorized query execution with
bandwidth-optimized
storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487)
+///
+/// # Example
+///
+/// ``` text
+/// See the example below:
+///
+/// Insert (10,1) <-- insert hash value 10 with row index 1
+/// map:
+/// ----------
+/// | 10 | 2 |
+/// ----------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 0 | 0 |
+/// ---------------------
+/// Insert (20,2)
+/// map:
+/// ----------
+/// | 10 | 2 |
+/// | 20 | 3 |
+/// ----------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 0 | 0 |
+/// ---------------------
+/// Insert (10,3) <-- collision! row index 3 has a hash value of 10
as well
+/// map:
+/// ----------
+/// | 10 | 4 |
+/// | 20 | 3 |
+/// ----------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 10 maps to 4,2 (which means indices
values 3,1)
+/// ---------------------
+/// Insert (10,4) <-- another collision! row index 4 ALSO has a hash
value of 10
+/// map:
+/// ---------
+/// | 10 | 5 |
+/// | 20 | 3 |
+/// ---------
+/// next:
+/// ---------------------
+/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 10 maps to 5,4,2 (which means
indices values 4,3,1)
+/// ---------------------
+/// ```
+pub struct JoinHashMap {
+ // Stores hash value to last row index
+ map: RawTable<(u64, u64)>,
+ // Stores indices in chained list data structure
+ next: Vec<u64>,
+}
+
+impl JoinHashMap {
+ #[cfg(test)]
+ pub(crate) fn new(map: RawTable<(u64, u64)>, next: Vec<u64>) -> Self {
+ Self { map, next }
+ }
+
+ pub(crate) fn with_capacity(capacity: usize) -> Self {
+ JoinHashMap {
+ map: RawTable::with_capacity(capacity),
+ next: vec![0; capacity],
+ }
+ }
+}
+
+// Trait defining methods that must be implemented by a hash map type to be
used for joins.
+pub trait JoinHashMapType {
+ /// The type of list used to store the next list
+ type NextType: IndexMut<usize, Output = u64>;
+ /// Extend with zero
+ fn extend_zero(&mut self, len: usize);
+ /// Returns mutable references to the hash map and the next.
+ fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType);
+ /// Returns a reference to the hash map.
+ fn get_map(&self) -> &RawTable<(u64, u64)>;
+ /// Returns a reference to the next.
+ fn get_list(&self) -> &Self::NextType;
+}
+
+/// Implementation of `JoinHashMapType` for `JoinHashMap`.
+impl JoinHashMapType for JoinHashMap {
+ type NextType = Vec<u64>;
+
+ // Void implementation
+ fn extend_zero(&mut self, _: usize) {}
+
+ /// Get mutable references to the hash map and the next.
+ fn get_mut(&mut self) -> (&mut RawTable<(u64, u64)>, &mut Self::NextType) {
+ (&mut self.map, &mut self.next)
+ }
+
+ /// Get a reference to the hash map.
+ fn get_map(&self) -> &RawTable<(u64, u64)> {
+ &self.map
+ }
+
+ /// Get a reference to the next.
+ fn get_list(&self) -> &Self::NextType {
+ &self.next
+ }
+}
+
+impl fmt::Debug for JoinHashMap {
+ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
+ Ok(())
+ }
+}
+
/// The on clause of the join, as vector of (left, right) columns.
pub type JoinOn = Vec<(Column, Column)>;
/// Reference for JoinOn.
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 9d508078c7..9197343d74 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1155,6 +1155,7 @@ message PhysicalPlanNode {
NestedLoopJoinExecNode nested_loop_join = 22;
AnalyzeExecNode analyze = 23;
JsonSinkExecNode json_sink = 24;
+ SymmetricHashJoinExecNode symmetric_hash_join = 25;
}
}
@@ -1432,6 +1433,21 @@ message HashJoinExecNode {
JoinFilter filter = 8;
}
+enum StreamPartitionMode {
+ SINGLE_PARTITION = 0;
+ PARTITIONED_EXEC = 1;
+}
+
+message SymmetricHashJoinExecNode {
+ PhysicalPlanNode left = 1;
+ PhysicalPlanNode right = 2;
+ repeated JoinOn on = 3;
+ JoinType join_type = 4;
+ StreamPartitionMode partition_mode = 6;
+ bool null_equals_null = 7;
+ JoinFilter filter = 8;
+}
+
message UnionExecNode {
repeated PhysicalPlanNode inputs = 1;
}
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 0a8f415e20..8a63600237 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -17844,6 +17844,9 @@ impl serde::Serialize for PhysicalPlanNode {
physical_plan_node::PhysicalPlanType::JsonSink(v) => {
struct_ser.serialize_field("jsonSink", v)?;
}
+ physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => {
+ struct_ser.serialize_field("symmetricHashJoin", v)?;
+ }
}
}
struct_ser.end()
@@ -17890,6 +17893,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"analyze",
"json_sink",
"jsonSink",
+ "symmetric_hash_join",
+ "symmetricHashJoin",
];
#[allow(clippy::enum_variant_names)]
@@ -17917,6 +17922,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
NestedLoopJoin,
Analyze,
JsonSink,
+ SymmetricHashJoin,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -17961,6 +17967,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"nestedLoopJoin" | "nested_loop_join" =>
Ok(GeneratedField::NestedLoopJoin),
"analyze" => Ok(GeneratedField::Analyze),
"jsonSink" | "json_sink" =>
Ok(GeneratedField::JsonSink),
+ "symmetricHashJoin" | "symmetric_hash_join" =>
Ok(GeneratedField::SymmetricHashJoin),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -18142,6 +18149,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode
{
return
Err(serde::de::Error::duplicate_field("jsonSink"));
}
physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink)
+;
+ }
+ GeneratedField::SymmetricHashJoin => {
+ if physical_plan_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("symmetricHashJoin"));
+ }
+ physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin)
;
}
}
@@ -23648,6 +23662,77 @@ impl<'de> serde::Deserialize<'de> for Statistics {
deserializer.deserialize_struct("datafusion.Statistics", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for StreamPartitionMode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ let variant = match self {
+ Self::SinglePartition => "SINGLE_PARTITION",
+ Self::PartitionedExec => "PARTITIONED_EXEC",
+ };
+ serializer.serialize_str(variant)
+ }
+}
+impl<'de> serde::Deserialize<'de> for StreamPartitionMode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "SINGLE_PARTITION",
+ "PARTITIONED_EXEC",
+ ];
+
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = StreamPartitionMode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value,
E>
+ where
+ E: serde::de::Error,
+ {
+ i32::try_from(v)
+ .ok()
+ .and_then(|x| x.try_into().ok())
+ .ok_or_else(|| {
+
serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self)
+ })
+ }
+
+ fn visit_u64<E>(self, v: u64) -> std::result::Result<Self::Value,
E>
+ where
+ E: serde::de::Error,
+ {
+ i32::try_from(v)
+ .ok()
+ .and_then(|x| x.try_into().ok())
+ .ok_or_else(|| {
+
serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self)
+ })
+ }
+
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<Self::Value, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "SINGLE_PARTITION" =>
Ok(StreamPartitionMode::SinglePartition),
+ "PARTITIONED_EXEC" =>
Ok(StreamPartitionMode::PartitionedExec),
+ _ => Err(serde::de::Error::unknown_variant(value, FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_any(GeneratedVisitor)
+ }
+}
impl serde::Serialize for StringifiedPlan {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -24066,6 +24151,206 @@ impl<'de> serde::Deserialize<'de> for
SubqueryAliasNode {
deserializer.deserialize_struct("datafusion.SubqueryAliasNode",
FIELDS, GeneratedVisitor)
}
}
+impl serde::Serialize for SymmetricHashJoinExecNode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.left.is_some() {
+ len += 1;
+ }
+ if self.right.is_some() {
+ len += 1;
+ }
+ if !self.on.is_empty() {
+ len += 1;
+ }
+ if self.join_type != 0 {
+ len += 1;
+ }
+ if self.partition_mode != 0 {
+ len += 1;
+ }
+ if self.null_equals_null {
+ len += 1;
+ }
+ if self.filter.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?;
+ if let Some(v) = self.left.as_ref() {
+ struct_ser.serialize_field("left", v)?;
+ }
+ if let Some(v) = self.right.as_ref() {
+ struct_ser.serialize_field("right", v)?;
+ }
+ if !self.on.is_empty() {
+ struct_ser.serialize_field("on", &self.on)?;
+ }
+ if self.join_type != 0 {
+ let v = JoinType::try_from(self.join_type)
+ .map_err(|_| serde::ser::Error::custom(format!("Invalid
variant {}", self.join_type)))?;
+ struct_ser.serialize_field("joinType", &v)?;
+ }
+ if self.partition_mode != 0 {
+ let v = StreamPartitionMode::try_from(self.partition_mode)
+ .map_err(|_| serde::ser::Error::custom(format!("Invalid
variant {}", self.partition_mode)))?;
+ struct_ser.serialize_field("partitionMode", &v)?;
+ }
+ if self.null_equals_null {
+ struct_ser.serialize_field("nullEqualsNull",
&self.null_equals_null)?;
+ }
+ if let Some(v) = self.filter.as_ref() {
+ struct_ser.serialize_field("filter", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "left",
+ "right",
+ "on",
+ "join_type",
+ "joinType",
+ "partition_mode",
+ "partitionMode",
+ "null_equals_null",
+ "nullEqualsNull",
+ "filter",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Left,
+ Right,
+ On,
+ JoinType,
+ PartitionMode,
+ NullEqualsNull,
+ Filter,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "left" => Ok(GeneratedField::Left),
+ "right" => Ok(GeneratedField::Right),
+ "on" => Ok(GeneratedField::On),
+ "joinType" | "join_type" =>
Ok(GeneratedField::JoinType),
+ "partitionMode" | "partition_mode" =>
Ok(GeneratedField::PartitionMode),
+ "nullEqualsNull" | "null_equals_null" =>
Ok(GeneratedField::NullEqualsNull),
+ "filter" => Ok(GeneratedField::Filter),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = SymmetricHashJoinExecNode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct
datafusion.SymmetricHashJoinExecNode")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<SymmetricHashJoinExecNode, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut left__ = None;
+ let mut right__ = None;
+ let mut on__ = None;
+ let mut join_type__ = None;
+ let mut partition_mode__ = None;
+ let mut null_equals_null__ = None;
+ let mut filter__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Left => {
+ if left__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("left"));
+ }
+ left__ = map_.next_value()?;
+ }
+ GeneratedField::Right => {
+ if right__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("right"));
+ }
+ right__ = map_.next_value()?;
+ }
+ GeneratedField::On => {
+ if on__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("on"));
+ }
+ on__ = Some(map_.next_value()?);
+ }
+ GeneratedField::JoinType => {
+ if join_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("joinType"));
+ }
+ join_type__ = Some(map_.next_value::<JoinType>()?
as i32);
+ }
+ GeneratedField::PartitionMode => {
+ if partition_mode__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("partitionMode"));
+ }
+ partition_mode__ =
Some(map_.next_value::<StreamPartitionMode>()? as i32);
+ }
+ GeneratedField::NullEqualsNull => {
+ if null_equals_null__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("nullEqualsNull"));
+ }
+ null_equals_null__ = Some(map_.next_value()?);
+ }
+ GeneratedField::Filter => {
+ if filter__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("filter"));
+ }
+ filter__ = map_.next_value()?;
+ }
+ }
+ }
+ Ok(SymmetricHashJoinExecNode {
+ left: left__,
+ right: right__,
+ on: on__.unwrap_or_default(),
+ join_type: join_type__.unwrap_or_default(),
+ partition_mode: partition_mode__.unwrap_or_default(),
+ null_equals_null: null_equals_null__.unwrap_or_default(),
+ filter: filter__,
+ })
+ }
+ }
+
deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS,
GeneratedVisitor)
+ }
+}
impl serde::Serialize for TimeUnit {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 84fb84b948..4fb8e1599e 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1514,7 +1514,7 @@ pub mod owned_table_reference {
pub struct PhysicalPlanNode {
#[prost(
oneof = "physical_plan_node::PhysicalPlanType",
- tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24"
+ tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25"
)]
pub physical_plan_type:
::core::option::Option<physical_plan_node::PhysicalPlanType>,
}
@@ -1571,6 +1571,8 @@ pub mod physical_plan_node {
Analyze(::prost::alloc::boxed::Box<super::AnalyzeExecNode>),
#[prost(message, tag = "24")]
JsonSink(::prost::alloc::boxed::Box<super::JsonSinkExecNode>),
+ #[prost(message, tag = "25")]
+
SymmetricHashJoin(::prost::alloc::boxed::Box<super::SymmetricHashJoinExecNode>),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
@@ -2009,6 +2011,24 @@ pub struct HashJoinExecNode {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct SymmetricHashJoinExecNode {
+ #[prost(message, optional, boxed, tag = "1")]
+ pub left:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, optional, boxed, tag = "2")]
+ pub right:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, repeated, tag = "3")]
+ pub on: ::prost::alloc::vec::Vec<JoinOn>,
+ #[prost(enumeration = "JoinType", tag = "4")]
+ pub join_type: i32,
+ #[prost(enumeration = "StreamPartitionMode", tag = "6")]
+ pub partition_mode: i32,
+ #[prost(bool, tag = "7")]
+ pub null_equals_null: bool,
+ #[prost(message, optional, tag = "8")]
+ pub filter: ::core::option::Option<JoinFilter>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct UnionExecNode {
#[prost(message, repeated, tag = "1")]
pub inputs: ::prost::alloc::vec::Vec<PhysicalPlanNode>,
@@ -3265,6 +3285,32 @@ impl PartitionMode {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord,
::prost::Enumeration)]
#[repr(i32)]
+pub enum StreamPartitionMode {
+ SinglePartition = 0,
+ PartitionedExec = 1,
+}
+impl StreamPartitionMode {
+ /// String value of the enum field names used in the ProtoBuf definition.
+ ///
+ /// The values are not transformed in any way and thus are considered
stable
+ /// (if the ProtoBuf definition does not change) and safe for programmatic
use.
+ pub fn as_str_name(&self) -> &'static str {
+ match self {
+ StreamPartitionMode::SinglePartition => "SINGLE_PARTITION",
+ StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC",
+ }
+ }
+ /// Creates an enum from field names used in the ProtoBuf definition.
+ pub fn from_str_name(value: &str) -> ::core::option::Option<Self> {
+ match value {
+ "SINGLE_PARTITION" => Some(Self::SinglePartition),
+ "PARTITIONED_EXEC" => Some(Self::PartitionedExec),
+ _ => None,
+ }
+ }
+}
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord,
::prost::Enumeration)]
+#[repr(i32)]
pub enum AggregateMode {
Partial = 0,
Final = 1,
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 1eedbe987e..6714c35dc6 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -39,7 +39,9 @@ use datafusion::physical_plan::expressions::{Column,
PhysicalSortExpr};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::insert::FileSinkExec;
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
-use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec};
+use datafusion::physical_plan::joins::{
+ CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode,
SymmetricHashJoinExec,
+};
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
@@ -583,6 +585,97 @@ impl AsExecutionPlan for PhysicalPlanNode {
hashjoin.null_equals_null,
)?))
}
+ PhysicalPlanType::SymmetricHashJoin(sym_join) => {
+ let left = into_physical_plan(
+ &sym_join.left,
+ registry,
+ runtime,
+ extension_codec,
+ )?;
+ let right = into_physical_plan(
+ &sym_join.right,
+ registry,
+ runtime,
+ extension_codec,
+ )?;
+ let on = sym_join
+ .on
+ .iter()
+ .map(|col| {
+ let left = into_required!(col.left)?;
+ let right = into_required!(col.right)?;
+ Ok((left, right))
+ })
+ .collect::<Result<_>>()?;
+ let join_type =
protobuf::JoinType::try_from(sym_join.join_type)
+ .map_err(|_| {
+ proto_error(format!(
+ "Received a SymmetricHashJoin message with unknown
JoinType {}",
+ sym_join.join_type
+ ))
+ })?;
+ let filter = sym_join
+ .filter
+ .as_ref()
+ .map(|f| {
+ let schema = f
+ .schema
+ .as_ref()
+ .ok_or_else(|| proto_error("Missing JoinFilter
schema"))?
+ .try_into()?;
+
+ let expression = parse_physical_expr(
+ f.expression.as_ref().ok_or_else(|| {
+ proto_error("Unexpected empty filter
expression")
+ })?,
+ registry, &schema
+ )?;
+ let column_indices = f.column_indices
+ .iter()
+ .map(|i| {
+ let side = protobuf::JoinSide::try_from(i.side)
+ .map_err(|_| proto_error(format!(
+ "Received a HashJoinNode message with
JoinSide in Filter {}",
+ i.side))
+ )?;
+
+ Ok(ColumnIndex{
+ index: i.index as usize,
+ side: side.into(),
+ })
+ })
+ .collect::<Result<_>>()?;
+
+ Ok(JoinFilter::new(expression, column_indices, schema))
+ })
+ .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+
+ let partition_mode =
+
protobuf::StreamPartitionMode::try_from(sym_join.partition_mode).map_err(|_| {
+ proto_error(format!(
+ "Received a SymmetricHashJoin message with unknown
PartitionMode {}",
+ sym_join.partition_mode
+ ))
+ })?;
+ let partition_mode = match partition_mode {
+ protobuf::StreamPartitionMode::SinglePartition => {
+ StreamJoinPartitionMode::SinglePartition
+ }
+ protobuf::StreamPartitionMode::PartitionedExec => {
+ StreamJoinPartitionMode::Partitioned
+ }
+ };
+ SymmetricHashJoinExec::try_new(
+ left,
+ right,
+ on,
+ filter,
+ &join_type.into(),
+ sym_join.null_equals_null,
+ partition_mode,
+ )
+ .map(|e| Arc::new(e) as _)
+ }
PhysicalPlanType::Union(union) => {
let mut inputs: Vec<Arc<dyn ExecutionPlan>> = vec![];
for input in &union.inputs {
@@ -1008,6 +1101,79 @@ impl AsExecutionPlan for PhysicalPlanNode {
});
}
+ if let Some(exec) = plan.downcast_ref::<SymmetricHashJoinExec>() {
+ let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec.left().to_owned(),
+ extension_codec,
+ )?;
+ let right = protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec.right().to_owned(),
+ extension_codec,
+ )?;
+ let on = exec
+ .on()
+ .iter()
+ .map(|tuple| protobuf::JoinOn {
+ left: Some(protobuf::PhysicalColumn {
+ name: tuple.0.name().to_string(),
+ index: tuple.0.index() as u32,
+ }),
+ right: Some(protobuf::PhysicalColumn {
+ name: tuple.1.name().to_string(),
+ index: tuple.1.index() as u32,
+ }),
+ })
+ .collect();
+ let join_type: protobuf::JoinType =
exec.join_type().to_owned().into();
+ let filter = exec
+ .filter()
+ .as_ref()
+ .map(|f| {
+ let expression = f.expression().to_owned().try_into()?;
+ let column_indices = f
+ .column_indices()
+ .iter()
+ .map(|i| {
+ let side: protobuf::JoinSide =
i.side.to_owned().into();
+ protobuf::ColumnIndex {
+ index: i.index as u32,
+ side: side.into(),
+ }
+ })
+ .collect();
+ let schema = f.schema().try_into()?;
+ Ok(protobuf::JoinFilter {
+ expression: Some(expression),
+ column_indices,
+ schema: Some(schema),
+ })
+ })
+ .map_or(Ok(None), |v: Result<protobuf::JoinFilter>|
v.map(Some))?;
+
+ let partition_mode = match exec.partition_mode() {
+ StreamJoinPartitionMode::SinglePartition => {
+ protobuf::StreamPartitionMode::SinglePartition
+ }
+ StreamJoinPartitionMode::Partitioned => {
+ protobuf::StreamPartitionMode::PartitionedExec
+ }
+ };
+
+ return Ok(protobuf::PhysicalPlanNode {
+ physical_plan_type:
Some(PhysicalPlanType::SymmetricHashJoin(Box::new(
+ protobuf::SymmetricHashJoinExecNode {
+ left: Some(Box::new(left)),
+ right: Some(Box::new(right)),
+ on,
+ join_type: join_type.into(),
+ partition_mode: partition_mode.into(),
+ null_equals_null: exec.null_equals_null(),
+ filter,
+ },
+ ))),
+ });
+ }
+
if let Some(exec) = plan.downcast_ref::<CrossJoinExec>() {
let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
exec.left().to_owned(),
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 23b0ea43c7..d7d762d470 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -45,7 +45,9 @@ use datafusion::physical_plan::expressions::{
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::physical_plan::insert::FileSinkExec;
-use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec,
PartitionMode};
+use datafusion::physical_plan::joins::{
+ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
+};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::sorts::sort::SortExec;
@@ -754,3 +756,45 @@ fn roundtrip_json_sink() -> Result<()> {
Some(sort_order),
)))
}
+
+#[test]
+fn roundtrip_sym_hash_join() -> Result<()> {
+ let field_a = Field::new("col", DataType::Int64, false);
+ let schema_left = Schema::new(vec![field_a.clone()]);
+ let schema_right = Schema::new(vec![field_a]);
+ let on = vec![(
+ Column::new("col", schema_left.index_of("col")?),
+ Column::new("col", schema_right.index_of("col")?),
+ )];
+
+ let schema_left = Arc::new(schema_left);
+ let schema_right = Arc::new(schema_right);
+ for join_type in &[
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftAnti,
+ JoinType::RightAnti,
+ JoinType::LeftSemi,
+ JoinType::RightSemi,
+ ] {
+ for partition_mode in &[
+ StreamJoinPartitionMode::Partitioned,
+ StreamJoinPartitionMode::SinglePartition,
+ ] {
+ roundtrip_test(Arc::new(
+
datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new(
+ Arc::new(EmptyExec::new(false, schema_left.clone())),
+ Arc::new(EmptyExec::new(false, schema_right.clone())),
+ on.clone(),
+ None,
+ join_type,
+ false,
+ *partition_mode,
+ )?,
+ ))?;
+ }
+ }
+ Ok(())
+}