alamb commented on code in PR #4185:
URL: https://github.com/apache/arrow-datafusion/pull/4185#discussion_r1028287950


##########
datafusion/optimizer/src/eliminate_cross_join.rs:
##########
@@ -849,14 +955,14 @@ mod tests {
 
         let expected = vec![
             "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < 
UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, 
c:UInt32]",
-            "  Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, 
c:UInt32]",

Review Comment:
   nice



##########
datafusion/optimizer/src/eliminate_cross_join.rs:
##########
@@ -44,143 +44,202 @@ impl ReduceCrossJoin {
     }
 }
 
+/// Attempt to reorder join tp reduce cross joins to inner joins.
+/// for queries:
+/// 'select ... from a, b where a.x = b.y and b.xx = 100;'
+/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and 
b.xx = 200);'
+/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
+/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
+/// For above queries, the join predicate is available in filters and they are 
moved to
+/// join nodes appropriately
+/// This fix helps to improve the performance of TPCH Q19. issue#78
+///
 impl OptimizerRule for ReduceCrossJoin {
     fn optimize(
         &self,
         plan: &LogicalPlan,
         _optimizer_config: &mut OptimizerConfig,
     ) -> Result<LogicalPlan> {
-        let mut possible_join_keys: Vec<(Column, Column)> = vec![];
-        let mut all_join_keys = HashSet::new();
+        match plan {
+            LogicalPlan::Filter(filter) => {
+                let input = (**filter.input()).clone();
+
+                let mut possible_join_keys: Vec<(Column, Column)> = vec![];
+                let mut all_inputs: Vec<LogicalPlan> = vec![];
+                match &input {
+                    LogicalPlan::Join(join) if (join.join_type == 
JoinType::Inner) => {
+                        flatten_join_inputs(
+                            &input,
+                            &mut possible_join_keys,
+                            &mut all_inputs,
+                        )?;
+                    }
+                    LogicalPlan::CrossJoin(_) => {
+                        flatten_join_inputs(
+                            &input,
+                            &mut possible_join_keys,
+                            &mut all_inputs,
+                        )?;
+                    }
+                    _ => {
+                        return utils::optimize_children(self, plan, 
_optimizer_config);
+                    }
+                }
+
+                let predicate = filter.predicate();
+                // join keys are handled locally
+                let mut all_join_keys: HashSet<(Column, Column)> = 
HashSet::new();
+
+                extract_possible_join_keys(predicate, &mut possible_join_keys);
+
+                let mut left = all_inputs.remove(0);
+                while !all_inputs.is_empty() {
+                    left = find_inner_join(
+                        &left,
+                        &mut all_inputs,
+                        &mut possible_join_keys,
+                        &mut all_join_keys,
+                    )?;
+                }
 
-        reduce_cross_join(self, plan, &mut possible_join_keys, &mut 
all_join_keys)
+                left = utils::optimize_children(self, &left, 
_optimizer_config)?;
+                if plan.schema() != left.schema() {
+                    left = LogicalPlan::Projection(Projection::new_from_schema(
+                        Arc::new(left.clone()),
+                        plan.schema().clone(),
+                        None,
+                    ));
+                }
+
+                // if there are no join keys then do nothing.
+                if all_join_keys.is_empty() {
+                    Ok(LogicalPlan::Filter(Filter::try_new(
+                        predicate.clone(),
+                        Arc::new(left),
+                    )?))
+                } else {
+                    // remove join expressions from filter
+                    match remove_join_expressions(predicate, &all_join_keys)? {
+                        Some(filter_expr) => 
Ok(LogicalPlan::Filter(Filter::try_new(
+                            filter_expr,
+                            Arc::new(left),
+                        )?)),
+                        _ => Ok(left),
+                    }
+                }
+            }
+
+            _ => utils::optimize_children(self, plan, _optimizer_config),
+        }
     }
 
     fn name(&self) -> &str {
         "reduce_cross_join"
     }
 }
 
