This is an automated email from the ASF dual-hosted git repository. ytyou pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 033fc9b0a1 chore: split hash join to smaller modules (#17300) 033fc9b0a1 is described below commit 033fc9b0a124082c6b4b6c68a8445c16bdaa51a1 Author: Yongting You <2010you...@gmail.com> AuthorDate: Wed Aug 27 17:56:18 2025 +0800 chore: split hash join to smaller modules (#17300) * split hash join to smaller modules * small cleanup * fix cargo doc * review * rust doc * fix visibility --- .../src/joins/{hash_join.rs => hash_join/exec.rs} | 1020 +------------------- .../physical-plan/src/joins/hash_join/mod.rs | 24 + .../src/joins/hash_join/shared_bounds.rs | 296 ++++++ .../physical-plan/src/joins/hash_join/stream.rs | 628 ++++++++++++ .../physical-plan/src/joins/symmetric_hash_join.rs | 7 +- datafusion/physical-plan/src/joins/utils.rs | 123 ++- 6 files changed, 1121 insertions(+), 977 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs similarity index 81% rename from datafusion/physical-plan/src/joins/hash_join.rs rename to datafusion/physical-plan/src/joins/hash_join/exec.rs index 80f1de5a5b..359d36a29c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -15,30 +15,27 @@ // specific language governing permissions and limitations // under the License. -//! [`HashJoinExec`] Partitioned Hash Join Operator - use std::fmt; use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; -use std::task::Poll; use std::{any::Any, vec}; -use super::utils::{ - asymmetric_join_output_partitioning, get_final_indices_from_shared_bitmap, - reorder_output_after_swap, swap_join_projection, -}; -use super::{ - utils::{OnceAsync, OnceFut}, - PartitionMode, SharedBitmapBuilder, -}; -use super::{JoinOn, JoinOnRef}; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; +use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; +use crate::joins::hash_join::stream::{ + BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, +}; use crate::joins::join_hash_map::{JoinHashMapU32, JoinHashMapU64}; +use crate::joins::utils::{ + asymmetric_join_output_partitioning, reorder_output_after_swap, swap_join_projection, + update_hash, OnceAsync, OnceFut, +}; +use crate::joins::{JoinOn, JoinOnRef, PartitionMode, SharedBitmapBuilder}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, @@ -47,324 +44,49 @@ use crate::spill::get_record_batch_memory_size; use crate::ExecutionPlanProperties; use crate::{ common::can_project, - handle_state, - hash_utils::create_hashes, - joins::join_hash_map::JoinHashMapOffset, joins::utils::{ - adjust_indices_by_join_type, apply_join_filter_to_indices, - build_batch_empty_build_side, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, need_produce_result_in_final, - symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, - JoinFilter, JoinHashMapType, StatefulStreamResult, + build_join_schema, check_join_is_valid, estimate_join_statistics, + need_produce_result_in_final, symmetric_join_output_partitioning, + BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, - PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, + PlanProperties, SendableRecordBatchStream, Statistics, }; -use arrow::array::{ - cast::downcast_array, Array, ArrayRef, BooleanArray, BooleanBufferBuilder, - UInt32Array, UInt64Array, -}; -use arrow::compute::kernels::cmp::{eq, not_distinct}; -use arrow::compute::{and, concat_batches, take, FilterBuilder}; -use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::ArrowError; +use arrow::array::{Array, ArrayRef, BooleanBufferBuilder}; +use arrow::compute::concat_batches; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ - internal_datafusion_err, internal_err, plan_err, project_schema, JoinSide, JoinType, - NullEquality, Result, ScalarValue, + internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, + ScalarValue, }; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_expr::Operator; use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; -use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use datafusion_physical_expr_common::datum::compare_op_for_nested; use ahash::RandomState; use datafusion_physical_expr_common::physical_expr::fmt_sql; -use futures::{ready, Stream, StreamExt, TryStreamExt}; -use itertools::Itertools; +use futures::TryStreamExt; use parking_lot::Mutex; /// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. const HASH_JOIN_SEED: RandomState = RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); -/// Represents the minimum and maximum values for a specific column. -/// Used in dynamic filter pushdown to establish value boundaries. -#[derive(Debug, Clone, PartialEq)] -struct ColumnBounds { - /// The minimum value observed for this column - min: ScalarValue, - /// The maximum value observed for this column - max: ScalarValue, -} - -impl ColumnBounds { - fn new(min: ScalarValue, max: ScalarValue) -> Self { - Self { min, max } - } -} - -/// Represents the bounds for all join key columns from a single partition. -/// This contains the min/max values computed from one partition's build-side data. -#[derive(Debug, Clone)] -struct PartitionBounds { - /// Partition identifier for debugging and determinism (not strictly necessary) - partition: usize, - /// Min/max bounds for each join key column in this partition. - /// Index corresponds to the join key expression index. - column_bounds: Vec<ColumnBounds>, -} - -impl PartitionBounds { - fn new(partition: usize, column_bounds: Vec<ColumnBounds>) -> Self { - Self { - partition, - column_bounds, - } - } - - fn len(&self) -> usize { - self.column_bounds.len() - } - - fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> { - self.column_bounds.get(index) - } -} - -/// Coordinates dynamic filter bounds collection across multiple partitions -/// -/// This structure ensures that dynamic filters are built with complete information from all -/// relevant partitions before being applied to probe-side scans. Incomplete filters would -/// incorrectly eliminate valid join results. -/// -/// ## Synchronization Strategy -/// -/// 1. Each partition computes bounds from its build-side data -/// 2. Bounds are stored in the shared HashMap (indexed by partition_id) -/// 3. A counter tracks how many partitions have reported their bounds -/// 4. When the last partition reports (completed == total), bounds are merged and filter is updated -/// -/// ## Partition Counting -/// -/// The `total_partitions` count represents how many times `collect_build_side` will be called: -/// - **CollectLeft**: Number of output partitions (each accesses shared build data) -/// - **Partitioned**: Number of input partitions (each builds independently) -/// -/// ## Thread Safety -/// -/// All fields use a single mutex to ensure correct coordination between concurrent -/// partition executions. -struct SharedBoundsAccumulator { - /// Shared state protected by a single mutex to avoid ordering concerns - inner: Mutex<SharedBoundsState>, - /// Total number of partitions. - /// Need to know this so that we can update the dynamic filter once we are done - /// building *all* of the hash tables. - total_partitions: usize, - /// Dynamic filter for pushdown to probe side - dynamic_filter: Arc<DynamicFilterPhysicalExpr>, - /// Right side join expressions needed for creating filter bounds - on_right: Vec<PhysicalExprRef>, -} - -/// State protected by SharedBoundsAccumulator's mutex -struct SharedBoundsState { - /// Bounds from completed partitions. - /// Each element represents the column bounds computed by one partition. - bounds: Vec<PartitionBounds>, - /// Number of partitions that have reported completion. - completed_partitions: usize, -} - -impl SharedBoundsAccumulator { - /// Creates a new SharedBoundsAccumulator configured for the given partition mode - /// - /// This method calculates how many times `collect_build_side` will be called based on the - /// partition mode's execution pattern. This count is critical for determining when we have - /// complete information from all partitions to build the dynamic filter. - /// - /// ## Partition Mode Execution Patterns - /// - /// - **CollectLeft**: Build side is collected ONCE from partition 0 and shared via `OnceFut` - /// across all output partitions. Each output partition calls `collect_build_side` to access - /// the shared build data. Expected calls = number of output partitions. - /// - /// - **Partitioned**: Each partition independently builds its own hash table by calling - /// `collect_build_side` once. Expected calls = number of build partitions. - /// - /// - **Auto**: Placeholder mode resolved during optimization. Uses 1 as safe default since - /// the actual mode will be determined and a new bounds_accumulator created before execution. - /// - /// ## Why This Matters - /// - /// We cannot build a partial filter from some partitions - it would incorrectly eliminate - /// valid join results. We must wait until we have complete bounds information from ALL - /// relevant partitions before updating the dynamic filter. - fn new_from_partition_mode( - partition_mode: PartitionMode, - left_child: &dyn ExecutionPlan, - right_child: &dyn ExecutionPlan, - dynamic_filter: Arc<DynamicFilterPhysicalExpr>, - on_right: Vec<PhysicalExprRef>, - ) -> Self { - // Troubleshooting: If partition counts are incorrect, verify this logic matches - // the actual execution pattern in collect_build_side() - let expected_calls = match partition_mode { - // Each output partition accesses shared build data - PartitionMode::CollectLeft => { - right_child.output_partitioning().partition_count() - } - // Each partition builds its own data - PartitionMode::Partitioned => { - left_child.output_partitioning().partition_count() - } - // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) - PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), - }; - Self { - inner: Mutex::new(SharedBoundsState { - bounds: Vec::with_capacity(expected_calls), - completed_partitions: 0, - }), - total_partitions: expected_calls, - dynamic_filter, - on_right, - } - } - - /// Create a filter expression from individual partition bounds using OR logic. - /// - /// This creates a filter where each partition's bounds form a conjunction (AND) - /// of column range predicates, and all partitions are combined with OR. - /// - /// For example, with 2 partitions and 2 columns: - /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) - /// OR - /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) - fn create_filter_from_partition_bounds( - &self, - bounds: &[PartitionBounds], - ) -> Result<Arc<dyn PhysicalExpr>> { - if bounds.is_empty() { - return Ok(lit(true)); - } - - // Create a predicate for each partition - let mut partition_predicates = Vec::with_capacity(bounds.len()); - - for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { - // Create range predicates for each join key in this partition - let mut column_predicates = Vec::with_capacity(partition_bounds.len()); - - for (col_idx, right_expr) in self.on_right.iter().enumerate() { - if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { - // Create predicate: col >= min AND col <= max - let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::GtEq, - lit(column_bounds.min.clone()), - )) as Arc<dyn PhysicalExpr>; - let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::LtEq, - lit(column_bounds.max.clone()), - )) as Arc<dyn PhysicalExpr>; - let range_expr = - Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) - as Arc<dyn PhysicalExpr>; - column_predicates.push(range_expr); - } - } - - // Combine all column predicates for this partition with AND - if !column_predicates.is_empty() { - let partition_predicate = column_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::And, pred)) - as Arc<dyn PhysicalExpr> - }) - .unwrap(); - partition_predicates.push(partition_predicate); - } - } - - // Combine all partition predicates with OR - let combined_predicate = partition_predicates - .into_iter() - .reduce(|acc, pred| { - Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) - as Arc<dyn PhysicalExpr> - }) - .unwrap_or_else(|| lit(true)); - - Ok(combined_predicate) - } - - /// Report bounds from a completed partition and update dynamic filter if all partitions are done - /// - /// This method coordinates the dynamic filter updates across all partitions. It stores the - /// bounds from the current partition, increments the completion counter, and when all - /// partitions have reported, creates an OR'd filter from individual partition bounds. - /// - /// # Arguments - /// * `partition_bounds` - The bounds computed by this partition (if any) - /// - /// # Returns - /// * `Result<()>` - Ok if successful, Err if filter update failed - fn report_partition_bounds( - &self, - partition: usize, - partition_bounds: Option<Vec<ColumnBounds>>, - ) -> Result<()> { - let mut inner = self.inner.lock(); - - // Store bounds in the accumulator - this runs once per partition - if let Some(bounds) = partition_bounds { - // Only push actual bounds if they exist - inner.bounds.push(PartitionBounds::new(partition, bounds)); - } - - // Increment the completion counter - // Even empty partitions must report to ensure proper termination - inner.completed_partitions += 1; - let completed = inner.completed_partitions; - let total_partitions = self.total_partitions; - - // Critical synchronization point: Only update the filter when ALL partitions are complete - // Troubleshooting: If you see "completed > total_partitions", check partition - // count calculation in new_from_partition_mode() - it may not match actual execution calls - if completed == total_partitions && !inner.bounds.is_empty() { - let filter_expr = self.create_filter_from_partition_bounds(&inner.bounds)?; - self.dynamic_filter.update(filter_expr)?; - } - - Ok(()) - } -} - -impl fmt::Debug for SharedBoundsAccumulator { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SharedBoundsAccumulator") - } -} - /// HashTable and input data for the left (build side) of a join -struct JoinLeftData { +pub(super) struct JoinLeftData { /// The hash table with indices into `batch` - hash_map: Box<dyn JoinHashMapType>, + pub(super) hash_map: Box<dyn JoinHashMapType>, /// The input rows for the build side batch: RecordBatch, /// The build side on expressions values @@ -380,12 +102,12 @@ struct JoinLeftData { /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle. _reservation: MemoryReservation, /// Bounds computed from the build side for dynamic filter pushdown - bounds: Option<Vec<ColumnBounds>>, + pub(super) bounds: Option<Vec<ColumnBounds>>, } impl JoinLeftData { /// Create a new `JoinLeftData` from its parts - fn new( + pub(super) fn new( hash_map: Box<dyn JoinHashMapType>, batch: RecordBatch, values: Vec<ArrayRef>, @@ -406,28 +128,28 @@ impl JoinLeftData { } /// return a reference to the hash map - fn hash_map(&self) -> &dyn JoinHashMapType { + pub(super) fn hash_map(&self) -> &dyn JoinHashMapType { &*self.hash_map } /// returns a reference to the build side batch - fn batch(&self) -> &RecordBatch { + pub(super) fn batch(&self) -> &RecordBatch { &self.batch } /// returns a reference to the build side expressions values - fn values(&self) -> &[ArrayRef] { + pub(super) fn values(&self) -> &[ArrayRef] { &self.values } /// returns a reference to the visited indices bitmap - fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { + pub(super) fn visited_indices_bitmap(&self) -> &SharedBitmapBuilder { &self.visited_indices_bitmap } /// Decrements the counter of running threads, and returns `true` /// if caller is the last running thread - fn report_probe_completed(&self) -> bool { + pub(super) fn report_probe_completed(&self) -> bool { self.probe_threads_counter.fetch_sub(1, Ordering::Relaxed) == 1 } } @@ -662,6 +384,12 @@ impl fmt::Debug for HashJoinExec { } } +impl EmbeddedProjection for HashJoinExec { + fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> { + self.with_projection(projection) + } +} + impl HashJoinExec { /// Tries to create a new [HashJoinExec]. /// @@ -1260,24 +988,24 @@ impl ExecutionPlan for HashJoinExec { .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::<Vec<_>>(); - Ok(Box::pin(HashJoinStream { + Ok(Box::pin(HashJoinStream::new( partition, - schema: self.schema(), + self.schema(), on_right, - filter: self.filter.clone(), - join_type: self.join_type, - right: right_stream, - column_indices: column_indices_after_projection, - random_state: self.random_state.clone(), + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), join_metrics, - null_equality: self.null_equality, - state: HashJoinStreamState::WaitBuildSide, - build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, - hashes_buffer: vec![], - right_side_ordered: self.right.output_ordering().is_some(), + vec![], + self.right.output_ordering().is_some(), bounds_accumulator, - })) + ))) } fn metrics(&self) -> Option<MetricsSet> { @@ -1613,666 +1341,22 @@ async fn collect_left_input( Ok(data) } -/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` -/// using `offset` as a start value for `batch` row indices. -/// -/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, -/// which allows to keep either first (if set to true) or last (if set to false) row index -/// as a chain head for rows with equal hash values. -#[allow(clippy::too_many_arguments)] -pub fn update_hash( - on: &[PhysicalExprRef], - batch: &RecordBatch, - hash_map: &mut dyn JoinHashMapType, - offset: usize, - random_state: &RandomState, - hashes_buffer: &mut Vec<u64>, - deleted_offset: usize, - fifo_hashmap: bool, -) -> Result<()> { - // evaluate the keys - let keys_values = on - .iter() - .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) - .collect::<Result<Vec<_>>>()?; - - // calculate the hash values - let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; - - // For usual JoinHashmap, the implementation is void. - hash_map.extend_zero(batch.num_rows()); - - // Updating JoinHashMap from hash values iterator - let hash_values_iter = hash_values - .iter() - .enumerate() - .map(|(i, val)| (i + offset, val)); - - if fifo_hashmap { - hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset); - } else { - hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset); - } - - Ok(()) -} - -/// Represents build-side of hash join. -enum BuildSide { - /// Indicates that build-side not collected yet - Initial(BuildSideInitialState), - /// Indicates that build-side data has been collected - Ready(BuildSideReadyState), -} - -/// Container for BuildSide::Initial related data -struct BuildSideInitialState { - /// Future for building hash table from build-side input - left_fut: OnceFut<JoinLeftData>, -} - -/// Container for BuildSide::Ready related data -struct BuildSideReadyState { - /// Collected build-side data - left_data: Arc<JoinLeftData>, -} - -impl BuildSide { - /// Tries to extract BuildSideInitialState from BuildSide enum. - /// Returns an error if state is not Initial. - fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { - match self { - BuildSide::Initial(state) => Ok(state), - _ => internal_err!("Expected build side in initial state"), - } - } - - /// Tries to extract BuildSideReadyState from BuildSide enum. - /// Returns an error if state is not Ready. - fn try_as_ready(&self) -> Result<&BuildSideReadyState> { - match self { - BuildSide::Ready(state) => Ok(state), - _ => internal_err!("Expected build side in ready state"), - } - } - - /// Tries to extract BuildSideReadyState from BuildSide enum. - /// Returns an error if state is not Ready. - fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { - match self { - BuildSide::Ready(state) => Ok(state), - _ => internal_err!("Expected build side in ready state"), - } - } -} - -/// Represents state of HashJoinStream -/// -/// Expected state transitions performed by HashJoinStream are: -/// -/// ```text -/// -/// WaitBuildSide -/// │ -/// ▼ -/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed -/// │ │ -/// │ ▼ -/// └─ ProcessProbeBatch -/// -/// ``` -#[derive(Debug, Clone)] -enum HashJoinStreamState { - /// Initial state for HashJoinStream indicating that build-side data not collected yet - WaitBuildSide, - /// Indicates that build-side has been collected, and stream is ready for fetching probe-side - FetchProbeBatch, - /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed - ProcessProbeBatch(ProcessProbeBatchState), - /// Indicates that probe-side has been fully processed - ExhaustedProbeSide, - /// Indicates that HashJoinStream execution is completed - Completed, -} - -impl HashJoinStreamState { - /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. - /// Returns an error if state is not ProcessProbeBatchState. - fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { - match self { - HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), - _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), - } - } -} - -/// Container for HashJoinStreamState::ProcessProbeBatch related data -#[derive(Debug, Clone)] -struct ProcessProbeBatchState { - /// Current probe-side batch - batch: RecordBatch, - /// Probe-side on expressions values - values: Vec<ArrayRef>, - /// Starting offset for JoinHashMap lookups - offset: JoinHashMapOffset, - /// Max joined probe-side index from current batch - joined_probe_idx: Option<usize>, -} - -impl ProcessProbeBatchState { - fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option<usize>) { - self.offset = offset; - if joined_probe_idx.is_some() { - self.joined_probe_idx = joined_probe_idx; - } - } -} - -/// [`Stream`] for [`HashJoinExec`] that does the actual join. -/// -/// This stream: -/// -/// 1. Reads the entire left input (build) and constructs a hash table -/// -/// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins -/// them with the contents of the hash table -struct HashJoinStream { - /// Partition identifier for debugging and determinism - partition: usize, - /// Input schema - schema: Arc<Schema>, - /// equijoin columns from the right (probe side) - on_right: Vec<PhysicalExprRef>, - /// optional join filter - filter: Option<JoinFilter>, - /// type of the join (left, right, semi, etc) - join_type: JoinType, - /// right (probe) input - right: SendableRecordBatchStream, - /// Random state used for hashing initialization - random_state: RandomState, - /// Metrics - join_metrics: BuildProbeJoinMetrics, - /// Information of index and left / right placement of columns - column_indices: Vec<ColumnIndex>, - /// Defines the null equality for the join. - null_equality: NullEquality, - /// State of the stream - state: HashJoinStreamState, - /// Build side - build_side: BuildSide, - /// Maximum output batch size - batch_size: usize, - /// Scratch space for computing hashes - hashes_buffer: Vec<u64>, - /// Specifies whether the right side has an ordering to potentially preserve - right_side_ordered: bool, - /// Shared bounds accumulator for coordinating dynamic filter updates (optional) - bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>, -} - -impl RecordBatchStream for HashJoinStream { - fn schema(&self) -> SchemaRef { - Arc::clone(&self.schema) - } -} - -/// Executes lookups by hash against JoinHashMap and resolves potential -/// hash collisions. -/// Returns build/probe indices satisfying the equality condition, along with -/// (optional) starting point for next iteration. -/// -/// # Example -/// -/// For `LEFT.b1 = RIGHT.b2`: -/// LEFT (build) Table: -/// ```text -/// a1 b1 c1 -/// 1 1 10 -/// 3 3 30 -/// 5 5 50 -/// 7 7 70 -/// 9 8 90 -/// 11 8 110 -/// 13 10 130 -/// ``` -/// -/// RIGHT (probe) Table: -/// ```text -/// a2 b2 c2 -/// 2 2 20 -/// 4 4 40 -/// 6 6 60 -/// 8 8 80 -/// 10 10 100 -/// 12 10 120 -/// ``` -/// -/// The result is -/// ```text -/// "+----+----+-----+----+----+-----+", -/// "| a1 | b1 | c1 | a2 | b2 | c2 |", -/// "+----+----+-----+----+----+-----+", -/// "| 9 | 8 | 90 | 8 | 8 | 80 |", -/// "| 11 | 8 | 110 | 8 | 8 | 80 |", -/// "| 13 | 10 | 130 | 10 | 10 | 100 |", -/// "| 13 | 10 | 130 | 12 | 10 | 120 |", -/// "+----+----+-----+----+----+-----+" -/// ``` -/// -/// And the result of build and probe indices are: -/// ```text -/// Build indices: 4, 5, 6, 6 -/// Probe indices: 3, 3, 4, 5 -/// ``` -#[allow(clippy::too_many_arguments)] -fn lookup_join_hashmap( - build_hashmap: &dyn JoinHashMapType, - build_side_values: &[ArrayRef], - probe_side_values: &[ArrayRef], - null_equality: NullEquality, - hashes_buffer: &[u64], - limit: usize, - offset: JoinHashMapOffset, -) -> Result<(UInt64Array, UInt32Array, Option<JoinHashMapOffset>)> { - let (probe_indices, build_indices, next_offset) = - build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset); - - let build_indices: UInt64Array = build_indices.into(); - let probe_indices: UInt32Array = probe_indices.into(); - - let (build_indices, probe_indices) = equal_rows_arr( - &build_indices, - &probe_indices, - build_side_values, - probe_side_values, - null_equality, - )?; - - Ok((build_indices, probe_indices, next_offset)) -} - -// version of eq_dyn supporting equality on null arrays -fn eq_dyn_null( - left: &dyn Array, - right: &dyn Array, - null_equality: NullEquality, -) -> Result<BooleanArray, ArrowError> { - // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special - // implementation - // <https://github.com/apache/datafusion/issues/10749> - if left.data_type().is_nested() { - let op = match null_equality { - NullEquality::NullEqualsNothing => Operator::Eq, - NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, - }; - return Ok(compare_op_for_nested(op, &left, &right)?); - } - match null_equality { - NullEquality::NullEqualsNothing => eq(&left, &right), - NullEquality::NullEqualsNull => not_distinct(&left, &right), - } -} - -pub fn equal_rows_arr( - indices_left: &UInt64Array, - indices_right: &UInt32Array, - left_arrays: &[ArrayRef], - right_arrays: &[ArrayRef], - null_equality: NullEquality, -) -> Result<(UInt64Array, UInt32Array)> { - let mut iter = left_arrays.iter().zip(right_arrays.iter()); - - let Some((first_left, first_right)) = iter.next() else { - return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into())); - }; - - let arr_left = take(first_left.as_ref(), indices_left, None)?; - let arr_right = take(first_right.as_ref(), indices_right, None)?; - - let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?; - - // Use map and try_fold to iterate over the remaining pairs of arrays. - // In each iteration, take is used on the pair of arrays and their equality is determined. - // The results are then folded (combined) using the and function to get a final equality result. - equal = iter - .map(|(left, right)| { - let arr_left = take(left.as_ref(), indices_left, None)?; - let arr_right = take(right.as_ref(), indices_right, None)?; - eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality) - }) - .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; - - let filter_builder = FilterBuilder::new(&equal).optimize().build(); - - let left_filtered = filter_builder.filter(indices_left)?; - let right_filtered = filter_builder.filter(indices_right)?; - - Ok(( - downcast_array(left_filtered.as_ref()), - downcast_array(right_filtered.as_ref()), - )) -} - -impl HashJoinStream { - /// Separate implementation function that unpins the [`HashJoinStream`] so - /// that partial borrows work correctly - fn poll_next_impl( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll<Option<Result<RecordBatch>>> { - loop { - return match self.state { - HashJoinStreamState::WaitBuildSide => { - handle_state!(ready!(self.collect_build_side(cx))) - } - HashJoinStreamState::FetchProbeBatch => { - handle_state!(ready!(self.fetch_probe_batch(cx))) - } - HashJoinStreamState::ProcessProbeBatch(_) => { - let poll = handle_state!(self.process_probe_batch()); - self.join_metrics.baseline.record_poll(poll) - } - HashJoinStreamState::ExhaustedProbeSide => { - let poll = handle_state!(self.process_unmatched_build_batch()); - self.join_metrics.baseline.record_poll(poll) - } - HashJoinStreamState::Completed => Poll::Ready(None), - }; - } - } - - /// Collects build-side data by polling `OnceFut` future from initialized build-side - /// - /// Updates build-side to `Ready`, and state to `FetchProbeSide` - fn collect_build_side( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { - let build_timer = self.join_metrics.build_time.timer(); - // build hash table from left (build) side, if not yet done - let left_data = ready!(self - .build_side - .try_as_initial_mut()? - .left_fut - .get_shared(cx))?; - build_timer.done(); - - // Handle dynamic filter bounds accumulation - // - // Dynamic filter coordination between partitions: - // Report bounds to the accumulator which will handle synchronization and filter updates - if let Some(ref bounds_accumulator) = self.bounds_accumulator { - bounds_accumulator - .report_partition_bounds(self.partition, left_data.bounds.clone())?; - } - - self.state = HashJoinStreamState::FetchProbeBatch; - self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); - - Poll::Ready(Ok(StatefulStreamResult::Continue)) - } - - /// Fetches next batch from probe-side - /// - /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, - /// otherwise updates state to `ExhaustedProbeSide` - fn fetch_probe_batch( - &mut self, - cx: &mut std::task::Context<'_>, - ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { - match ready!(self.right.poll_next_unpin(cx)) { - None => { - self.state = HashJoinStreamState::ExhaustedProbeSide; - } - Some(Ok(batch)) => { - // Precalculate hash values for fetched batch - let keys_values = self - .on_right - .iter() - .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) - .collect::<Result<Vec<_>>>()?; - - self.hashes_buffer.clear(); - self.hashes_buffer.resize(batch.num_rows(), 0); - create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; - - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.state = - HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { - batch, - values: keys_values, - offset: (0, None), - joined_probe_idx: None, - }); - } - Some(Err(err)) => return Poll::Ready(Err(err)), - }; - - Poll::Ready(Ok(StatefulStreamResult::Continue)) - } - - /// Joins current probe batch with build-side data and produces batch with matched output - /// - /// Updates state to `FetchProbeBatch` - fn process_probe_batch( - &mut self, - ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { - let state = self.state.try_as_process_probe_batch_mut()?; - let build_side = self.build_side.try_as_ready_mut()?; - - let timer = self.join_metrics.join_time.timer(); - - // if the left side is empty, we can skip the (potentially expensive) join operation - if build_side.left_data.hash_map.is_empty() && self.filter.is_none() { - let result = build_batch_empty_build_side( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &self.column_indices, - self.join_type, - )?; - self.join_metrics.output_batches.add(1); - timer.done(); - - self.state = HashJoinStreamState::FetchProbeBatch; - - return Ok(StatefulStreamResult::Ready(Some(result))); - } - - // get the matched by join keys indices - let (left_indices, right_indices, next_offset) = lookup_join_hashmap( - build_side.left_data.hash_map(), - build_side.left_data.values(), - &state.values, - self.null_equality, - &self.hashes_buffer, - self.batch_size, - state.offset, - )?; - - // apply join filter if exists - let (left_indices, right_indices) = if let Some(filter) = &self.filter { - apply_join_filter_to_indices( - build_side.left_data.batch(), - &state.batch, - left_indices, - right_indices, - filter, - JoinSide::Left, - None, - )? - } else { - (left_indices, right_indices) - }; - - // mark joined left-side indices as visited, if required by join type - if need_produce_result_in_final(self.join_type) { - let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); - left_indices.iter().flatten().for_each(|x| { - bitmap.set_bit(x as usize, true); - }); - } - - // The goals of index alignment for different join types are: - // - // 1) Right & FullJoin -- to append all missing probe-side indices between - // previous (excluding) and current joined indices. - // 2) SemiJoin -- deduplicate probe indices in range between previous - // (excluding) and current joined indices. - // 3) AntiJoin -- return only missing indices in range between - // previous and current joined indices. - // Inclusion/exclusion of the indices themselves don't matter - // - // As a summary -- alignment range can be produced based only on - // joined (matched with filters applied) probe side indices, excluding starting one - // (left from previous iteration). - - // if any rows have been joined -- get last joined probe-side (right) row - // it's important that index counts as "joined" after hash collisions checks - // and join filters applied. - let last_joined_right_idx = match right_indices.len() { - 0 => None, - n => Some(right_indices.value(n - 1) as usize), - }; - - // Calculate range and perform alignment. - // In case probe batch has been processed -- align all remaining rows. - let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); - let index_alignment_range_end = if next_offset.is_none() { - state.batch.num_rows() - } else { - last_joined_right_idx.map_or(0, |v| v + 1) - }; - - let (left_indices, right_indices) = adjust_indices_by_join_type( - left_indices, - right_indices, - index_alignment_range_start..index_alignment_range_end, - self.join_type, - self.right_side_ordered, - )?; - - let result = if self.join_type == JoinType::RightMark { - build_batch_from_indices( - &self.schema, - &state.batch, - build_side.left_data.batch(), - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Right, - )? - } else { - build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &state.batch, - &left_indices, - &right_indices, - &self.column_indices, - JoinSide::Left, - )? - }; - - self.join_metrics.output_batches.add(1); - timer.done(); - - if next_offset.is_none() { - self.state = HashJoinStreamState::FetchProbeBatch; - } else { - state.advance( - next_offset - .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?, - last_joined_right_idx, - ) - }; - - Ok(StatefulStreamResult::Ready(Some(result))) - } - - /// Processes unmatched build-side rows for certain join types and produces output batch - /// - /// Updates state to `Completed` - fn process_unmatched_build_batch( - &mut self, - ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { - let timer = self.join_metrics.join_time.timer(); - - if !need_produce_result_in_final(self.join_type) { - self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); - } - - let build_side = self.build_side.try_as_ready()?; - if !build_side.left_data.report_probe_completed() { - self.state = HashJoinStreamState::Completed; - return Ok(StatefulStreamResult::Continue); - } - - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = get_final_indices_from_shared_bitmap( - build_side.left_data.visited_indices_bitmap(), - self.join_type, - ); - let empty_right_batch = RecordBatch::new_empty(self.right.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - build_side.left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - - self.join_metrics.output_batches.add(1); - } - timer.done(); - - self.state = HashJoinStreamState::Completed; - - Ok(StatefulStreamResult::Ready(Some(result?))) - } -} - -impl Stream for HashJoinStream { - type Item = Result<RecordBatch>; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll<Option<Self::Item>> { - self.poll_next_impl(cx) - } -} - -impl EmbeddedProjection for HashJoinExec { - fn with_projection(&self, projection: Option<Vec<usize>>) -> Result<Self> { - self.with_projection(projection) - } -} - #[cfg(test)] mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; + use crate::joins::hash_join::stream::lookup_join_hashmap; use crate::test::{assert_join_metrics, TestMemoryExec}; use crate::{ common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; - use arrow::array::{Date32Array, Int32Array, StructArray}; + use arrow::array::{Date32Array, Int32Array, StructArray, UInt32Array, UInt64Array}; use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field}; + use arrow_schema::Schema; + use datafusion_common::hash_utils::create_hashes; use datafusion_common::test_util::{batches_to_sort_string, batches_to_string}; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs new file mode 100644 index 0000000000..7f1e5cae13 --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`HashJoinExec`] Partitioned Hash Join Operator + +pub use exec::HashJoinExec; + +mod exec; +mod shared_bounds; +mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs new file mode 100644 index 0000000000..73e65be686 --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for shared bounds. Used in dynamic filter pushdown in Hash Joins. +// TODO: include the link to the Dynamic Filter blog post. + +use std::fmt; +use std::sync::Arc; + +use crate::joins::PartitionMode; +use crate::ExecutionPlan; +use crate::ExecutionPlanProperties; + +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{lit, BinaryExpr, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; + +use itertools::Itertools; +use parking_lot::Mutex; + +/// Represents the minimum and maximum values for a specific column. +/// Used in dynamic filter pushdown to establish value boundaries. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct ColumnBounds { + /// The minimum value observed for this column + min: ScalarValue, + /// The maximum value observed for this column + max: ScalarValue, +} + +impl ColumnBounds { + pub(crate) fn new(min: ScalarValue, max: ScalarValue) -> Self { + Self { min, max } + } +} + +/// Represents the bounds for all join key columns from a single partition. +/// This contains the min/max values computed from one partition's build-side data. +#[derive(Debug, Clone)] +pub(crate) struct PartitionBounds { + /// Partition identifier for debugging and determinism (not strictly necessary) + partition: usize, + /// Min/max bounds for each join key column in this partition. + /// Index corresponds to the join key expression index. + column_bounds: Vec<ColumnBounds>, +} + +impl PartitionBounds { + pub(crate) fn new(partition: usize, column_bounds: Vec<ColumnBounds>) -> Self { + Self { + partition, + column_bounds, + } + } + + pub(crate) fn len(&self) -> usize { + self.column_bounds.len() + } + + pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> { + self.column_bounds.get(index) + } +} + +/// Coordinates dynamic filter bounds collection across multiple partitions +/// +/// This structure ensures that dynamic filters are built with complete information from all +/// relevant partitions before being applied to probe-side scans. Incomplete filters would +/// incorrectly eliminate valid join results. +/// +/// ## Synchronization Strategy +/// +/// 1. Each partition computes bounds from its build-side data +/// 2. Bounds are stored in the shared HashMap (indexed by partition_id) +/// 3. A counter tracks how many partitions have reported their bounds +/// 4. When the last partition reports (completed == total), bounds are merged and filter is updated +/// +/// ## Partition Counting +/// +/// The `total_partitions` count represents how many times `collect_build_side` will be called: +/// - **CollectLeft**: Number of output partitions (each accesses shared build data) +/// - **Partitioned**: Number of input partitions (each builds independently) +/// +/// ## Thread Safety +/// +/// All fields use a single mutex to ensure correct coordination between concurrent +/// partition executions. +pub(crate) struct SharedBoundsAccumulator { + /// Shared state protected by a single mutex to avoid ordering concerns + inner: Mutex<SharedBoundsState>, + /// Total number of partitions. + /// Need to know this so that we can update the dynamic filter once we are done + /// building *all* of the hash tables. + total_partitions: usize, + /// Dynamic filter for pushdown to probe side + dynamic_filter: Arc<DynamicFilterPhysicalExpr>, + /// Right side join expressions needed for creating filter bounds + on_right: Vec<PhysicalExprRef>, +} + +/// State protected by SharedBoundsAccumulator's mutex +struct SharedBoundsState { + /// Bounds from completed partitions. + /// Each element represents the column bounds computed by one partition. + bounds: Vec<PartitionBounds>, + /// Number of partitions that have reported completion. + completed_partitions: usize, +} + +impl SharedBoundsAccumulator { + /// Creates a new SharedBoundsAccumulator configured for the given partition mode + /// + /// This method calculates how many times `collect_build_side` will be called based on the + /// partition mode's execution pattern. This count is critical for determining when we have + /// complete information from all partitions to build the dynamic filter. + /// + /// ## Partition Mode Execution Patterns + /// + /// - **CollectLeft**: Build side is collected ONCE from partition 0 and shared via `OnceFut` + /// across all output partitions. Each output partition calls `collect_build_side` to access + /// the shared build data. Expected calls = number of output partitions. + /// + /// - **Partitioned**: Each partition independently builds its own hash table by calling + /// `collect_build_side` once. Expected calls = number of build partitions. + /// + /// - **Auto**: Placeholder mode resolved during optimization. Uses 1 as safe default since + /// the actual mode will be determined and a new bounds_accumulator created before execution. + /// + /// ## Why This Matters + /// + /// We cannot build a partial filter from some partitions - it would incorrectly eliminate + /// valid join results. We must wait until we have complete bounds information from ALL + /// relevant partitions before updating the dynamic filter. + pub(crate) fn new_from_partition_mode( + partition_mode: PartitionMode, + left_child: &dyn ExecutionPlan, + right_child: &dyn ExecutionPlan, + dynamic_filter: Arc<DynamicFilterPhysicalExpr>, + on_right: Vec<PhysicalExprRef>, + ) -> Self { + // Troubleshooting: If partition counts are incorrect, verify this logic matches + // the actual execution pattern in collect_build_side() + let expected_calls = match partition_mode { + // Each output partition accesses shared build data + PartitionMode::CollectLeft => { + right_child.output_partitioning().partition_count() + } + // Each partition builds its own data + PartitionMode::Partitioned => { + left_child.output_partitioning().partition_count() + } + // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) + PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), + }; + Self { + inner: Mutex::new(SharedBoundsState { + bounds: Vec::with_capacity(expected_calls), + completed_partitions: 0, + }), + total_partitions: expected_calls, + dynamic_filter, + on_right, + } + } + + /// Create a filter expression from individual partition bounds using OR logic. + /// + /// This creates a filter where each partition's bounds form a conjunction (AND) + /// of column range predicates, and all partitions are combined with OR. + /// + /// For example, with 2 partitions and 2 columns: + /// ((col0 >= p0_min0 AND col0 <= p0_max0 AND col1 >= p0_min1 AND col1 <= p0_max1) + /// OR + /// (col0 >= p1_min0 AND col0 <= p1_max0 AND col1 >= p1_min1 AND col1 <= p1_max1)) + pub(crate) fn create_filter_from_partition_bounds( + &self, + bounds: &[PartitionBounds], + ) -> Result<Arc<dyn PhysicalExpr>> { + if bounds.is_empty() { + return Ok(lit(true)); + } + + // Create a predicate for each partition + let mut partition_predicates = Vec::with_capacity(bounds.len()); + + for partition_bounds in bounds.iter().sorted_by_key(|b| b.partition) { + // Create range predicates for each join key in this partition + let mut column_predicates = Vec::with_capacity(partition_bounds.len()); + + for (col_idx, right_expr) in self.on_right.iter().enumerate() { + if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::GtEq, + lit(column_bounds.min.clone()), + )) as Arc<dyn PhysicalExpr>; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(right_expr), + Operator::LtEq, + lit(column_bounds.max.clone()), + )) as Arc<dyn PhysicalExpr>; + let range_expr = + Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc<dyn PhysicalExpr>; + column_predicates.push(range_expr); + } + } + + // Combine all column predicates for this partition with AND + if !column_predicates.is_empty() { + let partition_predicate = column_predicates + .into_iter() + .reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::And, pred)) + as Arc<dyn PhysicalExpr> + }) + .unwrap(); + partition_predicates.push(partition_predicate); + } + } + + // Combine all partition predicates with OR + let combined_predicate = partition_predicates + .into_iter() + .reduce(|acc, pred| { + Arc::new(BinaryExpr::new(acc, Operator::Or, pred)) + as Arc<dyn PhysicalExpr> + }) + .unwrap_or_else(|| lit(true)); + + Ok(combined_predicate) + } + + /// Report bounds from a completed partition and update dynamic filter if all partitions are done + /// + /// This method coordinates the dynamic filter updates across all partitions. It stores the + /// bounds from the current partition, increments the completion counter, and when all + /// partitions have reported, creates an OR'd filter from individual partition bounds. + /// + /// # Arguments + /// * `partition_bounds` - The bounds computed by this partition (if any) + /// + /// # Returns + /// * `Result<()>` - Ok if successful, Err if filter update failed + pub(crate) fn report_partition_bounds( + &self, + partition: usize, + partition_bounds: Option<Vec<ColumnBounds>>, + ) -> Result<()> { + let mut inner = self.inner.lock(); + + // Store bounds in the accumulator - this runs once per partition + if let Some(bounds) = partition_bounds { + // Only push actual bounds if they exist + inner.bounds.push(PartitionBounds::new(partition, bounds)); + } + + // Increment the completion counter + // Even empty partitions must report to ensure proper termination + inner.completed_partitions += 1; + let completed = inner.completed_partitions; + let total_partitions = self.total_partitions; + + // Critical synchronization point: Only update the filter when ALL partitions are complete + // Troubleshooting: If you see "completed > total_partitions", check partition + // count calculation in new_from_partition_mode() - it may not match actual execution calls + if completed == total_partitions && !inner.bounds.is_empty() { + let filter_expr = self.create_filter_from_partition_bounds(&inner.bounds)?; + self.dynamic_filter.update(filter_expr)?; + } + + Ok(()) + } +} + +impl fmt::Debug for SharedBoundsAccumulator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SharedBoundsAccumulator") + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs new file mode 100644 index 0000000000..d368a9cf8e --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -0,0 +1,628 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream implementation for Hash Join +//! +//! This module implements [`HashJoinStream`], the streaming engine for +//! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. + +use std::sync::Arc; +use std::task::Poll; + +use crate::joins::hash_join::exec::JoinLeftData; +use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::joins::utils::{ + equal_rows_arr, get_final_indices_from_shared_bitmap, OnceFut, +}; +use crate::{ + handle_state, + hash_utils::create_hashes, + joins::join_hash_map::JoinHashMapOffset, + joins::utils::{ + adjust_indices_by_join_type, apply_join_filter_to_indices, + build_batch_empty_build_side, build_batch_from_indices, + need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + JoinHashMapType, StatefulStreamResult, + }, + RecordBatchStream, SendableRecordBatchStream, +}; + +use arrow::array::{ArrayRef, UInt32Array, UInt64Array}; +use arrow::datatypes::{Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion_common::{ + internal_datafusion_err, internal_err, JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_physical_expr::PhysicalExprRef; + +use ahash::RandomState; +use futures::{ready, Stream, StreamExt}; + +/// Represents build-side of hash join. +pub(super) enum BuildSide { + /// Indicates that build-side not collected yet + Initial(BuildSideInitialState), + /// Indicates that build-side data has been collected + Ready(BuildSideReadyState), +} + +/// Container for BuildSide::Initial related data +pub(super) struct BuildSideInitialState { + /// Future for building hash table from build-side input + pub(super) left_fut: OnceFut<JoinLeftData>, +} + +/// Container for BuildSide::Ready related data +pub(super) struct BuildSideReadyState { + /// Collected build-side data + left_data: Arc<JoinLeftData>, +} + +impl BuildSide { + /// Tries to extract BuildSideInitialState from BuildSide enum. + /// Returns an error if state is not Initial. + fn try_as_initial_mut(&mut self) -> Result<&mut BuildSideInitialState> { + match self { + BuildSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready(&self) -> Result<&BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BuildSideReadyState> { + match self { + BuildSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +/// Represents state of HashJoinStream +/// +/// Expected state transitions performed by HashJoinStream are: +/// +/// ```text +/// +/// WaitBuildSide +/// │ +/// ▼ +/// ┌─► FetchProbeBatch ───► ExhaustedProbeSide ───► Completed +/// │ │ +/// │ ▼ +/// └─ ProcessProbeBatch +/// +/// ``` +#[derive(Debug, Clone)] +pub(super) enum HashJoinStreamState { + /// Initial state for HashJoinStream indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for fetching probe-side + FetchProbeBatch, + /// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed + ProcessProbeBatch(ProcessProbeBatchState), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that HashJoinStream execution is completed + Completed, +} + +impl HashJoinStreamState { + /// Tries to extract ProcessProbeBatchState from HashJoinStreamState enum. + /// Returns an error if state is not ProcessProbeBatchState. + fn try_as_process_probe_batch_mut(&mut self) -> Result<&mut ProcessProbeBatchState> { + match self { + HashJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected hash join stream in ProcessProbeBatch state"), + } + } +} + +/// Container for HashJoinStreamState::ProcessProbeBatch related data +#[derive(Debug, Clone)] +pub(super) struct ProcessProbeBatchState { + /// Current probe-side batch + batch: RecordBatch, + /// Probe-side on expressions values + values: Vec<ArrayRef>, + /// Starting offset for JoinHashMap lookups + offset: JoinHashMapOffset, + /// Max joined probe-side index from current batch + joined_probe_idx: Option<usize>, +} + +impl ProcessProbeBatchState { + fn advance(&mut self, offset: JoinHashMapOffset, joined_probe_idx: Option<usize>) { + self.offset = offset; + if joined_probe_idx.is_some() { + self.joined_probe_idx = joined_probe_idx; + } + } +} + +/// [`Stream`] for [`super::HashJoinExec`] that does the actual join. +/// +/// This stream: +/// +/// - Collecting the build side (left input) into a hash map +/// - Iterating over the probe side (right input) in streaming fashion +/// - Looking up matches against the hash table and applying join filters +/// - Producing joined [`RecordBatch`]es incrementally +/// - Emitting unmatched rows for outer/semi/anti joins in the final stage +pub(super) struct HashJoinStream { + /// Partition identifier for debugging and determinism + partition: usize, + /// Input schema + schema: Arc<Schema>, + /// equijoin columns from the right (probe side) + on_right: Vec<PhysicalExprRef>, + /// optional join filter + filter: Option<JoinFilter>, + /// type of the join (left, right, semi, etc) + join_type: JoinType, + /// right (probe) input + right: SendableRecordBatchStream, + /// Random state used for hashing initialization + random_state: RandomState, + /// Metrics + join_metrics: BuildProbeJoinMetrics, + /// Information of index and left / right placement of columns + column_indices: Vec<ColumnIndex>, + /// Defines the null equality for the join. + null_equality: NullEquality, + /// State of the stream + state: HashJoinStreamState, + /// Build side + build_side: BuildSide, + /// Maximum output batch size + batch_size: usize, + /// Scratch space for computing hashes + hashes_buffer: Vec<u64>, + /// Specifies whether the right side has an ordering to potentially preserve + right_side_ordered: bool, + /// Shared bounds accumulator for coordinating dynamic filter updates (optional) + bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>, +} + +impl RecordBatchStream for HashJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +/// Executes lookups by hash against JoinHashMap and resolves potential +/// hash collisions. +/// Returns build/probe indices satisfying the equality condition, along with +/// (optional) starting point for next iteration. +/// +/// # Example +/// +/// For `LEFT.b1 = RIGHT.b2`: +/// LEFT (build) Table: +/// ```text +/// a1 b1 c1 +/// 1 1 10 +/// 3 3 30 +/// 5 5 50 +/// 7 7 70 +/// 9 8 90 +/// 11 8 110 +/// 13 10 130 +/// ``` +/// +/// RIGHT (probe) Table: +/// ```text +/// a2 b2 c2 +/// 2 2 20 +/// 4 4 40 +/// 6 6 60 +/// 8 8 80 +/// 10 10 100 +/// 12 10 120 +/// ``` +/// +/// The result is +/// ```text +/// "+----+----+-----+----+----+-----+", +/// "| a1 | b1 | c1 | a2 | b2 | c2 |", +/// "+----+----+-----+----+----+-----+", +/// "| 9 | 8 | 90 | 8 | 8 | 80 |", +/// "| 11 | 8 | 110 | 8 | 8 | 80 |", +/// "| 13 | 10 | 130 | 10 | 10 | 100 |", +/// "| 13 | 10 | 130 | 12 | 10 | 120 |", +/// "+----+----+-----+----+----+-----+" +/// ``` +/// +/// And the result of build and probe indices are: +/// ```text +/// Build indices: 4, 5, 6, 6 +/// Probe indices: 3, 3, 4, 5 +/// ``` +#[allow(clippy::too_many_arguments)] +pub(super) fn lookup_join_hashmap( + build_hashmap: &dyn JoinHashMapType, + build_side_values: &[ArrayRef], + probe_side_values: &[ArrayRef], + null_equality: NullEquality, + hashes_buffer: &[u64], + limit: usize, + offset: JoinHashMapOffset, +) -> Result<(UInt64Array, UInt32Array, Option<JoinHashMapOffset>)> { + let (probe_indices, build_indices, next_offset) = + build_hashmap.get_matched_indices_with_limit_offset(hashes_buffer, limit, offset); + + let build_indices: UInt64Array = build_indices.into(); + let probe_indices: UInt32Array = probe_indices.into(); + + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + build_side_values, + probe_side_values, + null_equality, + )?; + + Ok((build_indices, probe_indices, next_offset)) +} + +impl HashJoinStream { + #[allow(clippy::too_many_arguments)] + pub(super) fn new( + partition: usize, + schema: Arc<Schema>, + on_right: Vec<PhysicalExprRef>, + filter: Option<JoinFilter>, + join_type: JoinType, + right: SendableRecordBatchStream, + random_state: RandomState, + join_metrics: BuildProbeJoinMetrics, + column_indices: Vec<ColumnIndex>, + null_equality: NullEquality, + state: HashJoinStreamState, + build_side: BuildSide, + batch_size: usize, + hashes_buffer: Vec<u64>, + right_side_ordered: bool, + bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>, + ) -> Self { + Self { + partition, + schema, + on_right, + filter, + join_type, + right, + random_state, + join_metrics, + column_indices, + null_equality, + state, + build_side, + batch_size, + hashes_buffer, + right_side_ordered, + bounds_accumulator, + } + } + + /// Separate implementation function that unpins the [`HashJoinStream`] so + /// that partial borrows work correctly + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Result<RecordBatch>>> { + loop { + return match self.state { + HashJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + HashJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + HashJoinStreamState::ProcessProbeBatch(_) => { + let poll = handle_state!(self.process_probe_batch()); + self.join_metrics.baseline.record_poll(poll) + } + HashJoinStreamState::ExhaustedProbeSide => { + let poll = handle_state!(self.process_unmatched_build_batch()); + self.join_metrics.baseline.record_poll(poll) + } + HashJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + /// Collects build-side data by polling `OnceFut` future from initialized build-side + /// + /// Updates build-side to `Ready`, and state to `FetchProbeSide` + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { + let build_timer = self.join_metrics.build_time.timer(); + // build hash table from left (build) side, if not yet done + let left_data = ready!(self + .build_side + .try_as_initial_mut()? + .left_fut + .get_shared(cx))?; + build_timer.done(); + + // Handle dynamic filter bounds accumulation + // + // Dynamic filter coordination between partitions: + // Report bounds to the accumulator which will handle synchronization and filter updates + if let Some(ref bounds_accumulator) = self.bounds_accumulator { + bounds_accumulator + .report_partition_bounds(self.partition, left_data.bounds.clone())?; + } + + self.state = HashJoinStreamState::FetchProbeBatch; + self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If non-empty batch has been fetched, updates state to `ProcessProbeBatchState`, + /// otherwise updates state to `ExhaustedProbeSide` + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> { + match ready!(self.right.poll_next_unpin(cx)) { + None => { + self.state = HashJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(batch)) => { + // Precalculate hash values for fetched batch + let keys_values = self + .on_right + .iter() + .map(|c| c.evaluate(&batch)?.into_array(batch.num_rows())) + .collect::<Result<Vec<_>>>()?; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(batch.num_rows(), 0); + create_hashes(&keys_values, &self.random_state, &mut self.hashes_buffer)?; + + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.state = + HashJoinStreamState::ProcessProbeBatch(ProcessProbeBatchState { + batch, + values: keys_values, + offset: (0, None), + joined_probe_idx: None, + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with matched output + /// + /// Updates state to `FetchProbeBatch` + fn process_probe_batch( + &mut self, + ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { + let state = self.state.try_as_process_probe_batch_mut()?; + let build_side = self.build_side.try_as_ready_mut()?; + + let timer = self.join_metrics.join_time.timer(); + + // if the left side is empty, we can skip the (potentially expensive) join operation + if build_side.left_data.hash_map.is_empty() && self.filter.is_none() { + let result = build_batch_empty_build_side( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &self.column_indices, + self.join_type, + )?; + self.join_metrics.output_batches.add(1); + timer.done(); + + self.state = HashJoinStreamState::FetchProbeBatch; + + return Ok(StatefulStreamResult::Ready(Some(result))); + } + + // get the matched by join keys indices + let (left_indices, right_indices, next_offset) = lookup_join_hashmap( + build_side.left_data.hash_map(), + build_side.left_data.values(), + &state.values, + self.null_equality, + &self.hashes_buffer, + self.batch_size, + state.offset, + )?; + + // apply join filter if exists + let (left_indices, right_indices) = if let Some(filter) = &self.filter { + apply_join_filter_to_indices( + build_side.left_data.batch(), + &state.batch, + left_indices, + right_indices, + filter, + JoinSide::Left, + None, + )? + } else { + (left_indices, right_indices) + }; + + // mark joined left-side indices as visited, if required by join type + if need_produce_result_in_final(self.join_type) { + let mut bitmap = build_side.left_data.visited_indices_bitmap().lock(); + left_indices.iter().flatten().for_each(|x| { + bitmap.set_bit(x as usize, true); + }); + } + + // The goals of index alignment for different join types are: + // + // 1) Right & FullJoin -- to append all missing probe-side indices between + // previous (excluding) and current joined indices. + // 2) SemiJoin -- deduplicate probe indices in range between previous + // (excluding) and current joined indices. + // 3) AntiJoin -- return only missing indices in range between + // previous and current joined indices. + // Inclusion/exclusion of the indices themselves don't matter + // + // As a summary -- alignment range can be produced based only on + // joined (matched with filters applied) probe side indices, excluding starting one + // (left from previous iteration). + + // if any rows have been joined -- get last joined probe-side (right) row + // it's important that index counts as "joined" after hash collisions checks + // and join filters applied. + let last_joined_right_idx = match right_indices.len() { + 0 => None, + n => Some(right_indices.value(n - 1) as usize), + }; + + // Calculate range and perform alignment. + // In case probe batch has been processed -- align all remaining rows. + let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); + let index_alignment_range_end = if next_offset.is_none() { + state.batch.num_rows() + } else { + last_joined_right_idx.map_or(0, |v| v + 1) + }; + + let (left_indices, right_indices) = adjust_indices_by_join_type( + left_indices, + right_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + self.right_side_ordered, + )?; + + let result = if self.join_type == JoinType::RightMark { + build_batch_from_indices( + &self.schema, + &state.batch, + build_side.left_data.batch(), + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Right, + )? + } else { + build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &state.batch, + &left_indices, + &right_indices, + &self.column_indices, + JoinSide::Left, + )? + }; + + self.join_metrics.output_batches.add(1); + timer.done(); + + if next_offset.is_none() { + self.state = HashJoinStreamState::FetchProbeBatch; + } else { + state.advance( + next_offset + .ok_or_else(|| internal_datafusion_err!("unexpected None offset"))?, + last_joined_right_idx, + ) + }; + + Ok(StatefulStreamResult::Ready(Some(result))) + } + + /// Processes unmatched build-side rows for certain join types and produces output batch + /// + /// Updates state to `Completed` + fn process_unmatched_build_batch( + &mut self, + ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { + let timer = self.join_metrics.join_time.timer(); + + if !need_produce_result_in_final(self.join_type) { + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + + let build_side = self.build_side.try_as_ready()?; + if !build_side.left_data.report_probe_completed() { + self.state = HashJoinStreamState::Completed; + return Ok(StatefulStreamResult::Continue); + } + + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = get_final_indices_from_shared_bitmap( + build_side.left_data.visited_indices_bitmap(), + self.join_type, + ); + let empty_right_batch = RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + build_side.left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + } + timer.done(); + + self.state = HashJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(result?))) + } +} + +impl Stream for HashJoinStream { + type Item = Result<RecordBatch>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + self.poll_next_impl(cx) + } +} diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 9a8d4cbb66..aedeb97186 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -34,7 +34,6 @@ use std::vec; use crate::common::SharedMemoryReservation; use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; -use crate::joins::hash_join::{equal_rows_arr, update_hash}; use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, @@ -43,9 +42,9 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, - BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, - NoopBatchTransformer, StatefulStreamResult, + check_join_is_valid, equal_rows_arr, symmetric_join_output_partitioning, update_hash, + BatchSplitter, BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, + JoinOnRef, NoopBatchTransformer, StatefulStreamResult, }; use crate::projection::{ join_allows_pushdown, join_table_borders, new_join_children, diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index a7cd81a98f..cf665f2a5a 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -37,23 +37,29 @@ pub use super::join_filter::JoinFilter; pub use super::join_hash_map::JoinHashMapType; pub use crate::joins::{JoinOn, JoinOnRef}; -use arrow::array::BooleanArray; +use ahash::RandomState; use arrow::array::{ builder::UInt64Builder, downcast_array, new_null_array, Array, ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, }; +use arrow::array::{ArrayRef, BooleanArray}; use arrow::buffer::{BooleanBuffer, NullBuffer}; -use arrow::compute; +use arrow::compute::kernels::cmp::eq; +use arrow::compute::{self, and, take, FilterBuilder}; use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; +use arrow_ord::cmp::not_distinct; +use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::{ - plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, + plan_err, DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::Operator; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ @@ -61,6 +67,7 @@ use datafusion_physical_expr::{ PhysicalExprRef, }; +use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use parking_lot::Mutex; @@ -963,7 +970,7 @@ pub(crate) fn build_batch_from_indices( assert_eq!(build_indices.null_count(), build_indices.len()); new_null_array(array.data_type(), build_indices.len()) } else { - compute::take(array.as_ref(), build_indices, None)? + take(array.as_ref(), build_indices, None)? } } else { let array = probe_batch.column(column_index.index); @@ -971,7 +978,7 @@ pub(crate) fn build_batch_from_indices( assert_eq!(probe_indices.null_count(), probe_indices.len()); new_null_array(array.data_type(), probe_indices.len()) } else { - compute::take(array.as_ref(), probe_indices, None)? + take(array.as_ref(), probe_indices, None)? } }; @@ -1633,6 +1640,112 @@ pub fn swap_join_projection( } } +/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on` +/// using `offset` as a start value for `batch` row indices. +/// +/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap, +/// which allows to keep either first (if set to true) or last (if set to false) row index +/// as a chain head for rows with equal hash values. +#[allow(clippy::too_many_arguments)] +pub fn update_hash( + on: &[PhysicalExprRef], + batch: &RecordBatch, + hash_map: &mut dyn JoinHashMapType, + offset: usize, + random_state: &RandomState, + hashes_buffer: &mut Vec<u64>, + deleted_offset: usize, + fifo_hashmap: bool, +) -> Result<()> { + // evaluate the keys + let keys_values = on + .iter() + .map(|c| c.evaluate(batch)?.into_array(batch.num_rows())) + .collect::<Result<Vec<_>>>()?; + + // calculate the hash values + let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; + + // For usual JoinHashmap, the implementation is void. + hash_map.extend_zero(batch.num_rows()); + + // Updating JoinHashMap from hash values iterator + let hash_values_iter = hash_values + .iter() + .enumerate() + .map(|(i, val)| (i + offset, val)); + + if fifo_hashmap { + hash_map.update_from_iter(Box::new(hash_values_iter.rev()), deleted_offset); + } else { + hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset); + } + + Ok(()) +} + +pub(super) fn equal_rows_arr( + indices_left: &UInt64Array, + indices_right: &UInt32Array, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + null_equality: NullEquality, +) -> Result<(UInt64Array, UInt32Array)> { + let mut iter = left_arrays.iter().zip(right_arrays.iter()); + + let Some((first_left, first_right)) = iter.next() else { + return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into())); + }; + + let arr_left = take(first_left.as_ref(), indices_left, None)?; + let arr_right = take(first_right.as_ref(), indices_right, None)?; + + let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?; + + // Use map and try_fold to iterate over the remaining pairs of arrays. + // In each iteration, take is used on the pair of arrays and their equality is determined. + // The results are then folded (combined) using the and function to get a final equality result. + equal = iter + .map(|(left, right)| { + let arr_left = take(left.as_ref(), indices_left, None)?; + let arr_right = take(right.as_ref(), indices_right, None)?; + eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality) + }) + .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; + + let filter_builder = FilterBuilder::new(&equal).optimize().build(); + + let left_filtered = filter_builder.filter(indices_left)?; + let right_filtered = filter_builder.filter(indices_right)?; + + Ok(( + downcast_array(left_filtered.as_ref()), + downcast_array(right_filtered.as_ref()), + )) +} + +// version of eq_dyn supporting equality on null arrays +fn eq_dyn_null( + left: &dyn Array, + right: &dyn Array, + null_equality: NullEquality, +) -> Result<BooleanArray, ArrowError> { + // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special + // implementation + // <https://github.com/apache/datafusion/issues/10749> + if left.data_type().is_nested() { + let op = match null_equality { + NullEquality::NullEqualsNothing => Operator::Eq, + NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, + }; + return Ok(compare_op_for_nested(op, &left, &right)?); + } + match null_equality { + NullEquality::NullEqualsNothing => eq(&left, &right), + NullEquality::NullEqualsNull => not_distinct(&left, &right), + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org