This is an automated email from the ASF dual-hosted git repository.
ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 87e931c976 Split output batches of joins that do not respect batch
size (#12969)
87e931c976 is described below
commit 87e931c976a7aa24cecaa9bf3658b42bba12a51e
Author: Alihan Çelikcan <[email protected]>
AuthorDate: Fri Oct 18 14:34:42 2024 +0300
Split output batches of joins that do not respect batch size (#12969)
* Add BatchSplitter to joins that do not respect batch size
* Group relevant imports
* Update configs.md
* Update SQL logic tests for config
* Review
* Use PrimitiveBuilder for PrimitiveArray concatenation
* Fix into_builder() bug
* Apply suggestions from code review
Co-authored-by: Andrew Lamb <[email protected]>
* Update config docs
* Format
* Update config SQL Logic Test
---------
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/common/src/config.rs | 26 +-
datafusion/execution/src/config.rs | 14 +
datafusion/physical-plan/src/joins/cross_join.rs | 84 +++--
datafusion/physical-plan/src/joins/hash_join.rs | 2 +-
.../physical-plan/src/joins/nested_loop_join.rs | 356 ++++++++++++++-------
.../physical-plan/src/joins/stream_join_utils.rs | 83 +++--
.../physical-plan/src/joins/symmetric_hash_join.rs | 252 ++++++++-------
datafusion/physical-plan/src/joins/utils.rs | 220 +++++++++++--
.../sqllogictest/test_files/information_schema.slt | 2 +
docs/source/user-guide/configs.md | 1 +
10 files changed, 709 insertions(+), 331 deletions(-)
diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs
index 1e1c5d5424..47ffe0b1c6 100644
--- a/datafusion/common/src/config.rs
+++ b/datafusion/common/src/config.rs
@@ -338,6 +338,12 @@ config_namespace! {
/// if the source of statistics is accurate.
/// We plan to make this the default in the future.
pub use_row_number_estimates_to_optimize_partitioning: bool, default =
false
+
+ /// Should DataFusion enforce batch size in joins or not. By default,
+ /// DataFusion will not enforce batch size in joins. Enforcing batch
size
+ /// in joins can reduce memory usage when joining large
+ /// tables with a highly-selective join filter, but is also slightly
slower.
+ pub enforce_batch_size_in_joins: bool, default = false
}
}
@@ -1222,16 +1228,18 @@ impl ConfigField for TableOptions {
fn set(&mut self, key: &str, value: &str) -> Result<()> {
// Extensions are handled in the public `ConfigOptions::set`
let (key, rem) = key.split_once('.').unwrap_or((key, ""));
- let Some(format) = &self.current_format else {
- return _config_err!("Specify a format for TableOptions");
- };
match key {
- "format" => match format {
- #[cfg(feature = "parquet")]
- ConfigFileType::PARQUET => self.parquet.set(rem, value),
- ConfigFileType::CSV => self.csv.set(rem, value),
- ConfigFileType::JSON => self.json.set(rem, value),
- },
+ "format" => {
+ let Some(format) = &self.current_format else {
+ return _config_err!("Specify a format for TableOptions");
+ };
+ match format {
+ #[cfg(feature = "parquet")]
+ ConfigFileType::PARQUET => self.parquet.set(rem, value),
+ ConfigFileType::CSV => self.csv.set(rem, value),
+ ConfigFileType::JSON => self.json.set(rem, value),
+ }
+ }
_ => _config_err!("Config value \"{key}\" not found on
TableOptions"),
}
}
diff --git a/datafusion/execution/src/config.rs
b/datafusion/execution/src/config.rs
index cede75d21c..53646dc5b4 100644
--- a/datafusion/execution/src/config.rs
+++ b/datafusion/execution/src/config.rs
@@ -432,6 +432,20 @@ impl SessionConfig {
self
}
+ /// Enables or disables the enforcement of batch size in joins
+ pub fn with_enforce_batch_size_in_joins(
+ mut self,
+ enforce_batch_size_in_joins: bool,
+ ) -> Self {
+ self.options.execution.enforce_batch_size_in_joins =
enforce_batch_size_in_joins;
+ self
+ }
+
+ /// Returns true if the joins will be enforced to output batches of the
configured size
+ pub fn enforce_batch_size_in_joins(&self) -> bool {
+ self.options.execution.enforce_batch_size_in_joins
+ }
+
/// Convert configuration options to name-value pairs with values
/// converted to strings.
///
diff --git a/datafusion/physical-plan/src/joins/cross_join.rs
b/datafusion/physical-plan/src/joins/cross_join.rs
index a70645f3d6..8f2bef56da 100644
--- a/datafusion/physical-plan/src/joins/cross_join.rs
+++ b/datafusion/physical-plan/src/joins/cross_join.rs
@@ -19,7 +19,8 @@
//! and producing batches in parallel for the right partitions
use super::utils::{
- adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync,
OnceFut,
+ adjust_right_output_partitioning, BatchSplitter, BatchTransformer,
+ BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut,
StatefulStreamResult,
};
use crate::coalesce_partitions::CoalescePartitionsExec;
@@ -86,6 +87,7 @@ impl CrossJoinExec {
let schema =
Arc::new(Schema::new(all_columns).with_metadata(metadata));
let cache = Self::compute_properties(&left, &right,
Arc::clone(&schema));
+
CrossJoinExec {
left,
right,
@@ -246,6 +248,10 @@ impl ExecutionPlan for CrossJoinExec {
let reservation =
MemoryConsumer::new("CrossJoinExec").register(context.memory_pool());
+ let batch_size = context.session_config().batch_size();
+ let enforce_batch_size_in_joins =
+ context.session_config().enforce_batch_size_in_joins();
+
let left_fut = self.left_fut.once(|| {
load_left_input(
Arc::clone(&self.left),
@@ -255,15 +261,29 @@ impl ExecutionPlan for CrossJoinExec {
)
});
- Ok(Box::pin(CrossJoinStream {
- schema: Arc::clone(&self.schema),
- left_fut,
- right: stream,
- left_index: 0,
- join_metrics,
- state: CrossJoinStreamState::WaitBuildSide,
- left_data: RecordBatch::new_empty(self.left().schema()),
- }))
+ if enforce_batch_size_in_joins {
+ Ok(Box::pin(CrossJoinStream {
+ schema: Arc::clone(&self.schema),
+ left_fut,
+ right: stream,
+ left_index: 0,
+ join_metrics,
+ state: CrossJoinStreamState::WaitBuildSide,
+ left_data: RecordBatch::new_empty(self.left().schema()),
+ batch_transformer: BatchSplitter::new(batch_size),
+ }))
+ } else {
+ Ok(Box::pin(CrossJoinStream {
+ schema: Arc::clone(&self.schema),
+ left_fut,
+ right: stream,
+ left_index: 0,
+ join_metrics,
+ state: CrossJoinStreamState::WaitBuildSide,
+ left_data: RecordBatch::new_empty(self.left().schema()),
+ batch_transformer: NoopBatchTransformer::new(),
+ }))
+ }
}
fn statistics(&self) -> Result<Statistics> {
@@ -319,7 +339,7 @@ fn stats_cartesian_product(
}
/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
-struct CrossJoinStream {
+struct CrossJoinStream<T> {
/// Input schema
schema: Arc<Schema>,
/// Future for data from left side
@@ -334,9 +354,11 @@ struct CrossJoinStream {
state: CrossJoinStreamState,
/// Left data
left_data: RecordBatch,
+ /// Batch transformer
+ batch_transformer: T,
}
-impl RecordBatchStream for CrossJoinStream {
+impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for
CrossJoinStream<T> {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
@@ -390,7 +412,7 @@ fn build_batch(
}
#[async_trait]
-impl Stream for CrossJoinStream {
+impl<T: BatchTransformer + Unpin + Send> Stream for CrossJoinStream<T> {
type Item = Result<RecordBatch>;
fn poll_next(
@@ -401,7 +423,7 @@ impl Stream for CrossJoinStream {
}
}
-impl CrossJoinStream {
+impl<T: BatchTransformer> CrossJoinStream<T> {
/// Separate implementation function that unpins the [`CrossJoinStream`] so
/// that partial borrows work correctly
fn poll_next_impl(
@@ -470,21 +492,33 @@ impl CrossJoinStream {
fn build_batches(&mut self) ->
Result<StatefulStreamResult<Option<RecordBatch>>> {
let right_batch = self.state.try_as_record_batch()?;
if self.left_index < self.left_data.num_rows() {
- let join_timer = self.join_metrics.join_time.timer();
- let result =
- build_batch(self.left_index, right_batch, &self.left_data,
&self.schema);
- join_timer.done();
-
- if let Ok(ref batch) = result {
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
+ match self.batch_transformer.next() {
+ None => {
+ let join_timer = self.join_metrics.join_time.timer();
+ let result = build_batch(
+ self.left_index,
+ right_batch,
+ &self.left_data,
+ &self.schema,
+ );
+ join_timer.done();
+
+ self.batch_transformer.set_batch(result?);
+ }
+ Some((batch, last)) => {
+ if last {
+ self.left_index += 1;
+ }
+
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ return Ok(StatefulStreamResult::Ready(Some(batch)));
+ }
}
- self.left_index += 1;
- result.map(|r| StatefulStreamResult::Ready(Some(r)))
} else {
self.state = CrossJoinStreamState::FetchProbeBatch;
- Ok(StatefulStreamResult::Continue)
}
+ Ok(StatefulStreamResult::Continue)
}
}
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs
b/datafusion/physical-plan/src/joins/hash_join.rs
index 74a45a7e47..3b730c0129 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -1438,7 +1438,7 @@ impl HashJoinStream {
index_alignment_range_start..index_alignment_range_end,
self.join_type,
self.right_side_ordered,
- );
+ )?;
let result = build_batch_from_indices(
&self.schema,
diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs
b/datafusion/physical-plan/src/joins/nested_loop_join.rs
index 6068e75263..358ff02473 100644
--- a/datafusion/physical-plan/src/joins/nested_loop_join.rs
+++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs
@@ -25,7 +25,10 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::Poll;
-use super::utils::{asymmetric_join_output_partitioning,
need_produce_result_in_final};
+use super::utils::{
+ asymmetric_join_output_partitioning, need_produce_result_in_final,
BatchSplitter,
+ BatchTransformer, NoopBatchTransformer, StatefulStreamResult,
+};
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::joins::utils::{
adjust_indices_by_join_type, apply_join_filter_to_indices,
build_batch_from_indices,
@@ -35,8 +38,8 @@ use crate::joins::utils::{
};
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::{
- execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution,
- ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties,
+ execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType,
+ Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties,
PlanProperties,
RecordBatchStream, SendableRecordBatchStream,
};
@@ -45,7 +48,9 @@ use arrow::compute::concat_batches;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow::util::bit_util;
-use datafusion_common::{exec_datafusion_err, JoinSide, Result, Statistics};
+use datafusion_common::{
+ exec_datafusion_err, internal_err, JoinSide, Result, Statistics,
+};
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_expr::JoinType;
@@ -230,10 +235,11 @@ impl NestedLoopJoinExec {
asymmetric_join_output_partitioning(left, right, &join_type);
// Determine execution mode:
- let mut mode = execution_mode_from_children([left, right]);
- if mode.is_unbounded() {
- mode = ExecutionMode::PipelineBreaking;
- }
+ let mode = if left.execution_mode().is_unbounded() {
+ ExecutionMode::PipelineBreaking
+ } else {
+ execution_mode_from_children([left, right])
+ };
PlanProperties::new(eq_properties, output_partitioning, mode)
}
@@ -345,6 +351,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
)
});
+ let batch_size = context.session_config().batch_size();
+ let enforce_batch_size_in_joins =
+ context.session_config().enforce_batch_size_in_joins();
+
let outer_table = self.right.execute(partition, context)?;
let indices_cache = (UInt64Array::new_null(0),
UInt32Array::new_null(0));
@@ -352,18 +362,38 @@ impl ExecutionPlan for NestedLoopJoinExec {
// Right side has an order and it is maintained during operation.
let right_side_ordered =
self.maintains_input_order()[1] &&
self.right.output_ordering().is_some();
- Ok(Box::pin(NestedLoopJoinStream {
- schema: Arc::clone(&self.schema),
- filter: self.filter.clone(),
- join_type: self.join_type,
- outer_table,
- inner_table,
- is_exhausted: false,
- column_indices: self.column_indices.clone(),
- join_metrics,
- indices_cache,
- right_side_ordered,
- }))
+
+ if enforce_batch_size_in_joins {
+ Ok(Box::pin(NestedLoopJoinStream {
+ schema: Arc::clone(&self.schema),
+ filter: self.filter.clone(),
+ join_type: self.join_type,
+ outer_table,
+ inner_table,
+ column_indices: self.column_indices.clone(),
+ join_metrics,
+ indices_cache,
+ right_side_ordered,
+ state: NestedLoopJoinStreamState::WaitBuildSide,
+ batch_transformer: BatchSplitter::new(batch_size),
+ left_data: None,
+ }))
+ } else {
+ Ok(Box::pin(NestedLoopJoinStream {
+ schema: Arc::clone(&self.schema),
+ filter: self.filter.clone(),
+ join_type: self.join_type,
+ outer_table,
+ inner_table,
+ column_indices: self.column_indices.clone(),
+ join_metrics,
+ indices_cache,
+ right_side_ordered,
+ state: NestedLoopJoinStreamState::WaitBuildSide,
+ batch_transformer: NoopBatchTransformer::new(),
+ left_data: None,
+ }))
+ }
}
fn metrics(&self) -> Option<MetricsSet> {
@@ -442,8 +472,37 @@ async fn collect_left_input(
))
}
+/// This enumeration represents various states of the nested loop join
algorithm.
+#[derive(Debug, Clone)]
+enum NestedLoopJoinStreamState {
+ /// The initial state, indicating that build-side data not collected yet
+ WaitBuildSide,
+ /// Indicates that build-side has been collected, and stream is ready for
+ /// fetching probe-side
+ FetchProbeBatch,
+ /// Indicates that a non-empty batch has been fetched from probe-side, and
+ /// is ready to be processed
+ ProcessProbeBatch(RecordBatch),
+ /// Indicates that probe-side has been fully processed
+ ExhaustedProbeSide,
+ /// Indicates that NestedLoopJoinStream execution is completed
+ Completed,
+}
+
+impl NestedLoopJoinStreamState {
+ /// Tries to extract a `ProcessProbeBatchState` from the
+ /// `NestedLoopJoinStreamState` enum. Returns an error if state is not
+ /// `ProcessProbeBatchState`.
+ fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> {
+ match self {
+ NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state),
+ _ => internal_err!("Expected join stream in ProcessProbeBatch
state"),
+ }
+ }
+}
+
/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
-struct NestedLoopJoinStream {
+struct NestedLoopJoinStream<T> {
/// Input schema
schema: Arc<Schema>,
/// join filter
@@ -454,8 +513,6 @@ struct NestedLoopJoinStream {
outer_table: SendableRecordBatchStream,
/// the inner table data of the nested loop join
inner_table: OnceFut<JoinLeftData>,
- /// There is nothing to process anymore and left side is processed in case
of full join
- is_exhausted: bool,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
// TODO: support null aware equal
@@ -466,6 +523,12 @@ struct NestedLoopJoinStream {
indices_cache: (UInt64Array, UInt32Array),
/// Whether the right side is ordered
right_side_ordered: bool,
+ /// Current state of the stream
+ state: NestedLoopJoinStreamState,
+ /// Transforms the output batch before returning.
+ batch_transformer: T,
+ /// Result of the left data future
+ left_data: Option<Arc<JoinLeftData>>,
}
/// Creates a Cartesian product of two input batches, preserving the order of
the right batch,
@@ -544,107 +607,164 @@ fn build_join_indices(
}
}
-impl NestedLoopJoinStream {
+impl<T: BatchTransformer> NestedLoopJoinStream<T> {
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
- // all left row
+ loop {
+ return match self.state {
+ NestedLoopJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ NestedLoopJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ NestedLoopJoinStreamState::ProcessProbeBatch(_) => {
+ handle_state!(self.process_probe_batch())
+ }
+ NestedLoopJoinStreamState::ExhaustedProbeSide => {
+ handle_state!(self.process_unmatched_build_batch())
+ }
+ NestedLoopJoinStreamState::Completed => Poll::Ready(None),
+ };
+ }
+ }
+
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
- let left_data = match ready!(self.inner_table.get_shared(cx)) {
- Ok(data) => data,
- Err(e) => return Poll::Ready(Some(Err(e))),
- };
+ // build hash table from left (build) side, if not yet done
+ self.left_data = Some(ready!(self.inner_table.get_shared(cx))?);
build_timer.done();
- // Get or initialize visited_left_side bitmap if required by join type
+ self.state = NestedLoopJoinStreamState::FetchProbeBatch;
+
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+
+ /// Fetches next batch from probe-side
+ ///
+ /// If a non-empty batch has been fetched, updates state to
+ /// `ProcessProbeBatchState`, otherwise updates state to
`ExhaustedProbeSide`.
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ match ready!(self.outer_table.poll_next_unpin(cx)) {
+ None => {
+ self.state = NestedLoopJoinStreamState::ExhaustedProbeSide;
+ }
+ Some(Ok(right_batch)) => {
+ self.state =
NestedLoopJoinStreamState::ProcessProbeBatch(right_batch);
+ }
+ Some(Err(err)) => return Poll::Ready(Err(err)),
+ };
+
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+
+ /// Joins current probe batch with build-side data and produces batch with
+ /// matched output, updates state to `FetchProbeBatch`.
+ fn process_probe_batch(
+ &mut self,
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let Some(left_data) = self.left_data.clone() else {
+ return internal_err!(
+ "Expected left_data to be Some in ProcessProbeBatch state"
+ );
+ };
let visited_left_side = left_data.bitmap();
+ let batch = self.state.try_as_process_probe_batch()?;
+
+ match self.batch_transformer.next() {
+ None => {
+ // Setting up timer & updating input metrics
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(batch.num_rows());
+ let timer = self.join_metrics.join_time.timer();
+
+ let result = join_left_and_right_batch(
+ left_data.batch(),
+ batch,
+ self.join_type,
+ self.filter.as_ref(),
+ &self.column_indices,
+ &self.schema,
+ visited_left_side,
+ &mut self.indices_cache,
+ self.right_side_ordered,
+ );
+ timer.done();
+
+ self.batch_transformer.set_batch(result?);
+ Ok(StatefulStreamResult::Continue)
+ }
+ Some((batch, last)) => {
+ if last {
+ self.state = NestedLoopJoinStreamState::FetchProbeBatch;
+ }
- // Check is_exhausted before polling the outer_table, such that when
the outer table
- // does not support `FusedStream`, Self will not poll it again
- if self.is_exhausted {
- return Poll::Ready(None);
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ Ok(StatefulStreamResult::Ready(Some(batch)))
+ }
}
+ }
- self.outer_table
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- Some(Ok(right_batch)) => {
- // Setting up timer & updating input metrics
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(right_batch.num_rows());
- let timer = self.join_metrics.join_time.timer();
-
- let result = join_left_and_right_batch(
- left_data.batch(),
- &right_batch,
- self.join_type,
- self.filter.as_ref(),
- &self.column_indices,
- &self.schema,
- visited_left_side,
- &mut self.indices_cache,
- self.right_side_ordered,
- );
-
- // Recording time & updating output metrics
- if let Ok(batch) = &result {
- timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
- }
-
- Some(result)
- }
- Some(err) => Some(err),
- None => {
- if need_produce_result_in_final(self.join_type) {
- // At this stage `visited_left_side` won't be updated,
so it's
- // safe to report about probe completion.
- //
- // Setting `is_exhausted` / returning None will
prevent from
- // multiple calls of `report_probe_completed()`
- if !left_data.report_probe_completed() {
- self.is_exhausted = true;
- return None;
- };
-
- // Only setting up timer, input is exhausted
- let timer = self.join_metrics.join_time.timer();
- // use the global left bitmap to produce the left
indices and right indices
- let (left_side, right_side) =
- get_final_indices_from_shared_bitmap(
- visited_left_side,
- self.join_type,
- );
- let empty_right_batch =
- RecordBatch::new_empty(self.outer_table.schema());
- // use the left and right indices to produce the batch
result
- let result = build_batch_from_indices(
- &self.schema,
- left_data.batch(),
- &empty_right_batch,
- &left_side,
- &right_side,
- &self.column_indices,
- JoinSide::Left,
- );
- self.is_exhausted = true;
-
- // Recording time & updating output metrics
- if let Ok(batch) = &result {
- timer.done();
- self.join_metrics.output_batches.add(1);
-
self.join_metrics.output_rows.add(batch.num_rows());
- }
-
- Some(result)
- } else {
- // end of the join loop
- None
- }
- }
- })
+ /// Processes unmatched build-side rows for certain join types and produces
+ /// output batch, updates state to `Completed`.
+ fn process_unmatched_build_batch(
+ &mut self,
+ ) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let Some(left_data) = self.left_data.clone() else {
+ return internal_err!(
+ "Expected left_data to be Some in ExhaustedProbeSide state"
+ );
+ };
+ let visited_left_side = left_data.bitmap();
+ if need_produce_result_in_final(self.join_type) {
+ // At this stage `visited_left_side` won't be updated, so it's
+ // safe to report about probe completion.
+ //
+ // Setting `is_exhausted` / returning None will prevent from
+ // multiple calls of `report_probe_completed()`
+ if !left_data.report_probe_completed() {
+ self.state = NestedLoopJoinStreamState::Completed;
+ return Ok(StatefulStreamResult::Ready(None));
+ };
+
+ // Only setting up timer, input is exhausted
+ let timer = self.join_metrics.join_time.timer();
+ // use the global left bitmap to produce the left indices and
right indices
+ let (left_side, right_side) =
+ get_final_indices_from_shared_bitmap(visited_left_side,
self.join_type);
+ let empty_right_batch =
RecordBatch::new_empty(self.outer_table.schema());
+ // use the left and right indices to produce the batch result
+ let result = build_batch_from_indices(
+ &self.schema,
+ left_data.batch(),
+ &empty_right_batch,
+ &left_side,
+ &right_side,
+ &self.column_indices,
+ JoinSide::Left,
+ );
+ self.state = NestedLoopJoinStreamState::Completed;
+
+ // Recording time
+ if result.is_ok() {
+ timer.done();
+ }
+
+ Ok(StatefulStreamResult::Ready(Some(result?)))
+ } else {
+ // end of the join loop
+ self.state = NestedLoopJoinStreamState::Completed;
+ Ok(StatefulStreamResult::Ready(None))
+ }
}
}
@@ -684,7 +804,7 @@ fn join_left_and_right_batch(
0..right_batch.num_rows(),
join_type,
right_side_ordered,
- );
+ )?;
build_batch_from_indices(
schema,
@@ -705,7 +825,7 @@ fn get_final_indices_from_shared_bitmap(
get_final_indices_from_bit_map(&bitmap, join_type)
}
-impl Stream for NestedLoopJoinStream {
+impl<T: BatchTransformer + Unpin + Send> Stream for NestedLoopJoinStream<T> {
type Item = Result<RecordBatch>;
fn poll_next(
@@ -716,14 +836,14 @@ impl Stream for NestedLoopJoinStream {
}
}
-impl RecordBatchStream for NestedLoopJoinStream {
+impl<T: BatchTransformer + Unpin + Send> RecordBatchStream for
NestedLoopJoinStream<T> {
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
#[cfg(test)]
-mod tests {
+pub(crate) mod tests {
use super::*;
use crate::{
common, expressions::Column, memory::MemoryExec,
repartition::RepartitionExec,
@@ -850,7 +970,7 @@ mod tests {
JoinFilter::new(filter_expression, column_indices, intermediate_schema)
}
- async fn multi_partitioned_join_collect(
+ pub(crate) async fn multi_partitioned_join_collect(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
join_type: &JoinType,
diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs
b/datafusion/physical-plan/src/joins/stream_join_utils.rs
index ba9384aef1..bddd152341 100644
--- a/datafusion/physical-plan/src/joins/stream_join_utils.rs
+++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs
@@ -31,8 +31,7 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder};
use arrow_schema::{Schema, SchemaRef};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{
- arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide,
Result,
- ScalarValue,
+ arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue,
};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::expressions::Column;
@@ -369,34 +368,40 @@ impl SortedFilterExpr {
filter_expr: Arc<dyn PhysicalExpr>,
filter_schema: &Schema,
) -> Result<Self> {
- let dt = &filter_expr.data_type(filter_schema)?;
+ let dt = filter_expr.data_type(filter_schema)?;
Ok(Self {
origin_sorted_expr,
filter_expr,
- interval: Interval::make_unbounded(dt)?,
+ interval: Interval::make_unbounded(&dt)?,
node_index: 0,
})
}
+
/// Get origin expr information
pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr {
&self.origin_sorted_expr
}
+
/// Get filter expr information
pub fn filter_expr(&self) -> &Arc<dyn PhysicalExpr> {
&self.filter_expr
}
+
/// Get interval information
pub fn interval(&self) -> &Interval {
&self.interval
}
+
/// Sets interval
pub fn set_interval(&mut self, interval: Interval) {
self.interval = interval;
}
+
/// Node index in ExprIntervalGraph
pub fn node_index(&self) -> usize {
self.node_index
}
+
/// Node index setter in ExprIntervalGraph
pub fn set_node_index(&mut self, node_index: usize) {
self.node_index = node_index;
@@ -409,41 +414,45 @@ impl SortedFilterExpr {
/// on the first or the last value of the expression in `build_input_buffer`
/// and `probe_batch`.
///
-/// # Arguments
+/// # Parameters
///
/// * `build_input_buffer` - The [RecordBatch] on the build side of the join.
/// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update.
/// * `probe_batch` - The `RecordBatch` on the probe side of the join.
/// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update.
///
-/// ### Note
-/// ```text
+/// ## Note
///
-/// Interval arithmetic is used to calculate viable join ranges for build-side
-/// pruning. This is done by first creating an interval for join filter values
in
-/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on
the
-/// ordering (descending/ascending) of the filter expression. Here, FV denotes
the
-/// first value on the build side. This range is then compared with the probe
side
-/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering
-/// (ascending/descending) of the probe side. Here, LV denotes the last value
on
-/// the probe side.
+/// Utilizing interval arithmetic, this function computes feasible join
intervals
+/// on the pruning side by evaluating the prospective value ranges that might
+/// emerge in subsequent data batches from the enforcer side. This is done by
+/// first creating an interval for join filter values in the pruning side of
the
+/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering
(descending/
+/// ascending) of the filter expression. Here, `FV` denotes the first value on
the
+/// pruning side. This range is then compared with the enforcer side interval,
+/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering
(ascending/
+/// descending) of the probe side. Here, `LV` denotes the last value on the
enforcer
+/// side.
///
/// As a concrete example, consider the following query:
///
+/// ```text
/// SELECT * FROM left_table, right_table
/// WHERE
/// left_key = right_key AND
/// a > b - 3 AND
/// a < b + 10
+/// ```
///
-/// where columns "a" and "b" come from tables "left_table" and "right_table",
+/// where columns `a` and `b` come from tables `left_table` and `right_table`,
/// respectively. When a new `RecordBatch` arrives at the right side, the
-/// condition a > b - 3 will possibly indicate a prunable range for the left
+/// condition `a > b - 3` will possibly indicate a prunable range for the left
/// side. Conversely, when a new `RecordBatch` arrives at the left side, the
-/// condition a < b + 10 will possibly indicate prunability for the right side.
-/// Let’s inspect what happens when a new RecordBatch` arrives at the right
+/// condition `a < b + 10` will possibly indicate prunability for the right
side.
+/// Let’s inspect what happens when a new `RecordBatch` arrives at the right
/// side (i.e. when the left side is the build side):
///
+/// ```text
/// Build Probe
/// +-------+ +-------+
/// | a | z | | b | y |
@@ -456,13 +465,13 @@ impl SortedFilterExpr {
/// |+--|--+| |+--|--+|
/// | 7 | 1 | | 6 | 3 |
/// +-------+ +-------+
+/// ```
///
/// In this case, the interval representing viable (i.e. joinable) values for
-/// column "a" is [1, ∞], and the interval representing possible future values
-/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate
+/// column `a` is `[1, ∞]`, and the interval representing possible future
values
+/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate
/// intervals for the whole filter expression and propagate join constraint by
/// traversing the expression graph.
-/// ```
pub fn calculate_filter_expr_intervals(
build_input_buffer: &RecordBatch,
build_sorted_filter_expr: &mut SortedFilterExpr,
@@ -710,13 +719,21 @@ fn update_sorted_exprs_with_node_indices(
}
}
-/// Prepares and sorts expressions based on a given filter, left and right
execution plans, and sort expressions.
+/// Prepares and sorts expressions based on a given filter, left and right
schemas,
+/// and sort expressions.
///
-/// # Arguments
+/// This function prepares sorted filter expressions for both the left and
right
+/// sides of a join operation. It first builds the filter order for each side
+/// based on the provided `ExecutionPlan`. If both sides have valid sorted
filter
+/// expressions, the function then constructs an expression interval graph and
+/// updates the sorted expressions with node indices. The final sorted filter
+/// expressions for both sides are then returned.
+///
+/// # Parameters
///
/// * `filter` - The join filter to base the sorting on.
-/// * `left` - The left execution plan.
-/// * `right` - The right execution plan.
+/// * `left` - The `ExecutionPlan` for the left side of the join.
+/// * `right` - The `ExecutionPlan` for the right side of the join.
/// * `left_sort_exprs` - The expressions to sort on the left side.
/// * `right_sort_exprs` - The expressions to sort on the right side.
///
@@ -730,9 +747,11 @@ pub fn prepare_sorted_exprs(
left_sort_exprs: &[PhysicalSortExpr],
right_sort_exprs: &[PhysicalSortExpr],
) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> {
- // Build the filter order for the left side
- let err = || plan_datafusion_err!("Filter does not include the child
order");
+ let err = || {
+ datafusion_common::plan_datafusion_err!("Filter does not include the
child order")
+ };
+ // Build the filter order for the left side:
let left_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Left,
filter,
@@ -741,7 +760,7 @@ pub fn prepare_sorted_exprs(
)?
.ok_or_else(err)?;
- // Build the filter order for the right side
+ // Build the filter order for the right side:
let right_temp_sorted_filter_expr = build_filter_input_order(
JoinSide::Right,
filter,
@@ -952,15 +971,15 @@ pub mod tests {
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
- index: 0,
+ index: left_schema.index_of("la1")?,
side: JoinSide::Left,
},
ColumnIndex {
- index: 4,
+ index: left_schema.index_of("la2")?,
side: JoinSide::Left,
},
ColumnIndex {
- index: 0,
+ index: right_schema.index_of("ra1")?,
side: JoinSide::Right,
},
];
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index ac718a95e9..70ada3892a 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -32,7 +32,6 @@ use std::task::{Context, Poll};
use std::vec;
use crate::common::SharedMemoryReservation;
-use crate::handle_state;
use crate::joins::hash_join::{equal_rows_arr, update_hash};
use crate::joins::stream_join_utils::{
calculate_filter_expr_intervals, combine_two_batches,
@@ -42,8 +41,9 @@ use crate::joins::stream_join_utils::{
};
use crate::joins::utils::{
apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
- check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex,
JoinFilter,
- JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult,
+ check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter,
+ BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn,
JoinOnRef,
+ NoopBatchTransformer, StatefulStreamResult,
};
use crate::{
execution_mode_from_children,
@@ -465,23 +465,27 @@ impl ExecutionPlan for SymmetricHashJoinExec {
consider using RepartitionExec"
);
}
- // If `filter_state` and `filter` are both present, then calculate
sorted filter expressions
- // for both sides, and build an expression graph.
- let (left_sorted_filter_expr, right_sorted_filter_expr, graph) =
- match (&self.left_sort_exprs, &self.right_sort_exprs,
&self.filter) {
- (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter))
=> {
- let (left, right, graph) = prepare_sorted_exprs(
- filter,
- &self.left,
- &self.right,
- left_sort_exprs,
- right_sort_exprs,
- )?;
- (Some(left), Some(right), Some(graph))
- }
- // If `filter_state` or `filter` is not present, then return
None for all three values:
- _ => (None, None, None),
- };
+ // If `filter_state` and `filter` are both present, then calculate
sorted
+ // filter expressions for both sides, and build an expression graph.
+ let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match
(
+ self.left_sort_exprs(),
+ self.right_sort_exprs(),
+ &self.filter,
+ ) {
+ (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => {
+ let (left, right, graph) = prepare_sorted_exprs(
+ filter,
+ &self.left,
+ &self.right,
+ left_sort_exprs,
+ right_sort_exprs,
+ )?;
+ (Some(left), Some(right), Some(graph))
+ }
+ // If `filter_state` or `filter` is not present, then return None
+ // for all three values:
+ _ => (None, None, None),
+ };
let (on_left, on_right) = self.on.iter().cloned().unzip();
@@ -494,6 +498,10 @@ impl ExecutionPlan for SymmetricHashJoinExec {
let right_stream = self.right.execute(partition,
Arc::clone(&context))?;
+ let batch_size = context.session_config().batch_size();
+ let enforce_batch_size_in_joins =
+ context.session_config().enforce_batch_size_in_joins();
+
let reservation = Arc::new(Mutex::new(
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
.register(context.memory_pool()),
@@ -502,29 +510,52 @@ impl ExecutionPlan for SymmetricHashJoinExec {
reservation.lock().try_grow(g.size())?;
}
- Ok(Box::pin(SymmetricHashJoinStream {
- left_stream,
- right_stream,
- schema: self.schema(),
- filter: self.filter.clone(),
- join_type: self.join_type,
- random_state: self.random_state.clone(),
- left: left_side_joiner,
- right: right_side_joiner,
- column_indices: self.column_indices.clone(),
- metrics: StreamJoinMetrics::new(partition, &self.metrics),
- graph,
- left_sorted_filter_expr,
- right_sorted_filter_expr,
- null_equals_null: self.null_equals_null,
- state: SHJStreamState::PullRight,
- reservation,
- }))
+ if enforce_batch_size_in_joins {
+ Ok(Box::pin(SymmetricHashJoinStream {
+ left_stream,
+ right_stream,
+ schema: self.schema(),
+ filter: self.filter.clone(),
+ join_type: self.join_type,
+ random_state: self.random_state.clone(),
+ left: left_side_joiner,
+ right: right_side_joiner,
+ column_indices: self.column_indices.clone(),
+ metrics: StreamJoinMetrics::new(partition, &self.metrics),
+ graph,
+ left_sorted_filter_expr,
+ right_sorted_filter_expr,
+ null_equals_null: self.null_equals_null,
+ state: SHJStreamState::PullRight,
+ reservation,
+ batch_transformer: BatchSplitter::new(batch_size),
+ }))
+ } else {
+ Ok(Box::pin(SymmetricHashJoinStream {
+ left_stream,
+ right_stream,
+ schema: self.schema(),
+ filter: self.filter.clone(),
+ join_type: self.join_type,
+ random_state: self.random_state.clone(),
+ left: left_side_joiner,
+ right: right_side_joiner,
+ column_indices: self.column_indices.clone(),
+ metrics: StreamJoinMetrics::new(partition, &self.metrics),
+ graph,
+ left_sorted_filter_expr,
+ right_sorted_filter_expr,
+ null_equals_null: self.null_equals_null,
+ state: SHJStreamState::PullRight,
+ reservation,
+ batch_transformer: NoopBatchTransformer::new(),
+ }))
+ }
}
}
/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
-struct SymmetricHashJoinStream {
+struct SymmetricHashJoinStream<T> {
/// Input streams
left_stream: SendableRecordBatchStream,
right_stream: SendableRecordBatchStream,
@@ -556,15 +587,19 @@ struct SymmetricHashJoinStream {
reservation: SharedMemoryReservation,
/// State machine for input execution
state: SHJStreamState,
+ /// Transforms the output batch before returning.
+ batch_transformer: T,
}
-impl RecordBatchStream for SymmetricHashJoinStream {
+impl<T: BatchTransformer + Unpin + Send> RecordBatchStream
+ for SymmetricHashJoinStream<T>
+{
fn schema(&self) -> SchemaRef {
Arc::clone(&self.schema)
}
}
-impl Stream for SymmetricHashJoinStream {
+impl<T: BatchTransformer + Unpin + Send> Stream for SymmetricHashJoinStream<T>
{
type Item = Result<RecordBatch>;
fn poll_next(
@@ -1140,7 +1175,7 @@ impl OneSideHashJoiner {
/// - Transition to `BothExhausted { final_result: true }`:
/// - Occurs in `prepare_for_final_results_after_exhaustion` when both
streams are
/// exhausted, indicating completion of processing and availability of
final results.
-impl SymmetricHashJoinStream {
+impl<T: BatchTransformer> SymmetricHashJoinStream<T> {
/// Implements the main polling logic for the join stream.
///
/// This method continuously checks the state of the join stream and
@@ -1159,26 +1194,45 @@ impl SymmetricHashJoinStream {
cx: &mut Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
- return match self.state() {
- SHJStreamState::PullRight => {
-
handle_state!(ready!(self.fetch_next_from_right_stream(cx)))
- }
- SHJStreamState::PullLeft => {
- handle_state!(ready!(self.fetch_next_from_left_stream(cx)))
+ match self.batch_transformer.next() {
+ None => {
+ let result = match self.state() {
+ SHJStreamState::PullRight => {
+ ready!(self.fetch_next_from_right_stream(cx))
+ }
+ SHJStreamState::PullLeft => {
+ ready!(self.fetch_next_from_left_stream(cx))
+ }
+ SHJStreamState::RightExhausted => {
+ ready!(self.handle_right_stream_end(cx))
+ }
+ SHJStreamState::LeftExhausted => {
+ ready!(self.handle_left_stream_end(cx))
+ }
+ SHJStreamState::BothExhausted {
+ final_result: false,
+ } => self.prepare_for_final_results_after_exhaustion(),
+ SHJStreamState::BothExhausted { final_result: true }
=> {
+ return Poll::Ready(None);
+ }
+ };
+
+ match result? {
+ StatefulStreamResult::Ready(None) => {
+ return Poll::Ready(None);
+ }
+ StatefulStreamResult::Ready(Some(batch)) => {
+ self.batch_transformer.set_batch(batch);
+ }
+ _ => {}
+ }
}
- SHJStreamState::RightExhausted => {
- handle_state!(ready!(self.handle_right_stream_end(cx)))
- }
- SHJStreamState::LeftExhausted => {
- handle_state!(ready!(self.handle_left_stream_end(cx)))
- }
- SHJStreamState::BothExhausted {
- final_result: false,
- } => {
-
handle_state!(self.prepare_for_final_results_after_exhaustion())
+ Some((batch, _)) => {
+ self.metrics.output_batches.add(1);
+ self.metrics.output_rows.add(batch.num_rows());
+ return Poll::Ready(Some(Ok(batch)));
}
- SHJStreamState::BothExhausted { final_result: true } =>
Poll::Ready(None),
- };
+ }
}
}
/// Asynchronously pulls the next batch from the right stream.
@@ -1384,11 +1438,8 @@ impl SymmetricHashJoinStream {
// Combine the left and right results:
let result = combine_two_batches(&self.schema, left_result,
right_result)?;
- // Update the metrics and return the result:
- if let Some(batch) = &result {
- // Update the metrics:
- self.metrics.output_batches.add(1);
- self.metrics.output_rows.add(batch.num_rows());
+ // Return the result:
+ if result.is_some() {
return Ok(StatefulStreamResult::Ready(result));
}
Ok(StatefulStreamResult::Continue)
@@ -1523,11 +1574,6 @@ impl SymmetricHashJoinStream {
let capacity = self.size();
self.metrics.stream_memory_usage.set(capacity);
self.reservation.lock().try_resize(capacity)?;
- // Update the metrics if we have a batch; otherwise, continue the loop.
- if let Some(batch) = &result {
- self.metrics.output_batches.add(1);
- self.metrics.output_rows.add(batch.num_rows());
- }
Ok(result)
}
}
@@ -1716,15 +1762,15 @@ mod tests {
let filter_expr = complicated_filter(&intermediate_schema)?;
let column_indices = vec![
ColumnIndex {
- index: 0,
+ index: left_schema.index_of("la1")?,
side: JoinSide::Left,
},
ColumnIndex {
- index: 4,
+ index: left_schema.index_of("la2")?,
side: JoinSide::Left,
},
ColumnIndex {
- index: 0,
+ index: right_schema.index_of("ra1")?,
side: JoinSide::Right,
},
];
@@ -1771,10 +1817,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
@@ -1825,10 +1868,7 @@ mod tests {
let (left, right) =
create_memory_table(left_partition, right_partition, vec![],
vec![])?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
@@ -1877,10 +1917,7 @@ mod tests {
let (left, right) =
create_memory_table(left_partition, right_partition, vec![],
vec![])?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
experiment(left, right, None, join_type, on, task_ctx).await?;
Ok(())
}
@@ -1926,10 +1963,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
@@ -1987,10 +2021,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
@@ -2048,10 +2079,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
@@ -2111,10 +2139,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Int32, true),
@@ -2170,10 +2195,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
@@ -2237,10 +2259,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("0", DataType::Int32, true),
@@ -2296,10 +2315,7 @@ mod tests {
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("lt1", left_schema)?,
options: SortOptions {
@@ -2380,10 +2396,7 @@ mod tests {
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("li1", left_schema)?,
options: SortOptions {
@@ -2473,10 +2486,7 @@ mod tests {
vec![right_sorted],
)?;
- let on = vec![(
- Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
- Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
- )];
+ let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)];
let intermediate_schema = Schema::new(vec![
Field::new("left", DataType::Float64, true),
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index 89f3feaf07..c520e42714 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -546,15 +546,16 @@ pub struct ColumnIndex {
pub side: JoinSide,
}
-/// Filter applied before join output
+/// Filter applied before join output. Fields are crate-public to allow
+/// downstream implementations to experiment with custom joins.
#[derive(Debug, Clone)]
pub struct JoinFilter {
/// Filter expression
- expression: Arc<dyn PhysicalExpr>,
+ pub(crate) expression: Arc<dyn PhysicalExpr>,
/// Column indices required to construct intermediate batch for filtering
- column_indices: Vec<ColumnIndex>,
+ pub(crate) column_indices: Vec<ColumnIndex>,
/// Physical schema of intermediate batch
- schema: Schema,
+ pub(crate) schema: Schema,
}
impl JoinFilter {
@@ -1280,15 +1281,15 @@ pub(crate) fn adjust_indices_by_join_type(
adjust_range: Range<usize>,
join_type: JoinType,
preserve_order_for_right: bool,
-) -> (UInt64Array, UInt32Array) {
+) -> Result<(UInt64Array, UInt32Array)> {
match join_type {
JoinType::Inner => {
// matched
- (left_indices, right_indices)
+ Ok((left_indices, right_indices))
}
JoinType::Left => {
// matched
- (left_indices, right_indices)
+ Ok((left_indices, right_indices))
// unmatched left row will be produced in the end of loop, and it
has been set in the left visited bitmap
}
JoinType::Right => {
@@ -1307,22 +1308,22 @@ pub(crate) fn adjust_indices_by_join_type(
// need to remove the duplicated record in the right side
let right_indices = get_semi_indices(adjust_range, &right_indices);
// the left_indices will not be used later for the `right semi`
join
- (left_indices, right_indices)
+ Ok((left_indices, right_indices))
}
JoinType::RightAnti => {
// need to remove the duplicated record in the right side
// get the anti index for the right side
let right_indices = get_anti_indices(adjust_range, &right_indices);
// the left_indices will not be used later for the `right anti`
join
- (left_indices, right_indices)
+ Ok((left_indices, right_indices))
}
JoinType::LeftSemi | JoinType::LeftAnti => {
// matched or unmatched left row will be produced in the end of
loop
// When visit the right batch, we can output the matched left row
and don't need to wait the end of loop
- (
+ Ok((
UInt64Array::from_iter_values(vec![]),
UInt32Array::from_iter_values(vec![]),
- )
+ ))
}
}
}
@@ -1347,27 +1348,64 @@ pub(crate) fn append_right_indices(
right_indices: UInt32Array,
adjust_range: Range<usize>,
preserve_order_for_right: bool,
-) -> (UInt64Array, UInt32Array) {
+) -> Result<(UInt64Array, UInt32Array)> {
if preserve_order_for_right {
- append_probe_indices_in_order(left_indices, right_indices,
adjust_range)
+ Ok(append_probe_indices_in_order(
+ left_indices,
+ right_indices,
+ adjust_range,
+ ))
} else {
let right_unmatched_indices = get_anti_indices(adjust_range,
&right_indices);
if right_unmatched_indices.is_empty() {
- (left_indices, right_indices)
+ Ok((left_indices, right_indices))
} else {
- let unmatched_size = right_unmatched_indices.len();
+ // `into_builder()` can fail here when there is nothing to be
filtered and
+ // left_indices or right_indices has the same reference to the
cached indices.
+ // In that case, we use a slower alternative.
+
// the new left indices: left_indices + null array
+ let mut new_left_indices_builder =
+ left_indices.into_builder().unwrap_or_else(|left_indices| {
+ let mut builder = UInt64Builder::with_capacity(
+ left_indices.len() + right_unmatched_indices.len(),
+ );
+ debug_assert_eq!(
+ left_indices.null_count(),
+ 0,
+ "expected left indices to have no nulls"
+ );
+ builder.append_slice(left_indices.values());
+ builder
+ });
+
new_left_indices_builder.append_nulls(right_unmatched_indices.len());
+ let new_left_indices =
UInt64Array::from(new_left_indices_builder.finish());
+
// the new right indices: right_indices + right_unmatched_indices
- let new_left_indices = left_indices
- .iter()
- .chain(std::iter::repeat(None).take(unmatched_size))
- .collect();
- let new_right_indices = right_indices
- .iter()
- .chain(right_unmatched_indices.iter())
- .collect();
- (new_left_indices, new_right_indices)
+ let mut new_right_indices_builder = right_indices
+ .into_builder()
+ .unwrap_or_else(|right_indices| {
+ let mut builder = UInt32Builder::with_capacity(
+ right_indices.len() + right_unmatched_indices.len(),
+ );
+ debug_assert_eq!(
+ right_indices.null_count(),
+ 0,
+ "expected right indices to have no nulls"
+ );
+ builder.append_slice(right_indices.values());
+ builder
+ });
+ debug_assert_eq!(
+ right_unmatched_indices.null_count(),
+ 0,
+ "expected right unmatched indices to have no nulls"
+ );
+
new_right_indices_builder.append_slice(right_unmatched_indices.values());
+ let new_right_indices =
UInt32Array::from(new_right_indices_builder.finish());
+
+ Ok((new_left_indices, new_right_indices))
}
}
}
@@ -1635,6 +1673,91 @@ pub(crate) fn asymmetric_join_output_partitioning(
}
}
+/// Trait for incrementally generating Join output.
+///
+/// This trait is used to limit some join outputs
+/// so it does not produce single large batches
+pub(crate) trait BatchTransformer: Debug + Clone {
+ /// Sets the next `RecordBatch` to be processed.
+ fn set_batch(&mut self, batch: RecordBatch);
+
+ /// Retrieves the next `RecordBatch` from the transformer.
+ /// Returns `None` if all batches have been produced.
+ /// The boolean flag indicates whether the batch is the last one.
+ fn next(&mut self) -> Option<(RecordBatch, bool)>;
+}
+
+#[derive(Debug, Clone)]
+/// A batch transformer that does nothing.
+pub(crate) struct NoopBatchTransformer {
+ /// RecordBatch to be processed
+ batch: Option<RecordBatch>,
+}
+
+impl NoopBatchTransformer {
+ pub fn new() -> Self {
+ Self { batch: None }
+ }
+}
+
+impl BatchTransformer for NoopBatchTransformer {
+ fn set_batch(&mut self, batch: RecordBatch) {
+ self.batch = Some(batch);
+ }
+
+ fn next(&mut self) -> Option<(RecordBatch, bool)> {
+ self.batch.take().map(|batch| (batch, true))
+ }
+}
+
+#[derive(Debug, Clone)]
+/// Splits large batches into smaller batches with a maximum number of rows.
+pub(crate) struct BatchSplitter {
+ /// RecordBatch to be split
+ batch: Option<RecordBatch>,
+ /// Maximum number of rows in a split batch
+ batch_size: usize,
+ /// Current row index
+ row_index: usize,
+}
+
+impl BatchSplitter {
+ /// Creates a new `BatchSplitter` with the specified batch size.
+ pub(crate) fn new(batch_size: usize) -> Self {
+ Self {
+ batch: None,
+ batch_size,
+ row_index: 0,
+ }
+ }
+}
+
+impl BatchTransformer for BatchSplitter {
+ fn set_batch(&mut self, batch: RecordBatch) {
+ self.batch = Some(batch);
+ self.row_index = 0;
+ }
+
+ fn next(&mut self) -> Option<(RecordBatch, bool)> {
+ let Some(batch) = &self.batch else {
+ return None;
+ };
+
+ let remaining_rows = batch.num_rows() - self.row_index;
+ let rows_to_slice = remaining_rows.min(self.batch_size);
+ let sliced_batch = batch.slice(self.row_index, rows_to_slice);
+ self.row_index += rows_to_slice;
+
+ let mut last = false;
+ if self.row_index >= batch.num_rows() {
+ self.batch = None;
+ last = true;
+ }
+
+ Some((sliced_batch, last))
+ }
+}
+
#[cfg(test)]
mod tests {
use std::pin::Pin;
@@ -1643,11 +1766,13 @@ mod tests {
use arrow::datatypes::{DataType, Fields};
use arrow::error::{ArrowError, Result as ArrowResult};
+ use arrow_array::Int32Array;
use arrow_schema::SortOptions;
-
use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
+ use rstest::rstest;
+
fn check(
left: &[Column],
right: &[Column],
@@ -2554,4 +2679,49 @@ mod tests {
Ok(())
}
+
+ fn create_test_batch(num_rows: usize) -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a",
DataType::Int32, false)]));
+ let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32));
+ RecordBatch::try_new(schema, vec![data]).unwrap()
+ }
+
+ fn assert_split_batches(
+ batches: Vec<(RecordBatch, bool)>,
+ batch_size: usize,
+ num_rows: usize,
+ ) {
+ let mut row_count = 0;
+ for (batch, last) in batches.into_iter() {
+ assert_eq!(batch.num_rows(), (num_rows -
row_count).min(batch_size));
+ let column = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .unwrap();
+ for i in 0..batch.num_rows() {
+ assert_eq!(column.value(i), i as i32 + row_count as i32);
+ }
+ row_count += batch.num_rows();
+ assert_eq!(last, row_count == num_rows);
+ }
+ }
+
+ #[rstest]
+ #[test]
+ fn test_batch_splitter(
+ #[values(1, 3, 11)] batch_size: usize,
+ #[values(1, 6, 50)] num_rows: usize,
+ ) {
+ let mut splitter = BatchSplitter::new(batch_size);
+ splitter.set_batch(create_test_batch(num_rows));
+
+ let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size));
+ while let Some(batch) = splitter.next() {
+ batches.push(batch);
+ }
+
+ assert!(splitter.next().is_none());
+ assert_split_batches(batches, batch_size, num_rows);
+ }
}
diff --git a/datafusion/sqllogictest/test_files/information_schema.slt
b/datafusion/sqllogictest/test_files/information_schema.slt
index 7acdf25b65..57bf029a63 100644
--- a/datafusion/sqllogictest/test_files/information_schema.slt
+++ b/datafusion/sqllogictest/test_files/information_schema.slt
@@ -173,6 +173,7 @@ datafusion.execution.batch_size 8192
datafusion.execution.coalesce_batches true
datafusion.execution.collect_statistics false
datafusion.execution.enable_recursive_ctes true
+datafusion.execution.enforce_batch_size_in_joins false
datafusion.execution.keep_partition_by_columns false
datafusion.execution.listing_table_ignore_subdirectory true
datafusion.execution.max_buffered_batches_per_output_file 2
@@ -263,6 +264,7 @@ datafusion.execution.batch_size 8192 Default batch size
while creating new batch
datafusion.execution.coalesce_batches true When set to true, record batches
will be examined between each operator and small batches will be coalesced into
larger batches. This is helpful when there are highly selective filters or
joins that could produce tiny output batches. The target batch size is
determined by the configuration setting
datafusion.execution.collect_statistics false Should DataFusion collect
statistics after listing files
datafusion.execution.enable_recursive_ctes true Should DataFusion support
recursive CTEs
+datafusion.execution.enforce_batch_size_in_joins false Should DataFusion
enforce batch size in joins or not. By default, DataFusion will not enforce
batch size in joins. Enforcing batch size in joins can reduce memory usage when
joining large tables with a highly-selective join filter, but is also slightly
slower.
datafusion.execution.keep_partition_by_columns false Should DataFusion keep
the columns used for partition_by in the output RecordBatches
datafusion.execution.listing_table_ignore_subdirectory true Should sub
directories be ignored when scanning directories for data files. Defaults to
true (ignores subdirectories), consistent with Hive. Note that this setting
does not affect reading partitioned tables (e.g.
`/table/year=2021/month=01/data.parquet`).
datafusion.execution.max_buffered_batches_per_output_file 2 This is the
maximum number of RecordBatches buffered for each output file being worked.
Higher values can potentially give faster write performance at the cost of
higher peak memory consumption
diff --git a/docs/source/user-guide/configs.md
b/docs/source/user-guide/configs.md
index f34d148f09..c61a7b6733 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -91,6 +91,7 @@ Environment variables are read during `SessionConfig`
initialisation so they mus
| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold |
0.8 | Aggregation ratio (number of distinct groups /
number of input rows) threshold for skipping partial aggregation. If the value
is greater then partial aggregation will skip aggregation for further input
[...]
| datafusion.execution.skip_partial_aggregation_probe_rows_threshold |
100000 | Number of input rows partial aggregation partition
should process, before aggregation ratio check and trying to switch to skipping
aggregation mode
[...]
| datafusion.execution.use_row_number_estimates_to_optimize_partitioning |
false | Should DataFusion use row number estimates at the
input to decide whether increasing parallelism is beneficial or not. By
default, only exact row numbers (not estimates) are used for this decision.
Setting this flag to `true` will likely produce better plans. if the source of
statistics is accurate. We plan to make this the default in the future.
[...]
+| datafusion.execution.enforce_batch_size_in_joins |
false | Should DataFusion enforce batch size in joins or
not. By default, DataFusion will not enforce batch size in joins. Enforcing
batch size in joins can reduce memory usage when joining large tables with a
highly-selective join filter, but is also slightly slower.
[...]
| datafusion.optimizer.enable_distinct_aggregation_soft_limit |
true | When set to true, the optimizer will push a limit
operation into grouped aggregations which have no aggregate expressions, as a
soft limit, emitting groups once the limit is reached, before all rows in the
group are read.
[...]
| datafusion.optimizer.enable_round_robin_repartition |
true | When set to true, the physical plan optimizer will
try to add round robin repartitioning to increase parallelism to leverage more
CPU cores
[...]
| datafusion.optimizer.enable_topk_aggregation |
true | When set to true, the optimizer will attempt to
perform limit operations during aggregations, if possible
[...]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]