This is an automated email from the ASF dual-hosted git repository.

ozankabak pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4bd7c137e0 CrossJoin Refactor (#9830)
4bd7c137e0 is described below

commit 4bd7c137e0e205140e273a7c25824c94b457c660
Author: Berkay Şahin <[email protected]>
AuthorDate: Thu Apr 4 12:30:16 2024 +0300

    CrossJoin Refactor (#9830)
    
    * First iteration
    
    * Wrap the logic inside function
    
    * Send batches in the size of left batches
    
    * Update cross_join.rs
    
    * fuzz tests
    
    * Update cross_join_fuzz.rs
    
    * Update cross_join_fuzz.rs
    
    * Test version 2
    
    * Minor changes
    
    * Minor changes
    
    * Stateful implementation of CJ
    
    * Adding comments
    
    * Update cross_join_fuzz.rs
    
    * Update cross_join.rs
    
    * collect until batch size
    
    * tmp
    
    * revert changes
    
    * Preserve the join strategy, clean the algorithm and states
    
    * Update cross_join.rs
    
    * Review
    
    * Update cross_join.rs
    
    ---------
    
    Co-authored-by: Mustafa Akur <[email protected]>
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 datafusion/physical-plan/src/joins/cross_join.rs | 142 +++++++++++++++--------
 1 file changed, 95 insertions(+), 47 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/cross_join.rs 
b/datafusion/physical-plan/src/joins/cross_join.rs
index 19d34f8048..9d1de3715f 100644
--- a/datafusion/physical-plan/src/joins/cross_join.rs
+++ b/datafusion/physical-plan/src/joins/cross_join.rs
@@ -22,14 +22,15 @@ use std::{any::Any, sync::Arc, task::Poll};
 
 use super::utils::{
     adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, 
OnceFut,
+    StatefulStreamResult,
 };
 use crate::coalesce_batches::concat_batches;
 use crate::coalesce_partitions::CoalescePartitionsExec;
 use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
-use crate::ExecutionPlanProperties;
 use crate::{
-    execution_mode_from_children, ColumnStatistics, DisplayAs, 
DisplayFormatType,
-    Distribution, ExecutionMode, ExecutionPlan, PlanProperties, 
RecordBatchStream,
+    execution_mode_from_children, handle_state, ColumnStatistics, DisplayAs,
+    DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan,
+    ExecutionPlanProperties, PlanProperties, RecordBatchStream,
     SendableRecordBatchStream, Statistics,
 };
 
@@ -37,7 +38,7 @@ use arrow::datatypes::{Fields, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use arrow_array::RecordBatchOptions;
 use datafusion_common::stats::Precision;
-use datafusion_common::{JoinType, Result, ScalarValue};
+use datafusion_common::{internal_err, JoinType, Result, ScalarValue};
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::equivalence::join_equivalence_properties;
@@ -257,9 +258,10 @@ impl ExecutionPlan for CrossJoinExec {
             schema: self.schema.clone(),
             left_fut,
             right: stream,
-            right_batch: Arc::new(parking_lot::Mutex::new(None)),
             left_index: 0,
             join_metrics,
+            state: CrossJoinStreamState::WaitBuildSide,
+            left_data: RecordBatch::new_empty(self.left().schema()),
         }))
     }
 
