This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 69687da337 feat: support `RightAnti` for `SortMergeJoin` (#13680)
69687da337 is described below

commit 69687da337331c978adc201a75955d7e96605d6f
Author: irenjj <[email protected]>
AuthorDate: Mon Jan 6 04:42:20 2025 +0800

    feat: support `RightAnti` for `SortMergeJoin` (#13680)
    
    * feat: support `RightAnti` for `SortMergeJoin`
---
 datafusion/core/tests/fuzz_cases/join_fuzz.rs      |  28 +-
 .../physical-plan/src/joins/sort_merge_join.rs     | 710 ++++++++++++++++-----
 datafusion/physical-plan/src/test.rs               |  20 +
 .../sqllogictest/test_files/sort_merge_join.slt    |  48 ++
 4 files changed, 643 insertions(+), 163 deletions(-)

diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index cf1742a30e..b331388f4f 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -210,7 +210,7 @@ async fn test_semi_join_1k_filtered() {
 }
 
 #[tokio::test]
-async fn test_anti_join_1k() {
+async fn test_left_anti_join_1k() {
     JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
@@ -222,7 +222,7 @@ async fn test_anti_join_1k() {
 }
 
 #[tokio::test]
-async fn test_anti_join_1k_filtered() {
+async fn test_left_anti_join_1k_filtered() {
     JoinFuzzTestCase::new(
         make_staggered_batches(1000),
         make_staggered_batches(1000),
@@ -233,6 +233,30 @@ async fn test_anti_join_1k_filtered() {
     .await
 }
 
+#[tokio::test]
+async fn test_right_anti_join_1k() {
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::RightAnti,
+        None,
+    )
+    .run_test(&[HjSmj, NljHj], false)
+    .await
+}
+
+#[tokio::test]
+async fn test_right_anti_join_1k_filtered() {
+    JoinFuzzTestCase::new(
+        make_staggered_batches(1000),
+        make_staggered_batches(1000),
+        JoinType::RightAnti,
+        Some(Box::new(col_lt_col_filter)),
+    )
+    .run_test(&[HjSmj, NljHj], false)
+    .await
+}
+
 #[tokio::test]
 async fn test_left_mark_join_1k() {
     JoinFuzzTestCase::new(
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index 54bd63084e..bcacc7dcae 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -935,7 +935,7 @@ fn get_corrected_filter_mask(
 
             Some(corrected_mask.finish())
         }
-        JoinType::LeftAnti => {
+        JoinType::LeftAnti | JoinType::RightAnti => {
             for i in 0..row_indices_length {
                 let last_index =
                     last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
@@ -1047,6 +1047,7 @@ impl Stream for SortMergeJoinStream {
                                                 | JoinType::LeftMark
                                                 | JoinType::Right
                                                 | JoinType::LeftAnti
+                                                | JoinType::RightAnti
                                                 | JoinType::Full
                                         )
                                     {
@@ -1129,6 +1130,7 @@ impl Stream for SortMergeJoinStream {
                                         | JoinType::LeftSemi
                                         | JoinType::Right
                                         | JoinType::LeftAnti
+                                        | JoinType::RightAnti
                                         | JoinType::LeftMark
                                         | JoinType::Full
                                 )
@@ -1152,6 +1154,7 @@ impl Stream for SortMergeJoinStream {
                                     | JoinType::LeftSemi
                                     | JoinType::Right
                                     | JoinType::LeftAnti
+                                    | JoinType::RightAnti
                                     | JoinType::Full
                                     | JoinType::LeftMark
                             )
@@ -1468,6 +1471,7 @@ impl SortMergeJoinStream {
                         | JoinType::RightSemi
                         | JoinType::Full
                         | JoinType::LeftAnti
+                        | JoinType::RightAnti
                         | JoinType::LeftMark
                 ) {
                     join_streamed = !self.streamed_joined;
@@ -1501,7 +1505,9 @@ impl SortMergeJoinStream {
                     join_buffered = true;
                 };
 
-                if matches!(self.join_type, JoinType::LeftAnti) && 
self.filter.is_some() {
+                if matches!(self.join_type, JoinType::LeftAnti | 
JoinType::RightAnti)
+                    && self.filter.is_some()
+                {
                     join_streamed = !self.streamed_joined;
                     join_buffered = join_streamed;
                 }
@@ -1684,7 +1690,10 @@ impl SortMergeJoinStream {
             let right_indices: UInt64Array = chunk.buffered_indices.finish();
             let mut right_columns = if matches!(self.join_type, 
JoinType::LeftMark) {
                 vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
-            } else if matches!(self.join_type, JoinType::LeftSemi | 
JoinType::LeftAnti) {
+            } else if matches!(
+                self.join_type,
+                JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti
+            ) {
                 vec![]
             } else if let Some(buffered_idx) = chunk.buffered_batch_idx {
                 fetch_right_columns_by_idxs(
@@ -1717,6 +1726,14 @@ impl SortMergeJoinStream {
                         )?;
 
                         get_filter_column(&self.filter, &left_columns, 
&right_cols)
+                    } else if matches!(self.join_type, JoinType::RightAnti) {
+                        let right_cols = fetch_right_columns_by_idxs(
+                            &self.buffered_data,
+                            chunk.buffered_batch_idx.unwrap(),
+                            &right_indices,
+                        )?;
+
+                        get_filter_column(&self.filter, &right_cols, 
&left_columns)
                     } else {
                         get_filter_column(&self.filter, &left_columns, 
&right_columns)
                     }
@@ -1773,6 +1790,7 @@ impl SortMergeJoinStream {
                             | JoinType::LeftSemi
                             | JoinType::Right
                             | JoinType::LeftAnti
+                            | JoinType::RightAnti
                             | JoinType::LeftMark
                             | JoinType::Full
                     ) {
@@ -1856,6 +1874,7 @@ impl SortMergeJoinStream {
                     | JoinType::LeftSemi
                     | JoinType::Right
                     | JoinType::LeftAnti
+                    | JoinType::RightAnti
                     | JoinType::LeftMark
                     | JoinType::Full
             ))
@@ -1902,6 +1921,14 @@ impl SortMergeJoinStream {
             &out_mask
         };
 
+        self.filter_record_batch_by_join_type(record_batch, corrected_mask)
+    }
+
+    fn filter_record_batch_by_join_type(
+        &mut self,
+        record_batch: RecordBatch,
+        corrected_mask: &BooleanArray,
+    ) -> Result<RecordBatch> {
         let mut filtered_record_batch =
             filter_record_batch(&record_batch, corrected_mask)?;
         let left_columns_length = self.streamed_schema.fields.len();
@@ -1954,6 +1981,10 @@ impl SortMergeJoinStream {
             let output_column_indices = 
(0..left_columns_length).collect::<Vec<_>>();
             filtered_record_batch =
                 filtered_record_batch.project(&output_column_indices)?;
+        } else if matches!(self.join_type, JoinType::RightAnti) {
+            let output_column_indices = 
(0..right_columns_length).collect::<Vec<_>>();
+            filtered_record_batch =
+                filtered_record_batch.project(&output_column_indices)?;
         } else if matches!(self.join_type, JoinType::Full)
             && corrected_mask.false_count() > 0
         {
@@ -2389,6 +2420,7 @@ mod tests {
     use arrow_array::builder::{BooleanBuilder, UInt64Builder};
     use arrow_array::{BooleanArray, UInt64Array};
 
+    use datafusion_common::JoinSide;
     use datafusion_common::JoinType::*;
     use datafusion_common::{
         assert_batches_eq, assert_batches_sorted_eq, assert_contains, 
JoinType, Result,
@@ -2397,13 +2429,15 @@ mod tests {
     use datafusion_execution::disk_manager::DiskManagerConfig;
     use datafusion_execution::runtime_env::RuntimeEnvBuilder;
     use datafusion_execution::TaskContext;
+    use datafusion_expr::Operator;
+    use datafusion_physical_expr::expressions::BinaryExpr;
 
     use crate::expressions::Column;
     use crate::joins::sort_merge_join::{get_corrected_filter_mask, 
JoinedRecordBatches};
-    use crate::joins::utils::JoinOn;
+    use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
     use crate::joins::SortMergeJoinExec;
     use crate::memory::MemoryExec;
-    use crate::test::build_table_i32;
+    use crate::test::{build_table_i32, build_table_i32_two_cols};
     use crate::{common, ExecutionPlan};
 
     fn build_table(
@@ -2494,6 +2528,15 @@ mod tests {
         Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
     }
 
+    pub fn build_table_two_cols(
+        a: (&str, &Vec<i32>),
+        b: (&str, &Vec<i32>),
+    ) -> Arc<dyn ExecutionPlan> {
+        let batch = build_table_i32_two_cols(a, b);
+        let schema = batch.schema();
+        Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+    }
+
     fn join(
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
@@ -2523,6 +2566,26 @@ mod tests {
         )
     }
 
+    fn join_with_filter(
+        left: Arc<dyn ExecutionPlan>,
+        right: Arc<dyn ExecutionPlan>,
+        on: JoinOn,
+        filter: JoinFilter,
+        join_type: JoinType,
+        sort_options: Vec<SortOptions>,
+        null_equals_null: bool,
+    ) -> Result<SortMergeJoinExec> {
+        SortMergeJoinExec::try_new(
+            left,
+            right,
+            on,
+            Some(filter),
+            join_type,
+            sort_options,
+            null_equals_null,
+        )
+    }
+
     async fn join_collect(
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
@@ -2533,6 +2596,25 @@ mod tests {
         join_collect_with_options(left, right, on, join_type, sort_options, 
false).await
     }
 
+    async fn join_collect_with_filter(
+        left: Arc<dyn ExecutionPlan>,
+        right: Arc<dyn ExecutionPlan>,
+        on: JoinOn,
+        filter: JoinFilter,
+        join_type: JoinType,
+    ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+        let sort_options = vec![SortOptions::default(); on.len()];
+
+        let task_ctx = Arc::new(TaskContext::default());
+        let join =
+            join_with_filter(left, right, on, filter, join_type, sort_options, 
false)?;
+        let columns = columns(&join.schema());
+
+        let stream = join.execute(0, task_ctx)?;
+        let batches = common::collect(stream).await?;
+        Ok((columns, batches))
+    }
+
     async fn join_collect_with_options(
         left: Arc<dyn ExecutionPlan>,
         right: Arc<dyn ExecutionPlan>,
@@ -2914,7 +2996,7 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn join_anti() -> Result<()> {
+    async fn join_left_anti() -> Result<()> {
         let left = build_table(
             ("a1", &vec![1, 2, 2, 3, 5]),
             ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
@@ -2944,6 +3026,310 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn join_right_anti_one_one() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![1, 2, 2]),
+            ("b1", &vec![4, 5, 5]),
+            ("c1", &vec![7, 8, 8]),
+        );
+        let right =
+            build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 
6]));
+        let on = vec![(
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+        )];
+
+        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+        let expected = [
+            "+----+----+",
+            "| a2 | b1 |",
+            "+----+----+",
+            "| 30 | 6  |",
+            "+----+----+",
+        ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
+
+        let left2 = build_table(
+            ("a1", &vec![1, 2, 2]),
+            ("b1", &vec![4, 5, 5]),
+            ("c1", &vec![7, 8, 8]),
+        );
+        let right2 = build_table(
+            ("a2", &vec![10, 20, 30]),
+            ("b1", &vec![4, 5, 6]),
+            ("c2", &vec![70, 80, 90]),
+        );
+
+        let on = vec![(
+            Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _,
+        )];
+
+        let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?;
+        let expected2 = [
+            "+----+----+----+",
+            "| a2 | b1 | c2 |",
+            "+----+----+----+",
+            "| 30 | 6  | 90 |",
+            "+----+----+----+",
+        ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected2, &batches2);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_anti_two_two() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![1, 2, 2]),
+            ("b1", &vec![4, 5, 5]),
+            ("c1", &vec![7, 8, 8]),
+        );
+        let right =
+            build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5, 
6]));
+        let on = vec![
+            (
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+            ),
+        ];
+
+        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+        let expected = [
+            "+----+----+",
+            "| a2 | b1 |",
+            "+----+----+",
+            "| 10 | 4  |",
+            "| 20 | 5  |",
+            "| 30 | 6  |",
+            "+----+----+",
+        ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
+
+        let left = build_table(
+            ("a1", &vec![1, 2, 2]),
+            ("b1", &vec![4, 5, 5]),
+            ("c1", &vec![7, 8, 8]),
+        );
+        let right = build_table(
+            ("a2", &vec![10, 20, 30]),
+            ("b1", &vec![4, 5, 6]),
+            ("c2", &vec![70, 80, 90]),
+        );
+
+        let on = vec![
+            (
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+            ),
+        ];
+
+        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+        let expected = [
+            "+----+----+----+",
+            "| a2 | b1 | c2 |",
+            "+----+----+----+",
+            "| 10 | 4  | 70 |",
+            "| 20 | 5  | 80 |",
+            "| 30 | 6  | 90 |",
+            "+----+----+----+",
+        ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_anti_two_with_filter() -> Result<()> {
+        let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1", 
&vec![30]));
+        let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2", 
&vec![20]));
+        let on = vec![
+            (
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+            ),
+        ];
+        let filter = JoinFilter::new(
+            Arc::new(BinaryExpr::new(
+                Arc::new(Column::new("c2", 1)),
+                Operator::Gt,
+                Arc::new(Column::new("c1", 0)),
+            )),
+            vec![
+                ColumnIndex {
+                    index: 2,
+                    side: JoinSide::Left,
+                },
+                ColumnIndex {
+                    index: 2,
+                    side: JoinSide::Right,
+                },
+            ],
+            Schema::new(vec![
+                Field::new("c1", DataType::Int32, true),
+                Field::new("c2", DataType::Int32, true),
+            ]),
+        );
+        let (_, batches) =
+            join_collect_with_filter(left, right, on, filter, 
RightAnti).await?;
+        let expected = [
+            "+----+----+----+",
+            "| a1 | b1 | c2 |",
+            "+----+----+----+",
+            "| 1  | 10 | 20 |",
+            "+----+----+----+",
+        ];
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_anti_with_nulls() -> Result<()> {
+        let left = build_table_i32_nullable(
+            ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
+            ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
+            ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
+        );
+        let right = build_table_i32_nullable(
+            ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
+            ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key 
field
+            ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key 
field
+        );
+        let on = vec![
+            (
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+            ),
+        ];
+
+        let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+        let expected = [
+            "+----+----+----+",
+            "| a1 | b1 | c2 |",
+            "+----+----+----+",
+            "| 2  |    | 8  |",
+            "+----+----+----+",
+        ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_anti_with_nulls_with_options() -> Result<()> {
+        let left = build_table_i32_nullable(
+            ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]),
+            ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]),
+            ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]),
+        );
+        let right = build_table_i32_nullable(
+            ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
+            ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key 
field
+            ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key 
field
+        );
+        let on = vec![
+            (
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+            ),
+        ];
+
+        let (_, batches) = join_collect_with_options(
+            left,
+            right,
+            on,
+            RightAnti,
+            vec![
+                SortOptions {
+                    descending: true,
+                    nulls_first: false,
+                };
+                2
+            ],
+            true,
+        )
+        .await?;
+
+        let expected = [
+            "+----+----+----+",
+            "| a1 | b1 | c2 |",
+            "+----+----+----+",
+            "| 3  |    | 9  |",
+            "| 2  | 5  |    |",
+            "| 2  | 5  | 8  |",
+            "+----+----+----+",
+        ];
+        // The output order is important as SMJ preserves sortedness
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_right_anti_output_two_batches() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![1, 2, 2]),
+            ("b1", &vec![4, 5, 5]),
+            ("c1", &vec![7, 8, 8]),
+        );
+        let right = build_table(
+            ("a2", &vec![10, 20, 30]),
+            ("b1", &vec![4, 5, 6]),
+            ("c2", &vec![70, 80, 90]),
+        );
+        let on = vec![
+            (
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
+            ),
+            (
+                Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+            ),
+        ];
+
+        let (_, batches) =
+            join_collect_batch_size_equals_two(left, right, on, 
LeftAnti).await?;
+        let expected = [
+            "+----+----+----+",
+            "| a1 | b1 | c1 |",
+            "+----+----+----+",
+            "| 1  | 4  | 7  |",
+            "| 2  | 5  | 8  |",
+            "| 2  | 5  | 8  |",
+            "+----+----+----+",
+        ];
+        assert_eq!(batches.len(), 2);
+        assert_eq!(batches[0].num_rows(), 2);
+        assert_eq!(batches[1].num_rows(), 1);
+        assert_batches_eq!(expected, &batches);
+        Ok(())
+    }
+
     #[tokio::test]
     async fn join_semi() -> Result<()> {
         let left = build_table(
@@ -4138,174 +4524,176 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn test_left_anti_join_filtered_mask() -> Result<()> {
-        let mut joined_batches = build_joined_record_batches()?;
-        let schema = joined_batches.batches.first().unwrap().schema();
-
-        let output = concat_batches(&schema, &joined_batches.batches)?;
-        let out_mask = joined_batches.filter_mask.finish();
-        let out_indices = joined_batches.row_indices.finish();
+    async fn test_anti_join_filtered_mask() -> Result<()> {
+        for join_type in [LeftAnti, RightAnti] {
+            let mut joined_batches = build_joined_record_batches()?;
+            let schema = joined_batches.batches.first().unwrap().schema();
+
+            let output = concat_batches(&schema, &joined_batches.batches)?;
+            let out_mask = joined_batches.filter_mask.finish();
+            let out_indices = joined_batches.row_indices.finish();
+
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0]),
+                    &[0usize],
+                    &BooleanArray::from(vec![true]),
+                    1
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0]),
-                &[0usize],
-                &BooleanArray::from(vec![true]),
-                1
-            )
-            .unwrap(),
-            BooleanArray::from(vec![None])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0]),
+                    &[0usize],
+                    &BooleanArray::from(vec![false]),
+                    1
+                )
+                .unwrap(),
+                BooleanArray::from(vec![Some(true)])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0]),
-                &[0usize],
-                &BooleanArray::from(vec![false]),
-                1
-            )
-            .unwrap(),
-            BooleanArray::from(vec![Some(true)])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0, 0]),
+                    &[0usize; 2],
+                    &BooleanArray::from(vec![true, true]),
+                    2
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None, None])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0, 0]),
-                &[0usize; 2],
-                &BooleanArray::from(vec![true, true]),
-                2
-            )
-            .unwrap(),
-            BooleanArray::from(vec![None, None])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0, 0, 0]),
+                    &[0usize; 3],
+                    &BooleanArray::from(vec![true, true, true]),
+                    3
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None, None, None])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0, 0, 0]),
-                &[0usize; 3],
-                &BooleanArray::from(vec![true, true, true]),
-                3
-            )
-            .unwrap(),
-            BooleanArray::from(vec![None, None, None])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0, 0, 0]),
+                    &[0usize; 3],
+                    &BooleanArray::from(vec![true, false, true]),
+                    3
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None, None, None])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0, 0, 0]),
-                &[0usize; 3],
-                &BooleanArray::from(vec![true, false, true]),
-                3
-            )
-            .unwrap(),
-            BooleanArray::from(vec![None, None, None])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0, 0, 0]),
+                    &[0usize; 3],
+                    &BooleanArray::from(vec![false, false, true]),
+                    3
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None, None, None])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0, 0, 0]),
-                &[0usize; 3],
-                &BooleanArray::from(vec![false, false, true]),
-                3
-            )
-            .unwrap(),
-            BooleanArray::from(vec![None, None, None])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0, 0, 0]),
+                    &[0usize; 3],
+                    &BooleanArray::from(vec![false, true, true]),
+                    3
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None, None, None])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0, 0, 0]),
-                &[0usize; 3],
-                &BooleanArray::from(vec![false, true, true]),
-                3
-            )
-            .unwrap(),
-            BooleanArray::from(vec![None, None, None])
-        );
+            assert_eq!(
+                get_corrected_filter_mask(
+                    join_type,
+                    &UInt64Array::from(vec![0, 0, 0]),
+                    &[0usize; 3],
+                    &BooleanArray::from(vec![false, false, false]),
+                    3
+                )
+                .unwrap(),
+                BooleanArray::from(vec![None, None, Some(true)])
+            );
 
