This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b978cf8236 Support filter in cross join elimination (#13025)
b978cf8236 is described below

commit b978cf8236436038a106ed94fb0d7eaa6ba99962
Author: DaniĆ«l Heres <[email protected]>
AuthorDate: Tue Oct 22 04:57:04 2024 +0200

    Support filter in cross join elimination (#13025)
    
    * Support filter in cross join elimination
    
    * Support filter in cross join elimination
    
    * Support filter in cross join elimination
    
    * Support filter in cross join elimination
---
 datafusion/optimizer/src/eliminate_cross_join.rs | 61 ++++++++++++++----------
 datafusion/sqllogictest/test_files/join.slt      |  2 +-
 2 files changed, 38 insertions(+), 25 deletions(-)

diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs 
b/datafusion/optimizer/src/eliminate_cross_join.rs
index bce5c77ca6..8a365fb389 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -22,13 +22,13 @@ use crate::{OptimizerConfig, OptimizerRule};
 
 use crate::join_key_set::JoinKeySet;
 use datafusion_common::tree_node::{Transformed, TreeNode};
-use datafusion_common::{internal_err, Result};
+use datafusion_common::Result;
 use datafusion_expr::expr::{BinaryExpr, Expr};
 use datafusion_expr::logical_plan::{
     Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
 };
 use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
-use datafusion_expr::{build_join_schema, ExprSchemable, Operator};
+use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator};
 
 #[derive(Default, Debug)]
 pub struct EliminateCrossJoin;
@@ -88,6 +88,7 @@ impl OptimizerRule for EliminateCrossJoin {
         let plan_schema = Arc::clone(plan.schema());
         let mut possible_join_keys = JoinKeySet::new();
         let mut all_inputs: Vec<LogicalPlan> = vec![];
+        let mut all_filters: Vec<Expr> = vec![];
 
         let parent_predicate = if let LogicalPlan::Filter(filter) = plan {
             // if input isn't a join that can potentially be rewritten
@@ -116,6 +117,7 @@ impl OptimizerRule for EliminateCrossJoin {
                 Arc::unwrap_or_clone(input),
                 &mut possible_join_keys,
                 &mut all_inputs,
+                &mut all_filters,
             )?;
 
             extract_possible_join_keys(&predicate, &mut possible_join_keys);
@@ -130,7 +132,12 @@ impl OptimizerRule for EliminateCrossJoin {
             if !can_flatten_join_inputs(&plan) {
                 return Ok(Transformed::no(plan));
             }
-            flatten_join_inputs(plan, &mut possible_join_keys, &mut 
all_inputs)?;
+            flatten_join_inputs(
+                plan,
+                &mut possible_join_keys,
+                &mut all_inputs,
+                &mut all_filters,
+            )?;
             None
         } else {
             // recursively try to rewrite children
@@ -158,6 +165,13 @@ impl OptimizerRule for EliminateCrossJoin {
             ));
         }
 
+        if !all_filters.is_empty() {
+            // Add any filters on top - PushDownFilter can push filters down 
to applicable join
+            let first = all_filters.swap_remove(0);
+            let predicate = all_filters.into_iter().fold(first, and);
+            left = LogicalPlan::Filter(Filter::try_new(predicate, 
Arc::new(left))?);
+        }
+
         let Some(predicate) = parent_predicate else {
             return Ok(Transformed::yes(left));
         };
@@ -206,25 +220,25 @@ fn flatten_join_inputs(
     plan: LogicalPlan,
     possible_join_keys: &mut JoinKeySet,
     all_inputs: &mut Vec<LogicalPlan>,
+    all_filters: &mut Vec<Expr>,
 ) -> Result<()> {
     match plan {
         LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
-            // checked in can_flatten_join_inputs
-            if join.filter.is_some() {
-                return internal_err!(
-                    "should not have filter in inner join in 
flatten_join_inputs"
-                );
+            if let Some(filter) = join.filter {
+                all_filters.push(filter);
             }
             possible_join_keys.insert_all_owned(join.on);
             flatten_join_inputs(
                 Arc::unwrap_or_clone(join.left),
                 possible_join_keys,
                 all_inputs,
+                all_filters,
             )?;
             flatten_join_inputs(
                 Arc::unwrap_or_clone(join.right),
                 possible_join_keys,
                 all_inputs,
+                all_filters,
             )?;
         }
         LogicalPlan::CrossJoin(join) => {
@@ -232,11 +246,13 @@ fn flatten_join_inputs(
                 Arc::unwrap_or_clone(join.left),
                 possible_join_keys,
                 all_inputs,
+                all_filters,
             )?;
             flatten_join_inputs(
                 Arc::unwrap_or_clone(join.right),
                 possible_join_keys,
                 all_inputs,
+                all_filters,
             )?;
         }
         _ => {
@@ -253,13 +269,7 @@ fn flatten_join_inputs(
 fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool {
     // can only flatten inner / cross joins
     match plan {
-        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {
-            // The filter of inner join will lost, skip this rule.
-            // issue: https://github.com/apache/datafusion/issues/4844
-            if join.filter.is_some() {
-                return false;
-            }
-        }
+        LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {}
         LogicalPlan::CrossJoin(_) => {}
         _ => return false,
     };
@@ -467,12 +477,6 @@ mod tests {
         assert_eq!(&starting_schema, optimized_plan.schema())
     }
 
-    fn assert_optimization_rule_fails(plan: LogicalPlan) {
-        let rule = EliminateCrossJoin::new();
-        let transformed_plan = rule.rewrite(plan, 
&OptimizerContext::new()).unwrap();
-        assert!(!transformed_plan.transformed)
-    }
-
     #[test]
     fn eliminate_cross_with_simple_and() -> Result<()> {
         let t1 = test_table_scan_with_name("t1")?;
@@ -642,8 +646,7 @@ mod tests {
     }
 
     #[test]
-    /// See https://github.com/apache/datafusion/issues/7530
-    fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> 
Result<()> {
+    fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> {
         let t1 = test_table_scan_with_name("t1")?;
         let t2 = test_table_scan_with_name("t2")?;
         let t3 = test_table_scan_with_name("t3")?;
@@ -660,7 +663,17 @@ mod tests {
             .filter(col("t1.a").gt(lit(15u32)))?
             .build()?;
 
-        assert_optimization_rule_fails(plan);
+        let expected = vec![
+            "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "  Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "    Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "      Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32]",
+            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", 
+            "      TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"
+        ];
+
+        assert_optimized_plan_eq(plan, expected);
 
         Ok(())
     }
diff --git a/datafusion/sqllogictest/test_files/join.slt 
b/datafusion/sqllogictest/test_files/join.slt
index fe9ceaa790..39f903a587 100644
--- a/datafusion/sqllogictest/test_files/join.slt
+++ b/datafusion/sqllogictest/test_files/join.slt
@@ -1152,7 +1152,7 @@ logical_plan
 01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1
 02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS 
Float64) > Float64(0)
 03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4
-04)------Inner Join: Using t1.v0 = t5.v0, t1.v1 = t5.v1
+04)------Inner Join: t1.v0 = t5.v0, t1.v1 = t5.v1
 05)--------TableScan: t1 projection=[v0, v1]
 06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4]
 07)----TableScan: t0 projection=[v0, v1]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to