alihan-synnada commented on code in PR #12634: URL: https://github.com/apache/datafusion/pull/12634#discussion_r1778291069
########## datafusion/physical-plan/src/joins/nested_loop_join.rs: ########## @@ -560,91 +729,147 @@ impl NestedLoopJoinStream { // Get or initialize visited_left_side bitmap if required by join type let visited_left_side = left_data.bitmap(); - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); - } + loop { + if let Some(batch) = self.output_buffer.next() { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + // Check is_exhausted before polling the outer_table, such that when the outer table + // does not support `FusedStream`, Self will not poll it again + if self.is_exhausted { + let batch = self.output_buffer.finish()?; + if let Some(batch) = &batch { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + return Poll::Ready(Ok(batch).transpose()); + } + + if self.outer_record_batch.is_none() { + // Get the next outer record batch + match self.outer_table.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.memory_reservation + .try_grow(batch.get_array_memory_size())?; + self.outer_record_batch = Some(batch); + self.outer_record_batch_row = 0; } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` will prevent from multiple calls of + // `report_probe_completed()` + if !left_data.report_probe_completed() { + self.is_exhausted = true; + continue; + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap( + visited_left_side, + self.join_type, + ); + let empty_right_batch = + RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.is_exhausted = true; - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { + // Recording time & updating output metrics + match result { + Ok(batch) => { + timer.done(); + self.output_buffer.push(batch)?; + continue; + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } else { self.is_exhausted = true; - return None; + continue; + } + } + Poll::Pending => { + return match self.output_buffer.flush() { + Ok(Some(batch)) => { + // If there was anything in the output buffer flush it + // so that it can be processed. + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Poll::Ready(Some(Ok(batch))) + } + Ok(None) => Poll::Pending, + Err(err) => Poll::Ready(Some(Err(err))), }; + } + } + } - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } + debug_assert!(self.outer_record_batch.is_some()); + let right_batch = self.outer_record_batch.as_ref().unwrap(); + let num_rows = match (self.join_type, left_data.batch().num_rows()) { + // An inner join will only produce 1 output row per input row. + (JoinType::Inner, _) | (_, 0) => self.output_buffer.needed_rows(), Review Comment: Isn't it possible for an Inner Join produce multiple output rows per input row? I believe the "1 output row per input row" statement only holds true for equijoins. NestedLoopJoin works on non equijoins by definition. https://github.com/apache/datafusion/blob/9b4f90ad1eefabdc0d5bbbfd99e58765b041bb77/datafusion/physical-plan/src/joins/nested_loop_join.rs#L103-L104 For example, for `left=[1, 2]` and `right=[1, 2, 3]` with the `ON` clause `left<>right`, it produces `[(1, 2), (2, 1), (1, 3), (2, 3)]`. The row `3` from the right side produces 2 rows. It seems impossible to predict the number of output rows without running the join. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org