-        assert_eq!(
-            get_corrected_filter_mask(
-                LeftAnti,
-                &UInt64Array::from(vec![0, 0, 0]),
-                &[0usize; 3],
-                &BooleanArray::from(vec![false, false, false]),
-                3
+            let corrected_mask = get_corrected_filter_mask(
+                join_type,
+                &out_indices,
+                &joined_batches.batch_ids,
+                &out_mask,
+                output.num_rows(),
             )
-            .unwrap(),
-            BooleanArray::from(vec![None, None, Some(true)])
-        );
-
-        let corrected_mask = get_corrected_filter_mask(
-            LeftAnti,
-            &out_indices,
-            &joined_batches.batch_ids,
-            &out_mask,
-            output.num_rows(),
-        )
-        .unwrap();
-
-        assert_eq!(
-            corrected_mask,
-            BooleanArray::from(vec![
-                None,
-                None,
-                None,
-                None,
-                None,
-                Some(true),
-                None,
-                Some(true)
-            ])
-        );
-
-        let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
+            .unwrap();
+
+            assert_eq!(
+                corrected_mask,
+                BooleanArray::from(vec![
+                    None,
+                    None,
+                    None,
+                    None,
+                    None,
+                    Some(true),
+                    None,
+                    Some(true)
+                ])
+            );
 
-        assert_batches_eq!(
-            &[
-                "+---+----+---+----+",
-                "| a | b  | x | y  |",
-                "+---+----+---+----+",
-                "| 1 | 13 | 1 | 12 |",
-                "| 1 | 14 | 1 | 11 |",
-                "+---+----+---+----+",
-            ],
-            &[filtered_rb]
-        );
+            let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
+
+            assert_batches_eq!(
+                &[
+                    "+---+----+---+----+",
+                    "| a | b  | x | y  |",
+                    "+---+----+---+----+",
+                    "| 1 | 13 | 1 | 12 |",
+                    "| 1 | 14 | 1 | 11 |",
+                    "+---+----+---+----+",
+                ],
+                &[filtered_rb]
+            );
 
-        // output null rows
-        let null_mask = arrow::compute::not(&corrected_mask)?;
-        assert_eq!(
-            null_mask,
-            BooleanArray::from(vec![
-                None,
-                None,
-                None,
-                None,
-                None,
-                Some(false),
-                None,
-                Some(false),
-            ])
-        );
+            // output null rows
+            let null_mask = arrow::compute::not(&corrected_mask)?;
+            assert_eq!(
+                null_mask,
+                BooleanArray::from(vec![
+                    None,
+                    None,
+                    None,
+                    None,
+                    None,
+                    Some(false),
+                    None,
+                    Some(false),
+                ])
+            );
 
-        let null_joined_batch = filter_record_batch(&output, &null_mask)?;
+            let null_joined_batch = filter_record_batch(&output, &null_mask)?;
 
-        assert_batches_eq!(
-            &[
-                "+---+---+---+---+",
-                "| a | b | x | y |",
-                "+---+---+---+---+",
-                "+---+---+---+---+",
-            ],
-            &[null_joined_batch]
-        );
+            assert_batches_eq!(
+                &[
+                    "+---+---+---+---+",
+                    "| a | b | x | y |",
+                    "+---+---+---+---+",
+                    "+---+---+---+---+",
+                ],
+                &[null_joined_batch]
+            );
+        }
         Ok(())
     }
 
