This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 87e931c976 Split output batches of joins that do not respect batch 
size (#12969)
87e931c976 is described below

commit 87e931c976a7aa24cecaa9bf3658b42bba12a51e
Author: Alihan Çelikcan <[email protected]>
AuthorDate: Fri Oct 18 14:34:42 2024 +0300

    Split output batches of joins that do not respect batch size (#12969)
    
    * Add BatchSplitter to joins that do not respect batch size
    
    * Group relevant imports
    
    * Update configs.md
    
    * Update SQL logic tests for config
    
    * Review
    
    * Use PrimitiveBuilder for PrimitiveArray concatenation
    
    * Fix into_builder() bug
    
    * Apply suggestions from code review
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Update config docs
    
    * Format
    
    * Update config SQL Logic Test
    
    ---------
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/common/src/config.rs                    |  26 +-
 datafusion/execution/src/config.rs                 |  14 +
 datafusion/physical-plan/src/joins/cross_join.rs   |  84 +++--
 datafusion/physical-plan/src/joins/hash_join.rs    |   2 +-
 .../physical-plan/src/joins/nested_loop_join.rs    | 356 ++++++++++++++-------
 .../physical-plan/src/joins/stream_join_utils.rs   |  83 +++--
 .../physical-plan/src/joins/symmetric_hash_join.rs | 252 ++++++++-------
 datafusion/physical-plan/src/joins/utils.rs        | 220 +++++++++++--
 .../sqllogictest/test_files/information_schema.slt |   2 +
 docs/source/user-guide/configs.md                  |   1 +
 10 files changed, 709 insertions(+), 331 deletions(-)

diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs
index 1e1c5d5424..47ffe0b1c6 100644
--- a/datafusion/common/src/config.rs
+++ b/datafusion/common/src/config.rs
@@ -338,6 +338,12 @@ config_namespace! {
         /// if the source of statistics is accurate.
         /// We plan to make this the default in the future.
         pub use_row_number_estimates_to_optimize_partitioning: bool, default = 
false
+
+        /// Should DataFusion enforce batch size in joins or not. By default,
+        /// DataFusion will not enforce batch size in joins. Enforcing batch 
size
+        /// in joins can reduce memory usage when joining large
+        /// tables with a highly-selective join filter, but is also slightly 
slower.
+        pub enforce_batch_size_in_joins: bool, default = false
     }
 }
 
@@ -1222,16 +1228,18 @@ impl ConfigField for TableOptions {
     fn set(&mut self, key: &str, value: &str) -> Result<()> {
         // Extensions are handled in the public `ConfigOptions::set`
         let (key, rem) = key.split_once('.').unwrap_or((key, ""));
-        let Some(format) = &self.current_format else {
-            return _config_err!("Specify a format for TableOptions");
-        };
         match key {
-            "format" => match format {
-                #[cfg(feature = "parquet")]
-                ConfigFileType::PARQUET => self.parquet.set(rem, value),
-                ConfigFileType::CSV => self.csv.set(rem, value),
-                ConfigFileType::JSON => self.json.set(rem, value),
-            },
+            "format" => {
+                let Some(format) = &self.current_format else {
+                    return _config_err!("Specify a format for TableOptions");
+                };
+                match format {
+                    #[cfg(feature = "parquet")]
+                    ConfigFileType::PARQUET => self.parquet.set(rem, value),
+                    ConfigFileType::CSV => self.csv.set(rem, value),
+                    ConfigFileType::JSON => self.json.set(rem, value),
+                }
+            }
             _ => _config_err!("Config value \"{key}\" not found on 
TableOptions"),
         }
     }
diff --git a/datafusion/execution/src/config.rs 
b/datafusion/execution/src/config.rs
index cede75d21c..53646dc5b4 100644
--- a/datafusion/execution/src/config.rs
+++ b/datafusion/execution/src/config.rs
@@ -432,6 +432,20 @@ impl SessionConfig {
         self
     }
 
+    /// Enables or disables the enforcement of batch size in joins
+    pub fn with_enforce_batch_size_in_joins(
+        mut self,
+        enforce_batch_size_in_joins: bool,
+    ) -> Self {
+        self.options.execution.enforce_batch_size_in_joins = 
enforce_batch_size_in_joins;
+        self
+    }
+
+    /// Returns true if the joins will be enforced to output batches of the 
configured size
+    pub fn enforce_batch_size_in_joins(&self) -> bool {
+        self.options.execution.enforce_batch_size_in_joins
+    }
+
     /// Convert configuration options to name-value pairs with values
     /// converted to strings.
     ///
diff --git a/datafusion/physical-plan/src/joins/cross_join.rs 
b/datafusion/physical-plan/src/joins/cross_join.rs
index a70645f3d6..8f2bef56da 100644
--- a/datafusion/physical-plan/src/joins/cross_join.rs
+++ b/datafusion/physical-plan/src/joins/cross_join.rs
@@ -19,7 +19,8 @@
 //! and producing batches in parallel for the right partitions
 
 use super::utils::{
-    adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, 
OnceFut,
+    adjust_right_output_partitioning, BatchSplitter, BatchTransformer,
+    BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
     StatefulStreamResult,
 };
 use crate::coalesce_partitions::CoalescePartitionsExec;
@@ -86,6 +87,7 @@ impl CrossJoinExec {
 
         let schema = 
Arc::new(Schema::new(all_columns).with_metadata(metadata));
         let cache = Self::compute_properties(&left, &right, 
Arc::clone(&schema));
+
         CrossJoinExec {
             left,
             right,
@@ -246,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec {
         let reservation =
             
MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
 
+        let batch_size = context.session_config().batch_size();
+        let enforce_batch_size_in_joins =
+            context.session_config().enforce_batch_size_in_joins();
+
         let left_fut = self.left_fut.once(|| {
             load_left_input(
                 Arc::clone(&self.left),
@@ -255,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec {
             )
         });
 
-        Ok(Box::pin(CrossJoinStream {
-            schema: Arc::clone(&self.schema),
-            left_fut,
-            right: stream,
-            left_index: 0,
-            join_metrics,
-            state: CrossJoinStreamState::WaitBuildSide,
-            left_data: RecordBatch::new_empty(self.left().schema()),
-        }))
+        if enforce_batch_size_in_joins {
+            Ok(Box::pin(CrossJoinStream {
+                schema: Arc::clone(&self.schema),
+                left_fut,
+                right: stream,
+                left_index: 0,
+                join_metrics,
+                state: CrossJoinStreamState::WaitBuildSide,
+                left_data: RecordBatch::new_empty(self.left().schema()),
+                batch_transformer: BatchSplitter::new(batch_size),
+            }))
+        } else {
+            Ok(Box::pin(CrossJoinStream {
+                schema: Arc::clone(&self.schema),
+                left_fut,
+                right: stream,
+                left_index: 0,
+                join_metrics,
+                state: CrossJoinStreamState::WaitBuildSide,
+                left_data: RecordBatch::new_empty(self.left().schema()),
+                batch_transformer: NoopBatchTransformer::new(),
+            }))
+        }
     }
 
     fn statistics(&self) -> Result<Statistics> {
@@ -319,7 +339,7 @@ fn stats_cartesian_product(
 }
 
 /// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
-struct CrossJoinStream {
+struct CrossJoinStream<T> {
     /// Input schema
     schema: Arc<Schema>,
     /// Future for data from left side
@@ -334,9 +354,11 @@ struct CrossJoinStream {
     state: CrossJoinStreamState,
     /// Left data
     left_data: RecordBatch,
+    /// Batch transformer
+    batch_transformer: T,
 }
 
-impl RecordBatchStream for CrossJoinStream {
+impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for 
CrossJoinStream<T> {
     fn schema(&self) -> SchemaRef {
         Arc::clone(&self.schema)
     }
@@ -390,7 +412,7 @@ fn build_batch(
 }
 
 #[async_trait]
-impl Stream for CrossJoinStream {
+impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
     type Item = Result<RecordBatch>;
 
     fn poll_next(
@@ -401,7 +423,7 @@ impl Stream for CrossJoinStream {
     }
 }
 
-impl CrossJoinStream {
+impl<T: BatchTransformer> CrossJoinStream<T> {
     /// Separate implementation function that unpins the [`CrossJoinStream`] so
     /// that partial borrows work correctly
     fn poll_next_impl(
@@ -470,21 +492,33 @@ impl CrossJoinStream {
     fn build_batches(&mut self) -> 
Result<StatefulStreamResult<Option<RecordBatch>>> {
         let right_batch = self.state.try_as_record_batch()?;
         if self.left_index < self.left_data.num_rows() {
-            let join_timer = self.join_metrics.join_time.timer();
-            let result =
-                build_batch(self.left_index, right_batch, &self.left_data, 
&self.schema);
-            join_timer.done();
-
-            if let Ok(ref batch) = result {
-                self.join_metrics.output_batches.add(1);
-                self.join_metrics.output_rows.add(batch.num_rows());
+            match self.batch_transformer.next() {
+                None => {
+                    let join_timer = self.join_metrics.join_time.timer();
+                    let result = build_batch(
+                        self.left_index,
+                        right_batch,
+                        &self.left_data,
+                        &self.schema,
+                    );
+                    join_timer.done();
+
+                    self.batch_transformer.set_batch(result?);
+                }
+                Some((batch, last)) => {
+                    if last {
+                        self.left_index += 1;
+                    }
+
+                    self.join_metrics.output_batches.add(1);
+                    self.join_metrics.output_rows.add(batch.num_rows());
+                    return Ok(StatefulStreamResult::Ready(Some(batch)));
+                }
             }
-            self.left_index += 1;
-            result.map(|r| StatefulStreamResult::Ready(Some(r)))
         } else {
             self.state = CrossJoinStreamState::FetchProbeBatch;
-            Ok(StatefulStreamResult::Continue)
         }
+        Ok(StatefulStreamResult::Continue)
     }
 }
 
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 74a45a7e47..3b730c0129 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -1438,7 +1438,7 @@ impl HashJoinStream {
             index_alignment_range_start..index_alignment_range_end,
             self.join_type,
             self.right_side_ordered,
-        );
+        )?;
 
         let result = build_batch_from_indices(
             &self.schema,
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs 
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index 6068e75263..358ff02473 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -25,7 +25,10 @@ use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Arc;
 use std::task::Poll;
 
-use super::utils::{asymmetric_join_output_partitioning, 
need_produce_result_in_final};
+use super::utils::{
+    asymmetric_join_output_partitioning, need_produce_result_in_final, 
BatchSplitter,
+    BatchTransformer, NoopBatchTransformer, StatefulStreamResult,
+};
 use crate::coalesce_partitions::CoalescePartitionsExec;
 use crate::joins::utils::{
     adjust_indices_by_join_type, apply_join_filter_to_indices, 
build_batch_from_indices,
@@ -35,8 +38,8 @@ use crate::joins::utils::{
 };
 use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
 use crate::{
-    execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution,
-    ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
+    execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType,
+    Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, 
PlanProperties,
     RecordBatchStream, SendableRecordBatchStream,
 };
 
@@ -45,7 +48,9 @@ use arrow::compute::concat_batches;
 use arrow::datatypes::{Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use arrow::util::bit_util;
-use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics};
+use datafusion_common::{
+    exec_datafusion_err, internal_err, JoinSide, Result, Statistics,
+};
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
 use datafusion_expr::JoinType;
@@ -230,10 +235,11 @@ impl NestedLoopJoinExec {
             asymmetric_join_output_partitioning(left, right, &join_type);
 
         // Determine execution mode:
-        let mut mode = execution_mode_from_children([left, right]);
-        if mode.is_unbounded() {
-            mode = ExecutionMode::PipelineBreaking;
-        }
+        let mode = if left.execution_mode().is_unbounded() {
+            ExecutionMode::PipelineBreaking
+        } else {
+            execution_mode_from_children([left, right])
+        };
 
         PlanProperties::new(eq_properties, output_partitioning, mode)
     }
@@ -345,6 +351,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
             )
         });
 
+        let batch_size = context.session_config().batch_size();
+        let enforce_batch_size_in_joins =
+            context.session_config().enforce_batch_size_in_joins();
+
         let outer_table = self.right.execute(partition, context)?;
 
         let indices_cache = (UInt64Array::new_null(0), 
UInt32Array::new_null(0));
@@ -352,18 +362,38 @@ impl ExecutionPlan for NestedLoopJoinExec {
         // Right side has an order and it is maintained during operation.
         let right_side_ordered =
             self.maintains_input_order()[1] && 
self.right.output_ordering().is_some();
-        Ok(Box::pin(NestedLoopJoinStream {
-            schema: Arc::clone(&self.schema),
-            filter: self.filter.clone(),
-            join_type: self.join_type,
-            outer_table,
-            inner_table,
-            is_exhausted: false,
-            column_indices: self.column_indices.clone(),
-            join_metrics,
-            indices_cache,
-            right_side_ordered,
-        }))
+
+        if enforce_batch_size_in_joins {
+            Ok(Box::pin(NestedLoopJoinStream {
+                schema: Arc::clone(&self.schema),
+                filter: self.filter.clone(),
+                join_type: self.join_type,
+                outer_table,
+                inner_table,
+                column_indices: self.column_indices.clone(),
+                join_metrics,
+                indices_cache,
+                right_side_ordered,
+                state: NestedLoopJoinStreamState::WaitBuildSide,
+                batch_transformer: BatchSplitter::new(batch_size),
+                left_data: None,
+            }))
+        } else {
+            Ok(Box::pin(NestedLoopJoinStream {
+                schema: Arc::clone(&self.schema),
+                filter: self.filter.clone(),
+                join_type: self.join_type,
+                outer_table,
+                inner_table,
+                column_indices: self.column_indices.clone(),
+                join_metrics,
+                indices_cache,
+                right_side_ordered,
+                state: NestedLoopJoinStreamState::WaitBuildSide,
+                batch_transformer: NoopBatchTransformer::new(),
+                left_data: None,
+            }))
+        }
     }
 
     fn metrics(&self) -> Option<MetricsSet> {
@@ -442,8 +472,37 @@ async fn collect_left_input(
     ))
 }
 
+/// This enumeration represents various states of the nested loop join 
algorithm.
+#[derive(Debug, Clone)]
+enum NestedLoopJoinStreamState {
+    /// The initial state, indicating that build-side data not collected yet
+    WaitBuildSide,
+    /// Indicates that build-side has been collected, and stream is ready for
+    /// fetching probe-side
+    FetchProbeBatch,
+    /// Indicates that a non-empty batch has been fetched from probe-side, and
+    /// is ready to be processed
+    ProcessProbeBatch(RecordBatch),
+    /// Indicates that probe-side has been fully processed
+    ExhaustedProbeSide,
+    /// Indicates that NestedLoopJoinStream execution is completed
+    Completed,
+}
+
+impl NestedLoopJoinStreamState {
+    /// Tries to extract a `ProcessProbeBatchState` from the
+    /// `NestedLoopJoinStreamState` enum. Returns an error if state is not
+    /// `ProcessProbeBatchState`.
+    fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> {
+        match self {
+            NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state),
+            _ => internal_err!("Expected join stream in ProcessProbeBatch 
state"),
+        }
+    }
+}
+
 /// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
-struct NestedLoopJoinStream {
+struct NestedLoopJoinStream<T> {
     /// Input schema
     schema: Arc<Schema>,
     /// join filter
@@ -454,8 +513,6 @@ struct NestedLoopJoinStream {
     outer_table: SendableRecordBatchStream,
     /// the inner table data of the nested loop join
     inner_table: OnceFut<JoinLeftData>,
-    /// There is nothing to process anymore and left side is processed in case 
of full join
-    is_exhausted: bool,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
     // TODO: support null aware equal
@@ -466,6 +523,12 @@ struct NestedLoopJoinStream {
     indices_cache: (UInt64Array, UInt32Array),
     /// Whether the right side is ordered
     right_side_ordered: bool,
+    /// Current state of the stream
+    state: NestedLoopJoinStreamState,
+    /// Transforms the output batch before returning.
+    batch_transformer: T,
+    /// Result of the left data future
+    left_data: Option<Arc<JoinLeftData>>,
 }
 
 /// Creates a Cartesian product of two input batches, preserving the order of 
the right batch,
@@ -544,107 +607,164 @@ fn build_join_indices(
     }
 }
 
-impl NestedLoopJoinStream {
+impl<T: BatchTransformer> NestedLoopJoinStream<T> {
     fn poll_next_impl(
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
-        // all left row
+        loop {
+            return match self.state {
+                NestedLoopJoinStreamState::WaitBuildSide => {
+                    handle_state!(ready!(self.collect_build_side(cx)))
+                }
+                NestedLoopJoinStreamState::FetchProbeBatch => {
+                    handle_state!(ready!(self.fetch_probe_batch(cx)))
+                }
+                NestedLoopJoinStreamState::ProcessProbeBatch(_) => {
+                    handle_state!(self.process_probe_batch())
+                }
+                NestedLoopJoinStreamState::ExhaustedProbeSide => {
+                    handle_state!(self.process_unmatched_build_batch())
+                }
+                NestedLoopJoinStreamState::Completed => Poll::Ready(None),
+            };
+        }
+    }
+
+    fn collect_build_side(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         let build_timer = self.join_metrics.build_time.timer();
-        let left_data = match ready!(self.inner_table.get_shared(cx)) {
-            Ok(data) => data,
-            Err(e) => return Poll::Ready(Some(Err(e))),
-        };
+        // build hash table from left (build) side, if not yet done
+        self.left_data = Some(ready!(self.inner_table.get_shared(cx))?);
         build_timer.done();
 
-        // Get or initialize visited_left_side bitmap if required by join type
+        self.state = NestedLoopJoinStreamState::FetchProbeBatch;
+
+        Poll::Ready(Ok(StatefulStreamResult::Continue))
+    }
+
+    /// Fetches next batch from probe-side
+    ///
+    /// If a non-empty batch has been fetched, updates state to
+    /// `ProcessProbeBatchState`, otherwise updates state to 
`ExhaustedProbeSide`.
+    fn fetch_probe_batch(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        match ready!(self.outer_table.poll_next_unpin(cx)) {
+            None => {
+                self.state = NestedLoopJoinStreamState::ExhaustedProbeSide;
+            }
+            Some(Ok(right_batch)) => {
+                self.state = 
NestedLoopJoinStreamState::ProcessProbeBatch(right_batch);
+            }
+            Some(Err(err)) => return Poll::Ready(Err(err)),
+        };
+
+        Poll::Ready(Ok(StatefulStreamResult::Continue))
+    }
+
+    /// Joins current probe batch with build-side data and produces batch with
+    /// matched output, updates state to `FetchProbeBatch`.
+    fn process_probe_batch(
+        &mut self,
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+        let Some(left_data) = self.left_data.clone() else {
+            return internal_err!(
+                "Expected left_data to be Some in ProcessProbeBatch state"
+            );
+        };
         let visited_left_side = left_data.bitmap();
+        let batch = self.state.try_as_process_probe_batch()?;
+
+        match self.batch_transformer.next() {
+            None => {
+                // Setting up timer & updating input metrics
+                self.join_metrics.input_batches.add(1);
+                self.join_metrics.input_rows.add(batch.num_rows());
+                let timer = self.join_metrics.join_time.timer();
+
+                let result = join_left_and_right_batch(
+                    left_data.batch(),
+                    batch,
+                    self.join_type,
+                    self.filter.as_ref(),
+                    &self.column_indices,
+                    &self.schema,
+                    visited_left_side,
+                    &mut self.indices_cache,
+                    self.right_side_ordered,
+                );
+                timer.done();
+
+                self.batch_transformer.set_batch(result?);
+                Ok(StatefulStreamResult::Continue)
+            }
+            Some((batch, last)) => {
+                if last {
+                    self.state = NestedLoopJoinStreamState::FetchProbeBatch;
+                }
 
-        // Check is_exhausted before polling the outer_table, such that when 
the outer table
-        // does not support `FusedStream`, Self will not poll it again
-        if self.is_exhausted {
-            return Poll::Ready(None);
+                self.join_metrics.output_batches.add(1);
+                self.join_metrics.output_rows.add(batch.num_rows());
+                Ok(StatefulStreamResult::Ready(Some(batch)))
+            }
         }
+    }
 
-        self.outer_table
-            .poll_next_unpin(cx)
-            .map(|maybe_batch| match maybe_batch {
-                Some(Ok(right_batch)) => {
-                    // Setting up timer & updating input metrics
-                    self.join_metrics.input_batches.add(1);
-                    self.join_metrics.input_rows.add(right_batch.num_rows());
-                    let timer = self.join_metrics.join_time.timer();
-
-                    let result = join_left_and_right_batch(
-                        left_data.batch(),
-                        &right_batch,
-                        self.join_type,
-                        self.filter.as_ref(),
-                        &self.column_indices,
-                        &self.schema,
-                        visited_left_side,
-                        &mut self.indices_cache,
-                        self.right_side_ordered,
-                    );
-
-                    // Recording time & updating output metrics
-                    if let Ok(batch) = &result {
-                        timer.done();
-                        self.join_metrics.output_batches.add(1);
-                        self.join_metrics.output_rows.add(batch.num_rows());
-                    }
-
-                    Some(result)
-                }
-                Some(err) => Some(err),
-                None => {
-                    if need_produce_result_in_final(self.join_type) {
-                        // At this stage `visited_left_side` won't be updated, 
so it's
-                        // safe to report about probe completion.
-                        //
-                        // Setting `is_exhausted` / returning None will 
prevent from
-                        // multiple calls of `report_probe_completed()`
-                        if !left_data.report_probe_completed() {
-                            self.is_exhausted = true;
-                            return None;
-                        };
-
-                        // Only setting up timer, input is exhausted
-                        let timer = self.join_metrics.join_time.timer();
-                        // use the global left bitmap to produce the left 
indices and right indices
-                        let (left_side, right_side) =
-                            get_final_indices_from_shared_bitmap(
-                                visited_left_side,
-                                self.join_type,
-                            );
-                        let empty_right_batch =
-                            RecordBatch::new_empty(self.outer_table.schema());
-                        // use the left and right indices to produce the batch 
result
-                        let result = build_batch_from_indices(
-                            &self.schema,
-                            left_data.batch(),
-                            &empty_right_batch,
-                            &left_side,
-                            &right_side,
-                            &self.column_indices,
-                            JoinSide::Left,
-                        );
-                        self.is_exhausted = true;
-
-                        // Recording time & updating output metrics
-                        if let Ok(batch) = &result {
-                            timer.done();
-                            self.join_metrics.output_batches.add(1);
-                            
self.join_metrics.output_rows.add(batch.num_rows());
-                        }
-
-                        Some(result)
-                    } else {
-                        // end of the join loop
-                        None
-                    }
-                }
-            })
+    /// Processes unmatched build-side rows for certain join types and produces
+    /// output batch, updates state to `Completed`.
+    fn process_unmatched_build_batch(
+        &mut self,
+    ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+        let Some(left_data) = self.left_data.clone() else {
+            return internal_err!(
+                "Expected left_data to be Some in ExhaustedProbeSide state"
+            );
+        };
+        let visited_left_side = left_data.bitmap();
+        if need_produce_result_in_final(self.join_type) {
+            // At this stage `visited_left_side` won't be updated, so it's
+            // safe to report about probe completion.
+            //
+            // Setting `is_exhausted` / returning None will prevent from
+            // multiple calls of `report_probe_completed()`
+            if !left_data.report_probe_completed() {
+                self.state = NestedLoopJoinStreamState::Completed;
+                return Ok(StatefulStreamResult::Ready(None));
+            };
+
+            // Only setting up timer, input is exhausted
+            let timer = self.join_metrics.join_time.timer();
+            // use the global left bitmap to produce the left indices and 
right indices
+            let (left_side, right_side) =
+                get_final_indices_from_shared_bitmap(visited_left_side, 
self.join_type);
+            let empty_right_batch = 
RecordBatch::new_empty(self.outer_table.schema());
+            // use the left and right indices to produce the batch result
+            let result = build_batch_from_indices(
+                &self.schema,
+                left_data.batch(),
+                &empty_right_batch,
+                &left_side,
+                &right_side,
+                &self.column_indices,
+                JoinSide::Left,
+            );
+            self.state = NestedLoopJoinStreamState::Completed;
+
+            // Recording time
+            if result.is_ok() {
+                timer.done();
+            }
+
+            Ok(StatefulStreamResult::Ready(Some(result?)))
+        } else {
+            // end of the join loop
+            self.state = NestedLoopJoinStreamState::Completed;
+            Ok(StatefulStreamResult::Ready(None))
+        }
     }
 }
 
@@ -684,7 +804,7 @@ fn join_left_and_right_batch(
         0..right_batch.num_rows(),
         join_type,
         right_side_ordered,
-    );
+    )?;
 
     build_batch_from_indices(
         schema,
@@ -705,7 +825,7 @@ fn get_final_indices_from_shared_bitmap(
     get_final_indices_from_bit_map(&bitmap, join_type)
 }
 
-impl Stream for NestedLoopJoinStream {
+impl<T: BatchTransformer + Unpin + Send> Stream for NestedLoopJoinStream<T> {
     type Item = Result<RecordBatch>;
 
     fn poll_next(
@@ -716,14 +836,14 @@ impl Stream for NestedLoopJoinStream {
     }
 }
 
-impl RecordBatchStream for NestedLoopJoinStream {
+impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for 
NestedLoopJoinStream<T> {
     fn schema(&self) -> SchemaRef {
         Arc::clone(&self.schema)
     }
 }
 
 #[cfg(test)]
-mod tests {
+pub(crate) mod tests {
     use super::*;
     use crate::{
         common, expressions::Column, memory::MemoryExec, 
repartition::RepartitionExec,
@@ -850,7 +970,7 @@ mod tests {
         JoinFilter::new(filter_expression, column_indices, intermediate_schema)
     }
 
-    async fn multi_partitioned_join_collect(
+    pub(crate) async fn multi_partitioned_join_collect(
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
         join_type: &JoinType,
diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs 
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
index ba9384aef1..bddd152341 100644
--- a/datafusion/physical-plan/src/joins/stream_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -31,8 +31,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder};
 use arrow_schema::{Schema, SchemaRef};
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::{
-    arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, 
Result,
-    ScalarValue,
+    arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue,
 };
 use datafusion_expr::interval_arithmetic::Interval;
 use datafusion_physical_expr::expressions::Column;
@@ -369,34 +368,40 @@ impl SortedFilterExpr {
         filter_expr: Arc<dyn PhysicalExpr>,
         filter_schema: &Schema,
     ) -> Result<Self> {
-        let dt = &filter_expr.data_type(filter_schema)?;
+        let dt = filter_expr.data_type(filter_schema)?;
         Ok(Self {
             origin_sorted_expr,
             filter_expr,
-            interval: Interval::make_unbounded(dt)?,
+            interval: Interval::make_unbounded(&dt)?,
             node_index: 0,
         })
     }
+
     /// Get origin expr information
     pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr {
         &self.origin_sorted_expr
     }
+
     /// Get filter expr information
     pub fn filter_expr(&self) -> &Arc<dyn PhysicalExpr> {
         &self.filter_expr
     }
+
     /// Get interval information
     pub fn interval(&self) -> &Interval {
         &self.interval
     }
+
     /// Sets interval
     pub fn set_interval(&mut self, interval: Interval) {
         self.interval = interval;
     }
+
     /// Node index in ExprIntervalGraph
     pub fn node_index(&self) -> usize {
         self.node_index
     }
+
     /// Node index setter in ExprIntervalGraph
     pub fn set_node_index(&mut self, node_index: usize) {
         self.node_index = node_index;
@@ -409,41 +414,45 @@ impl SortedFilterExpr {
 /// on the first or the last value of the expression in `build_input_buffer`
 /// and `probe_batch`.
 ///
-/// # Arguments
+/// # Parameters
 ///
 /// * `build_input_buffer` - The [RecordBatch] on the build side of the join.
 /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update.
 /// * `probe_batch` - The `RecordBatch` on the probe side of the join.
 /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update.
 ///
-/// ### Note
-/// ```text
+/// ## Note
 ///
-/// Interval arithmetic is used to calculate viable join ranges for build-side
-/// pruning. This is done by first creating an interval for join filter values 
in
-/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on 
the
-/// ordering (descending/ascending) of the filter expression. Here, FV denotes 
the
-/// first value on the build side. This range is then compared with the probe 
side
-/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering
-/// (ascending/descending) of the probe side. Here, LV denotes the last value 
on
-/// the probe side.
+/// Utilizing interval arithmetic, this function computes feasible join 
intervals
+/// on the pruning side by evaluating the prospective value ranges that might
+/// emerge in subsequent data batches from the enforcer side. This is done by
+/// first creating an interval for join filter values in the pruning side of 
the
+/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering 
(descending/
+/// ascending) of the filter expression. Here, `FV` denotes the first value on 
the
+/// pruning side. This range is then compared with the enforcer side interval,
+/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering 
(ascending/
+/// descending) of the probe side. Here, `LV` denotes the last value on the 
enforcer
+/// side.
 ///
 /// As a concrete example, consider the following query:
 ///
+/// ```text
 ///   SELECT * FROM left_table, right_table
 ///   WHERE
 ///     left_key = right_key AND
 ///     a > b - 3 AND
 ///     a < b + 10
+/// ```
 ///
-/// where columns "a" and "b" come from tables "left_table" and "right_table",
+/// where columns `a` and `b` come from tables `left_table` and `right_table`,
 /// respectively. When a new `RecordBatch` arrives at the right side, the
-/// condition a > b - 3 will possibly indicate a prunable range for the left
+/// condition `a > b - 3` will possibly indicate a prunable range for the left
 /// side. Conversely, when a new `RecordBatch` arrives at the left side, the
-/// condition a < b + 10 will possibly indicate prunability for the right side.
-/// Let’s inspect what happens when a new RecordBatch` arrives at the right
+/// condition `a < b + 10` will possibly indicate prunability for the right 
side.
+/// Let’s inspect what happens when a new `RecordBatch` arrives at the right
 /// side (i.e. when the left side is the build side):
 ///
+/// ```text
 ///         Build      Probe
 ///       +-------+  +-------+
 ///       | a | z |  | b | y |
@@ -456,13 +465,13 @@ impl SortedFilterExpr {
 ///       |+--|--+|  |+--|--+|
 ///       | 7 | 1 |  | 6 | 3 |
 ///       +-------+  +-------+
+/// ```
 ///
 /// In this case, the interval representing viable (i.e. joinable) values for
-/// column "a" is [1, ∞], and the interval representing possible future values
-/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate
+/// column `a` is `[1, ∞]`, and the interval representing possible future 
values
+/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate
 /// intervals for the whole filter expression and propagate join constraint by
 /// traversing the expression graph.
-/// ```
 pub fn calculate_filter_expr_intervals(
     build_input_buffer: &RecordBatch,
     build_sorted_filter_expr: &mut SortedFilterExpr,
@@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices(
     }
 }
 
-/// Prepares and sorts expressions based on a given filter, left and right 
execution plans, and sort expressions.
+/// Prepares and sorts expressions based on a given filter, left and right 
schemas,
+/// and sort expressions.
 ///
-/// # Arguments
+/// This function prepares sorted filter expressions for both the left and 
right
+/// sides of a join operation. It first builds the filter order for each side
+/// based on the provided `ExecutionPlan`. If both sides have valid sorted 
filter
+/// expressions, the function then constructs an expression interval graph and
+/// updates the sorted expressions with node indices. The final sorted filter
+/// expressions for both sides are then returned.
+///
+/// # Parameters
 ///
 /// * `filter` - The join filter to base the sorting on.
-/// * `left` - The left execution plan.
-/// * `right` - The right execution plan.
+/// * `left` - The `ExecutionPlan` for the left side of the join.
+/// * `right` - The `ExecutionPlan` for the right side of the join.
 /// * `left_sort_exprs` - The expressions to sort on the left side.
 /// * `right_sort_exprs` - The expressions to sort on the right side.
 ///
@@ -730,9 +747,11 @@ pub fn prepare_sorted_exprs(
     left_sort_exprs: &[PhysicalSortExpr],
     right_sort_exprs: &[PhysicalSortExpr],
 ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
-    // Build the filter order for the left side
-    let err = || plan_datafusion_err!("Filter does not include the child 
order");
+    let err = || {
+        datafusion_common::plan_datafusion_err!("Filter does not include the 
child order")
+    };
 
+    // Build the filter order for the left side:
     let left_temp_sorted_filter_expr = build_filter_input_order(
         JoinSide::Left,
         filter,
@@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs(
     )?
     .ok_or_else(err)?;
 
-    // Build the filter order for the right side
+    // Build the filter order for the right side:
     let right_temp_sorted_filter_expr = build_filter_input_order(
         JoinSide::Right,
         filter,
@@ -952,15 +971,15 @@ pub mod tests {
         let filter_expr = complicated_filter(&intermediate_schema)?;
         let column_indices = vec![
             ColumnIndex {
-                index: 0,
+                index: left_schema.index_of("la1")?,
                 side: JoinSide::Left,
             },
             ColumnIndex {
-                index: 4,
+                index: left_schema.index_of("la2")?,
                 side: JoinSide::Left,
             },
             ColumnIndex {
-                index: 0,
+                index: right_schema.index_of("ra1")?,
                 side: JoinSide::Right,
             },
         ];
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index ac718a95e9..70ada3892a 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -32,7 +32,6 @@ use std::task::{Context, Poll};
 use std::vec;
 
 use crate::common::SharedMemoryReservation;
-use crate::handle_state;
 use crate::joins::hash_join::{equal_rows_arr, update_hash};
 use crate::joins::stream_join_utils::{
     calculate_filter_expr_intervals, combine_two_batches,
@@ -42,8 +41,9 @@ use crate::joins::stream_join_utils::{
 };
 use crate::joins::utils::{
     apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
-    check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, 
JoinFilter,
-    JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult,
+    check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter,
+    BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, 
JoinOnRef,
+    NoopBatchTransformer, StatefulStreamResult,
 };
 use crate::{
     execution_mode_from_children,
@@ -465,23 +465,27 @@ impl ExecutionPlan for SymmetricHashJoinExec {
                  consider using RepartitionExec"
             );
         }
-        // If `filter_state` and `filter` are both present, then calculate 
sorted filter expressions
-        // for both sides, and build an expression graph.
-        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
-            match (&self.left_sort_exprs, &self.right_sort_exprs, 
&self.filter) {
-                (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) 
=> {
-                    let (left, right, graph) = prepare_sorted_exprs(
-                        filter,
-                        &self.left,
-                        &self.right,
-                        left_sort_exprs,
-                        right_sort_exprs,
-                    )?;
-                    (Some(left), Some(right), Some(graph))
-                }
-                // If `filter_state` or `filter` is not present, then return 
None for all three values:
-                _ => (None, None, None),
-            };
+        // If `filter_state` and `filter` are both present, then calculate 
sorted
+        // filter expressions for both sides, and build an expression graph.
+        let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match 
(
+            self.left_sort_exprs(),
+            self.right_sort_exprs(),
+            &self.filter,
+        ) {
+            (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
+                let (left, right, graph) = prepare_sorted_exprs(
+                    filter,
+                    &self.left,
+                    &self.right,
+                    left_sort_exprs,
+                    right_sort_exprs,
+                )?;
+                (Some(left), Some(right), Some(graph))
+            }
+            // If `filter_state` or `filter` is not present, then return None
+            // for all three values:
+            _ => (None, None, None),
+        };
 
         let (on_left, on_right) = self.on.iter().cloned().unzip();
 
@@ -494,6 +498,10 @@ impl ExecutionPlan for SymmetricHashJoinExec {
 
         let right_stream = self.right.execute(partition, 
Arc::clone(&context))?;
 
+        let batch_size = context.session_config().batch_size();
+        let enforce_batch_size_in_joins =
+            context.session_config().enforce_batch_size_in_joins();
+
         let reservation = Arc::new(Mutex::new(
             
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
                 .register(context.memory_pool()),
@@ -502,29 +510,52 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             reservation.lock().try_grow(g.size())?;
         }
 
-        Ok(Box::pin(SymmetricHashJoinStream {
-            left_stream,
-            right_stream,
-            schema: self.schema(),
-            filter: self.filter.clone(),
-            join_type: self.join_type,
-            random_state: self.random_state.clone(),
-            left: left_side_joiner,
-            right: right_side_joiner,
-            column_indices: self.column_indices.clone(),
-            metrics: StreamJoinMetrics::new(partition, &self.metrics),
-            graph,
-            left_sorted_filter_expr,
-            right_sorted_filter_expr,
-            null_equals_null: self.null_equals_null,
-            state: SHJStreamState::PullRight,
-            reservation,
-        }))
+        if enforce_batch_size_in_joins {
+            Ok(Box::pin(SymmetricHashJoinStream {
+                left_stream,
+                right_stream,
+                schema: self.schema(),
+                filter: self.filter.clone(),
+                join_type: self.join_type,
+                random_state: self.random_state.clone(),
+                left: left_side_joiner,
+                right: right_side_joiner,
+                column_indices: self.column_indices.clone(),
+                metrics: StreamJoinMetrics::new(partition, &self.metrics),
+                graph,
+                left_sorted_filter_expr,
+                right_sorted_filter_expr,
+                null_equals_null: self.null_equals_null,
+                state: SHJStreamState::PullRight,
+                reservation,
+                batch_transformer: BatchSplitter::new(batch_size),
+            }))
+        } else {
+            Ok(Box::pin(SymmetricHashJoinStream {
+                left_stream,
+                right_stream,
+                schema: self.schema(),
+                filter: self.filter.clone(),
+                join_type: self.join_type,
+                random_state: self.random_state.clone(),
+                left: left_side_joiner,
+                right: right_side_joiner,
+                column_indices: self.column_indices.clone(),
+                metrics: StreamJoinMetrics::new(partition, &self.metrics),
+                graph,
+                left_sorted_filter_expr,
+                right_sorted_filter_expr,
+                null_equals_null: self.null_equals_null,
+                state: SHJStreamState::PullRight,
+                reservation,
+                batch_transformer: NoopBatchTransformer::new(),
+            }))
+        }
     }
 }
 
 /// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
-struct SymmetricHashJoinStream {
+struct SymmetricHashJoinStream<T> {
     /// Input streams
     left_stream: SendableRecordBatchStream,
     right_stream: SendableRecordBatchStream,
@@ -556,15 +587,19 @@ struct SymmetricHashJoinStream {
     reservation: SharedMemoryReservation,
     /// State machine for input execution
     state: SHJStreamState,
+    /// Transforms the output batch before returning.
+    batch_transformer: T,
 }
 
-impl RecordBatchStream for SymmetricHashJoinStream {
+impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
+    for SymmetricHashJoinStream<T>
+{
     fn schema(&self) -> SchemaRef {
         Arc::clone(&self.schema)
     }
 }
 
-impl Stream for SymmetricHashJoinStream {
+impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T> 
{
     type Item = Result<RecordBatch>;
 
     fn poll_next(
@@ -1140,7 +1175,7 @@ impl OneSideHashJoiner {
 /// - Transition to `BothExhausted { final_result: true }`:
 ///   - Occurs in `prepare_for_final_results_after_exhaustion` when both 
streams are
 ///     exhausted, indicating completion of processing and availability of 
final results.
-impl SymmetricHashJoinStream {
+impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
     /// Implements the main polling logic for the join stream.
     ///
     /// This method continuously checks the state of the join stream and
@@ -1159,26 +1194,45 @@ impl SymmetricHashJoinStream {
         cx: &mut Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
         loop {
-            return match self.state() {
-                SHJStreamState::PullRight => {
-                    
handle_state!(ready!(self.fetch_next_from_right_stream(cx)))
-                }
-                SHJStreamState::PullLeft => {
-                    handle_state!(ready!(self.fetch_next_from_left_stream(cx)))
+            match self.batch_transformer.next() {
+                None => {
+                    let result = match self.state() {
+                        SHJStreamState::PullRight => {
+                            ready!(self.fetch_next_from_right_stream(cx))
+                        }
+                        SHJStreamState::PullLeft => {
+                            ready!(self.fetch_next_from_left_stream(cx))
+                        }
+                        SHJStreamState::RightExhausted => {
+                            ready!(self.handle_right_stream_end(cx))
+                        }
+                        SHJStreamState::LeftExhausted => {
+                            ready!(self.handle_left_stream_end(cx))
+                        }
+                        SHJStreamState::BothExhausted {
+                            final_result: false,
+                        } => self.prepare_for_final_results_after_exhaustion(),
+                        SHJStreamState::BothExhausted { final_result: true } 
=> {
+                            return Poll::Ready(None);
+                        }
+                    };
+
+                    match result? {
+                        StatefulStreamResult::Ready(None) => {
+                            return Poll::Ready(None);
+                        }
+                        StatefulStreamResult::Ready(Some(batch)) => {
+                            self.batch_transformer.set_batch(batch);
+                        }
+                        _ => {}
+                    }
                 }
-                SHJStreamState::RightExhausted => {
-                    handle_state!(ready!(self.handle_right_stream_end(cx)))
-                }
-                SHJStreamState::LeftExhausted => {
-                    handle_state!(ready!(self.handle_left_stream_end(cx)))
-                }
-                SHJStreamState::BothExhausted {
-                    final_result: false,
-                } => {
-                    
handle_state!(self.prepare_for_final_results_after_exhaustion())
+                Some((batch, _)) => {
+                    self.metrics.output_batches.add(1);
+                    self.metrics.output_rows.add(batch.num_rows());
+                    return Poll::Ready(Some(Ok(batch)));
                 }
-                SHJStreamState::BothExhausted { final_result: true } => 
Poll::Ready(None),
-            };
+            }
         }
     }
     /// Asynchronously pulls the next batch from the right stream.
@@ -1384,11 +1438,8 @@ impl SymmetricHashJoinStream {
         // Combine the left and right results:
         let result = combine_two_batches(&self.schema, left_result, 
right_result)?;
 
-        // Update the metrics and return the result:
-        if let Some(batch) = &result {
-            // Update the metrics:
-            self.metrics.output_batches.add(1);
-            self.metrics.output_rows.add(batch.num_rows());
+        // Return the result:
+        if result.is_some() {
             return Ok(StatefulStreamResult::Ready(result));
         }
         Ok(StatefulStreamResult::Continue)
@@ -1523,11 +1574,6 @@ impl SymmetricHashJoinStream {
         let capacity = self.size();
         self.metrics.stream_memory_usage.set(capacity);
         self.reservation.lock().try_resize(capacity)?;
-        // Update the metrics if we have a batch; otherwise, continue the loop.
-        if let Some(batch) = &result {
-            self.metrics.output_batches.add(1);
-            self.metrics.output_rows.add(batch.num_rows());
-        }
         Ok(result)
     }
 }
@@ -1716,15 +1762,15 @@ mod tests {
         let filter_expr = complicated_filter(&intermediate_schema)?;
         let column_indices = vec![
             ColumnIndex {
-                index: 0,
+                index: left_schema.index_of("la1")?,
                 side: JoinSide::Left,
             },
             ColumnIndex {
-                index: 4,
+                index: left_schema.index_of("la2")?,
                 side: JoinSide::Left,
             },
             ColumnIndex {
-                index: 0,
+                index: right_schema.index_of("ra1")?,
                 side: JoinSide::Right,
             },
         ];
@@ -1771,10 +1817,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Int32, true),
@@ -1825,10 +1868,7 @@ mod tests {
         let (left, right) =
             create_memory_table(left_partition, right_partition, vec![], 
vec![])?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Int32, true),
@@ -1877,10 +1917,7 @@ mod tests {
         let (left, right) =
             create_memory_table(left_partition, right_partition, vec![], 
vec![])?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
         experiment(left, right, None, join_type, on, task_ctx).await?;
         Ok(())
     }
@@ -1926,10 +1963,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Int32, true),
@@ -1987,10 +2021,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Int32, true),
@@ -2048,10 +2079,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Int32, true),
@@ -2111,10 +2139,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Int32, true),
@@ -2170,10 +2195,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("0", DataType::Int32, true),
@@ -2237,10 +2259,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("0", DataType::Int32, true),
@@ -2296,10 +2315,7 @@ mod tests {
 
         let left_schema = &left_partition[0].schema();
         let right_schema = &right_partition[0].schema();
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
         let left_sorted = vec![PhysicalSortExpr {
             expr: col("lt1", left_schema)?,
             options: SortOptions {
@@ -2380,10 +2396,7 @@ mod tests {
 
         let left_schema = &left_partition[0].schema();
         let right_schema = &right_partition[0].schema();
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
         let left_sorted = vec![PhysicalSortExpr {
             expr: col("li1", left_schema)?,
             options: SortOptions {
@@ -2473,10 +2486,7 @@ mod tests {
             vec![right_sorted],
         )?;
 
-        let on = vec![(
-            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
-            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
-        )];
+        let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
 
         let intermediate_schema = Schema::new(vec![
             Field::new("left", DataType::Float64, true),
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index 89f3feaf07..c520e42714 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -546,15 +546,16 @@ pub struct ColumnIndex {
     pub side: JoinSide,
 }
 
-/// Filter applied before join output
+/// Filter applied before join output. Fields are crate-public to allow
+/// downstream implementations to experiment with custom joins.
 #[derive(Debug, Clone)]
 pub struct JoinFilter {
     /// Filter expression
-    expression: Arc<dyn PhysicalExpr>,
+    pub(crate) expression: Arc<dyn PhysicalExpr>,
     /// Column indices required to construct intermediate batch for filtering
-    column_indices: Vec<ColumnIndex>,
+    pub(crate) column_indices: Vec<ColumnIndex>,
     /// Physical schema of intermediate batch
-    schema: Schema,
+    pub(crate) schema: Schema,
 }
 
 impl JoinFilter {
@@ -1280,15 +1281,15 @@ pub(crate) fn adjust_indices_by_join_type(
     adjust_range: Range<usize>,
     join_type: JoinType,
     preserve_order_for_right: bool,
-) -> (UInt64Array, UInt32Array) {
+) -> Result<(UInt64Array, UInt32Array)> {
     match join_type {
         JoinType::Inner => {
             // matched
-            (left_indices, right_indices)
+            Ok((left_indices, right_indices))
         }
         JoinType::Left => {
             // matched
-            (left_indices, right_indices)
+            Ok((left_indices, right_indices))
             // unmatched left row will be produced in the end of loop, and it 
has been set in the left visited bitmap
         }
         JoinType::Right => {
@@ -1307,22 +1308,22 @@ pub(crate) fn adjust_indices_by_join_type(
             // need to remove the duplicated record in the right side
             let right_indices = get_semi_indices(adjust_range, &right_indices);
             // the left_indices will not be used later for the `right semi` 
join
-            (left_indices, right_indices)
+            Ok((left_indices, right_indices))
         }
         JoinType::RightAnti => {
             // need to remove the duplicated record in the right side
             // get the anti index for the right side
             let right_indices = get_anti_indices(adjust_range, &right_indices);
             // the left_indices will not be used later for the `right anti` 
join
-            (left_indices, right_indices)
+            Ok((left_indices, right_indices))
         }
         JoinType::LeftSemi | JoinType::LeftAnti => {
             // matched or unmatched left row will be produced in the end of 
loop
             // When visit the right batch, we can output the matched left row 
and don't need to wait the end of loop
-            (
+            Ok((
                 UInt64Array::from_iter_values(vec![]),
                 UInt32Array::from_iter_values(vec![]),
-            )
+            ))
         }
     }
 }
@@ -1347,27 +1348,64 @@ pub(crate) fn append_right_indices(
     right_indices: UInt32Array,
     adjust_range: Range<usize>,
     preserve_order_for_right: bool,
-) -> (UInt64Array, UInt32Array) {
+) -> Result<(UInt64Array, UInt32Array)> {
     if preserve_order_for_right {
-        append_probe_indices_in_order(left_indices, right_indices, 
adjust_range)
+        Ok(append_probe_indices_in_order(
+            left_indices,
+            right_indices,
+            adjust_range,
+        ))
     } else {
         let right_unmatched_indices = get_anti_indices(adjust_range, 
&right_indices);
 
         if right_unmatched_indices.is_empty() {
-            (left_indices, right_indices)
+            Ok((left_indices, right_indices))
         } else {
-            let unmatched_size = right_unmatched_indices.len();
+            // `into_builder()` can fail here when there is nothing to be 
filtered and
+            // left_indices or right_indices has the same reference to the 
cached indices.
+            // In that case, we use a slower alternative.
+
             // the new left indices: left_indices + null array
+            let mut new_left_indices_builder =
+                left_indices.into_builder().unwrap_or_else(|left_indices| {
+                    let mut builder = UInt64Builder::with_capacity(
+                        left_indices.len() + right_unmatched_indices.len(),
+                    );
+                    debug_assert_eq!(
+                        left_indices.null_count(),
+                        0,
+                        "expected left indices to have no nulls"
+                    );
+                    builder.append_slice(left_indices.values());
+                    builder
+                });
+            
new_left_indices_builder.append_nulls(right_unmatched_indices.len());
+            let new_left_indices = 
UInt64Array::from(new_left_indices_builder.finish());
+
             // the new right indices: right_indices + right_unmatched_indices
-            let new_left_indices = left_indices
-                .iter()
-                .chain(std::iter::repeat(None).take(unmatched_size))
-                .collect();
-            let new_right_indices = right_indices
-                .iter()
-                .chain(right_unmatched_indices.iter())
-                .collect();
-            (new_left_indices, new_right_indices)
+            let mut new_right_indices_builder = right_indices
+                .into_builder()
+                .unwrap_or_else(|right_indices| {
+                    let mut builder = UInt32Builder::with_capacity(
+                        right_indices.len() + right_unmatched_indices.len(),
+                    );
+                    debug_assert_eq!(
+                        right_indices.null_count(),
+                        0,
+                        "expected right indices to have no nulls"
+                    );
+                    builder.append_slice(right_indices.values());
+                    builder
+                });
+            debug_assert_eq!(
+                right_unmatched_indices.null_count(),
+                0,
+                "expected right unmatched indices to have no nulls"
+            );
+            
new_right_indices_builder.append_slice(right_unmatched_indices.values());
+            let new_right_indices = 
UInt32Array::from(new_right_indices_builder.finish());
+
+            Ok((new_left_indices, new_right_indices))
         }
     }
 }
@@ -1635,6 +1673,91 @@ pub(crate) fn asymmetric_join_output_partitioning(
     }
 }
 
+/// Trait for incrementally generating Join output.
+///
+/// This trait is used to limit some join outputs
+/// so it does not produce single large batches
+pub(crate) trait BatchTransformer: Debug + Clone {
+    /// Sets the next `RecordBatch` to be processed.
+    fn set_batch(&mut self, batch: RecordBatch);
+
+    /// Retrieves the next `RecordBatch` from the transformer.
+    /// Returns `None` if all batches have been produced.
+    /// The boolean flag indicates whether the batch is the last one.
+    fn next(&mut self) -> Option<(RecordBatch, bool)>;
+}
+
+#[derive(Debug, Clone)]
+/// A batch transformer that does nothing.
+pub(crate) struct NoopBatchTransformer {
+    /// RecordBatch to be processed
+    batch: Option<RecordBatch>,
+}
+
+impl NoopBatchTransformer {
+    pub fn new() -> Self {
+        Self { batch: None }
+    }
+}
+
+impl BatchTransformer for NoopBatchTransformer {
+    fn set_batch(&mut self, batch: RecordBatch) {
+        self.batch = Some(batch);
+    }
+
+    fn next(&mut self) -> Option<(RecordBatch, bool)> {
+        self.batch.take().map(|batch| (batch, true))
+    }
+}
+
+#[derive(Debug, Clone)]
+/// Splits large batches into smaller batches with a maximum number of rows.
+pub(crate) struct BatchSplitter {
+    /// RecordBatch to be split
+    batch: Option<RecordBatch>,
+    /// Maximum number of rows in a split batch
+    batch_size: usize,
+    /// Current row index
+    row_index: usize,
+}
+
+impl BatchSplitter {
+    /// Creates a new `BatchSplitter` with the specified batch size.
+    pub(crate) fn new(batch_size: usize) -> Self {
+        Self {
+            batch: None,
+            batch_size,
+            row_index: 0,
+        }
+    }
+}
+
+impl BatchTransformer for BatchSplitter {
+    fn set_batch(&mut self, batch: RecordBatch) {
+        self.batch = Some(batch);
+        self.row_index = 0;
+    }
+
+    fn next(&mut self) -> Option<(RecordBatch, bool)> {
+        let Some(batch) = &self.batch else {
+            return None;
+        };
+
+        let remaining_rows = batch.num_rows() - self.row_index;
+        let rows_to_slice = remaining_rows.min(self.batch_size);
+        let sliced_batch = batch.slice(self.row_index, rows_to_slice);
+        self.row_index += rows_to_slice;
+
+        let mut last = false;
+        if self.row_index >= batch.num_rows() {
+            self.batch = None;
+            last = true;
+        }
+
+        Some((sliced_batch, last))
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use std::pin::Pin;
@@ -1643,11 +1766,13 @@ mod tests {
 
     use arrow::datatypes::{DataType, Fields};
     use arrow::error::{ArrowError, Result as ArrowResult};
+    use arrow_array::Int32Array;
     use arrow_schema::SortOptions;
-
     use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
     use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
 
+    use rstest::rstest;
+
     fn check(
         left: &[Column],
         right: &[Column],
@@ -2554,4 +2679,49 @@ mod tests {
 
         Ok(())
     }
+
+    fn create_test_batch(num_rows: usize) -> RecordBatch {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, false)]));
+        let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
+        RecordBatch::try_new(schema, vec![data]).unwrap()
+    }
+
+    fn assert_split_batches(
+        batches: Vec<(RecordBatch, bool)>,
+        batch_size: usize,
+        num_rows: usize,
+    ) {
+        let mut row_count = 0;
+        for (batch, last) in batches.into_iter() {
+            assert_eq!(batch.num_rows(), (num_rows - 
row_count).min(batch_size));
+            let column = batch
+                .column(0)
+                .as_any()
+                .downcast_ref::<Int32Array>()
+                .unwrap();
+            for i in 0..batch.num_rows() {
+                assert_eq!(column.value(i), i as i32 + row_count as i32);
+            }
+            row_count += batch.num_rows();
+            assert_eq!(last, row_count == num_rows);
+        }
+    }
+
+    #[rstest]
+    #[test]
+    fn test_batch_splitter(
+        #[values(1, 3, 11)] batch_size: usize,
+        #[values(1, 6, 50)] num_rows: usize,
+    ) {
+        let mut splitter = BatchSplitter::new(batch_size);
+        splitter.set_batch(create_test_batch(num_rows));
+
+        let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
+        while let Some(batch) = splitter.next() {
+            batches.push(batch);
+        }
+
+        assert!(splitter.next().is_none());
+        assert_split_batches(batches, batch_size, num_rows);
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/information_schema.slt 
b/datafusion/sqllogictest/test_files/information_schema.slt
index 7acdf25b65..57bf029a63 100644
--- a/datafusion/sqllogictest/test_files/information_schema.slt
+++ b/datafusion/sqllogictest/test_files/information_schema.slt
@@ -173,6 +173,7 @@ datafusion.execution.batch_size 8192
 datafusion.execution.coalesce_batches true
 datafusion.execution.collect_statistics false
 datafusion.execution.enable_recursive_ctes true
+datafusion.execution.enforce_batch_size_in_joins false
 datafusion.execution.keep_partition_by_columns false
 datafusion.execution.listing_table_ignore_subdirectory true
 datafusion.execution.max_buffered_batches_per_output_file 2
@@ -263,6 +264,7 @@ datafusion.execution.batch_size 8192 Default batch size 
while creating new batch
 datafusion.execution.coalesce_batches true When set to true, record batches 
will be examined between each operator and small batches will be coalesced into 
larger batches. This is helpful when there are highly selective filters or 
joins that could produce tiny output batches. The target batch size is 
determined by the configuration setting
 datafusion.execution.collect_statistics false Should DataFusion collect 
statistics after listing files
 datafusion.execution.enable_recursive_ctes true Should DataFusion support 
recursive CTEs
+datafusion.execution.enforce_batch_size_in_joins false Should DataFusion 
enforce batch size in joins or not. By default, DataFusion will not enforce 
batch size in joins. Enforcing batch size in joins can reduce memory usage when 
joining large tables with a highly-selective join filter, but is also slightly 
slower.
 datafusion.execution.keep_partition_by_columns false Should DataFusion keep 
the columns used for partition_by in the output RecordBatches
 datafusion.execution.listing_table_ignore_subdirectory true Should sub 
directories be ignored when scanning directories for data files. Defaults to 
true (ignores subdirectories), consistent with Hive. Note that this setting 
does not affect reading partitioned tables (e.g. 
`/table/year=2021/month=01/data.parquet`).
 datafusion.execution.max_buffered_batches_per_output_file 2 This is the 
maximum number of RecordBatches buffered for each output file being worked. 
Higher values can potentially give faster write performance at the cost of 
higher peak memory consumption
diff --git a/docs/source/user-guide/configs.md 
b/docs/source/user-guide/configs.md
index f34d148f09..c61a7b6733 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -91,6 +91,7 @@ Environment variables are read during `SessionConfig` 
initialisation so they mus
 | datafusion.execution.skip_partial_aggregation_probe_ratio_threshold     | 
0.8                       | Aggregation ratio (number of distinct groups / 
number of input rows) threshold for skipping partial aggregation. If the value 
is greater then partial aggregation will skip aggregation for further input     
                                                                                
                                                                                
                       [...]
 | datafusion.execution.skip_partial_aggregation_probe_rows_threshold      | 
100000                    | Number of input rows partial aggregation partition 
should process, before aggregation ratio check and trying to switch to skipping 
aggregation mode                                                                
                                                                                
                                                                                
                  [...]
 | datafusion.execution.use_row_number_estimates_to_optimize_partitioning  | 
false                     | Should DataFusion use row number estimates at the 
input to decide whether increasing parallelism is beneficial or not. By 
default, only exact row numbers (not estimates) are used for this decision. 
Setting this flag to `true` will likely produce better plans. if the source of 
statistics is accurate. We plan to make this the default in the future.         
                                [...]
+| datafusion.execution.enforce_batch_size_in_joins                        | 
false                     | Should DataFusion enforce batch size in joins or 
not. By default, DataFusion will not enforce batch size in joins. Enforcing 
batch size in joins can reduce memory usage when joining large tables with a 
highly-selective join filter, but is also slightly slower.                      
                                                                                
                           [...]
 | datafusion.optimizer.enable_distinct_aggregation_soft_limit             | 
true                      | When set to true, the optimizer will push a limit 
operation into grouped aggregations which have no aggregate expressions, as a 
soft limit, emitting groups once the limit is reached, before all rows in the 
group are read.                                                                 
                                                                                
                       [...]
 | datafusion.optimizer.enable_round_robin_repartition                     | 
true                      | When set to true, the physical plan optimizer will 
try to add round robin repartitioning to increase parallelism to leverage more 
CPU cores                                                                       
                                                                                
                                                                                
                   [...]
 | datafusion.optimizer.enable_topk_aggregation                            | 
true                      | When set to true, the optimizer will attempt to 
perform limit operations during aggregations, if possible                       
                                                                                
                                                                                
                                                                                
                     [...]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to