This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d36529ab Refactor: Change equijoin keys from column to expression in 
logical join (#4602)
8d36529ab is described below

commit 8d36529abab500f4effd805e204b78ff0ab03d17
Author: ygf11 <[email protected]>
AuthorDate: Sat Dec 17 19:06:40 2022 +0800

    Refactor: Change equijoin keys from column to expression in logical join 
(#4602)
    
    * Refactor: Make equi-join keys from column to expr in logical join plan
    
    * resolve conflicts
    
    * add tests
    
    * remove unused code
    
    * fix test
    
    * add test
    
    * remove unused code
    
    * fmt
    
    * clippy
    
    * resolve merge conflict
    
    * Update datafusion/expr/src/logical_plan/builder.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * rename join_keys to equi_exprs
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_plan/planner.rs       |  78 ++++-
 datafusion/core/tests/sql/joins.rs                 | 368 +++++++++++++++++++--
 datafusion/core/tests/sql/mod.rs                   |   4 +-
 datafusion/expr/src/logical_plan/builder.rs        | 112 ++++++-
 datafusion/expr/src/logical_plan/plan.rs           |  59 +++-
 datafusion/optimizer/src/eliminate_cross_join.rs   |  84 ++---
 datafusion/optimizer/src/filter_null_join_keys.rs  | 130 +++++---
 datafusion/optimizer/src/push_down_filter.rs       |  12 +-
 datafusion/optimizer/src/push_down_projection.rs   |   4 +-
 .../optimizer/src/subquery_filter_to_join.rs       |   2 +-
 datafusion/proto/proto/datafusion.proto            |   4 +-
 datafusion/proto/src/generated/pbjson.rs           |  52 +--
 datafusion/proto/src/generated/prost.rs            |   4 +-
 datafusion/proto/src/logical_plan.rs               |  47 ++-
 datafusion/sql/src/planner.rs                      | 131 +++-----
 15 files changed, 807 insertions(+), 284 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 5ba8cabe9..53c88d5e9 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -65,6 +65,8 @@ use datafusion_expr::expr::{
     Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, TryCast,
 };
 use datafusion_expr::expr_rewriter::unnormalize_cols;
+use datafusion_expr::logical_plan;
+use 
datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
 use datafusion_expr::utils::expand_wildcard;
 use datafusion_expr::{WindowFrame, WindowFrameBound};
 use datafusion_optimizer::utils::unalias;
@@ -846,8 +848,78 @@ impl DefaultPhysicalPlanner {
                     filter,
                     join_type,
                     null_equals_null,
+                    schema: join_schema,
                     ..
                 }) => {
+                    // If join has expression equijoin keys, add physical 
projecton.
+                    let has_expr_join_key = keys.iter().any(|(l, r)| {
+                        !(matches!(l, Expr::Column(_))
+                            && matches!(r, Expr::Column(_)))
+                    });
+                    if has_expr_join_key {
+                        let left_keys = keys
+                            .iter()
+                            .map(|(l, _r)| l)
+                            .cloned()
+                            .collect::<Vec<_>>();
+                        let right_keys = keys
+                            .iter()
+                            .map(|(_l, r)| r)
+                            .cloned()
+                            .collect::<Vec<_>>();
+                        let (left, right, column_on, added_project) = {
+                            let (left, left_col_keys, left_projected) =
+                                wrap_projection_for_join_if_necessary(
+                                    left_keys.as_slice(),
+                                    left.as_ref().clone(),
+                                )?;
+                            let (right, right_col_keys, right_projected) =
+                                wrap_projection_for_join_if_necessary(
+                                    &right_keys,
+                                    right.as_ref().clone(),
+                                )?;
+                            (
+                                left,
+                                right,
+                                (left_col_keys, right_col_keys),
+                                left_projected || right_projected,
+                            )
+                        };
+
+                        let join_plan =
+                            LogicalPlan::Join(Join::try_new_with_project_input(
+                                logical_plan,
+                                Arc::new(left),
+                                Arc::new(right),
+                                column_on,
+                            )?);
+
+                        // Remove temporary projected columns
+                        let join_plan = if added_project {
+                            let final_join_result = join_schema
+                                .fields()
+                                .iter()
+                                .map(|field| {
+                                    Expr::Column(field.qualified_column())
+                                })
+                                .collect::<Vec<_>>();
+                            let projection =
+                                logical_plan::Projection::try_new_with_schema(
+                                    final_join_result,
+                                    Arc::new(join_plan),
+                                    join_schema.clone(),
+                                )?;
+                            LogicalPlan::Projection(projection)
+                        } else {
+                            join_plan
+                        };
+
+                        return self
+                            .create_initial_plan(&join_plan, session_state)
+                            .await;
+                    }
+
+                    // All equi-join keys are columns now, create physical 
join plan
                     let left_df_schema = left.schema();
                     let physical_left = self.create_initial_plan(left, 
session_state).await?;
                     let right_df_schema = right.schema();
@@ -855,9 +927,11 @@ impl DefaultPhysicalPlanner {
                     let join_on = keys
                         .iter()
                         .map(|(l, r)| {
+                            let l = l.try_into_col()?;
+                            let r = r.try_into_col()?;
                             Ok((
-                                Column::new(&l.name, 
left_df_schema.index_of_column(l)?),
-                                Column::new(&r.name, 
right_df_schema.index_of_column(r)?),
+                                Column::new(&l.name, 
left_df_schema.index_of_column(&l)?),
+                                Column::new(&r.name, 
right_df_schema.index_of_column(&r)?),
                             ))
                         })
                         .collect::<Result<join_utils::JoinOn>>()?;
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index 1094a818f..e3f92cbb3 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2322,12 +2322,11 @@ async fn reduce_cross_join_with_expr_join_key_all() -> 
Result<()> {
         let expected = vec![
             "Explain [plan_type:Utf8, plan:Utf8]",
             "  Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, 
t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-            "    Inner Join: t1.t1_id + Int64(12) = t2.t2_id + Int64(1) 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
-            "      Projection: t1.t1_id, t1.t1_name, t1.t1_int, CAST(t1.t1_id 
AS Int64) + Int64(12) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t1.t1_id + Int64(12):Int64;N]",
-            "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-            "      Projection: t2.t2_id, t2.t2_name, t2.t2_int, CAST(t2.t2_id 
AS Int64) + Int64(1) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id 
+ Int64(1):Int64;N]",
-            "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "    Inner Join: CAST(t1.t1_id AS Int64) + Int64(12) = 
CAST(t2.t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         ];
+
         let formatted = plan.display_indent_schema().to_string();
         let actual: Vec<&str> = formatted.trim().lines().collect();
         assert_eq!(
@@ -2367,15 +2366,13 @@ async fn reduce_cross_join_with_cast_expr_join_key() -> 
Result<()> {
         let state = ctx.state();
         let plan = state.optimize(&plan)?;
         let expected = vec![
-            "Explain [plan_type:Utf8, plan:Utf8]",
-            "  Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, 
t2_id:UInt32;N, t1_name:Utf8;N]",
-            "    Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t2_id:UInt32;N]",
-            "      Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64) 
[t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N, t2_id:UInt32;N, 
CAST(t2.t2_id AS Int64):Int64;N]",
-            "        Projection: t1.t1_id, t1.t1_name, CAST(t1.t1_id AS Int64) 
+ Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N]",
-            "          TableScan: t1 projection=[t1_id, t1_name] 
[t1_id:UInt32;N, t1_name:Utf8;N]",
-            "        Projection: t2.t2_id, CAST(t2.t2_id AS Int64) AS 
CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
-            "          TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+           "Explain [plan_type:Utf8, plan:Utf8]",
+           "  Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N, 
t2_id:UInt32;N, t1_name:Utf8;N]",
+           "    Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = 
CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+           "      TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, 
t1_name:Utf8;N]",
+           "      TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
         ];
