This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4eb1a57ff Always wrapping OnceAsync for the inner table side in 
NestedLoopJoinExec (#5156)
4eb1a57ff is described below

commit 4eb1a57ffec82a8b3abd3ee266716921897f1630
Author: ygf11 <[email protected]>
AuthorDate: Sun Feb 12 21:10:31 2023 +0800

    Always wrapping OnceAsync for the inner table side in NestedLoopJoinExec 
(#5156)
    
    * Always wrapping OnceFut for the inner table side in NestedLoopJoinExec
    
    * fix cargo fmt
    
    * fix comment
    
    * fix comment
    
    * fix comment
    
    * Update datafusion/core/src/physical_plan/joins/nested_loop_join.rs
    
    Co-authored-by: jakevin <[email protected]>
    
    * Update datafusion/core/src/physical_plan/joins/nested_loop_join.rs
    
    Co-authored-by: jakevin <[email protected]>
    
    * Update datafusion/core/src/physical_plan/joins/nested_loop_join.rs
    
    Co-authored-by: jakevin <[email protected]>
    
    ---------
    
    Co-authored-by: jakevin <[email protected]>
---
 .../src/physical_plan/joins/nested_loop_join.rs    | 405 ++++++++++++++-------
 datafusion/core/src/physical_plan/joins/utils.rs   |  34 ++
 datafusion/core/tests/sql/joins.rs                 | 104 ++++++
 datafusion/core/tests/sql/mod.rs                   |  50 +++
 4 files changed, 454 insertions(+), 139 deletions(-)

diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs 
b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index 1e7f6c4db..3d4e64aa5 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -20,11 +20,11 @@
 //! determined by the [`JoinType`].
 
 use crate::physical_plan::joins::utils::{
-    adjust_indices_by_join_type, adjust_right_output_partitioning,
-    apply_join_filter_to_indices, build_batch_from_indices, build_join_schema,
-    check_join_is_valid, combine_join_equivalence_properties, 
estimate_join_statistics,
-    get_final_indices_from_bit_map, need_produce_result_in_final, ColumnIndex,
-    JoinFilter, OnceAsync, OnceFut,
+    adjust_right_output_partitioning, append_right_indices, 
apply_join_filter_to_indices,
+    build_batch_from_indices, build_join_schema, check_join_is_valid,
+    combine_join_equivalence_properties, estimate_join_statistics, 
get_anti_indices,
+    get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices,
+    get_semi_u64_indices, ColumnIndex, JoinFilter, OnceAsync, OnceFut,
 };
 use crate::physical_plan::{
     DisplayFormatType, Distribution, ExecutionPlan, Partitioning, 
RecordBatchStream,
@@ -50,9 +50,26 @@ use crate::error::Result;
 use crate::execution::context::TaskContext;
 use crate::physical_plan::coalesce_batches::concat_batches;
 
-/// Data of the left side
+/// Data of the inner table side
 type JoinLeftData = RecordBatch;
 
+/// NestedLoopJoinExec executes partitions in parallel.
+/// One input will be collected to a single partition, call it inner-table.
+/// The other side of the input is treated as outer-table, and the output 
Partitioning is from it.
+/// Giving an output partition number x, the execution will be:
+///
+/// ```text
+/// for outer-table-batch in outer-table-partition-x
+///     check-join(outer-table-batch, inner-table-data)
+/// ```
+///
+/// One of the inputs will become inner table, and it is decided by the join 
type.
+/// Following is the relation table:
+///
+/// | JoinType                       | Distribution (left, right)              
   | Inner-table |
+/// 
|--------------------------------|--------------------------------------------|-------------|
+/// | Inner/Left/LeftSemi/LeftAnti   | (UnspecifiedDistribution, 
SinglePartition) | right       |
+/// | Right/RightSemi/RightAnti/Full | (SinglePartition, 
UnspecifiedDistribution) | left        |
 ///
 #[derive(Debug)]
 pub struct NestedLoopJoinExec {
@@ -67,7 +84,7 @@ pub struct NestedLoopJoinExec {
     /// The schema once the join is applied
     schema: SchemaRef,
     /// Build-side data
-    left_fut: OnceAsync<JoinLeftData>,
+    inner_table: OnceAsync<JoinLeftData>,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
 }
@@ -91,24 +108,10 @@ impl NestedLoopJoinExec {
             filter,
             join_type: *join_type,
             schema: Arc::new(schema),
-            left_fut: Default::default(),
+            inner_table: Default::default(),
             column_indices,
         })
     }
