This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 191d8b72f Add optimizer rule for type coercion (binary operations
only) (#3222)
191d8b72f is described below
commit 191d8b72ff7500872953b2749937abcd1c06f848
Author: Andy Grove <[email protected]>
AuthorDate: Tue Sep 6 06:16:37 2022 -0600
Add optimizer rule for type coercion (binary operations only) (#3222)
* Add binary type coercion to logical plan and do not allow CAST to change
an expression name
* fix tests
* update avro tests
* add reference to GitHub issue
* unignore timestamp_add_interval_months
* fix: update tests to use correct column types
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/execution/context.rs | 4 +
datafusion/core/src/physical_plan/planner.rs | 12 +-
datafusion/core/tests/dataframe_functions.rs | 16 +--
datafusion/core/tests/parquet_pruning.rs | 1 +
datafusion/core/tests/sql/aggregates.rs | 29 +++-
datafusion/core/tests/sql/avro.rs | 64 ++++-----
datafusion/core/tests/sql/decimal.rs | 106 +++++++-------
datafusion/core/tests/sql/explain_analyze.rs | 6 +-
datafusion/core/tests/sql/expr.rs | 4 +-
datafusion/core/tests/sql/functions.rs | 56 ++++----
datafusion/core/tests/sql/joins.rs | 10 +-
datafusion/core/tests/sql/parquet.rs | 24 ++--
datafusion/core/tests/sql/predicates.rs | 2 +-
datafusion/core/tests/sql/subqueries.rs | 8 +-
datafusion/core/tests/sql/timestamp.rs | 60 ++++----
datafusion/core/tests/sql/window.rs | 18 +--
datafusion/expr/src/expr.rs | 29 +++-
datafusion/optimizer/src/lib.rs | 1 +
datafusion/optimizer/src/simplify_expressions.rs | 12 +-
datafusion/optimizer/src/type_coercion.rs | 170 +++++++++++++++++++++++
datafusion/physical-expr/src/planner.rs | 8 +-
21 files changed, 427 insertions(+), 213 deletions(-)
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index 2a9d01124..a6a508d36 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -111,6 +111,7 @@ use
datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
+use datafusion_optimizer::type_coercion::TypeCoercion;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
@@ -1433,6 +1434,9 @@ impl SessionState {
}
rules.push(Arc::new(ReduceOuterJoin::new()));
rules.push(Arc::new(FilterPushDown::new()));
+ // we do type coercion after filter push down so that we don't push
CAST filters to Parquet
+ // until https://github.com/apache/arrow-datafusion/issues/3289 is
resolved
+ rules.push(Arc::new(TypeCoercion::new()));
rules.push(Arc::new(LimitPushDown::new()));
rules.push(Arc::new(SingleDistinctToGroupBy::new()));
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 8d7e0e9e4..a2d868361 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -128,13 +128,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
name += "END";
Ok(name)
}
- Expr::Cast { expr, data_type } => {
- let expr = create_physical_name(expr, false)?;
- Ok(format!("CAST({} AS {:?})", expr, data_type))
+ Expr::Cast { expr, .. } => {
+ // CAST does not change the expression name
+ create_physical_name(expr, false)
}
- Expr::TryCast { expr, data_type } => {
- let expr = create_physical_name(expr, false)?;
- Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
+ Expr::TryCast { expr, .. } => {
+ // CAST does not change the expression name
+ create_physical_name(expr, false)
}
Expr::Not(expr) => {
let expr = create_physical_name(expr, false)?;
diff --git a/datafusion/core/tests/dataframe_functions.rs
b/datafusion/core/tests/dataframe_functions.rs
index 19694285c..0d3631b18 100644
--- a/datafusion/core/tests/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe_functions.rs
@@ -667,14 +667,14 @@ async fn test_fn_substr() -> Result<()> {
async fn test_cast() -> Result<()> {
let expr = cast(col("b"), DataType::Float64);
let expected = vec![
- "+-------------------------+",
- "| CAST(test.b AS Float64) |",
- "+-------------------------+",
- "| 1 |",
- "| 10 |",
- "| 10 |",
- "| 100 |",
- "+-------------------------+",
+ "+--------+",
+ "| test.b |",
+ "+--------+",
+ "| 1 |",
+ "| 10 |",
+ "| 10 |",
+ "| 100 |",
+ "+--------+",
];
assert_fn_batches!(expr, expected);
diff --git a/datafusion/core/tests/parquet_pruning.rs
b/datafusion/core/tests/parquet_pruning.rs
index 0c3acf9fb..b6c286763 100644
--- a/datafusion/core/tests/parquet_pruning.rs
+++ b/datafusion/core/tests/parquet_pruning.rs
@@ -647,6 +647,7 @@ impl ContextWithParquet {
let pretty_input = pretty_format_batches(&input).unwrap().to_string();
let logical_plan = self.ctx.optimize(&logical_plan).expect("optimizing
plan");
+
let physical_plan = self
.ctx
.create_physical_plan(&logical_plan)
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index 7e0e785da..357addbc0 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -462,11 +462,11 @@ async fn csv_query_external_table_sum() {
"SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM
aggregate_test_100";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
-
"+-------------------------------------------+-------------------------------------------+",
- "| SUM(CAST(aggregate_test_100.c7 AS Int64)) |
SUM(CAST(aggregate_test_100.c8 AS Int64)) |",
-
"+-------------------------------------------+-------------------------------------------+",
- "| 13060 | 3017641
|",
-
"+-------------------------------------------+-------------------------------------------+",
+ "+----------------------------+----------------------------+",
+ "| SUM(aggregate_test_100.c7) | SUM(aggregate_test_100.c8) |",
+ "+----------------------------+----------------------------+",
+ "| 13060 | 3017641 |",
+ "+----------------------------+----------------------------+",
];
assert_batches_eq!(expected, &actual);
}
@@ -555,6 +555,7 @@ async fn csv_query_count_one() {
}
#[tokio::test]
+#[ignore] // https://github.com/apache/arrow-datafusion/issues/3353
async fn csv_query_approx_count() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;
@@ -571,6 +572,24 @@ async fn csv_query_approx_count() -> Result<()> {
Ok(())
}
+#[tokio::test]
+async fn csv_query_approx_count_dupe_expr_aliased() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv(&ctx).await?;
+ let sql =
+ "SELECT approx_distinct(c9) a, approx_distinct(c9) b FROM
aggregate_test_100";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+-----+",
+ "| a | b |",
+ "+-----+-----+",
+ "| 100 | 100 |",
+ "+-----+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
// This test executes the APPROX_PERCENTILE_CONT aggregation against the test
// data, asserting the estimated quantiles are ±5% their actual values.
//
diff --git a/datafusion/core/tests/sql/avro.rs
b/datafusion/core/tests/sql/avro.rs
index f4ff4cd7c..8fdef28bd 100644
--- a/datafusion/core/tests/sql/avro.rs
+++ b/datafusion/core/tests/sql/avro.rs
@@ -37,18 +37,18 @@ async fn avro_query() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----+-----------------------------------------+",
- "| id | CAST(alltypes_plain.string_col AS Utf8) |",
- "+----+-----------------------------------------+",
- "| 4 | 0 |",
- "| 5 | 1 |",
- "| 6 | 0 |",
- "| 7 | 1 |",
- "| 2 | 0 |",
- "| 3 | 1 |",
- "| 0 | 0 |",
- "| 1 | 1 |",
- "+----+-----------------------------------------+",
+ "+----+---------------------------+",
+ "| id | alltypes_plain.string_col |",
+ "+----+---------------------------+",
+ "| 4 | 0 |",
+ "| 5 | 1 |",
+ "| 6 | 0 |",
+ "| 7 | 1 |",
+ "| 2 | 0 |",
+ "| 3 | 1 |",
+ "| 0 | 0 |",
+ "| 1 | 1 |",
+ "+----+---------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -84,26 +84,26 @@ async fn avro_query_multiple_files() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----+-----------------------------------------+",
- "| id | CAST(alltypes_plain.string_col AS Utf8) |",
- "+----+-----------------------------------------+",
- "| 4 | 0 |",
- "| 5 | 1 |",
- "| 6 | 0 |",
- "| 7 | 1 |",
- "| 2 | 0 |",
- "| 3 | 1 |",
- "| 0 | 0 |",
- "| 1 | 1 |",
- "| 4 | 0 |",
- "| 5 | 1 |",
- "| 6 | 0 |",
- "| 7 | 1 |",
- "| 2 | 0 |",
- "| 3 | 1 |",
- "| 0 | 0 |",
- "| 1 | 1 |",
- "+----+-----------------------------------------+",
+ "+----+---------------------------+",
+ "| id | alltypes_plain.string_col |",
+ "+----+---------------------------+",
+ "| 4 | 0 |",
+ "| 5 | 1 |",
+ "| 6 | 0 |",
+ "| 7 | 1 |",
+ "| 2 | 0 |",
+ "| 3 | 1 |",
+ "| 0 | 0 |",
+ "| 1 | 1 |",
+ "| 4 | 0 |",
+ "| 5 | 1 |",
+ "| 6 | 0 |",
+ "| 7 | 1 |",
+ "| 2 | 0 |",
+ "| 3 | 1 |",
+ "| 0 | 0 |",
+ "| 1 | 1 |",
+ "+----+---------------------------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/decimal.rs
b/datafusion/core/tests/sql/decimal.rs
index 8ded8752d..7c74cdd52 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -27,11 +27,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+------------------------------------------+",
- "| CAST(Float64(1.23) AS Decimal128(10, 4)) |",
- "+------------------------------------------+",
- "| 1.2300 |",
- "+------------------------------------------+",
+ "+---------------+",
+ "| Float64(1.23) |",
+ "+---------------+",
+ "| 1.2300 |",
+ "+---------------+",
];
assert_batches_eq!(expected, &actual);
@@ -42,11 +42,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
-
"+---------------------------------------------------------------------+",
- "| CAST(CAST(Float64(1.23) AS Decimal128(10, 3)) AS Decimal128(10, 4))
|",
-
"+---------------------------------------------------------------------+",
- "| 1.2300
|",
-
"+---------------------------------------------------------------------+",
+ "+---------------+",
+ "| Float64(1.23) |",
+ "+---------------+",
+ "| 1.2300 |",
+ "+---------------+",
];
assert_batches_eq!(expected, &actual);
@@ -57,11 +57,11 @@ async fn decimal_cast() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+--------------------------------------------+",
- "| CAST(Float64(1.2345) AS Decimal128(24, 2)) |",
- "+--------------------------------------------+",
- "| 1.23 |",
- "+--------------------------------------------+",
+ "+-----------------+",
+ "| Float64(1.2345) |",
+ "+-----------------+",
+ "| 1.23 |",
+ "+-----------------+",
];
assert_batches_eq!(expected, &actual);
@@ -550,25 +550,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+----------------------------------------------------------------+",
- "| decimal_simple.c1 / CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
- "+----------------------------------------------------------------+",
- "| 1.000000000000 |",
- "| 2.000000000000 |",
- "| 2.000000000000 |",
- "| 3.000000000000 |",
- "| 3.000000000000 |",
- "| 3.000000000000 |",
- "| 4.000000000000 |",
- "| 4.000000000000 |",
- "| 4.000000000000 |",
- "| 4.000000000000 |",
- "| 5.000000000000 |",
- "| 5.000000000000 |",
- "| 5.000000000000 |",
- "| 5.000000000000 |",
- "| 5.000000000000 |",
- "+----------------------------------------------------------------+",
+ "+--------------------------------------+",
+ "| decimal_simple.c1 / Float64(0.00001) |",
+ "+--------------------------------------+",
+ "| 1.000000000000 |",
+ "| 2.000000000000 |",
+ "| 2.000000000000 |",
+ "| 3.000000000000 |",
+ "| 3.000000000000 |",
+ "| 3.000000000000 |",
+ "| 4.000000000000 |",
+ "| 4.000000000000 |",
+ "| 4.000000000000 |",
+ "| 4.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "| 5.000000000000 |",
+ "+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -609,25 +609,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+----------------------------------------------------------------+",
- "| decimal_simple.c5 % CAST(Float64(0.00001) AS Decimal128(5, 5)) |",
- "+----------------------------------------------------------------+",
- "| 0.0000040 |",
- "| 0.0000050 |",
- "| 0.0000090 |",
- "| 0.0000020 |",
- "| 0.0000050 |",
- "| 0.0000010 |",
- "| 0.0000040 |",
- "| 0.0000000 |",
- "| 0.0000000 |",
- "| 0.0000040 |",
- "| 0.0000020 |",
- "| 0.0000080 |",
- "| 0.0000030 |",
- "| 0.0000080 |",
- "| 0.0000000 |",
- "+----------------------------------------------------------------+",
+ "+--------------------------------------+",
+ "| decimal_simple.c5 % Float64(0.00001) |",
+ "+--------------------------------------+",
+ "| 0.0000040 |",
+ "| 0.0000050 |",
+ "| 0.0000090 |",
+ "| 0.0000020 |",
+ "| 0.0000050 |",
+ "| 0.0000010 |",
+ "| 0.0000040 |",
+ "| 0.0000000 |",
+ "| 0.0000000 |",
+ "| 0.0000040 |",
+ "| 0.0000020 |",
+ "| 0.0000080 |",
+ "| 0.0000030 |",
+ "| 0.0000080 |",
+ "| 0.0000000 |",
+ "+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/explain_analyze.rs
b/datafusion/core/tests/sql/explain_analyze.rs
index 894d45564..a63839c2f 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name,
#SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue,
#customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone,
#customer.c_comment\
- \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name,
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address,
#customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Int64(1) -
#lineitem.l_discount)]]\
+ \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name,
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address,
#customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS
Float64) - #lineitem.l_discount)]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
@@ -663,7 +663,7 @@ order by
\n Filter: #lineitem.l_returnflag = Utf8(\"R\")\
\n TableScan: lineitem projection=[l_orderkey, l_extendedprice,
l_discount, l_returnflag], partial_filters=[#lineitem.l_returnflag =
Utf8(\"R\")]\
\n TableScan: nation projection=[n_nationkey, n_name]";
- assert_eq!(format!("{:?}", plan.unwrap()), expected);
+ assert_eq!(expected, format!("{:?}", plan.unwrap()),);
Ok(())
}
@@ -694,7 +694,7 @@ async fn test_physical_plan_display_indent() {
" RepartitionExec: partitioning=Hash([Column { name:
\"c1\", index: 0 }], 9000)",
" AggregateExec: mode=Partial, gby=[c1@0 as c1],
aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]",
" CoalesceBatchesExec: target_batch_size=4096",
- " FilterExec: c12@1 < CAST(10 AS Float64)",
+ " FilterExec: c12@1 < 10",
" RepartitionExec:
partitioning=RoundRobinBatch(9000)",
" CsvExec:
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true,
limit=None, projection=[c1, c12]",
];
diff --git a/datafusion/core/tests/sql/expr.rs
b/datafusion/core/tests/sql/expr.rs
index 0c59724bd..3ca2c4738 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -247,8 +247,8 @@ async fn query_not() -> Result<()> {
async fn csv_query_sum_cast() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
- // c8 = i32; c9 = i64
- let sql = "SELECT c8 + c9 FROM aggregate_test_100";
+ // c8 = i32; c6 = i64
+ let sql = "SELECT c8 + c6 FROM aggregate_test_100";
// check that the physical and logical schemas are equal
execute(&ctx, sql).await;
}
diff --git a/datafusion/core/tests/sql/functions.rs
b/datafusion/core/tests/sql/functions.rs
index e7bcb24c7..802810d64 100644
--- a/datafusion/core/tests/sql/functions.rs
+++ b/datafusion/core/tests/sql/functions.rs
@@ -43,12 +43,12 @@ async fn csv_query_cast() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+-----------------------------------------+",
- "| CAST(aggregate_test_100.c12 AS Float32) |",
- "+-----------------------------------------+",
- "| 0.39144436 |",
- "| 0.3887028 |",
- "+-----------------------------------------+",
+ "+------------------------+",
+ "| aggregate_test_100.c12 |",
+ "+------------------------+",
+ "| 0.39144436 |",
+ "| 0.3887028 |",
+ "+------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -64,12 +64,12 @@ async fn csv_query_cast_literal() -> Result<()> {
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+--------------------+---------------------------+",
- "| c12 | CAST(Int64(1) AS Float32) |",
- "+--------------------+---------------------------+",
- "| 0.9294097332465232 | 1 |",
- "| 0.3114712539863804 | 1 |",
- "+--------------------+---------------------------+",
+ "+--------------------+----------+",
+ "| c12 | Int64(1) |",
+ "+--------------------+----------+",
+ "| 0.9294097332465232 | 1 |",
+ "| 0.3114712539863804 | 1 |",
+ "+--------------------+----------+",
];
assert_batches_eq!(expected, &actual);
@@ -98,14 +98,14 @@ async fn query_concat() -> Result<()> {
let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----------------------------------------------------+",
- "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |",
- "+----------------------------------------------------+",
- "| -hi-0 |",
- "| a-hi-1 |",
- "| aa-hi- |",
- "| aaa-hi-3 |",
- "+----------------------------------------------------+",
+ "+--------------------------------------+",
+ "| concat(test.c1,Utf8(\"-hi-\"),test.c2) |",
+ "+--------------------------------------+",
+ "| -hi-0 |",
+ "| a-hi-1 |",
+ "| aa-hi- |",
+ "| aaa-hi-3 |",
+ "+--------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
@@ -133,14 +133,14 @@ async fn query_array() -> Result<()> {
let sql = "SELECT make_array(c1, cast(c2 as varchar)) FROM test";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------------------------------------------+",
- "| makearray(test.c1,CAST(test.c2 AS Utf8)) |",
- "+------------------------------------------+",
- "| [, 0] |",
- "| [a, 1] |",
- "| [aa, ] |",
- "| [aaa, 3] |",
- "+------------------------------------------+",
+ "+----------------------------+",
+ "| makearray(test.c1,test.c2) |",
+ "+----------------------------+",
+ "| [, 0] |",
+ "| [a, 1] |",
+ "| [aa, ] |",
+ "| [aaa, 3] |",
+ "+----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
diff --git a/datafusion/core/tests/sql/joins.rs
b/datafusion/core/tests/sql/joins.rs
index b899ac220..4ff29ea39 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1438,9 +1438,9 @@ async fn reduce_left_join_1() -> Result<()> {
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id,
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
- " Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N]",
+ " Filter: CAST(#t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: #t2.t2_id < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
+ " Filter: CAST(#t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N,
t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
@@ -1481,7 +1481,7 @@ async fn reduce_left_join_2() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, #t2.t2_id,
#t2.t2_name, #t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: #t2.t2_int < Int64(10) OR #t1.t1_int > Int64(2) AND
#t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: CAST(#t2.t2_int AS Int64) < Int64(10) OR CAST(#t1.t1_int
AS Int64) > Int64(2) AND #t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
" Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
@@ -1528,9 +1528,9 @@ async fn reduce_left_join_3() -> Result<()> {
" Projection: #t3.t1_id, #t3.t1_name, #t3.t1_int, alias=t3
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: #t1.t1_id, #t1.t1_name, #t1.t1_int, alias=t3
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Inner Join: #t1.t1_id = #t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
- " Filter: #t1.t1_id < Int64(100) [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: CAST(#t1.t1_id AS Int64) < Int64(100)
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: #t2.t2_int < Int64(3) AND #t2.t2_id < Int64(100)
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: CAST(#t2.t2_int AS Int64) < Int64(3) AND
CAST(#t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
diff --git a/datafusion/core/tests/sql/parquet.rs
b/datafusion/core/tests/sql/parquet.rs
index 51304f608..8bec4f1dd 100644
--- a/datafusion/core/tests/sql/parquet.rs
+++ b/datafusion/core/tests/sql/parquet.rs
@@ -31,18 +31,18 @@ async fn parquet_query() {
let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----+-----------------------------------------+",
- "| id | CAST(alltypes_plain.string_col AS Utf8) |",
- "+----+-----------------------------------------+",
- "| 4 | 0 |",
- "| 5 | 1 |",
- "| 6 | 0 |",
- "| 7 | 1 |",
- "| 2 | 0 |",
- "| 3 | 1 |",
- "| 0 | 0 |",
- "| 1 | 1 |",
- "+----+-----------------------------------------+",
+ "+----+---------------------------+",
+ "| id | alltypes_plain.string_col |",
+ "+----+---------------------------+",
+ "| 4 | 0 |",
+ "| 5 | 1 |",
+ "| 6 | 0 |",
+ "| 7 | 1 |",
+ "| 2 | 0 |",
+ "| 3 | 1 |",
+ "| 0 | 0 |",
+ "| 1 | 1 |",
+ "+----+---------------------------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/predicates.rs
b/datafusion/core/tests/sql/predicates.rs
index f7bdc41a9..3c11b690d 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -428,7 +428,7 @@ async fn multiple_or_predicates() -> Result<()> {
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS
#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey,
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size
[#part.p_partkey =
#lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N,
l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
- " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND
#lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND
Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >=
Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN
Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND
#lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int6 [...]
+ " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND
#lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size BETWEEN
Int64(1) AND Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND
#lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <=
CAST(Int64(20) AS Float64) AND #part.p_size BETWEEN Int64(1) AND Int64(10) OR
#part.p_brand = Utf8(\"Brand#34\") AN [...]
" CrossJoin: [l_partkey:Int64, l_quantity:Float64,
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity]
[l_partkey:Int64, l_quantity:Float64]",
" TableScan: part projection=[p_partkey, p_brand, p_size]
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
diff --git a/datafusion/core/tests/sql/subqueries.rs
b/datafusion/core/tests/sql/subqueries.rs
index d85a26932..58561de12 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -328,7 +328,7 @@ order by s_name;
Filter: #nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[#nation.n_name = Utf8("CANADA")]
Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
- Filter: #partsupp.ps_availqty > #__sq_3.__value
+ Filter: CAST(#partsupp.ps_availqty AS Float64) > #__sq_3.__value
Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey,
#partsupp.ps_suppkey = #__sq_3.l_suppkey
Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty]
@@ -436,18 +436,16 @@ order by value desc;
.create_logical_plan(sql)
.map_err(|e| format!("{:?} at {}", e, "error"))
.unwrap();
- println!("before:\n{}", plan.display_indent());
let plan = ctx
.optimize(&plan)
.map_err(|e| format!("{:?} at {}", e, "error"))
.unwrap();
let actual = format!("{}", plan.display_indent());
- println!("after:\n{}", actual);
let expected = r#"Sort: #value DESC NULLS FIRST
Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS value
Filter: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) >
#__sq_1.__value
CrossJoin:
- Aggregate: groupBy=[[#partsupp.ps_partkey]],
aggr=[[SUM(#partsupp.ps_supplycost * #partsupp.ps_availqty)]]
+ Aggregate: groupBy=[[#partsupp.ps_partkey]],
aggr=[[SUM(#partsupp.ps_supplycost * CAST(#partsupp.ps_availqty AS Float64))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty, ps_supplycost]
@@ -455,7 +453,7 @@ order by value desc;
Filter: #nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[#nation.n_name = Utf8("GERMANY")]
Projection: #SUM(partsupp.ps_supplycost * partsupp.ps_availqty) *
Float64(0.0001) AS __value, alias=__sq_1
- Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost *
#partsupp.ps_availqty)]]
+ Aggregate: groupBy=[[]], aggr=[[SUM(#partsupp.ps_supplycost *
CAST(#partsupp.ps_availqty AS Float64))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty, ps_supplycost]
diff --git a/datafusion/core/tests/sql/timestamp.rs
b/datafusion/core/tests/sql/timestamp.rs
index 123342c42..847d63e81 100644
--- a/datafusion/core/tests/sql/timestamp.rs
+++ b/datafusion/core/tests/sql/timestamp.rs
@@ -1176,11 +1176,11 @@ async fn to_timestamp_i32() -> Result<()> {
let results = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+--------------------------------------+",
- "| totimestamp(CAST(Int64(1) AS Int32)) |",
- "+--------------------------------------+",
- "| 1970-01-01 00:00:00.000000001 |",
- "+--------------------------------------+",
+ "+-------------------------------+",
+ "| totimestamp(Int64(1)) |",
+ "+-------------------------------+",
+ "| 1970-01-01 00:00:00.000000001 |",
+ "+-------------------------------+",
];
assert_batches_eq!(expected, &results);
@@ -1196,11 +1196,11 @@ async fn to_timestamp_micros_i32() -> Result<()> {
let results = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+--------------------------------------------+",
- "| totimestampmicros(CAST(Int64(1) AS Int32)) |",
- "+--------------------------------------------+",
- "| 1970-01-01 00:00:00.000001 |",
- "+--------------------------------------------+",
+ "+-----------------------------+",
+ "| totimestampmicros(Int64(1)) |",
+ "+-----------------------------+",
+ "| 1970-01-01 00:00:00.000001 |",
+ "+-----------------------------+",
];
assert_batches_eq!(expected, &results);
@@ -1216,11 +1216,11 @@ async fn to_timestamp_millis_i32() -> Result<()> {
let results = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+--------------------------------------------+",
- "| totimestampmillis(CAST(Int64(1) AS Int32)) |",
- "+--------------------------------------------+",
- "| 1970-01-01 00:00:00.001 |",
- "+--------------------------------------------+",
+ "+-----------------------------+",
+ "| totimestampmillis(Int64(1)) |",
+ "+-----------------------------+",
+ "| 1970-01-01 00:00:00.001 |",
+ "+-----------------------------+",
];
assert_batches_eq!(expected, &results);
@@ -1236,11 +1236,11 @@ async fn to_timestamp_seconds_i32() -> Result<()> {
let results = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+---------------------------------------------+",
- "| totimestampseconds(CAST(Int64(1) AS Int32)) |",
- "+---------------------------------------------+",
- "| 1970-01-01 00:00:01 |",
- "+---------------------------------------------+",
+ "+------------------------------+",
+ "| totimestampseconds(Int64(1)) |",
+ "+------------------------------+",
+ "| 1970-01-01 00:00:01 |",
+ "+------------------------------+",
];
assert_batches_eq!(expected, &results);
@@ -1512,11 +1512,11 @@ async fn cast_timestamp_before_1970() -> Result<()> {
let sql = "select cast('1969-01-01T00:00:00Z' as timestamp);";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
-
"+-------------------------------------------------------------------+",
- "| CAST(Utf8(\"1969-01-01T00:00:00Z\") AS Timestamp(Nanosecond, None))
|",
-
"+-------------------------------------------------------------------+",
- "| 1969-01-01 00:00:00
|",
-
"+-------------------------------------------------------------------+",
+ "+------------------------------+",
+ "| Utf8(\"1969-01-01T00:00:00Z\") |",
+ "+------------------------------+",
+ "| 1969-01-01 00:00:00 |",
+ "+------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -1524,11 +1524,11 @@ async fn cast_timestamp_before_1970() -> Result<()> {
let sql = "select cast('1969-01-01T00:00:00.1Z' as timestamp);";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
-
"+---------------------------------------------------------------------+",
- "| CAST(Utf8(\"1969-01-01T00:00:00.1Z\") AS Timestamp(Nanosecond,
None)) |",
-
"+---------------------------------------------------------------------+",
- "| 1969-01-01 00:00:00.100
|",
-
"+---------------------------------------------------------------------+",
+ "+--------------------------------+",
+ "| Utf8(\"1969-01-01T00:00:00.1Z\") |",
+ "+--------------------------------+",
+ "| 1969-01-01 00:00:00.100 |",
+ "+--------------------------------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/window.rs
b/datafusion/core/tests/sql/window.rs
index 1c909fa71..6a1f39a03 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -63,15 +63,15 @@ async fn csv_query_window_with_partition_by() -> Result<()>
{
limit 5";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
-
"+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+",
- "| c9 | SUM(CAST(aggregate_test_100.c4 AS Int32)) |
AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS
Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) |
MIN(CAST(aggregate_test_100.c4 AS Int32)) |",
-
"+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+",
- "| 28774375 | -16110 | -16110
| 1 |
-16110 | -16110
|",
- "| 63044568 | 3917 | 3917
| 1 |
3917 | 3917
|",
- "| 141047417 | -38455 | -19227.5
| 2 |
-16974 | -21481
|",
- "| 141680161 | -1114 | -1114
| 1 |
-1114 | -1114
|",
- "| 145294611 | 15673 | 15673
| 1 |
15673 | 15673
|",
-
"+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+",
+
"+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+",
+ "| c9 | SUM(aggregate_test_100.c4) | AVG(aggregate_test_100.c4)
| COUNT(aggregate_test_100.c4) | MAX(aggregate_test_100.c4) |
MIN(aggregate_test_100.c4) |",
+
"+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+",
+ "| 28774375 | -16110 | -16110
| 1 | -16110 | -16110
|",
+ "| 63044568 | 3917 | 3917
| 1 | 3917 | 3917
|",
+ "| 141047417 | -38455 | -19227.5
| 2 | -16974 | -21481
|",
+ "| 141680161 | -1114 | -1114
| 1 | -1114 | -1114
|",
+ "| 145294611 | 15673 | 15673
| 1 | 15673 | 15673
|",
+
"+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 8c6f26887..6226887e8 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -923,13 +923,13 @@ fn create_name(e: &Expr) -> Result<String> {
name += "END";
Ok(name)
}
- Expr::Cast { expr, data_type } => {
- let expr = create_name(expr)?;
- Ok(format!("CAST({} AS {:?})", expr, data_type))
+ Expr::Cast { expr, .. } => {
+ // CAST does not change the expression name
+ create_name(expr)
}
- Expr::TryCast { expr, data_type } => {
- let expr = create_name(expr)?;
- Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
+ Expr::TryCast { expr, .. } => {
+ // CAST does not change the expression name
+ create_name(expr)
}
Expr::Not(expr) => {
let expr = create_name(expr)?;
@@ -1086,7 +1086,8 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
#[cfg(test)]
mod test {
use crate::expr_fn::col;
- use crate::{case, lit};
+ use crate::{case, lit, Expr};
+ use arrow::datatypes::DataType;
use datafusion_common::{Result, ScalarValue};
#[test]
@@ -1101,6 +1102,20 @@ mod test {
Ok(())
}
+ #[test]
+ fn format_cast() -> Result<()> {
+ let expr = Expr::Cast {
+ expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))),
+ data_type: DataType::Utf8,
+ };
+ assert_eq!("CAST(Float32(1.23) AS Utf8)", format!("{}", expr));
+ assert_eq!("CAST(Float32(1.23) AS Utf8)", format!("{:?}", expr));
+ // note that CAST intentionally has a name that is different from its
`Display`
+ // representation. CAST does not change the name of expressions.
+ assert_eq!("Float32(1.23)", expr.name()?);
+ Ok(())
+ }
+
#[test]
fn test_not() {
assert_eq!(lit(1).not(), !lit(1));
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 8d6da350a..171381659 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -31,6 +31,7 @@ pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
+pub mod type_coercion;
pub mod utils;
pub mod pre_cast_lit_in_comparison;
diff --git a/datafusion/optimizer/src/simplify_expressions.rs
b/datafusion/optimizer/src/simplify_expressions.rs
index 334ec6182..1c826d7c3 100644
--- a/datafusion/optimizer/src/simplify_expressions.rs
+++ b/datafusion/optimizer/src/simplify_expressions.rs
@@ -1902,7 +1902,7 @@ mod tests {
.build()
.unwrap();
- let expected = "Projection: Int32(0) AS CAST(Utf8(\"0\") AS Int32)\
+ let expected = "Projection: Int32(0) AS Utf8(\"0\")\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
@@ -1949,7 +1949,7 @@ mod tests {
time.timestamp_nanos()
);
- assert_eq!(actual, expected);
+ assert_eq!(expected, actual);
}
#[test]
@@ -1971,7 +1971,7 @@ mod tests {
"Projection: NOT #test.a AS Boolean(true) OR Boolean(false) !=
test.a\
\n TableScan: test";
- assert_eq!(actual, expected);
+ assert_eq!(expected, actual);
}
#[test]
@@ -1993,7 +1993,7 @@ mod tests {
// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
- let expected = "Filter: Boolean(true) AS CAST(now() AS Int64) <
CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
+ let expected = "Filter: Boolean(true) AS now() <
totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) + Int32(50000)\
\n TableScan: test";
let actual = get_optimized_plan_formatted(&plan, &time);
@@ -2025,11 +2025,11 @@ mod tests {
// Note that constant folder runs and folds the entire
// expression down to a single constant (true)
- let expected = r#"Projection: Date32("18636") AS
CAST(totimestamp(Utf8("2020-09-08T12:05:00+00:00")) AS Date32) +
IntervalDayTime("528280977408")
+ let expected = r#"Projection: Date32("18636") AS
totimestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408")
TableScan: test"#;
let actual = get_optimized_plan_formatted(&plan, &time);
- assert_eq!(actual, expected);
+ assert_eq!(expected, actual);
}
#[test]
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/type_coercion.rs
new file mode 100644
index 000000000..d9f161599
--- /dev/null
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimizer rule for type validation and coercion
+
+use crate::{OptimizerConfig, OptimizerRule};
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchema, DFSchemaRef, Result};
+use datafusion_expr::binary_rule::coerce_types;
+use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter,
RewriteRecursion};
+use datafusion_expr::logical_plan::builder::build_join_schema;
+use datafusion_expr::logical_plan::JoinType;
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::ExprSchemable;
+use datafusion_expr::{Expr, LogicalPlan};
+
+#[derive(Default)]
+pub struct TypeCoercion {}
+
+impl TypeCoercion {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl OptimizerRule for TypeCoercion {
+ fn name(&self) -> &str {
+ "TypeCoercion"
+ }
+
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ // optimize child plans first
+ let new_inputs = plan
+ .inputs()
+ .iter()
+ .map(|p| self.optimize(p, optimizer_config))
+ .collect::<Result<Vec<_>>>()?;
+
+ let schema = match new_inputs.len() {
+ 1 => new_inputs[0].schema().clone(),
+ 2 => DFSchemaRef::new(build_join_schema(
+ new_inputs[0].schema(),
+ new_inputs[1].schema(),
+ &JoinType::Inner,
+ )?),
+ _ => DFSchemaRef::new(DFSchema::empty()),
+ };
+
+ let mut expr_rewrite = TypeCoercionRewriter { schema };
+
+ let new_expr = plan
+ .expressions()
+ .into_iter()
+ .map(|expr| expr.rewrite(&mut expr_rewrite))
+ .collect::<Result<Vec<_>>>()?;
+
+ from_plan(plan, &new_expr, &new_inputs)
+ }
+}
+
+struct TypeCoercionRewriter {
+ schema: DFSchemaRef,
+}
+
+impl ExprRewriter for TypeCoercionRewriter {
+ fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
+ Ok(RewriteRecursion::Continue)
+ }
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ match &expr {
+ Expr::BinaryExpr { left, op, right } => {
+ let left_type = left.get_type(&self.schema)?;
+ let right_type = right.get_type(&self.schema)?;
+ match right_type {
+ DataType::Interval(_) => {
+ // we don't want to cast intervals because that breaks
+ // the logic in the physical planner
+ Ok(expr)
+ }
+ _ => {
+ let coerced_type = coerce_types(&left_type, op,
&right_type)?;
+ let left = left.clone().cast_to(&coerced_type,
&self.schema)?;
+ let right = right.clone().cast_to(&coerced_type,
&self.schema)?;
+ match (&left, &right) {
+ (Expr::Cast { .. }, _) | (_, Expr::Cast { .. }) =>
{
+ Ok(Expr::BinaryExpr {
+ left: Box::new(left),
+ op: *op,
+ right: Box::new(right),
+ })
+ }
+ _ => {
+ // no cast was added so we return the original
expression
+ Ok(expr)
+ }
+ }
+ }
+ }
+ }
+ _ => Ok(expr),
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::type_coercion::TypeCoercion;
+ use crate::{OptimizerConfig, OptimizerRule};
+ use datafusion_common::{DFSchema, Result};
+ use datafusion_expr::logical_plan::{EmptyRelation, Projection};
+ use datafusion_expr::{lit, LogicalPlan};
+ use std::sync::Arc;
+
+ #[test]
+ fn simple_case() -> Result<()> {
+ let expr = lit(1.2_f64).lt(lit(2_u32));
+ let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: Arc::new(DFSchema::empty()),
+ }));
+ let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty, None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n
EmptyRelation",
+ &format!("{:?}", plan)
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn nested_case() -> Result<()> {
+ let expr = lit(1.2_f64).lt(lit(2_u32));
+ let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: Arc::new(DFSchema::empty()),
+ }));
+ let plan = LogicalPlan::Projection(Projection::try_new(
+ vec![expr.clone().or(expr)],
+ empty,
+ None,
+ )?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR
Float64(1.2) < CAST(UInt32(2) AS Float64)\
+ \n EmptyRelation", &format!("{:?}", plan));
+ Ok(())
+ }
+}
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 4226364c9..c344982b3 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -47,7 +47,13 @@ pub fn create_physical_expr(
input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
- assert_eq!(input_schema.fields.len(), input_dfschema.fields().len());
+ if input_schema.fields.len() != input_dfschema.fields().len() {
+ return Err(DataFusionError::Internal(
+ "create_physical_expr passed Arrow schema and DataFusion \
+ schema with different number of fields"
+ .to_string(),
+ ));
+ }
match e {
Expr::Alias(expr, ..) => Ok(create_physical_expr(
expr,