alihan-synnada commented on code in PR #12634:
URL: https://github.com/apache/datafusion/pull/12634#discussion_r1778291069
##########
datafusion/physical-plan/src/joins/nested_loop_join.rs:
##########
@@ -560,91 +729,147 @@ impl NestedLoopJoinStream {
// Get or initialize visited_left_side bitmap if required by join type
let visited_left_side = left_data.bitmap();
- // Check is_exhausted before polling the outer_table, such that when
the outer table
- // does not support `FusedStream`, Self will not poll it again
- if self.is_exhausted {
- return Poll::Ready(None);
- }
+ loop {
+ if let Some(batch) = self.output_buffer.next() {
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ return Poll::Ready(Some(Ok(batch)));
+ }
- self.outer_table
- .poll_next_unpin(cx)
- .map(|maybe_batch| match maybe_batch {
- Some(Ok(right_batch)) => {
- // Setting up timer & updating input metrics
- self.join_metrics.input_batches.add(1);
- self.join_metrics.input_rows.add(right_batch.num_rows());
- let timer = self.join_metrics.join_time.timer();
-
- let result = join_left_and_right_batch(
- left_data.batch(),
- &right_batch,
- self.join_type,
- self.filter.as_ref(),
- &self.column_indices,
- &self.schema,
- visited_left_side,
- &mut self.indices_cache,
- self.right_side_ordered,
- );
-
- // Recording time & updating output metrics
- if let Ok(batch) = &result {
- timer.done();
- self.join_metrics.output_batches.add(1);
- self.join_metrics.output_rows.add(batch.num_rows());
+ // Check is_exhausted before polling the outer_table, such that
when the outer table
+ // does not support `FusedStream`, Self will not poll it again
+ if self.is_exhausted {
+ let batch = self.output_buffer.finish()?;
+ if let Some(batch) = &batch {
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ }
+ return Poll::Ready(Ok(batch).transpose());
+ }
+
+ if self.outer_record_batch.is_none() {
+ // Get the next outer record batch
+ match self.outer_table.poll_next_unpin(cx) {
+ Poll::Ready(Some(Ok(batch))) => {
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(batch.num_rows());
+ self.memory_reservation
+ .try_grow(batch.get_array_memory_size())?;
+ self.outer_record_batch = Some(batch);
+ self.outer_record_batch_row = 0;
}
+ Poll::Ready(Some(Err(e))) => return
Poll::Ready(Some(Err(e))),
+ Poll::Ready(None) => {
+ if need_produce_result_in_final(self.join_type) {
+ // At this stage `visited_left_side` won't be
updated, so it's
+ // safe to report about probe completion.
+ //
+ // Setting `is_exhausted` will prevent from
multiple calls of
+ // `report_probe_completed()`
+ if !left_data.report_probe_completed() {
+ self.is_exhausted = true;
+ continue;
+ };
+
+ // Only setting up timer, input is exhausted
+ let timer = self.join_metrics.join_time.timer();
+ // use the global left bitmap to produce the left
indices and right indices
+ let (left_side, right_side) =
+ get_final_indices_from_shared_bitmap(
+ visited_left_side,
+ self.join_type,
+ );
+ let empty_right_batch =
+
RecordBatch::new_empty(self.outer_table.schema());
+ // use the left and right indices to produce the
batch result
+ let result = build_batch_from_indices(
+ &self.schema,
+ left_data.batch(),
+ &empty_right_batch,
+ &left_side,
+ &right_side,
+ &self.column_indices,
+ JoinSide::Left,
+ );
+ self.is_exhausted = true;
- Some(result)
- }
- Some(err) => Some(err),
- None => {
- if need_produce_result_in_final(self.join_type) {
- // At this stage `visited_left_side` won't be updated,
so it's
- // safe to report about probe completion.
- //
- // Setting `is_exhausted` / returning None will
prevent from
- // multiple calls of `report_probe_completed()`
- if !left_data.report_probe_completed() {
+ // Recording time & updating output metrics
+ match result {
+ Ok(batch) => {
+ timer.done();
+ self.output_buffer.push(batch)?;
+ continue;
+ }
+ Err(e) => return Poll::Ready(Some(Err(e))),
+ }
+ } else {
self.is_exhausted = true;
- return None;
+ continue;
+ }
+ }
+ Poll::Pending => {
+ return match self.output_buffer.flush() {
+ Ok(Some(batch)) => {
+ // If there was anything in the output buffer
flush it
+ // so that it can be processed.
+ self.join_metrics.output_batches.add(1);
+
self.join_metrics.output_rows.add(batch.num_rows());
+ Poll::Ready(Some(Ok(batch)))
+ }
+ Ok(None) => Poll::Pending,
+ Err(err) => Poll::Ready(Some(Err(err))),
};
+ }
+ }
+ }
- // Only setting up timer, input is exhausted
- let timer = self.join_metrics.join_time.timer();
- // use the global left bitmap to produce the left
indices and right indices
- let (left_side, right_side) =
- get_final_indices_from_shared_bitmap(
- visited_left_side,
- self.join_type,
- );
- let empty_right_batch =
- RecordBatch::new_empty(self.outer_table.schema());
- // use the left and right indices to produce the batch
result
- let result = build_batch_from_indices(
- &self.schema,
- left_data.batch(),
- &empty_right_batch,
- &left_side,
- &right_side,
- &self.column_indices,
- JoinSide::Left,
- );
- self.is_exhausted = true;
-
- // Recording time & updating output metrics
- if let Ok(batch) = &result {
- timer.done();
- self.join_metrics.output_batches.add(1);
-
self.join_metrics.output_rows.add(batch.num_rows());
- }
+ debug_assert!(self.outer_record_batch.is_some());
+ let right_batch = self.outer_record_batch.as_ref().unwrap();
+ let num_rows = match (self.join_type,
left_data.batch().num_rows()) {
+ // An inner join will only produce 1 output row per input row.
+ (JoinType::Inner, _) | (_, 0) =>
self.output_buffer.needed_rows(),
Review Comment:
Isn't it possible for an Inner Join produce multiple output rows per input
row? I believe the "1 output row per input row" statement only holds true for
equijoins. NestedLoopJoin works on non equijoins by definition.
https://github.com/apache/datafusion/blob/9b4f90ad1eefabdc0d5bbbfd99e58765b041bb77/datafusion/physical-plan/src/joins/nested_loop_join.rs#L103-L104
For example, for `left=[1, 2]` and `right=[1, 2, 3]` with the `ON` clause
`left<>right`, it produces `[(1, 2), (2, 1), (1, 3), (2, 3)]`. The row `3` from
the right side produces 2 rows.
It seems impossible to predict the number of output rows without running the
join.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]