This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new f48a9971b Evaluate expressions after type coercion (#3444)
f48a9971b is described below
commit f48a9971bb4822427395d641c071b7ea825ec496
Author: Daniƫl Heres <[email protected]>
AuthorDate: Mon Sep 12 21:29:05 2022 +0200
Evaluate expressions after type coercion (#3444)
* Evaluate expressions after type coercion
* Fix some explains
* Fix some explains
* Fix some explains
* Update test
* Update test
* Update test
* Update more tests
* Fix tests
* Use supported date string
---
datafusion/core/tests/sql/aggregates.rs | 10 +--
datafusion/core/tests/sql/decimal.rs | 114 ++++++++++++-------------
datafusion/core/tests/sql/explain_analyze.rs | 2 +-
datafusion/core/tests/sql/subqueries.rs | 8 +-
datafusion/optimizer/src/type_coercion.rs | 69 +++++++++++----
datafusion/optimizer/tests/integration-test.rs | 6 +-
6 files changed, 121 insertions(+), 88 deletions(-)
diff --git a/datafusion/core/tests/sql/aggregates.rs
b/datafusion/core/tests/sql/aggregates.rs
index 357addbc0..b7f24992c 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
assert_eq!(results.len(), 1);
let expected = vec![
-
"+--------------+-------------------------+-------------------------+-------------------------+",
- "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) |
Int64(1) + AVG(test.c1) |",
-
"+--------------+-------------------------+-------------------------+-------------------------+",
- "| 1.5 | 2.5 | 3.5 |
2.5 |",
-
"+--------------+-------------------------+-------------------------+-------------------------+",
+
"+--------------+---------------------------+---------------------------+---------------------------+",
+ "| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) +
Float64(2) | Float64(1) + AVG(test.c1) |",
+
"+--------------+---------------------------+---------------------------+---------------------------+",
+ "| 1.5 | 2.5 | 3.5
| 2.5 |",
+
"+--------------+---------------------------+---------------------------+---------------------------+",
];
assert_batches_sorted_eq!(expected, &results);
diff --git a/datafusion/core/tests/sql/decimal.rs
b/datafusion/core/tests/sql/decimal.rs
index 7c74cdd52..db686deb7 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+------------------------------+",
- "| decimal_simple.c1 + Int64(1) |",
- "+------------------------------+",
- "| 1.000010 |",
- "| 1.000020 |",
- "| 1.000020 |",
- "| 1.000030 |",
- "| 1.000030 |",
- "| 1.000030 |",
- "| 1.000040 |",
- "| 1.000040 |",
- "| 1.000040 |",
- "| 1.000040 |",
- "| 1.000050 |",
- "| 1.000050 |",
- "| 1.000050 |",
- "| 1.000050 |",
- "| 1.000050 |",
- "+------------------------------+",
+ "+----------------------------------------------------+",
+ "| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
+ "+----------------------------------------------------+",
+ "| 1.000010 |",
+ "| 1.000020 |",
+ "| 1.000020 |",
+ "| 1.000030 |",
+ "| 1.000030 |",
+ "| 1.000030 |",
+ "| 1.000040 |",
+ "| 1.000040 |",
+ "| 1.000040 |",
+ "| 1.000040 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "| 1.000050 |",
+ "+----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
@@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+------------------------------+",
- "| decimal_simple.c1 - Int64(1) |",
- "+------------------------------+",
- "| -0.999990 |",
- "| -0.999980 |",
- "| -0.999980 |",
- "| -0.999970 |",
- "| -0.999970 |",
- "| -0.999970 |",
- "| -0.999960 |",
- "| -0.999960 |",
- "| -0.999960 |",
- "| -0.999960 |",
- "| -0.999950 |",
- "| -0.999950 |",
- "| -0.999950 |",
- "| -0.999950 |",
- "| -0.999950 |",
- "+------------------------------+",
+ "+----------------------------------------------------+",
+ "| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
+ "+----------------------------------------------------+",
+ "| -0.999990 |",
+ "| -0.999980 |",
+ "| -0.999980 |",
+ "| -0.999970 |",
+ "| -0.999970 |",
+ "| -0.999970 |",
+ "| -0.999960 |",
+ "| -0.999960 |",
+ "| -0.999960 |",
+ "| -0.999960 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "| -0.999950 |",
+ "+----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
- "+-------------------------------+",
- "| decimal_simple.c1 * Int64(20) |",
- "+-------------------------------+",
- "| 0.000200 |",
- "| 0.000400 |",
- "| 0.000400 |",
- "| 0.000600 |",
- "| 0.000600 |",
- "| 0.000600 |",
- "| 0.000800 |",
- "| 0.000800 |",
- "| 0.000800 |",
- "| 0.000800 |",
- "| 0.001000 |",
- "| 0.001000 |",
- "| 0.001000 |",
- "| 0.001000 |",
- "| 0.001000 |",
- "+-------------------------------+",
+ "+-----------------------------------------------------+",
+ "| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
+ "+-----------------------------------------------------+",
+ "| 0.000200 |",
+ "| 0.000400 |",
+ "| 0.000400 |",
+ "| 0.000600 |",
+ "| 0.000600 |",
+ "| 0.000600 |",
+ "| 0.000800 |",
+ "| 0.000800 |",
+ "| 0.000800 |",
+ "| 0.000800 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "| 0.001000 |",
+ "+-----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/explain_analyze.rs
b/datafusion/core/tests/sql/explain_analyze.rs
index 91dd9401e..7f465c4c6 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name,
#SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue,
#customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone,
#customer.c_comment\
- \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name,
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address,
#customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS
Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) -
CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
+ \n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name,
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address,
#customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS
Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) -
CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
diff --git a/datafusion/core/tests/sql/subqueries.rs
b/datafusion/core/tests/sql/subqueries.rs
index 0d9fe37f9..1ae5bc68e 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -336,9 +336,9 @@ order by s_name;
Projection: #part.p_partkey AS p_partkey, alias=__sq_1
Filter: #part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name],
partial_filters=[#part.p_name LIKE Utf8("forest%")]
- Projection: #lineitem.l_partkey, #lineitem.l_suppkey,
CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS
Decimal128(38, 17)) AS __value, alias=__sq_3
+ Projection: #lineitem.l_partkey, #lineitem.l_suppkey,
Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS
Decimal128(38, 17)) AS __value, alias=__sq_3
Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]],
aggr=[[SUM(#lineitem.l_quantity)]]
- Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS
Date32)
+ Filter: #lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey,
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >=
CAST(Utf8("1994-01-01") AS Date32)]"#
.to_string();
assert_eq!(actual, expected);
@@ -393,7 +393,7 @@ order by cntrycode;"#;
TableScan: orders projection=[o_custkey]
Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
- Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) >
CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1),
Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"),
Utf8("18"), Utf8("17")])
+ Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) >
Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"),
Utf8("17")])
TableScan: customer projection=[c_phone, c_acctbal],
partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone,
Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"),
Utf8("30"), Utf8("18"), Utf8("17")])]"#
.to_string();
assert_eq!(actual, expected);
@@ -453,7 +453,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: #nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[#nation.n_name = Utf8("GERMANY")]
- Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty)
AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS
__value, alias=__sq_1
+ Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty)
AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value,
alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS
Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: #supplier.s_nationkey = #nation.n_nationkey
Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/type_coercion.rs
index 77580c063..72ee5d19a 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -17,6 +17,7 @@
//! Optimizer rule for type validation and coercion
+use crate::simplify_expressions::ConstEvaluator;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
@@ -26,6 +27,7 @@ use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};
+use datafusion_physical_expr::execution_props::ExecutionProps;
use std::sync::Arc;
#[derive(Default)]
@@ -64,8 +66,14 @@ impl OptimizerRule for TypeCoercion {
},
);
+ let mut execution_props = ExecutionProps::new();
+ execution_props.query_execution_start_time =
+ optimizer_config.query_execution_start_time;
+ let const_evaluator = ConstEvaluator::try_new(&execution_props)?;
+
let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
+ const_evaluator,
};
let new_expr = plan
@@ -78,11 +86,12 @@ impl OptimizerRule for TypeCoercion {
}
}
-struct TypeCoercionRewriter {
+struct TypeCoercionRewriter<'a> {
schema: DFSchemaRef,
+ const_evaluator: ConstEvaluator<'a>,
}
-impl ExprRewriter for TypeCoercionRewriter {
+impl ExprRewriter for TypeCoercionRewriter<'_> {
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
@@ -106,7 +115,7 @@ impl ExprRewriter for TypeCoercionRewriter {
}
_ => {
let coerced_type = coerce_types(&left_type, &op,
&right_type)?;
- Ok(Expr::BinaryExpr {
+ let expr = Expr::BinaryExpr {
left: Box::new(
left.clone().cast_to(&coerced_type,
&self.schema)?,
),
@@ -114,7 +123,9 @@ impl ExprRewriter for TypeCoercionRewriter {
right: Box::new(
right.clone().cast_to(&coerced_type,
&self.schema)?,
),
- })
+ };
+
+ expr.rewrite(&mut self.const_evaluator)
}
}
}
@@ -133,12 +144,13 @@ impl ExprRewriter for TypeCoercionRewriter {
expr_type, low_type
))
})?;
- Ok(Expr::Between {
+ let expr = Expr::Between {
expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
negated,
low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
- })
+ };
+ expr.rewrite(&mut self.const_evaluator)
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
@@ -146,10 +158,11 @@ impl ExprRewriter for TypeCoercionRewriter {
&self.schema,
&fun.signature,
)?;
- Ok(Expr::ScalarUDF {
+ let expr = Expr::ScalarUDF {
fun,
args: new_expr,
- })
+ };
+ expr.rewrite(&mut self.const_evaluator)
}
expr => Ok(expr),
}
@@ -188,7 +201,8 @@ mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
- use datafusion_common::{DFSchema, Result, ScalarValue};
+ use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
+ use datafusion_expr::{col, ColumnarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
@@ -199,17 +213,23 @@ mod test {
#[test]
fn simple_case() -> Result<()> {
- let expr = lit(1.2_f64).lt(lit(2_u32));
+ let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
+ schema: Arc::new(
+ DFSchema::new_with_metadata(
+ vec![DFField::new(None, "a", DataType::Float64, true)],
+ std::collections::HashMap::new(),
+ )
+ .unwrap(),
+ ),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
- "Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n
EmptyRelation",
+ "Projection: #a < Float64(2)\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
@@ -217,10 +237,16 @@ mod test {
#[test]
fn nested_case() -> Result<()> {
- let expr = lit(1.2_f64).lt(lit(2_u32));
+ let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
- schema: Arc::new(DFSchema::empty()),
+ schema: Arc::new(
+ DFSchema::new_with_metadata(
+ vec![DFField::new(None, "a", DataType::Float64, true)],
+ std::collections::HashMap::new(),
+ )
+ .unwrap(),
+ ),
}));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![expr.clone().or(expr)],
@@ -230,8 +256,11 @@ mod test {
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
- assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR
Float64(1.2) < CAST(UInt32(2) AS Float64)\
- \n EmptyRelation", &format!("{:?}", plan));
+ assert_eq!(
+ "Projection: #a < Float64(2) OR #a < Float64(2)\
+ \n EmptyRelation",
+ &format!("{:?}", plan)
+ );
Ok(())
}
@@ -240,7 +269,11 @@ mod test {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
- let fun: ScalarFunctionImplementation = Arc::new(move |_|
unimplemented!());
+ let fun: ScalarFunctionImplementation = Arc::new(move |_| {
+ Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
+ "a".to_string(),
+ ))))
+ });
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
@@ -255,7 +288,7 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
- "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n
EmptyRelation",
+ "Projection: Utf8(\"a\")\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
diff --git a/datafusion/optimizer/tests/integration-test.rs
b/datafusion/optimizer/tests/integration-test.rs
index 87a0bab68..cb31600b9 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -83,7 +83,7 @@ fn between_date32_plus_interval() -> Result<()> {
let plan = test_sql(sql)?;
let expected =
"Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]],
aggr=[[COUNT(UInt8(1))]]\
- \n Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32)
AND #test.col_date32 <= Date32(\"10393\")\
+ \n Filter: #test.col_date32 >= Date32(\"10303\") AND
#test.col_date32 <= Date32(\"10393\")\
\n TableScan: test projection=[col_date32]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
@@ -92,11 +92,11 @@ fn between_date32_plus_interval() -> Result<()> {
#[test]
fn between_date64_plus_interval() -> Result<()> {
let sql = "SELECT count(1) FROM test \
- WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) +
INTERVAL '90 days'";
+ WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as
date) + INTERVAL '90 days'";
let plan = test_sql(sql)?;
let expected =
"Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]],
aggr=[[COUNT(UInt8(1))]]\
- \n Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64)
AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\
+ \n Filter: #test.col_date64 >= Date64(\"890179200000\") AND
#test.col_date64 <= Date64(\"897955200000\")\
\n TableScan: test projection=[col_date64]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())