berkaysynnada commented on code in PR #9830:
URL: https://github.com/apache/arrow-datafusion/pull/9830#discussion_r1546084513
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
}
impl CrossJoinStream {
- /// Separate implementation function that unpins the [`CrossJoinStream`] so
- /// that partial borrows work correctly
+ /// Separate implementation function that unpins the [`CrossJoinStream`]
+ /// so that partial borrows work correctly
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ CrossJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ CrossJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ CrossJoinStreamState::GenerateResult => {
+ handle_state!(self.generate_result())
+ }
+ CrossJoinStreamState::Completed => Poll::Ready(None),
+ };
+ }
+ }
+
+ /// Waits until the left data computation completes. After it is ready,
+ /// copies it into the state and continues with fetching probe side. If we
+ /// cannot receive any row from left, the operation ends without polling
right.
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
+ Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();
- if left_data.num_rows() == 0 {
- return Poll::Ready(None);
+ // If the left batch is empty, we can return `Poll::Ready(None)`
immediately.
+ if left_data.iter().all(|batch| batch.num_rows() == 0) {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ } else {
+ self.left_data = left_data
+ .clone()
+ .into_iter()
+ .filter(|batch| batch.num_rows() > 0)
+ .collect();
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
}
+ }
- if self.left_index > 0 && self.left_index < left_data.num_rows() {
- let join_timer = self.join_metrics.join_time.timer();
- let right_batch = {
- let right_batch = self.right_batch.lock();
- right_batch.clone().unwrap()
- };
- let result =
- build_batch(self.left_index, &right_batch, left_data,
&self.schema);
- self.join_metrics.input_rows.add(right_batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
+ /// Polls the right (probe) side until a non-empty batch is ready.
+ /// Then, the next state is set as the result generation step after
+ /// the polled batch is stored in the state and indices are reset.
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ match ready!(self.right.poll_next_unpin(cx)) {
+ None => {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
}
- self.left_index += 1;
- return Poll::Ready(Some(result));
- }
- self.left_index = 0;
- self.right
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- Some(Ok(batch)) => {
- let join_timer = self.join_metrics.join_time.timer();
- let result =
- build_batch(self.left_index, &batch, left_data,
&self.schema);
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
- }
- self.left_index = 1;
-
- let mut right_batch = self.right_batch.lock();
- *right_batch = Some(batch);
-
- Some(result)
+ Some(Ok(right_batch)) => {
+ // Update the metrics.
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_batch.num_rows());
+ if right_batch.num_rows() == 0 {
+ return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
- other => other,
- })
+ // New batch arrives, reset the indices.
+ self.left_batch_index = 0;
+ self.right_row_index = 0;
+ // Store the new batch into the state.
+ self.right_batch = right_batch;
+ self.state = CrossJoinStreamState::GenerateResult;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+ Some(Err(err)) => Poll::Ready(Err(err)),
+ }
+ }
+
+ /// If there is non-paired rows in the probe batch, the function process
them.
+ /// If not, it directs the state to fetching probe side.
+ fn generate_result(&mut self) ->
Result<StatefulStreamResult<Option<RecordBatch>>> {
+ if self.right_row_index < self.right_batch.num_rows() {
+ // Right batch has some unpaired rows, continue with the next row.
+ let result_batch = self.build_batch()?;
+ Ok(StatefulStreamResult::Ready(Some(result_batch)))
+ } else {
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Ok(StatefulStreamResult::Continue)
+ }
+ }
+
+ /// This function constructs a new `RecordBatch` by joining the left and
right batches
+ /// based on the current indices. It also updates the metrics.
+ ///
+ /// # Arguments
+ /// * `self.left_data` - Array of `RecordBatch`es from the left side to be
joined.
+ /// * `self.right_batch` - The current `RecordBatch` from the right side
to be joined.
+ /// * `self.left_batch_index` - Index of the current left batch being
processed.
+ /// * `self.right_row_index` - Index of the current row in the right batch
to be joined.
+ /// * `join_metrics` - Metrics container to track performance of the join
operation.
+
+ fn build_batch(&mut self) -> Result<RecordBatch> {
+ let join_timer = self.join_metrics.join_time.timer();
+ // Create copies of the indexed right-side row for joining.
+ let right_copies: Vec<Arc<dyn Array>> = get_arrayref_at_indices(
Review Comment:
I believed it to be so. Other joins use the same approach. But unfortunately
I have no concrete evidence.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]