This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 69687da337 feat: support `RightAnti` for `SortMergeJoin` (#13680)
69687da337 is described below
commit 69687da337331c978adc201a75955d7e96605d6f
Author: irenjj <[email protected]>
AuthorDate: Mon Jan 6 04:42:20 2025 +0800
feat: support `RightAnti` for `SortMergeJoin` (#13680)
* feat: support `RightAnti` for `SortMergeJoin`
---
datafusion/core/tests/fuzz_cases/join_fuzz.rs | 28 +-
.../physical-plan/src/joins/sort_merge_join.rs | 710 ++++++++++++++++-----
datafusion/physical-plan/src/test.rs | 20 +
.../sqllogictest/test_files/sort_merge_join.slt | 48 ++
4 files changed, 643 insertions(+), 163 deletions(-)
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index cf1742a30e..b331388f4f 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -210,7 +210,7 @@ async fn test_semi_join_1k_filtered() {
}
#[tokio::test]
-async fn test_anti_join_1k() {
+async fn test_left_anti_join_1k() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
@@ -222,7 +222,7 @@ async fn test_anti_join_1k() {
}
#[tokio::test]
-async fn test_anti_join_1k_filtered() {
+async fn test_left_anti_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
@@ -233,6 +233,30 @@ async fn test_anti_join_1k_filtered() {
.await
}
+#[tokio::test]
+async fn test_right_anti_join_1k() {
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::RightAnti,
+ None,
+ )
+ .run_test(&[HjSmj, NljHj], false)
+ .await
+}
+
+#[tokio::test]
+async fn test_right_anti_join_1k_filtered() {
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::RightAnti,
+ Some(Box::new(col_lt_col_filter)),
+ )
+ .run_test(&[HjSmj, NljHj], false)
+ .await
+}
+
#[tokio::test]
async fn test_left_mark_join_1k() {
JoinFuzzTestCase::new(
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index 54bd63084e..bcacc7dcae 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -935,7 +935,7 @@ fn get_corrected_filter_mask(
Some(corrected_mask.finish())
}
- JoinType::LeftAnti => {
+ JoinType::LeftAnti | JoinType::RightAnti => {
for i in 0..row_indices_length {
let last_index =
last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
@@ -1047,6 +1047,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftMark
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::RightAnti
| JoinType::Full
)
{
@@ -1129,6 +1130,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::RightAnti
| JoinType::LeftMark
| JoinType::Full
)
@@ -1152,6 +1154,7 @@ impl Stream for SortMergeJoinStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::RightAnti
| JoinType::Full
| JoinType::LeftMark
)
@@ -1468,6 +1471,7 @@ impl SortMergeJoinStream {
| JoinType::RightSemi
| JoinType::Full
| JoinType::LeftAnti
+ | JoinType::RightAnti
| JoinType::LeftMark
) {
join_streamed = !self.streamed_joined;
@@ -1501,7 +1505,9 @@ impl SortMergeJoinStream {
join_buffered = true;
};
- if matches!(self.join_type, JoinType::LeftAnti) &&
self.filter.is_some() {
+ if matches!(self.join_type, JoinType::LeftAnti |
JoinType::RightAnti)
+ && self.filter.is_some()
+ {
join_streamed = !self.streamed_joined;
join_buffered = join_streamed;
}
@@ -1684,7 +1690,10 @@ impl SortMergeJoinStream {
let right_indices: UInt64Array = chunk.buffered_indices.finish();
let mut right_columns = if matches!(self.join_type,
JoinType::LeftMark) {
vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef]
- } else if matches!(self.join_type, JoinType::LeftSemi |
JoinType::LeftAnti) {
+ } else if matches!(
+ self.join_type,
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti
+ ) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
fetch_right_columns_by_idxs(
@@ -1717,6 +1726,14 @@ impl SortMergeJoinStream {
)?;
get_filter_column(&self.filter, &left_columns,
&right_cols)
+ } else if matches!(self.join_type, JoinType::RightAnti) {
+ let right_cols = fetch_right_columns_by_idxs(
+ &self.buffered_data,
+ chunk.buffered_batch_idx.unwrap(),
+ &right_indices,
+ )?;
+
+ get_filter_column(&self.filter, &right_cols,
&left_columns)
} else {
get_filter_column(&self.filter, &left_columns,
&right_columns)
}
@@ -1773,6 +1790,7 @@ impl SortMergeJoinStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::RightAnti
| JoinType::LeftMark
| JoinType::Full
) {
@@ -1856,6 +1874,7 @@ impl SortMergeJoinStream {
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
+ | JoinType::RightAnti
| JoinType::LeftMark
| JoinType::Full
))
@@ -1902,6 +1921,14 @@ impl SortMergeJoinStream {
&out_mask
};
+ self.filter_record_batch_by_join_type(record_batch, corrected_mask)
+ }
+
+ fn filter_record_batch_by_join_type(
+ &mut self,
+ record_batch: RecordBatch,
+ corrected_mask: &BooleanArray,
+ ) -> Result<RecordBatch> {
let mut filtered_record_batch =
filter_record_batch(&record_batch, corrected_mask)?;
let left_columns_length = self.streamed_schema.fields.len();
@@ -1954,6 +1981,10 @@ impl SortMergeJoinStream {
let output_column_indices =
(0..left_columns_length).collect::<Vec<_>>();
filtered_record_batch =
filtered_record_batch.project(&output_column_indices)?;
+ } else if matches!(self.join_type, JoinType::RightAnti) {
+ let output_column_indices =
(0..right_columns_length).collect::<Vec<_>>();
+ filtered_record_batch =
+ filtered_record_batch.project(&output_column_indices)?;
} else if matches!(self.join_type, JoinType::Full)
&& corrected_mask.false_count() > 0
{
@@ -2389,6 +2420,7 @@ mod tests {
use arrow_array::builder::{BooleanBuilder, UInt64Builder};
use arrow_array::{BooleanArray, UInt64Array};
+ use datafusion_common::JoinSide;
use datafusion_common::JoinType::*;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains,
JoinType, Result,
@@ -2397,13 +2429,15 @@ mod tests {
use datafusion_execution::disk_manager::DiskManagerConfig;
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_execution::TaskContext;
+ use datafusion_expr::Operator;
+ use datafusion_physical_expr::expressions::BinaryExpr;
use crate::expressions::Column;
use crate::joins::sort_merge_join::{get_corrected_filter_mask,
JoinedRecordBatches};
- use crate::joins::utils::JoinOn;
+ use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
use crate::joins::SortMergeJoinExec;
use crate::memory::MemoryExec;
- use crate::test::build_table_i32;
+ use crate::test::{build_table_i32, build_table_i32_two_cols};
use crate::{common, ExecutionPlan};
fn build_table(
@@ -2494,6 +2528,15 @@ mod tests {
Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
}
+ pub fn build_table_two_cols(
+ a: (&str, &Vec<i32>),
+ b: (&str, &Vec<i32>),
+ ) -> Arc<dyn ExecutionPlan> {
+ let batch = build_table_i32_two_cols(a, b);
+ let schema = batch.schema();
+ Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap())
+ }
+
fn join(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
@@ -2523,6 +2566,26 @@ mod tests {
)
}
+ fn join_with_filter(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ filter: JoinFilter,
+ join_type: JoinType,
+ sort_options: Vec<SortOptions>,
+ null_equals_null: bool,
+ ) -> Result<SortMergeJoinExec> {
+ SortMergeJoinExec::try_new(
+ left,
+ right,
+ on,
+ Some(filter),
+ join_type,
+ sort_options,
+ null_equals_null,
+ )
+ }
+
async fn join_collect(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
@@ -2533,6 +2596,25 @@ mod tests {
join_collect_with_options(left, right, on, join_type, sort_options,
false).await
}
+ async fn join_collect_with_filter(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ filter: JoinFilter,
+ join_type: JoinType,
+ ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+ let sort_options = vec![SortOptions::default(); on.len()];
+
+ let task_ctx = Arc::new(TaskContext::default());
+ let join =
+ join_with_filter(left, right, on, filter, join_type, sort_options,
false)?;
+ let columns = columns(&join.schema());
+
+ let stream = join.execute(0, task_ctx)?;
+ let batches = common::collect(stream).await?;
+ Ok((columns, batches))
+ }
+
async fn join_collect_with_options(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
@@ -2914,7 +2996,7 @@ mod tests {
}
#[tokio::test]
- async fn join_anti() -> Result<()> {
+ async fn join_left_anti() -> Result<()> {
let left = build_table(
("a1", &vec![1, 2, 2, 3, 5]),
("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
@@ -2944,6 +3026,310 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn join_right_anti_one_one() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2]),
+ ("b1", &vec![4, 5, 5]),
+ ("c1", &vec![7, 8, 8]),
+ );
+ let right =
+ build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5,
6]));
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ )];
+
+ let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+ let expected = [
+ "+----+----+",
+ "| a2 | b1 |",
+ "+----+----+",
+ "| 30 | 6 |",
+ "+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected, &batches);
+
+ let left2 = build_table(
+ ("a1", &vec![1, 2, 2]),
+ ("b1", &vec![4, 5, 5]),
+ ("c1", &vec![7, 8, 8]),
+ );
+ let right2 = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left2.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right2.schema())?) as _,
+ )];
+
+ let (_, batches2) = join_collect(left2, right2, on, RightAnti).await?;
+ let expected2 = [
+ "+----+----+----+",
+ "| a2 | b1 | c2 |",
+ "+----+----+----+",
+ "| 30 | 6 | 90 |",
+ "+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected2, &batches2);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_anti_two_two() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2]),
+ ("b1", &vec![4, 5, 5]),
+ ("c1", &vec![7, 8, 8]),
+ );
+ let right =
+ build_table_two_cols(("a2", &vec![10, 20, 30]), ("b1", &vec![4, 5,
6]));
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ ),
+ ];
+
+ let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+ let expected = [
+ "+----+----+",
+ "| a2 | b1 |",
+ "+----+----+",
+ "| 10 | 4 |",
+ "| 20 | 5 |",
+ "| 30 | 6 |",
+ "+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected, &batches);
+
+ let left = build_table(
+ ("a1", &vec![1, 2, 2]),
+ ("b1", &vec![4, 5, 5]),
+ ("c1", &vec![7, 8, 8]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ ),
+ ];
+
+ let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+ let expected = [
+ "+----+----+----+",
+ "| a2 | b1 | c2 |",
+ "+----+----+----+",
+ "| 10 | 4 | 70 |",
+ "| 20 | 5 | 80 |",
+ "| 30 | 6 | 90 |",
+ "+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_anti_two_with_filter() -> Result<()> {
+ let left = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c1",
&vec![30]));
+ let right = build_table(("a1", &vec![1]), ("b1", &vec![10]), ("c2",
&vec![20]));
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ ),
+ ];
+ let filter = JoinFilter::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("c2", 1)),
+ Operator::Gt,
+ Arc::new(Column::new("c1", 0)),
+ )),
+ vec![
+ ColumnIndex {
+ index: 2,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 2,
+ side: JoinSide::Right,
+ },
+ ],
+ Schema::new(vec![
+ Field::new("c1", DataType::Int32, true),
+ Field::new("c2", DataType::Int32, true),
+ ]),
+ );
+ let (_, batches) =
+ join_collect_with_filter(left, right, on, filter,
RightAnti).await?;
+ let expected = [
+ "+----+----+----+",
+ "| a1 | b1 | c2 |",
+ "+----+----+----+",
+ "| 1 | 10 | 20 |",
+ "+----+----+----+",
+ ];
+ assert_batches_eq!(expected, &batches);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_anti_with_nulls() -> Result<()> {
+ let left = build_table_i32_nullable(
+ ("a1", &vec![Some(0), Some(1), Some(2), Some(2), Some(3)]),
+ ("b1", &vec![Some(3), Some(4), Some(5), None, Some(6)]),
+ ("c2", &vec![Some(60), None, Some(80), Some(85), Some(90)]),
+ );
+ let right = build_table_i32_nullable(
+ ("a1", &vec![Some(1), Some(2), Some(2), Some(3)]),
+ ("b1", &vec![Some(4), Some(5), None, Some(6)]), // null in key
field
+ ("c2", &vec![Some(7), Some(8), Some(8), None]), // null in non-key
field
+ );
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ ),
+ ];
+
+ let (_, batches) = join_collect(left, right, on, RightAnti).await?;
+ let expected = [
+ "+----+----+----+",
+ "| a1 | b1 | c2 |",
+ "+----+----+----+",
+ "| 2 | | 8 |",
+ "+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected, &batches);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_anti_with_nulls_with_options() -> Result<()> {
+ let left = build_table_i32_nullable(
+ ("a1", &vec![Some(1), Some(2), Some(1), Some(0), Some(2)]),
+ ("b1", &vec![Some(4), Some(5), Some(5), None, Some(5)]),
+ ("c1", &vec![Some(7), Some(8), Some(8), Some(60), None]),
+ );
+ let right = build_table_i32_nullable(
+ ("a1", &vec![Some(3), Some(2), Some(2), Some(1)]),
+ ("b1", &vec![None, Some(5), Some(5), Some(4)]), // null in key
field
+ ("c2", &vec![Some(9), None, Some(8), Some(7)]), // null in non-key
field
+ );
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ ),
+ ];
+
+ let (_, batches) = join_collect_with_options(
+ left,
+ right,
+ on,
+ RightAnti,
+ vec![
+ SortOptions {
+ descending: true,
+ nulls_first: false,
+ };
+ 2
+ ],
+ true,
+ )
+ .await?;
+
+ let expected = [
+ "+----+----+----+",
+ "| a1 | b1 | c2 |",
+ "+----+----+----+",
+ "| 3 | | 9 |",
+ "| 2 | 5 | |",
+ "| 2 | 5 | 8 |",
+ "+----+----+----+",
+ ];
+ // The output order is important as SMJ preserves sortedness
+ assert_batches_eq!(expected, &batches);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_right_anti_output_two_batches() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 2]),
+ ("b1", &vec![4, 5, 5]),
+ ("c1", &vec![7, 8, 8]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
+ ),
+ ];
+
+ let (_, batches) =
+ join_collect_batch_size_equals_two(left, right, on,
LeftAnti).await?;
+ let expected = [
+ "+----+----+----+",
+ "| a1 | b1 | c1 |",
+ "+----+----+----+",
+ "| 1 | 4 | 7 |",
+ "| 2 | 5 | 8 |",
+ "| 2 | 5 | 8 |",
+ "+----+----+----+",
+ ];
+ assert_eq!(batches.len(), 2);
+ assert_eq!(batches[0].num_rows(), 2);
+ assert_eq!(batches[1].num_rows(), 1);
+ assert_batches_eq!(expected, &batches);
+ Ok(())
+ }
+
#[tokio::test]
async fn join_semi() -> Result<()> {
let left = build_table(
@@ -4138,174 +4524,176 @@ mod tests {
}
#[tokio::test]
- async fn test_left_anti_join_filtered_mask() -> Result<()> {
- let mut joined_batches = build_joined_record_batches()?;
- let schema = joined_batches.batches.first().unwrap().schema();
-
- let output = concat_batches(&schema, &joined_batches.batches)?;
- let out_mask = joined_batches.filter_mask.finish();
- let out_indices = joined_batches.row_indices.finish();
+ async fn test_anti_join_filtered_mask() -> Result<()> {
+ for join_type in [LeftAnti, RightAnti] {
+ let mut joined_batches = build_joined_record_batches()?;
+ let schema = joined_batches.batches.first().unwrap().schema();
+
+ let output = concat_batches(&schema, &joined_batches.batches)?;
+ let out_mask = joined_batches.filter_mask.finish();
+ let out_indices = joined_batches.row_indices.finish();
+
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0]),
+ &[0usize],
+ &BooleanArray::from(vec![true]),
+ 1
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0]),
- &[0usize],
- &BooleanArray::from(vec![true]),
- 1
- )
- .unwrap(),
- BooleanArray::from(vec![None])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0]),
+ &[0usize],
+ &BooleanArray::from(vec![false]),
+ 1
+ )
+ .unwrap(),
+ BooleanArray::from(vec![Some(true)])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0]),
- &[0usize],
- &BooleanArray::from(vec![false]),
- 1
- )
- .unwrap(),
- BooleanArray::from(vec![Some(true)])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0, 0]),
+ &[0usize; 2],
+ &BooleanArray::from(vec![true, true]),
+ 2
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None, None])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0, 0]),
- &[0usize; 2],
- &BooleanArray::from(vec![true, true]),
- 2
- )
- .unwrap(),
- BooleanArray::from(vec![None, None])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0, 0, 0]),
+ &[0usize; 3],
+ &BooleanArray::from(vec![true, true, true]),
+ 3
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None, None, None])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0, 0, 0]),
- &[0usize; 3],
- &BooleanArray::from(vec![true, true, true]),
- 3
- )
- .unwrap(),
- BooleanArray::from(vec![None, None, None])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0, 0, 0]),
+ &[0usize; 3],
+ &BooleanArray::from(vec![true, false, true]),
+ 3
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None, None, None])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0, 0, 0]),
- &[0usize; 3],
- &BooleanArray::from(vec![true, false, true]),
- 3
- )
- .unwrap(),
- BooleanArray::from(vec![None, None, None])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0, 0, 0]),
+ &[0usize; 3],
+ &BooleanArray::from(vec![false, false, true]),
+ 3
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None, None, None])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0, 0, 0]),
- &[0usize; 3],
- &BooleanArray::from(vec![false, false, true]),
- 3
- )
- .unwrap(),
- BooleanArray::from(vec![None, None, None])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0, 0, 0]),
+ &[0usize; 3],
+ &BooleanArray::from(vec![false, true, true]),
+ 3
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None, None, None])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0, 0, 0]),
- &[0usize; 3],
- &BooleanArray::from(vec![false, true, true]),
- 3
- )
- .unwrap(),
- BooleanArray::from(vec![None, None, None])
- );
+ assert_eq!(
+ get_corrected_filter_mask(
+ join_type,
+ &UInt64Array::from(vec![0, 0, 0]),
+ &[0usize; 3],
+ &BooleanArray::from(vec![false, false, false]),
+ 3
+ )
+ .unwrap(),
+ BooleanArray::from(vec![None, None, Some(true)])
+ );
- assert_eq!(
- get_corrected_filter_mask(
- LeftAnti,
- &UInt64Array::from(vec![0, 0, 0]),
- &[0usize; 3],
- &BooleanArray::from(vec![false, false, false]),
- 3
+ let corrected_mask = get_corrected_filter_mask(
+ join_type,
+ &out_indices,
+ &joined_batches.batch_ids,
+ &out_mask,
+ output.num_rows(),
)
- .unwrap(),
- BooleanArray::from(vec![None, None, Some(true)])
- );
-
- let corrected_mask = get_corrected_filter_mask(
- LeftAnti,
- &out_indices,
- &joined_batches.batch_ids,
- &out_mask,
- output.num_rows(),
- )
- .unwrap();
-
- assert_eq!(
- corrected_mask,
- BooleanArray::from(vec![
- None,
- None,
- None,
- None,
- None,
- Some(true),
- None,
- Some(true)
- ])
- );
-
- let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
+ .unwrap();
+
+ assert_eq!(
+ corrected_mask,
+ BooleanArray::from(vec![
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(true),
+ None,
+ Some(true)
+ ])
+ );
- assert_batches_eq!(
- &[
- "+---+----+---+----+",
- "| a | b | x | y |",
- "+---+----+---+----+",
- "| 1 | 13 | 1 | 12 |",
- "| 1 | 14 | 1 | 11 |",
- "+---+----+---+----+",
- ],
- &[filtered_rb]
- );
+ let filtered_rb = filter_record_batch(&output, &corrected_mask)?;
+
+ assert_batches_eq!(
+ &[
+ "+---+----+---+----+",
+ "| a | b | x | y |",
+ "+---+----+---+----+",
+ "| 1 | 13 | 1 | 12 |",
+ "| 1 | 14 | 1 | 11 |",
+ "+---+----+---+----+",
+ ],
+ &[filtered_rb]
+ );
- // output null rows
- let null_mask = arrow::compute::not(&corrected_mask)?;
- assert_eq!(
- null_mask,
- BooleanArray::from(vec![
- None,
- None,
- None,
- None,
- None,
- Some(false),
- None,
- Some(false),
- ])
- );
+ // output null rows
+ let null_mask = arrow::compute::not(&corrected_mask)?;
+ assert_eq!(
+ null_mask,
+ BooleanArray::from(vec![
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(false),
+ None,
+ Some(false),
+ ])
+ );
- let null_joined_batch = filter_record_batch(&output, &null_mask)?;
+ let null_joined_batch = filter_record_batch(&output, &null_mask)?;
- assert_batches_eq!(
- &[
- "+---+---+---+---+",
- "| a | b | x | y |",
- "+---+---+---+---+",
- "+---+---+---+---+",
- ],
- &[null_joined_batch]
- );
+ assert_batches_eq!(
+ &[
+ "+---+---+---+---+",
+ "| a | b | x | y |",
+ "+---+---+---+---+",
+ "+---+---+---+---+",
+ ],
+ &[null_joined_batch]
+ );
+ }
Ok(())
}
diff --git a/datafusion/physical-plan/src/test.rs
b/datafusion/physical-plan/src/test.rs
index 90ec9b1068..b7bbfd1169 100644
--- a/datafusion/physical-plan/src/test.rs
+++ b/datafusion/physical-plan/src/test.rs
@@ -88,6 +88,26 @@ pub fn build_table_i32(
.unwrap()
}
+/// Returns record batch with 2 columns of i32 in memory
+pub fn build_table_i32_two_cols(
+ a: (&str, &Vec<i32>),
+ b: (&str, &Vec<i32>),
+) -> RecordBatch {
+ let schema = Schema::new(vec![
+ Field::new(a.0, DataType::Int32, false),
+ Field::new(b.0, DataType::Int32, false),
+ ]);
+
+ RecordBatch::try_new(
+ Arc::new(schema),
+ vec![
+ Arc::new(Int32Array::from(a.1.clone())),
+ Arc::new(Int32Array::from(b.1.clone())),
+ ],
+ )
+ .unwrap()
+}
+
/// Returns memory table scan wrapped around record batch with 3 columns of i32
pub fn build_table_scan_i32(
a: (&str, &Vec<i32>),
diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt
b/datafusion/sqllogictest/test_files/sort_merge_join.slt
index 9a20e7987f..1df52dd1eb 100644
--- a/datafusion/sqllogictest/test_files/sort_merge_join.slt
+++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt
@@ -647,6 +647,54 @@ NULL NULL 7 9
NULL NULL 8 10
NULL NULL 9 11
+query II
+select * from (
+with
+t1 as (
+ select 31 a, 32 b union all
+ select 31 a, 33 b
+),
+t2 as (
+ select 31 a, 32 b union all
+ select 31 a, 35 b
+)
+select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b
+) order by 1, 2;
+----
+31 35
+
+query II
+select * from (
+with
+t1 as (
+ select 41 a, 42 b union all
+ select 41 a, 43 b
+),
+t2 as (
+ select 41 a, 42 b union all
+ select 41 a, 45 b
+)
+select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b
+) order by 1, 2;
+----
+41 45
+
+query II
+select * from (
+with
+t1 as (
+ select 51 a, 52 b union all
+ select 51 a, 53 b
+),
+t2 as (
+ select 51 a, 52 b union all
+ select 51 a, 54 b
+)
+select t2.* from t1 right anti join t2 on t1.a = t2.a and t1.b = t2.b
+) order by 1, 2;
+----
+51 54
+
# return sql params back to default values
statement ok
set datafusion.optimizer.prefer_hash_join = true;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]