This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new f48a9971b Evaluate expressions after type coercion (#3444)
f48a9971b is described below

commit f48a9971bb4822427395d641c071b7ea825ec496
Author: DaniĆ«l Heres <[email protected]>
AuthorDate: Mon Sep 12 21:29:05 2022 +0200

    Evaluate expressions after type coercion (#3444)
    
    * Evaluate expressions after type coercion
    
    * Fix some explains
    
    * Fix some explains
    
    * Fix some explains
    
    * Update test
    
    * Update test
    
    * Update test
    
    * Update more tests
    
    * Fix tests
    
    * Use supported date string
---
 datafusion/core/tests/sql/aggregates.rs        |  10 +--
 datafusion/core/tests/sql/decimal.rs           | 114 ++++++++++++-------------
 datafusion/core/tests/sql/explain_analyze.rs   |   2 +-
 datafusion/core/tests/sql/subqueries.rs        |   8 +-
 datafusion/optimizer/src/type_coercion.rs      |  69 +++++++++++----
 datafusion/optimizer/tests/integration-test.rs |   6 +-
 6 files changed, 121 insertions(+), 88 deletions(-)

diff --git a/datafusion/core/tests/sql/aggregates.rs 
b/datafusion/core/tests/sql/aggregates.rs
index 357addbc0..b7f24992c 100644
--- a/datafusion/core/tests/sql/aggregates.rs
+++ b/datafusion/core/tests/sql/aggregates.rs
@@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
     assert_eq!(results.len(), 1);
 
     let expected = vec![
-        
"+--------------+-------------------------+-------------------------+-------------------------+",
-        "| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | 
Int64(1) + AVG(test.c1) |",
-        
"+--------------+-------------------------+-------------------------+-------------------------+",
-        "| 1.5          | 2.5                     | 3.5                     | 
2.5                     |",
-        
"+--------------+-------------------------+-------------------------+-------------------------+",
+        
"+--------------+---------------------------+---------------------------+---------------------------+",
+        "| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + 
Float64(2) | Float64(1) + AVG(test.c1) |",
+        
"+--------------+---------------------------+---------------------------+---------------------------+",
+        "| 1.5          | 2.5                       | 3.5                      
 | 2.5                       |",
+        
"+--------------+---------------------------+---------------------------+---------------------------+",
     ];
     assert_batches_sorted_eq!(expected, &results);
 
diff --git a/datafusion/core/tests/sql/decimal.rs 
b/datafusion/core/tests/sql/decimal.rs
index 7c74cdd52..db686deb7 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+------------------------------+",
-        "| decimal_simple.c1 + Int64(1) |",
-        "+------------------------------+",
-        "| 1.000010                     |",
-        "| 1.000020                     |",
-        "| 1.000020                     |",
-        "| 1.000030                     |",
-        "| 1.000030                     |",
-        "| 1.000030                     |",
-        "| 1.000040                     |",
-        "| 1.000040                     |",
-        "| 1.000040                     |",
-        "| 1.000040                     |",
-        "| 1.000050                     |",
-        "| 1.000050                     |",
-        "| 1.000050                     |",
-        "| 1.000050                     |",
-        "| 1.000050                     |",
-        "+------------------------------+",
+        "+----------------------------------------------------+",
+        "| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
+        "+----------------------------------------------------+",
+        "| 1.000010                                           |",
+        "| 1.000020                                           |",
+        "| 1.000020                                           |",
+        "| 1.000030                                           |",
+        "| 1.000030                                           |",
+        "| 1.000030                                           |",
+        "| 1.000040                                           |",
+        "| 1.000040                                           |",
+        "| 1.000040                                           |",
+        "| 1.000040                                           |",
+        "| 1.000050                                           |",
+        "| 1.000050                                           |",
+        "| 1.000050                                           |",
+        "| 1.000050                                           |",
+        "| 1.000050                                           |",
+        "+----------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
     // array decimal(10,6) + array decimal(12,7) => decimal(13,7)
@@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+------------------------------+",
-        "| decimal_simple.c1 - Int64(1) |",
-        "+------------------------------+",
-        "| -0.999990                    |",
-        "| -0.999980                    |",
-        "| -0.999980                    |",
-        "| -0.999970                    |",
-        "| -0.999970                    |",
-        "| -0.999970                    |",
-        "| -0.999960                    |",
-        "| -0.999960                    |",
-        "| -0.999960                    |",
-        "| -0.999960                    |",
-        "| -0.999950                    |",
-        "| -0.999950                    |",
-        "| -0.999950                    |",
-        "| -0.999950                    |",
-        "| -0.999950                    |",
-        "+------------------------------+",
+        "+----------------------------------------------------+",
+        "| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
+        "+----------------------------------------------------+",
+        "| -0.999990                                          |",
+        "| -0.999980                                          |",
+        "| -0.999980                                          |",
+        "| -0.999970                                          |",
+        "| -0.999970                                          |",
+        "| -0.999970                                          |",
+        "| -0.999960                                          |",
+        "| -0.999960                                          |",
+        "| -0.999960                                          |",
+        "| -0.999960                                          |",
+        "| -0.999950                                          |",
+        "| -0.999950                                          |",
+        "| -0.999950                                          |",
+        "| -0.999950                                          |",
+        "| -0.999950                                          |",
+        "+----------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
@@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
         actual[0].schema().field(0).data_type()
     );
     let expected = vec![
-        "+-------------------------------+",
-        "| decimal_simple.c1 * Int64(20) |",
-        "+-------------------------------+",
-        "| 0.000200                      |",
-        "| 0.000400                      |",
-        "| 0.000400                      |",
-        "| 0.000600                      |",
-        "| 0.000600                      |",
-        "| 0.000600                      |",
-        "| 0.000800                      |",
-        "| 0.000800                      |",
-        "| 0.000800                      |",
-        "| 0.000800                      |",
-        "| 0.001000                      |",
-        "| 0.001000                      |",
-        "| 0.001000                      |",
-        "| 0.001000                      |",
-        "| 0.001000                      |",
-        "+-------------------------------+",
+        "+-----------------------------------------------------+",
+        "| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
+        "+-----------------------------------------------------+",
+        "| 0.000200                                            |",
+        "| 0.000400                                            |",
+        "| 0.000400                                            |",
+        "| 0.000600                                            |",
+        "| 0.000600                                            |",
+        "| 0.000600                                            |",
+        "| 0.000800                                            |",
+        "| 0.000800                                            |",
+        "| 0.000800                                            |",
+        "| 0.000800                                            |",
+        "| 0.001000                                            |",
+        "| 0.001000                                            |",
+        "| 0.001000                                            |",
+        "| 0.001000                                            |",
+        "| 0.001000                                            |",
+        "+-----------------------------------------------------+",
     ];
     assert_batches_eq!(expected, &actual);
 
diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index 91dd9401e..7f465c4c6 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -653,7 +653,7 @@ order by
     let expected = "\
     Sort: #revenue DESC NULLS FIRST\
     \n  Projection: #customer.c_custkey, #customer.c_name, 
#SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, 
#customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, 
#customer.c_comment\
-    \n    Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, 
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, 
#customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS 
Decimal128(38, 4)) * CAST(CAST(Int64(1) AS Decimal128(23, 2)) - 
CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
+    \n    Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, 
#customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, 
#customer.c_comment]], aggr=[[SUM(CAST(#lineitem.l_extendedprice AS 
Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - 
CAST(#lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)))]]\
     \n      Inner Join: #customer.c_nationkey = #nation.n_nationkey\
     \n        Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
     \n          Inner Join: #customer.c_custkey = #orders.o_custkey\
diff --git a/datafusion/core/tests/sql/subqueries.rs 
b/datafusion/core/tests/sql/subqueries.rs
index 0d9fe37f9..1ae5bc68e 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -336,9 +336,9 @@ order by s_name;
               Projection: #part.p_partkey AS p_partkey, alias=__sq_1
                 Filter: #part.p_name LIKE Utf8("forest%")
                   TableScan: part projection=[p_partkey, p_name], 
partial_filters=[#part.p_name LIKE Utf8("forest%")]
-            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, 
CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
+            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, 
Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
               Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], 
aggr=[[SUM(#lineitem.l_quantity)]]
-                Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS 
Date32)
+                Filter: #lineitem.l_shipdate >= Date32("8766")
                   TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= 
CAST(Utf8("1994-01-01") AS Date32)]"#
         .to_string();
     assert_eq!(actual, expected);
@@ -393,7 +393,7 @@ order by cntrycode;"#;
                 TableScan: orders projection=[o_custkey]
               Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
                 Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
-                  Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), 
Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), 
Utf8("18"), Utf8("17")])
+                  Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])
                     TableScan: customer projection=[c_phone, c_acctbal], 
partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, 
Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), 
Utf8("30"), Utf8("18"), Utf8("17")])]"#
         .to_string();
     assert_eq!(actual, expected);
@@ -453,7 +453,7 @@ order by value desc;
               TableScan: supplier projection=[s_suppkey, s_nationkey]
             Filter: #nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("GERMANY")]
-        Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) 
AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS 
__value, alias=__sq_1
+        Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) 
AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, 
alias=__sq_1
           Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS 
Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
             Inner Join: #supplier.s_nationkey = #nation.n_nationkey
               Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
index 77580c063..72ee5d19a 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -17,6 +17,7 @@
 
 //! Optimizer rule for type validation and coercion
 
+use crate::simplify_expressions::ConstEvaluator;
 use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::DataType;
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
@@ -26,6 +27,7 @@ use datafusion_expr::type_coercion::data_types;
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{Expr, LogicalPlan};
 use datafusion_expr::{ExprSchemable, Signature};
+use datafusion_physical_expr::execution_props::ExecutionProps;
 use std::sync::Arc;
 
 #[derive(Default)]
@@ -64,8 +66,14 @@ impl OptimizerRule for TypeCoercion {
             },
         );
 
+        let mut execution_props = ExecutionProps::new();
+        execution_props.query_execution_start_time =
+            optimizer_config.query_execution_start_time;
+        let const_evaluator = ConstEvaluator::try_new(&execution_props)?;
+
         let mut expr_rewrite = TypeCoercionRewriter {
             schema: Arc::new(schema),
+            const_evaluator,
         };
 
         let new_expr = plan