@@ -319,16 +321,18 @@ fn stats_cartesian_product(
 struct CrossJoinStream {
     /// Input schema
     schema: Arc<Schema>,
-    /// future for data from left side
+    /// Future for data from left side
     left_fut: OnceFut<JoinLeftData>,
-    /// right
+    /// Right side stream
     right: SendableRecordBatchStream,
     /// Current value on the left
     left_index: usize,
-    /// Current batch being processed from the right side
-    right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
-    /// join execution metrics
+    /// Join execution metrics
     join_metrics: BuildProbeJoinMetrics,
+    /// State of the stream
+    state: CrossJoinStreamState,
+    /// Left data
+    left_data: RecordBatch,
 }
 
 impl RecordBatchStream for CrossJoinStream {
@@ -337,6 +341,25 @@ impl RecordBatchStream for CrossJoinStream {
     }
 }
 
+/// Represents states of CrossJoinStream
+enum CrossJoinStreamState {
+    WaitBuildSide,
+    FetchProbeBatch,
+    /// Holds the currently processed right side batch
+    BuildBatches(RecordBatch),
+}
+
+impl CrossJoinStreamState {
+    /// Tries to extract RecordBatch from CrossJoinStreamState enum.
+    /// Returns an error if state is not BuildBatches state.
+    fn try_as_record_batch(&mut self) -> Result<&RecordBatch> {
+        match self {
+            CrossJoinStreamState::BuildBatches(rb) => Ok(rb),
+            _ => internal_err!("Expected RecordBatch in BuildBatches state"),
+        }
+    }
+}
+
 fn build_batch(
     left_index: usize,
     batch: &RecordBatch,
@@ -384,58 +407,83 @@ impl CrossJoinStream {
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> std::task::Poll<Option<Result<RecordBatch>>> {
+        loop {
+            return match self.state {
+                CrossJoinStreamState::WaitBuildSide => {
+                    handle_state!(ready!(self.collect_build_side(cx)))
+                }
+                CrossJoinStreamState::FetchProbeBatch => {
+                    handle_state!(ready!(self.fetch_probe_batch(cx)))
+                }
+                CrossJoinStreamState::BuildBatches(_) => {
+                    handle_state!(self.build_batches())
+                }
+            };
+        }
+    }
+
+    /// Collects build (left) side of the join into the state. In case of an 
empty build batch,
+    /// the execution terminates. Otherwise, the state is updated to fetch 
probe (right) batch.
+    fn collect_build_side(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         let build_timer = self.join_metrics.build_time.timer();
         let (left_data, _) = match ready!(self.left_fut.get(cx)) {
             Ok(left_data) => left_data,
-            Err(e) => return Poll::Ready(Some(Err(e))),
+            Err(e) => return Poll::Ready(Err(e)),
         };
         build_timer.done();
 
-        if left_data.num_rows() == 0 {
-            return Poll::Ready(None);
-        }
+        let result = if left_data.num_rows() == 0 {
+            StatefulStreamResult::Ready(None)
+        } else {
+            self.left_data = left_data.clone();
+            self.state = CrossJoinStreamState::FetchProbeBatch;
+            StatefulStreamResult::Continue
+        };
+        Poll::Ready(Ok(result))
+    }
+
+    /// Fetches the probe (right) batch, updates the metrics, and save the 
batch in the state.
+    /// Then, the state is updated to build result batches.
+    fn fetch_probe_batch(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        self.left_index = 0;
+        let right_data = match ready!(self.right.poll_next_unpin(cx)) {
+            Some(Ok(right_data)) => right_data,
+            Some(Err(e)) => return Poll::Ready(Err(e)),
+            None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
+        };
+        self.join_metrics.input_batches.add(1);
+        self.join_metrics.input_rows.add(right_data.num_rows());
+
+        self.state = CrossJoinStreamState::BuildBatches(right_data);
+        Poll::Ready(Ok(StatefulStreamResult::Continue))
+    }
 
-        if self.left_index > 0 && self.left_index < left_data.num_rows() {
+    /// Joins the the indexed row of left data with the current probe batch.
+    /// If all the results are produced, the state is set to fetch new probe 
batch.
+    fn build_batches(&mut self) -> 
Result<StatefulStreamResult<Option<RecordBatch>>> {
+        let right_batch = self.state.try_as_record_batch()?;
+        if self.left_index < self.left_data.num_rows() {
             let join_timer = self.join_metrics.join_time.timer();
-            let right_batch = {
-                let right_batch = self.right_batch.lock();
-                right_batch.clone().unwrap()
-            };
             let result =
-                build_batch(self.left_index, &right_batch, left_data, 
&self.schema);
-            self.join_metrics.input_rows.add(right_batch.num_rows());
+                build_batch(self.left_index, right_batch, &self.left_data, 
&self.schema);
+            join_timer.done();
+
             if let Ok(ref batch) = result {
-                join_timer.done();
                 self.join_metrics.output_batches.add(1);
                 self.join_metrics.output_rows.add(batch.num_rows());
             }
             self.left_index += 1;
-            return Poll::Ready(Some(result));
+            result.map(|r| StatefulStreamResult::Ready(Some(r)))
+        } else {
+            self.state = CrossJoinStreamState::FetchProbeBatch;
+            Ok(StatefulStreamResult::Continue)
         }
-        self.left_index = 0;
-        self.right
-            .poll_next_unpin(cx)
-            .map(|maybe_batch| match maybe_batch {
-                Some(Ok(batch)) => {
-                    let join_timer = self.join_metrics.join_time.timer();
-                    let result =
-                        build_batch(self.left_index, &batch, left_data, 
&self.schema);
-                    self.join_metrics.input_batches.add(1);
-                    self.join_metrics.input_rows.add(batch.num_rows());
-                    if let Ok(ref batch) = result {
-                        join_timer.done();
-                        self.join_metrics.output_batches.add(1);
-                        self.join_metrics.output_rows.add(batch.num_rows());
-                    }
-                    self.left_index = 1;
-
-                    let mut right_batch = self.right_batch.lock();
-                    *right_batch = Some(batch);
-
-                    Some(result)
-                }
-                other => other,
-            })
     }
 }
 

Reply via email to