This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 97ea05c0f6 Extending join fuzz tests to support join filtering (#10728)
97ea05c0f6 is described below
commit 97ea05c0f60aa11a270420968d3fefc859c0d346
Author: Edmondo Porcu <[email protected]>
AuthorDate: Tue Jun 11 22:19:55 2024 -0400
Extending join fuzz tests to support join filtering (#10728)
* Extending join fuzz tests to support join filtering
---------
Co-authored-by: Oleks V <[email protected]>
---
datafusion/core/tests/fuzz_cases/join_fuzz.rs | 407 +++++++++++++++++++-------
1 file changed, 296 insertions(+), 111 deletions(-)
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index 824f1eec4a..8c2e24de56 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -22,6 +22,11 @@ use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::Schema;
+
+use datafusion_common::ScalarValue;
+use datafusion_physical_expr::expressions::Literal;
+use datafusion_physical_expr::PhysicalExprRef;
+
use rand::Rng;
use datafusion::common::JoinSide;
@@ -40,92 +45,207 @@ use test_utils::stagger_batch_with_seed;
#[tokio::test]
async fn test_inner_join_1k() {
- run_join_test(
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::Inner,
+ None,
+ )
+ .run_test()
+ .await
+}
+
+fn less_than_10_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) ->
JoinFilter {
+ let less_than_100 = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::Lt,
+ Arc::new(Literal::new(ScalarValue::from(100))),
+ )) as _;
+ let column_indices = vec![ColumnIndex {
+ index: 0,
+ side: JoinSide::Left,
+ }];
+ let intermediate_schema =
+ Schema::new(vec![schema1.field_with_name("a").unwrap().to_owned()]);
+
+ JoinFilter::new(less_than_100, column_indices, intermediate_schema)
+}
+
+#[tokio::test]
+async fn test_inner_join_1k_filtered() {
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::Inner,
+ Some(Box::new(less_than_10_join_filter)),
+ )
+ .run_test()
+ .await
+}
+
+#[tokio::test]
+async fn test_inner_join_1k_smjoin() {
+ JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Inner,
+ None,
)
+ .run_test()
.await
}
#[tokio::test]
async fn test_left_join_1k() {
- run_join_test(
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::Left,
+ None,
+ )
+ .run_test()
+ .await
+}
+
+#[tokio::test]
+async fn test_left_join_1k_filtered() {
+ JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Left,
+ Some(Box::new(less_than_10_join_filter)),
)
+ .run_test()
.await
}
#[tokio::test]
async fn test_right_join_1k() {
- run_join_test(
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::Right,
+ None,
+ )
+ .run_test()
+ .await
+}
+// Add support for Right filtered joins
+#[ignore]
+#[tokio::test]
+async fn test_right_join_1k_filtered() {
+ JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Right,
+ Some(Box::new(less_than_10_join_filter)),
)
+ .run_test()
.await
}
#[tokio::test]
async fn test_full_join_1k() {
- run_join_test(
+ JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::Full,
+ None,
)
+ .run_test()
+ .await
+}
+
+#[tokio::test]
+async fn test_full_join_1k_filtered() {
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::Full,
+ Some(Box::new(less_than_10_join_filter)),
+ )
+ .run_test()
.await
}
#[tokio::test]
async fn test_semi_join_1k() {
- run_join_test(
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::LeftSemi,
+ None,
+ )
+ .run_test()
+ .await
+}
+
+#[tokio::test]
+async fn test_semi_join_1k_filtered() {
+ JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftSemi,
+ Some(Box::new(less_than_10_join_filter)),
)
+ .run_test()
.await
}
#[tokio::test]
async fn test_anti_join_1k() {
- run_join_test(
+ JoinFuzzTestCase::new(
+ make_staggered_batches(1000),
+ make_staggered_batches(1000),
+ JoinType::LeftAnti,
+ None,
+ )
+ .run_test()
+ .await
+}
+
+// Test failed for now. https://github.com/apache/datafusion/issues/10872
+#[ignore]
+#[tokio::test]
+async fn test_anti_join_1k_filtered() {
+ JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftAnti,
+ Some(Box::new(less_than_10_join_filter)),
)
+ .run_test()
.await
}
-/// Perform sort-merge join and hash join on same input
-/// and verify two outputs are equal
-async fn run_join_test(
+type JoinFilterBuilder = Box<dyn Fn(Arc<Schema>, Arc<Schema>) -> JoinFilter>;
+
+struct JoinFuzzTestCase {
+ batch_sizes: &'static [usize],
input1: Vec<RecordBatch>,
input2: Vec<RecordBatch>,
join_type: JoinType,
-) {
- let batch_sizes = [1, 2, 7, 49, 50, 51, 100];
- for batch_size in batch_sizes {
- let session_config = SessionConfig::new().with_batch_size(batch_size);
- let ctx = SessionContext::new_with_config(session_config);
- let task_ctx = ctx.task_ctx();
-
- let schema1 = input1[0].schema();
- let schema2 = input2[0].schema();
- let on_columns = vec![
- (
- Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
- Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
- ),
- (
- Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
- Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
- ),
- ];
+ join_filter_builder: Option<JoinFilterBuilder>,
+}
- // Nested loop join uses filter for joining records
- let column_indices = vec![
+impl JoinFuzzTestCase {
+ fn new(
+ input1: Vec<RecordBatch>,
+ input2: Vec<RecordBatch>,
+ join_type: JoinType,
+ join_filter_builder: Option<JoinFilterBuilder>,
+ ) -> Self {
+ Self {
+ batch_sizes: &[1, 2, 7, 49, 50, 51, 100],
+ input1,
+ input2,
+ join_type,
+ join_filter_builder,
+ }
+ }
+
+ fn column_indices(&self) -> Vec<ColumnIndex> {
+ vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
@@ -142,120 +262,185 @@ async fn run_join_test(
index: 1,
side: JoinSide::Right,
},
- ];
- let intermediate_schema = Schema::new(vec![
- schema1.field_with_name("a").unwrap().to_owned(),
- schema1.field_with_name("b").unwrap().to_owned(),
- schema2.field_with_name("a").unwrap().to_owned(),
- schema2.field_with_name("b").unwrap().to_owned(),
- ]);
+ ]
+ }
- let equal_a = Arc::new(BinaryExpr::new(
- Arc::new(Column::new("a", 0)),
- Operator::Eq,
- Arc::new(Column::new("a", 2)),
- )) as _;
- let equal_b = Arc::new(BinaryExpr::new(
- Arc::new(Column::new("b", 1)),
- Operator::Eq,
- Arc::new(Column::new("b", 3)),
- )) as _;
- let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And,
equal_b)) as _;
+ fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
+ let schema1 = self.input1[0].schema();
+ let schema2 = self.input2[0].schema();
+ vec![
+ (
+ Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
+ Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
+ ),
+ ]
+ }
- let on_filter = JoinFilter::new(expression, column_indices,
intermediate_schema);
+ fn intermediate_schema(&self) -> Schema {
+ let schema1 = self.input1[0].schema();
+ let schema2 = self.input2[0].schema();
+ Schema::new(vec![
+ schema1
+ .field_with_name("a")
+ .unwrap()
+ .to_owned()
+ .with_nullable(true),
+ schema1
+ .field_with_name("b")
+ .unwrap()
+ .to_owned()
+ .with_nullable(true),
+ schema2.field_with_name("a").unwrap().to_owned(),
+ schema2.field_with_name("b").unwrap().to_owned(),
+ ])
+ }
- // sort-merge join
+ fn left_right(&self) -> (Arc<MemoryExec>, Arc<MemoryExec>) {
+ let schema1 = self.input1[0].schema();
+ let schema2 = self.input2[0].schema();
let left = Arc::new(
- MemoryExec::try_new(&[input1.clone()], schema1.clone(),
None).unwrap(),
+ MemoryExec::try_new(&[self.input1.clone()], schema1.clone(),
None).unwrap(),
);
let right = Arc::new(
- MemoryExec::try_new(&[input2.clone()], schema2.clone(),
None).unwrap(),
+ MemoryExec::try_new(&[self.input2.clone()], schema2.clone(),
None).unwrap(),
);
- let smj = Arc::new(
+ (left, right)
+ }
+
+ fn join_filter(&self) -> Option<JoinFilter> {
+ let schema1 = self.input1[0].schema();
+ let schema2 = self.input2[0].schema();
+ self.join_filter_builder
+ .as_ref()
+ .map(|builder| builder(schema1, schema2))
+ }
+
+ fn sort_merge_join(&self) -> Arc<SortMergeJoinExec> {
+ let (left, right) = self.left_right();
+ Arc::new(
SortMergeJoinExec::try_new(
left,
right,
- on_columns.clone(),
- None,
- join_type,
+ self.on_columns().clone(),
+ self.join_filter(),
+ self.join_type,
vec![SortOptions::default(), SortOptions::default()],
false,
)
.unwrap(),
- );
- let smj_collected = collect(smj, task_ctx.clone()).await.unwrap();
+ )
+ }
- // hash join
- let left = Arc::new(
- MemoryExec::try_new(&[input1.clone()], schema1.clone(),
None).unwrap(),
- );
- let right = Arc::new(
- MemoryExec::try_new(&[input2.clone()], schema2.clone(),
None).unwrap(),
- );
- let hj = Arc::new(
+ fn hash_join(&self) -> Arc<HashJoinExec> {
+ let (left, right) = self.left_right();
+ Arc::new(
HashJoinExec::try_new(
left,
right,
- on_columns.clone(),
- None,
- &join_type,
+ self.on_columns().clone(),
+ self.join_filter(),
+ &self.join_type,
None,
PartitionMode::Partitioned,
false,
)
.unwrap(),
- );
- let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
+ )
+ }
- // nested loop join
- let left = Arc::new(
- MemoryExec::try_new(&[input1.clone()], schema1.clone(),
None).unwrap(),
- );
- let right = Arc::new(
- MemoryExec::try_new(&[input2.clone()], schema2.clone(),
None).unwrap(),
- );
- let nlj = Arc::new(
- NestedLoopJoinExec::try_new(left, right, Some(on_filter),
&join_type)
+ fn nested_loop_join(&self) -> Arc<NestedLoopJoinExec> {
+ let (left, right) = self.left_right();
+ // Nested loop join uses filter for joining records
+ let column_indices = self.column_indices();
+ let intermediate_schema = self.intermediate_schema();
+
+ let equal_a = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::Eq,
+ Arc::new(Column::new("a", 2)),
+ )) as _;
+ let equal_b = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("b", 1)),
+ Operator::Eq,
+ Arc::new(Column::new("b", 3)),
+ )) as _;
+ let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And,
equal_b)) as _;
+
+ let on_filter = JoinFilter::new(expression, column_indices,
intermediate_schema);
+
+ Arc::new(
+ NestedLoopJoinExec::try_new(left, right, Some(on_filter),
&self.join_type)
.unwrap(),
- );
- let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
+ )
+ }
- // compare
- let smj_formatted =
pretty_format_batches(&smj_collected).unwrap().to_string();
- let hj_formatted =
pretty_format_batches(&hj_collected).unwrap().to_string();
- let nlj_formatted =
pretty_format_batches(&nlj_collected).unwrap().to_string();
+ /// Perform sort-merge join and hash join on same input
+ /// and verify two outputs are equal
+ async fn run_test(&self) {
+ for batch_size in self.batch_sizes {
+ let session_config =
SessionConfig::new().with_batch_size(*batch_size);
+ let ctx = SessionContext::new_with_config(session_config);
+ let task_ctx = ctx.task_ctx();
+ let smj = self.sort_merge_join();
+ let smj_collected = collect(smj, task_ctx.clone()).await.unwrap();
- let mut smj_formatted_sorted: Vec<&str> =
smj_formatted.trim().lines().collect();
- smj_formatted_sorted.sort_unstable();
+ let hj = self.hash_join();
+ let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
- let mut hj_formatted_sorted: Vec<&str> =
hj_formatted.trim().lines().collect();
- hj_formatted_sorted.sort_unstable();
+ let nlj = self.nested_loop_join();
+ let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
- let mut nlj_formatted_sorted: Vec<&str> =
nlj_formatted.trim().lines().collect();
- nlj_formatted_sorted.sort_unstable();
+ // compare
+ let smj_formatted =
+ pretty_format_batches(&smj_collected).unwrap().to_string();
+ let hj_formatted =
pretty_format_batches(&hj_collected).unwrap().to_string();
+ let nlj_formatted =
+ pretty_format_batches(&nlj_collected).unwrap().to_string();
- for (i, (smj_line, hj_line)) in smj_formatted_sorted
- .iter()
- .zip(&hj_formatted_sorted)
- .enumerate()
- {
- assert_eq!(
- (i, smj_line),
- (i, hj_line),
- "SortMergeJoinExec and HashJoinExec produced different results"
- );
- }
+ let mut smj_formatted_sorted: Vec<&str> =
+ smj_formatted.trim().lines().collect();
+ smj_formatted_sorted.sort_unstable();
+
+ let mut hj_formatted_sorted: Vec<&str> =
+ hj_formatted.trim().lines().collect();
+ hj_formatted_sorted.sort_unstable();
+
+ let mut nlj_formatted_sorted: Vec<&str> =
+ nlj_formatted.trim().lines().collect();
+ nlj_formatted_sorted.sort_unstable();
- for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
- .iter()
- .zip(&hj_formatted_sorted)
- .enumerate()
- {
assert_eq!(
- (i, nlj_line),
- (i, hj_line),
- "NestedLoopJoinExec and HashJoinExec produced different
results"
+ smj_formatted_sorted.len(),
+ hj_formatted_sorted.len(),
+ "SortMergeJoinExec and HashJoinExec produced different row
counts"
);
+ for (i, (smj_line, hj_line)) in smj_formatted_sorted
+ .iter()
+ .zip(&hj_formatted_sorted)
+ .enumerate()
+ {
+ assert_eq!(
+ (i, smj_line),
+ (i, hj_line),
+ "SortMergeJoinExec and HashJoinExec produced different
results"
+ );
+ }
+
+ for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
+ .iter()
+ .zip(&hj_formatted_sorted)
+ .enumerate()
+ {
+ assert_eq!(
+ (i, nlj_line),
+ (i, hj_line),
+ "NestedLoopJoinExec and HashJoinExec produced different
results"
+ );
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]