This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d6ddd23795 Fix SMJ Left Anti Join when the join filter is set (#10724)
d6ddd23795 is described below

commit d6ddd23795222672055e0b737c20bc1fc19e7dd3
Author: Oleks V <[email protected]>
AuthorDate: Fri May 31 14:32:56 2024 -0700

    Fix SMJ Left Anti Join when the join filter is set (#10724)
    
    * Fix: Sort Merge Join crashes on TPCH Q21
    
    * Fix LeftAnti SMJ join when the join filter is set
    
    * rm dbg
---
 .../physical-plan/src/joins/sort_merge_join.rs     | 249 +++++++++++++++++----
 .../sqllogictest/test_files/sort_merge_join.slt    | 121 ++++++++--
 2 files changed, 306 insertions(+), 64 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index ec83fe3f2a..143a726d31 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -487,7 +487,6 @@ struct StreamedBatch {
     /// The join key arrays of streamed batch which are used to compare with 
buffered batches
     /// and to produce output. They are produced by evaluating `on` 
expressions.
     pub join_arrays: Vec<ArrayRef>,
-
     /// Chunks of indices from buffered side (may be nulls) joined to streamed
     pub output_indices: Vec<StreamedJoinedChunk>,
     /// Index of currently scanned batch from buffered data
@@ -1021,6 +1020,15 @@ impl SMJStream {
                     join_streamed = true;
                     join_buffered = true;
                 };
+
+                if matches!(self.join_type, JoinType::LeftAnti) && 
self.filter.is_some() {
+                    join_streamed = !self
+                        .streamed_batch
+                        .join_filter_matched_idxs
+                        .contains(&(self.streamed_batch.idx as u64))
+                        && !self.streamed_joined;
+                    join_buffered = join_streamed;
+                }
             }
             Ordering::Greater => {
                 if matches!(self.join_type, JoinType::Full) {
@@ -1181,7 +1189,10 @@ impl SMJStream {
             let filter_columns = if chunk.buffered_batch_idx.is_some() {
                 if matches!(self.join_type, JoinType::Right) {
                     get_filter_column(&self.filter, &buffered_columns, 
&streamed_columns)
-                } else if matches!(self.join_type, JoinType::LeftSemi) {
+                } else if matches!(
+                    self.join_type,
+                    JoinType::LeftSemi | JoinType::LeftAnti
+                ) {
                     // unwrap is safe here as we check is_some on top of if 
statement
                     let buffered_columns = get_buffered_columns(
                         &self.buffered_data,
@@ -1228,7 +1239,15 @@ impl SMJStream {
                         
datafusion_common::cast::as_boolean_array(&filter_result)?;
 
                     let maybe_filtered_join_mask: Option<(BooleanArray, 
Vec<u64>)> =
-                        get_filtered_join_mask(self.join_type, 
streamed_indices, mask);
+                        get_filtered_join_mask(
+                            self.join_type,
+                            streamed_indices,
+                            mask,
+                            &self.streamed_batch.join_filter_matched_idxs,
+                            &self.buffered_data.scanning_batch_idx,
+                            &self.buffered_data.batches.len(),
+                        );
+
                     if let Some(ref filtered_join_mask) = 
maybe_filtered_join_mask {
                         mask = &filtered_join_mask.0;
                         self.streamed_batch
@@ -1419,51 +1438,87 @@ fn get_buffered_columns(
         .collect::<Result<Vec<_>, ArrowError>>()
 }
 
-// Calculate join filter bit mask considering join type specifics
-// `streamed_indices` - array of streamed datasource JOINED row indices
-// `mask` - array booleans representing computed join filter expression eval 
result:
-//      true = the row index matches the join filter
-//      false = the row index doesn't match the join filter
-// `streamed_indices` have the same length as `mask`
+/// Calculate join filter bit mask considering join type specifics
+/// `streamed_indices` - array of streamed datasource JOINED row indices
+/// `mask` - array booleans representing computed join filter expression eval 
result:
+///      true = the row index matches the join filter
+///      false = the row index doesn't match the join filter
+/// `streamed_indices` have the same length as `mask`
+/// `matched_indices` array of streaming indices that already has a join 
filter match
+/// `scanning_batch_idx` current buffered batch
+/// `buffered_batches_len` how many batches are in buffered data
 fn get_filtered_join_mask(
     join_type: JoinType,
     streamed_indices: UInt64Array,
     mask: &BooleanArray,
+    matched_indices: &HashSet<u64>,
+    scanning_buffered_batch_idx: &usize,
+    buffered_batches_len: &usize,
 ) -> Option<(BooleanArray, Vec<u64>)> {
-    // for LeftSemi Join the filter mask should be calculated in its own way:
-    // if we find at least one matching row for specific streaming index
-    // we don't need to check any others for the same index
-    if matches!(join_type, JoinType::LeftSemi) {
-        // have we seen a filter match for a streaming index before
-        let mut seen_as_true: bool = false;
-        let streamed_indices_length = streamed_indices.len();
-        let mut corrected_mask: BooleanBuilder =
-            BooleanBuilder::with_capacity(streamed_indices_length);
-
-        let mut filter_matched_indices: Vec<u64> = vec![];
-
-        #[allow(clippy::needless_range_loop)]
-        for i in 0..streamed_indices_length {
-            // LeftSemi respects only first true values for specific streaming 
index,
-            // others true values for the same index must be false
-            if mask.value(i) && !seen_as_true {
-                seen_as_true = true;
-                corrected_mask.append_value(true);
-                filter_matched_indices.push(streamed_indices.value(i));
-            } else {
-                corrected_mask.append_value(false);
+    let mut seen_as_true: bool = false;
+    let streamed_indices_length = streamed_indices.len();
+    let mut corrected_mask: BooleanBuilder =
+        BooleanBuilder::with_capacity(streamed_indices_length);
+
+    let mut filter_matched_indices: Vec<u64> = vec![];
+
+    #[allow(clippy::needless_range_loop)]
+    match join_type {
+        // for LeftSemi Join the filter mask should be calculated in its own 
way:
+        // if we find at least one matching row for specific streaming index
+        // we don't need to check any others for the same index
+        JoinType::LeftSemi => {
+            // have we seen a filter match for a streaming index before
+            for i in 0..streamed_indices_length {
+                // LeftSemi respects only first true values for specific 
streaming index,
+                // others true values for the same index must be false
+                if mask.value(i) && !seen_as_true {
+                    seen_as_true = true;
+                    corrected_mask.append_value(true);
+                    filter_matched_indices.push(streamed_indices.value(i));
+                } else {
+                    corrected_mask.append_value(false);
+                }
+
+                // if switched to next streaming index(e.g. from 0 to 1, or 
from 1 to 2), we reset seen_as_true flag
+                if i < streamed_indices_length - 1
+                    && streamed_indices.value(i) != streamed_indices.value(i + 
1)
+                {
+                    seen_as_true = false;
+                }
             }
+            Some((corrected_mask.finish(), filter_matched_indices))
+        }
+        // LeftAnti semantics: return true if for every x in the collection, 
p(x) is false.
+        // the true(if any) flag needs to be set only once per streaming index
+        // to prevent duplicates in the output
+        JoinType::LeftAnti => {
+            // have we seen a filter match for a streaming index before
+            for i in 0..streamed_indices_length {
+                if mask.value(i) && !seen_as_true {
+                    seen_as_true = true;
+                    filter_matched_indices.push(streamed_indices.value(i));
+                }
 
-            // if switched to next streaming index(e.g. from 0 to 1, or from 1 
to 2), we reset seen_as_true flag
-            if i < streamed_indices_length - 1
-                && streamed_indices.value(i) != streamed_indices.value(i + 1)
-            {
-                seen_as_true = false;
+                // if switched to next streaming index(e.g. from 0 to 1, or 
from 1 to 2), we reset seen_as_true flag
+                if (i < streamed_indices_length - 1
+                    && streamed_indices.value(i) != streamed_indices.value(i + 
1))
+                    || (i == streamed_indices_length - 1
+                        && *scanning_buffered_batch_idx == 
buffered_batches_len - 1)
+                {
+                    corrected_mask.append_value(
+                        !matched_indices.contains(&streamed_indices.value(i))
+                            && !seen_as_true,
+                    );
+                    seen_as_true = false;
+                } else {
+                    corrected_mask.append_value(false);
+                }
             }
+
+            Some((corrected_mask.finish(), filter_matched_indices))
         }
-        Some((corrected_mask.finish(), filter_matched_indices))
-    } else {
-        None
+        _ => None,
     }
 }
 
@@ -1711,8 +1766,9 @@ mod tests {
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow::record_batch::RecordBatch;
     use arrow_array::{BooleanArray, UInt64Array};
+    use hashbrown::HashSet;
 
-    use datafusion_common::JoinType::LeftSemi;
+    use datafusion_common::JoinType::{LeftAnti, LeftSemi};
     use datafusion_common::{
         assert_batches_eq, assert_batches_sorted_eq, assert_contains, 
JoinType, Result,
     };
@@ -2754,7 +2810,10 @@ mod tests {
             get_filtered_join_mask(
                 LeftSemi,
                 UInt64Array::from(vec![0, 0, 1, 1]),
-                &BooleanArray::from(vec![true, true, false, false])
+                &BooleanArray::from(vec![true, true, false, false]),
+                &HashSet::new(),
+                &0,
+                &0
             ),
             Some((BooleanArray::from(vec![true, false, false, false]), 
vec![0]))
         );
@@ -2763,7 +2822,10 @@ mod tests {
             get_filtered_join_mask(
                 LeftSemi,
                 UInt64Array::from(vec![0, 1]),
-                &BooleanArray::from(vec![true, true])
+                &BooleanArray::from(vec![true, true]),
+                &HashSet::new(),
+                &0,
+                &0
             ),
             Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
         );
@@ -2772,7 +2834,10 @@ mod tests {
             get_filtered_join_mask(
                 LeftSemi,
                 UInt64Array::from(vec![0, 1]),
-                &BooleanArray::from(vec![false, true])
+                &BooleanArray::from(vec![false, true]),
+                &HashSet::new(),
+                &0,
+                &0
             ),
             Some((BooleanArray::from(vec![false, true]), vec![1]))
         );
@@ -2781,7 +2846,10 @@ mod tests {
             get_filtered_join_mask(
                 LeftSemi,
                 UInt64Array::from(vec![0, 1]),
-                &BooleanArray::from(vec![true, false])
+                &BooleanArray::from(vec![true, false]),
+                &HashSet::new(),
+                &0,
+                &0
             ),
             Some((BooleanArray::from(vec![true, false]), vec![0]))
         );
@@ -2790,7 +2858,10 @@ mod tests {
             get_filtered_join_mask(
                 LeftSemi,
                 UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
-                &BooleanArray::from(vec![false, true, true, true, true, true])
+                &BooleanArray::from(vec![false, true, true, true, true, true]),
+                &HashSet::new(),
+                &0,
+                &0
             ),
             Some((
                 BooleanArray::from(vec![false, true, false, true, false, 
false]),
@@ -2802,7 +2873,10 @@ mod tests {
             get_filtered_join_mask(
                 LeftSemi,
                 UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
-                &BooleanArray::from(vec![false, false, false, false, false, 
true])
+                &BooleanArray::from(vec![false, false, false, false, false, 
true]),
+                &HashSet::new(),
+                &0,
+                &0
             ),
             Some((
                 BooleanArray::from(vec![false, false, false, false, false, 
true]),
@@ -2813,6 +2887,89 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn left_anti_join_filtered_mask() -> Result<()> {
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftAnti,
+                UInt64Array::from(vec![0, 0, 1, 1]),
+                &BooleanArray::from(vec![true, true, false, false]),
+                &HashSet::new(),
+                &0,
+                &1
+            ),
+            Some((BooleanArray::from(vec![false, false, false, true]), 
vec![0]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftAnti,
+                UInt64Array::from(vec![0, 1]),
+                &BooleanArray::from(vec![true, true]),
+                &HashSet::new(),
+                &0,
+                &1
+            ),
+            Some((BooleanArray::from(vec![false, false]), vec![0, 1]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftAnti,
+                UInt64Array::from(vec![0, 1]),
+                &BooleanArray::from(vec![false, true]),
+                &HashSet::new(),
+                &0,
+                &1
+            ),
+            Some((BooleanArray::from(vec![true, false]), vec![1]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftAnti,
+                UInt64Array::from(vec![0, 1]),
+                &BooleanArray::from(vec![true, false]),
+                &HashSet::new(),
+                &0,
+                &1
+            ),
+            Some((BooleanArray::from(vec![false, true]), vec![0]))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftAnti,
+                UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+                &BooleanArray::from(vec![false, true, true, true, true, true]),
+                &HashSet::new(),
+                &0,
+                &1
+            ),
+            Some((
+                BooleanArray::from(vec![false, false, false, false, false, 
false]),
+                vec![0, 1]
+            ))
+        );
+
+        assert_eq!(
+            get_filtered_join_mask(
+                LeftAnti,
+                UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+                &BooleanArray::from(vec![false, false, false, false, false, 
true]),
+                &HashSet::new(),
+                &0,
+                &1
+            ),
+            Some((
+                BooleanArray::from(vec![false, false, true, false, false, 
false]),
+                vec![1]
+            ))
+        );
+
+        Ok(())
+    }
+
     /// Returns the column names on the schema
     fn columns(schema: &Schema) -> Vec<String> {
         schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt 
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 3a27d9693d..babb7dc8fd 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -378,24 +378,6 @@ select t1.* from t1 where exists (select 1 from t2 where 
t2.a = t1.a and t2.b !=
 11 12
 11 13
 
-#LEFTANTI tests
-# returns no rows instead of correct result
-#query III
-#select * from (
-#with
-#t1 as (
-#    select 11 a, 12 b, 1 c union all
-#    select 11 a, 13 b, 2 c),
-#t2 as (
-#    select 11 a, 12 b, 3 c union all
-#    select 11 a, 14 b, 4 c
-#    )
-#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b and t1.c > t2.c)
-#) order by 1, 2;
-#----
-#11 12 1
-#11 13 2
-
 # Set batch size to 1 for sort merge join to test scenario when data spread 
across multiple batches
 statement ok
 set datafusion.execution.batch_size = 1;
@@ -431,5 +413,108 @@ select t1.* from t1 where exists (select 1 from t2 where 
t2.a = t1.a and t2.b !=
 11 12
 11 13
 
+#LEFTANTI tests
+statement ok
+set datafusion.execution.batch_size = 10;
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b),
+t2 as (
+    select 11 a, 13 c union all
+    select 11 a, 14 c
+    )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t1.b > t2.c)
+) order by 1, 2
+----
+11 12
+
+query III
+select * from (
+with
+t1 as (
+    select 11 a, 12 b, 1 c union all
+    select 11 a, 13 b, 2 c),
+t2 as (
+    select 11 a, 12 b, 3 c union all
+    select 11 a, 14 b, 4 c
+    )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+query III
+select * from (
+with
+t1 as (
+    select 11 a, 12 b, 1 c union all
+    select 11 a, 13 b, 2 c),
+t2 as (
+    select 11 a, 12 b, 3 c where false
+    )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+# Test LEFT ANTI with cross batch data distribution
+statement ok
+set datafusion.execution.batch_size = 1;
+
+query II
+select * from (
+with
+t1 as (
+    select 11 a, 12 b),
+t2 as (
+    select 11 a, 13 c union all
+    select 11 a, 14 c
+    )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t1.b > t2.c)
+) order by 1, 2
+----
+11 12
+
+query III
+select * from (
+with
+t1 as (
+    select 11 a, 12 b, 1 c union all
+    select 11 a, 13 b, 2 c),
+t2 as (
+    select 11 a, 12 b, 3 c union all
+    select 11 a, 14 b, 4 c
+    )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+query III
+select * from (
+with
+t1 as (
+    select 11 a, 12 b, 1 c union all
+    select 11 a, 13 b, 2 c),
+t2 as (
+    select 11 a, 12 b, 3 c where false
+    )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and 
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+# return sql params back to default values
 statement ok
 set datafusion.optimizer.prefer_hash_join = true;
+
+statement ok
+set datafusion.execution.batch_size = 8192;
+


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to