This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 156ebff  Generic constant expression evaluation (#1153)
156ebff is described below

commit 156ebff70f96346742c0654ea4af76b9d1036530
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Oct 27 19:19:47 2021 -0400

    Generic constant expression evaluation (#1153)
    
    * Generic constant expression evaluation
    
    * Better list of evaluatable expressions
    
    * Fixup comments
    
    * Use Null type
---
 datafusion/src/optimizer/constant_folding.rs | 279 ++++++++++---------
 datafusion/src/optimizer/utils.rs            | 397 ++++++++++++++++++++++++++-
 datafusion/src/test_util.rs                  |  46 ++++
 datafusion/tests/sql.rs                      |  46 +---
 4 files changed, 586 insertions(+), 182 deletions(-)

diff --git a/datafusion/src/optimizer/constant_folding.rs 
b/datafusion/src/optimizer/constant_folding.rs
index d67d7d1..74fdc72 100644
--- a/datafusion/src/optimizer/constant_folding.rs
+++ b/datafusion/src/optimizer/constant_folding.rs
@@ -15,12 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Boolean comparison rule rewrites redundant comparison expression involving 
boolean literal into
-//! unary expression.
+//! Constant folding and algebraic simplification
 
 use std::sync::Arc;
 
-use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos;
 use arrow::datatypes::DataType;
 
 use crate::error::Result;
@@ -30,11 +28,11 @@ use crate::optimizer::optimizer::OptimizerRule;
 use crate::optimizer::utils;
 use crate::physical_plan::functions::BuiltinScalarFunction;
 use crate::scalar::ScalarValue;
-use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS};
 
-/// Optimizer that simplifies comparison expressions involving boolean 
literals.
+/// Simplifies plans by rewriting [`Expr`]`s evaluating constants
+/// and applying algebraic simplifications
 ///
-/// Recursively go through all expressions and simplify the following cases:
+/// Example transformations that are applied:
 /// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean 
type
 /// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean 
type
 /// * `true = true` and `false = false` to `true`
@@ -61,14 +59,16 @@ impl OptimizerRule for ConstantFolding {
         // projected columns. With just the projected schema, it's not 
possible to infer types for
         // expressions that references non-projected columns within the same 
project plan or its
         // children plans.
-        let mut rewriter = ConstantRewriter {
+        let mut simplifier = Simplifier {
             schemas: plan.all_schemas(),
             execution_props,
         };
 
+        let mut const_evaluator = utils::ConstEvaluator::new();
+
         match plan {
             LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter 
{
-                predicate: predicate.clone().rewrite(&mut rewriter)?,
+                predicate: predicate.clone().rewrite(&mut simplifier)?,
                 input: Arc::new(self.optimize(input, execution_props)?),
             }),
             // Rest: recurse into plan, apply optimization where possible
@@ -96,7 +96,18 @@ impl OptimizerRule for ConstantFolding {
                 let expr = plan
                     .expressions()
                     .into_iter()
-                    .map(|e| e.rewrite(&mut rewriter))
+                    .map(|e| {
+                        // TODO iterate until no changes are made
+                        // during rewrite (evaluating constants can
+                        // enable new simplifications and
+                        // simplifications can enable new constant
+                        // evaluation)
+                        let new_e = e
+                            // fold constants and then simplify
+                            .rewrite(&mut const_evaluator)?
+                            .rewrite(&mut simplifier)?;
+                        Ok(new_e)
+                    })
                     .collect::<Result<Vec<_>>>()?;
 
                 utils::from_plan(plan, &expr, &new_inputs)
@@ -112,13 +123,17 @@ impl OptimizerRule for ConstantFolding {
     }
 }
 
-struct ConstantRewriter<'a> {
+/// Simplifies [`Expr`]s by applying algebraic transformation rules
+///
+/// For example
+/// `true && col` --> `col` where `col` is a boolean types
+struct Simplifier<'a> {
     /// input schemas
     schemas: Vec<&'a DFSchemaRef>,
     execution_props: &'a ExecutionProps,
 }
 
-impl<'a> ConstantRewriter<'a> {
+impl<'a> Simplifier<'a> {
     fn is_boolean_type(&self, expr: &Expr) -> bool {
         for schema in &self.schemas {
             if let Ok(DataType::Boolean) = expr.get_type(schema) {
@@ -130,7 +145,7 @@ impl<'a> ConstantRewriter<'a> {
     }
 }
 
-impl<'a> ExprRewriter for ConstantRewriter<'a> {
+impl<'a> ExprRewriter for Simplifier<'a> {
     /// rewrite the expression simplifying any constant expressions
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         let new_expr = match expr {
@@ -205,14 +220,15 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> {
                 },
                 _ => Expr::BinaryExpr { left, op, right },
             },
+            // Not(Not(expr)) --> expr
             Expr::Not(inner) => {
-                // Not(Not(expr)) --> expr
                 if let Expr::Not(negated_inner) = *inner {
                     *negated_inner
                 } else {
                     Expr::Not(inner)
                 }
             }
+            // convert now() --> the time in `ExecutionProps`
             Expr::ScalarFunction {
                 fun: BuiltinScalarFunction::Now,
                 ..
@@ -221,56 +237,8 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> {
                     .query_execution_start_time
                     .timestamp_nanos(),
             ))),
-            Expr::ScalarFunction {
-                fun: BuiltinScalarFunction::ToTimestamp,
-                args,
-            } => {
-                if !args.is_empty() {
-                    match &args[0] {
-                        Expr::Literal(ScalarValue::Utf8(Some(val))) => {
-                            match string_to_timestamp_nanos(val) {
-                                Ok(timestamp) => Expr::Literal(
-                                    
ScalarValue::TimestampNanosecond(Some(timestamp)),
-                                ),
-                                _ => Expr::ScalarFunction {
-                                    fun: BuiltinScalarFunction::ToTimestamp,
-                                    args,
-                                },
-                            }
-                        }
-                        _ => Expr::ScalarFunction {
-                            fun: BuiltinScalarFunction::ToTimestamp,
-                            args,
-                        },
-                    }
-                } else {
-                    Expr::ScalarFunction {
-                        fun: BuiltinScalarFunction::ToTimestamp,
-                        args,
-                    }
-                }
-            }
-            Expr::Cast {
-                expr: inner,
-                data_type,
-            } => match inner.as_ref() {
-                Expr::Literal(val) => {
-                    let scalar_array = val.to_array();
-                    let cast_array = kernels::cast::cast_with_options(
-                        &scalar_array,
-                        &data_type,
-                        &DEFAULT_CAST_OPTIONS,
-                    )?;
-                    let cast_scalar = ScalarValue::try_from_array(&cast_array, 
0)?;
-                    Expr::Literal(cast_scalar)
-                }
-                _ => Expr::Cast {
-                    expr: inner,
-                    data_type,
-                },
-            },
             expr => {
-                // no rewrite possible
+                // no additional rewrites possible
                 expr
             }
         };
@@ -281,12 +249,13 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::logical_plan::{
-        col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder,
+    use crate::{
+        assert_contains,
+        logical_plan::{col, lit, max, min, DFField, DFSchema, 
LogicalPlanBuilder},
     };
 
     use arrow::datatypes::*;
-    use chrono::{DateTime, Utc};
+    use chrono::{DateTime, TimeZone, Utc};
 
     fn test_table_scan() -> Result<LogicalPlan> {
         let schema = Schema::new(vec![
@@ -311,7 +280,7 @@ mod tests {
     #[test]
     fn optimize_expr_not_not() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -327,7 +296,7 @@ mod tests {
     #[test]
     fn optimize_expr_null_comparison() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -363,7 +332,7 @@ mod tests {
     #[test]
     fn optimize_expr_eq() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -394,7 +363,7 @@ mod tests {
     #[test]
     fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -434,7 +403,7 @@ mod tests {
     #[test]
     fn optimize_expr_not_eq() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -470,7 +439,7 @@ mod tests {
     #[test]
     fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -506,7 +475,7 @@ mod tests {
     #[test]
     fn optimize_expr_case_when_then_else() -> Result<()> {
         let schema = expr_test_schema();
-        let mut rewriter = ConstantRewriter {
+        let mut rewriter = Simplifier {
             schemas: vec![&schema],
             execution_props: &ExecutionProps::new(),
         };
@@ -669,6 +638,20 @@ mod tests {
         Ok(())
     }
 
+    // expect optimizing will result in an error, returning the error string
+    fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime<Utc>) 
-> String {
+        let rule = ConstantFolding::new();
+        let execution_props = ExecutionProps {
+            query_execution_start_time: *date_time,
+        };
+
+        let err = rule
+            .optimize(plan, &execution_props)
+            .expect_err("expected optimization to fail");
+
+        err.to_string()
+    }
+
     fn get_optimized_plan_formatted(
         plan: &LogicalPlan,
         date_time: &DateTime<Utc>,
@@ -684,15 +667,19 @@ mod tests {
         return format!("{:?}", optimized_plan);
     }
 
+    /// Create a to_timestamp expr
+    fn to_timestamp_expr(arg: impl Into<String>) -> Expr {
+        Expr::ScalarFunction {
+            args: vec![lit(arg.into())],
+            fun: BuiltinScalarFunction::ToTimestamp,
+        }
+    }
+
     #[test]
-    fn to_timestamp_expr() {
+    fn to_timestamp_expr_folded() {
         let table_scan = test_table_scan().unwrap();
-        let proj = vec![Expr::ScalarFunction {
-            args: vec![Expr::Literal(ScalarValue::Utf8(Some(
-                "2020-09-08T12:00:00+00:00".to_string(),
-            )))],
-            fun: BuiltinScalarFunction::ToTimestamp,
-        }];
+        let proj = vec![to_timestamp_expr("2020-09-08T12:00:00+00:00")];
+
         let plan = LogicalPlanBuilder::from(table_scan)
             .project(proj)
             .unwrap()
@@ -702,55 +689,30 @@ mod tests {
         let expected = "Projection: TimestampNanosecond(1599566400000000000)\
             \n  TableScan: test projection=None"
             .to_string();
-        let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
+        let actual = get_optimized_plan_formatted(&plan, &Utc::now());
         assert_eq!(expected, actual);
     }
 
     #[test]
     fn to_timestamp_expr_wrong_arg() {
         let table_scan = test_table_scan().unwrap();
-        let proj = vec![Expr::ScalarFunction {
-            args: vec![Expr::Literal(ScalarValue::Utf8(Some(
-                "I'M NOT A TIMESTAMP".to_string(),
-            )))],
-            fun: BuiltinScalarFunction::ToTimestamp,
-        }];
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .project(proj)
-            .unwrap()
-            .build()
-            .unwrap();
-
-        let expected = "Projection: totimestamp(Utf8(\"I\'M NOT A 
TIMESTAMP\"))\
-            \n  TableScan: test projection=None";
-        let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
-        assert_eq!(expected, actual);
-    }
-
-    #[test]
-    fn to_timestamp_expr_no_arg() {
-        let table_scan = test_table_scan().unwrap();
-        let proj = vec![Expr::ScalarFunction {
-            args: vec![],
-            fun: BuiltinScalarFunction::ToTimestamp,
-        }];
+        let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")];
         let plan = LogicalPlanBuilder::from(table_scan)
             .project(proj)
             .unwrap()
             .build()
             .unwrap();
 
-        let expected = "Projection: totimestamp()\
-            \n  TableScan: test projection=None";
-        let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
-        assert_eq!(expected, actual);
+        let expected = "Error parsing 'I'M NOT A TIMESTAMP' as timestamp";
+        let actual = get_optimized_plan_err(&plan, &Utc::now());
+        assert_contains!(actual, expected);
     }
 
     #[test]
     fn cast_expr() {
         let table_scan = test_table_scan().unwrap();
         let proj = vec![Expr::Cast {
-            expr: 
Box::new(Expr::Literal(ScalarValue::Utf8(Some("0".to_string())))),
+            expr: Box::new(lit("0")),
             data_type: DataType::Int32,
         }];
         let plan = LogicalPlanBuilder::from(table_scan)
@@ -761,7 +723,7 @@ mod tests {
 
         let expected = "Projection: Int32(0)\
             \n  TableScan: test projection=None";
-        let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
+        let actual = get_optimized_plan_formatted(&plan, &Utc::now());
         assert_eq!(expected, actual);
     }
 
@@ -769,7 +731,7 @@ mod tests {
     fn cast_expr_wrong_arg() {
         let table_scan = test_table_scan().unwrap();
         let proj = vec![Expr::Cast {
-            expr: 
Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))),
+            expr: Box::new(lit("")),
             data_type: DataType::Int32,
         }];
         let plan = LogicalPlanBuilder::from(table_scan)
@@ -778,20 +740,24 @@ mod tests {
             .build()
             .unwrap();
 
-        let expected = "Projection: Int32(NULL)\
-            \n  TableScan: test projection=None";
-        let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
-        assert_eq!(expected, actual);
+        let expected =
+            "Cannot cast string '' to value of 
arrow::datatypes::types::Int32Type type";
+        let actual = get_optimized_plan_err(&plan, &Utc::now());
+        assert_contains!(actual, expected);
+    }
+
+    fn now_expr() -> Expr {
+        Expr::ScalarFunction {
+            args: vec![],
+            fun: BuiltinScalarFunction::Now,
+        }
     }
 
     #[test]
     fn single_now_expr() {
         let table_scan = test_table_scan().unwrap();
-        let proj = vec![Expr::ScalarFunction {
-            args: vec![],
-            fun: BuiltinScalarFunction::Now,
-        }];
-        let time = chrono::Utc::now();
+        let proj = vec![now_expr()];
+        let time = Utc::now();
         let plan = LogicalPlanBuilder::from(table_scan)
             .project(proj)
             .unwrap()
@@ -811,19 +777,10 @@ mod tests {
     #[test]
     fn multiple_now_expr() {
         let table_scan = test_table_scan().unwrap();
-        let time = chrono::Utc::now();
+        let time = Utc::now();
         let proj = vec![
-            Expr::ScalarFunction {
-                args: vec![],
-                fun: BuiltinScalarFunction::Now,
-            },
-            Expr::Alias(
-                Box::new(Expr::ScalarFunction {
-                    args: vec![],
-                    fun: BuiltinScalarFunction::Now,
-                }),
-                "t2".to_string(),
-            ),
+            now_expr(),
+            Expr::Alias(Box::new(now_expr()), "t2".to_string()),
         ];
         let plan = LogicalPlanBuilder::from(table_scan)
             .project(proj)
@@ -831,6 +788,7 @@ mod tests {
             .build()
             .unwrap();
 
+        // expect the same timestamp appears in both exprs
         let actual = get_optimized_plan_formatted(&plan, &time);
         let expected = format!(
             "Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS 
t2\
@@ -841,4 +799,59 @@ mod tests {
 
         assert_eq!(actual, expected);
     }
+
+    #[test]
+    fn simplify_and_eval() {
+        // demonstrate a case where the evaluation needs to run prior
+        // to the simplifier for it to work
+        let table_scan = test_table_scan().unwrap();
+        let time = Utc::now();
+        // (true or false) != col --> !col
+        let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))];
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .project(proj)
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let actual = get_optimized_plan_formatted(&plan, &time);
+        let expected = "Projection: NOT #test.a\
+                        \n  TableScan: test projection=None";
+
+        assert_eq!(actual, expected);
+    }
+
+    fn cast_to_int64_expr(expr: Expr) -> Expr {
+        Expr::Cast {
+            expr: expr.into(),
+            data_type: DataType::Int64,
+        }
+    }
+
+    #[test]
+    fn now_less_than_timestamp() {
+        let table_scan = test_table_scan().unwrap();
+
+        let ts_string = "2020-09-08T12:05:00+00:00";
+        let time = chrono::Utc.timestamp_nanos(1599566400000000000i64);
+
+        //  now() < cast(to_timestamp(...) as int) + 5000000000
+        let plan = LogicalPlanBuilder::from(table_scan)
+            .filter(
+                now_expr()
+                    .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + 
lit(50000)),
+            )
+            .unwrap()
+            .build()
+            .unwrap();
+
+        // Note that constant folder should be able to run again and fold
+        // this whole expression down to a single constant;
+        // https://github.com/apache/arrow-datafusion/issues/1160
+        let expected = "Filter: TimestampNanosecond(1599566400000000000) < 
CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
+                        \n  TableScan: test projection=None";
+        let actual = get_optimized_plan_formatted(&plan, &time);
+
+        assert_eq!(expected, actual);
+    }
 }
diff --git a/datafusion/src/optimizer/utils.rs 
b/datafusion/src/optimizer/utils.rs
index 1da584b..fdc9a17 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -17,12 +17,18 @@
 
 //! Collection of utility functions that are leveraged by the query optimizer 
rules
 
+use arrow::array::new_null_array;
+use arrow::datatypes::{DataType, Field, Schema};
+use arrow::record_batch::RecordBatch;
+
 use super::optimizer::OptimizerRule;
-use crate::execution::context::ExecutionProps;
+use crate::execution::context::{ExecutionContextState, ExecutionProps};
 use crate::logical_plan::{
-    build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, 
LogicalPlanBuilder,
-    Operator, Partitioning, Recursion,
+    build_join_schema, Column, DFSchema, DFSchemaRef, Expr, ExprRewriter, 
LogicalPlan,
+    LogicalPlanBuilder, Operator, Partitioning, Recursion, RewriteRecursion,
 };
+use crate::physical_plan::functions::Volatility;
+use crate::physical_plan::planner::DefaultPhysicalPlanner;
 use crate::prelude::lit;
 use crate::scalar::ScalarValue;
 use crate::{
@@ -493,11 +499,196 @@ pub fn rewrite_expression(expr: &Expr, expressions: 
&[Expr]) -> Result<Expr> {
     }
 }
 
+/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
+///
+/// Note it does not handle other algebriac rewrites such as `(a and false)` 
--> `a`
+///
+/// ```
+/// # use datafusion::prelude::*;
+/// # use datafusion::optimizer::utils::ConstEvaluator;
+/// let mut const_evaluator = ConstEvaluator::new();
+///
+/// // (1 + 2) + a
+/// let expr = (lit(1) + lit(2)) + col("a");
+///
+/// // is rewritten to (3 + a);
+/// let rewritten = expr.rewrite(&mut const_evaluator).unwrap();
+/// assert_eq!(rewritten, lit(3) + col("a"));
+/// ```
+pub struct ConstEvaluator {
+    /// can_evaluate is used during the depth-first-search of the
+    /// Expr tree to track if any siblings (or their descendants) were
+    /// non evaluatable (e.g. had a column reference or volatile
+    /// function)
+    ///
+    /// Specifically, can_evaluate[N] represents the state of
+    /// traversal when we are N levels deep in the tree, one entry for
+    /// this Expr and each of its parents.
+    ///
+    /// After visiting all siblings if can_evauate.top() is true, that
+    /// means there were no non evaluatable siblings (or their
+    /// descendants) so this Expr can be evaluated
+    can_evaluate: Vec<bool>,
+
+    ctx_state: ExecutionContextState,
+    planner: DefaultPhysicalPlanner,
+    input_schema: DFSchema,
+    input_batch: RecordBatch,
+}
+
+impl ExprRewriter for ConstEvaluator {
+    fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
+        // Default to being able to evaluate this node
+        self.can_evaluate.push(true);
+
+        // if this expr is not ok to evaluate, mark entire parent
+        // stack as not ok (as all parents have at least one child or
+        // descendant that is non evaluateable
+
+        if !Self::can_evaluate(expr) {
+            // walk back up stack, marking first parent that is not mutable
+            let parent_iter = self.can_evaluate.iter_mut().rev();
+            for p in parent_iter {
+                if !*p {
+                    // optimization: if we find an element on the
+                    // stack already marked, know all elements above are also 
marked
+                    break;
+                }
+                *p = false;
+            }
+        }
+
+        // NB: do not short circuit recursion even if we find a non
+        // evaluatable node (so we can fold other children, args to
+        // functions, etc)
+        Ok(RewriteRecursion::Continue)
+    }
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        if self.can_evaluate.pop().unwrap() {
+            let scalar = self.evaluate_to_scalar(expr)?;
+            Ok(Expr::Literal(scalar))
+        } else {
+            Ok(expr)
+        }
+    }
+}
+
+impl ConstEvaluator {
+    /// Create a new `ConstantEvaluator`.
+    pub fn new() -> Self {
+        let planner = DefaultPhysicalPlanner::default();
+        let ctx_state = ExecutionContextState::new();
+        let input_schema = DFSchema::empty();
+
+        // The dummy column name is unused and doesn't matter as only
+        // expressions without column references can be evaluated
+        static DUMMY_COL_NAME: &str = ".";
+        let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, 
DataType::Null, true)]);
+
+        // Need a single "input" row to produce a single output row
+        let col = new_null_array(&DataType::Null, 1);
+        let input_batch =
+            RecordBatch::try_new(std::sync::Arc::new(schema), 
vec![col]).unwrap();
+
+        Self {
+            can_evaluate: vec![],
+            ctx_state,
+            planner,
+            input_schema,
+            input_batch,
+        }
+    }
+
+    /// Can a function of the specified volatility be evaluated?
+    fn volatility_ok(volatility: Volatility) -> bool {
+        match volatility {
+            Volatility::Immutable => true,
+            // To evaluate stable functions, need ExecutionProps, see
+            // Simplifier for code that does that.
+            Volatility::Stable => false,
+            Volatility::Volatile => false,
+        }
+    }
+
+    /// Can the expression be evaluated at plan time, (assuming all of
+    /// its children can also be evaluated)?
+    fn can_evaluate(expr: &Expr) -> bool {
+        // check for reasons we can't evaluate this node
+        //
+        // NOTE all expr types are listed here so when new ones are
+        // added they can be checked for their ability to be evaluated
+        // at plan time
+        match expr {
+            // Has no runtime cost, but needed during planning
+            Expr::Alias(..) => false,
+            Expr::AggregateFunction { .. } => false,
+            Expr::AggregateUDF { .. } => false,
+            Expr::ScalarVariable(_) => false,
+            Expr::Column(_) => false,
+            Expr::ScalarFunction { fun, .. } => 
Self::volatility_ok(fun.volatility()),
+            Expr::ScalarUDF { fun, .. } => 
Self::volatility_ok(fun.signature.volatility),
+            Expr::WindowFunction { .. } => false,
+            Expr::Sort { .. } => false,
+            Expr::Wildcard => false,
+
+            Expr::Literal(_) => true,
+            Expr::BinaryExpr { .. } => true,
+            Expr::Not(_) => true,
+            Expr::IsNotNull(_) => true,
+            Expr::IsNull(_) => true,
+            Expr::Negative(_) => true,
+            Expr::Between { .. } => true,
+            Expr::Case { .. } => true,
+            Expr::Cast { .. } => true,
+            Expr::TryCast { .. } => true,
+            Expr::InList { .. } => true,
+        }
+    }
+
+    /// Internal helper to evaluates an Expr
+    fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
+        if let Expr::Literal(s) = expr {
+            return Ok(s);
+        }
+
+        let phys_expr = self.planner.create_physical_expr(
+            &expr,
+            &self.input_schema,
+            &self.input_batch.schema(),
+            &self.ctx_state,
+        )?;
+        let col_val = phys_expr.evaluate(&self.input_batch)?;
+        match col_val {
+            crate::physical_plan::ColumnarValue::Array(a) => {
+                if a.len() != 1 {
+                    Err(DataFusionError::Execution(format!(
+                        "Could not evaluate the expressison, found a result of 
length {}",
+                        a.len()
+                    )))
+                } else {
+                    Ok(ScalarValue::try_from_array(&a, 0)?)
+                }
+            }
+            crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s),
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::logical_plan::col;
-    use arrow::datatypes::DataType;
+    use crate::{
+        logical_plan::{col, create_udf, lit_timestamp_nano},
+        physical_plan::{
+            functions::{make_scalar_function, BuiltinScalarFunction},
+            udf::ScalarUDF,
+        },
+    };
+    use arrow::{
+        array::{ArrayRef, Int32Array},
+        datatypes::DataType,
+    };
     use std::collections::HashSet;
 
     #[test]
@@ -521,4 +712,200 @@ mod tests {
         assert!(accum.contains(&Column::from_name("a")));
         Ok(())
     }
+
+    #[test]
+    fn test_const_evaluator() {
+        // true --> true
+        test_evaluate(lit(true), lit(true));
+        // true or true --> true
+        test_evaluate(lit(true).or(lit(true)), lit(true));
+        // true or false --> true
+        test_evaluate(lit(true).or(lit(false)), lit(true));
+
+        // "foo" == "foo" --> true
+        test_evaluate(lit("foo").eq(lit("foo")), lit(true));
+        // "foo" != "foo" --> false
+        test_evaluate(lit("foo").not_eq(lit("foo")), lit(false));
+
+        // c = 1 --> c = 1
+        test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1)));
+        // c = 1 + 2 --> c + 3
+        test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3)));
+        // (foo != foo) OR (c = 1) --> false OR (c = 1)
+        test_evaluate(
+            (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))),
+            lit(false).or(col("c").eq(lit(1))),
+        );
+    }
+
+    #[test]
+    fn test_const_evaluator_scalar_functions() {
+        // concat("foo", "bar") --> "foobar"
+        let expr = Expr::ScalarFunction {
+            args: vec![lit("foo"), lit("bar")],
+            fun: BuiltinScalarFunction::Concat,
+        };
+        test_evaluate(expr, lit("foobar"));
+
+        // ensure arguments are also constant folded
+        // concat("foo", concat("bar", "baz")) --> "foobarbaz"
+        let concat1 = Expr::ScalarFunction {
+            args: vec![lit("bar"), lit("baz")],
+            fun: BuiltinScalarFunction::Concat,
+        };
+        let expr = Expr::ScalarFunction {
+            args: vec![lit("foo"), concat1],
+            fun: BuiltinScalarFunction::Concat,
+        };
+        test_evaluate(expr, lit("foobarbaz"));
+
+        // Check non string arguments
+        // to_timestamp("2020-09-08T12:00:00+00:00") --> 
timestamp(1599566400000000000i64)
+        let expr = Expr::ScalarFunction {
+            args: vec![lit("2020-09-08T12:00:00+00:00")],
+            fun: BuiltinScalarFunction::ToTimestamp,
+        };
+        test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64));
+
+        // check that non foldable arguments are folded
+        // to_timestamp(a) --> to_timestamp(a) [no rewrite possible]
+        let expr = Expr::ScalarFunction {
+            args: vec![col("a")],
+            fun: BuiltinScalarFunction::ToTimestamp,
+        };
+        test_evaluate(expr.clone(), expr);
+
+        // check that non foldable arguments are folded
+        // to_timestamp(a) --> to_timestamp(a) [no rewrite possible]
+        let expr = Expr::ScalarFunction {
+            args: vec![col("a")],
+            fun: BuiltinScalarFunction::ToTimestamp,
+        };
+        test_evaluate(expr.clone(), expr);
+
+        // volatile / stable functions should not be evaluated
+        // rand() + (1 + 2) --> rand() + 3
+        let fun = BuiltinScalarFunction::Random;
+        assert_eq!(fun.volatility(), Volatility::Volatile);
+        let rand = Expr::ScalarFunction { args: vec![], fun };
+        let expr = rand.clone() + (lit(1) + lit(2));
+        let expected = rand + lit(3);
+        test_evaluate(expr, expected);
+
+        // parenthesization matters: can't rewrite
+        // (rand() + 1) + 2 --> (rand() + 1) + 2)
+        let fun = BuiltinScalarFunction::Random;
+        assert_eq!(fun.volatility(), Volatility::Volatile);
+        let rand = Expr::ScalarFunction { args: vec![], fun };
+        let expr = (rand + lit(1)) + lit(2);
+        test_evaluate(expr.clone(), expr);
+
+        // volatile / stable functions should not be evaluated
+        // now() + (1 + 2) --> now() + 3
+        let fun = BuiltinScalarFunction::Now;
+        assert_eq!(fun.volatility(), Volatility::Stable);
+        let now = Expr::ScalarFunction { args: vec![], fun };
+        let expr = now.clone() + (lit(1) + lit(2));
+        let expected = now + lit(3);
+        test_evaluate(expr, expected);
+    }
+
+    #[test]
+    fn test_const_evaluator_udfs() {
+        let args = vec![lit(1) + lit(2), lit(30) + lit(40)];
+        let folded_args = vec![lit(3), lit(70)];
+
+        // immutable UDF should get folded
+        // udf_add(1+2, 30+40) --> 70
+        let expr = Expr::ScalarUDF {
+            args: args.clone(),
+            fun: make_udf_add(Volatility::Immutable),
+        };
+        test_evaluate(expr, lit(73));
+
+        // stable UDF should have args folded
+        // udf_add(1+2, 30+40) --> udf_add(3, 70)
+        let fun = make_udf_add(Volatility::Stable);
+        let expr = Expr::ScalarUDF {
+            args: args.clone(),
+            fun: Arc::clone(&fun),
+        };
+        let expected_expr = Expr::ScalarUDF {
+            args: folded_args.clone(),
+            fun: Arc::clone(&fun),
+        };
+        test_evaluate(expr, expected_expr);
+
+        // volatile UDF should have args folded
+        // udf_add(1+2, 30+40) --> udf_add(3, 70)
+        let fun = make_udf_add(Volatility::Volatile);
+        let expr = Expr::ScalarUDF {
+            args,
+            fun: Arc::clone(&fun),
+        };
+        let expected_expr = Expr::ScalarUDF {
+            args: folded_args,
+            fun: Arc::clone(&fun),
+        };
+        test_evaluate(expr, expected_expr);
+    }
+
+    // Make a UDF that adds its two values together, with the specified 
volatility
+    fn make_udf_add(volatility: Volatility) -> Arc<ScalarUDF> {
+        let input_types = vec![DataType::Int32, DataType::Int32];
+        let return_type = Arc::new(DataType::Int32);
+
+        let fun = |args: &[ArrayRef]| {
+            let arg0 = &args[0]
+                .as_any()
+                .downcast_ref::<Int32Array>()
+                .expect("cast failed");
+            let arg1 = &args[1]
+                .as_any()
+                .downcast_ref::<Int32Array>()
+                .expect("cast failed");
+
+            // 2. perform the computation
+            let array = arg0
+                .iter()
+                .zip(arg1.iter())
+                .map(|args| {
+                    if let (Some(arg0), Some(arg1)) = args {
+                        Some(arg0 + arg1)
+                    } else {
+                        // one or both args were Null
+                        None
+                    }
+                })
+                .collect::<Int32Array>();
+
+            Ok(Arc::new(array) as ArrayRef)
+        };
+
+        let fun = make_scalar_function(fun);
+        Arc::new(create_udf(
+            "udf_add",
+            input_types,
+            return_type,
+            volatility,
+            fun,
+        ))
+    }
+
+    // udfs
+    // validate that even a volatile function's arguments will be evaluated
+
+    fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
+        let mut const_evaluator = ConstEvaluator::new();
+        let evaluated_expr = input_expr
+            .clone()
+            .rewrite(&mut const_evaluator)
+            .expect("successfully evaluated");
+
+        assert_eq!(
+            evaluated_expr, expected_expr,
+            "Mismatch evaluating {}\n  Expected:{}\n  Got:{}",
+            input_expr, expected_expr, evaluated_expr
+        );
+    }
 }
diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs
index 0c9498a..03e0054 100644
--- a/datafusion/src/test_util.rs
+++ b/datafusion/src/test_util.rs
@@ -88,6 +88,52 @@ macro_rules! assert_batches_sorted_eq {
     };
 }
 
+/// A macro to assert that one string is contained within another with
+/// a nice error message if they are not.
+///
+/// Usage: `assert_contains!(actual, expected)`
+///
+/// Is a macro so test error
+/// messages are on the same line as the failure;
+///
+/// Both arguments must be convertable into Strings (Into<String>)
+#[macro_export]
+macro_rules! assert_contains {
+    ($ACTUAL: expr, $EXPECTED: expr) => {
+        let actual_value: String = $ACTUAL.into();
+        let expected_value: String = $EXPECTED.into();
+        assert!(
+            actual_value.contains(&expected_value),
+            "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
+            expected_value,
+            actual_value
+        );
+    };
+}
+
+/// A macro to assert that one string is NOT contained within another with
+/// a nice error message if they are are.
+///
+/// Usage: `assert_not_contains!(actual, unexpected)`
+///
+/// Is a macro so test error
+/// messages are on the same line as the failure;
+///
+/// Both arguments must be convertable into Strings (Into<String>)
+#[macro_export]
+macro_rules! assert_not_contains {
+    ($ACTUAL: expr, $UNEXPECTED: expr) => {
+        let actual_value: String = $ACTUAL.into();
+        let unexpected_value: String = $UNEXPECTED.into();
+        assert!(
+            !actual_value.contains(&unexpected_value),
+            "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
+            unexpected_value,
+            actual_value
+        );
+    };
+}
+
 /// Returns the arrow test data directory, which is by default stored
 /// in a git submodule rooted at `testing/data`.
 ///
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index e484152..f3dba3f 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -34,6 +34,8 @@ use arrow::{
 
 use datafusion::assert_batches_eq;
 use datafusion::assert_batches_sorted_eq;
+use datafusion::assert_contains;
+use datafusion::assert_not_contains;
 use datafusion::logical_plan::LogicalPlan;
 use datafusion::physical_plan::functions::Volatility;
 use datafusion::physical_plan::metrics::MetricValue;
@@ -47,50 +49,6 @@ use datafusion::{
 };
 use datafusion::{execution::context::ExecutionContext, 
physical_plan::displayable};
 
-/// A macro to assert that one string is contained within another with
-/// a nice error message if they are not.
-///
-/// Usage: `assert_contains!(actual, expected)`
-///
-/// Is a macro so test error
-/// messages are on the same line as the failure;
-///
-/// Both arguments must be convertable into Strings (Into<String>)
-macro_rules! assert_contains {
-    ($ACTUAL: expr, $EXPECTED: expr) => {
-        let actual_value: String = $ACTUAL.into();
-        let expected_value: String = $EXPECTED.into();
-        assert!(
-            actual_value.contains(&expected_value),
-            "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
-            expected_value,
-            actual_value
-        );
-    };
-}
-
-/// A macro to assert that one string is NOT contained within another with
-/// a nice error message if they are are.
-///
-/// Usage: `assert_not_contains!(actual, unexpected)`
-///
-/// Is a macro so test error
-/// messages are on the same line as the failure;
-///
-/// Both arguments must be convertable into Strings (Into<String>)
-macro_rules! assert_not_contains {
-    ($ACTUAL: expr, $UNEXPECTED: expr) => {
-        let actual_value: String = $ACTUAL.into();
-        let unexpected_value: String = $UNEXPECTED.into();
-        assert!(
-            !actual_value.contains(&unexpected_value),
-            "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
-            unexpected_value,
-            actual_value
-        );
-    };
-}
-
 #[tokio::test]
 async fn nyc() -> Result<()> {
     // schema for nyxtaxi csv files

Reply via email to