This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 734c211c3 Move the extract_join_keys to optimizer (#4711)
734c211c3 is described below

commit 734c211c3832b004cdb3cd57d1815c3fe006388a
Author: ygf11 <[email protected]>
AuthorDate: Tue Dec 27 18:00:13 2022 +0800

    Move the extract_join_keys to optimizer (#4711)
    
    * Move extract_join_keys to optimizer
    
    * rename ExtractEquijoinExpr to ExtractEquijoinPredicate
    
    * fix cargo clippy
    
    * utilize the optimizer to traverse the plan tree
    
    * add a new test
---
 .../optimizer/src/extract_equijoin_predicate.rs    | 420 +++++++++++++++++++++
 datafusion/optimizer/src/lib.rs                    |   1 +
 datafusion/optimizer/src/optimizer.rs              |   2 +
 datafusion/sql/src/planner.rs                      | 276 ++++----------
 4 files changed, 487 insertions(+), 212 deletions(-)

diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs 
b/datafusion/optimizer/src/extract_equijoin_predicate.rs
new file mode 100644
index 000000000..214fbe728
--- /dev/null
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -0,0 +1,420 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimizer rule to extract equijoin expr from filter
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
+use datafusion_common::DFSchema;
+use datafusion_common::Result;
+use datafusion_expr::utils::{can_hash, check_all_column_from_schema};
+use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, 
Operator};
+use std::sync::Arc;
+
+/// Optimization rule that extract equijoin expr from the filter
+#[derive(Default)]
+pub struct ExtractEquijoinPredicate;
+
+impl ExtractEquijoinPredicate {
+    #[allow(missing_docs)]
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for ExtractEquijoinPredicate {
+    fn try_optimize(
+        &self,
+        plan: &LogicalPlan,
+        _config: &dyn OptimizerConfig,
+    ) -> Result<Option<LogicalPlan>> {
+        match plan {
+            LogicalPlan::Join(Join {
+                left,
+                right,
+                on,
+                filter,
+                join_type,
+                join_constraint,
+                schema,
+                null_equals_null,
+            }) => {
+                let left_schema = left.schema();
+                let right_schema = right.schema();
+
+                filter.as_ref().map_or(Result::Ok(None), |expr| {
+                    let mut accum: Vec<(Expr, Expr)> = vec![];
+                    let mut accum_filter: Vec<Expr> = vec![];
+                    // TODO: avoding clone with split_conjunction
+                    extract_join_keys(
+                        expr.clone(),
+                        &mut accum,
+                        &mut accum_filter,
+                        left_schema,
+                        right_schema,
+                    )?;
+
+                    let optimized_plan = (!accum.is_empty()).then(|| {
+                        let mut new_on = on.clone();
+                        new_on.extend(accum);
+
+                        let new_filter = 
accum_filter.into_iter().reduce(Expr::and);
+                        LogicalPlan::Join(Join {
+                            left: left.clone(),
+                            right: right.clone(),
+                            on: new_on,
+                            filter: new_filter,
+                            join_type: *join_type,
+                            join_constraint: *join_constraint,
+                            schema: schema.clone(),
+                            null_equals_null: *null_equals_null,
+                        })
+                    });
+
+                    Ok(optimized_plan)
+                })
+            }
+            _ => Ok(None),
+        }
+    }
+
+    fn name(&self) -> &str {
+        "extract_equijoin_predicate"
+    }
+
+    fn apply_order(&self) -> Option<ApplyOrder> {
+        Some(ApplyOrder::BottomUp)
+    }
+}
+
+/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs
+/// Filters matching this pattern are added to `accum`
+/// Filters that don't match this pattern are added to `accum_filter`
+/// Examples:
+/// ```text
+/// foo = bar => accum=[(foo, bar)] accum_filter=[]
+/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[]
+/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1]
+///
+/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, 
c2):
+/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10]
+/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)],  accum_filter=[]
+/// (a.c0 + b.c0 = 10) =>  accum=[], accum_filter=[a.c0 + b.c0 = 10]
+/// ```
+fn extract_join_keys(
+    expr: Expr,
+    accum: &mut Vec<(Expr, Expr)>,
+    accum_filter: &mut Vec<Expr>,
+    left_schema: &Arc<DFSchema>,
+    right_schema: &Arc<DFSchema>,
+) -> Result<()> {
+    match &expr {
+        Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
+            Operator::Eq => {
+                let left = *left.clone();
+                let right = *right.clone();
+                let left_using_columns = left.to_columns()?;
+                let right_using_columns = right.to_columns()?;
+
+                // When one side key does not contain columns, we need move 
this expression to filter.
+                // For example: a = 1, a = now() + 10.
+                if left_using_columns.is_empty() || 
right_using_columns.is_empty() {
+                    accum_filter.push(expr);
+                    return Ok(());
+                }
+
+                // Checking left join key is from left schema, right join key 
is from right schema, or the opposite.
+                let l_is_left = check_all_column_from_schema(
+                    &left_using_columns,
+                    left_schema.clone(),
+                )?;
+                let r_is_right = check_all_column_from_schema(
+                    &right_using_columns,
+                    right_schema.clone(),
+                )?;
+
+                let r_is_left_and_l_is_right = || {
+                    let result = check_all_column_from_schema(
+                        &right_using_columns,
+                        left_schema.clone(),
+                    )? && check_all_column_from_schema(
+                        &left_using_columns,
+                        right_schema.clone(),
+                    )?;
+
+                    Result::Ok(result)
+                };
+
+                let join_key_pair = match (l_is_left, r_is_right) {
+                    (true, true) => Some((left, right)),
+                    (_, _) if r_is_left_and_l_is_right()? => Some((right, 
left)),
+                    _ => None,
+                };
+
+                if let Some((left_expr, right_expr)) = join_key_pair {
+                    let left_expr_type = left_expr.get_type(left_schema)?;
+                    let right_expr_type = right_expr.get_type(right_schema)?;
+
+                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) 
{
+                        accum.push((left_expr, right_expr));
+                    } else {
+                        accum_filter.push(expr);
+                    }
+                } else {
+                    accum_filter.push(expr);
+                }
+            }
+            Operator::And => {
+                if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = 
expr {
+                    extract_join_keys(
+                        *left,
+                        accum,
+                        accum_filter,
+                        left_schema,
+                        right_schema,
+                    )?;
+                    extract_join_keys(
+                        *right,
+                        accum,
+                        accum_filter,
+                        left_schema,
+                        right_schema,
+                    )?;
+                }
+            }
+            _other => {
+                accum_filter.push(expr);
+            }
+        },
+        _other => {
+            accum_filter.push(expr);
+        }
+    }
+
+    Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::test::*;
+    use datafusion_common::Column;
+    use datafusion_expr::{
+        col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
+    };
+
+    fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+        assert_optimized_plan_eq_display_indent(
+            Arc::new(ExtractEquijoinPredicate {}),
+            plan,
+            expected,
+        );
+
+        Ok(())
+    }
+
+    #[test]
+    fn join_with_only_column_equi_predicate() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                t2,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some(col("t1.a").eq(col("t2.a"))),
+            )?
+            .build()?;
+        let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn join_with_only_equi_expr_predicate() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                t2,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
+            )?
+            .build()?;
+        let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn join_with_only_none_equi_predicate() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                t2,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some(
+                    (col("t1.a") + lit(10i64))
+                        .gt_eq(col("t2.a") * lit(2u32))
+                        .and(col("t1.b").lt(lit(100i32))),
+                ),
+            )?
+            .build()?;
+        let expected = "Left Join:  Filter: t1.a + Int64(10) >= t2.a * 
UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, 
b:UInt32, c:UInt32]\
+            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(t1)
+            .join_with_expr_keys(
+                t2,
+                JoinType::Left,
+                (
+                    vec![col("t1.a") + lit(11u32)],
+                    vec![col("t2.a") * lit(2u32)],
+                ),
+                Some(
+                    (col("t1.a") + lit(10i64))
+                        .eq(col("t2.a") * lit(2u32))
+                        .and(col("t1.b").lt(lit(100i32))),
+                ),
+            )?
+            .build()?;
+        let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a 
+ Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, 
c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn join_with_and_or_filter() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                t2,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some(
+                    col("t1.c")
+                        .eq(col("t2.c"))
+                        .or((col("t1.a") + col("t1.b")).gt(col("t2.b") + 
col("t2.c")))
+                        .and(
+                            
col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
+                        ),
+                ),
+            )?
+            .build()?;
+        let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = 
t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, 
b:UInt32, c:UInt32]\
+            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn join_with_multiple_table() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+        let t3 = test_table_scan_with_name("t3")?;
+
+        let input = LogicalPlanBuilder::from(t2)
+            .join(
+                t3,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some(
+                    col("t2.a")
+                        .eq(col("t3.a"))
+                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
+                ),
+            )?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                input,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some(
+                    col("t1.a")
+                        .eq(col("t2.a"))
+                        .and((col("t1.c") + col("t2.c") + 
col("t3.c")).lt(lit(100u32))),
+                ),
+            )?
+            .build()?;
+        let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < 
UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32]\
+            \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+            \n  Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+            \n    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+            \n    TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+
+    #[test]
+    fn join_with_multiple_table_and_eq_filter() -> Result<()> {
+        let t1 = test_table_scan_with_name("t1")?;
+        let t2 = test_table_scan_with_name("t2")?;
+        let t3 = test_table_scan_with_name("t3")?;
+
+        let input = LogicalPlanBuilder::from(t2)
+            .join(
+                t3,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                Some(
+                    col("t2.a")
+                        .eq(col("t3.a"))
+                        .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
+                ),
+            )?
+            .build()?;
+        let plan = LogicalPlanBuilder::from(t1)
+            .join(
+                input,
+                JoinType::Left,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                
Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
+            )?
+            .build()?;
+        let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+        \n  TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+        \n  Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+        \n    TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
+
+        assert_plan_eq(&plan, expected)
+    }
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index a4804ca5b..b03725fe6 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -22,6 +22,7 @@ pub mod eliminate_cross_join;
 pub mod eliminate_filter;
 pub mod eliminate_limit;
 pub mod eliminate_outer_join;
