This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 191d8b72f Add optimizer rule for type coercion (binary operations 
only) (#3222)
191d8b72f is described below

commit 191d8b72ff7500872953b2749937abcd1c06f848
Author: Andy Grove <[email protected]>
AuthorDate: Tue Sep 6 06:16:37 2022 -0600

    Add optimizer rule for type coercion (binary operations only) (#3222)
    
    * Add binary type coercion to logical plan and do not allow CAST to change 
an expression name
    
    * fix tests
    
    * update avro tests
    
    * add reference to GitHub issue
    
    * unignore timestamp_add_interval_months
    
    * fix: update tests to use correct column types
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/execution/context.rs         |   4 +
 datafusion/core/src/physical_plan/planner.rs     |  12 +-
 datafusion/core/tests/dataframe_functions.rs     |  16 +--
 datafusion/core/tests/parquet_pruning.rs         |   1 +
 datafusion/core/tests/sql/aggregates.rs          |  29 +++-
 datafusion/core/tests/sql/avro.rs                |  64 ++++-----
 datafusion/core/tests/sql/decimal.rs             | 106 +++++++-------
 datafusion/core/tests/sql/explain_analyze.rs     |   6 +-
 datafusion/core/tests/sql/expr.rs                |   4 +-
 datafusion/core/tests/sql/functions.rs           |  56 ++++----
 datafusion/core/tests/sql/joins.rs               |  10 +-
 datafusion/core/tests/sql/parquet.rs             |  24 ++--
 datafusion/core/tests/sql/predicates.rs          |   2 +-
 datafusion/core/tests/sql/subqueries.rs          |   8 +-
 datafusion/core/tests/sql/timestamp.rs           |  60 ++++----
 datafusion/core/tests/sql/window.rs              |  18 +--
 datafusion/expr/src/expr.rs                      |  29 +++-
 datafusion/optimizer/src/lib.rs                  |   1 +
 datafusion/optimizer/src/simplify_expressions.rs |  12 +-
 datafusion/optimizer/src/type_coercion.rs        | 170 +++++++++++++++++++++++
 datafusion/physical-expr/src/planner.rs          |   8 +-
 21 files changed, 427 insertions(+), 213 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index 2a9d01124..a6a508d36 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -111,6 +111,7 @@ use 
datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
 use 
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
 use 
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
 use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
+use datafusion_optimizer::type_coercion::TypeCoercion;
 use datafusion_sql::{
     parser::DFParser,
     planner::{ContextProvider, SqlToRel},
@@ -1433,6 +1434,9 @@ impl SessionState {
         }
         rules.push(Arc::new(ReduceOuterJoin::new()));
         rules.push(Arc::new(FilterPushDown::new()));
+        // we do type coercion after filter push down so that we don't push 
CAST filters to Parquet
+        // until https://github.com/apache/arrow-datafusion/issues/3289 is 
resolved
+        rules.push(Arc::new(TypeCoercion::new()));
         rules.push(Arc::new(LimitPushDown::new()));
         rules.push(Arc::new(SingleDistinctToGroupBy::new()));
 
diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 8d7e0e9e4..a2d868361 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -128,13 +128,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> 
Result<String> {
             name += "END";
             Ok(name)
         }
-        Expr::Cast { expr, data_type } => {
-            let expr = create_physical_name(expr, false)?;
-            Ok(format!("CAST({} AS {:?})", expr, data_type))
+        Expr::Cast { expr, .. } => {
+            // CAST does not change the expression name
+            create_physical_name(expr, false)
         }
-        Expr::TryCast { expr, data_type } => {
-            let expr = create_physical_name(expr, false)?;
-            Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
+        Expr::TryCast { expr, .. } => {
+            // CAST does not change the expression name
+            create_physical_name(expr, false)
         }
         Expr::Not(expr) => {
             let expr = create_physical_name(expr, false)?;
diff --git a/datafusion/core/tests/dataframe_functions.rs 
b/datafusion/core/tests/dataframe_functions.rs
index 19694285c..0d3631b18 100644
--- a/datafusion/core/tests/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe_functions.rs
@@ -667,14 +667,14 @@ async fn test_fn_substr() -> Result<()> {
 async fn test_cast() -> Result<()> {
     let expr = cast(col("b"), DataType::Float64);
     let expected = vec![
-        "+-------------------------+",
-        "| CAST(test.b AS Float64) |",
-        "+-------------------------+",
-        "| 1                       |",
-        "| 10                      |",
-        "| 10                      |",
-        "| 100                     |",
-        "+-------------------------+",
+        "+--------+",
+        "| test.b |",
+        "+--------+",
+        "| 1      |",
+        "| 10     |",
+        "| 10     |",
+        "| 100    |",
+        "+--------+",
     ];
 
     assert_fn_batches!(expr, expected);
diff --git a/datafusion/core/tests/parquet_pruning.rs 
b/datafusion/core/tests/parquet_pruning.rs
index 0c3acf9fb..b6c286763 100644
--- a/datafusion/core/tests/parquet_pruning.rs
+++ b/datafusion/core/tests/parquet_pruning.rs
@@ -647,6 +647,7 @@ impl ContextWithParquet {
         let pretty_input = pretty_format_batches(&input).unwrap().to_string();
 
         let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing 
plan");
+
         let physical_plan = self
             .ctx
             .create_physical_plan(&logical_plan)
diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index 7e0e785da..357addbc0 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -462,11 +462,11 @@ async fn csv_query_external_table_sum() {
         "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM 
aggregate_test_100";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        
"+-------------------------------------------+-------------------------------------------+",
-        "| SUM(CAST(aggregate_test_100.c7 AS Int64)) | 
SUM(CAST(aggregate_test_100.c8 AS Int64)) |",
-        
"+-------------------------------------------+-------------------------------------------+",
-        "| 13060                                     | 3017641                 
                  |",
-        
"+-------------------------------------------+-------------------------------------------+",
+        "+----------------------------+----------------------------+",
+        "| SUM(aggregate_test_100.c7) | SUM(aggregate_test_100.c8) |",
+        "+----------------------------+----------------------------+",
+        "| 13060                      | 3017641                    |",
+        "+----------------------------+----------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 }
@@ -555,6 +555,7 @@ async fn csv_query_count_one() {
 }
 
 #[tokio::test]
+#[ignore] // https://github.com/apache/arrow-datafusion/issues/3353
 async fn csv_query_approx_count() -> Result<()> {
     let ctx = SessionContext::new();
     register_aggregate_csv(&ctx).await?;
@@ -571,6 +572,24 @@ async fn csv_query_approx_count() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_approx_count_dupe_expr_aliased() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_aggregate_csv(&ctx).await?;
+    let sql =
+        "SELECT approx_distinct(c9) a, approx_distinct(c9) b FROM 
aggregate_test_100";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----+-----+",
+        "| a   | b   |",
+        "+-----+-----+",
+        "| 100 | 100 |",
+        "+-----+-----+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 // This test executes the APPROX_PERCENTILE_CONT aggregation against the test
 // data, asserting the estimated quantiles are ±5% their actual values.
 //
diff --git a/datafusion/core/tests/sql/avro.rs 
b/datafusion/core/tests/sql/avro.rs
index f4ff4cd7c..8fdef28bd 100644
--- a/datafusion/core/tests/sql/avro.rs
+++ b/datafusion/core/tests/sql/avro.rs
@@ -37,18 +37,18 @@ async fn avro_query() {
     let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+----+-----------------------------------------+",
-        "| id | CAST(alltypes_plain.string_col AS Utf8) |",
-        "+----+-----------------------------------------+",
-        "| 4  | 0                                       |",
-        "| 5  | 1                                       |",
-        "| 6  | 0                                       |",
-        "| 7  | 1                                       |",
-        "| 2  | 0                                       |",
-        "| 3  | 1                                       |",
-        "| 0  | 0                                       |",
-        "| 1  | 1                                       |",
-        "+----+-----------------------------------------+",
+        "+----+---------------------------+",
+        "| id | alltypes_plain.string_col |",
+        "+----+---------------------------+",
+        "| 4  | 0                         |",
+        "| 5  | 1                         |",
+        "| 6  | 0                         |",
+        "| 7  | 1                         |",
+        "| 2  | 0                         |",
+        "| 3  | 1                         |",
+        "| 0  | 0                         |",
+        "| 1  | 1                         |",
+        "+----+---------------------------+",
     ];
 
     assert_batches_eq!(expected, &actual);
@@ -84,26 +84,26 @@ async fn avro_query_multiple_files() {
     let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+----+-----------------------------------------+",
-        "| id | CAST(alltypes_plain.string_col AS Utf8) |",
-        "+----+-----------------------------------------+",
-        "| 4  | 0                                       |",
-        "| 5  | 1                                       |",
-        "| 6  | 0                                       |",
-        "| 7  | 1                                       |",
-        "| 2  | 0                                       |",
-        "| 3  | 1                                       |",
-        "| 0  | 0                                       |",
-        "| 1  | 1                                       |",
-        "| 4  | 0                                       |",
-        "| 5  | 1                                       |",
-        "| 6  | 0                                       |",
-        "| 7  | 1                                       |",
-        "| 2  | 0                                       |",
-        "| 3  | 1                                       |",
-        "| 0  | 0                                       |",
-        "| 1  | 1                                       |",
-        "+----+-----------------------------------------+",
+        "+----+---------------------------+",
+        "| id | alltypes_plain.string_col |",
+        "+----+---------------------------+",
+        "| 4  | 0                         |",
+        "| 5  | 1                         |",
+        "| 6  | 0                         |",
+        "| 7  | 1                         |",
+        "| 2  | 0                         |",
+        "| 3  | 1                         |",
+        "| 0  | 0                         |",
+        "| 1  | 1                         |",
+        "| 4  | 0                         |",
+        "| 5  | 1                         |",
+        "| 6  | 0                         |",
+        "| 7  | 1                         |",
+        "| 2  | 0                         |",
+        "| 3  | 1                         |",
+        "| 0  | 0                         |",
+        "| 1  | 1                         |",
+        "+----+---------------------------+",
     ];
 
     assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/decimal.rs 
b/datafusion/core/tests/sql/decimal.rs
index 8ded8752d..7c74cdd52 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -27,11 +27,11 @@ async fn decimal_cast() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+------------------------------------------+",
-        "| CAST(Float64(1.23) AS Decimal128(10, 4)) |",
-        "+------------------------------------------+",
-        "| 1.2300                                   |",
-        "+------------------------------------------+",
+        "+---------------+",
+        "| Float64(1.23) |",
+        "+---------------+",
+        "| 1.2300        |",
+        "+---------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -42,11 +42,11 @@ async fn decimal_cast() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        
"+---------------------------------------------------------------------+",
-        "| CAST(CAST(Float64(1.23) AS Decimal128(10, 3)) AS Decimal128(10, 4)) 
|",
-        
"+---------------------------------------------------------------------+",
-        "| 1.2300                                                              
|",
-        
"+---------------------------------------------------------------------+",
+        "+---------------+",
+        "| Float64(1.23) |",
+        "+---------------+",
+        "| 1.2300        |",
+        "+---------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -57,11 +57,11 @@ async fn decimal_cast() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+--------------------------------------------+",
-        "| CAST(Float64(1.2345) AS Decimal128(24, 2)) |",
-        "+--------------------------------------------+",
-        "| 1.23                                       |",
-        "+--------------------------------------------+",
+        "+-----------------+",
+        "| Float64(1.2345) |",
+        "+-----------------+",
+        "| 1.23            |",
+        "+-----------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -550,25 +550,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+----------------------------------------------------------------+",
-        "| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
-        "+----------------------------------------------------------------+",
-        "| 1.000000000000                                                 |",
-        "| 2.000000000000                                                 |",
-        "| 2.000000000000                                                 |",
-        "| 3.000000000000                                                 |",
-        "| 3.000000000000                                                 |",
-        "| 3.000000000000                                                 |",
-        "| 4.000000000000                                                 |",
-        "| 4.000000000000                                                 |",
-        "| 4.000000000000                                                 |",
-        "| 4.000000000000                                                 |",
-        "| 5.000000000000                                                 |",
-        "| 5.000000000000                                                 |",
-        "| 5.000000000000                                                 |",
-        "| 5.000000000000                                                 |",
-        "| 5.000000000000                                                 |",
-        "+----------------------------------------------------------------+",
+        "+--------------------------------------+",
+        "| decimal_simple.c1 / Float64(0.00001) |",
+        "+--------------------------------------+",
+        "| 1.000000000000                       |",
+        "| 2.000000000000                       |",
+        "| 2.000000000000                       |",
+        "| 3.000000000000                       |",
+        "| 3.000000000000                       |",
+        "| 3.000000000000                       |",
+        "| 4.000000000000                       |",
+        "| 4.000000000000                       |",
+        "| 4.000000000000                       |",
+        "| 4.000000000000                       |",
+        "| 5.000000000000                       |",
+        "| 5.000000000000                       |",
+        "| 5.000000000000                       |",
+        "| 5.000000000000                       |",
+        "| 5.000000000000                       |",
+        "+--------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -609,25 +609,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+----------------------------------------------------------------+",
-        "| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
-        "+----------------------------------------------------------------+",
-        "| 0.0000040                                                      |",
-        "| 0.0000050                                                      |",
-        "| 0.0000090                                                      |",
-        "| 0.0000020                                                      |",
-        "| 0.0000050                                                      |",
-        "| 0.0000010                                                      |",
-        "| 0.0000040                                                      |",
-        "| 0.0000000                                                      |",
-        "| 0.0000000                                                      |",
-        "| 0.0000040                                                      |",
-        "| 0.0000020                                                      |",
-        "| 0.0000080                                                      |",
-        "| 0.0000030                                                      |",
-        "| 0.0000080                                                      |",
-        "| 0.0000000                                                      |",
-        "+----------------------------------------------------------------+",
+        "+--------------------------------------+",
+        "| decimal_simple.c5 % Float64(0.00001) |",
+        "+--------------------------------------+",
+        "| 0.0000040                            |",
+        "| 0.0000050                            |",
+        "| 0.0000090                            |",
+        "| 0.0000020                            |",
+        "| 0.0000050                            |",
+        "| 0.0000010                            |",
+        "| 0.0000040                            |",
+        "| 0.0000000                            |",
+        "| 0.0000000                            |",
+        "| 0.0000040                            |",
+        "| 0.0000020                            |",
+        "| 0.0000080                            |",
+        "| 0.0000030                            |",
+        "| 0.0000080                            |",
+        "| 0.0000000                            |",
+        "+--------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index 894d45564..a63839c2f 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -653,7 +653,7 @@ order by
     let expected = "\
     Sort: #revenue DESC NULLS FIRST\
     \n  Projection: #customer.c_custkey, #customer.c_name, 
#SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, 
#customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, 
#customer.c_comment\
-    \n    Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, 
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, 
#customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) - 
#lineitem.l_discount)]]\
+    \n    Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, 
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, 
#customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS 
Float64) - #lineitem.l_discount)]]\
     \n      Inner Join: #customer.c_nationkey = #nation.n_nationkey\
     \n        Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
     \n          Inner Join: #customer.c_custkey = #orders.o_custkey\
@@ -663,7 +663,7 @@ order by
     \n          Filter: #lineitem.l_returnflag = Utf8(\"R\")\
     \n            TableScan: lineitem projection=[l_orderkey, l_extendedprice, 
l_discount, l_returnflag], partial_filters=[#lineitem.l_returnflag = 
Utf8(\"R\")]\
     \n        TableScan: nation projection=[n_nationkey, n_name]";
-    assert_eq!(format!("{:?}", plan.unwrap()), expected);
+    assert_eq!(expected, format!("{:?}", plan.unwrap()),);
 
     Ok(())
 }
@@ -694,7 +694,7 @@ async fn test_physical_plan_display_indent() {
         "            RepartitionExec: partitioning=Hash([Column { name: 
\"c1\", index: 0 }], 9000)",
         "              AggregateExec: mode=Partial, gby=[c1@0 as c1], 
aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
         "                CoalesceBatchesExec: target_batch_size=4096",
-        "                  FilterExec: c12@1 < CAST(10 AS Float64)",
+        "                  FilterExec: c12@1 < 10",
         "                    RepartitionExec: 
partitioning=RoundRobinBatch(9000)",
         "                      CsvExec: 
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, 
limit=None, projection=[c1, c12]",
     ];
diff --git a/datafusion/core/tests/sql/expr.rs 
b/datafusion/core/tests/sql/expr.rs
index 0c59724bd..3ca2c4738 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -247,8 +247,8 @@ async fn query_not() -> Result<()> {
 async fn csv_query_sum_cast() {
     let ctx = SessionContext::new();
     register_aggregate_csv_by_sql(&ctx).await;
-    // c8 = i32; c9 = i64
-    let sql = "SELECT c8 + c9 FROM aggregate_test_100";
+    // c8 = i32; c6 = i64
+    let sql = "SELECT c8 + c6 FROM aggregate_test_100";
     // check that the physical and logical schemas are equal
     execute(&ctx, sql).await;
 }
diff --git a/datafusion/core/tests/sql/functions.rs 
b/datafusion/core/tests/sql/functions.rs
index e7bcb24c7..802810d64 100644
--- a/datafusion/core/tests/sql/functions.rs
+++ b/datafusion/core/tests/sql/functions.rs
@@ -43,12 +43,12 @@ async fn csv_query_cast() -> Result<()> {
     let actual = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+-----------------------------------------+",
-        "| CAST(aggregate_test_100.c12 AS Float32) |",
-        "+-----------------------------------------+",
-        "| 0.39144436                              |",
-        "| 0.3887028                               |",
-        "+-----------------------------------------+",
+        "+------------------------+",
+        "| aggregate_test_100.c12 |",
+        "+------------------------+",
+        "| 0.39144436             |",
+        "| 0.3887028              |",
+        "+------------------------+",
     ];
 
     assert_batches_eq!(expected, &actual);
@@ -64,12 +64,12 @@ async fn csv_query_cast_literal() -> Result<()> {
     let actual = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+--------------------+---------------------------+",
-        "| c12                | CAST(Int64(1) AS Float32) |",
-        "+--------------------+---------------------------+",
-        "| 0.9294097332465232 | 1                         |",
-        "| 0.3114712539863804 | 1                         |",
-        "+--------------------+---------------------------+",
+        "+--------------------+----------+",
+        "| c12                | Int64(1) |",
+        "+--------------------+----------+",
+        "| 0.9294097332465232 | 1        |",
+        "| 0.3114712539863804 | 1        |",
+        "+--------------------+----------+",
     ];
 
     assert_batches_eq!(expected, &actual);
@@ -98,14 +98,14 @@ async fn query_concat() -> Result<()> {
     let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+----------------------------------------------------+",
-        "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |",
-        "+----------------------------------------------------+",
-        "| -hi-0                                              |",
-        "| a-hi-1                                             |",
-        "| aa-hi-                                             |",
-        "| aaa-hi-3                                           |",
-        "+----------------------------------------------------+",
+        "+--------------------------------------+",
+        "| concat(test.c1,Utf8(\"-hi-\"),test.c2) |",
+        "+--------------------------------------+",
+        "| -hi-0                                |",
+        "| a-hi-1                               |",
+        "| aa-hi-                               |",
+        "| aaa-hi-3                             |",
+        "+--------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
     Ok(())
@@ -133,14 +133,14 @@ async fn query_array() -> Result<()> {
     let sql = "SELECT make_array(c1, cast(c2 as varchar)) FROM test";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+------------------------------------------+",
-        "| makearray(test.c1,CAST(test.c2 AS Utf8)) |",
-        "+------------------------------------------+",
-        "| [, 0]                                    |",
-        "| [a, 1]                                   |",
-        "| [aa, ]                                   |",
-        "| [aaa, 3]                                 |",
-        "+------------------------------------------+",
+        "+----------------------------+",
+        "| makearray(test.c1,test.c2) |",
+        "+----------------------------+",
+        "| [, 0]                      |",
+        "| [a, 1]                     |",
+        "| [aa, ]                     |",
+        "| [aaa, 3]                   |",
+        "+----------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
     Ok(())
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index b899ac220..4ff29ea39 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1438,9 +1438,9 @@ async fn reduce_left_join_1() -> Result<()> {
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "    Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
-        "      Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, 
t1_int:UInt32;N]",
+        "      Filter: CAST(#t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
         "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-        "      Filter: #t2.t2_id < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
+        "      Filter: CAST(#t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, 
t2_name:Utf8;N, t2_int:UInt32;N]",
         "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
     ];
     let formatted = plan.display_indent_schema().to_string();
@@ -1481,7 +1481,7 @@ async fn reduce_left_join_2() -> Result<()> {
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id, 
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-        "    Filter: #t2.t2_int < Int64(10) OR #t1.t1_int > Int64(2) AND 
#t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, 
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR CAST(#t1.t1_int 
AS Int64) > Int64(2) AND #t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
         "      Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
         "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
@@ -1528,9 +1528,9 @@ async fn reduce_left_join_3() -> Result<()> {
         "      Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, alias=t3 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "        Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, alias=t3 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "          Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
-        "            Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N]",
+        "            Filter: CAST(#t1.t1_id AS Int64) < Int64(100) 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "              TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-        "            Filter: #t2.t2_int < Int64(3) AND #t2.t2_id < Int64(100) 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "            Filter: CAST(#t2.t2_int AS Int64) < Int64(3) AND 
CAST(#t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
         "              TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "      TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
     ];
diff --git a/datafusion/core/tests/sql/parquet.rs 
b/datafusion/core/tests/sql/parquet.rs
index 51304f608..8bec4f1dd 100644
--- a/datafusion/core/tests/sql/parquet.rs
+++ b/datafusion/core/tests/sql/parquet.rs
@@ -31,18 +31,18 @@ async fn parquet_query() {
     let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        "+----+-----------------------------------------+",
-        "| id | CAST(alltypes_plain.string_col AS Utf8) |",
-        "+----+-----------------------------------------+",
-        "| 4  | 0                                       |",
-        "| 5  | 1                                       |",
-        "| 6  | 0                                       |",
-        "| 7  | 1                                       |",
-        "| 2  | 0                                       |",
-        "| 3  | 1                                       |",
-        "| 0  | 0                                       |",
-        "| 1  | 1                                       |",
-        "+----+-----------------------------------------+",
+        "+----+---------------------------+",
+        "| id | alltypes_plain.string_col |",
+        "+----+---------------------------+",
+        "| 4  | 0                         |",
+        "| 5  | 1                         |",
+        "| 6  | 0                         |",
+        "| 7  | 1                         |",
+        "| 2  | 0                         |",
+        "| 3  | 1                         |",
+        "| 0  | 0                         |",
+        "| 1  | 1                         |",
+        "+----+---------------------------+",
     ];
 
     assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index f7bdc41a9..3c11b690d 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -428,7 +428,7 @@ async fn multiple_or_predicates() -> Result<()> {
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: #lineitem.l_partkey [l_partkey:Int64]",
         "    Projection: #part.p_partkey = #lineitem.l_partkey AS 
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, 
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size 
[#part.p_partkey = 
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, 
l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
-        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND 
#lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND 
Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= 
Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN 
Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND 
#lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int6 [...]
+        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND 
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN 
Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND 
#lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= 
CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR 
#part.p_brand = Utf8(\"Brand#34\") AN [...]
         "        CrossJoin: [l_partkey:Int64, l_quantity:Float64, 
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
         "          TableScan: lineitem projection=[l_partkey, l_quantity] 
[l_partkey:Int64, l_quantity:Float64]",
         "          TableScan: part projection=[p_partkey, p_brand, p_size] 
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
diff --git a/datafusion/core/tests/sql/subqueries.rs 
b/datafusion/core/tests/sql/subqueries.rs
index d85a26932..58561de12 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -328,7 +328,7 @@ order by s_name;
         Filter: #nation.n_name = Utf8("CANADA")
           TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("CANADA")]
       Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
-        Filter: #partsupp.ps_availqty > #__sq_3.__value
+        Filter: CAST(#partsupp.ps_availqty AS Float64) > #__sq_3.__value
           Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, 
#partsupp.ps_suppkey = #__sq_3.l_suppkey
             Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
               TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty]
@@ -436,18 +436,16 @@ order by value desc;
         .create_logical_plan(sql)
         .map_err(|e| format!("{:?} at {}", e, "error"))
         .unwrap();
-    println!("before:\n{}", plan.display_indent());
     let plan = ctx
         .optimize(&plan)
         .map_err(|e| format!("{:?} at {}", e, "error"))
         .unwrap();
     let actual = format!("{}", plan.display_indent());
-    println!("after:\n{}", actual);
     let expected = r#"Sort: #value DESC NULLS FIRST
   Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * 
partsupp.ps_availqty) AS value
     Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) > 
#__sq_1.__value
       CrossJoin:
-        Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]]
+        Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]]
           Inner Join: #supplier.s_nationkey = #nation.n_nationkey
             Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
               TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty, ps_supplycost]
@@ -455,7 +453,7 @@ order by value desc;
             Filter: #nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("GERMANY")]
         Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * 
Float64(0.0001) AS __value, alias=__sq_1
-          Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * 
#partsupp.ps_availqty)]]
+          Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost * 
CAST(#partsupp.ps_availqty AS Float64))]]
             Inner Join: #supplier.s_nationkey = #nation.n_nationkey
               Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
                 TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty, ps_supplycost]
diff --git a/datafusion/core/tests/sql/timestamp.rs 
b/datafusion/core/tests/sql/timestamp.rs
index 123342c42..847d63e81 100644
--- a/datafusion/core/tests/sql/timestamp.rs
+++ b/datafusion/core/tests/sql/timestamp.rs
@@ -1176,11 +1176,11 @@ async fn to_timestamp_i32() -> Result<()> {
     let results = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+--------------------------------------+",
-        "| totimestamp(CAST(Int64(1) AS Int32)) |",
-        "+--------------------------------------+",
-        "| 1970-01-01 00:00:00.000000001        |",
-        "+--------------------------------------+",
+        "+-------------------------------+",
+        "| totimestamp(Int64(1))         |",
+        "+-------------------------------+",
+        "| 1970-01-01 00:00:00.000000001 |",
+        "+-------------------------------+",
     ];
 
     assert_batches_eq!(expected, &results);
@@ -1196,11 +1196,11 @@ async fn to_timestamp_micros_i32() -> Result<()> {
     let results = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+--------------------------------------------+",
-        "| totimestampmicros(CAST(Int64(1) AS Int32)) |",
-        "+--------------------------------------------+",
-        "| 1970-01-01 00:00:00.000001                 |",
-        "+--------------------------------------------+",
+        "+-----------------------------+",
+        "| totimestampmicros(Int64(1)) |",
+        "+-----------------------------+",
+        "| 1970-01-01 00:00:00.000001  |",
+        "+-----------------------------+",
     ];
 
     assert_batches_eq!(expected, &results);
@@ -1216,11 +1216,11 @@ async fn to_timestamp_millis_i32() -> Result<()> {
     let results = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+--------------------------------------------+",
-        "| totimestampmillis(CAST(Int64(1) AS Int32)) |",
-        "+--------------------------------------------+",
-        "| 1970-01-01 00:00:00.001                    |",
-        "+--------------------------------------------+",
+        "+-----------------------------+",
+        "| totimestampmillis(Int64(1)) |",
+        "+-----------------------------+",
+        "| 1970-01-01 00:00:00.001     |",
+        "+-----------------------------+",
     ];
 
     assert_batches_eq!(expected, &results);
@@ -1236,11 +1236,11 @@ async fn to_timestamp_seconds_i32() -> Result<()> {
     let results = execute_to_batches(&ctx, sql).await;
 
     let expected = vec![
-        "+---------------------------------------------+",
-        "| totimestampseconds(CAST(Int64(1) AS Int32)) |",
-        "+---------------------------------------------+",
-        "| 1970-01-01 00:00:01                         |",
-        "+---------------------------------------------+",
+        "+------------------------------+",
+        "| totimestampseconds(Int64(1)) |",
+        "+------------------------------+",
+        "| 1970-01-01 00:00:01          |",
+        "+------------------------------+",
     ];
 
     assert_batches_eq!(expected, &results);
@@ -1512,11 +1512,11 @@ async fn cast_timestamp_before_1970() -> Result<()> {
     let sql = "select cast('1969-01-01T00:00:00Z' as timestamp);";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        
"+-------------------------------------------------------------------+",
-        "| CAST(Utf8(\"1969-01-01T00:00:00Z\") AS Timestamp(Nanosecond, None)) 
|",
-        
"+-------------------------------------------------------------------+",
-        "| 1969-01-01 00:00:00                                               
|",
-        
"+-------------------------------------------------------------------+",
+        "+------------------------------+",
+        "| Utf8(\"1969-01-01T00:00:00Z\") |",
+        "+------------------------------+",
+        "| 1969-01-01 00:00:00          |",
+        "+------------------------------+",
     ];
 
     assert_batches_eq!(expected, &actual);
@@ -1524,11 +1524,11 @@ async fn cast_timestamp_before_1970() -> Result<()> {
     let sql = "select cast('1969-01-01T00:00:00.1Z' as timestamp);";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        
"+---------------------------------------------------------------------+",
-        "| CAST(Utf8(\"1969-01-01T00:00:00.1Z\") AS Timestamp(Nanosecond, 
None)) |",
-        
"+---------------------------------------------------------------------+",
-        "| 1969-01-01 00:00:00.100                                             
|",
-        
"+---------------------------------------------------------------------+",
+        "+--------------------------------+",
+        "| Utf8(\"1969-01-01T00:00:00.1Z\") |",
+        "+--------------------------------+",
+        "| 1969-01-01 00:00:00.100        |",
+        "+--------------------------------+",
     ];
 
     assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/window.rs 
b/datafusion/core/tests/sql/window.rs
index 1c909fa71..6a1f39a03 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -63,15 +63,15 @@ async fn csv_query_window_with_partition_by() -> Result<()> 
{
                limit 5";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
-        
"+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+",
-        "| c9        | SUM(CAST(aggregate_test_100.c4 AS Int32)) | 
AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS 
Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) | 
MIN(CAST(aggregate_test_100.c4 AS Int32)) |",
-        
"+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+",
-        "| 28774375  | -16110                                    | -16110      
                              | 1                                           | 
-16110                                    | -16110                              
      |",
-        "| 63044568  | 3917                                      | 3917        
                              | 1                                           | 
3917                                      | 3917                                
      |",
-        "| 141047417 | -38455                                    | -19227.5    
                              | 2                                           | 
-16974                                    | -21481                              
      |",
-        "| 141680161 | -1114                                     | -1114       
                              | 1                                           | 
-1114                                     | -1114                               
      |",
-        "| 145294611 | 15673                                     | 15673       
                              | 1                                           | 
15673                                     | 15673                               
      |",
-        
"+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+",
+        
"+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+",
+        "| c9        | SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4) 
| COUNT(aggregate_test_100.c4) | MAX(aggregate_test_100.c4) | 
MIN(aggregate_test_100.c4) |",
+        
"+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+",
+        "| 28774375  | -16110                     | -16110                     
| 1                            | -16110                     | -16110            
         |",
+        "| 63044568  | 3917                       | 3917                       
| 1                            | 3917                       | 3917              
         |",
+        "| 141047417 | -38455                     | -19227.5                   
| 2                            | -16974                     | -21481            
         |",
+        "| 141680161 | -1114                      | -1114                      
| 1                            | -1114                      | -1114             
         |",
+        "| 145294611 | 15673                      | 15673                      
| 1                            | 15673                      | 15673             
         |",
+        
"+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
     Ok(())
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 8c6f26887..6226887e8 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -923,13 +923,13 @@ fn create_name(e: &Expr) -> Result<String> {
             name += "END";
             Ok(name)
         }
-        Expr::Cast { expr, data_type } => {
-            let expr = create_name(expr)?;
-            Ok(format!("CAST({} AS {:?})", expr, data_type))
+        Expr::Cast { expr, .. } => {
+            // CAST does not change the expression name
+            create_name(expr)
         }
-        Expr::TryCast { expr, data_type } => {
-            let expr = create_name(expr)?;
-            Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
+        Expr::TryCast { expr, .. } => {
+            // CAST does not change the expression name
+            create_name(expr)
         }
         Expr::Not(expr) => {
             let expr = create_name(expr)?;
@@ -1086,7 +1086,8 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
 #[cfg(test)]
 mod test {
     use crate::expr_fn::col;
-    use crate::{case, lit};
+    use crate::{case, lit, Expr};
+    use arrow::datatypes::DataType;
     use datafusion_common::{Result, ScalarValue};
 
     #[test]
@@ -1101,6 +1102,20 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn format_cast() -> Result<()> {
+        let expr = Expr::Cast {
+            expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))),
+            data_type: DataType::Utf8,
+        };
+        assert_eq!("CAST(Float32(1.23) AS Utf8)", format!("{}", expr));
+        assert_eq!("CAST(Float32(1.23) AS Utf8)", format!("{:?}", expr));
+        // note that CAST intentionally has a name that is different from its 
`Display`
+        // representation. CAST does not change the name of expressions.
+        assert_eq!("Float32(1.23)", expr.name()?);
+        Ok(())
+    }
+
     #[test]
     fn test_not() {
         assert_eq!(lit(1).not(), !lit(1));
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 8d6da350a..171381659 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -31,6 +31,7 @@ pub mod scalar_subquery_to_join;
 pub mod simplify_expressions;
 pub mod single_distinct_to_groupby;
 pub mod subquery_filter_to_join;
+pub mod type_coercion;
 pub mod utils;
 
 pub mod pre_cast_lit_in_comparison;
diff --git a/datafusion/optimizer/src/simplify_expressions.rs 
b/datafusion/optimizer/src/simplify_expressions.rs
index 334ec6182..1c826d7c3 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -1902,7 +1902,7 @@ mod tests {
             .build()
             .unwrap();
 
-        let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\
+        let expected = "Projection: Int32(0) AS Utf8(\"0\")\
             \n  TableScan: test";
         let actual = get_optimized_plan_formatted(&plan, &Utc::now());
         assert_eq!(expected, actual);
@@ -1949,7 +1949,7 @@ mod tests {
             time.timestamp_nanos()
         );
 
-        assert_eq!(actual, expected);
+        assert_eq!(expected, actual);
     }
 
     #[test]
@@ -1971,7 +1971,7 @@ mod tests {
             "Projection: NOT #test.a AS Boolean(true) OR Boolean(false) != 
test.a\
                         \n  TableScan: test";
 
-        assert_eq!(actual, expected);
+        assert_eq!(expected, actual);
     }
 
     #[test]
@@ -1993,7 +1993,7 @@ mod tests {
 
         // Note that constant folder runs and folds the entire
         // expression down to a single constant (true)
-        let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) < 
CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
+        let expected = "Filter: Boolean(true) AS now() < 
totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) + Int32(50000)\
                         \n  TableScan: test";
         let actual = get_optimized_plan_formatted(&plan, &time);
 
@@ -2025,11 +2025,11 @@ mod tests {
 
         // Note that constant folder runs and folds the entire
         // expression down to a single constant (true)
-        let expected = r#"Projection: Date32("18636") AS 
CAST(totimestamp(Utf8("2020-09-08T12:05:00+00:00")) AS Date32) + 
IntervalDayTime("528280977408")
+        let expected = r#"Projection: Date32("18636") AS 
totimestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408")
   TableScan: test"#;
         let actual = get_optimized_plan_formatted(&plan, &time);
 
-        assert_eq!(actual, expected);
+        assert_eq!(expected, actual);
     }
 
     #[test]
diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
new file mode 100644
index 000000000..d9f161599
--- /dev/null
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimizer rule for type validation and coercion
+
+use crate::{OptimizerConfig, OptimizerRule};
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchema, DFSchemaRef, Result};
+use datafusion_expr::binary_rule::coerce_types;
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, 
RewriteRecursion};
+use datafusion_expr::logical_plan::builder::build_join_schema;
+use datafusion_expr::logical_plan::JoinType;
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::ExprSchemable;
+use datafusion_expr::{Expr, LogicalPlan};
+
+#[derive(Default)]
+pub struct TypeCoercion {}
+
+impl TypeCoercion {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for TypeCoercion {
+    fn name(&self) -> &str {
+        "TypeCoercion"
+    }
+
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        optimizer_config: &mut OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        // optimize child plans first
+        let new_inputs = plan
+            .inputs()
+            .iter()
+            .map(|p| self.optimize(p, optimizer_config))
+            .collect::<Result<Vec<_>>>()?;
+
+        let schema = match new_inputs.len() {
+            1 => new_inputs[0].schema().clone(),
+            2 => DFSchemaRef::new(build_join_schema(
+                new_inputs[0].schema(),
+                new_inputs[1].schema(),
+                &JoinType::Inner,
+            )?),
+            _ => DFSchemaRef::new(DFSchema::empty()),
+        };
+
+        let mut expr_rewrite = TypeCoercionRewriter { schema };
+
+        let new_expr = plan
+            .expressions()
+            .into_iter()
+            .map(|expr| expr.rewrite(&mut expr_rewrite))
+            .collect::<Result<Vec<_>>>()?;
+
+        from_plan(plan, &new_expr, &new_inputs)
+    }
+}
+
+struct TypeCoercionRewriter {
+    schema: DFSchemaRef,
+}
+
+impl ExprRewriter for TypeCoercionRewriter {
+    fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
+        Ok(RewriteRecursion::Continue)
+    }
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        match &expr {
+            Expr::BinaryExpr { left, op, right } => {
+                let left_type = left.get_type(&self.schema)?;
+                let right_type = right.get_type(&self.schema)?;
+                match right_type {
+                    DataType::Interval(_) => {
+                        // we don't want to cast intervals because that breaks
+                        // the logic in the physical planner
+                        Ok(expr)
+                    }
+                    _ => {
+                        let coerced_type = coerce_types(&left_type, op, 
&right_type)?;
+                        let left = left.clone().cast_to(&coerced_type, 
&self.schema)?;
+                        let right = right.clone().cast_to(&coerced_type, 
&self.schema)?;
+                        match (&left, &right) {
+                            (Expr::Cast { .. }, _) | (_, Expr::Cast { .. }) => 
{
+                                Ok(Expr::BinaryExpr {
+                                    left: Box::new(left),
+                                    op: *op,
+                                    right: Box::new(right),
+                                })
+                            }
+                            _ => {
+                                // no cast was added so we return the original 
expression
+                                Ok(expr)
+                            }
+                        }
+                    }
+                }
+            }
+            _ => Ok(expr),
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::type_coercion::TypeCoercion;
+    use crate::{OptimizerConfig, OptimizerRule};
+    use datafusion_common::{DFSchema, Result};
+    use datafusion_expr::logical_plan::{EmptyRelation, Projection};
+    use datafusion_expr::{lit, LogicalPlan};
+    use std::sync::Arc;
+
+    #[test]
+    fn simple_case() -> Result<()> {
+        let expr = lit(1.2_f64).lt(lit(2_u32));
+        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::empty()),
+        }));
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], 
empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n  
EmptyRelation",
+            &format!("{:?}", plan)
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn nested_case() -> Result<()> {
+        let expr = lit(1.2_f64).lt(lit(2_u32));
+        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::empty()),
+        }));
+        let plan = LogicalPlan::Projection(Projection::try_new(
+            vec![expr.clone().or(expr)],
+            empty,
+            None,
+        )?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR 
Float64(1.2) < CAST(UInt32(2) AS Float64)\
+            \n  EmptyRelation", &format!("{:?}", plan));
+        Ok(())
+    }
+}
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 4226364c9..c344982b3 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -47,7 +47,13 @@ pub fn create_physical_expr(
     input_schema: &Schema,
     execution_props: &ExecutionProps,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    assert_eq!(input_schema.fields.len(), input_dfschema.fields().len());
+    if input_schema.fields.len() != input_dfschema.fields().len() {
+        return Err(DataFusionError::Internal(
+            "create_physical_expr passed Arrow schema and DataFusion \
+            schema with different number of fields"
+                .to_string(),
+        ));
+    }
     match e {
         Expr::Alias(expr, ..) => Ok(create_physical_expr(
             expr,

Reply via email to