This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d6ddd23795 Fix SMJ Left Anti Join when the join filter is set (#10724)
d6ddd23795 is described below
commit d6ddd23795222672055e0b737c20bc1fc19e7dd3
Author: Oleks V <[email protected]>
AuthorDate: Fri May 31 14:32:56 2024 -0700
Fix SMJ Left Anti Join when the join filter is set (#10724)
* Fix: Sort Merge Join crashes on TPCH Q21
* Fix LeftAnti SMJ join when the join filter is set
* rm dbg
---
.../physical-plan/src/joins/sort_merge_join.rs | 249 +++++++++++++++++----
.../sqllogictest/test_files/sort_merge_join.slt | 121 ++++++++--
2 files changed, 306 insertions(+), 64 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index ec83fe3f2a..143a726d31 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -487,7 +487,6 @@ struct StreamedBatch {
/// The join key arrays of streamed batch which are used to compare with
buffered batches
/// and to produce output. They are produced by evaluating `on`
expressions.
pub join_arrays: Vec<ArrayRef>,
-
/// Chunks of indices from buffered side (may be nulls) joined to streamed
pub output_indices: Vec<StreamedJoinedChunk>,
/// Index of currently scanned batch from buffered data
@@ -1021,6 +1020,15 @@ impl SMJStream {
join_streamed = true;
join_buffered = true;
};
+
+ if matches!(self.join_type, JoinType::LeftAnti) &&
self.filter.is_some() {
+ join_streamed = !self
+ .streamed_batch
+ .join_filter_matched_idxs
+ .contains(&(self.streamed_batch.idx as u64))
+ && !self.streamed_joined;
+ join_buffered = join_streamed;
+ }
}
Ordering::Greater => {
if matches!(self.join_type, JoinType::Full) {
@@ -1181,7 +1189,10 @@ impl SMJStream {
let filter_columns = if chunk.buffered_batch_idx.is_some() {
if matches!(self.join_type, JoinType::Right) {
get_filter_column(&self.filter, &buffered_columns,
&streamed_columns)
- } else if matches!(self.join_type, JoinType::LeftSemi) {
+ } else if matches!(
+ self.join_type,
+ JoinType::LeftSemi | JoinType::LeftAnti
+ ) {
// unwrap is safe here as we check is_some on top of if
statement
let buffered_columns = get_buffered_columns(
&self.buffered_data,
@@ -1228,7 +1239,15 @@ impl SMJStream {
datafusion_common::cast::as_boolean_array(&filter_result)?;
let maybe_filtered_join_mask: Option<(BooleanArray,
Vec<u64>)> =
- get_filtered_join_mask(self.join_type,
streamed_indices, mask);
+ get_filtered_join_mask(
+ self.join_type,
+ streamed_indices,
+ mask,
+ &self.streamed_batch.join_filter_matched_idxs,
+ &self.buffered_data.scanning_batch_idx,
+ &self.buffered_data.batches.len(),
+ );
+
if let Some(ref filtered_join_mask) =
maybe_filtered_join_mask {
mask = &filtered_join_mask.0;
self.streamed_batch
@@ -1419,51 +1438,87 @@ fn get_buffered_columns(
.collect::<Result<Vec<_>, ArrowError>>()
}
-// Calculate join filter bit mask considering join type specifics
-// `streamed_indices` - array of streamed datasource JOINED row indices
-// `mask` - array booleans representing computed join filter expression eval
result:
-// true = the row index matches the join filter
-// false = the row index doesn't match the join filter
-// `streamed_indices` have the same length as `mask`
+/// Calculate join filter bit mask considering join type specifics
+/// `streamed_indices` - array of streamed datasource JOINED row indices
+/// `mask` - array booleans representing computed join filter expression eval
result:
+/// true = the row index matches the join filter
+/// false = the row index doesn't match the join filter
+/// `streamed_indices` have the same length as `mask`
+/// `matched_indices` array of streaming indices that already has a join
filter match
+/// `scanning_batch_idx` current buffered batch
+/// `buffered_batches_len` how many batches are in buffered data
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
mask: &BooleanArray,
+ matched_indices: &HashSet<u64>,
+ scanning_buffered_batch_idx: &usize,
+ buffered_batches_len: &usize,
) -> Option<(BooleanArray, Vec<u64>)> {
- // for LeftSemi Join the filter mask should be calculated in its own way:
- // if we find at least one matching row for specific streaming index
- // we don't need to check any others for the same index
- if matches!(join_type, JoinType::LeftSemi) {
- // have we seen a filter match for a streaming index before
- let mut seen_as_true: bool = false;
- let streamed_indices_length = streamed_indices.len();
- let mut corrected_mask: BooleanBuilder =
- BooleanBuilder::with_capacity(streamed_indices_length);
-
- let mut filter_matched_indices: Vec<u64> = vec![];
-
- #[allow(clippy::needless_range_loop)]
- for i in 0..streamed_indices_length {
- // LeftSemi respects only first true values for specific streaming
index,
- // others true values for the same index must be false
- if mask.value(i) && !seen_as_true {
- seen_as_true = true;
- corrected_mask.append_value(true);
- filter_matched_indices.push(streamed_indices.value(i));
- } else {
- corrected_mask.append_value(false);
+ let mut seen_as_true: bool = false;
+ let streamed_indices_length = streamed_indices.len();
+ let mut corrected_mask: BooleanBuilder =
+ BooleanBuilder::with_capacity(streamed_indices_length);
+
+ let mut filter_matched_indices: Vec<u64> = vec![];
+
+ #[allow(clippy::needless_range_loop)]
+ match join_type {
+ // for LeftSemi Join the filter mask should be calculated in its own
way:
+ // if we find at least one matching row for specific streaming index
+ // we don't need to check any others for the same index
+ JoinType::LeftSemi => {
+ // have we seen a filter match for a streaming index before
+ for i in 0..streamed_indices_length {
+ // LeftSemi respects only first true values for specific
streaming index,
+ // others true values for the same index must be false
+ if mask.value(i) && !seen_as_true {
+ seen_as_true = true;
+ corrected_mask.append_value(true);
+ filter_matched_indices.push(streamed_indices.value(i));
+ } else {
+ corrected_mask.append_value(false);
+ }
+
+ // if switched to next streaming index(e.g. from 0 to 1, or
from 1 to 2), we reset seen_as_true flag
+ if i < streamed_indices_length - 1
+ && streamed_indices.value(i) != streamed_indices.value(i +
1)
+ {
+ seen_as_true = false;
+ }
}
+ Some((corrected_mask.finish(), filter_matched_indices))
+ }
+ // LeftAnti semantics: return true if for every x in the collection,
p(x) is false.
+ // the true(if any) flag needs to be set only once per streaming index
+ // to prevent duplicates in the output
+ JoinType::LeftAnti => {
+ // have we seen a filter match for a streaming index before
+ for i in 0..streamed_indices_length {
+ if mask.value(i) && !seen_as_true {
+ seen_as_true = true;
+ filter_matched_indices.push(streamed_indices.value(i));
+ }
- // if switched to next streaming index(e.g. from 0 to 1, or from 1
to 2), we reset seen_as_true flag
- if i < streamed_indices_length - 1
- && streamed_indices.value(i) != streamed_indices.value(i + 1)
- {
- seen_as_true = false;
+ // if switched to next streaming index(e.g. from 0 to 1, or
from 1 to 2), we reset seen_as_true flag
+ if (i < streamed_indices_length - 1
+ && streamed_indices.value(i) != streamed_indices.value(i +
1))
+ || (i == streamed_indices_length - 1
+ && *scanning_buffered_batch_idx ==
buffered_batches_len - 1)
+ {
+ corrected_mask.append_value(
+ !matched_indices.contains(&streamed_indices.value(i))
+ && !seen_as_true,
+ );
+ seen_as_true = false;
+ } else {
+ corrected_mask.append_value(false);
+ }
}
+
+ Some((corrected_mask.finish(), filter_matched_indices))
}
- Some((corrected_mask.finish(), filter_matched_indices))
- } else {
- None
+ _ => None,
}
}
@@ -1711,8 +1766,9 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::{BooleanArray, UInt64Array};
+ use hashbrown::HashSet;
- use datafusion_common::JoinType::LeftSemi;
+ use datafusion_common::JoinType::{LeftAnti, LeftSemi};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains,
JoinType, Result,
};
@@ -2754,7 +2810,10 @@ mod tests {
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 1, 1]),
- &BooleanArray::from(vec![true, true, false, false])
+ &BooleanArray::from(vec![true, true, false, false]),
+ &HashSet::new(),
+ &0,
+ &0
),
Some((BooleanArray::from(vec![true, false, false, false]),
vec![0]))
);
@@ -2763,7 +2822,10 @@ mod tests {
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
- &BooleanArray::from(vec![true, true])
+ &BooleanArray::from(vec![true, true]),
+ &HashSet::new(),
+ &0,
+ &0
),
Some((BooleanArray::from(vec![true, true]), vec![0, 1]))
);
@@ -2772,7 +2834,10 @@ mod tests {
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
- &BooleanArray::from(vec![false, true])
+ &BooleanArray::from(vec![false, true]),
+ &HashSet::new(),
+ &0,
+ &0
),
Some((BooleanArray::from(vec![false, true]), vec![1]))
);
@@ -2781,7 +2846,10 @@ mod tests {
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
- &BooleanArray::from(vec![true, false])
+ &BooleanArray::from(vec![true, false]),
+ &HashSet::new(),
+ &0,
+ &0
),
Some((BooleanArray::from(vec![true, false]), vec![0]))
);
@@ -2790,7 +2858,10 @@ mod tests {
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
- &BooleanArray::from(vec![false, true, true, true, true, true])
+ &BooleanArray::from(vec![false, true, true, true, true, true]),
+ &HashSet::new(),
+ &0,
+ &0
),
Some((
BooleanArray::from(vec![false, true, false, true, false,
false]),
@@ -2802,7 +2873,10 @@ mod tests {
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
- &BooleanArray::from(vec![false, false, false, false, false,
true])
+ &BooleanArray::from(vec![false, false, false, false, false,
true]),
+ &HashSet::new(),
+ &0,
+ &0
),
Some((
BooleanArray::from(vec![false, false, false, false, false,
true]),
@@ -2813,6 +2887,89 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn left_anti_join_filtered_mask() -> Result<()> {
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftAnti,
+ UInt64Array::from(vec![0, 0, 1, 1]),
+ &BooleanArray::from(vec![true, true, false, false]),
+ &HashSet::new(),
+ &0,
+ &1
+ ),
+ Some((BooleanArray::from(vec![false, false, false, true]),
vec![0]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftAnti,
+ UInt64Array::from(vec![0, 1]),
+ &BooleanArray::from(vec![true, true]),
+ &HashSet::new(),
+ &0,
+ &1
+ ),
+ Some((BooleanArray::from(vec![false, false]), vec![0, 1]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftAnti,
+ UInt64Array::from(vec![0, 1]),
+ &BooleanArray::from(vec![false, true]),
+ &HashSet::new(),
+ &0,
+ &1
+ ),
+ Some((BooleanArray::from(vec![true, false]), vec![1]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftAnti,
+ UInt64Array::from(vec![0, 1]),
+ &BooleanArray::from(vec![true, false]),
+ &HashSet::new(),
+ &0,
+ &1
+ ),
+ Some((BooleanArray::from(vec![false, true]), vec![0]))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftAnti,
+ UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &BooleanArray::from(vec![false, true, true, true, true, true]),
+ &HashSet::new(),
+ &0,
+ &1
+ ),
+ Some((
+ BooleanArray::from(vec![false, false, false, false, false,
false]),
+ vec![0, 1]
+ ))
+ );
+
+ assert_eq!(
+ get_filtered_join_mask(
+ LeftAnti,
+ UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
+ &BooleanArray::from(vec![false, false, false, false, false,
true]),
+ &HashSet::new(),
+ &0,
+ &1
+ ),
+ Some((
+ BooleanArray::from(vec![false, false, true, false, false,
false]),
+ vec![1]
+ ))
+ );
+
+ Ok(())
+ }
+
/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 3a27d9693d..babb7dc8fd 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -378,24 +378,6 @@ select t1.* from t1 where exists (select 1 from t2 where
t2.a = t1.a and t2.b !=
11 12
11 13
-#LEFTANTI tests
-# returns no rows instead of correct result
-#query III
-#select * from (
-#with
-#t1 as (
-# select 11 a, 12 b, 1 c union all
-# select 11 a, 13 b, 2 c),
-#t2 as (
-# select 11 a, 12 b, 3 c union all
-# select 11 a, 14 b, 4 c
-# )
-#select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b and t1.c > t2.c)
-#) order by 1, 2;
-#----
-#11 12 1
-#11 13 2
-
# Set batch size to 1 for sort merge join to test scenario when data spread
across multiple batches
statement ok
set datafusion.execution.batch_size = 1;
@@ -431,5 +413,108 @@ select t1.* from t1 where exists (select 1 from t2 where
t2.a = t1.a and t2.b !=
11 12
11 13
+#LEFTANTI tests
+statement ok
+set datafusion.execution.batch_size = 10;
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b),
+t2 as (
+ select 11 a, 13 c union all
+ select 11 a, 14 c
+ )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t1.b > t2.c)
+) order by 1, 2
+----
+11 12
+
+query III
+select * from (
+with
+t1 as (
+ select 11 a, 12 b, 1 c union all
+ select 11 a, 13 b, 2 c),
+t2 as (
+ select 11 a, 12 b, 3 c union all
+ select 11 a, 14 b, 4 c
+ )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+query III
+select * from (
+with
+t1 as (
+ select 11 a, 12 b, 1 c union all
+ select 11 a, 13 b, 2 c),
+t2 as (
+ select 11 a, 12 b, 3 c where false
+ )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+# Test LEFT ANTI with cross batch data distribution
+statement ok
+set datafusion.execution.batch_size = 1;
+
+query II
+select * from (
+with
+t1 as (
+ select 11 a, 12 b),
+t2 as (
+ select 11 a, 13 c union all
+ select 11 a, 14 c
+ )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t1.b > t2.c)
+) order by 1, 2
+----
+11 12
+
+query III
+select * from (
+with
+t1 as (
+ select 11 a, 12 b, 1 c union all
+ select 11 a, 13 b, 2 c),
+t2 as (
+ select 11 a, 12 b, 3 c union all
+ select 11 a, 14 b, 4 c
+ )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+query III
+select * from (
+with
+t1 as (
+ select 11 a, 12 b, 1 c union all
+ select 11 a, 13 b, 2 c),
+t2 as (
+ select 11 a, 12 b, 3 c where false
+ )
+select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and
t2.b != t1.b and t1.c > t2.c)
+) order by 1, 2;
+----
+11 12 1
+11 13 2
+
+# return sql params back to default values
statement ok
set datafusion.optimizer.prefer_hash_join = true;
+
+statement ok
+set datafusion.execution.batch_size = 8192;
+
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]