This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new df49f9f2ee fix: preserve null_equals_null flag in eliminate_cross_join 
rule (#16356)
df49f9f2ee is described below

commit df49f9f2ee0c23d94750619673b4ea993cdc02a3
Author: Ruihang Xia <waynest...@gmail.com>
AuthorDate: Thu Jun 12 07:25:28 2025 +0800

    fix: preserve null_equals_null flag in eliminate_cross_join rule (#16356)
    
    Signed-off-by: Ruihang Xia <waynest...@gmail.com>
---
 datafusion/optimizer/src/eliminate_cross_join.rs | 120 +++++++++++++++++++----
 1 file changed, 99 insertions(+), 21 deletions(-)

diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs 
b/datafusion/optimizer/src/eliminate_cross_join.rs
index d465faf0c5..deefaef2c0 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin {
         let mut possible_join_keys = JoinKeySet::new();
         let mut all_inputs: Vec<LogicalPlan> = vec![];
         let mut all_filters: Vec<Expr> = vec![];
+        let mut null_equals_null = false;
 
         let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
             // if input isn't a join that can potentially be rewritten
@@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin {
             let Filter {
                 input, predicate, ..
             } = filter;
+
+            // Extract null_equals_null setting from the input join
+            if let LogicalPlan::Join(join) = input.as_ref() {
+                null_equals_null = join.null_equals_null;
+            }
+
             flatten_join_inputs(
                 Arc::unwrap_or_clone(input),
                 &mut possible_join_keys,
@@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin {
 
             extract_possible_join_keys(&predicate, &mut possible_join_keys);
             Some(predicate)
-        } else if matches!(
-            plan,
-            LogicalPlan::Join(Join {
-                join_type: JoinType::Inner,
-                ..
-            })
-        ) {
-            if !can_flatten_join_inputs(&plan) {
-                return Ok(Transformed::no(plan));
-            }
-            flatten_join_inputs(
-                plan,
-                &mut possible_join_keys,
-                &mut all_inputs,
-                &mut all_filters,
-            )?;
-            None
         } else {
-            // recursively try to rewrite children
-            return rewrite_children(self, plan, config);
+            match plan {
+                LogicalPlan::Join(Join {
+                    join_type: JoinType::Inner,
+                    null_equals_null: original_null_equals_null,
+                    ..
+                }) => {
+                    if !can_flatten_join_inputs(&plan) {
+                        return Ok(Transformed::no(plan));
+                    }
+                    flatten_join_inputs(
+                        plan,
+                        &mut possible_join_keys,
+                        &mut all_inputs,
+                        &mut all_filters,
+                    )?;
+                    null_equals_null = original_null_equals_null;
+                    None
+                }
+                _ => {
+                    // recursively try to rewrite children
+                    return rewrite_children(self, plan, config);
+                }
+            }
         };
 
         // Join keys are handled locally:
@@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin {
                 &mut all_inputs,
                 &possible_join_keys,
                 &mut all_join_keys,
+                null_equals_null,
             )?;
         }
 
@@ -290,6 +302,7 @@ fn find_inner_join(
     rights: &mut Vec<LogicalPlan>,
     possible_join_keys: &JoinKeySet,
     all_join_keys: &mut JoinKeySet,
+    null_equals_null: bool,
 ) -> Result<LogicalPlan> {
     for (i, right_input) in rights.iter().enumerate() {
         let mut join_keys = vec![];
@@ -328,7 +341,7 @@ fn find_inner_join(
                 on: join_keys,
                 filter: None,
                 schema: join_schema,
-                null_equals_null: false,
+                null_equals_null,
             }));
         }
     }
@@ -350,7 +363,7 @@ fn find_inner_join(
         filter: None,
         join_type: JoinType::Inner,
         join_constraint: JoinConstraint::On,
-        null_equals_null: false,
+        null_equals_null,
     }))
 }
 
@@ -1333,4 +1346,69 @@ mod tests {
         "
         )
     }
+
+    #[test]
+    fn preserve_null_equals_null_setting() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        // Create an inner join with null_equals_null: true
+        let join_schema = Arc::new(build_join_schema(
+            t1.schema(),
+            t2.schema(),
+            &JoinType::Inner,
+        )?);
+
+        let inner_join = LogicalPlan::Join(Join {
+            left: Arc::new(t1),
+            right: Arc::new(t2),
+            join_type: JoinType::Inner,
+            join_constraint: JoinConstraint::On,
+            on: vec![],
+            filter: None,
+            schema: join_schema,
+            null_equals_null: true, // Set to true to test preservation
+        });
+
+        // Apply filter that can create join conditions
+        let plan = LogicalPlanBuilder::from(inner_join)
+            .filter(binary_expr(
+                col("t1.a").eq(col("t2.a")),
+                And,
+                col("t2.c").lt(lit(20u32)),
+            ))?
+            .build()?;
+
+        let rule = EliminateCrossJoin::new();
+        let optimized_plan = rule.rewrite(plan, 
&OptimizerContext::new())?.data;
+
+        // Verify that null_equals_null is preserved in the optimized plan
+        fn check_null_equals_null_preserved(plan: &LogicalPlan) -> bool {
+            match plan {
+                LogicalPlan::Join(join) => {
+                    // All joins in the optimized plan should preserve 
null_equals_null: true
+                    if !join.null_equals_null {
+                        return false;
+                    }
+                    // Recursively check child plans
+                    plan.inputs()
+                        .iter()
+                        .all(|input| check_null_equals_null_preserved(input))
+                }
+                _ => {
+                    // Recursively check child plans for non-join nodes
+                    plan.inputs()
+                        .iter()
+                        .all(|input| check_null_equals_null_preserved(input))
+                }
+            }
+        }
+
+        assert!(
+            check_null_equals_null_preserved(&optimized_plan),
+            "null_equals_null setting should be preserved after optimization"
+        );
+
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to