This is an automated email from the ASF dual-hosted git repository.
ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4bd7c137e0 CrossJoin Refactor (#9830)
4bd7c137e0 is described below
commit 4bd7c137e0e205140e273a7c25824c94b457c660
Author: Berkay Şahin <[email protected]>
AuthorDate: Thu Apr 4 12:30:16 2024 +0300
CrossJoin Refactor (#9830)
* First iteration
* Wrap the logic inside function
* Send batches in the size of left batches
* Update cross_join.rs
* fuzz tests
* Update cross_join_fuzz.rs
* Update cross_join_fuzz.rs
* Test version 2
* Minor changes
* Minor changes
* Stateful implementation of CJ
* Adding comments
* Update cross_join_fuzz.rs
* Update cross_join.rs
* collect until batch size
* tmp
* revert changes
* Preserve the join strategy, clean the algorithm and states
* Update cross_join.rs
* Review
* Update cross_join.rs
---------
Co-authored-by: Mustafa Akur <[email protected]>
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion/physical-plan/src/joins/cross_join.rs | 142 +++++++++++++++--------
1 file changed, 95 insertions(+), 47 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/cross_join.rs
b/datafusion/physical-plan/src/joins/cross_join.rs
index 19d34f8048..9d1de3715f 100644
--- a/datafusion/physical-plan/src/joins/cross_join.rs
+++ b/datafusion/physical-plan/src/joins/cross_join.rs
@@ -22,14 +22,15 @@ use std::{any::Any, sync::Arc, task::Poll};
use super::utils::{
adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync,
OnceFut,
+ StatefulStreamResult,
};
use crate::coalesce_batches::concat_batches;
use crate::coalesce_partitions::CoalescePartitionsExec;
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
-use crate::ExecutionPlanProperties;
use crate::{
- execution_mode_from_children, ColumnStatistics, DisplayAs,
DisplayFormatType,
- Distribution, ExecutionMode, ExecutionPlan, PlanProperties,
RecordBatchStream,
+ execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs,
+ DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan,
+ ExecutionPlanProperties, PlanProperties, RecordBatchStream,
SendableRecordBatchStream, Statistics,
};
@@ -37,7 +38,7 @@ use arrow::datatypes::{Fields, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_array::RecordBatchOptions;
use datafusion_common::stats::Precision;
-use datafusion_common::{JoinType, Result, ScalarValue};
+use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
@@ -257,9 +258,10 @@ impl ExecutionPlan for CrossJoinExec {
schema: self.schema.clone(),
left_fut,
right: stream,
- right_batch: Arc::new(parking_lot::Mutex::new(None)),
left_index: 0,
join_metrics,
+ state: CrossJoinStreamState::WaitBuildSide,
+ left_data: RecordBatch::new_empty(self.left().schema()),
}))
}
@@ -319,16 +321,18 @@ fn stats_cartesian_product(
struct CrossJoinStream {
/// Input schema
schema: Arc<Schema>,
- /// future for data from left side
+ /// Future for data from left side
left_fut: OnceFut<JoinLeftData>,
- /// right
+ /// Right side stream
right: SendableRecordBatchStream,
/// Current value on the left
left_index: usize,
- /// Current batch being processed from the right side
- right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
- /// join execution metrics
+ /// Join execution metrics
join_metrics: BuildProbeJoinMetrics,
+ /// State of the stream
+ state: CrossJoinStreamState,
+ /// Left data
+ left_data: RecordBatch,
}
impl RecordBatchStream for CrossJoinStream {
@@ -337,6 +341,25 @@ impl RecordBatchStream for CrossJoinStream {
}
}
+/// Represents states of CrossJoinStream
+enum CrossJoinStreamState {
+ WaitBuildSide,
+ FetchProbeBatch,
+ /// Holds the currently processed right side batch
+ BuildBatches(RecordBatch),
+}
+
+impl CrossJoinStreamState {
+ /// Tries to extract RecordBatch from CrossJoinStreamState enum.
+ /// Returns an error if state is not BuildBatches state.
+ fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
+ match self {
+ CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
+ _ => internal_err!("Expected RecordBatch in BuildBatches state"),
+ }
+ }
+}
+
fn build_batch(
left_index: usize,
batch: &RecordBatch,
@@ -384,58 +407,83 @@ impl CrossJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ CrossJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ CrossJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ CrossJoinStreamState::BuildBatches(_) => {
+ handle_state!(self.build_batches())
+ }
+ };
+ }
+ }
+
+ /// Collects build (left) side of the join into the state. In case of an
empty build batch,
+ /// the execution terminates. Otherwise, the state is updated to fetch
probe (right) batch.
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
+ Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();
- if left_data.num_rows() == 0 {
- return Poll::Ready(None);
- }
+ let result = if left_data.num_rows() == 0 {
+ StatefulStreamResult::Ready(None)
+ } else {
+ self.left_data = left_data.clone();
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ StatefulStreamResult::Continue
+ };
+ Poll::Ready(Ok(result))
+ }
+
+ /// Fetches the probe (right) batch, updates the metrics, and save the
batch in the state.
+ /// Then, the state is updated to build result batches.
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ self.left_index = 0;
+ let right_data = match ready!(self.right.poll_next_unpin(cx)) {
+ Some(Ok(right_data)) => right_data,
+ Some(Err(e)) => return Poll::Ready(Err(e)),
+ None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
+ };
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_data.num_rows());
+
+ self.state = CrossJoinStreamState::BuildBatches(right_data);
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
- if self.left_index > 0 && self.left_index < left_data.num_rows() {
+ /// Joins the the indexed row of left data with the current probe batch.
+ /// If all the results are produced, the state is set to fetch new probe
batch.
+ fn build_batches(&mut self) ->
Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let right_batch = self.state.try_as_record_batch()?;
+ if self.left_index < self.left_data.num_rows() {
let join_timer = self.join_metrics.join_time.timer();
- let right_batch = {
- let right_batch = self.right_batch.lock();
- right_batch.clone().unwrap()
- };
let result =
- build_batch(self.left_index, &right_batch, left_data,
&self.schema);
- self.join_metrics.input_rows.add(right_batch.num_rows());
+ build_batch(self.left_index, right_batch, &self.left_data,
&self.schema);
+ join_timer.done();
+
if let Ok(ref batch) = result {
- join_timer.done();
self.join_metrics.output_batches.add(1);
self.join_metrics.output_rows.add(batch.num_rows());
}
self.left_index += 1;
- return Poll::Ready(Some(result));
+ result.map(|r| StatefulStreamResult::Ready(Some(r)))
+ } else {
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Ok(StatefulStreamResult::Continue)
}
- self.left_index = 0;
- self.right
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- Some(Ok(batch)) => {
- let join_timer = self.join_metrics.join_time.timer();
- let result =
- build_batch(self.left_index, &batch, left_data,
&self.schema);
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
- }
- self.left_index = 1;
-
- let mut right_batch = self.right_batch.lock();
- *right_batch = Some(batch);
-
- Some(result)
- }
- other => other,
- })
}
}