This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 787d00047b refactor: move type_coercion to analyzer (#5831)
787d00047b is described below
commit 787d00047be6e4085418a839301e714c78249915
Author: jakevin <[email protected]>
AuthorDate: Wed Apr 5 19:40:52 2023 +0800
refactor: move type_coercion to analyzer (#5831)
---
benchmarks/expected-plans/q11.txt | 8 +-
benchmarks/expected-plans/q17.txt | 98 +++++------
datafusion/core/src/physical_plan/planner.rs | 4 +-
datafusion/core/tests/fifo.rs | 2 +-
datafusion/core/tests/sql/timestamp.rs | 27 +--
datafusion/core/tests/sql/window.rs | 10 +-
.../core/tests/sqllogictests/test_files/dates.slt | 2 +-
datafusion/expr/src/expr_schema.rs | 57 +++++-
datafusion/optimizer/src/analyzer/mod.rs | 5 +-
.../optimizer/src/{ => analyzer}/type_coercion.rs | 192 +++++++++------------
datafusion/optimizer/src/lib.rs | 1 -
datafusion/optimizer/src/optimizer.rs | 2 -
.../src/simplify_expressions/expr_simplifier.rs | 5 +-
13 files changed, 217 insertions(+), 196 deletions(-)
diff --git a/benchmarks/expected-plans/q11.txt
b/benchmarks/expected-plans/q11.txt
index a403a31e8b..2f87f7c98e 100644
--- a/benchmarks/expected-plans/q11.txt
+++ b/benchmarks/expected-plans/q11.txt
@@ -3,7 +3,7 @@
+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| logical_plan | Sort: value DESC NULLS FIRST
|
| | Projection: partsupp.ps_partkey,
SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
|
-| | Filter: CAST(SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__scalar_sq_1.__value AS
Decimal128(38, 15)) |
+| | Filter: CAST(SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.__value
|
| | CrossJoin:
|
| | Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) *
CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] |
| | Projection: partsupp.ps_partkey,
partsupp.ps_availqty, partsupp.ps_supplycost
|
@@ -16,7 +16,7 @@
| | Filter: nation.n_name = Utf8("GERMANY")
|
| | TableScan: nation projection=[n_nationkey,
n_name]
|
| | SubqueryAlias: __scalar_sq_1
|
-| | Projection: CAST(SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value
|
+| | Projection: CAST(CAST(SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) AS
__value |
| | Aggregate: groupBy=[[]],
aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) *
CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] |
| | Projection: partsupp.ps_availqty,
partsupp.ps_supplycost
|
| | Inner Join: supplier.s_nationkey =
nation.n_nationkey
|
@@ -30,7 +30,7 @@
| physical_plan | SortExec: expr=[value@1 DESC]
|
| | ProjectionExec: expr=[ps_partkey@0 as ps_partkey,
SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value]
|
| | CoalesceBatchesExec: target_batch_size=8192
|
-| | FilterExec: CAST(SUM(partsupp.ps_supplycost *
partsupp.ps_availqty)@1 AS Decimal128(38, 15)) > CAST(__value@2 AS
Decimal128(38, 15)) |
+| | FilterExec: CAST(SUM(partsupp.ps_supplycost *
partsupp.ps_availqty)@1 AS Decimal128(38, 15)) > __value@2
|
| | CrossJoinExec
|
| | CoalescePartitionsExec
|
| | AggregateExec: mode=FinalPartitioned,
gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost *
partsupp.ps_availqty)] |
@@ -59,7 +59,7 @@
| | CoalesceBatchesExec:
target_batch_size=8192
|
| | FilterExec: n_name@1 =
GERMANY
|
| | MemoryExec:
partitions=0, partition_sizes=[]
|
-| | ProjectionExec:
expr=[CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) *
0.0001 as __value] |
+| | ProjectionExec:
expr=[CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS
Float64) * 0.0001 AS Decimal128(38, 15)) as __value] |
| | AggregateExec: mode=Final, gby=[],
aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)]
|
| | CoalescePartitionsExec
|
| | AggregateExec: mode=Partial, gby=[],
aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)]
|
diff --git a/benchmarks/expected-plans/q17.txt
b/benchmarks/expected-plans/q17.txt
index be3e81084a..9924555f6d 100644
--- a/benchmarks/expected-plans/q17.txt
+++ b/benchmarks/expected-plans/q17.txt
@@ -1,49 +1,49 @@
-+---------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| plan_type | plan
|
-+---------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
-| logical_plan | Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) /
Float64(7) AS avg_yearly
|
-| | Aggregate: groupBy=[[]],
aggr=[[SUM(lineitem.l_extendedprice)]]
|
-| | Projection: lineitem.l_extendedprice
|
-| | Inner Join: part.p_partkey = __scalar_sq_1.l_partkey
Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) <
CAST(__scalar_sq_1.__value AS Decimal128(30, 15))
|
-| | Projection: lineitem.l_quantity,
lineitem.l_extendedprice, part.p_partkey
|
-| | Inner Join: lineitem.l_partkey = part.p_partkey
|
-| | TableScan: lineitem projection=[l_partkey,
l_quantity, l_extendedprice]
|
-| | Projection: part.p_partkey
|
-| | Filter: part.p_brand = Utf8("Brand#23") AND
part.p_container = Utf8("MED BOX")
|
-| | TableScan: part projection=[p_partkey,
p_brand, p_container]
|
-| | SubqueryAlias: __scalar_sq_1
|
-| | Projection: lineitem.l_partkey, Float64(0.2) *
CAST(AVG(lineitem.l_quantity) AS Float64) AS __value
|
-| | Aggregate: groupBy=[[lineitem.l_partkey]],
aggr=[[AVG(lineitem.l_quantity)]]
|
-| | TableScan: lineitem projection=[l_partkey,
l_quantity]
|
-| physical_plan | ProjectionExec: expr=[CAST(SUM(lineitem.l_extendedprice)@0
AS Float64) / 7 as avg_yearly]
|
-| | AggregateExec: mode=Final, gby=[],
aggr=[SUM(lineitem.l_extendedprice)]
|
-| | CoalescePartitionsExec
|
-| | AggregateExec: mode=Partial, gby=[],
aggr=[SUM(lineitem.l_extendedprice)]
|
-| | ProjectionExec: expr=[l_extendedprice@1 as
l_extendedprice]
|
-| | CoalesceBatchesExec: target_batch_size=8192
|
-| | HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: "p_partkey", index: 2 }, Column { name: "l_partkey", index:
0 })], filter=BinaryExpr { left: CastExpr { expr: Column { name: "l_quantity",
index: 0 }, cast_type: Decimal128(30, 15), cast_options: CastOptions { safe:
false } }, op: Lt, right: CastExpr { expr: Column { name: "__value", index: 1
}, cast_type: Decimal128(30, 15), cast_options: CastOptions { safe: false } } }
|
-| | CoalesceBatchesExec: target_batch_size=8192
|
-| | RepartitionExec: partitioning=Hash([Column {
name: "p_partkey", index: 2 }], 2), input_partitions=2
|
-| | RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=2
|
-| | ProjectionExec: expr=[l_quantity@1 as
l_quantity, l_extendedprice@2 as l_extendedprice, p_partkey@3 as p_partkey]
|
-| | CoalesceBatchesExec:
target_batch_size=8192
|
-| | HashJoinExec: mode=Partitioned,
join_type=Inner, on=[(Column { name: "l_partkey", index: 0 }, Column { name:
"p_partkey", index: 0 })]
|
-| | CoalesceBatchesExec:
target_batch_size=8192
|
-| | RepartitionExec:
partitioning=Hash([Column { name: "l_partkey", index: 0 }], 2),
input_partitions=0
|
-| | MemoryExec: partitions=0,
partition_sizes=[]
|
-| | CoalesceBatchesExec:
target_batch_size=8192
|
-| | RepartitionExec:
partitioning=Hash([Column { name: "p_partkey", index: 0 }], 2),
input_partitions=2
|
-| | RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=0
|
-| | ProjectionExec:
expr=[p_partkey@0 as p_partkey]
|
-| | CoalesceBatchesExec:
target_batch_size=8192
|
-| | FilterExec: p_brand@1 =
Brand#23 AND p_container@2 = MED BOX
|
-| | MemoryExec:
partitions=0, partition_sizes=[]
|
-| | ProjectionExec: expr=[l_partkey@0 as
l_partkey, 0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) as __value]
|
-| | AggregateExec: mode=FinalPartitioned,
gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)]
|
-| | CoalesceBatchesExec:
target_batch_size=8192
|
-| | RepartitionExec:
partitioning=Hash([Column { name: "l_partkey", index: 0 }], 2),
input_partitions=2
|
-| | RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=0
|
-| | AggregateExec: mode=Partial,
gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)]
|
-| | MemoryExec: partitions=0,
partition_sizes=[]
|
-| |
|
-+---------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
\ No newline at end of file
++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+| plan_type | plan
|
++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+| logical_plan | Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) /
Float64(7) AS avg_yearly
|
+| | Aggregate: groupBy=[[]],
aggr=[[SUM(lineitem.l_extendedprice)]]
|
+| | Projection: lineitem.l_extendedprice
|
+| | Inner Join: part.p_partkey = __scalar_sq_1.l_partkey
Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.__value
|
+| | Projection: lineitem.l_quantity,
lineitem.l_extendedprice, part.p_partkey
|
+| | Inner Join: lineitem.l_partkey = part.p_partkey
|
+| | TableScan: lineitem projection=[l_partkey,
l_quantity, l_extendedprice]
|
+| | Projection: part.p_partkey
|
+| | Filter: part.p_brand = Utf8("Brand#23") AND
part.p_container = Utf8("MED BOX")
|
+| | TableScan: part projection=[p_partkey,
p_brand, p_container]
|
+| | SubqueryAlias: __scalar_sq_1
|
+| | Projection: lineitem.l_partkey, CAST(Float64(0.2)
* CAST(AVG(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)) AS __value
|
+| | Aggregate: groupBy=[[lineitem.l_partkey]],
aggr=[[AVG(lineitem.l_quantity)]]
|
+| | TableScan: lineitem projection=[l_partkey,
l_quantity]
|
+| physical_plan | ProjectionExec: expr=[CAST(SUM(lineitem.l_extendedprice)@0
AS Float64) / 7 as avg_yearly]
|
+| | AggregateExec: mode=Final, gby=[],
aggr=[SUM(lineitem.l_extendedprice)]
|
+| | CoalescePartitionsExec
|
+| | AggregateExec: mode=Partial, gby=[],
aggr=[SUM(lineitem.l_extendedprice)]
|
+| | ProjectionExec: expr=[l_extendedprice@1 as
l_extendedprice]
|
+| | CoalesceBatchesExec: target_batch_size=8192
|
+| | HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: "p_partkey", index: 2 }, Column { name: "l_partkey", index:
0 })], filter=BinaryExpr { left: CastExpr { expr: Column { name: "l_quantity",
index: 0 }, cast_type: Decimal128(30, 15), cast_options: CastOptions { safe:
false } }, op: Lt, right: Column { name: "__value", index: 1 } } |
+| | CoalesceBatchesExec: target_batch_size=8192
|
+| | RepartitionExec: partitioning=Hash([Column {
name: "p_partkey", index: 2 }], 2), input_partitions=2
|
+| | RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=2
|
+| | ProjectionExec: expr=[l_quantity@1 as
l_quantity, l_extendedprice@2 as l_extendedprice, p_partkey@3 as p_partkey]
|
+| | CoalesceBatchesExec:
target_batch_size=8192
|
+| | HashJoinExec: mode=Partitioned,
join_type=Inner, on=[(Column { name: "l_partkey", index: 0 }, Column { name:
"p_partkey", index: 0 })]
|
+| | CoalesceBatchesExec:
target_batch_size=8192
|
+| | RepartitionExec:
partitioning=Hash([Column { name: "l_partkey", index: 0 }], 2),
input_partitions=0
|
+| | MemoryExec: partitions=0,
partition_sizes=[]
|
+| | CoalesceBatchesExec:
target_batch_size=8192
|
+| | RepartitionExec:
partitioning=Hash([Column { name: "p_partkey", index: 0 }], 2),
input_partitions=2
|
+| | RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=0
|
+| | ProjectionExec:
expr=[p_partkey@0 as p_partkey]
|
+| | CoalesceBatchesExec:
target_batch_size=8192
|
+| | FilterExec: p_brand@1 =
Brand#23 AND p_container@2 = MED BOX
|
+| | MemoryExec:
partitions=0, partition_sizes=[]
|
+| | ProjectionExec: expr=[l_partkey@0 as
l_partkey, CAST(0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) AS
Decimal128(30, 15)) as __value]
|
+| | AggregateExec: mode=FinalPartitioned,
gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)]
|
+| | CoalesceBatchesExec:
target_batch_size=8192
|
+| | RepartitionExec:
partitioning=Hash([Column { name: "l_partkey", index: 0 }], 2),
input_partitions=2
|
+| | RepartitionExec:
partitioning=RoundRobinBatch(2), input_partitions=0
|
+| | AggregateExec: mode=Partial,
gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)]
|
+| | MemoryExec: partitions=0,
partition_sizes=[]
|
+| |
|
++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
\ No newline at end of file
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 25afdb4d1c..8c3f61b01e 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -2164,9 +2164,7 @@ mod tests {
assert_contains!(
&e,
- r#"type_coercion
-caused by
-Internal error: Optimizer rule 'type_coercion' failed due to unexpected error:
Error during planning: Can not find compatible types to compare Boolean with
[Struct([Field { name: "foo", data_type: Boolean, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }]), Utf8]. This was likely caused by a
bug in DataFusion's code and we would welcome that you file an bug report in
our issue tracker"#
+ r#"Error during planning: Can not find compatible types to compare
Boolean with [Struct([Field { name: "foo", data_type: Boolean, nullable: false,
dict_id: 0, dict_is_ordered: false, metadata: {} }]), Utf8]"#
);
Ok(())
diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs
index 5c12045c43..f5a62ddee4 100644
--- a/datafusion/core/tests/fifo.rs
+++ b/datafusion/core/tests/fifo.rs
@@ -104,7 +104,7 @@ mod unix_test {
// Create a new temporary FIFO file
let tmp_dir = TempDir::new()?;
let fifo_path =
- create_fifo_file(&tmp_dir, &format!("fifo_{:?}.csv",
unbounded_file))?;
+ create_fifo_file(&tmp_dir,
&format!("fifo_{unbounded_file:?}.csv"))?;
// Execution can calculated at least one RecordBatch after the number
of
// "joinable_lines_length" lines are read.
let joinable_lines_length =
diff --git a/datafusion/core/tests/sql/timestamp.rs
b/datafusion/core/tests/sql/timestamp.rs
index 6c6b19b38d..afc1932b74 100644
--- a/datafusion/core/tests/sql/timestamp.rs
+++ b/datafusion/core/tests/sql/timestamp.rs
@@ -925,14 +925,12 @@ async fn test_ts_dt_binary_ops() -> Result<()> {
let batch = &plan[0];
let mut res: Option<String> = None;
for row in 0..batch.num_rows() {
- if &array_value_to_string(batch.column(0), row)?
- == "logical_plan after type_coercion"
- {
+ if &array_value_to_string(batch.column(0), row)? ==
"initial_logical_plan" {
res = Some(array_value_to_string(batch.column(1), row)?);
break;
}
}
- assert_eq!(res, Some("Projection: CAST(Utf8(\"2000-01-01\") AS
Timestamp(Nanosecond, None)) >= CAST(CAST(Utf8(\"2000-01-01\") AS Date32) AS
Timestamp(Nanosecond, None))\n EmptyRelation".to_string()));
+ assert_eq!(res, Some("Projection: CAST(Utf8(\"2000-01-01\") AS
Timestamp(Nanosecond, None)) >= CAST(Utf8(\"2000-01-01\") AS Date32)\n
EmptyRelation".to_string()));
//test cast path timestamp date using function
let sql = "select now() >= '2000-01-01'::date";
@@ -942,14 +940,18 @@ async fn test_ts_dt_binary_ops() -> Result<()> {
let batch = &plan[0];
let mut res: Option<String> = None;
for row in 0..batch.num_rows() {
- if &array_value_to_string(batch.column(0), row)?
- == "logical_plan after type_coercion"
- {
+ if &array_value_to_string(batch.column(0), row)? ==
"initial_logical_plan" {
res = Some(array_value_to_string(batch.column(1), row)?);
break;
}
}
- assert_eq!(res, Some("Projection: CAST(now() AS Timestamp(Nanosecond,
None)) >= CAST(CAST(Utf8(\"2000-01-01\") AS Date32) AS Timestamp(Nanosecond,
None))\n EmptyRelation".to_string()));
+ assert_eq!(
+ res,
+ Some(
+ "Projection: now() >= CAST(Utf8(\"2000-01-01\") AS Date32)\n
EmptyRelation"
+ .to_string()
+ )
+ );
let sql = "select now() = current_date()";
let df = ctx.sql(sql).await.unwrap();
@@ -958,14 +960,15 @@ async fn test_ts_dt_binary_ops() -> Result<()> {
let batch = &plan[0];
let mut res: Option<String> = None;
for row in 0..batch.num_rows() {
- if &array_value_to_string(batch.column(0), row)?
- == "logical_plan after type_coercion"
- {
+ if &array_value_to_string(batch.column(0), row)? ==
"initial_logical_plan" {
res = Some(array_value_to_string(batch.column(1), row)?);
break;
}
}
- assert_eq!(res, Some("Projection: CAST(now() AS Timestamp(Nanosecond,
None)) = CAST(currentdate() AS Timestamp(Nanosecond, None))\n
EmptyRelation".to_string()));
+ assert_eq!(
+ res,
+ Some("Projection: now() = currentdate()\n EmptyRelation".to_string())
+ );
Ok(())
}
diff --git a/datafusion/core/tests/sql/window.rs
b/datafusion/core/tests/sql/window.rs
index 535e1d8917..9b75edf88d 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -51,14 +51,16 @@ async fn window_frame_creation_type_checking() ->
Result<()> {
// Error is returned from the physical plan.
check_query(
true,
- "Internal error: Operator - is not implemented for types UInt32(1) and
Utf8(\"1 DAY\")."
- ).await?;
+ r#"Execution error: Cannot cast Utf8("1 DAY") to UInt32"#,
+ )
+ .await?;
// Error is returned from the logical plan.
check_query(
false,
- "Internal error: Optimizer rule 'type_coercion' failed due to
unexpected error: Execution error: Cannot cast Utf8(\"1 DAY\") to UInt32."
- ).await
+ r#"Execution error: Cannot cast Utf8("1 DAY") to UInt32"#,
+ )
+ .await
}
fn split_record_batch(batch: RecordBatch, n_split: usize) -> Vec<RecordBatch> {
diff --git a/datafusion/core/tests/sqllogictests/test_files/dates.slt
b/datafusion/core/tests/sqllogictests/test_files/dates.slt
index ccd07e14b6..d75fba6f51 100644
--- a/datafusion/core/tests/sqllogictests/test_files/dates.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/dates.slt
@@ -85,6 +85,6 @@ g
h
## Plan error when compare Utf8 and timestamp in where clause
-statement error DataFusion error: Error during planning: The type of
Timestamp\(Nanosecond, Some\("\+00:00"\)\) Plus Utf8 of binary physical should
be same
+statement error DataFusion error: Error during planning:
Timestamp\(Nanosecond, Some\("\+00:00"\)\) \+ Utf8 can't be evaluated because
there isn't a common type to coerce the types to
select i_item_desc from test
where d3_date > now() + '5 days';
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index c7e8123c1b..ff6d9f5057 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -22,10 +22,13 @@ use crate::expr::{
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::binary_operator_data_type;
use crate::type_coercion::other::get_coerce_type_for_case_expression;
-use crate::{aggregate_function, function, window_function};
+use crate::{
+ aggregate_function, function, window_function, LogicalPlan, Projection,
Subquery,
+};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType;
use datafusion_common::{Column, DFField, DFSchema, DataFusionError,
ExprSchema, Result};
+use std::sync::Arc;
/// trait to allow expr to typable with respect to a schema
pub trait ExprSchemable {
@@ -290,13 +293,30 @@ impl ExprSchemable for Expr {
/// This function errors when it is impossible to cast the
/// expression to the target [arrow::datatypes::DataType].
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) ->
Result<Expr> {
+ let this_type = self.get_type(schema)?;
+ if this_type == *cast_to_type {
+ return Ok(self);
+ }
+
// TODO(kszucs): most of the operations do not validate the type
correctness
// like all of the binary expressions below. Perhaps Expr should track
the
// type of the expression?
- let this_type = self.get_type(schema)?;
- if this_type == *cast_to_type {
- Ok(self)
- } else if can_cast_types(&this_type, cast_to_type) {
+
+ // TODO(jackwener): Handle subqueries separately, need to refactor it.
+ match self {
+ Expr::ScalarSubquery(subquery) => {
+ return Ok(Expr::ScalarSubquery(cast_subquery(subquery,
cast_to_type)?));
+ }
+ Expr::Exists { subquery, negated } => {
+ return Ok(Expr::Exists {
+ subquery: cast_subquery(subquery, cast_to_type)?,
+ negated,
+ });
+ }
+ _ => {}
+ }
+
+ if can_cast_types(&this_type, cast_to_type) {
Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone())))
} else {
Err(DataFusionError::Plan(format!(
@@ -306,6 +326,33 @@ impl ExprSchemable for Expr {
}
}
+fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) ->
Result<Subquery> {
+ let plan = subquery.subquery.as_ref();
+ let new_plan = match plan {
+ LogicalPlan::Projection(projection) => {
+ let cast_expr = projection.expr[0]
+ .clone()
+ .cast_to(cast_to_type, projection.input.schema())?;
+ LogicalPlan::Projection(Projection::try_new(
+ vec![cast_expr],
+ projection.input.clone(),
+ )?)
+ }
+ _ => {
+ let cast_expr =
Expr::Column(plan.schema().field(0).qualified_column())
+ .cast_to(cast_to_type, subquery.subquery.schema())?;
+ LogicalPlan::Projection(Projection::try_new(
+ vec![cast_expr],
+ subquery.subquery.clone(),
+ )?)
+ }
+ };
+ Ok(Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns: subquery.outer_ref_columns,
+ })
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/optimizer/src/analyzer/mod.rs
b/datafusion/optimizer/src/analyzer/mod.rs
index aef46926f5..bb9b01c859 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -17,10 +17,12 @@
mod count_wildcard_rule;
mod inline_table_scan;
+pub(crate) mod type_coercion;
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
use crate::analyzer::inline_table_scan::InlineTableScan;
+use crate::analyzer::type_coercion::TypeCoercion;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
use datafusion_common::{DataFusionError, Result};
@@ -64,8 +66,9 @@ impl Analyzer {
/// Create a new analyzer using the recommended list of rules
pub fn new() -> Self {
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> = vec![
- Arc::new(CountWildcardRule::new()),
Arc::new(InlineTableScan::new()),
+ Arc::new(TypeCoercion::new()),
+ Arc::new(CountWildcardRule::new()),
];
Self::with_rules(rules)
}
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
similarity index 91%
rename from datafusion/optimizer/src/type_coercion.rs
rename to datafusion/optimizer/src/analyzer/type_coercion.rs
index a931da4b3e..0038aec933 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -21,6 +21,7 @@ use std::sync::Arc;
use arrow::datatypes::{DataType, IntervalUnit};
+use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue};
use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like,
WindowFunction};
@@ -43,8 +44,8 @@ use datafusion_expr::{
};
use datafusion_expr::{ExprSchemable, Signature};
+use crate::analyzer::AnalyzerRule;
use crate::utils::{merge_schema, rewrite_preserving_name};
-use crate::{OptimizerConfig, OptimizerRule};
#[derive(Default)]
pub struct TypeCoercion {}
@@ -55,21 +56,17 @@ impl TypeCoercion {
}
}
-impl OptimizerRule for TypeCoercion {
+impl AnalyzerRule for TypeCoercion {
fn name(&self) -> &str {
"type_coercion"
}
- fn try_optimize(
- &self,
- plan: &LogicalPlan,
- _: &dyn OptimizerConfig,
- ) -> Result<Option<LogicalPlan>> {
- Ok(Some(optimize_internal(&DFSchema::empty(), plan)?))
+ fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) ->
Result<LogicalPlan> {
+ analyze_internal(&DFSchema::empty(), &plan)
}
}
-fn optimize_internal(
+fn analyze_internal(
// use the external schema to handle the correlated subqueries case
external_schema: &DFSchema,
plan: &LogicalPlan,
@@ -78,7 +75,7 @@ fn optimize_internal(
let new_inputs = plan
.inputs()
.iter()
- .map(|p| optimize_internal(external_schema, p))
+ .map(|p| analyze_internal(external_schema, p))
.collect::<Result<Vec<_>>>()?;
// get schema representing all available input fields. This is used for
data type
// resolution only, so order does not matter here
@@ -129,14 +126,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
outer_ref_columns,
}) => {
- let new_plan = optimize_internal(&self.schema, &subquery)?;
+ let new_plan = analyze_internal(&self.schema, &subquery)?;
Ok(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
}))
}
Expr::Exists { subquery, negated } => {
- let new_plan = optimize_internal(&self.schema,
&subquery.subquery)?;
+ let new_plan = analyze_internal(&self.schema,
&subquery.subquery)?;
Ok(Expr::Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
@@ -151,7 +148,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
negated,
} => {
let expr_type = expr.get_type(&self.schema)?;
- let new_plan = optimize_internal(&self.schema,
&subquery.subquery)?;
+ let new_plan = analyze_internal(&self.schema,
&subquery.subquery)?;
let subquery_type = new_plan.schema().field(0).data_type();
let expr = if &expr_type == subquery_type {
expr
@@ -747,35 +744,24 @@ mod test {
};
use datafusion_physical_expr::expressions::AvgAccumulator;
- use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
- use crate::{OptimizerContext, OptimizerRule};
-
- use super::coerce_case_expression;
-
- fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) ->
Result<()> {
- let rule = TypeCoercion::new();
- let config = OptimizerContext::default();
- let plan = rule.try_optimize(plan, &config)?.unwrap();
- assert_eq!(expected, &format!("{plan:?}"));
- Ok(())
- }
+ use crate::analyzer::type_coercion::{
+ coerce_case_expression, TypeCoercion, TypeCoercionRewriter,
+ };
+ use crate::test::assert_analyzed_plan_eq;
#[test]
fn simple_case() -> Result<()> {
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Float64,
true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- ),
+ schema: Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", DataType::Float64, true)],
+ std::collections::HashMap::new(),
+ )?),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n
EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}
#[test]
@@ -783,13 +769,10 @@ mod test {
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Float64,
true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- ),
+ schema: Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", DataType::Float64, true)],
+ std::collections::HashMap::new(),
+ )?),
}));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![expr.clone().or(expr)],
@@ -797,7 +780,7 @@ mod test {
)?);
let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a <
CAST(UInt32(2) AS Float64)\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}
#[test]
@@ -819,7 +802,7 @@ mod test {
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf],
empty)?);
let expected =
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n
EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}
#[test]
@@ -838,7 +821,9 @@ mod test {
args: vec![lit("Apple")],
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf],
empty)?);
- let err = assert_optimized_plan_eq(&plan, "").err().unwrap();
+ let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, "")
+ .err()
+ .unwrap();
assert_eq!(
"Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Int32])
failed.\")",
&format!("{err:?}")
@@ -860,7 +845,7 @@ mod test {
empty,
)?);
let expected = "Projection: abs(CAST(Int64(10) AS Float64))\n
EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}
#[test]
@@ -881,7 +866,7 @@ mod test {
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty)?);
let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n
EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}
#[test]
@@ -906,7 +891,9 @@ mod test {
filter: None,
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty)?);
- let err = assert_optimized_plan_eq(&plan, "").err().unwrap();
+ let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, "")
+ .err()
+ .unwrap();
assert_eq!(
"Plan(\"Coercion from [Utf8] to the signature Uniform(1,
[Float64]) failed.\")",
&format!("{err:?}")
@@ -926,7 +913,7 @@ mod test {
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr],
empty)?);
let expected = "Projection: AVG(Int64(12))\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let empty = empty_with_type(DataType::Int32);
let fun: AggregateFunction = AggregateFunction::Avg;
@@ -938,7 +925,7 @@ mod test {
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr],
empty)?);
let expected = "Projection: AVG(a)\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -972,7 +959,7 @@ mod test {
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected =
"Projection: CAST(Utf8(\"1998-03-18\") AS Date32) +
IntervalDayTime(\"386547056640\")\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -982,41 +969,35 @@ mod test {
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)],
false);
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Int64, true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- ),
+ schema: Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", DataType::Int64, true)],
+ std::collections::HashMap::new(),
+ )?),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected =
"Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS
Int64), Int64(8)]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)]) })\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
// a in (1,4,8), a is decimal
let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)],
false);
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified(
- "a",
- DataType::Decimal128(12, 4),
- true,
- )],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- ),
+ schema: Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified(
+ "a",
+ DataType::Decimal128(12, 4),
+ true,
+ )],
+ std::collections::HashMap::new(),
+ )?),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected =
"Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS
Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS
Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Int32(1), Int8(4), Int64(8)])
})\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -1028,11 +1009,11 @@ mod test {
let plan =
LogicalPlan::Projection(Projection::try_new(vec![expr.clone()],
empty)?);
let expected = "Projection: a IS TRUE\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let empty = empty_with_type(DataType::Int64);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
- let err = assert_optimized_plan_eq(&plan, "");
+ let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, "");
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("Int64 IS DISTINCT FROM
Boolean can't be evaluated because there isn't a common type to coerce the
types to"));
@@ -1041,21 +1022,21 @@ mod test {
let empty = empty_with_type(DataType::Boolean);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected = "Projection: a IS NOT TRUE\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
// is false
let expr = col("a").is_false();
let empty = empty_with_type(DataType::Boolean);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected = "Projection: a IS FALSE\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
// is not false
let expr = col("a").is_not_false();
let empty = empty_with_type(DataType::Boolean);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected = "Projection: a IS NOT FALSE\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -1069,7 +1050,7 @@ mod test {
let empty = empty_with_type(DataType::Utf8);
let plan =
LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let expr = Box::new(col("a"));
let pattern = Box::new(lit(ScalarValue::Null));
@@ -1078,14 +1059,14 @@ mod test {
let plan =
LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
let expected = "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let expr = Box::new(col("a"));
let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
let like_expr = Expr::Like(Like::new(false, expr, pattern, None));
let empty = empty_with_type(DataType::Int64);
let plan =
LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?);
- let err = assert_optimized_plan_eq(&plan, expected);
+ let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, expected);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains(
"There isn't a common type to coerce Int64 and Utf8 in LIKE
expression"
@@ -1098,7 +1079,7 @@ mod test {
let empty = empty_with_type(DataType::Utf8);
let plan =
LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let expr = Box::new(col("a"));
let pattern = Box::new(lit(ScalarValue::Null));
@@ -1107,14 +1088,14 @@ mod test {
let plan =
LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
let expected = "Projection: a ILIKE CAST(NULL AS Utf8) AS a ILIKE NULL
\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let expr = Box::new(col("a"));
let pattern = Box::new(lit(ScalarValue::new_utf8("abc")));
let ilike_expr = Expr::ILike(Like::new(false, expr, pattern, None));
let empty = empty_with_type(DataType::Int64);
let plan =
LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?);
- let err = assert_optimized_plan_eq(&plan, expected);
+ let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, expected);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains(
"There isn't a common type to coerce Int64 and Utf8 in ILIKE
expression"
@@ -1130,11 +1111,11 @@ mod test {
let plan =
LogicalPlan::Projection(Projection::try_new(vec![expr.clone()],
empty)?);
let expected = "Projection: a IS UNKNOWN\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
let empty = empty_with_type(DataType::Utf8);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
- let err = assert_optimized_plan_eq(&plan, expected);
+ let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, expected);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("Utf8 IS NOT DISTINCT
FROM Boolean can't be evaluated because there isn't a common type to coerce the
types to"));
@@ -1143,7 +1124,7 @@ mod test {
let empty = empty_with_type(DataType::Boolean);
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -1161,7 +1142,7 @@ mod test {
LogicalPlan::Projection(Projection::try_new(vec![expr],
empty.clone())?);
let expected =
"Projection: concat(a, Utf8(\"b\"), CAST(Boolean(true) AS
Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
}
// concat_ws
@@ -1171,7 +1152,7 @@ mod test {
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected =
"Projection: concatwithseparator(Utf8(\"-\"), a, Utf8(\"b\"),
CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS
Utf8))\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
}
Ok(())
@@ -1200,13 +1181,10 @@ mod test {
#[test]
fn test_type_coercion_rewrite() -> Result<()> {
// gt
- let schema = Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Int64, true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- );
+ let schema = Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", DataType::Int64, true)],
+ std::collections::HashMap::new(),
+ )?);
let mut rewriter = TypeCoercionRewriter { schema };
let expr = is_true(lit(12i32).gt(lit(13i64)));
let expected = is_true(cast(lit(12i32),
DataType::Int64).gt(lit(13i64)));
@@ -1214,13 +1192,10 @@ mod test {
assert_eq!(expected, result);
// eq
- let schema = Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Int64, true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- );
+ let schema = Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", DataType::Int64, true)],
+ std::collections::HashMap::new(),
+ )?);
let mut rewriter = TypeCoercionRewriter { schema };
let expr = is_true(lit(12i32).eq(lit(13i64)));
let expected = is_true(cast(lit(12i32),
DataType::Int64).eq(lit(13i64)));
@@ -1228,13 +1203,10 @@ mod test {
assert_eq!(expected, result);
// lt
- let schema = Arc::new(
- DFSchema::new_with_metadata(
- vec![DFField::new_unqualified("a", DataType::Int64, true)],
- std::collections::HashMap::new(),
- )
- .unwrap(),
- );
+ let schema = Arc::new(DFSchema::new_with_metadata(
+ vec![DFField::new_unqualified("a", DataType::Int64, true)],
+ std::collections::HashMap::new(),
+ )?);
let mut rewriter = TypeCoercionRewriter { schema };
let expr = is_true(lit(12i32).lt(lit(13i64)));
let expected = is_true(cast(lit(12i32),
DataType::Int64).lt(lit(13i64)));
@@ -1259,7 +1231,7 @@ mod test {
dbg!(&plan);
let expected =
"Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond,
None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond,
None))\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -1426,7 +1398,7 @@ mod test {
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty)?);
let expected = "Projection: IntervalYearMonth(\"12\") +
CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n
EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -1451,7 +1423,7 @@ mod test {
dbg!(&plan);
let expected =
"Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond,
None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n
EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
@@ -1486,7 +1458,7 @@ mod test {
\n Subquery:\
\n EmptyRelation\
\n EmptyRelation";
- assert_optimized_plan_eq(&plan, expected)?;
+ assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan,
expected)?;
Ok(())
}
}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 65943675ae..42c0ccb484 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -39,7 +39,6 @@ pub mod rewrite_disjunctive_predicate;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
-pub mod type_coercion;
pub mod unwrap_cast_in_comparison;
pub mod utils;
diff --git a/datafusion/optimizer/src/optimizer.rs
b/datafusion/optimizer/src/optimizer.rs
index e4880cb124..6d02c46cc0 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -40,7 +40,6 @@ use
crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
use crate::simplify_expressions::SimplifyExpressions;
use crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
-use crate::type_coercion::TypeCoercion;
use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
use chrono::{DateTime, Utc};
use datafusion_common::config::ConfigOptions;
@@ -209,7 +208,6 @@ impl Optimizer {
/// Create a new optimizer using the recommended list of rules
pub fn new() -> Self {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
- Arc::new(TypeCoercion::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(ReplaceDistinctWithAggregate::new()),
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 5ef672c90b..69f2fa3484 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -18,9 +18,8 @@
//! Expression simplification API
use super::utils::*;
-use crate::{
- simplify_expressions::regex::simplify_regex_expr,
type_coercion::TypeCoercionRewriter,
-};
+use crate::analyzer::type_coercion::TypeCoercionRewriter;
+use crate::simplify_expressions::regex::simplify_regex_expr;
use arrow::{
array::new_null_array,
datatypes::{DataType, Field, Schema},