korowa commented on code in PR #9830:
URL: https://github.com/apache/arrow-datafusion/pull/9830#discussion_r1545459833


##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
 }
 
 impl CrossJoinStream {
-    /// Separate implementation function that unpins the [`CrossJoinStream`] so
-    /// that partial borrows work correctly
+    /// Separate implementation function that unpins the [`CrossJoinStream`]
+    /// so that partial borrows work correctly
     fn poll_next_impl(
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> std::task::Poll<Option<Result<RecordBatch>>> {
+        loop {
+            return match self.state {
+                CrossJoinStreamState::WaitBuildSide => {
+                    handle_state!(ready!(self.collect_build_side(cx)))
+                }
+                CrossJoinStreamState::FetchProbeBatch => {
+                    handle_state!(ready!(self.fetch_probe_batch(cx)))
+                }
+                CrossJoinStreamState::GenerateResult => {
+                    handle_state!(self.generate_result())
+                }
+                CrossJoinStreamState::Completed => Poll::Ready(None),
+            };
+        }
+    }
+
+    /// Waits until the left data computation completes. After it is ready,
+    /// copies it into the state and continues with fetching probe side. If we
+    /// cannot receive any row from left, the operation ends without polling 
right.
+    fn collect_build_side(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         let build_timer = self.join_metrics.build_time.timer();
         let (left_data, _) = match ready!(self.left_fut.get(cx)) {
             Ok(left_data) => left_data,
-            Err(e) => return Poll::Ready(Some(Err(e))),
+            Err(e) => return Poll::Ready(Err(e)),
         };
         build_timer.done();
 
-        if left_data.num_rows() == 0 {
-            return Poll::Ready(None);
+        // If the left batch is empty, we can return `Poll::Ready(None)` 
immediately.
+        if left_data.iter().all(|batch| batch.num_rows() == 0) {
+            self.state = CrossJoinStreamState::Completed;
+            Poll::Ready(Ok(StatefulStreamResult::Continue))
+        } else {
+            self.left_data = left_data

Review Comment:
   Will this code clone all contents of collected LeftData for each 
CrossJoinExec stream? If so -- could we avoid having multiple copies of 
left-side batches and use only the original data from left_fut, accessed via 
Arc?



##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -311,20 +311,27 @@ fn stats_cartesian_product(
     }
 }
 
-/// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
+/// A stream that issues [RecordBatch]es as they arrive from the right of the 
join.
+/// Right column orders are preserved.
 struct CrossJoinStream {
     /// Input schema
     schema: Arc<Schema>,
-    /// future for data from left side
+    /// Future for data from left side
     left_fut: OnceFut<JoinLeftData>,
-    /// right
+    /// Right stream
     right: SendableRecordBatchStream,
-    /// Current value on the left
-    left_index: usize,
-    /// Current batch being processed from the right side
-    right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
-    /// join execution metrics
+    /// Join execution metrics
     join_metrics: BuildProbeJoinMetrics,
+    /// State information
+    state: CrossJoinStreamState,
+    /// Left data
+    left_data: Vec<RecordBatch>,
+    /// Current right batch
+    right_batch: RecordBatch,
+    /// Indexes the next processed build side batch
+    left_batch_index: usize,

Review Comment:
   minor: Could batch & row indices be a contents of GenerateResult state (as 
they seem to be valid only while result generation)?



##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -311,20 +311,27 @@ fn stats_cartesian_product(
     }
 }
 
-/// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
+/// A stream that issues [RecordBatch]es as they arrive from the right of the 
join.
+/// Right column orders are preserved.
 struct CrossJoinStream {
     /// Input schema
     schema: Arc<Schema>,
-    /// future for data from left side
+    /// Future for data from left side
     left_fut: OnceFut<JoinLeftData>,
-    /// right
+    /// Right stream
     right: SendableRecordBatchStream,
-    /// Current value on the left
-    left_index: usize,
-    /// Current batch being processed from the right side
-    right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
-    /// join execution metrics
+    /// Join execution metrics
     join_metrics: BuildProbeJoinMetrics,
+    /// State information
+    state: CrossJoinStreamState,
+    /// Left data
+    left_data: Vec<RecordBatch>,

Review Comment:
   minor: we store both fut & data -- won't it make sense to store 
`stream.left_data` as enum (with initial/ready variants, wrapping future and 
data respectively)?



##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
 }
 
 impl CrossJoinStream {
-    /// Separate implementation function that unpins the [`CrossJoinStream`] so
-    /// that partial borrows work correctly
+    /// Separate implementation function that unpins the [`CrossJoinStream`]
+    /// so that partial borrows work correctly
     fn poll_next_impl(
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> std::task::Poll<Option<Result<RecordBatch>>> {
+        loop {
+            return match self.state {
+                CrossJoinStreamState::WaitBuildSide => {
+                    handle_state!(ready!(self.collect_build_side(cx)))
+                }
+                CrossJoinStreamState::FetchProbeBatch => {
+                    handle_state!(ready!(self.fetch_probe_batch(cx)))
+                }
+                CrossJoinStreamState::GenerateResult => {
+                    handle_state!(self.generate_result())
+                }
+                CrossJoinStreamState::Completed => Poll::Ready(None),
+            };
+        }
+    }
+
+    /// Waits until the left data computation completes. After it is ready,
+    /// copies it into the state and continues with fetching probe side. If we
+    /// cannot receive any row from left, the operation ends without polling 
right.
+    fn collect_build_side(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         let build_timer = self.join_metrics.build_time.timer();
         let (left_data, _) = match ready!(self.left_fut.get(cx)) {
             Ok(left_data) => left_data,
-            Err(e) => return Poll::Ready(Some(Err(e))),
+            Err(e) => return Poll::Ready(Err(e)),
         };
         build_timer.done();
 
-        if left_data.num_rows() == 0 {
-            return Poll::Ready(None);
+        // If the left batch is empty, we can return `Poll::Ready(None)` 
immediately.
+        if left_data.iter().all(|batch| batch.num_rows() == 0) {
+            self.state = CrossJoinStreamState::Completed;
+            Poll::Ready(Ok(StatefulStreamResult::Continue))
+        } else {
+            self.left_data = left_data
+                .clone()
+                .into_iter()
+                .filter(|batch| batch.num_rows() > 0)
+                .collect();
+            self.state = CrossJoinStreamState::FetchProbeBatch;
+            Poll::Ready(Ok(StatefulStreamResult::Continue))
         }
+    }
 
-        if self.left_index > 0 && self.left_index < left_data.num_rows() {
-            let join_timer = self.join_metrics.join_time.timer();
-            let right_batch = {
-                let right_batch = self.right_batch.lock();
-                right_batch.clone().unwrap()
-            };
-            let result =
-                build_batch(self.left_index, &right_batch, left_data, 
&self.schema);
-            self.join_metrics.input_rows.add(right_batch.num_rows());
-            if let Ok(ref batch) = result {
-                join_timer.done();
-                self.join_metrics.output_batches.add(1);
-                self.join_metrics.output_rows.add(batch.num_rows());
+    /// Polls the right (probe) side until a non-empty batch is ready.
+    /// Then, the next state is set as the result generation step after
+    /// the polled batch is stored in the state and indices are reset.
+    fn fetch_probe_batch(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        match ready!(self.right.poll_next_unpin(cx)) {
+            None => {
+                self.state = CrossJoinStreamState::Completed;
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
             }
-            self.left_index += 1;
-            return Poll::Ready(Some(result));
-        }
-        self.left_index = 0;
-        self.right
-            .poll_next_unpin(cx)
-            .map(|maybe_batch| match maybe_batch {
-                Some(Ok(batch)) => {
-                    let join_timer = self.join_metrics.join_time.timer();
-                    let result =
-                        build_batch(self.left_index, &batch, left_data, 
&self.schema);
-                    self.join_metrics.input_batches.add(1);
-                    self.join_metrics.input_rows.add(batch.num_rows());
-                    if let Ok(ref batch) = result {
-                        join_timer.done();
-                        self.join_metrics.output_batches.add(1);
-                        self.join_metrics.output_rows.add(batch.num_rows());
-                    }
-                    self.left_index = 1;
-
-                    let mut right_batch = self.right_batch.lock();
-                    *right_batch = Some(batch);
-
-                    Some(result)
+            Some(Ok(right_batch)) => {
+                // Update the metrics.
+                self.join_metrics.input_batches.add(1);
+                self.join_metrics.input_rows.add(right_batch.num_rows());
+                if right_batch.num_rows() == 0 {
+                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
                 }
-                other => other,
-            })
+                // New batch arrives, reset the indices.
+                self.left_batch_index = 0;
+                self.right_row_index = 0;
+                // Store the new batch into the state.
+                self.right_batch = right_batch;
+                self.state = CrossJoinStreamState::GenerateResult;
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
+            }
+            Some(Err(err)) => Poll::Ready(Err(err)),
+        }
+    }
+
+    /// If there is non-paired rows in the probe batch, the function process 
them.
+    /// If not, it directs the state to fetching probe side.
+    fn generate_result(&mut self) -> 
Result<StatefulStreamResult<Option<RecordBatch>>> {
+        if self.right_row_index < self.right_batch.num_rows() {
+            // Right batch has some unpaired rows, continue with the next row.
+            let result_batch = self.build_batch()?;
+            Ok(StatefulStreamResult::Ready(Some(result_batch)))
+        } else {
+            self.state = CrossJoinStreamState::FetchProbeBatch;
+            Ok(StatefulStreamResult::Continue)
+        }
+    }
+
+    /// This function constructs a new `RecordBatch` by joining the left and 
right batches
+    /// based on the current indices. It also updates the metrics.
+    ///
+    /// # Arguments
+    /// * `self.left_data` - Array of `RecordBatch`es from the left side to be 
joined.
+    /// * `self.right_batch` - The current `RecordBatch` from the right side 
to be joined.
+    /// * `self.left_batch_index` - Index of the current left batch being 
processed.
+    /// * `self.right_row_index` - Index of the current row in the right batch 
to be joined.
+    /// * `join_metrics` - Metrics container to track performance of the join 
operation.
+
+    fn build_batch(&mut self) -> Result<RecordBatch> {

Review Comment:
   minor: Got the idea, but it may be a bit confusing that docstring doesn't 
match signature.
   
   BTW won't it be cleaner to leave it as a standalone function (as it was 
before) and leave mertics & offsets updates to `generate_result`?



##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
 }
 
 impl CrossJoinStream {
-    /// Separate implementation function that unpins the [`CrossJoinStream`] so
-    /// that partial borrows work correctly
+    /// Separate implementation function that unpins the [`CrossJoinStream`]
+    /// so that partial borrows work correctly
     fn poll_next_impl(
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> std::task::Poll<Option<Result<RecordBatch>>> {
+        loop {
+            return match self.state {
+                CrossJoinStreamState::WaitBuildSide => {
+                    handle_state!(ready!(self.collect_build_side(cx)))
+                }
+                CrossJoinStreamState::FetchProbeBatch => {
+                    handle_state!(ready!(self.fetch_probe_batch(cx)))
+                }
+                CrossJoinStreamState::GenerateResult => {
+                    handle_state!(self.generate_result())
+                }
+                CrossJoinStreamState::Completed => Poll::Ready(None),
+            };
+        }
+    }
+
+    /// Waits until the left data computation completes. After it is ready,
+    /// copies it into the state and continues with fetching probe side. If we
+    /// cannot receive any row from left, the operation ends without polling 
right.
+    fn collect_build_side(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
         let build_timer = self.join_metrics.build_time.timer();
         let (left_data, _) = match ready!(self.left_fut.get(cx)) {
             Ok(left_data) => left_data,
-            Err(e) => return Poll::Ready(Some(Err(e))),
+            Err(e) => return Poll::Ready(Err(e)),
         };
         build_timer.done();
 
-        if left_data.num_rows() == 0 {
-            return Poll::Ready(None);
+        // If the left batch is empty, we can return `Poll::Ready(None)` 
immediately.
+        if left_data.iter().all(|batch| batch.num_rows() == 0) {
+            self.state = CrossJoinStreamState::Completed;
+            Poll::Ready(Ok(StatefulStreamResult::Continue))
+        } else {
+            self.left_data = left_data
+                .clone()
+                .into_iter()
+                .filter(|batch| batch.num_rows() > 0)
+                .collect();
+            self.state = CrossJoinStreamState::FetchProbeBatch;
+            Poll::Ready(Ok(StatefulStreamResult::Continue))
         }
+    }
 
-        if self.left_index > 0 && self.left_index < left_data.num_rows() {
-            let join_timer = self.join_metrics.join_time.timer();
-            let right_batch = {
-                let right_batch = self.right_batch.lock();
-                right_batch.clone().unwrap()
-            };
-            let result =
-                build_batch(self.left_index, &right_batch, left_data, 
&self.schema);
-            self.join_metrics.input_rows.add(right_batch.num_rows());
-            if let Ok(ref batch) = result {
-                join_timer.done();
-                self.join_metrics.output_batches.add(1);
-                self.join_metrics.output_rows.add(batch.num_rows());
+    /// Polls the right (probe) side until a non-empty batch is ready.
+    /// Then, the next state is set as the result generation step after
+    /// the polled batch is stored in the state and indices are reset.
+    fn fetch_probe_batch(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+        match ready!(self.right.poll_next_unpin(cx)) {
+            None => {
+                self.state = CrossJoinStreamState::Completed;
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
             }
-            self.left_index += 1;
-            return Poll::Ready(Some(result));
-        }
-        self.left_index = 0;
-        self.right
-            .poll_next_unpin(cx)
-            .map(|maybe_batch| match maybe_batch {
-                Some(Ok(batch)) => {
-                    let join_timer = self.join_metrics.join_time.timer();
-                    let result =
-                        build_batch(self.left_index, &batch, left_data, 
&self.schema);
-                    self.join_metrics.input_batches.add(1);
-                    self.join_metrics.input_rows.add(batch.num_rows());
-                    if let Ok(ref batch) = result {
-                        join_timer.done();
-                        self.join_metrics.output_batches.add(1);
-                        self.join_metrics.output_rows.add(batch.num_rows());
-                    }
-                    self.left_index = 1;
-
-                    let mut right_batch = self.right_batch.lock();
-                    *right_batch = Some(batch);
-
-                    Some(result)
+            Some(Ok(right_batch)) => {
+                // Update the metrics.
+                self.join_metrics.input_batches.add(1);
+                self.join_metrics.input_rows.add(right_batch.num_rows());
+                if right_batch.num_rows() == 0 {
+                    return Poll::Ready(Ok(StatefulStreamResult::Continue));
                 }
-                other => other,
-            })
+                // New batch arrives, reset the indices.
+                self.left_batch_index = 0;
+                self.right_row_index = 0;
+                // Store the new batch into the state.
+                self.right_batch = right_batch;
+                self.state = CrossJoinStreamState::GenerateResult;
+                Poll::Ready(Ok(StatefulStreamResult::Continue))
+            }
+            Some(Err(err)) => Poll::Ready(Err(err)),
+        }
+    }
+
+    /// If there is non-paired rows in the probe batch, the function process 
them.
+    /// If not, it directs the state to fetching probe side.
+    fn generate_result(&mut self) -> 
Result<StatefulStreamResult<Option<RecordBatch>>> {
+        if self.right_row_index < self.right_batch.num_rows() {
+            // Right batch has some unpaired rows, continue with the next row.
+            let result_batch = self.build_batch()?;
+            Ok(StatefulStreamResult::Ready(Some(result_batch)))
+        } else {
+            self.state = CrossJoinStreamState::FetchProbeBatch;
+            Ok(StatefulStreamResult::Continue)
+        }
+    }
+
+    /// This function constructs a new `RecordBatch` by joining the left and 
right batches
+    /// based on the current indices. It also updates the metrics.
+    ///
+    /// # Arguments
+    /// * `self.left_data` - Array of `RecordBatch`es from the left side to be 
joined.
+    /// * `self.right_batch` - The current `RecordBatch` from the right side 
to be joined.
+    /// * `self.left_batch_index` - Index of the current left batch being 
processed.
+    /// * `self.right_row_index` - Index of the current row in the right batch 
to be joined.
+    /// * `join_metrics` - Metrics container to track performance of the join 
operation.
+
+    fn build_batch(&mut self) -> Result<RecordBatch> {
+        let join_timer = self.join_metrics.join_time.timer();
+        // Create copies of the indexed right-side row for joining.
+        let right_copies: Vec<Arc<dyn Array>> = get_arrayref_at_indices(

Review Comment:
   Just a question -- is this version of build-batch more performant due to 
avoiding conversion through `ScalarValue`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to