+
         let formatted = plan.display_indent_schema().to_string();
         let actual: Vec<&str> = formatted.trim().lines().collect();
         assert_eq!(
@@ -2407,6 +2404,8 @@ async fn reduce_cross_join_with_wildcard_and_expr() -> 
Result<()> {
         let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
 
         let sql = "select *,t1.t1_id+11 from t1,t2 where t1.t1_id+11=t2.t2_id";
+
+        // assert logical plan
         let msg = format!("Creating logical plan for '{}'", sql);
         let plan = ctx
             .create_logical_plan(&("explain ".to_owned() + sql))
@@ -2417,12 +2416,9 @@ async fn reduce_cross_join_with_wildcard_and_expr() -> 
Result<()> {
         let expected = vec![
             "Explain [plan_type:Utf8, plan:Utf8]",
             "  Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, 
t2.t2_name, t2.t2_int, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N, t1.t1_id + Int64(11):Int64;N]",
-            "    Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, 
t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-            "      Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64) 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(11):Int64;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, CAST(t2.t2_id AS 
Int64):Int64;N]",
-            "        Projection: t1.t1_id, t1.t1_name, t1.t1_int, 
CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N, t1.t1_id + Int64(11):Int64;N]",
-            "          TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-            "        Projection: t2.t2_id, t2.t2_name, t2.t2_int, 
CAST(t2.t2_id AS Int64) AS CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, 
t2_name:Utf8;N, t2_int:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
-            "          TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "    Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = 
CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+            "      TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+            "      TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]"
         ];
 
         let formatted = plan.display_indent_schema().to_string();
@@ -2432,6 +2428,54 @@ async fn reduce_cross_join_with_wildcard_and_expr() -> 
Result<()> {
             "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
             expected, actual
         );