+pub mod extract_equijoin_predicate;
 pub mod filter_null_join_keys;
 pub mod inline_table_scan;
 pub mod optimizer;
diff --git a/datafusion/optimizer/src/optimizer.rs 
b/datafusion/optimizer/src/optimizer.rs
index 6fe94e792..36968f2f1 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -24,6 +24,7 @@ use crate::eliminate_cross_join::EliminateCrossJoin;
 use crate::eliminate_filter::EliminateFilter;
 use crate::eliminate_limit::EliminateLimit;
 use crate::eliminate_outer_join::EliminateOuterJoin;
+use crate::extract_equijoin_predicate::ExtractEquijoinPredicate;
 use crate::filter_null_join_keys::FilterNullJoinKeys;
 use crate::inline_table_scan::InlineTableScan;
 use crate::propagate_empty_relation::PropagateEmptyRelation;
@@ -237,6 +238,7 @@ impl Optimizer {
         let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
             Arc::new(InlineTableScan::new()),
             Arc::new(TypeCoercion::new()),
+            Arc::new(ExtractEquijoinPredicate::new()),
             Arc::new(SimplifyExpressions::new()),
             Arc::new(UnwrapCastInComparison::new()),
             Arc::new(DecorrelateWhereExists::new()),
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index dc43cbaf1..a4cc0b775 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -58,9 +58,8 @@ use datafusion_expr::logical_plan::{
 };
 use datafusion_expr::logical_plan::{Filter, Prepare, Subquery};
 use datafusion_expr::utils::{
-    can_hash, check_all_column_from_schema, expand_qualified_wildcard, 
expand_wildcard,
-    expr_as_column_expr, expr_to_columns, find_aggregate_exprs, 
find_column_exprs,
-    find_window_exprs, COUNT_STAR_EXPANSION,
+    expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, 
expr_to_columns,
+    find_aggregate_exprs, find_column_exprs, find_window_exprs, 
COUNT_STAR_EXPANSION,
 };
 use datafusion_expr::Expr::Alias;
 use datafusion_expr::{
@@ -806,7 +805,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     ) -> Result<LogicalPlan> {
         match constraint {
             JoinConstraint::On(sql_expr) => {
-                let mut keys: Vec<(Expr, Expr)> = vec![];
                 let join_schema = left.schema().join(right.schema())?;
 
                 // parse ON expression
@@ -820,45 +818,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
 
                 // normalize all columns in expression
                 let using_columns = expr.to_columns()?;
-                let normalized_expr = normalize_col_with_schemas(
+                let filter = normalize_col_with_schemas(
                     expr,
                     &[left.schema(), right.schema()],
                     &[using_columns],
                 )?;
 
-                // expression that didn't match equi-join pattern
-                let mut filter = vec![];
-
-                // extract join keys
-                extract_join_keys(
-                    normalized_expr,
-                    &mut keys,
-                    &mut filter,
-                    left.schema(),
-                    right.schema(),
-                )?;
-
-                let (left_keys, right_keys): (Vec<Expr>, Vec<Expr>) =
-                    keys.into_iter().unzip();
-
-                let join_filter = filter.into_iter().reduce(Expr::and);
-
-                if left_keys.is_empty() && join_filter.is_none() {
-                    let mut join = 
LogicalPlanBuilder::from(left).cross_join(right)?;
-                    if let Some(filter) = join_filter {
-                        join = join.filter(filter)?;
-                    }
-                    join.build()
-                } else {
-                    LogicalPlanBuilder::from(left)
-                        .join_with_expr_keys(
-                            right,
-                            join_type,
-                            (left_keys, right_keys),
-                            join_filter,
-                        )?
-                        .build()
-                }
+                LogicalPlanBuilder::from(left)
+                    .join(
+                        right,
+                        join_type,
+                        (Vec::<Column>::new(), Vec::<Column>::new()),
+                        Some(filter),
+                    )?
+                    .build()
             }
             JoinConstraint::Using(idents) => {
                 let keys: Vec<Column> = idents
@@ -3095,113 +3068,6 @@ pub fn object_name_to_qualifier(sql_table_name: 
&ObjectName) -> String {
         .join(" AND ")
 }
 
-/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs
-/// Filters matching this pattern are added to `accum`
-/// Filters that don't match this pattern are added to `accum_filter`
-/// Examples:
-/// ```text
-/// foo = bar => accum=[(foo, bar)] accum_filter=[]
-/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[]
-/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1]
-///
-/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, 
c2):
-/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10]
-/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)],  accum_filter=[]
-/// (a.c0 + b.c0 = 10) =>  accum=[], accum_filter=[a.c0 + b.c0 = 10]
-/// ```
-fn extract_join_keys(
-    expr: Expr,
-    accum: &mut Vec<(Expr, Expr)>,
-    accum_filter: &mut Vec<Expr>,
-    left_schema: &Arc<DFSchema>,
-    right_schema: &Arc<DFSchema>,
-) -> Result<()> {
-    match &expr {
-        Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
-            Operator::Eq => {
-                let left = *left.clone();
-                let right = *right.clone();
-                let left_using_columns = left.to_columns()?;
-                let right_using_columns = right.to_columns()?;
-
-                // When one side key does not contain columns, we need move 
this expression to filter.
-                // For example: a = 1, a = now() + 10.
-                if left_using_columns.is_empty() || 
right_using_columns.is_empty() {
-                    accum_filter.push(expr);
-                    return Ok(());
-                }
-
-                // Checking left join key is from left schema, right join key 
is from right schema, or the opposite.
-                let l_is_left = check_all_column_from_schema(
-                    &left_using_columns,
-                    left_schema.clone(),
-                )?;
-                let r_is_right = check_all_column_from_schema(
-                    &right_using_columns,
-                    right_schema.clone(),
-                )?;
-
-                let r_is_left_and_l_is_right = || {
-                    let result = check_all_column_from_schema(
-                        &right_using_columns,
-                        left_schema.clone(),
-                    )? && check_all_column_from_schema(
-                        &left_using_columns,
-                        right_schema.clone(),
-                    )?;
-
-                    Result::Ok(result)
-                };
-
-                let join_key_pair = match (l_is_left, r_is_right) {
-                    (true, true) => Some((left, right)),
-                    (_, _) if r_is_left_and_l_is_right()? => Some((right, 
left)),
-                    _ => None,
-                };
-
-                if let Some((left_expr, right_expr)) = join_key_pair {
-                    let left_expr_type = left_expr.get_type(left_schema)?;
-                    let right_expr_type = right_expr.get_type(right_schema)?;
-
-                    if can_hash(&left_expr_type) && can_hash(&right_expr_type) 
{
-                        accum.push((left_expr, right_expr));
-                    } else {
-                        accum_filter.push(expr);
-                    }
-                } else {
-                    accum_filter.push(expr);
-                }
-            }
-            Operator::And => {
-                if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = 
expr {
-                    extract_join_keys(
-                        *left,
-                        accum,
-                        accum_filter,
-                        left_schema,
-                        right_schema,
-                    )?;
-                    extract_join_keys(
-                        *right,
-                        accum,
-                        accum_filter,
-                        left_schema,
-                        right_schema,
-                    )?;
-                }
-            }
-            _other => {
-                accum_filter.push(expr);
-            }
-        },
-        _other => {
-            accum_filter.push(expr);
-        }
-    }
-
-    Ok(())
-}
-
 /// Ensure any column reference of the expression is unambiguous.
 /// Assume we have two schema:
 /// schema1: a, b ,c
