korowa commented on code in PR #9830:
URL: https://github.com/apache/arrow-datafusion/pull/9830#discussion_r1550149513
##########
datafusion/physical-plan/src/joins/cross_join.rs:
##########
@@ -380,58 +403,84 @@ impl CrossJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<RecordBatch>>> {
+ loop {
+ return match self.state {
+ CrossJoinStreamState::WaitBuildSide => {
+ handle_state!(ready!(self.collect_build_side(cx)))
+ }
+ CrossJoinStreamState::FetchProbeBatch => {
+ handle_state!(ready!(self.fetch_probe_batch(cx)))
+ }
+ CrossJoinStreamState::BuildBatches(_) => {
+ handle_state!(self.build_batches())
+ }
+ };
+ }
+ }
+
+ /// Collects build (left) side of the join into the state. In case of an
empty build batch,
+ /// the execution terminates. Otherwise, the state is updated to fetch
probe (right) batch.
+ fn collect_build_side(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
let build_timer = self.join_metrics.build_time.timer();
let (left_data, _) = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => return Poll::Ready(Some(Err(e))),
+ Err(e) => return Poll::Ready(Err(e)),
};
build_timer.done();
- if left_data.num_rows() == 0 {
- return Poll::Ready(None);
- }
+ let result = if left_data.num_rows() == 0 {
+ StatefulStreamResult::Ready(None)
+ } else {
+ self.left_data = left_data.clone();
+ self.state = CrossJoinStreamState::FetchProbeBatch;
+ StatefulStreamResult::Continue
+ };
+ Poll::Ready(Ok(result))
+ }
+
+ /// Fetches the probe (right) batch, updates the metrics, and save the
batch in the state.
+ /// Then, the state is updated to build result batches.
+ fn fetch_probe_batch(
+ &mut self,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
+ self.left_index = 0;
+ let right_data = match ready!(self.right.poll_next_unpin(cx)) {
+ Some(Ok(right_data)) => right_data,
+ Some(Err(e)) => return Poll::Ready(Err(e)),
+ None => return Poll::Ready(Ok(StatefulStreamResult::Ready(None))),
+ };
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_data.num_rows());
+
+ self.state = CrossJoinStreamState::BuildBatches(right_data);
+ Poll::Ready(Ok(StatefulStreamResult::Continue))
+ }
- if self.left_index > 0 && self.left_index < left_data.num_rows() {
+ /// Joins the the indexed row of left data with the current probe batch.
+ /// If all the results are produced, the state is set to fetch new probe
batch.
+ fn build_batches(&mut self) ->
Result<StatefulStreamResult<Option<RecordBatch>>> {
+ let right_batch = self.state.try_as_record_batch()?;
+ if self.left_index < self.left_data.num_rows() {
let join_timer = self.join_metrics.join_time.timer();
- let right_batch = {
- let right_batch = self.right_batch.lock();
- right_batch.clone().unwrap()
- };
let result =
- build_batch(self.left_index, &right_batch, left_data,
&self.schema);
+ build_batch(self.left_index, right_batch, &self.left_data,
&self.schema);
+ join_timer.done();
self.join_metrics.input_rows.add(right_batch.num_rows());
Review Comment:
This metrics increment looks like a duplicate of one in `fetch_probe_batch`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]