+
+        // assert physical plan
+        let msg = format!("Creating physical plan for '{}'", sql);
+        let plan = ctx.create_logical_plan(sql).expect(&msg);
+        let state = ctx.state();
+        let logical_plan = state.optimize(&plan)?;
+        let physical_plan = state.create_physical_plan(&logical_plan).await?;
+        let expected = if repartition_joins {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, 
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, 
CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 
as t2_int]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=Partitioned, join_type=Inner, 
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 3 }, Column { name: 
\"CAST(t2.t2_id AS Int64)\", index: 3 })]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t1.t1_id + Int64(11)\", index: 3 }], 2)",
+                "            ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 
as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + 
Int64(11)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"CAST(t2.t2_id AS Int64)\", index: 3 }], 2)",
+                "            ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 
as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS 
Int64)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+           ]
+        } else {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, 
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int, 
CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6 
as t2_int]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=CollectLeft, join_type=Inner, 
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 3 }, Column { name: 
\"CAST(t2.t2_id AS Int64)\", index: 3 })]",
+                "        CoalescePartitionsExec",
+                "          ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 
as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + 
Int64(11)]",
+                "            RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "              MemoryExec: partitions=1, partition_sizes=[1]",
+                "        ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as 
t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS 
Int64)]",
+                "          RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "            MemoryExec: partitions=1, partition_sizes=[1]",
+            ]
+        };
+        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+
+        // assert execution result
         let expected = vec![
             
"+-------+---------+--------+-------+---------+--------+----------------------+",
             "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int | t1.t1_id 
+ Int64(11) |",
@@ -2448,3 +2492,289 @@ async fn reduce_cross_join_with_wildcard_and_expr() -> 
Result<()> {
 
     Ok(())
 }
+
+#[tokio::test]
+async fn both_side_expr_key_inner_join() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \
+                         FROM t1 \
+                         INNER JOIN t2 \
+                         ON t1.t1_id + cast(12 as INT UNSIGNED)  = t2.t2_id + 
cast(1 as INT UNSIGNED)";
+
+        let msg = format!("Creating logical plan for '{}'", sql);
+        let plan = ctx.create_logical_plan(sql).expect(&msg);
+        let state = ctx.state();
+        let logical_plan = state.optimize(&plan)?;
+        let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+        let expected = if repartition_joins {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, 
t1_name@1 as t1_name]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t2_id@3 as t2_id]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=Partitioned, join_type=Inner, 
on=[(Column { name: \"t1.t1_id + Int64(12)\", index: 2 }, Column { name: 
\"t2.t2_id + Int64(1)\", index: 1 })]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t1.t1_id + Int64(12)\", index: 2 }], 2)",
+                "            ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 
as t1_name, t1_id@0 + CAST(12 AS UInt32) as t1.t1_id + Int64(12)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t2.t2_id + Int64(1)\", index: 1 }], 2)",
+                "            ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 
CAST(1 AS UInt32) as t2.t2_id + Int64(1)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+           ]
+        } else {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, 
t1_name@1 as t1_name]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t2_id@3 as t2_id]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=CollectLeft, join_type=Inner, 
on=[(Column { name: \"t1.t1_id + Int64(12)\", index: 2 }, Column { name: 
\"t2.t2_id + Int64(1)\", index: 1 })]",
+                "        CoalescePartitionsExec",
+                "          ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 
as t1_name, t1_id@0 + CAST(12 AS UInt32) as t1.t1_id + Int64(12)]",
+                "            RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "              MemoryExec: partitions=1, partition_sizes=[1]",
+                "        ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 + 
CAST(1 AS UInt32) as t2.t2_id + Int64(1)]",
+                "          RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "            MemoryExec: partitions=1, partition_sizes=[1]",
+            ]
+        };
+        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+
+        let expected = vec![
+            "+-------+-------+---------+",
+            "| t1_id | t2_id | t1_name |",
+            "+-------+-------+---------+",
+            "| 11    | 22    | a       |",
+            "| 33    | 44    | c       |",
+            "| 44    | 55    | d       |",
+            "+-------+-------+---------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn left_side_expr_key_inner_join() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \
+                         FROM t1 \
+                         INNER JOIN t2 \
+                         ON t1.t1_id + cast(11 as INT UNSIGNED)  = t2.t2_id";
+
+        let msg = format!("Creating logical plan for '{}'", sql);
+        let plan = ctx.create_logical_plan(sql).expect(&msg);
+        let state = ctx.state();
+        let logical_plan = state.optimize(&plan)?;
+        let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+        let expected = if repartition_joins {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, 
t1_name@1 as t1_name]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t2_id@3 as t2_id]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=Partitioned, join_type=Inner, 
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 2 }, Column { name: 
\"t2_id\", index: 0 })]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t1.t1_id + Int64(11)\", index: 2 }], 2)",
+                "            ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 
as t1_name, t1_id@0 + CAST(11 AS UInt32) as t1.t1_id + Int64(11)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t2_id\", index: 0 }], 2)",
+                "            RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "              MemoryExec: partitions=1, partition_sizes=[1]",
+           ]
+        } else {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, 
t1_name@1 as t1_name]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t2_id@3 as t2_id]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "        HashJoinExec: mode=CollectLeft, join_type=Inner, 
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 2 }, Column { name: 
\"t2_id\", index: 0 })]",
+                "          CoalescePartitionsExec",
+                "            ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 
as t1_name, t1_id@0 + CAST(11 AS UInt32) as t1.t1_id + Int64(11)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+                "          MemoryExec: partitions=1, partition_sizes=[1]",
+            ]
+        };
+        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+
+        let expected = vec![
+            "+-------+-------+---------+",
+            "| t1_id | t2_id | t1_name |",
+            "+-------+-------+---------+",
+            "| 11    | 22    | a       |",
+            "| 33    | 44    | c       |",
+            "| 44    | 55    | d       |",
+            "+-------+-------+---------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn right_side_expr_key_inner_join() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \
+                         FROM t1 \
+                         INNER JOIN t2 \
+                         ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED)";
+
+        let msg = format!("Creating logical plan for '{}'", sql);
+        let plan = ctx.create_logical_plan(sql).expect(&msg);
+        let state = ctx.state();
+        let logical_plan = state.optimize(&plan)?;
+        let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+        let expected = if repartition_joins {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, 
t1_name@1 as t1_name]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t2_id@2 as t2_id]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=Partitioned, join_type=Inner, 
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - 
Int64(11)\", index: 1 })]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t1_id\", index: 0 }], 2)",
+                "            RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "              MemoryExec: partitions=1, partition_sizes=[1]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t2.t2_id - Int64(11)\", index: 1 }], 2)",
+                "            ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 
CAST(11 AS UInt32) as t2.t2_id - Int64(11)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+           ]
+        } else {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id, 
t1_name@1 as t1_name]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t2_id@2 as t2_id]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=CollectLeft, join_type=Inner, 
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - 
Int64(11)\", index: 1 })]",
+                "        MemoryExec: partitions=1, partition_sizes=[1]",
+                "        ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 - 
CAST(11 AS UInt32) as t2.t2_id - Int64(11)]",
+                "          RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "            MemoryExec: partitions=1, partition_sizes=[1]",
+            ]
+        };
+        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+
+        let expected = vec![
+            "+-------+-------+---------+",
+            "| t1_id | t2_id | t1_name |",
+            "+-------+-------+---------+",
+            "| 11    | 22    | a       |",
+            "| 33    | 44    | c       |",
+            "| 44    | 55    | d       |",
+            "+-------+-------+---------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn select_wildcard_with_expr_key_inner_join() -> Result<()> {
+    let test_repartition_joins = vec![true, false];
+    for repartition_joins in test_repartition_joins {
+        let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+        let sql = "SELECT * \
+                         FROM t1 \
+                         INNER JOIN t2 \
+                         ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED)";
+
+        let msg = format!("Creating logical plan for '{}'", sql);
+        let plan = ctx.create_logical_plan(sql).expect(&msg);
+        let state = ctx.state();
+        let logical_plan = state.optimize(&plan)?;
+        let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+        let expected = if repartition_joins {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, 
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as 
t2_int]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 
as t2_int]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=Partitioned, join_type=Inner, 
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - 
Int64(11)\", index: 3 })]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t1_id\", index: 0 }], 2)",
+                "            RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "              MemoryExec: partitions=1, partition_sizes=[1]",
+                "        CoalesceBatchesExec: target_batch_size=4096",
+                "          RepartitionExec: partitioning=Hash([Column { name: 
\"t2.t2_id - Int64(11)\", index: 3 }], 2)",
+                "            ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 
as t2_name, t2_int@2 as t2_int, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id - 
Int64(11)]",
+                "              RepartitionExec: 
partitioning=RoundRobinBatch(2)",
+                "                MemoryExec: partitions=1, 
partition_sizes=[1]",
+           ]
+        } else {
+            vec![
+                "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name, 
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as 
t2_int]",
+                "  ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as 
t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 
as t2_int]",
+                "    CoalesceBatchesExec: target_batch_size=4096",
+                "      HashJoinExec: mode=CollectLeft, join_type=Inner, 
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id - 
Int64(11)\", index: 3 })]",
+                "        MemoryExec: partitions=1, partition_sizes=[1]",
+                "        ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as 
t2_name, t2_int@2 as t2_int, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id - 
Int64(11)]",
+                "          RepartitionExec: partitioning=RoundRobinBatch(2)",
+                "            MemoryExec: partitions=1, partition_sizes=[1]",
+            ]
+        };
+        let formatted = 
displayable(physical_plan.as_ref()).indent().to_string();
+        let actual: Vec<&str> = formatted.trim().lines().collect();
+        assert_eq!(
+            expected, actual,
+            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+            expected, actual
+        );
+
+        let expected = vec![
+            "+-------+---------+--------+-------+---------+--------+",
+            "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+            "+-------+---------+--------+-------+---------+--------+",
+            "| 11    | a       | 1      | 22    | y       | 1      |",
+            "| 33    | c       | 3      | 44    | x       | 3      |",
+            "| 44    | d       | 4      | 55    | w       | 3      |",
+            "+-------+---------+--------+-------+---------+--------+",
+        ];
+
+        let results = execute_to_batches(&ctx, sql).await;
+        assert_batches_sorted_eq!(expected, &results);
+    }
+
+    Ok(())
+}
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 1075041c2..d97eb04bd 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -185,7 +185,9 @@ fn create_join_context(
     repartition_joins: bool,
 ) -> Result<SessionContext> {
     let ctx = SessionContext::with_config(
-        SessionConfig::new().with_repartition_joins(repartition_joins),
+        SessionConfig::new()
+            .with_repartition_joins(repartition_joins)
+            .with_target_partitions(2),
     );
 
     let t1_schema = Arc::new(Schema::new(vec![
diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index 2c16a3ca7..f0a9c9e5b 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -18,7 +18,8 @@
 //! This module provides a builder for creating LogicalPlans
 
 use crate::expr_rewriter::{
-    coerce_plan_expr_for_schema, normalize_col, normalize_cols, 
rewrite_sort_cols_by_aggs,
+    coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas,
+    normalize_cols, rewrite_sort_cols_by_aggs,
 };
 use crate::type_coercion::binary::comparison_coercion;
 use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
@@ -31,8 +32,8 @@ use crate::{
         Union, Values, Window,
     },
     utils::{
-        can_hash, expand_qualified_wildcard, expand_wildcard,
-        group_window_expr_by_sort_keys,
+        can_hash, check_all_column_from_schema, expand_qualified_wildcard,
+        expand_wildcard, group_window_expr_by_sort_keys,
     },
     Expr, ExprSchemable, TableSource,
 };
@@ -555,7 +556,11 @@ impl LogicalPlanBuilder {
         let left_keys = 
left_keys.into_iter().collect::<Result<Vec<Column>>>()?;
         let right_keys = 
right_keys.into_iter().collect::<Result<Vec<Column>>>()?;
 
-        let on: Vec<(_, _)> = 
left_keys.into_iter().zip(right_keys.into_iter()).collect();
+        let on = left_keys
+            .into_iter()
+            .zip(right_keys.into_iter())
+            .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
+            .collect();
         let join_schema =
             build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
 
@@ -591,19 +596,19 @@ impl LogicalPlanBuilder {
         let on: Vec<(_, _)> = 
left_keys.into_iter().zip(right_keys.into_iter()).collect();
         let join_schema =
             build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
-        let mut join_on: Vec<(Column, Column)> = vec![];
+        let mut join_on: Vec<(Expr, Expr)> = vec![];
         let mut filters: Option<Expr> = None;
         for (l, r) in &on {
             if self.plan.schema().field_from_column(l).is_ok()
                 && right.schema().field_from_column(r).is_ok()
                 && 
can_hash(self.plan.schema().field_from_column(l)?.data_type())
             {
-                join_on.push((l.clone(), r.clone()));
+                join_on.push((Expr::Column(l.clone()), 
Expr::Column(r.clone())));
             } else if self.plan.schema().field_from_column(r).is_ok()
                 && right.schema().field_from_column(l).is_ok()
                 && 
can_hash(self.plan.schema().field_from_column(r)?.data_type())
             {
-                join_on.push((r.clone(), l.clone()));
+                join_on.push((Expr::Column(r.clone()), 
Expr::Column(l.clone())));
             } else {
                 let expr = binary_expr(
                     Expr::Column(l.clone()),
@@ -616,6 +621,7 @@ impl LogicalPlanBuilder {
                 }
             }
         }
+
         if join_on.is_empty() {
             let join = Self::from(self.plan).cross_join(right)?;
             join.filter(filters.ok_or_else(|| {
@@ -791,6 +797,98 @@ impl LogicalPlanBuilder {
     pub fn build(self) -> Result<LogicalPlan> {
         Ok(self.plan)
     }
+
+    /// Apply a join with the expression on constraint.
+    ///
+    /// equi_exprs are "equijoin" predicates expressions on the existing and 
right inputs, respectively.
+    ///
+    /// filter: any other filter expression to apply during the join. 
equi_exprs predicates are likely
+    /// to be evaluated more quickly than the filter expressions
+    pub fn join_with_expr_keys(
+        self,
+        right: LogicalPlan,
+        join_type: JoinType,
+        equi_exprs: (Vec<impl Into<Expr>>, Vec<impl Into<Expr>>),
+        filter: Option<Expr>,
+    ) -> Result<Self> {
+        if equi_exprs.0.len() != equi_exprs.1.len() {
+            return Err(DataFusionError::Plan(
+                "left_keys and right_keys were not the same 
length".to_string(),
+            ));
+        }
+
+        let join_key_pairs = equi_exprs
+            .0
+            .into_iter()
+            .zip(equi_exprs.1.into_iter())
+            .map(|(l, r)| {
+                let left_key = l.into();
+                let right_key = r.into();
+
+                let left_using_columns = left_key.to_columns()?;
+                let normalized_left_key = normalize_col_with_schemas(
+                    left_key,
+                    &[self.plan.schema(), right.schema()],
+                    &[left_using_columns],
+                )?;
+
+                let right_using_columns = right_key.to_columns()?;
+                let normalized_right_key = normalize_col_with_schemas(
+                    right_key,
+                    &[self.plan.schema(), right.schema()],
+                    &[right_using_columns],
+                )?;
+
+                let normalized_left_using_columns = 
normalized_left_key.to_columns()?;
+                let l_is_left = check_all_column_from_schema(
+                    &normalized_left_using_columns,
+                    self.plan.schema().clone(),
+                )?;
+
+                let normalized_right_using_columns = 
normalized_right_key.to_columns()?;
+                let r_is_right = check_all_column_from_schema(
+                    &normalized_right_using_columns,
+                    right.schema().clone(),
+                )?;
+
+                let r_is_left_and_l_is_right = || {
+                    let result = check_all_column_from_schema(
+                        &normalized_right_using_columns,
+                        self.plan.schema().clone(),
+                    )? && check_all_column_from_schema(
+                        &normalized_left_using_columns,
+                        right.schema().clone(),
+                    )?;
+                    Result::Ok(result)
+                };
+
+                if l_is_left && r_is_right {
+                    Ok((normalized_left_key, normalized_right_key))
+                } else if r_is_left_and_l_is_right()?{
+                    Ok((normalized_right_key, normalized_left_key))
+                } else {
+                    Err(DataFusionError::Plan(format!(
+                        "can't create join plan, join key should belong to one 
input, error key: ({},{})",
+                        normalized_left_key, normalized_right_key
+                    )))
+                }
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        let join_schema =
+            build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+
+        Ok(Self::from(LogicalPlan::Join(Join {
+            left: Arc::new(self.plan),
+            right: Arc::new(right),
+            on: join_key_pairs,
+            filter,
+            join_type,
+            join_constraint: JoinConstraint::On,
+            schema: DFSchemaRef::new(join_schema),
+            null_equals_null: false,
+        })))
+    }
 }
 
 /// Creates a schema for a join operation.
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index f1eb3f135..9368b1ec4 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -25,7 +25,9 @@ use crate::utils::{
     self, exprlist_to_fields, from_plan, grouping_set_expr_count,
     grouping_set_to_exprlist,
 };
-use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource};
+use crate::{
+    build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown, 
TableSource,
+};
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use datafusion_common::{
     plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, 
OwnedTableReference,
@@ -252,7 +254,7 @@ impl LogicalPlan {
             }) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
             LogicalPlan::Join(Join { on, filter, .. }) => on
                 .iter()
-                .flat_map(|(l, r)| vec![Expr::Column(l.clone()), 
Expr::Column(r.clone())])
+                .flat_map(|(l, r)| vec![l.clone(), r.clone()])
                 .chain(
                     filter
                         .as_ref()
@@ -343,12 +345,14 @@ impl LogicalPlan {
                     ..
                 }) = plan
                 {
-                    self.using_columns.push(
-                        on.iter()
-                            .flat_map(|entry| [&entry.0, &entry.1])
-                            .cloned()
-                            .collect::<HashSet<Column>>(),
-                    );
+                    // The join keys in using-join must be columns.
+                    let columns =
+                        on.iter().try_fold(HashSet::new(), |mut accumu, (l, 
r)| {
+                            accumu.insert(l.try_into_col()?);
+                            accumu.insert(r.try_into_col()?);
+                            Result::<_, DataFusionError>::Ok(accumu)
+                        })?;
+                    self.using_columns.push(columns);
                 }
                 Ok(true)
             }
@@ -1646,8 +1650,8 @@ pub struct Join {
     pub left: Arc<LogicalPlan>,
     /// Right input
     pub right: Arc<LogicalPlan>,
-    /// Equijoin clause expressed as pairs of (left, right) join columns
-    pub on: Vec<(Column, Column)>,
+    /// Equijoin clause expressed as pairs of (left, right) join expressions
+    pub on: Vec<(Expr, Expr)>,
     /// Filters applied during join (non-equi conditions)
     pub filter: Option<Expr>,
     /// Join type
@@ -1660,6 +1664,41 @@ pub struct Join {
     pub null_equals_null: bool,
 }
 
+impl Join {
+    /// Create Join with input which wrapped with projection, this method is 
used to help create physical join.
+    pub fn try_new_with_project_input(
+        original: &LogicalPlan,
+        left: Arc<LogicalPlan>,
+        right: Arc<LogicalPlan>,
+        column_on: (Vec<Column>, Vec<Column>),
+    ) -> Result<Self, DataFusionError> {
+        let original_join = match original {
+            LogicalPlan::Join(join) => join,
+            _ => return plan_err!("Could not create join with project input"),
+        };
+
+        let on: Vec<(Expr, Expr)> = column_on
+            .0
+            .into_iter()
+            .zip(column_on.1.into_iter())
+            .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
+            .collect();
+        let join_schema =
+            build_join_schema(left.schema(), right.schema(), 
&original_join.join_type)?;
+
+        Ok(Join {
+            left,
+            right,
+            on,
+            filter: original_join.filter.clone(),
+            join_type: original_join.join_type,
+            join_constraint: original_join.join_constraint,
+            schema: Arc::new(join_schema),
+            null_equals_null: original_join.null_equals_null,
+        })
+    }
+}
+
 /// Subquery
 #[derive(Clone)]
 pub struct Subquery {
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs 
b/datafusion/optimizer/src/eliminate_cross_join.rs
index 338e58f68..6aaf02d26 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -19,18 +19,14 @@
 use std::collections::HashSet;
 use std::sync::Arc;
 
+use crate::{utils, OptimizerConfig, OptimizerRule};
 use datafusion_common::{DataFusionError, Result};
 use datafusion_expr::expr::{BinaryExpr, Expr};
 use datafusion_expr::logical_plan::{
     CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
 };
 use datafusion_expr::utils::{can_hash, check_all_column_from_schema};
-use datafusion_expr::{
-    and, build_join_schema, or, wrap_projection_for_join_if_necessary, 
ExprSchemable,
-    Operator,
-};
-
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_expr::{and, build_join_schema, or, ExprSchemable, Operator};
 
 #[derive(Default)]
 pub struct EliminateCrossJoin;
@@ -142,11 +138,7 @@ fn flatten_join_inputs(
 ) -> Result<()> {
     let children = match plan {
         LogicalPlan::Join(join) => {
-            for join_keys in join.on.iter() {
-                let join_keys = join_keys.clone();
-                possible_join_keys
-                    .push((Expr::Column(join_keys.0), 
Expr::Column(join_keys.1)));
-            }
+            possible_join_keys.extend(join.on.clone());
             let left = &*(join.left);
             let right = &*(join.right);
             Ok::<Vec<&LogicalPlan>, DataFusionError>(vec![left, right])
@@ -239,26 +231,12 @@ fn find_inner_join(
                 &JoinType::Inner,
             )?);
 
-            // Wrap projection
-            let (left_on, right_on): (Vec<Expr>, Vec<Expr>) =
-                join_keys.into_iter().unzip();
-            let (new_left_input, new_left_on, _) =
-                wrap_projection_for_join_if_necessary(&left_on, 
left_input.clone())?;
-            let (new_right_input, new_right_on, _) =
-                wrap_projection_for_join_if_necessary(&right_on, right_input)?;
-
-            // Build new join on
-            let join_on = new_left_on
-                .into_iter()
-                .zip(new_right_on.into_iter())
-                .collect::<Vec<_>>();
-
             return Ok(LogicalPlan::Join(Join {
-                left: Arc::new(new_left_input),
-                right: Arc::new(new_right_input),
+                left: Arc::new(left_input.clone()),
+                right: Arc::new(right_input),
                 join_type: JoinType::Inner,
                 join_constraint: JoinConstraint::On,
-                on: join_on,
+                on: join_keys,
                 filter: None,
                 schema: join_schema,
                 null_equals_null: false,
@@ -1108,14 +1086,10 @@ mod tests {
             .build()?;
 
         let expected = vec![
-              "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32]",
-              "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
-              "    Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, 
c:UInt32, t2.a * UInt32(2):UInt32]",
-              "      Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
-              "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
-              "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
-              "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
-        ];
+            "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, 
a:UInt32, b:UInt32, c:UInt32]",
+            "  Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"];
 
         assert_optimized_plan_eq(&plan, expected);
 
@@ -1167,13 +1141,10 @@ mod tests {
             .build()?;
 
         let expected = vec![
-               "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
-               "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
-               "    Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, 
c:UInt32, t2.a * UInt32(2):UInt32]",
-               "      Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
-               "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
-               "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
-               "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+            "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "  Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+            "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
         ];
 
         assert_optimized_plan_eq(&plan, expected);
@@ -1198,14 +1169,11 @@ mod tests {
             .build()?;
 
         let expected = vec![
-               "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
-               "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
-               "    Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32, 
c:UInt32, t2.a * UInt32(2):UInt32]",
-               "      Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
-               "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
-               "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
-               "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
-       ];
+        "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, 
c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+        "  Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+        "    TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+        "    TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+        ];
 
         assert_optimized_plan_eq(&plan, expected);
 
@@ -1240,15 +1208,11 @@ mod tests {
         let expected = vec![
             "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, 
c:UInt32]",
             "  Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, 
t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, 
b:UInt32, c:UInt32]",
-            "    Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32, 
a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
-            "      Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a + 
UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + 
UInt32(100):UInt32]",
-            "        Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32, 
c:UInt32, t3.a + UInt32(100):UInt32]",
-            "          Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2) 
[a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]",
-            "            TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
-            "          Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
-            "            TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
-            "      Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
-            "        TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+            "    Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, 
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, 
c:UInt32]",
+            "      Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) 
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+            "        TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+            "        TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
+            "      TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
         ];
 
         assert_optimized_plan_eq(&plan, expected);
diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs 
b/datafusion/optimizer/src/filter_null_join_keys.rs
index 4bee88359..eea98ad1f 100644
--- a/datafusion/optimizer/src/filter_null_join_keys.rs
+++ b/datafusion/optimizer/src/filter_null_join_keys.rs
@@ -20,14 +20,12 @@
 //! and then insert an `IsNotNull` filter on the nullable side since null 
values
 //! can never match.
 
-use std::sync::Arc;
-
-use datafusion_common::{Column, DFField, DFSchemaRef, Result};
+use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_common::Result;
 use datafusion_expr::{
-    and, logical_plan::Filter, logical_plan::JoinType, Expr, LogicalPlan,
+    and, logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, 
LogicalPlan,
 };
-
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use std::sync::Arc;
 
 /// The FilterNullJoinKeys rule will identify inner joins with equi-join 
conditions
 /// where the join key is nullable on one side and non-nullable on the other 
side
@@ -66,15 +64,12 @@ impl OptimizerRule for FilterNullJoinKeys {
                 let mut right_filters = vec![];
 
                 for (l, r) in &join.on {
-                    if let Some((left_field, right_field)) =
-                        resolve_join_key_pair(left_schema, right_schema, l, r)
-                    {
-                        if left_field.is_nullable() {
-                            left_filters.push(l.clone());
-                        }
-                        if right_field.is_nullable() {
-                            right_filters.push(r.clone());
-                        }
+                    if l.nullable(left_schema)? {
+                        left_filters.push(l.clone());
+                    }
+
+                    if r.nullable(right_schema)? {
+                        right_filters.push(r.clone());
                     }
                 }
 
@@ -106,10 +101,10 @@ impl OptimizerRule for FilterNullJoinKeys {
     }
 }
 
-fn create_not_null_predicate(columns: Vec<Column>) -> Expr {
-    let not_null_exprs: Vec<Expr> = columns
+fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
+    let not_null_exprs: Vec<Expr> = filters
         .into_iter()
-        .map(|c| Expr::IsNotNull(Box::new(Expr::Column(c))))
+        .map(|c| Expr::IsNotNull(Box::new(c)))
         .collect();
     // combine the IsNotNull expressions with AND
     not_null_exprs
@@ -118,42 +113,13 @@ fn create_not_null_predicate(columns: Vec<Column>) -> 
Expr {
         .fold(not_null_exprs[0].clone(), |a, b| and(a, b.clone()))
 }
 
-fn resolve_join_key_pair(
-    left_schema: &DFSchemaRef,
-    right_schema: &DFSchemaRef,
-    c1: &Column,
-    c2: &Column,
-) -> Option<(DFField, DFField)> {
-    resolve_fields(left_schema, right_schema, c1, c2)
-        .or_else(|| resolve_fields(left_schema, right_schema, c2, c1))
-}
-
-fn resolve_fields(
-    left_schema: &DFSchemaRef,
-    right_schema: &DFSchemaRef,
-    c1: &Column,
-    c2: &Column,
-) -> Option<(DFField, DFField)> {
-    match (
-        left_schema.index_of_column(c1),
-        right_schema.index_of_column(c2),
-    ) {
-        (Ok(left_index), Ok(right_index)) => {
-            let left_field = left_schema.field(left_index);
-            let right_field = right_schema.field(right_index);
-            Some((left_field.clone(), right_field.clone()))
-        }
-        _ => None,
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use arrow::datatypes::{DataType, Field, Schema};
 
     use datafusion_common::{Column, Result};
     use datafusion_expr::logical_plan::table_scan;
-    use datafusion_expr::{logical_plan::JoinType, LogicalPlanBuilder};
+    use datafusion_expr::{col, lit, logical_plan::JoinType, 
LogicalPlanBuilder};
 
     use crate::optimizer::OptimizerContext;
 
@@ -234,6 +200,74 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn left_nullable_expr_key() -> Result<()> {
+        let (t1, t2) = test_tables()?;
+        let plan = LogicalPlanBuilder::from(t1)
+            .join_with_expr_keys(
+                t2,
+                JoinType::Inner,
+                (
+                    vec![col("t1.optional_id") + lit(1u32)],
+                    vec![col("t2.id") + lit(1u32)],
+                ),
+                None,
+            )?
+            .build()?;
+        let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id + 
UInt32(1)\
+        \n  Filter: t1.optional_id + UInt32(1) IS NOT NULL\
+        \n    TableScan: t1\
+        \n  TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn right_nullable_expr_key() -> Result<()> {
+        let (t1, t2) = test_tables()?;
+        let plan = LogicalPlanBuilder::from(t1)
+            .join_with_expr_keys(
+                t2,
+                JoinType::Inner,
+                (
+                    vec![col("t1.id") + lit(1u32)],
+                    vec![col("t2.optional_id") + lit(1u32)],
+                ),
+                None,
+            )?
+            .build()?;
+        let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id + 
UInt32(1)\
+        \n  TableScan: t1\
+        \n  Filter: t2.optional_id + UInt32(1) IS NOT NULL\
+        \n    TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn both_side_nullable_expr_key() -> Result<()> {
+        let (t1, t2) = test_tables()?;
+        let plan = LogicalPlanBuilder::from(t1)
+            .join_with_expr_keys(
+                t2,
+                JoinType::Inner,
+                (
+                    vec![col("t1.optional_id") + lit(1u32)],
+                    vec![col("t2.optional_id") + lit(1u32)],
+                ),
+                None,
+            )?
+            .build()?;
+        let expected =
+            "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id + 
UInt32(1)\
+        \n  Filter: t1.optional_id + UInt32(1) IS NOT NULL\
+        \n    TableScan: t1\
+        \n  Filter: t2.optional_id + UInt32(1) IS NOT NULL\
+        \n    TableScan: t2";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
     fn build_plan(
         left_table: LogicalPlan,
         right_table: LogicalPlan,
diff --git a/datafusion/optimizer/src/push_down_filter.rs 
b/datafusion/optimizer/src/push_down_filter.rs
index 0cf6c635b..4ca14c278 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -457,8 +457,18 @@ fn push_down_join(
                     Err(e) => return Some(Err(e)),
                 };
 
+                // Only allow both side key is column.
+                let join_col_keys = join
+                    .on
+                    .iter()
+                    .flat_map(|(l, r)| match (l.try_into_col(), 
r.try_into_col()) {
+                        (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
+                        _ => None,
+                    })
+                    .collect::<Vec<_>>();
+
                 for col in columns.iter() {
-                    for (l, r) in join.on.iter() {
+                    for (l, r) in join_col_keys.iter() {
                         if col == l {
                             join_cols_to_replace.insert(col, r);
                             break;
diff --git a/datafusion/optimizer/src/push_down_projection.rs 
b/datafusion/optimizer/src/push_down_projection.rs
index 01b089e42..fb8e241fd 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -156,8 +156,8 @@ fn optimize_plan(
             ..
         }) => {
             for (l, r) in on {
-                new_required_columns.insert(l.clone());
-                new_required_columns.insert(r.clone());
+                new_required_columns.extend(l.to_columns()?);
+                new_required_columns.extend(r.to_columns()?);
             }
 
             if let Some(expr) = filter {
diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs 
b/datafusion/optimizer/src/subquery_filter_to_join.rs
index 436d478b9..da7986956 100644
--- a/datafusion/optimizer/src/subquery_filter_to_join.rs
+++ b/datafusion/optimizer/src/subquery_filter_to_join.rs
@@ -132,7 +132,7 @@ impl OptimizerRule for SubqueryFilterToJoin {
                             Ok(LogicalPlan::Join(Join {
                                 left: Arc::new(input),
                                 right: Arc::new(right_input),
-                                on: vec![(left_key, right_key)],
+                                on: vec![(Expr::Column(left_key), 
Expr::Column(right_key))],
                                 filter: None,
                                 join_type,
                                 join_constraint: JoinConstraint::On,
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 97ba57a7e..0b0124007 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -257,8 +257,8 @@ message JoinNode {
   LogicalPlanNode right = 2;
   JoinType join_type = 3;
   JoinConstraint join_constraint = 4;
-  repeated datafusion.Column left_join_column = 5;
-  repeated datafusion.Column right_join_column = 6;
+  repeated LogicalExprNode left_join_key = 5;
+  repeated LogicalExprNode right_join_key = 6;
   bool null_equals_null = 7;
   LogicalExprNode filter = 8;
 }
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 13236a935..402048203 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -9591,10 +9591,10 @@ impl serde::Serialize for JoinNode {
         if self.join_constraint != 0 {
             len += 1;
         }
-        if !self.left_join_column.is_empty() {
+        if !self.left_join_key.is_empty() {
             len += 1;
         }
-        if !self.right_join_column.is_empty() {
+        if !self.right_join_key.is_empty() {
             len += 1;
         }
         if self.null_equals_null {
@@ -9620,11 +9620,11 @@ impl serde::Serialize for JoinNode {
                 .ok_or_else(|| serde::ser::Error::custom(format!("Invalid 
variant {}", self.join_constraint)))?;
             struct_ser.serialize_field("joinConstraint", &v)?;
         }
-        if !self.left_join_column.is_empty() {
-            struct_ser.serialize_field("leftJoinColumn", 
&self.left_join_column)?;
+        if !self.left_join_key.is_empty() {
+            struct_ser.serialize_field("leftJoinKey", &self.left_join_key)?;
         }
-        if !self.right_join_column.is_empty() {
-            struct_ser.serialize_field("rightJoinColumn", 
&self.right_join_column)?;
+        if !self.right_join_key.is_empty() {
+            struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?;
         }
         if self.null_equals_null {
             struct_ser.serialize_field("nullEqualsNull", 
&self.null_equals_null)?;
@@ -9648,10 +9648,10 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
             "joinType",
             "join_constraint",
             "joinConstraint",
-            "left_join_column",
-            "leftJoinColumn",
-            "right_join_column",
-            "rightJoinColumn",
+            "left_join_key",
+            "leftJoinKey",
+            "right_join_key",
+            "rightJoinKey",
             "null_equals_null",
             "nullEqualsNull",
             "filter",
@@ -9663,8 +9663,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
             Right,
             JoinType,
             JoinConstraint,
-            LeftJoinColumn,
-            RightJoinColumn,
+            LeftJoinKey,
+            RightJoinKey,
             NullEqualsNull,
             Filter,
         }
@@ -9692,8 +9692,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                             "right" => Ok(GeneratedField::Right),
                             "joinType" | "join_type" => 
Ok(GeneratedField::JoinType),
                             "joinConstraint" | "join_constraint" => 
Ok(GeneratedField::JoinConstraint),
-                            "leftJoinColumn" | "left_join_column" => 
Ok(GeneratedField::LeftJoinColumn),
-                            "rightJoinColumn" | "right_join_column" => 
Ok(GeneratedField::RightJoinColumn),
+                            "leftJoinKey" | "left_join_key" => 
Ok(GeneratedField::LeftJoinKey),
+                            "rightJoinKey" | "right_join_key" => 
Ok(GeneratedField::RightJoinKey),
                             "nullEqualsNull" | "null_equals_null" => 
Ok(GeneratedField::NullEqualsNull),
                             "filter" => Ok(GeneratedField::Filter),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
@@ -9719,8 +9719,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                 let mut right__ = None;
                 let mut join_type__ = None;
                 let mut join_constraint__ = None;
-                let mut left_join_column__ = None;
-                let mut right_join_column__ = None;
+                let mut left_join_key__ = None;
+                let mut right_join_key__ = None;
                 let mut null_equals_null__ = None;
                 let mut filter__ = None;
                 while let Some(k) = map.next_key()? {
@@ -9749,17 +9749,17 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                             }
                             join_constraint__ = 
Some(map.next_value::<JoinConstraint>()? as i32);
                         }
-                        GeneratedField::LeftJoinColumn => {
-                            if left_join_column__.is_some() {
-                                return 
Err(serde::de::Error::duplicate_field("leftJoinColumn"));
+                        GeneratedField::LeftJoinKey => {
+                            if left_join_key__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("leftJoinKey"));
                             }
-                            left_join_column__ = Some(map.next_value()?);
+                            left_join_key__ = Some(map.next_value()?);
                         }
-                        GeneratedField::RightJoinColumn => {
-                            if right_join_column__.is_some() {
-                                return 
Err(serde::de::Error::duplicate_field("rightJoinColumn"));
+                        GeneratedField::RightJoinKey => {
+                            if right_join_key__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("rightJoinKey"));
                             }
-                            right_join_column__ = Some(map.next_value()?);
+                            right_join_key__ = Some(map.next_value()?);
                         }
                         GeneratedField::NullEqualsNull => {
                             if null_equals_null__.is_some() {
@@ -9780,8 +9780,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
                     right: right__,
                     join_type: join_type__.unwrap_or_default(),
                     join_constraint: join_constraint__.unwrap_or_default(),
-                    left_join_column: left_join_column__.unwrap_or_default(),
-                    right_join_column: right_join_column__.unwrap_or_default(),
+                    left_join_key: left_join_key__.unwrap_or_default(),
+                    right_join_key: right_join_key__.unwrap_or_default(),
                     null_equals_null: null_equals_null__.unwrap_or_default(),
                     filter: filter__,
                 })
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 1405e1eba..18bf2796f 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -364,9 +364,9 @@ pub struct JoinNode {
     #[prost(enumeration = "JoinConstraint", tag = "4")]
     pub join_constraint: i32,
     #[prost(message, repeated, tag = "5")]
-    pub left_join_column: ::prost::alloc::vec::Vec<Column>,
+    pub left_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>,
     #[prost(message, repeated, tag = "6")]
-    pub right_join_column: ::prost::alloc::vec::Vec<Column>,
+    pub right_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>,
     #[prost(bool, tag = "7")]
     pub null_equals_null: bool,
     #[prost(message, optional, tag = "8")]
diff --git a/datafusion/proto/src/logical_plan.rs 
b/datafusion/proto/src/logical_plan.rs
index 09ae9e41e..f2000d7a7 100644
--- a/datafusion/proto/src/logical_plan.rs
+++ b/datafusion/proto/src/logical_plan.rs
@@ -38,7 +38,7 @@ use datafusion::{
     datasource::{provider_as_source, source_as_provider},
     prelude::SessionContext,
 };
-use datafusion_common::{context, Column, DataFusionError, OwnedTableReference};
+use datafusion_common::{context, DataFusionError, OwnedTableReference};
 use datafusion_expr::logical_plan::{builder::project, Prepare};
 use datafusion_expr::{
     logical_plan::{
@@ -695,10 +695,16 @@ impl AsLogicalPlan for LogicalPlanNode {
                 LogicalPlanBuilder::from(input).limit(skip, fetch)?.build()
             }
             LogicalPlanType::Join(join) => {
-                let left_keys: Vec<Column> =
-                    join.left_join_column.iter().map(|i| i.into()).collect();
-                let right_keys: Vec<Column> =
-                    join.right_join_column.iter().map(|i| i.into()).collect();
+                let left_keys: Vec<Expr> = join
+                    .left_join_key
+                    .iter()
+                    .map(|expr| parse_expr(expr, ctx))
+                    .collect::<Result<Vec<_>, _>>()?;
+                let right_keys: Vec<Expr> = join
+                    .right_join_key
+                    .iter()
+                    .map(|expr| parse_expr(expr, ctx))
+                    .collect::<Result<Vec<_>, _>>()?;
                 let join_type =
                     protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| 
{
                         proto_error(format!(
@@ -727,17 +733,24 @@ impl AsLogicalPlan for LogicalPlanNode {
                     extension_codec
                 )?);
                 let builder = match join_constraint.into() {
-                    JoinConstraint::On => builder.join(
+                    JoinConstraint::On => builder.join_with_expr_keys(
                         into_logical_plan!(join.right, ctx, extension_codec)?,
                         join_type.into(),
                         (left_keys, right_keys),
                         filter,
                     )?,
-                    JoinConstraint::Using => builder.join_using(
-                        into_logical_plan!(join.right, ctx, extension_codec)?,
-                        join_type.into(),
-                        left_keys,
-                    )?,
+                    JoinConstraint::Using => {
+                        // The equijoin keys in using-join must be column.
+                        let using_keys = left_keys
+                            .into_iter()
+                            .map(|key| key.try_into_col())
+                            .collect::<Result<Vec<_>, _>>()?;
+                        builder.join_using(
+                            into_logical_plan!(join.right, ctx, 
extension_codec)?,
+                            join_type.into(),
+                            using_keys,
+                        )?
+                    }
                 };
 
                 builder.build()
@@ -1100,8 +1113,12 @@ impl AsLogicalPlan for LogicalPlanNode {
                         right.as_ref(),
                         extension_codec,
                     )?;
-                let (left_join_column, right_join_column) =
-                    on.iter().map(|(l, r)| (l.into(), r.into())).unzip();
+                let (left_join_key, right_join_key) = on
+                    .iter()
+                    .map(|(l, r)| Ok((l.try_into()?, r.try_into()?)))
+                    .collect::<Result<Vec<_>, to_proto::Error>>()?
+                    .into_iter()
+                    .unzip();
                 let join_type: protobuf::JoinType = 
join_type.to_owned().into();
                 let join_constraint: protobuf::JoinConstraint =
                     join_constraint.to_owned().into();
@@ -1116,8 +1133,8 @@ impl AsLogicalPlan for LogicalPlanNode {
                             right: Some(Box::new(right)),
                             join_type: join_type.into(),
                             join_constraint: join_constraint.into(),
-                            left_join_column,
-                            right_join_column,
+                            left_join_key,
+                            right_join_key,
                             null_equals_null: *null_equals_null,
                             filter,
                         },
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 761e13fda..7f12e9dec 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -47,9 +47,7 @@ use datafusion_expr::expr::{
 };
 use datafusion_expr::expr_rewriter::normalize_col;
 use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
-use datafusion_expr::logical_plan::builder::{
-    project, wrap_projection_for_join_if_necessary,
-};
+use datafusion_expr::logical_plan::builder::project;
 use datafusion_expr::logical_plan::Join as HashJoin;
 use datafusion_expr::logical_plan::JoinConstraint as HashJoinConstraint;
 use datafusion_expr::logical_plan::{
@@ -852,32 +850,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     }
                     join.build()
                 } else {
-                    // Wrap projection for left input if left join keys 
contain normal expression.
-                    let (left_child, left_join_keys, left_projected) =
-                        wrap_projection_for_join_if_necessary(&left_keys, 
left)?;
-
-                    // Wrap projection for right input if right join keys 
contains normal expression.
-                    let (right_child, right_join_keys, right_projected) =
-                        wrap_projection_for_join_if_necessary(&right_keys, 
right)?;
-
-                    let join_plan_builder = 
LogicalPlanBuilder::from(left_child).join(
-                        right_child,
-                        join_type,
-                        (left_join_keys, right_join_keys),
-                        join_filter,
-                    )?;
-
-                    // Remove temporary projected columns if necessary.
-                    if left_projected || right_projected {
-                        let final_join_result = join_schema
-                            .fields()
-                            .iter()
-                            .map(|field| 
Expr::Column(field.qualified_column()))
-                            .collect::<Vec<_>>();
-                        join_plan_builder.project(final_join_result)?.build()
-                    } else {
-                        join_plan_builder.build()
-                    }
+                    LogicalPlanBuilder::from(left)
+                        .join_with_expr_keys(
+                            right,
+                            join_type,
+                            (left_keys, right_keys),
+                            join_filter,
+                        )?
+                        .build()
                 }
             }
             JoinConstraint::Using(idents) => {
@@ -1045,8 +1025,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         {
             // For query: select id from t1 join t2 using(id), this is legal.
             // We should dedup the fields for cols in using clause.
-            for join_cols in on.iter() {
-                let left_field = 
left.schema().field_from_column(&join_cols.0)?;
+            for join_keys in on.iter() {
+                let join_col = &join_keys.0.try_into_col()?;
+                let left_field = left.schema().field_from_column(join_col)?;
                 fields.retain(|field| {
                     field.unqualified_column().name
                         != left_field.unqualified_column().name
@@ -6006,12 +5987,10 @@ mod tests {
             ON orders.customer_id * 2 = person.id + 10";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id + Int64(10) = orders.customer_id * 
Int64(2)\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
+ Int64(10)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id * Int64(2)\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+        \n    TableScan: person\
+        \n    TableScan: orders";
+
         quick_test(sql, expected);
     }
 
@@ -6023,12 +6002,9 @@ mod tests {
             ON person.id + 10 = orders.customer_id * 2";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id + Int64(10) = orders.customer_id * 
Int64(2)\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
+ Int64(10)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id * Int64(2)\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6040,12 +6016,9 @@ mod tests {
             ON person.id + person.age + 10 = orders.customer_id * 2 - 
orders.price";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id + person.age + Int64(10) = 
orders.customer_id * Int64(2) - orders.price\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
+ person.age + Int64(10)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id * Int64(2) - orders.price\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id + person.age + Int64(10) = 
orders.customer_id * Int64(2) - orders.price\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6057,11 +6030,9 @@ mod tests {
             ON person.id + person.age + 10 = orders.customer_id";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id + person.age + Int64(10) = 
orders.customer_id\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
+ person.age + Int64(10)\
-        \n        TableScan: person\
-        \n      TableScan: orders";
+        \n  Inner Join: person.id + person.age + Int64(10) = 
orders.customer_id\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6073,11 +6044,9 @@ mod tests {
             ON person.id = orders.customer_id * 2 - orders.price";
 
         let expected = "Projection: person.id, orders.order_id\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id = orders.customer_id * Int64(2) - 
orders.price\
-        \n      TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id * Int64(2) - orders.price\
-        \n        TableScan: orders";
+       \n  Inner Join: person.id = orders.customer_id * Int64(2) - 
orders.price\
+       \n    TableScan: person\
+       \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6106,12 +6075,9 @@ mod tests {
             ON orders.customer_id * 2 = person.id + 10";
 
         let expected = "Projection: person.id, person.first_name, 
person.last_name, person.age, person.state, person.salary, person.birth_date, 
person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id + Int64(10) = orders.customer_id * 
Int64(2)\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
+ Int64(10)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id * Int64(2)\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6123,12 +6089,9 @@ mod tests {
             ON orders.customer_id * 2 = person.id + 10";
 
         let expected = "Projection: orders.customer_id * Int64(2), person.id + 
Int64(10)\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id + Int64(10) = orders.customer_id * 
Int64(2)\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
+ Int64(10)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id * Int64(2)\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6156,12 +6119,9 @@ mod tests {
             ON person.id * 2 = orders.customer_id + 10 and person.id * 2 = 
orders.order_id";
 
         let expected = "Projection: person.id, person.age\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: person.id * Int64(2) = orders.customer_id + 
Int64(10), person.id * Int64(2) = orders.order_id\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
* Int64(2)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id + Int64(10)\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), 
person.id * Int64(2) = orders.order_id\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6174,11 +6134,9 @@ mod tests {
             ON person.id * 2 = orders.customer_id + 10 and person.id =  
orders.customer_id + 10";
 
         let expected = "Projection: person.id, person.age\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\n    Inner Join: person.id * Int64(2) = 
orders.customer_id + Int64(10), person.id = orders.customer_id + Int64(10)\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, person.id 
* Int64(2)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
orders.customer_id + Int64(10)\
-        \n        TableScan: orders";
+        \n  Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10), 
person.id = orders.customer_id + Int64(10)\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 
@@ -6596,12 +6554,9 @@ mod tests {
             ON cast(person.id as Int) = cast(orders.customer_id as Int)";
 
         let expected = "Projection: person.id, person.age\
-        \n  Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, 
orders.price, orders.delivered\
-        \n    Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id 
AS Int32)\
-        \n      Projection: person.id, person.first_name, person.last_name, 
person.age, person.state, person.salary, person.birth_date, person.😀, 
CAST(person.id AS Int32) AS CAST(person.id AS Int32)\
-        \n        TableScan: person\
-        \n      Projection: orders.order_id, orders.customer_id, 
orders.o_item_id, orders.qty, orders.price, orders.delivered, 
CAST(orders.customer_id AS Int32) AS CAST(orders.customer_id AS Int32)\
-        \n        TableScan: orders";
+        \n  Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS 
Int32)\
+        \n    TableScan: person\
+        \n    TableScan: orders";
         quick_test(sql, expected);
     }
 


Reply via email to