-
-    fn is_single_partition_for_left(&self) -> bool {
-        matches!(
-            self.required_input_distribution()[0],
-            Distribution::SinglePartition
-        )
-    }
-
-    fn is_single_partition_for_right(&self) -> bool {
-        matches!(
-            self.required_input_distribution()[1],
-            Distribution::SinglePartition
-        )
-    }
 }
 
 impl ExecutionPlan for NestedLoopJoinExec {
@@ -186,37 +189,28 @@ impl ExecutionPlan for NestedLoopJoinExec {
         partition: usize,
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
-        // left side
-        let left_fut = if self.is_single_partition_for_left() {
-            // if the distribution of left is `SinglePartition`, just need to 
collect the left one
-            self.left_fut.once(|| {
-                // just one partition for the left side, and the first 
partition is all of data for left
-                load_left_specified_partition(0, self.left.clone(), 
context.clone())
-            })
-        } else {
-            // the distribution of left is not single partition, just need the 
specified partition for left
-            OnceFut::new(load_left_specified_partition(
-                partition,
-                self.left.clone(),
-                context.clone(),
-            ))
-        };
-        // right side
-        let right_side = if self.is_single_partition_for_right() {
-            // the distribution of right is `SinglePartition`
-            // if the distribution of right is `SinglePartition`, just need to 
collect the right one
-            self.right.execute(0, context)?
+        let (outer_table, inner_table) = if left_is_build_side(self.join_type) 
{
+            // left must be single partition
+            let inner_table = self.inner_table.once(|| {
+                load_specified_partition_of_input(0, self.left.clone(), 
context.clone())
+            });
+            let outer_table = self.right.execute(partition, context)?;
+            (outer_table, inner_table)
         } else {
-            // the distribution of right is not single partition, just need 
the specified partition for right
-            self.right.execute(partition, context)?
+            // right must be single partition
+            let inner_table = self.inner_table.once(|| {
+                load_specified_partition_of_input(0, self.right.clone(), 
context.clone())
+            });
+            let outer_table = self.left.execute(partition, context)?;
+            (outer_table, inner_table)
         };
 
         Ok(Box::pin(NestedLoopJoinStream {
             schema: self.schema.clone(),
             filter: self.filter.clone(),
             join_type: self.join_type,
-            left_fut,
-            right: right_side,
+            outer_table,
+            inner_table,
             is_exhausted: false,
             visited_left_side: None,
             column_indices: self.column_indices.clone(),
@@ -274,14 +268,14 @@ fn distribution_from_join_type(join_type: &JoinType) -> 
Vec<Distribution> {
     }
 }
 
-/// Asynchronously collect the result of the left child for the specified 
partition
-async fn load_left_specified_partition(
+/// Asynchronously collect the specified partition data of the input
+async fn load_specified_partition_of_input(
     partition: usize,
-    left: Arc<dyn ExecutionPlan>,
+    input: Arc<dyn ExecutionPlan>,
     context: Arc<TaskContext>,
 ) -> Result<JoinLeftData> {
     let start = Instant::now();
-    let stream = left.execute(partition, context)?;
+    let stream = input.execute(partition, context)?;
 
     // Load all batches and count the rows
     let (batches, num_rows) = stream
@@ -292,10 +286,10 @@ async fn load_left_specified_partition(
         })
         .await?;
 
-    let merged_batch = concat_batches(&left.schema(), &batches, num_rows)?;
+    let merged_batch = concat_batches(&input.schema(), &batches, num_rows)?;
 
     debug!(
-        "Built left-side of nested loop join containing {} rows in {} ms for 
partition {}",
+        "Built input of nested loop join containing {} rows in {} ms for 
partition {}",
         num_rows,
         start.elapsed().as_millis(),
         partition
@@ -304,6 +298,14 @@ async fn load_left_specified_partition(
     Ok(merged_batch)
 }
 
+// BuildLeft means the left relation is the single patrition side.
+// For full join, both side are single partition, so it is BuildLeft and 
BuildRight, treat it as BuildLeft.
+pub fn left_is_build_side(join_type: JoinType) -> bool {
+    matches!(
+        join_type,
+        JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | 
JoinType::Full
+    )
+}
 /// A stream that issues [RecordBatch]es as they arrive from the right  of the 
join.
 struct NestedLoopJoinStream {
     /// Input schema
@@ -312,11 +314,11 @@ struct NestedLoopJoinStream {
     filter: Option<JoinFilter>,
     /// type of the join
     join_type: JoinType,
-    /// future for data from left side
-    left_fut: OnceFut<JoinLeftData>,
-    /// right
-    right: SendableRecordBatchStream,
-    /// There is nothing to process anymore and left side is processed in case 
of left/left semi/left anti/full join
+    /// the outer table data of the nested loop join
+    outer_table: SendableRecordBatchStream,
+    /// the inner table data of the nested loop join
+    inner_table: OnceFut<JoinLeftData>,
+    /// There is nothing to process anymore and left side is processed in case 
of full join
     is_exhausted: bool,
     /// Keeps track of the left side rows whether they are visited
     visited_left_side: Option<BooleanBufferBuilder>,
@@ -332,9 +334,10 @@ fn build_join_indices(
     left_data: &JoinLeftData,
     filter: Option<&JoinFilter>,
 ) -> Result<(UInt64Array, UInt32Array)> {
-    let right_row_count = batch.num_rows();
     // left indices: [left_index, left_index, ...., left_index]
     // right indices: [0, 1, 2, 3, 4,....,right_row_count]
+
+    let right_row_count = batch.num_rows();
     let left_indices = UInt64Array::from(vec![left_index as u64; 
right_row_count]);
     let right_indices = UInt32Array::from_iter_values(0..(right_row_count as 
u32));
     // in the nested loop join, the filter can contain non-equal and equal 
condition.
@@ -352,24 +355,22 @@ fn build_join_indices(
 }
 
 impl NestedLoopJoinStream {
-    fn poll_next_impl(
+    /// For Right/RightSemi/RightAnti/Full joins, left is the single partition 
side.
+    fn poll_next_impl_for_build_left(
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
         // all left row
-        let left_data = match ready!(self.left_fut.get(cx)) {
-            Ok(left_data) => left_data,
+        let left_data = match ready!(self.inner_table.get(cx)) {
+            Ok(data) => data,
             Err(e) => return Poll::Ready(Some(Err(e))),
         };
 
+        // add a bitmap for full join.
         let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
             let left_num_rows = left_data.num_rows();
-            if need_produce_result_in_final(self.join_type) {
-                // these join type need the bitmap to identify which row has 
be matched or unmatched.
-                // For the `left semi` join, need to use the bitmap to produce 
the matched row in the left side
-                // For the `left` join, need to use the bitmap to produce the 
unmatched row in the left side with null
-                // For the `left anti` join, need to use the bitmap to produce 
the unmatched row in the left side
-                // For the `full` join, need to use the bitmap to produce the 
unmatched row in the left side with null
+            // only full join need bitmap
+            if self.join_type == JoinType::Full {
                 let mut buffer = BooleanBufferBuilder::new(left_num_rows);
                 buffer.append_n(left_num_rows, false);
                 buffer
@@ -378,89 +379,31 @@ impl NestedLoopJoinStream {
             }
         });
 
-        // iter the right batch
-        self.right
+        self.outer_table
             .poll_next_unpin(cx)
             .map(|maybe_batch| match maybe_batch {
                 Some(Ok(right_batch)) => {
-                    // TODO: optimize this logic like the cross join, and just 
return a small batch for each loop
-                    // get the matched left and right indices
-                    // each left row will try to match every right row
-                    let indices_result = (0..left_data.num_rows())
-                        .map(|left_row_index| {
-                            build_join_indices(
-                                left_row_index,
-                                &right_batch,
-                                left_data,
-                                self.filter.as_ref(),
-                            )
-                        })
-                        .collect::<Result<Vec<(UInt64Array, UInt32Array)>>>();
-                    let mut left_indices_builder = UInt64Builder::new();
-                    let mut right_indices_builder = UInt32Builder::new();
-                    let left_right_indices = match indices_result {
-                        Err(_) => Err(DataFusionError::Execution(
-                            "Build left right indices error".to_string(),
-                        )),
-                        Ok(indices) => {
-                            for (left_side, right_side) in indices {
-                                left_indices_builder.append_values(
-                                    left_side.values(),
-                                    &vec![true; left_side.len()],
-                                );
-                                right_indices_builder.append_values(
-                                    right_side.values(),
-                                    &vec![true; right_side.len()],
-                                );
-                            }
-                            Ok((
-                                left_indices_builder.finish(),
-                                right_indices_builder.finish(),
-                            ))
-                        }
-                    };
-                    let result = match left_right_indices {
-                        Ok((left_side, right_side)) => {
-                            // set the left bitmap
-                            // and only left, full, left semi, left anti need 
the left bitmap
-                            if need_produce_result_in_final(self.join_type) {
-                                left_side.iter().flatten().for_each(|x| {
-                                    visited_left_side.set_bit(x as usize, 
true);
-                                });
-                            }
-                            // adjust the two side indices base on the join 
type
-                            let (left_side, right_side) = 
adjust_indices_by_join_type(
-                                left_side,
-                                right_side,
-                                right_batch.num_rows(),
-                                self.join_type,
-                            );
-
-                            let result = build_batch_from_indices(
-                                &self.schema,
-                                left_data,
-                                &right_batch,
-                                left_side,
-                                right_side,
-                                &self.column_indices,
-                            );
-                            Some(result)
-                        }
-                        Err(e) => Some(Err(e)),
-                    };
-                    result
+                    let result = join_left_and_right_batch(
+                        left_data,
+                        &right_batch,
+                        self.join_type,
+                        self.filter.as_ref(),
+                        &self.column_indices,
+                        &self.schema,
+                        visited_left_side,
+                    );
+                    Some(result)
                 }
                 Some(err) => Some(err),
                 None => {
-                    if need_produce_result_in_final(self.join_type) && 
!self.is_exhausted
-                    {
+                    if self.join_type == JoinType::Full && !self.is_exhausted {
                         // use the global left bitmap to produce the left 
indices and right indices
                         let (left_side, right_side) = 
get_final_indices_from_bit_map(
                             visited_left_side,
                             self.join_type,
                         );
                         let empty_right_batch =
-                            RecordBatch::new_empty(self.right.schema());
+                            RecordBatch::new_empty(self.outer_table.schema());
                         // use the left and right indices to produce the batch 
result
                         let result = build_batch_from_indices(
                             &self.schema,
@@ -479,6 +422,186 @@ impl NestedLoopJoinStream {
                 }
             })
     }
+
+    /// For Inner/Left/LeftSemi/LeftAnti joins, right is the single partition 
side.
+    fn poll_next_impl_for_build_right(
+        &mut self,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        // all right row
+        let right_data = match ready!(self.inner_table.get(cx)) {
+            Ok(data) => data,
+            Err(e) => return Poll::Ready(Some(Err(e))),
+        };
+
+        // for build right, bitmap is not needed.
+        let mut empty_visited_left_side = BooleanBufferBuilder::new(0);
+        self.outer_table
+            .poll_next_unpin(cx)
+            .map(|maybe_batch| match maybe_batch {
+                Some(Ok(left_batch)) => {
+                    let result = join_left_and_right_batch(
+                        &left_batch,
+                        right_data,
+                        self.join_type,
+                        self.filter.as_ref(),
+                        &self.column_indices,
+                        &self.schema,
+                        &mut empty_visited_left_side,
+                    );
+                    Some(result)
+                }
+                Some(err) => Some(err),
+                None => None,
+            })
+    }
+}
+
+fn join_left_and_right_batch(
+    left_batch: &RecordBatch,
+    right_batch: &RecordBatch,
+    join_type: JoinType,
+    filter: Option<&JoinFilter>,
+    column_indices: &[ColumnIndex],
+    schema: &Schema,
+    visited_left_side: &mut BooleanBufferBuilder,
+) -> Result<RecordBatch> {
+    let indices_result = (0..left_batch.num_rows())
+        .map(|left_row_index| {
+            build_join_indices(left_row_index, right_batch, left_batch, filter)
+        })
+        .collect::<Result<Vec<(UInt64Array, UInt32Array)>>>();
+
+    let mut left_indices_builder = UInt64Builder::new();
+    let mut right_indices_builder = UInt32Builder::new();
+    let left_right_indices = match indices_result {
+        Err(_) => Err(DataFusionError::Execution(
+            "Build left right indices error".to_string(),
+        )),
+        Ok(indices) => {
+            for (left_side, right_side) in indices {
+                left_indices_builder
+                    .append_values(left_side.values(), &vec![true; 
left_side.len()]);
+                right_indices_builder
+                    .append_values(right_side.values(), &vec![true; 
right_side.len()]);
+            }
+            Ok((
+                left_indices_builder.finish(),
+                right_indices_builder.finish(),
+            ))
+        }
+    };
+    match left_right_indices {
+        Ok((left_side, right_side)) => {
+            // set the left bitmap
+            // and only full join need the left bitmap
+            if join_type == JoinType::Full {
+                left_side.iter().flatten().for_each(|x| {
+                    visited_left_side.set_bit(x as usize, true);
+                });
+            }
+            // adjust the two side indices base on the join type
+            let (left_side, right_side) = adjust_indices_by_join_type(
+                left_side,
+                right_side,
+                left_batch.num_rows(),
+                right_batch.num_rows(),
+                join_type,
+            );
+
+            build_batch_from_indices(
+                schema,
+                left_batch,
+                right_batch,
+                left_side,
+                right_side,
+                column_indices,
+            )
+        }
+        Err(e) => Err(e),
+    }
+}
+
+fn adjust_indices_by_join_type(
+    left_indices: UInt64Array,
+    right_indices: UInt32Array,
+    count_left_batch: usize,
+    count_right_batch: usize,
+    join_type: JoinType,
+) -> (UInt64Array, UInt32Array) {
+    match join_type {
+        JoinType::Inner => (left_indices, right_indices),
+        JoinType::Left => {
+            // matched
+            // unmatched left row will be produced in this batch
+            let left_unmatched_indices =
+                get_anti_u64_indices(count_left_batch, &left_indices);
+            // combine the matched and unmatched left result together
+            append_left_indices(left_indices, right_indices, 
left_unmatched_indices)
+        }
+        JoinType::LeftSemi => {
+            // need to remove the duplicated record in the left side
+            let left_indices = get_semi_u64_indices(count_left_batch, 
&left_indices);
+            // the right_indices will not be used later for the `left semi` 
join
+            (left_indices, right_indices)
+        }
+        JoinType::LeftAnti => {
+            // need to remove the duplicated record in the left side
+            // get the anti index for the left side
+            let left_indices = get_anti_u64_indices(count_left_batch, 
&left_indices);
+            // the right_indices will not be used later for the `left anti` 
join
+            (left_indices, right_indices)
+        }
+        // right/right-semi/right-anti => right = outer_table, left = 
inner_table
+        JoinType::Right | JoinType::Full => {
+            // matched
+            // unmatched right row will be produced in this batch
+            let right_unmatched_indices =
+                get_anti_indices(count_right_batch, &right_indices);
+            // combine the matched and unmatched right result together
+            append_right_indices(left_indices, right_indices, 
right_unmatched_indices)
+        }
+        JoinType::RightSemi => {
+            // need to remove the duplicated record in the right side
+            let right_indices = get_semi_indices(count_right_batch, 
&right_indices);
+            // the left_indices will not be used later for the `right semi` 
join
+            (left_indices, right_indices)
+        }
+        JoinType::RightAnti => {
+            // need to remove the duplicated record in the right side
+            // get the anti index for the right side
+            let right_indices = get_anti_indices(count_right_batch, 
&right_indices);
+            // the left_indices will not be used later for the `right anti` 
join
+            (left_indices, right_indices)
+        }
+    }
+}
+
+/// Appends the `left_unmatched_indices` to the `left_indices`,
+/// and fills Null to tail of `right_indices` to
+/// keep the length of `left_indices` and `right_indices` consistent.
+fn append_left_indices(
+    left_indices: UInt64Array,
+    right_indices: UInt32Array,
+    left_unmatched_indices: UInt64Array,
+) -> (UInt64Array, UInt32Array) {
+    if left_unmatched_indices.is_empty() {
+        (left_indices, right_indices)
+    } else {
+        let unmatched_size = left_unmatched_indices.len();
+        // the new left indices: left_indices + null array
+        // the new right indices: right_indices + right_unmatched_indices
+        let new_left_indices = left_indices
+            .iter()
+            .chain(left_unmatched_indices.iter())
+            .collect::<UInt64Array>();
+        let new_right_indices = right_indices
+            .iter()
+            .chain(std::iter::repeat(None).take(unmatched_size))
+            .collect::<UInt32Array>();
+
+        (new_left_indices, new_right_indices)
+    }
 }
 
 impl Stream for NestedLoopJoinStream {
@@ -488,7 +611,11 @@ impl Stream for NestedLoopJoinStream {
         mut self: std::pin::Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Self::Item>> {
-        self.poll_next_impl(cx)
+        if left_is_build_side(self.join_type) {
+            self.poll_next_impl_for_build_left(cx)
+        } else {
+            self.poll_next_impl_for_build_right(cx)
+        }
     }
 }
 
diff --git a/datafusion/core/src/physical_plan/joins/utils.rs 
b/datafusion/core/src/physical_plan/joins/utils.rs
index 372458523..b01483f56 100644
--- a/datafusion/core/src/physical_plan/joins/utils.rs
+++ b/datafusion/core/src/physical_plan/joins/utils.rs
@@ -917,6 +917,23 @@ pub(crate) fn get_anti_indices(
         .collect::<UInt32Array>()
 }
 
+/// Get unmatched and deduplicated indices
+pub(crate) fn get_anti_u64_indices(
+    row_count: usize,
+    input_indices: &UInt64Array,
+) -> UInt64Array {
+    let mut bitmap = BooleanBufferBuilder::new(row_count);
+    bitmap.append_n(row_count, false);
+    input_indices.iter().flatten().for_each(|v| {
+        bitmap.set_bit(v as usize, true);
+    });
+
+    // get the anti index
+    (0..row_count)
+        .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u64))
+        .collect::<UInt64Array>()
+}
+
 /// Get matched and deduplicated indices
 pub(crate) fn get_semi_indices(
     row_count: usize,
@@ -934,6 +951,23 @@ pub(crate) fn get_semi_indices(
         .collect::<UInt32Array>()
 }
 
+/// Get matched and deduplicated indices
+pub(crate) fn get_semi_u64_indices(
+    row_count: usize,
+    input_indices: &UInt64Array,
+) -> UInt64Array {
+    let mut bitmap = BooleanBufferBuilder::new(row_count);
+    bitmap.append_n(row_count, false);
+    input_indices.iter().flatten().for_each(|v| {
+        bitmap.set_bit(v as usize, true);
+    });
+
+    // get the semi index
+    (0..row_count)
+        .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u64))
+        .collect::<UInt64Array>()
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index 37b662a28..30bdbd418 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -3298,3 +3298,107 @@ async fn two_in_subquery_to_join_with_outer_filter() -> 
Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn right_as_inner_table_nested_loop_join() -> Result<()> {
+    let ctx = create_nested_loop_join_context()?;
+
+    // Distribution: left is `UnspecifiedDistribution`, right is 
`SinglePartition`.
+    let sql = "SELECT t1.t1_id, t2.t2_id 
+                     FROM t1 INNER JOIN t2 ON t1.t1_id > t2.t2_id 
+                     WHERE t1.t1_id > 10 AND t2.t2_int > 1";
+
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(sql).await.expect(&msg);
+    let physical_plan = dataframe.create_physical_plan().await?;
+
+    // right is single partition side, so it will be visited many times.
+    let expected = vec![
+        "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@1 as t2_id]",
+        "  NestedLoopJoinExec: join_type=Inner, filter=BinaryExpr { left: 
Column { name: \"t1_id\", index: 0 }, op: Gt, right: Column { name: \"t2_id\", 
index: 1 } }",
+        "    CoalesceBatchesExec: target_batch_size=4096",
+        "      FilterExec: t1_id@0 > 10",
+        "        RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1",
+        "          MemoryExec: partitions=1, partition_sizes=[1]",
+        "    CoalescePartitionsExec",
+        "      CoalesceBatchesExec: target_batch_size=4096",
+        "        FilterExec: t2_int@1 > 1",
+        "          RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1",
+        "            MemoryExec: partitions=1, partition_sizes=[1]",
+    ];
+    let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+-------+",
+        "| t1_id | t2_id |",
+        "+-------+-------+",
+        "| 22    | 11    |",
+        "| 33    | 11    |",
+        "| 44    | 11    |",
+        "+-------+-------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn left_as_inner_table_nested_loop_join() -> Result<()> {
+    let ctx = create_nested_loop_join_context()?;
+
+    // Distribution: left is `SinglePartition`, right is 
`UnspecifiedDistribution`.
+    let sql = "SELECT t1.t1_id,t2.t2_id FROM (select t1_id from t1 where 
t1.t1_id > 22) as t1 
+                                                 RIGHT JOIN (select t2_id from 
t2 where t2.t2_id > 11) as t2 
+                                                 ON t1.t1_id < t2.t2_id";
+
+    let msg = format!("Creating logical plan for '{sql}'");
+    let dataframe = ctx.sql(sql).await.expect(&msg);
+    let physical_plan = dataframe.create_physical_plan().await?;
+
+    // left is single partition side, so it will be visited many times.
+    let expected = vec![
+        "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@1 as t2_id]",
+        "  NestedLoopJoinExec: join_type=Right, filter=BinaryExpr { left: 
Column { name: \"t1_id\", index: 0 }, op: Lt, right: Column { name: \"t2_id\", 
index: 1 } }",
+        "    CoalescePartitionsExec",
+        "      ProjectionExec: expr=[t1_id@0 as t1_id]",
+        "        CoalesceBatchesExec: target_batch_size=4096",
+        "          FilterExec: t1_id@0 > 22",
+        "            RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1",
+        "              MemoryExec: partitions=1, partition_sizes=[1]",
+        "    ProjectionExec: expr=[t2_id@0 as t2_id]",
+        "      CoalesceBatchesExec: target_batch_size=4096",
+        "        FilterExec: t2_id@0 > 11",
+        "          RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1",
+        "            MemoryExec: partitions=1, partition_sizes=[1]",
+    ];
+    let formatted = displayable(physical_plan.as_ref()).indent().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
+    );
+
+    let expected = vec![
+        "+-------+-------+",
+        "| t1_id | t2_id |",
+        "+-------+-------+",
+        "|       | 22    |",
+        "| 33    | 44    |",
+        "| 33    | 55    |",
+        "| 44    | 55    |",
+        "+-------+-------+",
+    ];
+
+    let results = execute_to_batches(&ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &results);
+
+    Ok(())
+}
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index e2a1199d5..4b3c60d7e 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -710,6 +710,56 @@ fn create_union_context() -> Result<SessionContext> {
     Ok(ctx)
 }
 
+fn create_nested_loop_join_context() -> Result<SessionContext> {
+    let ctx = SessionContext::with_config(
+        SessionConfig::new()
+            .with_target_partitions(4)
+            .with_batch_size(4096),
+    );
+
+    let t1_schema = Arc::new(Schema::new(vec![
+        Field::new("t1_id", DataType::UInt32, true),
+        Field::new("t1_name", DataType::Utf8, true),
+        Field::new("t1_int", DataType::UInt32, true),
+    ]));
+    let t1_data = RecordBatch::try_new(
+        t1_schema,
+        vec![
+            Arc::new(UInt32Array::from_slice([11, 22, 33, 44])),
+            Arc::new(StringArray::from(vec![
+                Some("a"),
+                Some("b"),
+                Some("c"),
+                Some("d"),
+            ])),
+            Arc::new(UInt32Array::from_slice([1, 2, 3, 4])),
+        ],
+    )?;
+    ctx.register_batch("t1", t1_data)?;
+
+    let t2_schema = Arc::new(Schema::new(vec![
+        Field::new("t2_id", DataType::UInt32, true),
+        Field::new("t2_name", DataType::Utf8, true),
+        Field::new("t2_int", DataType::UInt32, true),
+    ]));
+    let t2_data = RecordBatch::try_new(
+        t2_schema,
+        vec![
+            Arc::new(UInt32Array::from_slice([11, 22, 44, 55])),
+            Arc::new(StringArray::from(vec![
+                Some("z"),
+                Some("y"),
+                Some("x"),
+                Some("w"),
+            ])),
+            Arc::new(UInt32Array::from_slice([3, 1, 3, 3])),
+        ],
+    )?;
+    ctx.register_batch("t2", t2_data)?;
+
+    Ok(ctx)
+}
+
 fn get_tpch_table_schema(table: &str) -> Schema {
     match table {
         "customer" => Schema::new(vec![

Reply via email to