diff --git a/datafusion/physical-plan/src/test.rs 
b/datafusion/physical-plan/src/test.rs
index 90ec9b1068..b7bbfd1169 100644
--- a/datafusion/physical-plan/src/test.rs
+++ b/datafusion/physical-plan/src/test.rs
@@ -88,6 +88,26 @@ pub fn build_table_i32(
     .unwrap()
 }
 
+/// Returns record batch with 2 columns of i32 in memory
+pub fn build_table_i32_two_cols(
+    a: (&str, &Vec<i32>),
+    b: (&str, &Vec<i32>),
+) -> RecordBatch {
+    let schema = Schema::new(vec![
+        Field::new(a.0, DataType::Int32, false),
+        Field::new(b.0, DataType::Int32, false),
+    ]);
+
+    RecordBatch::try_new(
+        Arc::new(schema),
+        vec![
+            Arc::new(Int32Array::from(a.1.clone())),
+            Arc::new(Int32Array::from(b.1.clone())),
+        ],
+    )
+    .unwrap()
+}
+
 /// Returns memory table scan wrapped around record batch with 3 columns of i32
 pub fn build_table_scan_i32(
     a: (&str, &Vec<i32>),
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt 
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 9a20e7987f..1df52dd1eb 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -647,6 +647,54 @@ NULL NULL 7 9
 NULL NULL 8 10
 NULL NULL 9 11
 
+query II
+select * from (
+with
+t1 as (
+    select 31 a, 32 b union all
+    select 31 a, 33 b
+),
+t2 as (
+    select 31 a, 32 b union all
+    select 31 a, 35 b
+)
+select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b
+) order by 1, 2;
+----
+31 35
+
+query II
+select * from (
+with
+t1 as (
+    select 41 a, 42 b union all
+    select 41 a, 43 b
+),
+t2 as (
+    select 41 a, 42 b union all
+    select 41 a, 45 b
+)
+select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b
+) order by 1, 2;
+----
+41 45
+
+query II
+select * from (
+with
+t1 as (
+    select 51 a, 52 b union all
+    select 51 a, 53 b
+),
+t2 as (
+    select 51 a, 52 b union all
+    select 51 a, 54 b
+)
+select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b
+) order by 1, 2;
+----
+51 54
+
 # return sql params back to default values
 statement ok
 set datafusion.optimizer.prefer_hash_join = true;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to