korowa commented on code in PR #9830:
URL: https://github.com/apache/arrow-datafusion/pull/9830#discussion_r1545459833
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
}
impl CrossJoinStream {
- /// Separate implementation function that unpins the [`CrossJoinStream`] so
- /// that partial borrows work correctly
+ /// Separate implementation function that unpins the [`CrossJoinStream`]
+ /// so that partial borrows work correctly
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ CrossJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ CrossJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ CrossJoinStreamState::GenerateResult => {
+ handle_state!(self.generate_result())
+ }
+ CrossJoinStreamState::Completed => Poll::Ready(None),
+ };
+ }
+ }
+
+ /// Waits until the left data computation completes. After it is ready,
+ /// copies it into the state and continues with fetching probe side. If we
+ /// cannot receive any row from left, the operation ends without polling
right.
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
+ Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();
- if left_data.num_rows() == 0 {
- return Poll::Ready(None);
+ // If the left batch is empty, we can return `Poll::Ready(None)`
immediately.
+ if left_data.iter().all(|batch| batch.num_rows() == 0) {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ } else {
+ self.left_data = left_data
Review Comment:
Will this code clone all contents of collected LeftData for each
CrossJoinExec stream? If so -- could we avoid having multiple copies of
left-side batches and use only the original data from left_fut, accessed via
Arc?
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -311,20 +311,27 @@ fn stats_cartesian_product(
}
}
-/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
+/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
+/// Right column orders are preserved.
struct CrossJoinStream {
/// Input schema
schema: Arc<Schema>,
- /// future for data from left side
+ /// Future for data from left side
left_fut: OnceFut<JoinLeftData>,
- /// right
+ /// Right stream
right: SendableRecordBatchStream,
- /// Current value on the left
- left_index: usize,
- /// Current batch being processed from the right side
- right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
- /// join execution metrics
+ /// Join execution metrics
join_metrics: BuildProbeJoinMetrics,
+ /// State information
+ state: CrossJoinStreamState,
+ /// Left data
+ left_data: Vec<RecordBatch>,
+ /// Current right batch
+ right_batch: RecordBatch,
+ /// Indexes the next processed build side batch
+ left_batch_index: usize,
Review Comment:
minor: Could batch & row indices be a contents of GenerateResult state (as
they seem to be valid only while result generation)?
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -311,20 +311,27 @@ fn stats_cartesian_product(
}
}
-/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
+/// A stream that issues [RecordBatch]es as they arrive from the right of the
join.
+/// Right column orders are preserved.
struct CrossJoinStream {
/// Input schema
schema: Arc<Schema>,
- /// future for data from left side
+ /// Future for data from left side
left_fut: OnceFut<JoinLeftData>,
- /// right
+ /// Right stream
right: SendableRecordBatchStream,
- /// Current value on the left
- left_index: usize,
- /// Current batch being processed from the right side
- right_batch: Arc<parking_lot::Mutex<Option<RecordBatch>>>,
- /// join execution metrics
+ /// Join execution metrics
join_metrics: BuildProbeJoinMetrics,
+ /// State information
+ state: CrossJoinStreamState,
+ /// Left data
+ left_data: Vec<RecordBatch>,
Review Comment:
minor: we store both fut & data -- won't it make sense to store
`stream.left_data` as enum (with initial/ready variants, wrapping future and
data respectively)?
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
}
impl CrossJoinStream {
- /// Separate implementation function that unpins the [`CrossJoinStream`] so
- /// that partial borrows work correctly
+ /// Separate implementation function that unpins the [`CrossJoinStream`]
+ /// so that partial borrows work correctly
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ CrossJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ CrossJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ CrossJoinStreamState::GenerateResult => {
+ handle_state!(self.generate_result())
+ }
+ CrossJoinStreamState::Completed => Poll::Ready(None),
+ };
+ }
+ }
+
+ /// Waits until the left data computation completes. After it is ready,
+ /// copies it into the state and continues with fetching probe side. If we
+ /// cannot receive any row from left, the operation ends without polling
right.
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
+ Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();
- if left_data.num_rows() == 0 {
- return Poll::Ready(None);
+ // If the left batch is empty, we can return `Poll::Ready(None)`
immediately.
+ if left_data.iter().all(|batch| batch.num_rows() == 0) {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ } else {
+ self.left_data = left_data
+ .clone()
+ .into_iter()
+ .filter(|batch| batch.num_rows() > 0)
+ .collect();
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
}
+ }
- if self.left_index > 0 && self.left_index < left_data.num_rows() {
- let join_timer = self.join_metrics.join_time.timer();
- let right_batch = {
- let right_batch = self.right_batch.lock();
- right_batch.clone().unwrap()
- };
- let result =
- build_batch(self.left_index, &right_batch, left_data,
&self.schema);
- self.join_metrics.input_rows.add(right_batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
+ /// Polls the right (probe) side until a non-empty batch is ready.
+ /// Then, the next state is set as the result generation step after
+ /// the polled batch is stored in the state and indices are reset.
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ match ready!(self.right.poll_next_unpin(cx)) {
+ None => {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
}
- self.left_index += 1;
- return Poll::Ready(Some(result));
- }
- self.left_index = 0;
- self.right
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- Some(Ok(batch)) => {
- let join_timer = self.join_metrics.join_time.timer();
- let result =
- build_batch(self.left_index, &batch, left_data,
&self.schema);
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
- }
- self.left_index = 1;
-
- let mut right_batch = self.right_batch.lock();
- *right_batch = Some(batch);
-
- Some(result)
+ Some(Ok(right_batch)) => {
+ // Update the metrics.
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_batch.num_rows());
+ if right_batch.num_rows() == 0 {
+ return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
- other => other,
- })
+ // New batch arrives, reset the indices.
+ self.left_batch_index = 0;
+ self.right_row_index = 0;
+ // Store the new batch into the state.
+ self.right_batch = right_batch;
+ self.state = CrossJoinStreamState::GenerateResult;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+ Some(Err(err)) => Poll::Ready(Err(err)),
+ }
+ }
+
+ /// If there is non-paired rows in the probe batch, the function process
them.
+ /// If not, it directs the state to fetching probe side.
+ fn generate_result(&mut self) ->
Result<StatefulStreamResult<Option<RecordBatch>>> {
+ if self.right_row_index < self.right_batch.num_rows() {
+ // Right batch has some unpaired rows, continue with the next row.
+ let result_batch = self.build_batch()?;
+ Ok(StatefulStreamResult::Ready(Some(result_batch)))
+ } else {
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Ok(StatefulStreamResult::Continue)
+ }
+ }
+
+ /// This function constructs a new `RecordBatch` by joining the left and
right batches
+ /// based on the current indices. It also updates the metrics.
+ ///
+ /// # Arguments
+ /// * `self.left_data` - Array of `RecordBatch`es from the left side to be
joined.
+ /// * `self.right_batch` - The current `RecordBatch` from the right side
to be joined.
+ /// * `self.left_batch_index` - Index of the current left batch being
processed.
+ /// * `self.right_row_index` - Index of the current row in the right batch
to be joined.
+ /// * `join_metrics` - Metrics container to track performance of the join
operation.
+
+ fn build_batch(&mut self) -> Result<RecordBatch> {
Review Comment:
minor: Got the idea, but it may be a bit confusing that docstring doesn't
match signature.
BTW won't it be cleaner to leave it as a standalone function (as it was
before) and leave mertics & offsets updates to `generate_result`?
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -374,64 +376,147 @@ impl Stream for CrossJoinStream {
}
impl CrossJoinStream {
- /// Separate implementation function that unpins the [`CrossJoinStream`] so
- /// that partial borrows work correctly
+ /// Separate implementation function that unpins the [`CrossJoinStream`]
+ /// so that partial borrows work correctly
fn poll_next_impl(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ CrossJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ CrossJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ CrossJoinStreamState::GenerateResult => {
+ handle_state!(self.generate_result())
+ }
+ CrossJoinStreamState::Completed => Poll::Ready(None),
+ };
+ }
+ }
+
+ /// Waits until the left data computation completes. After it is ready,
+ /// copies it into the state and continues with fetching probe side. If we
+ /// cannot receive any row from left, the operation ends without polling
right.
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
+ Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();
- if left_data.num_rows() == 0 {
- return Poll::Ready(None);
+ // If the left batch is empty, we can return `Poll::Ready(None)`
immediately.
+ if left_data.iter().all(|batch| batch.num_rows() == 0) {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ } else {
+ self.left_data = left_data
+ .clone()
+ .into_iter()
+ .filter(|batch| batch.num_rows() > 0)
+ .collect();
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
}
+ }
- if self.left_index > 0 && self.left_index < left_data.num_rows() {
- let join_timer = self.join_metrics.join_time.timer();
- let right_batch = {
- let right_batch = self.right_batch.lock();
- right_batch.clone().unwrap()
- };
- let result =
- build_batch(self.left_index, &right_batch, left_data,
&self.schema);
- self.join_metrics.input_rows.add(right_batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
+ /// Polls the right (probe) side until a non-empty batch is ready.
+ /// Then, the next state is set as the result generation step after
+ /// the polled batch is stored in the state and indices are reset.
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ match ready!(self.right.poll_next_unpin(cx)) {
+ None => {
+ self.state = CrossJoinStreamState::Completed;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
}
- self.left_index += 1;
- return Poll::Ready(Some(result));
- }
- self.left_index = 0;
- self.right
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- Some(Ok(batch)) => {
- let join_timer = self.join_metrics.join_time.timer();
- let result =
- build_batch(self.left_index, &batch, left_data,
&self.schema);
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(batch.num_rows());
- if let Ok(ref batch) = result {
- join_timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
- }
- self.left_index = 1;
-
- let mut right_batch = self.right_batch.lock();
- *right_batch = Some(batch);
-
- Some(result)
+ Some(Ok(right_batch)) => {
+ // Update the metrics.
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_batch.num_rows());
+ if right_batch.num_rows() == 0 {
+ return Poll::Ready(Ok(StatefulStreamResult::Continue));
}
- other => other,
- })
+ // New batch arrives, reset the indices.
+ self.left_batch_index = 0;
+ self.right_row_index = 0;
+ // Store the new batch into the state.
+ self.right_batch = right_batch;
+ self.state = CrossJoinStreamState::GenerateResult;
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
+ Some(Err(err)) => Poll::Ready(Err(err)),
+ }
+ }
+
+ /// If there is non-paired rows in the probe batch, the function process
them.
+ /// If not, it directs the state to fetching probe side.
+ fn generate_result(&mut self) ->
Result<StatefulStreamResult<Option<RecordBatch>>> {
+ if self.right_row_index < self.right_batch.num_rows() {
+ // Right batch has some unpaired rows, continue with the next row.
+ let result_batch = self.build_batch()?;
+ Ok(StatefulStreamResult::Ready(Some(result_batch)))
+ } else {
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ Ok(StatefulStreamResult::Continue)
+ }
+ }
+
+ /// This function constructs a new `RecordBatch` by joining the left and
right batches
+ /// based on the current indices. It also updates the metrics.
+ ///
+ /// # Arguments
+ /// * `self.left_data` - Array of `RecordBatch`es from the left side to be
joined.
+ /// * `self.right_batch` - The current `RecordBatch` from the right side
to be joined.
+ /// * `self.left_batch_index` - Index of the current left batch being
processed.
+ /// * `self.right_row_index` - Index of the current row in the right batch
to be joined.
+ /// * `join_metrics` - Metrics container to track performance of the join
operation.
+
+ fn build_batch(&mut self) -> Result<RecordBatch> {
+ let join_timer = self.join_metrics.join_time.timer();
+ // Create copies of the indexed right-side row for joining.
+ let right_copies: Vec<Arc<dyn Array>> = get_arrayref_at_indices(
Review Comment:
Just a question -- is this version of build-batch more performant due to
avoiding conversion through `ScalarValue`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]