@@ -78,11 +86,12 @@ impl OptimizerRule for TypeCoercion {
     }
 }
 
-struct TypeCoercionRewriter {
+struct TypeCoercionRewriter<'a> {
     schema: DFSchemaRef,
+    const_evaluator: ConstEvaluator<'a>,
 }
 
-impl ExprRewriter for TypeCoercionRewriter {
+impl ExprRewriter for TypeCoercionRewriter<'_> {
     fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
         Ok(RewriteRecursion::Continue)
     }
@@ -106,7 +115,7 @@ impl ExprRewriter for TypeCoercionRewriter {
                     }
                     _ => {
                         let coerced_type = coerce_types(&left_type, &op, 
&right_type)?;
-                        Ok(Expr::BinaryExpr {
+                        let expr = Expr::BinaryExpr {
                             left: Box::new(
                                 left.clone().cast_to(&coerced_type, 
&self.schema)?,
                             ),
@@ -114,7 +123,9 @@ impl ExprRewriter for TypeCoercionRewriter {
                             right: Box::new(
                                 right.clone().cast_to(&coerced_type, 
&self.schema)?,
                             ),
-                        })
+                        };
+
+                        expr.rewrite(&mut self.const_evaluator)
                     }
                 }
             }
@@ -133,12 +144,13 @@ impl ExprRewriter for TypeCoercionRewriter {
                             expr_type, low_type
                         ))
                     })?;
-                Ok(Expr::Between {
+                let expr = Expr::Between {
                     expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
                     negated,
                     low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
                     high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
-                })
+                };
+                expr.rewrite(&mut self.const_evaluator)
             }
             Expr::ScalarUDF { fun, args } => {
                 let new_expr = coerce_arguments_for_signature(
@@ -146,10 +158,11 @@ impl ExprRewriter for TypeCoercionRewriter {
                     &self.schema,
                     &fun.signature,
                 )?;
-                Ok(Expr::ScalarUDF {
+                let expr = Expr::ScalarUDF {
                     fun,
                     args: new_expr,
-                })
+                };
+                expr.rewrite(&mut self.const_evaluator)
             }
             expr => Ok(expr),
         }
@@ -188,7 +201,8 @@ mod test {
     use crate::type_coercion::TypeCoercion;
     use crate::{OptimizerConfig, OptimizerRule};
     use arrow::datatypes::DataType;
-    use datafusion_common::{DFSchema, Result, ScalarValue};
+    use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
+    use datafusion_expr::{col, ColumnarValue};
     use datafusion_expr::{
         lit,
         logical_plan::{EmptyRelation, Projection},
@@ -199,17 +213,23 @@ mod test {
 
     #[test]
     fn simple_case() -> Result<()> {
-        let expr = lit(1.2_f64).lt(lit(2_u32));
+        let expr = col("a").lt(lit(2_u32));
         let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
             produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
+            schema: Arc::new(
+                DFSchema::new_with_metadata(
+                    vec![DFField::new(None, "a", DataType::Float64, true)],
+                    std::collections::HashMap::new(),
+                )
+                .unwrap(),
+            ),
         }));
         let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], 