-/// Attempt to reduce cross joins to inner joins.
-/// for queries:
-/// 'select ... from a, b where a.x = b.y and b.xx = 100;'
-/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and 
b.xx = 200);'
-/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z)
-/// or (a.x = b.y and b.xx = 200 and a.z=c.z);'
-/// For above queries, the join predicate is available in filters and they are 
moved to
-/// join nodes appropriately
-/// This fix helps to improve the performance of TPCH Q19. issue#78
-///
-fn reduce_cross_join(
-    _optimizer: &ReduceCrossJoin,
+fn flatten_join_inputs(
     plan: &LogicalPlan,
     possible_join_keys: &mut Vec<(Column, Column)>,
-    all_join_keys: &mut HashSet<(Column, Column)>,
-) -> Result<LogicalPlan> {
-    match plan {
-        LogicalPlan::Filter(filter) => {
-            let input = filter.input();
-            let predicate = filter.predicate();
-            // join keys are handled locally
-            let mut new_possible_join_keys: Vec<(Column, Column)> = vec![];
-            let mut new_all_join_keys = HashSet::new();
-
-            extract_possible_join_keys(predicate, &mut new_possible_join_keys);
-
-            let new_plan = reduce_cross_join(
-                _optimizer,
-                input,
-                &mut new_possible_join_keys,
-                &mut new_all_join_keys,
-            )?;
-
-            // if there are no join keys then do nothing.
-            if new_all_join_keys.is_empty() {
-                Ok(LogicalPlan::Filter(Filter::try_new(
-                    predicate.clone(),
-                    Arc::new(new_plan),
-                )?))
-            } else {
-                // remove join expressions from filter
-                match remove_join_expressions(predicate, &new_all_join_keys)? {
-                    Some(filter_expr) => 
Ok(LogicalPlan::Filter(Filter::try_new(
-                        filter_expr,
-                        Arc::new(new_plan),
-                    )?)),
-                    _ => Ok(new_plan),
-                }
+    all_inputs: &mut Vec<LogicalPlan>,
+) -> Result<()> {
+    let children = match plan {
+        LogicalPlan::Join(join) => {
+            for join_keys in join.on.iter() {
+                possible_join_keys.push(join_keys.clone());
             }
+            let left = &*(join.left);
+            let right = &*(join.right);
+            Ok::<Vec<&LogicalPlan>, DataFusionError>(vec![left, right])
         }
-        LogicalPlan::CrossJoin(cross_join) => {
-            let left_plan = reduce_cross_join(
-                _optimizer,
-                &cross_join.left,
-                possible_join_keys,
-                all_join_keys,
-            )?;
-            let right_plan = reduce_cross_join(
-                _optimizer,
-                &cross_join.right,
-                possible_join_keys,
-                all_join_keys,
-            )?;
-            // can we find a match?
-            let left_schema = left_plan.schema();
-            let right_schema = right_plan.schema();
-            let mut join_keys = vec![];
-
-            for (l, r) in possible_join_keys {
-                if left_schema.field_from_column(l).is_ok()
-                    && right_schema.field_from_column(r).is_ok()
-                    && 
can_hash(left_schema.field_from_column(l).unwrap().data_type())
-                {
-                    join_keys.push((l.clone(), r.clone()));
-                } else if left_schema.field_from_column(r).is_ok()
-                    && right_schema.field_from_column(l).is_ok()
-                    && 
can_hash(left_schema.field_from_column(r).unwrap().data_type())
-                {
-                    join_keys.push((r.clone(), l.clone()));
+        LogicalPlan::CrossJoin(join) => {
+            let left = &*(join.left);
+            let right = &*(join.right);
+            Ok::<Vec<&LogicalPlan>, DataFusionError>(vec![left, right])
+        }
+        _ => {
+            return Err(DataFusionError::Plan(
+                "flatten_join_inputs just can call 
join/cross_join".to_string(),
+            ));
+        }
+    }?;
+
+    for child in children.iter() {
+        match *child {
+            LogicalPlan::Join(left_join) => {
+                if left_join.join_type == JoinType::Inner {
+                    flatten_join_inputs(child, possible_join_keys, 
all_inputs)?;
+                } else {
+                    all_inputs.push((*child).clone());
                 }
             }
+            LogicalPlan::CrossJoin(_) => {
+                flatten_join_inputs(child, possible_join_keys, all_inputs)?;
+            }
+            _ => all_inputs.push((*child).clone()),

Review Comment:
   eventually it would be awesome to avoid so much cloneing -- maybe as a 
follow on PR



##########
datafusion/sql/src/planner.rs:
##########
@@ -2955,30 +2807,6 @@ fn extract_join_keys(
     }
 }
 
-/// Extract join keys from a WHERE clause
-fn extract_possible_join_keys(
-    expr: &Expr,
-    accum: &mut Vec<(Column, Column)>,
-) -> Result<()> {
-    match expr {
-        Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
-            Operator::Eq => match (left.as_ref(), right.as_ref()) {
-                (Expr::Column(l), Expr::Column(r)) => {
-                    accum.push((l.clone(), r.clone()));
-                    Ok(())
-                }
-                _ => Ok(()),
-            },
-            Operator::And => {
-                extract_possible_join_keys(left, accum)?;
-                extract_possible_join_keys(right, accum)
-            }
-            _ => Ok(()),
-        },
-        _ => Ok(()),
-    }

Review Comment:
   Maybe you are thinking about this copy: 
https://github.com/apache/arrow-datafusion/blob/bcd624855778384ee27648161de73951e3fb6ea1/datafusion/optimizer/src/reduce_cross_join.rs#L238-L279



##########
benchmarks/expected-plans/q2.txt:
##########
@@ -1,24 +1,25 @@
 Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, 
supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST
   Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, 
part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, 
supplier.s_comment
-    Inner Join: part.p_partkey = __sq_1.ps_partkey, partsupp.ps_supplycost = 
__sq_1.__value
-      Inner Join: nation.n_regionkey = region.r_regionkey
-        Inner Join: supplier.s_nationkey = nation.n_nationkey
-          Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
-            Inner Join: part.p_partkey = partsupp.ps_partkey
-              Filter: part.p_size = Int32(15) AND part.p_type LIKE 
Utf8("%BRASS")
-                TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size]
-              TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
-            TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
-          TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
-        Filter: region.r_name = Utf8("EUROPE")
-          TableScan: region projection=[r_regionkey, r_name]
-      Projection: partsupp.ps_partkey, MIN(partsupp.ps_supplycost) AS __value, 
alias=__sq_1
-        Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[MIN(partsupp.ps_supplycost)]]
-          Inner Join: nation.n_regionkey = region.r_regionkey
-            Inner Join: supplier.s_nationkey = nation.n_nationkey
-              Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
+    Projection: part.p_partkey, part.p_mfgr, supplier.s_name, 
supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, 
nation.n_name

Review Comment:
    I found the difference much easier to see with whitespace blind diff
   
   https://github.com/apache/arrow-datafusion/pull/4185/files?w=1
   
   I found github's rendering of this really hard to understand the change in 
plan -- I drew out the join graphs by hand to make sure they were the same.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to