This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e9f9a239ae Minor: Add routine to debug join fuzz tests (#10970)
e9f9a239ae is described below

commit e9f9a239ae9467850b7d17c42f0f11555a7d3058
Author: Oleks V <[email protected]>
AuthorDate: Tue Jun 18 10:25:53 2024 -0700

    Minor: Add routine to debug join fuzz tests (#10970)
    
    * Minor: Add routine to debug join fuzz tests
---
 datafusion/core/tests/fuzz_cases/join_fuzz.rs | 202 +++++++++++++++++++++-----
 1 file changed, 162 insertions(+), 40 deletions(-)

diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index 516749e82a..5fdf020794 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -43,6 +43,17 @@ use datafusion::physical_plan::memory::MemoryExec;
 use datafusion::prelude::{SessionConfig, SessionContext};
 use test_utils::stagger_batch_with_seed;
 
+// Determines what Fuzz tests needs to run
+// Ideally all tests should match, but in reality some tests
+// passes only partial cases
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+enum JoinTestType {
+    // compare NestedLoopJoin and HashJoin
+    NljHj,
+    // compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin 
and NestedLoopJoin
+    // because if existing variants both passed that means SortMergeJoin and 
NestedLoopJoin also passes
+    HjSmj,
+}
 #[tokio::test]
 async fn test_inner_join_1k() {
     JoinFuzzTestCase::new(
@@ -51,7 +62,7 @@ async fn test_inner_join_1k() {
         JoinType::Inner,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -71,6 +82,30 @@ fn less_than_100_join_filter(schema1: Arc<Schema>, _schema2: 
Arc<Schema>) -> Joi
     JoinFilter::new(less_than_100, column_indices, intermediate_schema)
 }
 
+fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter 
{
+    let less_than_100 = Arc::new(BinaryExpr::new(
+        Arc::new(Column::new("x", 1)),
+        Operator::Lt,
+        Arc::new(Column::new("x", 0)),
+    )) as _;
+    let column_indices = vec![
+        ColumnIndex {
+            index: 2,
+            side: JoinSide::Left,
+        },
+        ColumnIndex {
+            index: 2,
+            side: JoinSide::Right,
+        },
+    ];
+    let intermediate_schema = Schema::new(vec![
+        schema1.field_with_name("x").unwrap().to_owned(),
+        schema2.field_with_name("x").unwrap().to_owned(),
+    ]);
+
+    JoinFilter::new(less_than_100, column_indices, intermediate_schema)
+}
+
 #[tokio::test]
 async fn test_inner_join_1k_filtered() {
     JoinFuzzTestCase::new(
@@ -79,7 +114,7 @@ async fn test_inner_join_1k_filtered() {
         JoinType::Inner,
         Some(Box::new(less_than_100_join_filter)),
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -91,7 +126,7 @@ async fn test_inner_join_1k_smjoin() {
         JoinType::Inner,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -103,7 +138,7 @@ async fn test_left_join_1k() {
         JoinType::Left,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -115,7 +150,7 @@ async fn test_left_join_1k_filtered() {
         JoinType::Left,
         Some(Box::new(less_than_100_join_filter)),
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -127,7 +162,7 @@ async fn test_right_join_1k() {
         JoinType::Right,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 // Add support for Right filtered joins
@@ -140,7 +175,7 @@ async fn test_right_join_1k_filtered() {
         JoinType::Right,
         Some(Box::new(less_than_100_join_filter)),
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -152,7 +187,7 @@ async fn test_full_join_1k() {
         JoinType::Full,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -164,7 +199,7 @@ async fn test_full_join_1k_filtered() {
         JoinType::Full,
         Some(Box::new(less_than_100_join_filter)),
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -176,12 +211,13 @@ async fn test_semi_join_1k() {
         JoinType::LeftSemi,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
 // The test is flaky
 // https://github.com/apache/datafusion/issues/10886
+// SMJ produces 1 more row in the output
 #[ignore]
 #[tokio::test]
 async fn test_semi_join_1k_filtered() {
@@ -189,9 +225,9 @@ async fn test_semi_join_1k_filtered() {
         make_staggered_batches(1000),
         make_staggered_batches(1000),
         JoinType::LeftSemi,
-        Some(Box::new(less_than_100_join_filter)),
+        Some(Box::new(col_lt_col_filter)),
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj], false)
     .await
 }
 
@@ -203,7 +239,7 @@ async fn test_anti_join_1k() {
         JoinType::LeftAnti,
         None,
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -217,7 +253,7 @@ async fn test_anti_join_1k_filtered() {
         JoinType::LeftAnti,
         Some(Box::new(less_than_100_join_filter)),
     )
-    .run_test()
+    .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
     .await
 }
 
@@ -331,7 +367,7 @@ impl JoinFuzzTestCase {
                 self.on_columns().clone(),
                 self.join_filter(),
                 self.join_type,
-                vec![SortOptions::default(), SortOptions::default()],
+                vec![SortOptions::default(); self.on_columns().len()],
                 false,
             )
             .unwrap(),
@@ -381,9 +417,11 @@ impl JoinFuzzTestCase {
         )
     }
 
-    /// Perform sort-merge join and hash join on same input
-    /// and verify two outputs are equal
-    async fn run_test(&self) {
+    /// Perform joins tests on same inputs and verify outputs are equal
+    /// `join_tests` - identifies what join types to test
+    /// if `debug` flag is set the test will save randomly generated inputs 
and outputs to user folders,
+    /// so it is easy to debug a test on top of the failed data
+    async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) {
         for batch_size in self.batch_sizes {
             let session_config = 
SessionConfig::new().with_batch_size(*batch_size);
             let ctx = SessionContext::new_with_config(session_config);
@@ -394,17 +432,30 @@ impl JoinFuzzTestCase {
             let hj = self.hash_join();
             let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();
 
+            let nlj = self.nested_loop_join();
+            let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
+
             // Get actual row counts(without formatting overhead) for HJ and 
SMJ
             let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + 
b.num_rows());
             let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + 
b.num_rows());
+            let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + 
b.num_rows());
 
-            assert_eq!(
-                hj_rows, smj_rows,
-                "SortMergeJoinExec and HashJoinExec produced different row 
counts"
-            );
+            if debug {
+                println!("The debug is ON. Input data will be saved");
+                let out_dir_name = 
&format!("fuzz_test_debug_batch_size_{batch_size}");
+                Self::save_as_parquet(&self.input1, out_dir_name, "input1");
+                Self::save_as_parquet(&self.input2, out_dir_name, "input2");
 
-            let nlj = self.nested_loop_join();
-            let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
+                if join_tests.contains(&JoinTestType::NljHj) {
+                    Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj");
+                    Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
+                }
+
+                if join_tests.contains(&JoinTestType::HjSmj) {
+                    Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
+                    Self::save_as_parquet(&smj_collected, out_dir_name, "smj");
+                }
+            }
 
             // compare
             let smj_formatted =
@@ -425,35 +476,106 @@ impl JoinFuzzTestCase {
                 nlj_formatted.trim().lines().collect();
             nlj_formatted_sorted.sort_unstable();
 
-            // row level compare if any of joins returns the result
-            // the reason is different formatting when there is no rows
-            if smj_rows > 0 || hj_rows > 0 {
-                for (i, (smj_line, hj_line)) in smj_formatted_sorted
+            if join_tests.contains(&JoinTestType::NljHj) {
+                let err_msg_rowcnt = format!("NestedLoopJoinExec and 
HashJoinExec produced different row counts, batch_size: {}", batch_size);
+                assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());
+
+                let err_msg_contents = format!("NestedLoopJoinExec and 
HashJoinExec produced different results, batch_size: {}", batch_size);
+                // row level compare if any of joins returns the result
+                // the reason is different formatting when there is no rows
+                for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
                     .iter()
                     .zip(&hj_formatted_sorted)
                     .enumerate()
                 {
                     assert_eq!(
-                        (i, smj_line),
+                        (i, nlj_line),
                         (i, hj_line),
-                        "SortMergeJoinExec and HashJoinExec produced different 
results"
+                        "{}",
+                        err_msg_contents.as_str()
                     );
                 }
             }
 
-            for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
-                .iter()
-                .zip(&hj_formatted_sorted)
-                .enumerate()
-            {
-                assert_eq!(
-                    (i, nlj_line),
-                    (i, hj_line),
-                    "NestedLoopJoinExec and HashJoinExec produced different 
results"
-                );
+            if join_tests.contains(&JoinTestType::HjSmj) {
+                let err_msg_row_cnt = format!("HashJoinExec and 
SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size);
+                assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str());
+
+                let err_msg_contents = format!("SortMergeJoinExec and 
HashJoinExec produced different results, batch_size: {}", &batch_size);
+                // row level compare if any of joins returns the result
+                // the reason is different formatting when there is no rows
+                if smj_rows > 0 || hj_rows > 0 {
+                    for (i, (smj_line, hj_line)) in smj_formatted_sorted
+                        .iter()
+                        .zip(&hj_formatted_sorted)
+                        .enumerate()
+                    {
+                        assert_eq!(
+                            (i, smj_line),
+                            (i, hj_line),
+                            "{}",
+                            err_msg_contents.as_str()
+                        );
+                    }
+                }
             }
         }
     }
+
+    /// This method useful for debugging fuzz tests
+    /// It helps to save randomly generated input test data for both join 
inputs into the user folder
+    /// as a parquet files preserving partitioning.
+    /// Once the data is saved it is possible to run a custom test on top of 
the saved data and debug
+    ///
+    ///     let ctx: SessionContext = SessionContext::new();
+    ///     let df = ctx
+    ///         .read_parquet(
+    ///             "/tmp/input1/*.parquet",
+    ///             ParquetReadOptions::default(),
+    ///         )
+    ///         .await
+    ///         .unwrap();
+    ///     let left = df.collect().await.unwrap();
+    ///
+    ///     let df = ctx
+    ///         .read_parquet(
+    ///             "/tmp/input2/*.parquet",
+    ///             ParquetReadOptions::default(),
+    ///         )
+    ///         .await
+    ///         .unwrap();
+    ///
+    ///     let right = df.collect().await.unwrap();
+    ///         JoinFuzzTestCase::new(
+    ///             left,
+    ///             right,
+    ///             JoinType::LeftSemi,
+    ///             Some(Box::new(less_than_100_join_filter)),
+    ///         )
+    ///         .run_test()
+    ///         .await
+    /// }
+    fn save_as_parquet(input: &[RecordBatch], output_dir: &str, out_name: 
&str) {
+        let out_path = &format!("{output_dir}/{out_name}");
+        std::fs::remove_dir_all(out_path).unwrap_or(());
+        std::fs::create_dir_all(out_path).unwrap();
+
+        input.iter().enumerate().for_each(|(idx, batch)| {
+            let mut file =
+                std::fs::File::create(format!("{out_path}/file_{}.parquet", 
idx))
+                    .unwrap();
+            let mut writer = parquet::arrow::ArrowWriter::try_new(
+                &mut file,
+                input.first().unwrap().schema(),
+                None,
+            )
+            .expect("creating writer");
+            writer.write(batch).unwrap();
+            writer.close().unwrap();
+        });
+
+        println!("The data {out_name} saved as parquet into {out_path}");
+    }
 }
 
 /// Return randomly sized record batches with:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to