This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d594e6257b Relax join keys constraint from Column to any physical 
expression for physical join operators (#8991)
d594e6257b is described below

commit d594e6257b34a5ad47112e26d41516aaeb19e6dd
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jan 29 10:02:22 2024 -0800

    Relax join keys constraint from Column to any physical expression for 
physical join operators (#8991)
    
    * Relex SortMergeJoin join keys
    
    * More
    
    * More
    
    * More
    
    * More
    
    * Fix clippy
    
    * Fix more clippy
    
    * More
    
    * More
    
    * Fix
    
    * Fix
    
    * Use collect_columns
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../src/physical_optimizer/enforce_distribution.rs | 291 ++++++++++++---------
 .../core/src/physical_optimizer/enforce_sorting.rs |  19 +-
 .../core/src/physical_optimizer/join_selection.rs  |  79 +++---
 .../src/physical_optimizer/projection_pushdown.rs  |  49 +++-
 .../replace_with_order_preserving_variants.rs      |   2 +-
 datafusion/core/src/physical_planner.rs            |  18 +-
 datafusion/core/tests/fuzz_cases/join_fuzz.rs      |   8 +-
 datafusion/physical-expr/src/equivalence/class.rs  |  26 +-
 .../physical-expr/src/equivalence/properties.rs    |   7 +-
 datafusion/physical-plan/src/joins/hash_join.rs    | 197 +++++++-------
 .../physical-plan/src/joins/sort_merge_join.rs     | 152 +++++------
 .../physical-plan/src/joins/symmetric_hash_join.rs |  76 +++---
 datafusion/physical-plan/src/joins/test_utils.rs   |  12 +-
 datafusion/physical-plan/src/joins/utils.rs        | 177 +++++++++----
 datafusion/proto/proto/datafusion.proto            |   4 +-
 datafusion/proto/src/generated/prost.rs            |   4 +-
 datafusion/proto/src/physical_plan/mod.rs          |  73 ++++--
 .../proto/tests/cases/roundtrip_physical_plan.rs   |   8 +-
 18 files changed, 691 insertions(+), 511 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs 
b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
index 0c5c2d78b6..fab26c49c2 100644
--- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
@@ -51,7 +51,7 @@ use datafusion_physical_expr::expressions::{Column, NoOp};
 use datafusion_physical_expr::utils::map_columns_before_projection;
 use datafusion_physical_expr::{
     physical_exprs_equal, EquivalenceProperties, LexRequirementRef, 
PhysicalExpr,
-    PhysicalSortRequirement,
+    PhysicalExprRef, PhysicalSortRequirement,
 };
 use datafusion_physical_plan::sorts::sort::SortExec;
 use datafusion_physical_plan::unbounded_output;
@@ -285,19 +285,21 @@ fn adjust_input_keys_ordering(
     {
         match mode {
             PartitionMode::Partitioned => {
-                let join_constructor =
-                    |new_conditions: (Vec<(Column, Column)>, 
Vec<SortOptions>)| {
-                        HashJoinExec::try_new(
-                            left.clone(),
-                            right.clone(),
-                            new_conditions.0,
-                            filter.clone(),
-                            join_type,
-                            PartitionMode::Partitioned,
-                            *null_equals_null,
-                        )
-                        .map(|e| Arc::new(e) as _)
-                    };
+                let join_constructor = |new_conditions: (
+                    Vec<(PhysicalExprRef, PhysicalExprRef)>,
+                    Vec<SortOptions>,
+                )| {
+                    HashJoinExec::try_new(
+                        left.clone(),
+                        right.clone(),
+                        new_conditions.0,
+                        filter.clone(),
+                        join_type,
+                        PartitionMode::Partitioned,
+                        *null_equals_null,
+                    )
+                    .map(|e| Arc::new(e) as _)
+                };
                 return reorder_partitioned_join_keys(
                     requirements,
                     on,
@@ -346,18 +348,20 @@ fn adjust_input_keys_ordering(
         ..
     }) = plan.as_any().downcast_ref::<SortMergeJoinExec>()
     {
-        let join_constructor =
-            |new_conditions: (Vec<(Column, Column)>, Vec<SortOptions>)| {
-                SortMergeJoinExec::try_new(
-                    left.clone(),
-                    right.clone(),
-                    new_conditions.0,
-                    *join_type,
-                    new_conditions.1,
-                    *null_equals_null,
-                )
-                .map(|e| Arc::new(e) as _)
-            };
+        let join_constructor = |new_conditions: (
+            Vec<(PhysicalExprRef, PhysicalExprRef)>,
+            Vec<SortOptions>,
+        )| {
+            SortMergeJoinExec::try_new(
+                left.clone(),
+                right.clone(),
+                new_conditions.0,
+                *join_type,
+                new_conditions.1,
+                *null_equals_null,
+            )
+            .map(|e| Arc::new(e) as _)
+        };
         return reorder_partitioned_join_keys(
             requirements,
             on,
@@ -408,12 +412,14 @@ fn adjust_input_keys_ordering(
 
 fn reorder_partitioned_join_keys<F>(
     mut join_plan: PlanWithKeyRequirements,
-    on: &[(Column, Column)],
+    on: &[(PhysicalExprRef, PhysicalExprRef)],
     sort_options: Vec<SortOptions>,
     join_constructor: &F,
 ) -> Result<PlanWithKeyRequirements>
 where
-    F: Fn((Vec<(Column, Column)>, Vec<SortOptions>)) -> Result<Arc<dyn 
ExecutionPlan>>,
+    F: Fn(
+        (Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec<SortOptions>),
+    ) -> Result<Arc<dyn ExecutionPlan>>,
 {
     let parent_required = &join_plan.data;
     let join_key_pairs = extract_join_keys(on);
@@ -788,10 +794,10 @@ fn expected_expr_positions(
     Some(indexes)
 }
 
-fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs {
+fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) -> 
JoinKeyPairs {
     let (left_keys, right_keys) = on
         .iter()
-        .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
+        .map(|(l, r)| (l.clone() as _, r.clone() as _))
         .unzip();
     JoinKeyPairs {
         left_keys,
@@ -802,16 +808,11 @@ fn extract_join_keys(on: &[(Column, Column)]) -> 
JoinKeyPairs {
 fn new_join_conditions(
     new_left_keys: &[Arc<dyn PhysicalExpr>],
     new_right_keys: &[Arc<dyn PhysicalExpr>],
-) -> Vec<(Column, Column)> {
+) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
     new_left_keys
         .iter()
         .zip(new_right_keys.iter())
-        .map(|(l_key, r_key)| {
-            (
-                l_key.as_any().downcast_ref::<Column>().unwrap().clone(),
-                r_key.as_any().downcast_ref::<Column>().unwrap().clone(),
-            )
-        })
+        .map(|(l_key, r_key)| (l_key.clone(), r_key.clone()))
         .collect()
 }
 
@@ -1886,8 +1887,8 @@ pub(crate) mod tests {
 
         // Join on (a == b1)
         let join_on = vec![(
-            Column::new_with_schema("a", &schema()).unwrap(),
-            Column::new_with_schema("b1", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) 
as _,
         )];
 
         for join_type in join_types {
@@ -1905,8 +1906,9 @@ pub(crate) mod tests {
                 | JoinType::LeftAnti => {
                     // Join on (a == c)
                     let top_join_on = vec![(
-                        Column::new_with_schema("a", &join.schema()).unwrap(),
-                        Column::new_with_schema("c", &schema()).unwrap(),
+                        Arc::new(Column::new_with_schema("a", 
&join.schema()).unwrap())
+                            as _,
+                        Arc::new(Column::new_with_schema("c", 
&schema()).unwrap()) as _,
                     )];
                     let top_join = hash_join_exec(
                         join.clone(),
@@ -1966,8 +1968,9 @@ pub(crate) mod tests {
                     // This time we use (b1 == c) for top join
                     // Join on (b1 == c)
                     let top_join_on = vec![(
-                        Column::new_with_schema("b1", &join.schema()).unwrap(),
-                        Column::new_with_schema("c", &schema()).unwrap(),
+                        Arc::new(Column::new_with_schema("b1", 
&join.schema()).unwrap())
+                            as _,
+                        Arc::new(Column::new_with_schema("c", 
&schema()).unwrap()) as _,
                     )];
 
                     let top_join =
@@ -2031,8 +2034,8 @@ pub(crate) mod tests {
 
         // Join on (a == b)
         let join_on = vec![(
-            Column::new_with_schema("a", &schema()).unwrap(),
-            Column::new_with_schema("b", &schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _,
         )];
         let join = hash_join_exec(left, right.clone(), &join_on, 
&JoinType::Inner);
 
@@ -2045,8 +2048,8 @@ pub(crate) mod tests {
 
         // Join on (a1 == c)
         let top_join_on = vec![(
-            Column::new_with_schema("a1", &projection.schema()).unwrap(),
-            Column::new_with_schema("c", &schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a1", 
&projection.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _,
         )];
 
         let top_join = hash_join_exec(
@@ -2076,8 +2079,8 @@ pub(crate) mod tests {
 
         // Join on (a2 == c)
         let top_join_on = vec![(
-            Column::new_with_schema("a2", &projection.schema()).unwrap(),
-            Column::new_with_schema("c", &schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a2", 
&projection.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _,
         )];
 
         let top_join = hash_join_exec(projection, right, &top_join_on, 
&JoinType::Inner);
@@ -2110,8 +2113,8 @@ pub(crate) mod tests {
 
         // Join on (a == b)
         let join_on = vec![(
-            Column::new_with_schema("a", &schema()).unwrap(),
-            Column::new_with_schema("b", &schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _,
         )];
 
         let join = hash_join_exec(left, right.clone(), &join_on, 
&JoinType::Inner);
@@ -2128,8 +2131,8 @@ pub(crate) mod tests {
 
         // Join on (a == c)
         let top_join_on = vec![(
-            Column::new_with_schema("a", &projection2.schema()).unwrap(),
-            Column::new_with_schema("c", &schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", 
&projection2.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _,
         )];
 
         let top_join = hash_join_exec(projection2, right, &top_join_on, 
&JoinType::Inner);
@@ -2174,8 +2177,8 @@ pub(crate) mod tests {
 
         // Join on (a1 == a2)
         let join_on = vec![(
-            Column::new_with_schema("a1", &left.schema()).unwrap(),
-            Column::new_with_schema("a2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("a2", &right.schema()).unwrap()) 
as _,
         )];
         let join = hash_join_exec(left, right.clone(), &join_on, 
&JoinType::Inner);
 
@@ -2221,12 +2224,12 @@ pub(crate) mod tests {
         // Join on (b1 == b && a1 == a)
         let join_on = vec![
             (
-                Column::new_with_schema("b1", &left.schema()).unwrap(),
-                Column::new_with_schema("b", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b1", 
&left.schema()).unwrap()) as _,
+                Arc::new(Column::new_with_schema("b", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("a1", &left.schema()).unwrap(),
-                Column::new_with_schema("a", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a1", 
&left.schema()).unwrap()) as _,
+                Arc::new(Column::new_with_schema("a", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let join = hash_join_exec(left, right.clone(), &join_on, 
&JoinType::Inner);
@@ -2265,16 +2268,16 @@ pub(crate) mod tests {
         // Join on (a == a1 and b == b1 and c == c1)
         let join_on = vec![
             (
-                Column::new_with_schema("a", &schema()).unwrap(),
-                Column::new_with_schema("a1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("a1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema()).unwrap(),
-                Column::new_with_schema("b1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("b1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("c", &schema()).unwrap(),
-                Column::new_with_schema("c1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("c1", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let bottom_left_join =
@@ -2293,16 +2296,16 @@ pub(crate) mod tests {
         // Join on (c == c1 and b == b1 and a == a1)
         let join_on = vec![
             (
-                Column::new_with_schema("c", &schema()).unwrap(),
-                Column::new_with_schema("c1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("c1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema()).unwrap(),
-                Column::new_with_schema("b1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("b1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("a", &schema()).unwrap(),
-                Column::new_with_schema("a1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("a1", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let bottom_right_join =
@@ -2311,16 +2314,31 @@ pub(crate) mod tests {
         // Join on (B == b1 and C == c and AA = a1)
         let top_join_on = vec![
             (
-                Column::new_with_schema("B", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("b1", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("B", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("b1", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
             (
-                Column::new_with_schema("C", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("c", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("C", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("c", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
             (
-                Column::new_with_schema("AA", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("a1", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("AA", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("a1", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
         ];
 
@@ -2382,16 +2400,16 @@ pub(crate) mod tests {
         // Join on (a == a1 and b == b1 and c == c1)
         let join_on = vec![
             (
-                Column::new_with_schema("a", &schema()).unwrap(),
-                Column::new_with_schema("a1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("a1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema()).unwrap(),
-                Column::new_with_schema("b1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("b1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("c", &schema()).unwrap(),
-                Column::new_with_schema("c1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("c1", 
&right.schema()).unwrap()) as _,
             ),
         ];
 
@@ -2414,16 +2432,16 @@ pub(crate) mod tests {
         // Join on (c == c1 and b == b1 and a == a1)
         let join_on = vec![
             (
-                Column::new_with_schema("c", &schema()).unwrap(),
-                Column::new_with_schema("c1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("c1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema()).unwrap(),
-                Column::new_with_schema("b1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("b1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("a", &schema()).unwrap(),
-                Column::new_with_schema("a1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("a1", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let bottom_right_join = ensure_distribution_helper(
@@ -2435,16 +2453,31 @@ pub(crate) mod tests {
         // Join on (B == b1 and C == c and AA = a1)
         let top_join_on = vec![
             (
-                Column::new_with_schema("B", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("b1", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("B", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("b1", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
             (
-                Column::new_with_schema("C", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("c", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("C", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("c", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
             (
-                Column::new_with_schema("AA", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("a1", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("AA", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("a1", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
         ];
 
@@ -2512,12 +2545,12 @@ pub(crate) mod tests {
         // Join on (a == a1 and b == b1)
         let join_on = vec![
             (
-                Column::new_with_schema("a", &schema()).unwrap(),
-                Column::new_with_schema("a1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("a1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema()).unwrap(),
-                Column::new_with_schema("b1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("b1", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let bottom_left_join = ensure_distribution_helper(
@@ -2539,16 +2572,16 @@ pub(crate) mod tests {
         // Join on (c == c1 and b == b1 and a == a1)
         let join_on = vec![
             (
-                Column::new_with_schema("c", &schema()).unwrap(),
-                Column::new_with_schema("c1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("c1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema()).unwrap(),
-                Column::new_with_schema("b1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("b1", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("a", &schema()).unwrap(),
-                Column::new_with_schema("a1", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as 
_,
+                Arc::new(Column::new_with_schema("a1", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let bottom_right_join = ensure_distribution_helper(
@@ -2560,16 +2593,31 @@ pub(crate) mod tests {
         // Join on (B == b1 and C == c and AA = a1)
         let top_join_on = vec![
             (
-                Column::new_with_schema("B", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("b1", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("B", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("b1", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
             (
-                Column::new_with_schema("C", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("c", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("C", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("c", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
             (
-                Column::new_with_schema("AA", 
&bottom_left_projection.schema()).unwrap(),
-                Column::new_with_schema("a1", 
&bottom_right_join.schema()).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("AA", 
&bottom_left_projection.schema())
+                        .unwrap(),
+                ) as _,
+                Arc::new(
+                    Column::new_with_schema("a1", 
&bottom_right_join.schema()).unwrap(),
+                ) as _,
             ),
         ];
 
@@ -2648,8 +2696,8 @@ pub(crate) mod tests {
 
         // Join on (a == b1)
         let join_on = vec![(
-            Column::new_with_schema("a", &schema()).unwrap(),
-            Column::new_with_schema("b1", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) 
as _,
         )];
 
         for join_type in join_types {
@@ -2660,8 +2708,8 @@ pub(crate) mod tests {
 
             // Top join on (a == c)
             let top_join_on = vec![(
-                Column::new_with_schema("a", &join.schema()).unwrap(),
-                Column::new_with_schema("c", &schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a", 
&join.schema()).unwrap()) as _,
+                Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as 
_,
             )];
             let top_join = sort_merge_join_exec(
                 join.clone(),
@@ -2783,8 +2831,9 @@ pub(crate) mod tests {
                     // This time we use (b1 == c) for top join
                     // Join on (b1 == c)
                     let top_join_on = vec![(
-                        Column::new_with_schema("b1", &join.schema()).unwrap(),
-                        Column::new_with_schema("c", &schema()).unwrap(),
+                        Arc::new(Column::new_with_schema("b1", 
&join.schema()).unwrap())
+                            as _,
+                        Arc::new(Column::new_with_schema("c", 
&schema()).unwrap()) as _,
                     )];
                     let top_join = sort_merge_join_exec(
                         join,
@@ -2933,12 +2982,12 @@ pub(crate) mod tests {
         // Join on (b3 == b2 && a3 == a2)
         let join_on = vec![
             (
-                Column::new_with_schema("b3", &left.schema()).unwrap(),
-                Column::new_with_schema("b2", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("b3", 
&left.schema()).unwrap()) as _,
+                Arc::new(Column::new_with_schema("b2", 
&right.schema()).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("a3", &left.schema()).unwrap(),
-                Column::new_with_schema("a2", &right.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("a3", 
&left.schema()).unwrap()) as _,
+                Arc::new(Column::new_with_schema("a2", 
&right.schema()).unwrap()) as _,
             ),
         ];
         let join = sort_merge_join_exec(left, right.clone(), &join_on, 
&JoinType::Inner);
diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs 
b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
index 3aa9cdad18..5c46e64a22 100644
--- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
@@ -985,8 +985,8 @@ mod tests {
         let right_input = parquet_exec_sorted(&right_schema, 
parquet_sort_exprs);
 
         let on = vec![(
-            Column::new_with_schema("col_a", &left_schema)?,
-            Column::new_with_schema("c", &right_schema)?,
+            Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _,
+            Arc::new(Column::new_with_schema("c", &right_schema)?) as _,
         )];
         let join = hash_join_exec(left_input, right_input, on, None, 
&JoinType::Inner)?;
         let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], 
join);
@@ -1639,8 +1639,9 @@ mod tests {
 
         // Join on (nullable_col == col_a)
         let join_on = vec![(
-            Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
-            Column::new_with_schema("col_a", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("nullable_col", 
&left.schema()).unwrap())
+                as _,
+            Arc::new(Column::new_with_schema("col_a", 
&right.schema()).unwrap()) as _,
         )];
 
         let join_types = vec![
@@ -1711,8 +1712,9 @@ mod tests {
 
         // Join on (nullable_col == col_a)
         let join_on = vec![(
-            Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
-            Column::new_with_schema("col_a", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("nullable_col", 
&left.schema()).unwrap())
+                as _,
+            Arc::new(Column::new_with_schema("col_a", 
&right.schema()).unwrap()) as _,
         )];
 
         let join_types = vec![
@@ -1785,8 +1787,9 @@ mod tests {
 
         // Join on (nullable_col == col_a)
         let join_on = vec![(
-            Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
-            Column::new_with_schema("col_a", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("nullable_col", 
&left.schema()).unwrap())
+                as _,
+            Arc::new(Column::new_with_schema("col_a", 
&right.schema()).unwrap()) as _,
         )];
 
         let join = sort_merge_join_exec(left, right, &join_on, 
&JoinType::Inner);
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs 
b/datafusion/core/src/physical_optimizer/join_selection.rs
index 083cd5ecab..02626056f6 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -690,7 +690,7 @@ mod tests_statistical {
     use arrow::datatypes::{DataType, Field, Schema};
     use datafusion_common::{stats::Precision, JoinType, ScalarValue};
     use datafusion_physical_expr::expressions::Column;
-    use datafusion_physical_expr::PhysicalExpr;
+    use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
 
     /// Return statistcs for empty table
     fn empty_statistics() -> Statistics {
@@ -860,8 +860,10 @@ mod tests_statistical {
                 Arc::clone(&big),
                 Arc::clone(&small),
                 vec![(
-                    Column::new_with_schema("big_col", &big.schema()).unwrap(),
-                    Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
+                    Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()),
+                    Arc::new(
+                        Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
+                    ),
                 )],
                 None,
                 &JoinType::Left,
@@ -914,8 +916,10 @@ mod tests_statistical {
                 Arc::clone(&small),
                 Arc::clone(&big),
                 vec![(
-                    Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
-                    Column::new_with_schema("big_col", &big.schema()).unwrap(),
+                    Arc::new(
+                        Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
+                    ),
+                    Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()),
                 )],
                 None,
                 &JoinType::Left,
@@ -970,8 +974,13 @@ mod tests_statistical {
                     Arc::clone(&big),
                     Arc::clone(&small),
                     vec![(
-                        Column::new_with_schema("big_col", 
&big.schema()).unwrap(),
-                        Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
+                        Arc::new(
+                            Column::new_with_schema("big_col", 
&big.schema()).unwrap(),
+                        ),
+                        Arc::new(
+                            Column::new_with_schema("small_col", 
&small.schema())
+                                .unwrap(),
+                        ),
                     )],
                     None,
                     &join_type,
@@ -1040,8 +1049,8 @@ mod tests_statistical {
             Arc::clone(&big),
             Arc::clone(&small),
             vec![(
-                Column::new_with_schema("big_col", &big.schema()).unwrap(),
-                Column::new_with_schema("small_col", &small.schema()).unwrap(),
+                Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()),
+                Arc::new(Column::new_with_schema("small_col", 
&small.schema()).unwrap()),
             )],
             None,
             &JoinType::Inner,
@@ -1056,8 +1065,10 @@ mod tests_statistical {
             Arc::clone(&medium),
             Arc::new(child_join),
             vec![(
-                Column::new_with_schema("medium_col", 
&medium.schema()).unwrap(),
-                Column::new_with_schema("small_col", &child_schema).unwrap(),
+                Arc::new(
+                    Column::new_with_schema("medium_col", 
&medium.schema()).unwrap(),
+                ),
+                Arc::new(Column::new_with_schema("small_col", 
&child_schema).unwrap()),
             )],
             None,
             &JoinType::Left,
@@ -1094,8 +1105,10 @@ mod tests_statistical {
                 Arc::clone(&small),
                 Arc::clone(&big),
                 vec![(
-                    Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
-                    Column::new_with_schema("big_col", &big.schema()).unwrap(),
+                    Arc::new(
+                        Column::new_with_schema("small_col", 
&small.schema()).unwrap(),
+                    ),
+                    Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()),
                 )],
                 None,
                 &JoinType::Inner,
@@ -1178,8 +1191,8 @@ mod tests_statistical {
         ));
 
         let join_on = vec![(
-            Column::new_with_schema("small_col", &small.schema()).unwrap(),
-            Column::new_with_schema("big_col", &big.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("small_col", 
&small.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(
             small.clone(),
@@ -1190,8 +1203,8 @@ mod tests_statistical {
         );
 
         let join_on = vec![(
-            Column::new_with_schema("big_col", &big.schema()).unwrap(),
-            Column::new_with_schema("small_col", &small.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("small_col", 
&small.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(
             big.clone(),
@@ -1202,8 +1215,8 @@ mod tests_statistical {
         );
 
         let join_on = vec![(
-            Column::new_with_schema("small_col", &small.schema()).unwrap(),
-            Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("small_col", 
&small.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("empty_col", 
&empty.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(
             small.clone(),
@@ -1214,8 +1227,8 @@ mod tests_statistical {
         );
 
         let join_on = vec![(
-            Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
-            Column::new_with_schema("small_col", &small.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("empty_col", 
&empty.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("small_col", 
&small.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(
             empty.clone(),
@@ -1244,8 +1257,9 @@ mod tests_statistical {
         ));
 
         let join_on = vec![(
-            Column::new_with_schema("big_col", &big.schema()).unwrap(),
-            Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("bigger_col", 
&bigger.schema()).unwrap())
+                as _,
         )];
         check_join_partition_mode(
             big.clone(),
@@ -1256,8 +1270,9 @@ mod tests_statistical {
         );
 
         let join_on = vec![(
-            Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
-            Column::new_with_schema("big_col", &big.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("bigger_col", 
&bigger.schema()).unwrap())
+                as _,
+            Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(
             bigger.clone(),
@@ -1268,8 +1283,8 @@ mod tests_statistical {
         );
 
         let join_on = vec![(
-            Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
-            Column::new_with_schema("big_col", &big.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("empty_col", 
&empty.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(
             empty.clone(),
@@ -1280,8 +1295,8 @@ mod tests_statistical {
         );
 
         let join_on = vec![(
-            Column::new_with_schema("big_col", &big.schema()).unwrap(),
-            Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("big_col", 
&big.schema()).unwrap()) as _,
+            Arc::new(Column::new_with_schema("empty_col", 
&empty.schema()).unwrap()) as _,
         )];
         check_join_partition_mode(big, empty, join_on, false, 
PartitionMode::Partitioned);
     }
@@ -1289,7 +1304,7 @@ mod tests_statistical {
     fn check_join_partition_mode(
         left: Arc<StatisticsExec>,
         right: Arc<StatisticsExec>,
-        on: Vec<(Column, Column)>,
+        on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
         is_swapped: bool,
         expected_mode: PartitionMode,
     ) {
@@ -1748,8 +1763,8 @@ mod hash_join_tests {
             Arc::clone(&left_exec),
             Arc::clone(&right_exec),
             vec![(
-                Column::new_with_schema("a", &left_exec.schema())?,
-                Column::new_with_schema("b", &right_exec.schema())?,
+                Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
+                Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
             )],
             None,
             &t.initial_join_type,
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs 
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index 2d20c487e4..301a97bba4 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -44,10 +44,11 @@ use crate::physical_plan::{Distribution, ExecutionPlan};
 use arrow_schema::SchemaRef;
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
-use datafusion_common::JoinSide;
+use datafusion_common::{DataFusionError, JoinSide};
 use datafusion_physical_expr::expressions::{Column, Literal};
 use datafusion_physical_expr::{
-    Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
+    Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
+    PhysicalSortRequirement,
 };
 use datafusion_physical_plan::streaming::StreamingTableExec;
 use datafusion_physical_plan::union::UnionExec;
@@ -1000,8 +1001,8 @@ fn join_table_borders(
 fn update_join_on(
     proj_left_exprs: &[(Column, String)],
     proj_right_exprs: &[(Column, String)],
-    hash_join_on: &[(Column, Column)],
-) -> Option<Vec<(Column, Column)>> {
+    hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
+) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
     // TODO: Clippy wants the "map" call removed, but doing so generates
     //       a compilation error. Remove the clippy directive once this
     //       issue is fixed.
@@ -1024,17 +1025,41 @@ fn update_join_on(
 /// operation based on a set of equi-join conditions (`hash_join_on`) and a
 /// list of projection expressions (`projection_exprs`).
 fn new_columns_for_join_on(
-    hash_join_on: &[&Column],
+    hash_join_on: &[&PhysicalExprRef],
     projection_exprs: &[(Column, String)],
-) -> Option<Vec<Column>> {
+) -> Option<Vec<PhysicalExprRef>> {
     let new_columns = hash_join_on
         .iter()
         .filter_map(|on| {
-            projection_exprs
-                .iter()
-                .enumerate()
-                .find(|(_, (proj_column, _))| on.name() == proj_column.name())
-                .map(|(index, (_, alias))| Column::new(alias, index))
+            // Rewrite all columns in `on`
+            (*on)
+                .clone()
+                .transform(&|expr| {
+                    if let Some(column) = 
expr.as_any().downcast_ref::<Column>() {
+                        // Find the column in the projection expressions
+                        let new_column = projection_exprs
+                            .iter()
+                            .enumerate()
+                            .find(|(_, (proj_column, _))| {
+                                column.name() == proj_column.name()
+                            })
+                            .map(|(index, (_, alias))| Column::new(alias, 
index));
+                        if let Some(new_column) = new_column {
+                            Ok(Transformed::Yes(Arc::new(new_column)))
+                        } else {
+                            // If the column is not found in the projection 
expressions,
+                            // it means that the column is not projected. In 
this case,
+                            // we cannot push the projection down.
+                            Err(DataFusionError::Internal(format!(
+                                "Column {:?} not found in projection 
expressions",
+                                column
+                            )))
+                        }
+                    } else {
+                        Ok(Transformed::No(expr))
+                    }
+                })
+                .ok()
         })
         .collect::<Vec<_>>();
     (new_columns.len() == hash_join_on.len()).then_some(new_columns)
@@ -2018,7 +2043,7 @@ mod tests {
         let join: Arc<dyn ExecutionPlan> = 
Arc::new(SymmetricHashJoinExec::try_new(
             left_csv,
             right_csv,
-            vec![(Column::new("b", 1), Column::new("c", 2))],
+            vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c", 
2)))],
             // b_left-(1+a_right)<=a_right+c_left
             Some(JoinFilter::new(
                 Arc::new(BinaryExpr::new(
diff --git 
a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
 
b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
index 4656b5b270..bc9bd0010d 100644
--- 
a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
+++ 
b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
@@ -1440,7 +1440,7 @@ mod tests {
             HashJoinExec::try_new(
                 left,
                 right,
-                vec![(left_col.clone(), right_col.clone())],
+                vec![(Arc::new(left_col.clone()), 
Arc::new(right_col.clone()))],
                 None,
                 &JoinType::Inner,
                 PartitionMode::Partitioned,
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index d383ddce92..d4ef40493d 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1036,15 +1036,21 @@ impl DefaultPhysicalPlanner {
                     let [physical_left, physical_right]: [Arc<dyn 
ExecutionPlan>; 2] = left_right.try_into().map_err(|_| 
DataFusionError::Internal("`create_initial_plan_multi` is 
broken".to_string()))?;
                     let left_df_schema = left.schema();
                     let right_df_schema = right.schema();
+                    let execution_props = session_state.execution_props();
                     let join_on = keys
                         .iter()
                         .map(|(l, r)| {
-                            let l = l.try_into_col()?;
-                            let r = r.try_into_col()?;
-                            Ok((
-                                Column::new(&l.name, 
left_df_schema.index_of_column(&l)?),
-                                Column::new(&r.name, 
right_df_schema.index_of_column(&r)?),
-                            ))
+                            let l = create_physical_expr(
+                                l,
+                                left_df_schema,
+                                execution_props
+                            )?;
+                            let r = create_physical_expr(
+                                r,
+                                right_df_schema,
+                                execution_props
+                            )?;
+                            Ok((l, r))
                         })
                         .collect::<Result<join_utils::JoinOn>>()?;
 
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index ac86364f42..1c819ac466 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -109,12 +109,12 @@ async fn run_join_test(
         let schema2 = input2[0].schema();
         let on_columns = vec![
             (
-                Column::new_with_schema("a", &schema1).unwrap(),
-                Column::new_with_schema("a", &schema2).unwrap(),
+                Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
+                Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
             ),
             (
-                Column::new_with_schema("b", &schema1).unwrap(),
-                Column::new_with_schema("b", &schema2).unwrap(),
+                Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
+                Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
             ),
         ];
 
diff --git a/datafusion/physical-expr/src/equivalence/class.rs 
b/datafusion/physical-expr/src/equivalence/class.rs
index f0bd1740d5..1f79701871 100644
--- a/datafusion/physical-expr/src/equivalence/class.rs
+++ b/datafusion/physical-expr/src/equivalence/class.rs
@@ -19,7 +19,7 @@ use super::{add_offset_to_expr, collapse_lex_req, 
ProjectionMapping};
 use crate::{
     expressions::Column, physical_expr::deduplicate_physical_exprs,
     physical_exprs_bag_equal, physical_exprs_contains, LexOrdering, 
LexOrderingRef,
-    LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr,
+    LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, 
PhysicalSortExpr,
     PhysicalSortRequirement,
 };
 use datafusion_common::tree_node::TreeNode;
@@ -427,7 +427,7 @@ impl EquivalenceGroup {
         right_equivalences: &Self,
         join_type: &JoinType,
         left_size: usize,
-        on: &[(Column, Column)],
+        on: &[(PhysicalExprRef, PhysicalExprRef)],
     ) -> Self {
         match join_type {
             JoinType::Inner | JoinType::Left | JoinType::Full | 
JoinType::Right => {
@@ -445,9 +445,25 @@ impl EquivalenceGroup {
                 // are equal in the resulting table.
                 if join_type == &JoinType::Inner {
                     for (lhs, rhs) in on.iter() {
-                        let index = rhs.index() + left_size;
-                        let new_lhs = Arc::new(lhs.clone()) as _;
-                        let new_rhs = Arc::new(Column::new(rhs.name(), index)) 
as _;
+                        let new_lhs = lhs.clone() as _;
+                        // Rewrite rhs to point to the right side of the join:
+                        let new_rhs = rhs
+                            .clone()
+                            .transform(&|expr| {
+                                if let Some(column) =
+                                    expr.as_any().downcast_ref::<Column>()
+                                {
+                                    let new_column = Arc::new(Column::new(
+                                        column.name(),
+                                        column.index() + left_size,
+                                    ))
+                                        as _;
+                                    return Ok(Transformed::Yes(new_column));
+                                }
+
+                                Ok(Transformed::No(expr))
+                            })
+                            .unwrap();
                         result.add_equal_conditions(&new_lhs, &new_rhs);
                     }
                 }
diff --git a/datafusion/physical-expr/src/equivalence/properties.rs 
b/datafusion/physical-expr/src/equivalence/properties.rs
index cd0ae09a92..2471d9249e 100644
--- a/datafusion/physical-expr/src/equivalence/properties.rs
+++ b/datafusion/physical-expr/src/equivalence/properties.rs
@@ -23,11 +23,12 @@ use super::ordering::collapse_lex_ordering;
 use crate::equivalence::{
     collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, 
ProjectionMapping,
 };
-use crate::expressions::{Column, Literal};
+use crate::expressions::Literal;
 use crate::sort_properties::{ExprOrdering, SortProperties};
 use crate::{
     physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement,
-    LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
+    LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
+    PhysicalSortRequirement,
 };
 
 use arrow_schema::SchemaRef;
@@ -1099,7 +1100,7 @@ pub fn join_equivalence_properties(
     join_schema: SchemaRef,
     maintains_input_order: &[bool],
     probe_side: Option<JoinSide>,
-    on: &[(Column, Column)],
+    on: &[(PhysicalExprRef, PhysicalExprRef)],
 ) -> EquivalenceProperties {
     let left_size = left.schema.fields.len();
     let mut result = EquivalenceProperties::new(join_schema);
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs 
b/datafusion/physical-plan/src/joins/hash_join.rs
index 0c213f4257..cd8b17d135 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -30,7 +30,6 @@ use crate::joins::utils::{
 };
 use crate::{
     coalesce_partitions::CoalescePartitionsExec,
-    expressions::Column,
     expressions::PhysicalSortExpr,
     hash_utils::create_hashes,
     joins::utils::{
@@ -39,8 +38,8 @@ use crate::{
         BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn, 
StatefulStreamResult,
     },
     metrics::{ExecutionPlanMetricsSet, MetricsSet},
-    DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
-    RecordBatchStream, SendableRecordBatchStream, Statistics,
+    DisplayFormatType, Distribution, ExecutionPlan, Partitioning, 
RecordBatchStream,
+    SendableRecordBatchStream, Statistics,
 };
 use crate::{handle_state, DisplayAs};
 
@@ -67,7 +66,7 @@ use datafusion_common::{
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::equivalence::join_equivalence_properties;
-use datafusion_physical_expr::EquivalenceProperties;
+use datafusion_physical_expr::{EquivalenceProperties, PhysicalExprRef};
 
 use ahash::RandomState;
 use futures::{ready, Stream, StreamExt, TryStreamExt};
@@ -278,7 +277,7 @@ pub struct HashJoinExec {
     /// right (probe) side which are filtered by the hash table
     pub right: Arc<dyn ExecutionPlan>,
     /// Set of equijoin columns from the relations: `(left_col, right_col)`
-    pub on: Vec<(Column, Column)>,
+    pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
     /// Filters which are applied while finding matching rows
     pub filter: Option<JoinFilter>,
     /// How the join is performed (`OUTER`, `INNER`, etc)
@@ -369,7 +368,7 @@ impl HashJoinExec {
     }
 
     /// Set of common columns used to join on
-    pub fn on(&self) -> &[(Column, Column)] {
+    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
         &self.on
     }
 
@@ -451,16 +450,8 @@ impl ExecutionPlan for HashJoinExec {
                 Distribution::UnspecifiedDistribution,
             ],
             PartitionMode::Partitioned => {
-                let (left_expr, right_expr) = self
-                    .on
-                    .iter()
-                    .map(|(l, r)| {
-                        (
-                            Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
-                            Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
-                        )
-                    })
-                    .unzip();
+                let (left_expr, right_expr) =
+                    self.on.iter().map(|(l, r)| (l.clone(), 
r.clone())).unzip();
                 vec![
                     Distribution::HashPartitioned(left_expr),
                     Distribution::HashPartitioned(right_expr),
@@ -697,7 +688,7 @@ async fn collect_left_input(
     partition: Option<usize>,
     random_state: RandomState,
     left: Arc<dyn ExecutionPlan>,
-    on_left: Vec<Column>,
+    on_left: Vec<PhysicalExprRef>,
     context: Arc<TaskContext>,
     metrics: BuildProbeJoinMetrics,
     reservation: MemoryReservation,
@@ -793,7 +784,7 @@ async fn collect_left_input(
 /// as a chain head for rows with equal hash values.
 #[allow(clippy::too_many_arguments)]
 pub fn update_hash<T>(
-    on: &[Column],
+    on: &[PhysicalExprRef],
     batch: &RecordBatch,
     hash_map: &mut T,
     offset: usize,
@@ -955,9 +946,9 @@ struct HashJoinStream {
     /// Input schema
     schema: Arc<Schema>,
     /// equijoin columns from the left (build side)
-    on_left: Vec<Column>,
+    on_left: Vec<PhysicalExprRef>,
     /// equijoin columns from the right (probe side)
-    on_right: Vec<Column>,
+    on_right: Vec<PhysicalExprRef>,
     /// optional join filter
     filter: Option<JoinFilter>,
     /// type of the join (left, right, semi, etc)
@@ -1043,8 +1034,8 @@ fn lookup_join_hashmap(
     build_hashmap: &JoinHashMap,
     build_input_buffer: &RecordBatch,
     probe_batch: &RecordBatch,
-    build_on: &[Column],
-    probe_on: &[Column],
+    build_on: &[PhysicalExprRef],
+    probe_on: &[PhysicalExprRef],
     null_equals_null: bool,
     hashes_buffer: &[u64],
     limit: usize,
@@ -1437,6 +1428,7 @@ mod tests {
     use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
     use datafusion_expr::Operator;
     use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
+    use datafusion_physical_expr::PhysicalExpr;
 
     use hashbrown::raw::RawTable;
     use rstest::*;
@@ -1529,15 +1521,8 @@ mod tests {
     ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
         let partition_count = 4;
 
-        let (left_expr, right_expr) = on
-            .iter()
-            .map(|(l, r)| {
-                (
-                    Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
-                    Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
-                )
-            })
-            .unzip();
+        let (left_expr, right_expr) =
+            on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip();
 
         let join = HashJoinExec::try_new(
             Arc::new(RepartitionExec::try_new(
@@ -1588,8 +1573,8 @@ mod tests {
         );
 
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (columns, batches) = join_collect(
@@ -1635,8 +1620,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (columns, batches) = partitioned_join_collect(
@@ -1679,8 +1664,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (columns, batches) =
@@ -1718,8 +1703,8 @@ mod tests {
             ("c2", &vec![80, 90, 70]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (columns, batches) =
@@ -1760,12 +1745,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
 
@@ -1822,12 +1807,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
 
@@ -1884,8 +1869,8 @@ mod tests {
             ("c2", &vec![80, 90, 70]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (columns, batches) =
@@ -1934,8 +1919,8 @@ mod tests {
         );
 
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let join = join(left, right, on, &JoinType::Inner, false)?;
@@ -2016,8 +2001,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b1", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) 
as _,
         )];
 
         let join = join(left, right, on, &JoinType::Left, false).unwrap();
@@ -2059,8 +2044,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
         let join = join(left, right, on, &JoinType::Full, false).unwrap();
@@ -2099,8 +2084,8 @@ mod tests {
         );
         let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", 
&vec![]));
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b1", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) 
as _,
         )];
         let schema = right.schema();
         let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, 
None).unwrap());
@@ -2136,8 +2121,8 @@ mod tests {
         );
         let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", 
&vec![]));
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
         let schema = right.schema();
         let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, 
None).unwrap());
@@ -2177,8 +2162,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (columns, batches) = join_collect(
@@ -2221,8 +2206,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (columns, batches) = partitioned_join_collect(
@@ -2278,8 +2263,8 @@ mod tests {
         let right = build_semi_anti_right_table();
         // left_table left semi join right_table on left_table.b1 = 
right_table.b2
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let join = join(left, right, on, &JoinType::LeftSemi, false)?;
@@ -2314,8 +2299,8 @@ mod tests {
 
         // left_table left semi join right_table on left_table.b1 = 
right_table.b2 and right_table.a2 != 10
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let column_indices = vec![ColumnIndex {
@@ -2401,8 +2386,8 @@ mod tests {
 
         // left_table right semi join right_table on left_table.b1 = 
right_table.b2
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let join = join(left, right, on, &JoinType::RightSemi, false)?;
@@ -2438,8 +2423,8 @@ mod tests {
 
         // left_table right semi join right_table on left_table.b1 = 
right_table.b2 on left_table.a1!=9
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let column_indices = vec![ColumnIndex {
@@ -2527,8 +2512,8 @@ mod tests {
         let right = build_semi_anti_right_table();
         // left_table left anti join right_table on left_table.b1 = 
right_table.b2
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let join = join(left, right, on, &JoinType::LeftAnti, false)?;
@@ -2561,8 +2546,8 @@ mod tests {
         let right = build_semi_anti_right_table();
         // left_table left anti join right_table on left_table.b1 = 
right_table.b2 and right_table.a2!=8
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let column_indices = vec![ColumnIndex {
@@ -2654,8 +2639,8 @@ mod tests {
         let left = build_semi_anti_left_table();
         let right = build_semi_anti_right_table();
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let join = join(left, right, on, &JoinType::RightAnti, false)?;
@@ -2689,8 +2674,8 @@ mod tests {
         let right = build_semi_anti_right_table();
         // left_table right anti join right_table on left_table.b1 = 
right_table.b2 and left_table.a1!=13
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let column_indices = vec![ColumnIndex {
@@ -2797,8 +2782,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (columns, batches) =
@@ -2836,8 +2821,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (columns, batches) =
@@ -2876,8 +2861,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
         let join = join(left, right, on, &JoinType::Full, false)?;
@@ -2930,7 +2915,7 @@ mod tests {
         );
 
         // Join key column for both join sides
-        let key_column = Column::new("a", 0);
+        let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _;
 
         let join_hash_map = JoinHashMap::new(hashmap_left, next);
 
@@ -2981,8 +2966,8 @@ mod tests {
         );
         let on = vec![(
             // join on a=b so there are duplicate column names on unjoined 
columns
-            Column::new_with_schema("a", &left.schema()).unwrap(),
-            Column::new_with_schema("b", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as 
_,
+            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) 
as _,
         )];
 
         let join = join(left, right, on, &JoinType::Inner, false)?;
@@ -3045,8 +3030,8 @@ mod tests {
             ("c", &vec![7, 5, 6, 4]),
         );
         let on = vec![(
-            Column::new_with_schema("a", &left.schema()).unwrap(),
-            Column::new_with_schema("b", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as 
_,
+            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) 
as _,
         )];
         let filter = prepare_join_filter();
 
@@ -3086,8 +3071,8 @@ mod tests {
             ("c", &vec![7, 5, 6, 4]),
         );
         let on = vec![(
-            Column::new_with_schema("a", &left.schema()).unwrap(),
-            Column::new_with_schema("b", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as 
_,
+            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) 
as _,
         )];
         let filter = prepare_join_filter();
 
@@ -3130,8 +3115,8 @@ mod tests {
             ("c", &vec![7, 5, 6, 4]),
         );
         let on = vec![(
-            Column::new_with_schema("a", &left.schema()).unwrap(),
-            Column::new_with_schema("b", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as 
_,
+            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) 
as _,
         )];
         let filter = prepare_join_filter();
 
@@ -3173,8 +3158,8 @@ mod tests {
             ("c", &vec![7, 5, 6, 4]),
         );
         let on = vec![(
-            Column::new_with_schema("a", &left.schema()).unwrap(),
-            Column::new_with_schema("b", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as 
_,
+            Arc::new(Column::new_with_schema("b", &right.schema()).unwrap()) 
as _,
         )];
         let filter = prepare_join_filter();
 
@@ -3223,8 +3208,8 @@ mod tests {
         let right = Arc::new(MemoryExec::try_new(&[vec![batch]], schema, 
None).unwrap());
 
         let on = vec![(
-            Column::new_with_schema("date", &left.schema()).unwrap(),
-            Column::new_with_schema("date", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("date", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("date", 
&right.schema()).unwrap()) as _,
         )];
 
         let join = join(left, right, on, &JoinType::Inner, false)?;
@@ -3261,8 +3246,8 @@ mod tests {
         let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", 
&vec![]));
 
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b1", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap()) 
as _,
         )];
         let schema = right.schema();
         let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", 
&vec![]));
@@ -3317,8 +3302,8 @@ mod tests {
             ("c2", &vec![0, 0, 0, 0, 0]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
         let join_types = vec![
@@ -3451,8 +3436,8 @@ mod tests {
             ("c2", &vec![14, 15]),
         );
         let on = vec![(
-            Column::new_with_schema("a1", &left.schema()).unwrap(),
-            Column::new_with_schema("b2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
         let join_types = vec![
@@ -3520,8 +3505,8 @@ mod tests {
             .unwrap(),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left_batch.schema())?,
-            Column::new_with_schema("b2", &right_batch.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left_batch.schema())?) as 
_,
+            Arc::new(Column::new_with_schema("b2", &right_batch.schema())?) as 
_,
         )];
 
         let join_types = vec![
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index f6fdc6d77c..675e90fb63 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -30,7 +30,7 @@ use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 
-use crate::expressions::{Column, PhysicalSortExpr};
+use crate::expressions::PhysicalSortExpr;
 use crate::joins::utils::{
     build_join_schema, calculate_join_output_ordering, check_join_is_valid,
     estimate_join_statistics, partitioned_join_output_partitioning, JoinOn,
@@ -52,7 +52,9 @@ use datafusion_common::{
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::equivalence::join_equivalence_properties;
-use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement};
+use datafusion_physical_expr::{
+    EquivalenceProperties, PhysicalExprRef, PhysicalSortRequirement,
+};
 
 use futures::{Stream, StreamExt};
 
@@ -120,11 +122,11 @@ impl SortMergeJoinExec {
             .zip(sort_options.iter())
             .map(|((l, r), sort_op)| {
                 let left = PhysicalSortExpr {
-                    expr: Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
+                    expr: l.clone(),
                     options: *sort_op,
                 };
                 let right = PhysicalSortExpr {
-                    expr: Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
+                    expr: r.clone(),
                     options: *sort_op,
                 };
                 (left, right)
@@ -189,7 +191,7 @@ impl SortMergeJoinExec {
     }
 
     /// Set of common columns used to join on
-    pub fn on(&self) -> &[(Column, Column)] {
+    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
         &self.on
     }
 
@@ -236,16 +238,8 @@ impl ExecutionPlan for SortMergeJoinExec {
     }
 
     fn required_input_distribution(&self) -> Vec<Distribution> {
-        let (left_expr, right_expr) = self
-            .on
-            .iter()
-            .map(|(l, r)| {
-                (
-                    Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
-                    Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
-                )
-            })
-            .unzip();
+        let (left_expr, right_expr) =
+            self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip();
         vec![
             Distribution::HashPartitioned(left_expr),
             Distribution::HashPartitioned(right_expr),
@@ -483,7 +477,7 @@ struct StreamedBatch {
 }
 
 impl StreamedBatch {
-    fn new(batch: RecordBatch, on_column: &[Column]) -> Self {
+    fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
         let join_arrays = join_arrays(&batch, on_column);
         StreamedBatch {
             batch,
@@ -547,7 +541,11 @@ struct BufferedBatch {
 }
 
 impl BufferedBatch {
-    fn new(batch: RecordBatch, range: Range<usize>, on_column: &[Column]) -> 
Self {
+    fn new(
+        batch: RecordBatch,
+        range: Range<usize>,
+        on_column: &[PhysicalExprRef],
+    ) -> Self {
         let join_arrays = join_arrays(&batch, on_column);
 
         // Estimation is calculated as
@@ -609,9 +607,9 @@ struct SMJStream {
     /// The comparison result of current streamed row and buffered batches
     pub current_ordering: Ordering,
     /// Join key columns of streamed
-    pub on_streamed: Vec<Column>,
+    pub on_streamed: Vec<PhysicalExprRef>,
     /// Join key columns of buffered
-    pub on_buffered: Vec<Column>,
+    pub on_buffered: Vec<PhysicalExprRef>,
     /// Staging output array builders
     pub output_record_batches: Vec<RecordBatch>,
     /// Staging output size, including output batches and staging joined 
results
@@ -736,8 +734,8 @@ impl SMJStream {
         null_equals_null: bool,
         streamed: SendableRecordBatchStream,
         buffered: SendableRecordBatchStream,
-        on_streamed: Vec<Column>,
-        on_buffered: Vec<Column>,
+        on_streamed: Vec<Arc<dyn PhysicalExpr>>,
+        on_buffered: Vec<Arc<dyn PhysicalExpr>>,
         join_type: JoinType,
         batch_size: usize,
         join_metrics: SortMergeJoinMetrics,
@@ -1218,10 +1216,14 @@ impl BufferedData {
 }
 
 /// Get join array refs of given batch and join columns
-fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec<ArrayRef> {
+fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> 
Vec<ArrayRef> {
     on_column
         .iter()
-        .map(|c| batch.column(c.index()).clone())
+        .map(|c| {
+            let num_rows = batch.num_rows();
+            let c = c.evaluate(batch).unwrap();
+            c.into_array(num_rows).unwrap()
+        })
         .collect()
 }
 
@@ -1582,8 +1584,8 @@ mod tests {
         );
 
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Inner).await?;
@@ -1616,12 +1618,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
 
@@ -1654,12 +1656,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
 
@@ -1693,12 +1695,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
 
@@ -1731,12 +1733,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
         let (_, batches) = join_collect_with_options(
@@ -1783,12 +1785,12 @@ mod tests {
         );
         let on = vec![
             (
-                Column::new_with_schema("a1", &left.schema())?,
-                Column::new_with_schema("a1", &right.schema())?,
+                Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
             ),
             (
-                Column::new_with_schema("b2", &left.schema())?,
-                Column::new_with_schema("b2", &right.schema())?,
+                Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+                Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
             ),
         ];
 
@@ -1824,8 +1826,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Left).await?;
@@ -1856,8 +1858,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Right).await?;
@@ -1888,8 +1890,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema()).unwrap(),
-            Column::new_with_schema("b2", &right.schema()).unwrap(),
+            Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) 
as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) 
as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Full).await?;
@@ -1920,8 +1922,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::LeftAnti).await?;
@@ -1951,8 +1953,8 @@ mod tests {
             ("c2", &vec![70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::LeftSemi).await?;
@@ -1984,8 +1986,8 @@ mod tests {
         );
         let on = vec![(
             // join on a=b so there are duplicate column names on unjoined 
columns
-            Column::new_with_schema("a", &left.schema())?,
-            Column::new_with_schema("b", &right.schema())?,
+            Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Inner).await?;
@@ -2016,8 +2018,8 @@ mod tests {
         );
 
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Inner).await?;
@@ -2048,8 +2050,8 @@ mod tests {
         );
 
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b1", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Inner).await?;
@@ -2079,8 +2081,8 @@ mod tests {
             ("c2", &vec![50, 60, 70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Left).await?;
@@ -2115,8 +2117,8 @@ mod tests {
             ("c2", &vec![60, 70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Right).await?;
@@ -2159,8 +2161,8 @@ mod tests {
         let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
         let right = build_table_from_batches(vec![right_batch_1, 
right_batch_2]);
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Left).await?;
@@ -2208,8 +2210,8 @@ mod tests {
         let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
         let right = build_table_from_batches(vec![right_batch_1, 
right_batch_2]);
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Right).await?;
@@ -2257,8 +2259,8 @@ mod tests {
         let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
         let right = build_table_from_batches(vec![right_batch_1, 
right_batch_2]);
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
 
         let (_, batches) = join_collect(left, right, on, 
JoinType::Full).await?;
@@ -2296,8 +2298,8 @@ mod tests {
             ("c2", &vec![50, 60, 70, 80, 90]),
         );
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
         let sort_options = vec![SortOptions::default(); on.len()];
 
@@ -2376,8 +2378,8 @@ mod tests {
         let right =
             build_table_from_batches(vec![right_batch_1, right_batch_2, 
right_batch_3]);
         let on = vec![(
-            Column::new_with_schema("b1", &left.schema())?,
-            Column::new_with_schema("b2", &right.schema())?,
+            Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+            Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
         )];
         let sort_options = vec![SortOptions::default(); on.len()];
 
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs 
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 00950f0825..3f907930d6 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -46,11 +46,11 @@ use crate::joins::utils::{
     JoinHashMapType, JoinOn, StatefulStreamResult,
 };
 use crate::{
-    expressions::{Column, PhysicalSortExpr},
+    expressions::PhysicalSortExpr,
     joins::StreamJoinPartitionMode,
     metrics::{ExecutionPlanMetricsSet, MetricsSet},
     DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties, 
ExecutionPlan,
-    Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, 
Statistics,
+    Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics,
 };
 
 use arrow::array::{
@@ -72,7 +72,7 @@ use 
datafusion_physical_expr::equivalence::join_equivalence_properties;
 use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
 
 use ahash::RandomState;
-use datafusion_physical_expr::PhysicalSortRequirement;
+use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
 use futures::Stream;
 use hashbrown::HashSet;
 use parking_lot::Mutex;
@@ -171,7 +171,7 @@ pub struct SymmetricHashJoinExec {
     /// Right side stream
     pub(crate) right: Arc<dyn ExecutionPlan>,
     /// Set of common columns used to join on
-    pub(crate) on: Vec<(Column, Column)>,
+    pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
     /// Filters applied when finding matching rows
     pub(crate) filter: Option<JoinFilter>,
     /// How the join is performed
@@ -261,7 +261,7 @@ impl SymmetricHashJoinExec {
     }
 
     /// Set of common columns used to join on
-    pub fn on(&self) -> &[(Column, Column)] {
+    pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
         &self.on
     }
 
@@ -367,7 +367,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
                 let (left_expr, right_expr) = self
                     .on
                     .iter()
-                    .map(|(l, r)| (Arc::new(l.clone()) as _, 
Arc::new(r.clone()) as _))
+                    .map(|(l, r)| (l.clone() as _, r.clone() as _))
                     .unzip();
                 vec![
                     Distribution::HashPartitioned(left_expr),
@@ -874,8 +874,8 @@ fn lookup_join_hashmap(
     build_hashmap: &PruningJoinHashMap,
     build_batch: &RecordBatch,
     probe_batch: &RecordBatch,
-    build_on: &[Column],
-    probe_on: &[Column],
+    build_on: &[PhysicalExprRef],
+    probe_on: &[PhysicalExprRef],
     random_state: &RandomState,
     null_equals_null: bool,
     hashes_buffer: &mut Vec<u64>,
@@ -952,7 +952,7 @@ pub struct OneSideHashJoiner {
     /// Input record batch buffer
     pub input_buffer: RecordBatch,
     /// Columns from the side
-    pub(crate) on: Vec<Column>,
+    pub(crate) on: Vec<PhysicalExprRef>,
     /// Hashmap
     pub(crate) hashmap: PruningJoinHashMap,
     /// Reuse the hashes buffer
@@ -979,7 +979,11 @@ impl OneSideHashJoiner {
         size += std::mem::size_of_val(&self.deleted_offset);
         size
     }
-    pub fn new(build_side: JoinSide, on: Vec<Column>, schema: SchemaRef) -> 
Self {
+    pub fn new(
+        build_side: JoinSide,
+        on: Vec<PhysicalExprRef>,
+        schema: SchemaRef,
+    ) -> Self {
         Self {
             build_side,
             input_buffer: RecordBatch::new_empty(schema),
@@ -1447,8 +1451,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1515,8 +1519,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1569,8 +1573,8 @@ mod tests {
             create_memory_table(left_partition, right_partition, vec![], 
vec![])?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1621,8 +1625,8 @@ mod tests {
             create_memory_table(left_partition, right_partition, vec![], 
vec![])?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
         experiment(left, right, None, join_type, on, task_ctx).await?;
         Ok(())
@@ -1670,8 +1674,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1731,8 +1735,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1792,8 +1796,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1855,8 +1859,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1914,8 +1918,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -1981,8 +1985,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
@@ -2040,8 +2044,8 @@ mod tests {
         let left_schema = &left_partition[0].schema();
         let right_schema = &right_partition[0].schema();
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
         let left_sorted = vec![PhysicalSortExpr {
             expr: col("lt1", left_schema)?,
@@ -2124,8 +2128,8 @@ mod tests {
         let left_schema = &left_partition[0].schema();
         let right_schema = &right_partition[0].schema();
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
         let left_sorted = vec![PhysicalSortExpr {
             expr: col("li1", left_schema)?,
@@ -2217,8 +2221,8 @@ mod tests {
         )?;
 
         let on = vec![(
-            Column::new_with_schema("lc1", left_schema)?,
-            Column::new_with_schema("rc1", right_schema)?,
+            Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+            Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
         )];
 
         let intermediate_schema = Schema::new(vec![
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs 
b/datafusion/physical-plan/src/joins/test_utils.rs
index 477e2de421..37faae8737 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -78,15 +78,9 @@ pub async fn partitioned_sym_join_with_filter(
 ) -> Result<Vec<RecordBatch>> {
     let partition_count = 4;
 
-    let left_expr = on
-        .iter()
-        .map(|(l, _)| Arc::new(l.clone()) as _)
-        .collect::<Vec<_>>();
+    let left_expr = on.iter().map(|(l, _)| l.clone() as _).collect::<Vec<_>>();
 
-    let right_expr = on
-        .iter()
-        .map(|(_, r)| Arc::new(r.clone()) as _)
-        .collect::<Vec<_>>();
+    let right_expr = on.iter().map(|(_, r)| r.clone() as 
_).collect::<Vec<_>>();
 
     let join = SymmetricHashJoinExec::try_new(
         Arc::new(RepartitionExec::try_new(
@@ -133,7 +127,7 @@ pub async fn partitioned_hash_join_with_filter(
     let partition_count = 4;
     let (left_expr, right_expr) = on
         .iter()
-        .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
+        .map(|(l, r)| (l.clone() as _, r.clone() as _))
         .unzip();
 
     let join = Arc::new(HashJoinExec::try_new(
diff --git a/datafusion/physical-plan/src/joins/utils.rs 
b/datafusion/physical-plan/src/joins/utils.rs
index cd987ab40d..e6e3f83fd7 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -45,11 +45,12 @@ use datafusion_common::{
 use datafusion_expr::interval_arithmetic::Interval;
 use datafusion_physical_expr::equivalence::add_offset_to_expr;
 use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::utils::merge_vectors;
+use datafusion_physical_expr::utils::{collect_columns, merge_vectors};
 use datafusion_physical_expr::{
-    LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr,
+    LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, 
PhysicalSortExpr,
 };
 
+use datafusion_common::tree_node::{Transformed, TreeNode};
 use futures::future::{BoxFuture, Shared};
 use futures::{ready, FutureExt};
 use hashbrown::raw::RawTable;
@@ -377,9 +378,9 @@ impl fmt::Debug for JoinHashMap {
 }
 
 /// The on clause of the join, as vector of (left, right) columns.
-pub type JoinOn = Vec<(Column, Column)>;
+pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>;
 /// Reference for JoinOn.
-pub type JoinOnRef<'a> = &'a [(Column, Column)];
+pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)];
 
 /// Checks whether the schemas "left" and "right" and columns "on" represent a 
valid join.
 /// They are valid whenever their columns' intersection equals the set `on`
@@ -405,12 +406,18 @@ pub fn check_join_is_valid(left: &Schema, right: &Schema, 
on: JoinOnRef) -> Resu
 fn check_join_set_is_valid(
     left: &HashSet<Column>,
     right: &HashSet<Column>,
-    on: &[(Column, Column)],
+    on: &[(PhysicalExprRef, PhysicalExprRef)],
 ) -> Result<()> {
-    let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
+    let on_left = &on
+        .iter()
+        .flat_map(|on| collect_columns(&on.0))
+        .collect::<HashSet<_>>();
     let left_missing = on_left.difference(left).collect::<HashSet<_>>();
 
-    let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
+    let on_right = &on
+        .iter()
+        .flat_map(|on| collect_columns(&on.1))
+        .collect::<HashSet<_>>();
     let right_missing = on_right.difference(right).collect::<HashSet<_>>();
 
     if !left_missing.is_empty() | !right_missing.is_empty() {
@@ -466,21 +473,41 @@ pub fn adjust_right_output_partitioning(
 /// Replaces the right column (first index in the `on_column` tuple) with
 /// the left column (zeroth index in the tuple) inside `right_ordering`.
 fn replace_on_columns_of_right_ordering(
-    on_columns: &[(Column, Column)],
+    on_columns: &[(PhysicalExprRef, PhysicalExprRef)],
     right_ordering: &mut [PhysicalSortExpr],
-    left_columns_len: usize,
-) {
+) -> Result<()> {
     for (left_col, right_col) in on_columns {
-        let right_col =
-            Column::new(right_col.name(), right_col.index() + 
left_columns_len);
         for item in right_ordering.iter_mut() {
-            if let Some(col) = item.expr.as_any().downcast_ref::<Column>() {
-                if right_col.eq(col) {
-                    item.expr = Arc::new(left_col.clone()) as _;
+            let new_expr = item.expr.clone().transform(&|e| {
+                if e.eq(right_col) {
+                    Ok(Transformed::Yes(left_col.clone()))
+                } else {
+                    Ok(Transformed::No(e))
                 }
-            }
+            })?;
+            item.expr = new_expr;
         }
     }
+    Ok(())
+}
+
+fn offset_ordering(
+    ordering: LexOrderingRef,
+    join_type: &JoinType,
+    offset: usize,
+) -> Vec<PhysicalSortExpr> {
+    match join_type {
+        // In the case below, right ordering should be offseted with the left
+        // side length, since we append the right table to the left table.
+        JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => 
ordering
+            .iter()
+            .map(|sort_expr| PhysicalSortExpr {
+                expr: add_offset_to_expr(sort_expr.expr.clone(), offset),
+                options: sort_expr.options,
+            })
+            .collect(),
+        _ => ordering.to_vec(),
+    }
 }
 
 /// Calculate the output ordering of a given join operation.
@@ -488,35 +515,24 @@ pub fn calculate_join_output_ordering(
     left_ordering: LexOrderingRef,
     right_ordering: LexOrderingRef,
     join_type: JoinType,
-    on_columns: &[(Column, Column)],
+    on_columns: &[(PhysicalExprRef, PhysicalExprRef)],
     left_columns_len: usize,
     maintains_input_order: &[bool],
     probe_side: Option<JoinSide>,
 ) -> Option<LexOrdering> {
-    let mut right_ordering = match join_type {
-        // In the case below, right ordering should be offseted with the left
-        // side length, since we append the right table to the left table.
-        JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => 
{
-            right_ordering
-                .iter()
-                .map(|sort_expr| PhysicalSortExpr {
-                    expr: add_offset_to_expr(sort_expr.expr.clone(), 
left_columns_len),
-                    options: sort_expr.options,
-                })
-                .collect()
-        }
-        _ => right_ordering.to_vec(),
-    };
     let output_ordering = match maintains_input_order {
         [true, false] => {
             // Special case, we can prefix ordering of right side with the 
ordering of left side.
             if join_type == JoinType::Inner && probe_side == 
Some(JoinSide::Left) {
                 replace_on_columns_of_right_ordering(
                     on_columns,
-                    &mut right_ordering,
-                    left_columns_len,
-                );
-                merge_vectors(left_ordering, &right_ordering)
+                    &mut right_ordering.to_vec(),
+                )
+                .ok()?;
+                merge_vectors(
+                    left_ordering,
+                    &offset_ordering(right_ordering, &join_type, 
left_columns_len),
+                )
             } else {
                 left_ordering.to_vec()
             }
@@ -526,12 +542,15 @@ pub fn calculate_join_output_ordering(
             if join_type == JoinType::Inner && probe_side == 
Some(JoinSide::Right) {
                 replace_on_columns_of_right_ordering(
                     on_columns,
-                    &mut right_ordering,
-                    left_columns_len,
-                );
-                merge_vectors(&right_ordering, left_ordering)
+                    &mut right_ordering.to_vec(),
+                )
+                .ok()?;
+                merge_vectors(
+                    &offset_ordering(right_ordering, &join_type, 
left_columns_len),
+                    left_ordering,
+                )
             } else {
-                right_ordering.to_vec()
+                offset_ordering(right_ordering, &join_type, left_columns_len)
             }
         }
         // Doesn't maintain ordering, output ordering is None.
@@ -810,10 +829,19 @@ fn estimate_join_cardinality(
             let (left_col_stats, right_col_stats) = on
                 .iter()
                 .map(|(left, right)| {
-                    (
-                        left_stats.column_statistics[left.index()].clone(),
-                        right_stats.column_statistics[right.index()].clone(),
-                    )
+                    match (
+                        left.as_any().downcast_ref::<Column>(),
+                        right.as_any().downcast_ref::<Column>(),
+                    ) {
+                        (Some(left), Some(right)) => (
+                            left_stats.column_statistics[left.index()].clone(),
+                            
right_stats.column_statistics[right.index()].clone(),
+                        ),
+                        _ => (
+                            ColumnStatistics::new_unknown(),
+                            ColumnStatistics::new_unknown(),
+                        ),
+                    }
                 })
                 .unzip::<_, _, Vec<_>, Vec<_>>();
 
@@ -1476,7 +1504,11 @@ mod tests {
     use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
     use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
 
-    fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> 
Result<()> {
+    fn check(
+        left: &[Column],
+        right: &[Column],
+        on: &[(PhysicalExprRef, PhysicalExprRef)],
+    ) -> Result<()> {
         let left = left
             .iter()
             .map(|x| x.to_owned())
@@ -1492,7 +1524,10 @@ mod tests {
     fn check_valid() -> Result<()> {
         let left = vec![Column::new("a", 0), Column::new("b1", 1)];
         let right = vec![Column::new("a", 0), Column::new("b2", 1)];
-        let on = &[(Column::new("a", 0), Column::new("a", 0))];
+        let on = &[(
+            Arc::new(Column::new("a", 0)) as _,
+            Arc::new(Column::new("a", 0)) as _,
+        )];
 
         check(&left, &right, on)?;
         Ok(())
@@ -1502,7 +1537,10 @@ mod tests {
     fn check_not_in_right() {
         let left = vec![Column::new("a", 0), Column::new("b", 1)];
         let right = vec![Column::new("b", 0)];
-        let on = &[(Column::new("a", 0), Column::new("a", 0))];
+        let on = &[(
+            Arc::new(Column::new("a", 0)) as _,
+            Arc::new(Column::new("a", 0)) as _,
+        )];
 
         assert!(check(&left, &right, on).is_err());
     }
@@ -1544,7 +1582,10 @@ mod tests {
     fn check_not_in_left() {
         let left = vec![Column::new("b", 0)];
         let right = vec![Column::new("a", 0)];
-        let on = &[(Column::new("a", 0), Column::new("a", 0))];
+        let on = &[(
+            Arc::new(Column::new("a", 0)) as _,
+            Arc::new(Column::new("a", 0)) as _,
+        )];
 
         assert!(check(&left, &right, on).is_err());
     }
@@ -1554,7 +1595,10 @@ mod tests {
         // column "a" would appear both in left and right
         let left = vec![Column::new("a", 0), Column::new("c", 1)];
         let right = vec![Column::new("a", 0), Column::new("b", 1)];
-        let on = &[(Column::new("a", 0), Column::new("b", 1))];
+        let on = &[(
+            Arc::new(Column::new("a", 0)) as _,
+            Arc::new(Column::new("b", 1)) as _,
+        )];
 
         assert!(check(&left, &right, on).is_ok());
     }
@@ -1563,7 +1607,10 @@ mod tests {
     fn check_in_right() {
         let left = vec![Column::new("a", 0), Column::new("c", 1)];
         let right = vec![Column::new("b", 0)];
-        let on = &[(Column::new("a", 0), Column::new("b", 0))];
+        let on = &[(
+            Arc::new(Column::new("a", 0)) as _,
+            Arc::new(Column::new("b", 0)) as _,
+        )];
 
         assert!(check(&left, &right, on).is_ok());
     }
@@ -1835,7 +1882,10 @@ mod tests {
 
             // We should also be able to use join_cardinality to get the same 
results
             let join_type = JoinType::Inner;
-            let join_on = vec![(Column::new("a", 0), Column::new("b", 0))];
+            let join_on = vec![(
+                Arc::new(Column::new("a", 0)) as _,
+                Arc::new(Column::new("b", 0)) as _,
+            )];
             let partial_join_stats = estimate_join_cardinality(
                 &join_type,
                 create_stats(Some(left_num_rows), left_col_stats.clone(), 
false),
@@ -1957,8 +2007,14 @@ mod tests {
 
         for (join_type, expected_num_rows) in cases {
             let join_on = vec![
-                (Column::new("a", 0), Column::new("c", 0)),
-                (Column::new("b", 1), Column::new("d", 1)),
+                (
+                    Arc::new(Column::new("a", 0)) as _,
+                    Arc::new(Column::new("c", 0)) as _,
+                ),
+                (
+                    Arc::new(Column::new("b", 1)) as _,
+                    Arc::new(Column::new("d", 1)) as _,
+                ),
             ];
 
             let partial_join_stats = estimate_join_cardinality(
@@ -2005,8 +2061,14 @@ mod tests {
         ];
 
         let join_on = vec![
-            (Column::new("a", 0), Column::new("c", 0)),
-            (Column::new("x", 2), Column::new("y", 2)),
+            (
+                Arc::new(Column::new("a", 0)) as _,
+                Arc::new(Column::new("c", 0)) as _,
+            ),
+            (
+                Arc::new(Column::new("x", 2)) as _,
+                Arc::new(Column::new("y", 2)) as _,
+            ),
         ];
 
         let cases = vec![
@@ -2071,7 +2133,10 @@ mod tests {
             },
         ];
         let join_type = JoinType::Inner;
-        let on_columns = [(Column::new("b", 1), Column::new("x", 0))];
+        let on_columns = [(
+            Arc::new(Column::new("b", 1)) as _,
+            Arc::new(Column::new("x", 0)) as _,
+        )];
         let left_columns_len = 5;
         let maintains_input_orders = [[true, false], [false, true]];
         let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index c8468e1709..1d5ca59171 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1581,8 +1581,8 @@ message PhysicalColumn {
 }
 
 message JoinOn {
-  PhysicalColumn left = 1;
-  PhysicalColumn right = 2;
+  PhysicalExprNode left = 1;
+  PhysicalExprNode right = 2;
 }
 
 message EmptyExecNode {
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index a5582cc2dc..485dbd48b8 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2244,9 +2244,9 @@ pub struct PhysicalColumn {
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct JoinOn {
     #[prost(message, optional, tag = "1")]
-    pub left: ::core::option::Option<PhysicalColumn>,
+    pub left: ::core::option::Option<PhysicalExprNode>,
     #[prost(message, optional, tag = "2")]
-    pub right: ::core::option::Option<PhysicalColumn>,
+    pub right: ::core::option::Option<PhysicalExprNode>,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index f39f885b78..d2961875d8 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -31,6 +31,7 @@ use datafusion::datasource::physical_plan::ParquetExec;
 use datafusion::datasource::physical_plan::{AvroExec, CsvExec};
 use datafusion::execution::runtime_env::RuntimeEnv;
 use datafusion::execution::FunctionRegistry;
+use datafusion::physical_expr::PhysicalExprRef;
 use datafusion::physical_plan::aggregates::{create_aggregate_expr, 
AggregateMode};
 use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
 use datafusion::physical_plan::analyze::AnalyzeExec;
@@ -38,7 +39,7 @@ use 
datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
 use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
 use datafusion::physical_plan::empty::EmptyExec;
 use datafusion::physical_plan::explain::ExplainExec;
-use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr};
+use datafusion::physical_plan::expressions::PhysicalSortExpr;
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::insert::FileSinkExec;
 use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
@@ -64,6 +65,7 @@ use prost::Message;
 
 use crate::common::str_to_byte;
 use crate::common::{byte_to_string, proto_error};
+use crate::convert_required;
 use crate::physical_plan::from_proto::{
     parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
     parse_protobuf_file_scan_config,
@@ -75,7 +77,6 @@ use crate::protobuf::repartition_exec_node::PartitionMethod;
 use crate::protobuf::{
     self, window_agg_exec_node, PhysicalPlanNode, 
PhysicalSortExprNodeCollection,
 };
-use crate::{convert_required, into_required};
 
 use self::from_proto::parse_physical_window_expr;
 
@@ -506,12 +507,22 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     runtime,
                     extension_codec,
                 )?;
-                let on: Vec<(Column, Column)> = hashjoin
+                let left_schema = left.schema();
+                let right_schema = right.schema();
+                let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin
                     .on
                     .iter()
                     .map(|col| {
-                        let left = into_required!(col.left)?;
-                        let right = into_required!(col.right)?;
+                        let left = parse_physical_expr(
+                            &col.left.clone().unwrap(),
+                            registry,
+                            left_schema.as_ref(),
+                        )?;
+                        let right = parse_physical_expr(
+                            &col.right.clone().unwrap(),
+                            registry,
+                            right_schema.as_ref(),
+                        )?;
                         Ok((left, right))
                     })
                     .collect::<Result<_>>()?;
@@ -595,12 +606,22 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     runtime,
                     extension_codec,
                 )?;
+                let left_schema = left.schema();
+                let right_schema = right.schema();
                 let on = sym_join
                     .on
                     .iter()
                     .map(|col| {
-                        let left = into_required!(col.left)?;
-                        let right = into_required!(col.right)?;
+                        let left = parse_physical_expr(
+                            &col.left.clone().unwrap(),
+                            registry,
+                            left_schema.as_ref(),
+                        )?;
+                        let right = parse_physical_expr(
+                            &col.right.clone().unwrap(),
+                            registry,
+                            right_schema.as_ref(),
+                        )?;
                         Ok((left, right))
                     })
                     .collect::<Result<_>>()?;
@@ -647,7 +668,6 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     })
                     .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
 
-                let left_schema = left.schema();
                 let left_sort_exprs = parse_physical_sort_exprs(
                     &sym_join.left_sort_exprs,
                     registry,
@@ -659,7 +679,6 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     Some(left_sort_exprs)
                 };
 
-                let right_schema = right.schema();
                 let right_sort_exprs = parse_physical_sort_exprs(
                     &sym_join.right_sort_exprs,
                     registry,
@@ -1144,17 +1163,15 @@ impl AsExecutionPlan for PhysicalPlanNode {
             let on: Vec<protobuf::JoinOn> = exec
                 .on()
                 .iter()
-                .map(|tuple| protobuf::JoinOn {
-                    left: Some(protobuf::PhysicalColumn {
-                        name: tuple.0.name().to_string(),
-                        index: tuple.0.index() as u32,
-                    }),
-                    right: Some(protobuf::PhysicalColumn {
-                        name: tuple.1.name().to_string(),
-                        index: tuple.1.index() as u32,
-                    }),
+                .map(|tuple| {
+                    let l = tuple.0.to_owned().try_into()?;
+                    let r = tuple.1.to_owned().try_into()?;
+                    Ok::<_, DataFusionError>(protobuf::JoinOn {
+                        left: Some(l),
+                        right: Some(r),
+                    })
                 })
-                .collect();
+                .collect::<Result<_>>()?;
             let join_type: protobuf::JoinType = 
exec.join_type().to_owned().into();
             let filter = exec
                 .filter()
@@ -1214,17 +1231,15 @@ impl AsExecutionPlan for PhysicalPlanNode {
             let on = exec
                 .on()
                 .iter()
-                .map(|tuple| protobuf::JoinOn {
-                    left: Some(protobuf::PhysicalColumn {
-                        name: tuple.0.name().to_string(),
-                        index: tuple.0.index() as u32,
-                    }),
-                    right: Some(protobuf::PhysicalColumn {
-                        name: tuple.1.name().to_string(),
-                        index: tuple.1.index() as u32,
-                    }),
+                .map(|tuple| {
+                    let l = tuple.0.to_owned().try_into()?;
+                    let r = tuple.1.to_owned().try_into()?;
+                    Ok::<_, DataFusionError>(protobuf::JoinOn {
+                        left: Some(l),
+                        right: Some(r),
+                    })
                 })
-                .collect();
+                .collect::<Result<_>>()?;
             let join_type: protobuf::JoinType = 
exec.join_type().to_owned().into();
             let filter = exec
                 .filter()
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index eba3db298f..f2f1b0ea0d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -191,8 +191,8 @@ fn roundtrip_hash_join() -> Result<()> {
     let schema_left = Schema::new(vec![field_a.clone()]);
     let schema_right = Schema::new(vec![field_a]);
     let on = vec![(
-        Column::new("col", schema_left.index_of("col")?),
-        Column::new("col", schema_right.index_of("col")?),
+        Arc::new(Column::new("col", schema_left.index_of("col")?)) as _,
+        Arc::new(Column::new("col", schema_right.index_of("col")?)) as _,
     )];
 
     let schema_left = Arc::new(schema_left);
@@ -916,8 +916,8 @@ fn roundtrip_sym_hash_join() -> Result<()> {
     let schema_left = Schema::new(vec![field_a.clone()]);
     let schema_right = Schema::new(vec![field_a]);
     let on = vec![(
-        Column::new("col", schema_left.index_of("col")?),
-        Column::new("col", schema_right.index_of("col")?),
+        Arc::new(Column::new("col", schema_left.index_of("col")?)) as _,
+        Arc::new(Column::new("col", schema_right.index_of("col")?)) as _,
     )];
 
     let schema_left = Arc::new(schema_left);

Reply via email to