@@ -4620,9 +4486,9 @@ mod tests {
             JOIN orders \
             ON id = customer_id";
         let expected = "Projection: person.id, orders.order_id\
-        \n  Inner Join: person.id = orders.customer_id\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id = orders.customer_id\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -4633,7 +4499,7 @@ mod tests {
             JOIN orders \
             ON id = customer_id AND order_id > 1 ";
         let expected = "Projection: person.id, orders.order_id\
-            \n  Inner Join: person.id = orders.customer_id Filter: 
orders.order_id > Int64(1)\
+            \n  Inner Join:  Filter: person.id = orders.customer_id AND 
orders.order_id > Int64(1)\
             \n    TableScan: person\
             \n    TableScan: orders";
 
@@ -4647,7 +4513,7 @@ mod tests {
             LEFT JOIN orders \
             ON id = customer_id AND order_id > 1 AND age < 30";
         let expected = "Projection: person.id, orders.order_id\
-            \n  Left Join: person.id = orders.customer_id Filter: 
orders.order_id > Int64(1) AND person.age < Int64(30)\
+            \n  Left Join:  Filter: person.id = orders.customer_id AND 
orders.order_id > Int64(1) AND person.age < Int64(30)\
             \n    TableScan: person\
             \n    TableScan: orders";
         quick_test(sql, expected);
@@ -4659,8 +4525,9 @@ mod tests {
             FROM person \
             RIGHT JOIN orders \
             ON id = customer_id AND id > 1 AND order_id < 100";
+
         let expected = "Projection: person.id, orders.order_id\
-            \n  Right Join: person.id = orders.customer_id Filter: person.id > 
Int64(1) AND orders.order_id < Int64(100)\
+            \n  Right Join:  Filter: person.id = orders.customer_id AND 
person.id > Int64(1) AND orders.order_id < Int64(100)\
             \n    TableScan: person\
             \n    TableScan: orders";
         quick_test(sql, expected);
@@ -4673,9 +4540,9 @@ mod tests {
             FULL JOIN orders \
             ON id = customer_id AND id > 1 AND order_id < 100";
         let expected = "Projection: person.id, orders.order_id\
-        \n  Full Join: person.id = orders.customer_id Filter: person.id > 
Int64(1) AND orders.order_id < Int64(100)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Full Join:  Filter: person.id = orders.customer_id AND 
person.id > Int64(1) AND orders.order_id < Int64(100)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -4686,9 +4553,9 @@ mod tests {
             JOIN orders \
             ON person.id = orders.customer_id";
         let expected = "Projection: person.id, orders.order_id\
-        \n  Inner Join: person.id = orders.customer_id\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id = orders.customer_id\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -4727,8 +4594,8 @@ mod tests {
             JOIN orders ON id = customer_id \
             JOIN lineitem ON o_item_id = l_item_id";
         let expected = "Projection: person.id, orders.order_id, 
lineitem.l_description\
-            \n  Inner Join: orders.o_item_id = lineitem.l_item_id\
-            \n    Inner Join: person.id = orders.customer_id\
+            \n  Inner Join:  Filter: orders.o_item_id = lineitem.l_item_id\
+            \n    Inner Join:  Filter: person.id = orders.customer_id\
             \n      TableScan: person\
             \n      TableScan: orders\
             \n    TableScan: lineitem";
@@ -5517,11 +5384,11 @@ mod tests {
     fn join_with_aliases() {
         let sql = "select peeps.id, folks.first_name from person as peeps join 
person as folks on peeps.id = folks.id";
         let expected = "Projection: peeps.id, folks.first_name\
-                                    \n  Inner Join: peeps.id = folks.id\
-                                    \n    SubqueryAlias: peeps\
-                                    \n      TableScan: person\
-                                    \n    SubqueryAlias: folks\
-                                    \n      TableScan: person";
+            \n  Inner Join:  Filter: peeps.id = folks.id\
+            \n    SubqueryAlias: peeps\
+            \n      TableScan: person\
+            \n    SubqueryAlias: folks\
+            \n      TableScan: person";
         quick_test(sql, expected);
     }
 
@@ -5855,7 +5722,7 @@ mod tests {
             FROM person \
             JOIN orders ON id = customer_id AND (person.age > 30 OR 
person.last_name = 'X')";
         let expected = "Projection: person.id, orders.order_id\
-            \n  Inner Join: person.id = orders.customer_id Filter: person.age 
> Int64(30) OR person.last_name = Utf8(\"X\")\
+            \n  Inner Join:  Filter: person.id = orders.customer_id AND 
(person.age > Int64(30) OR person.last_name = Utf8(\"X\"))\
             \n    TableScan: person\
             \n    TableScan: orders";
         quick_test(sql, expected);
@@ -5981,9 +5848,9 @@ mod tests {
             ON orders.customer_id * 2 = person.id + 10";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: orders.customer_id * Int64(2) = person.id 
+ Int64(10)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
 
         quick_test(sql, expected);
     }
@@ -5996,9 +5863,9 @@ mod tests {
             ON person.id + 10 = orders.customer_id * 2";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id + Int64(10) = 
orders.customer_id * Int64(2)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6010,37 +5877,37 @@ mod tests {
             ON person.id + person.age + 10 = orders.customer_id * 2 - 
orders.price";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Inner Join: person.id + person.age + Int64(10) = 
orders.customer_id * Int64(2) - orders.price\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id + person.age + Int64(10) = 
orders.customer_id * Int64(2) - orders.price\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
     #[test]
-    fn test_left_projection_expr_eq_join() {
+    fn test_left_expr_eq_join() {
         let sql = "SELECT id, order_id \
             FROM person \
             INNER JOIN orders \
             ON person.id + person.age + 10 = orders.customer_id";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Inner Join: person.id + person.age + Int64(10) = 
orders.customer_id\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id + person.age + Int64(10) = 
orders.customer_id\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
     #[test]
-    fn test_right_projection_expr_eq_join() {
+    fn test_right_expr_eq_join() {
         let sql = "SELECT id, order_id \
             FROM person \
             INNER JOIN orders \
             ON person.id = orders.customer_id * 2 - orders.price";
 
         let expected = "Projection: person.id, orders.order_id\
-       \n  Inner Join: person.id = orders.customer_id * Int64(2) - 
orders.price\
-       \n    TableScan: person\
-       \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id = orders.customer_id * Int64(2) 
- orders.price\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6108,9 +5975,9 @@ mod tests {
             ON orders.customer_id * 2 = person.id + 10";
 
         let expected = "Projection: person.id, person.first_name, 
person.last_name, person.age, person.state, person.salary, person.birth_date, 
person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: orders.customer_id * Int64(2) = person.id 
+ Int64(10)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6122,24 +5989,9 @@ mod tests {
             ON orders.customer_id * 2 = person.id + 10";
 
         let expected = "Projection: orders.customer_id * Int64(2), person.id + 
Int64(10)\
-        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
-        quick_test(sql, expected);
-    }
-
-    #[test]
-    fn test_non_projetion_after_inner_join() {
-        // There's no need to add projection for left and right, so does 
adding projection after join.
-        let sql = "SELECT  person.id, person.age
-            FROM person
-            INNER JOIN orders
-            ON orders.customer_id = person.id";
-
-        let expected = "Projection: person.id, person.age\
-        \n  Inner Join: person.id = orders.customer_id\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: orders.customer_id * Int64(2) = person.id 
+ Int64(10)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6152,9 +6004,9 @@ mod tests {
             ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = 
orders.order_id";
 
         let expected = "Projection: person.id, person.age\
-        \n  Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), 
person.id * Int64(2) = orders.order_id\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id * Int64(2) = orders.customer_id 
+ Int64(10) AND person.id * Int64(2) = orders.order_id\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6167,9 +6019,9 @@ mod tests {
             ON person.id * 2 = orders.customer_id + 10 and person.id =  
orders.customer_id + 10";
 
         let expected = "Projection: person.id, person.age\
-        \n  Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), 
person.id = orders.customer_id + Int64(10)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: person.id * Int64(2) = orders.customer_id 
+ Int64(10) AND person.id = orders.customer_id + Int64(10)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6587,9 +6439,9 @@ mod tests {
             ON cast(person.id as Int) = cast(orders.customer_id as Int)";
 
         let expected = "Projection: person.id, person.age\
-        \n  Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS 
Int32)\
-        \n    TableScan: person\
-        \n    TableScan: orders";
+            \n  Inner Join:  Filter: CAST(person.id AS Int32) = 
CAST(orders.customer_id AS Int32)\
+            \n    TableScan: person\
+            \n    TableScan: orders";
         quick_test(sql, expected);
     }
 


Reply via email to