empty, None)?);
         let rule = TypeCoercion::new();
         let mut config = OptimizerConfig::default();
         let plan = rule.optimize(&plan, &mut config)?;
         assert_eq!(
-            "Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n  
EmptyRelation",
+            "Projection: #a < Float64(2)\n  EmptyRelation",
             &format!("{:?}", plan)
         );
         Ok(())
@@ -217,10 +237,16 @@ mod test {
 
     #[test]
     fn nested_case() -> Result<()> {
-        let expr = lit(1.2_f64).lt(lit(2_u32));
+        let expr = col("a").lt(lit(2_u32));
         let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
             produce_one_row: false,
-            schema: Arc::new(DFSchema::empty()),
+            schema: Arc::new(
+                DFSchema::new_with_metadata(
+                    vec![DFField::new(None, "a", DataType::Float64, true)],
+                    std::collections::HashMap::new(),
+                )
+                .unwrap(),
+            ),
         }));
         let plan = LogicalPlan::Projection(Projection::try_new(
             vec![expr.clone().or(expr)],
@@ -230,8 +256,11 @@ mod test {
         let rule = TypeCoercion::new();
         let mut config = OptimizerConfig::default();
         let plan = rule.optimize(&plan, &mut config)?;
-        assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR 
Float64(1.2) < CAST(UInt32(2) AS Float64)\
-            \n  EmptyRelation", &format!("{:?}", plan));
+        assert_eq!(
+            "Projection: #a < Float64(2) OR #a < Float64(2)\
+            \n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
         Ok(())
     }
 
@@ -240,7 +269,11 @@ mod test {
         let empty = empty();
         let return_type: ReturnTypeFunction =
             Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
-        let fun: ScalarFunctionImplementation = Arc::new(move |_| 
unimplemented!());
+        let fun: ScalarFunctionImplementation = Arc::new(move |_| {
+            Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
+                "a".to_string(),
+            ))))
+        });
         let udf = Expr::ScalarUDF {
             fun: Arc::new(ScalarUDF::new(
                 "TestScalarUDF",
@@ -255,7 +288,7 @@ mod test {
         let mut config = OptimizerConfig::default();
         let plan = rule.optimize(&plan, &mut config)?;
         assert_eq!(
-            "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n  
EmptyRelation",
+            "Projection: Utf8(\"a\")\n  EmptyRelation",
             &format!("{:?}", plan)
         );
         Ok(())
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
index 87a0bab68..cb31600b9 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -83,7 +83,7 @@ fn between_date32_plus_interval() -> Result<()> {
     let plan = test_sql(sql)?;
     let expected =
         "Projection: #COUNT(UInt8(1))\n  Aggregate: groupBy=[[]], 
aggr=[[COUNT(UInt8(1))]]\
-        \n    Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32) 
AND #test.col_date32 <= Date32(\"10393\")\
+        \n    Filter: #test.col_date32 >= Date32(\"10303\") AND 
#test.col_date32 <= Date32(\"10393\")\
         \n      TableScan: test projection=[col_date32]";
     assert_eq!(expected, format!("{:?}", plan));
     Ok(())
@@ -92,11 +92,11 @@ fn between_date32_plus_interval() -> Result<()> {
 #[test]
 fn between_date64_plus_interval() -> Result<()> {
     let sql = "SELECT count(1) FROM test \
-    WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) + 
INTERVAL '90 days'";
+    WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as 
date) + INTERVAL '90 days'";
     let plan = test_sql(sql)?;
     let expected =
         "Projection: #COUNT(UInt8(1))\n  Aggregate: groupBy=[[]], 
aggr=[[COUNT(UInt8(1))]]\
-        \n    Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64) 
AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\
+        \n    Filter: #test.col_date64 >= Date64(\"890179200000\") AND 
#test.col_date64 <= Date64(\"897955200000\")\
         \n      TableScan: test projection=[col_date64]";
     assert_eq!(expected, format!("{:?}", plan));
     Ok(())

Reply via email to