This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a758270730 fix: skip EliminateCrossJoin rule if inner join with filter 
is found (#7529)
a758270730 is described below

commit a758270730f6a66319987363202fcdf1e6fba1c6
Author: epsio-banay <[email protected]>
AuthorDate: Wed Sep 13 19:57:42 2023 +0300

    fix: skip EliminateCrossJoin rule if inner join with filter is found (#7529)
    
    * Skip EliminateCrossJoin rule if inner join with filter is found - check 
recursively
    
    * Add test eliminate_cross_not_possible_nested_inner_join_with_filter
---
 datafusion/optimizer/src/eliminate_cross_join.rs | 81 +++++++++++++++++-------
 1 file changed, 59 insertions(+), 22 deletions(-)

diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs 
b/datafusion/optimizer/src/eliminate_cross_join.rs
index ec4d8a2cbf..d4832d674e 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -60,30 +60,26 @@ impl OptimizerRule for EliminateCrossJoin {
 
                 let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
                 let mut all_inputs: Vec<LogicalPlan> = vec![];
-                match &input {
+                let did_flat_successfully = match &input {
                     LogicalPlan::Join(join) if (join.join_type == 
JoinType::Inner) => {
-                        // The filter of inner join will lost, skip this rule.
-                        // issue: 
https://github.com/apache/arrow-datafusion/issues/4844
-                        if join.filter.is_some() {
-                            return Ok(None);
-                        }
-
-                        flatten_join_inputs(
+                        try_flatten_join_inputs(
                             &input,
                             &mut possible_join_keys,
                             &mut all_inputs,
-                        )?;
-                    }
-                    LogicalPlan::CrossJoin(_) => {
-                        flatten_join_inputs(
-                            &input,
-                            &mut possible_join_keys,
-                            &mut all_inputs,
-                        )?;
+                        )?
                     }
+                    LogicalPlan::CrossJoin(_) => try_flatten_join_inputs(
+                        &input,
+                        &mut possible_join_keys,
+                        &mut all_inputs,
+                    )?,
                     _ => {
                         return utils::optimize_children(self, plan, config);
                     }
+                };
+
+                if !did_flat_successfully {
+                    return Ok(None);
                 }
 
                 let predicate = &filter.predicate;
@@ -137,13 +133,20 @@ impl OptimizerRule for EliminateCrossJoin {
     }
 }
 
-fn flatten_join_inputs(
+/// Recursively accumulate possible_join_keys and inputs from inner joins 
(including cross joins).
+/// Returns a boolean indicating whether the flattening was successful.
+fn try_flatten_join_inputs(
     plan: &LogicalPlan,
     possible_join_keys: &mut Vec<(Expr, Expr)>,
     all_inputs: &mut Vec<LogicalPlan>,
-) -> Result<()> {
+) -> Result<bool> {
     let children = match plan {
-        LogicalPlan::Join(join) => {
+        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
+            if join.filter.is_some() {
+                // The filter of inner join will lost, skip this rule.
+                // issue: 
https://github.com/apache/arrow-datafusion/issues/4844
+                return Ok(false);
+            }
             possible_join_keys.extend(join.on.clone());
             let left = &*(join.left);
             let right = &*(join.right);
@@ -163,18 +166,22 @@ fn flatten_join_inputs(
         match *child {
             LogicalPlan::Join(left_join) => {
                 if left_join.join_type == JoinType::Inner {
-                    flatten_join_inputs(child, possible_join_keys, 
all_inputs)?;
+                    if !try_flatten_join_inputs(child, possible_join_keys, 
all_inputs)? {
+                        return Ok(false);
+                    }
                 } else {
                     all_inputs.push((*child).clone());
                 }
             }
             LogicalPlan::CrossJoin(_) => {
-                flatten_join_inputs(child, possible_join_keys, all_inputs)?;
+                if !try_flatten_join_inputs(child, possible_join_keys, 
all_inputs)? {
+                    return Ok(false);
+                }
             }
             _ => all_inputs.push((*child).clone()),
         }
     }
-    Ok(())
+    Ok(true)
 }
 
 fn find_inner_join(
@@ -363,6 +370,12 @@ mod tests {
         assert_eq!(plan.schema(), optimized_plan.schema())
     }
 
+    fn assert_optimization_rule_fails(plan: &LogicalPlan) {
+        let rule = EliminateCrossJoin::new();
+        let optimized_plan = rule.try_optimize(plan, 
&OptimizerContext::new()).unwrap();
+        assert!(optimized_plan.is_none());
+    }
+
     #[test]
     fn eliminate_cross_with_simple_and() -> Result<()> {
         let t1 = test_table_scan_with_name("t1")?;
@@ -531,6 +544,30 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    /// See https://github.com/apache/arrow-datafusion/issues/7530
+    fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> 
Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+        let t3 = test_table_scan_with_name("t3")?;
+
+        // could not eliminate to inner join with filter
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                t3,
+                JoinType::Inner,
+                (vec!["t1.a"], vec!["t3.a"]),
+                Some(col("t1.a").gt(lit(20u32))),
+            )?
+            .join(t2, JoinType::Inner, (vec!["t1.a"], vec!["t2.a"]), None)?
+            .filter(col("t1.a").gt(lit(15u32)))?
+            .build()?;
+
+        assert_optimization_rule_fails(&plan);
+
+        Ok(())
+    }
+
     #[test]
     /// ```txt
     /// filter: a.id = b.id and a.id = c.id

Reply via email to