This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 156ebff Generic constant expression evaluation (#1153)
156ebff is described below
commit 156ebff70f96346742c0654ea4af76b9d1036530
Author: Andrew Lamb <[email protected]>
AuthorDate: Wed Oct 27 19:19:47 2021 -0400
Generic constant expression evaluation (#1153)
* Generic constant expression evaluation
* Better list of evaluatable expressions
* Fixup comments
* Use Null type
---
datafusion/src/optimizer/constant_folding.rs | 279 ++++++++++---------
datafusion/src/optimizer/utils.rs | 397 ++++++++++++++++++++++++++-
datafusion/src/test_util.rs | 46 ++++
datafusion/tests/sql.rs | 46 +---
4 files changed, 586 insertions(+), 182 deletions(-)
diff --git a/datafusion/src/optimizer/constant_folding.rs
b/datafusion/src/optimizer/constant_folding.rs
index d67d7d1..74fdc72 100644
--- a/datafusion/src/optimizer/constant_folding.rs
+++ b/datafusion/src/optimizer/constant_folding.rs
@@ -15,12 +15,10 @@
// specific language governing permissions and limitations
// under the License.
-//! Boolean comparison rule rewrites redundant comparison expression involving
boolean literal into
-//! unary expression.
+//! Constant folding and algebraic simplification
use std::sync::Arc;
-use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos;
use arrow::datatypes::DataType;
use crate::error::Result;
@@ -30,11 +28,11 @@ use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use crate::physical_plan::functions::BuiltinScalarFunction;
use crate::scalar::ScalarValue;
-use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS};
-/// Optimizer that simplifies comparison expressions involving boolean
literals.
+/// Simplifies plans by rewriting [`Expr`]`s evaluating constants
+/// and applying algebraic simplifications
///
-/// Recursively go through all expressions and simplify the following cases:
+/// Example transformations that are applied:
/// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean
type
/// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean
type
/// * `true = true` and `false = false` to `true`
@@ -61,14 +59,16 @@ impl OptimizerRule for ConstantFolding {
// projected columns. With just the projected schema, it's not
possible to infer types for
// expressions that references non-projected columns within the same
project plan or its
// children plans.
- let mut rewriter = ConstantRewriter {
+ let mut simplifier = Simplifier {
schemas: plan.all_schemas(),
execution_props,
};
+ let mut const_evaluator = utils::ConstEvaluator::new();
+
match plan {
LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter
{
- predicate: predicate.clone().rewrite(&mut rewriter)?,
+ predicate: predicate.clone().rewrite(&mut simplifier)?,
input: Arc::new(self.optimize(input, execution_props)?),
}),
// Rest: recurse into plan, apply optimization where possible
@@ -96,7 +96,18 @@ impl OptimizerRule for ConstantFolding {
let expr = plan
.expressions()
.into_iter()
- .map(|e| e.rewrite(&mut rewriter))
+ .map(|e| {
+ // TODO iterate until no changes are made
+ // during rewrite (evaluating constants can
+ // enable new simplifications and
+ // simplifications can enable new constant
+ // evaluation)
+ let new_e = e
+ // fold constants and then simplify
+ .rewrite(&mut const_evaluator)?
+ .rewrite(&mut simplifier)?;
+ Ok(new_e)
+ })
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
@@ -112,13 +123,17 @@ impl OptimizerRule for ConstantFolding {
}
}
-struct ConstantRewriter<'a> {
+/// Simplifies [`Expr`]s by applying algebraic transformation rules
+///
+/// For example
+/// `true && col` --> `col` where `col` is a boolean types
+struct Simplifier<'a> {
/// input schemas
schemas: Vec<&'a DFSchemaRef>,
execution_props: &'a ExecutionProps,
}
-impl<'a> ConstantRewriter<'a> {
+impl<'a> Simplifier<'a> {
fn is_boolean_type(&self, expr: &Expr) -> bool {
for schema in &self.schemas {
if let Ok(DataType::Boolean) = expr.get_type(schema) {
@@ -130,7 +145,7 @@ impl<'a> ConstantRewriter<'a> {
}
}
-impl<'a> ExprRewriter for ConstantRewriter<'a> {
+impl<'a> ExprRewriter for Simplifier<'a> {
/// rewrite the expression simplifying any constant expressions
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
let new_expr = match expr {
@@ -205,14 +220,15 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> {
},
_ => Expr::BinaryExpr { left, op, right },
},
+ // Not(Not(expr)) --> expr
Expr::Not(inner) => {
- // Not(Not(expr)) --> expr
if let Expr::Not(negated_inner) = *inner {
*negated_inner
} else {
Expr::Not(inner)
}
}
+ // convert now() --> the time in `ExecutionProps`
Expr::ScalarFunction {
fun: BuiltinScalarFunction::Now,
..
@@ -221,56 +237,8 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> {
.query_execution_start_time
.timestamp_nanos(),
))),
- Expr::ScalarFunction {
- fun: BuiltinScalarFunction::ToTimestamp,
- args,
- } => {
- if !args.is_empty() {
- match &args[0] {
- Expr::Literal(ScalarValue::Utf8(Some(val))) => {
- match string_to_timestamp_nanos(val) {
- Ok(timestamp) => Expr::Literal(
-
ScalarValue::TimestampNanosecond(Some(timestamp)),
- ),
- _ => Expr::ScalarFunction {
- fun: BuiltinScalarFunction::ToTimestamp,
- args,
- },
- }
- }
- _ => Expr::ScalarFunction {
- fun: BuiltinScalarFunction::ToTimestamp,
- args,
- },
- }
- } else {
- Expr::ScalarFunction {
- fun: BuiltinScalarFunction::ToTimestamp,
- args,
- }
- }
- }
- Expr::Cast {
- expr: inner,
- data_type,
- } => match inner.as_ref() {
- Expr::Literal(val) => {
- let scalar_array = val.to_array();
- let cast_array = kernels::cast::cast_with_options(
- &scalar_array,
- &data_type,
- &DEFAULT_CAST_OPTIONS,
- )?;
- let cast_scalar = ScalarValue::try_from_array(&cast_array,
0)?;
- Expr::Literal(cast_scalar)
- }
- _ => Expr::Cast {
- expr: inner,
- data_type,
- },
- },
expr => {
- // no rewrite possible
+ // no additional rewrites possible
expr
}
};
@@ -281,12 +249,13 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> {
#[cfg(test)]
mod tests {
use super::*;
- use crate::logical_plan::{
- col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder,
+ use crate::{
+ assert_contains,
+ logical_plan::{col, lit, max, min, DFField, DFSchema,
LogicalPlanBuilder},
};
use arrow::datatypes::*;
- use chrono::{DateTime, Utc};
+ use chrono::{DateTime, TimeZone, Utc};
fn test_table_scan() -> Result<LogicalPlan> {
let schema = Schema::new(vec![
@@ -311,7 +280,7 @@ mod tests {
#[test]
fn optimize_expr_not_not() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -327,7 +296,7 @@ mod tests {
#[test]
fn optimize_expr_null_comparison() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -363,7 +332,7 @@ mod tests {
#[test]
fn optimize_expr_eq() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -394,7 +363,7 @@ mod tests {
#[test]
fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -434,7 +403,7 @@ mod tests {
#[test]
fn optimize_expr_not_eq() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -470,7 +439,7 @@ mod tests {
#[test]
fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -506,7 +475,7 @@ mod tests {
#[test]
fn optimize_expr_case_when_then_else() -> Result<()> {
let schema = expr_test_schema();
- let mut rewriter = ConstantRewriter {
+ let mut rewriter = Simplifier {
schemas: vec![&schema],
execution_props: &ExecutionProps::new(),
};
@@ -669,6 +638,20 @@ mod tests {
Ok(())
}
+ // expect optimizing will result in an error, returning the error string
+ fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime<Utc>)
-> String {
+ let rule = ConstantFolding::new();
+ let execution_props = ExecutionProps {
+ query_execution_start_time: *date_time,
+ };
+
+ let err = rule
+ .optimize(plan, &execution_props)
+ .expect_err("expected optimization to fail");
+
+ err.to_string()
+ }
+
fn get_optimized_plan_formatted(
plan: &LogicalPlan,
date_time: &DateTime<Utc>,
@@ -684,15 +667,19 @@ mod tests {
return format!("{:?}", optimized_plan);
}
+ /// Create a to_timestamp expr
+ fn to_timestamp_expr(arg: impl Into<String>) -> Expr {
+ Expr::ScalarFunction {
+ args: vec![lit(arg.into())],
+ fun: BuiltinScalarFunction::ToTimestamp,
+ }
+ }
+
#[test]
- fn to_timestamp_expr() {
+ fn to_timestamp_expr_folded() {
let table_scan = test_table_scan().unwrap();
- let proj = vec![Expr::ScalarFunction {
- args: vec![Expr::Literal(ScalarValue::Utf8(Some(
- "2020-09-08T12:00:00+00:00".to_string(),
- )))],
- fun: BuiltinScalarFunction::ToTimestamp,
- }];
+ let proj = vec![to_timestamp_expr("2020-09-08T12:00:00+00:00")];
+
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)
.unwrap()
@@ -702,55 +689,30 @@ mod tests {
let expected = "Projection: TimestampNanosecond(1599566400000000000)\
\n TableScan: test projection=None"
.to_string();
- let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
+ let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
}
#[test]
fn to_timestamp_expr_wrong_arg() {
let table_scan = test_table_scan().unwrap();
- let proj = vec![Expr::ScalarFunction {
- args: vec![Expr::Literal(ScalarValue::Utf8(Some(
- "I'M NOT A TIMESTAMP".to_string(),
- )))],
- fun: BuiltinScalarFunction::ToTimestamp,
- }];
- let plan = LogicalPlanBuilder::from(table_scan)
- .project(proj)
- .unwrap()
- .build()
- .unwrap();
-
- let expected = "Projection: totimestamp(Utf8(\"I\'M NOT A
TIMESTAMP\"))\
- \n TableScan: test projection=None";
- let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
- assert_eq!(expected, actual);
- }
-
- #[test]
- fn to_timestamp_expr_no_arg() {
- let table_scan = test_table_scan().unwrap();
- let proj = vec![Expr::ScalarFunction {
- args: vec![],
- fun: BuiltinScalarFunction::ToTimestamp,
- }];
+ let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")];
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)
.unwrap()
.build()
.unwrap();
- let expected = "Projection: totimestamp()\
- \n TableScan: test projection=None";
- let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
- assert_eq!(expected, actual);
+ let expected = "Error parsing 'I'M NOT A TIMESTAMP' as timestamp";
+ let actual = get_optimized_plan_err(&plan, &Utc::now());
+ assert_contains!(actual, expected);
}
#[test]
fn cast_expr() {
let table_scan = test_table_scan().unwrap();
let proj = vec![Expr::Cast {
- expr:
Box::new(Expr::Literal(ScalarValue::Utf8(Some("0".to_string())))),
+ expr: Box::new(lit("0")),
data_type: DataType::Int32,
}];
let plan = LogicalPlanBuilder::from(table_scan)
@@ -761,7 +723,7 @@ mod tests {
let expected = "Projection: Int32(0)\
\n TableScan: test projection=None";
- let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
+ let actual = get_optimized_plan_formatted(&plan, &Utc::now());
assert_eq!(expected, actual);
}
@@ -769,7 +731,7 @@ mod tests {
fn cast_expr_wrong_arg() {
let table_scan = test_table_scan().unwrap();
let proj = vec![Expr::Cast {
- expr:
Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))),
+ expr: Box::new(lit("")),
data_type: DataType::Int32,
}];
let plan = LogicalPlanBuilder::from(table_scan)
@@ -778,20 +740,24 @@ mod tests {
.build()
.unwrap();
- let expected = "Projection: Int32(NULL)\
- \n TableScan: test projection=None";
- let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now());
- assert_eq!(expected, actual);
+ let expected =
+ "Cannot cast string '' to value of
arrow::datatypes::types::Int32Type type";
+ let actual = get_optimized_plan_err(&plan, &Utc::now());
+ assert_contains!(actual, expected);
+ }
+
+ fn now_expr() -> Expr {
+ Expr::ScalarFunction {
+ args: vec![],
+ fun: BuiltinScalarFunction::Now,
+ }
}
#[test]
fn single_now_expr() {
let table_scan = test_table_scan().unwrap();
- let proj = vec![Expr::ScalarFunction {
- args: vec![],
- fun: BuiltinScalarFunction::Now,
- }];
- let time = chrono::Utc::now();
+ let proj = vec![now_expr()];
+ let time = Utc::now();
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)
.unwrap()
@@ -811,19 +777,10 @@ mod tests {
#[test]
fn multiple_now_expr() {
let table_scan = test_table_scan().unwrap();
- let time = chrono::Utc::now();
+ let time = Utc::now();
let proj = vec![
- Expr::ScalarFunction {
- args: vec![],
- fun: BuiltinScalarFunction::Now,
- },
- Expr::Alias(
- Box::new(Expr::ScalarFunction {
- args: vec![],
- fun: BuiltinScalarFunction::Now,
- }),
- "t2".to_string(),
- ),
+ now_expr(),
+ Expr::Alias(Box::new(now_expr()), "t2".to_string()),
];
let plan = LogicalPlanBuilder::from(table_scan)
.project(proj)
@@ -831,6 +788,7 @@ mod tests {
.build()
.unwrap();
+ // expect the same timestamp appears in both exprs
let actual = get_optimized_plan_formatted(&plan, &time);
let expected = format!(
"Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS
t2\
@@ -841,4 +799,59 @@ mod tests {
assert_eq!(actual, expected);
}
+
+ #[test]
+ fn simplify_and_eval() {
+ // demonstrate a case where the evaluation needs to run prior
+ // to the simplifier for it to work
+ let table_scan = test_table_scan().unwrap();
+ let time = Utc::now();
+ // (true or false) != col --> !col
+ let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))];
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .project(proj)
+ .unwrap()
+ .build()
+ .unwrap();
+
+ let actual = get_optimized_plan_formatted(&plan, &time);
+ let expected = "Projection: NOT #test.a\
+ \n TableScan: test projection=None";
+
+ assert_eq!(actual, expected);
+ }
+
+ fn cast_to_int64_expr(expr: Expr) -> Expr {
+ Expr::Cast {
+ expr: expr.into(),
+ data_type: DataType::Int64,
+ }
+ }
+
+ #[test]
+ fn now_less_than_timestamp() {
+ let table_scan = test_table_scan().unwrap();
+
+ let ts_string = "2020-09-08T12:05:00+00:00";
+ let time = chrono::Utc.timestamp_nanos(1599566400000000000i64);
+
+ // now() < cast(to_timestamp(...) as int) + 5000000000
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .filter(
+ now_expr()
+ .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) +
lit(50000)),
+ )
+ .unwrap()
+ .build()
+ .unwrap();
+
+ // Note that constant folder should be able to run again and fold
+ // this whole expression down to a single constant;
+ // https://github.com/apache/arrow-datafusion/issues/1160
+ let expected = "Filter: TimestampNanosecond(1599566400000000000) <
CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\
+ \n TableScan: test projection=None";
+ let actual = get_optimized_plan_formatted(&plan, &time);
+
+ assert_eq!(expected, actual);
+ }
}
diff --git a/datafusion/src/optimizer/utils.rs
b/datafusion/src/optimizer/utils.rs
index 1da584b..fdc9a17 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -17,12 +17,18 @@
//! Collection of utility functions that are leveraged by the query optimizer
rules
+use arrow::array::new_null_array;
+use arrow::datatypes::{DataType, Field, Schema};
+use arrow::record_batch::RecordBatch;
+
use super::optimizer::OptimizerRule;
-use crate::execution::context::ExecutionProps;
+use crate::execution::context::{ExecutionContextState, ExecutionProps};
use crate::logical_plan::{
- build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan,
LogicalPlanBuilder,
- Operator, Partitioning, Recursion,
+ build_join_schema, Column, DFSchema, DFSchemaRef, Expr, ExprRewriter,
LogicalPlan,
+ LogicalPlanBuilder, Operator, Partitioning, Recursion, RewriteRecursion,
};
+use crate::physical_plan::functions::Volatility;
+use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::prelude::lit;
use crate::scalar::ScalarValue;
use crate::{
@@ -493,11 +499,196 @@ pub fn rewrite_expression(expr: &Expr, expressions:
&[Expr]) -> Result<Expr> {
}
}
+/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time.
+///
+/// Note it does not handle other algebriac rewrites such as `(a and false)`
--> `a`
+///
+/// ```
+/// # use datafusion::prelude::*;
+/// # use datafusion::optimizer::utils::ConstEvaluator;
+/// let mut const_evaluator = ConstEvaluator::new();
+///
+/// // (1 + 2) + a
+/// let expr = (lit(1) + lit(2)) + col("a");
+///
+/// // is rewritten to (3 + a);
+/// let rewritten = expr.rewrite(&mut const_evaluator).unwrap();
+/// assert_eq!(rewritten, lit(3) + col("a"));
+/// ```
+pub struct ConstEvaluator {
+ /// can_evaluate is used during the depth-first-search of the
+ /// Expr tree to track if any siblings (or their descendants) were
+ /// non evaluatable (e.g. had a column reference or volatile
+ /// function)
+ ///
+ /// Specifically, can_evaluate[N] represents the state of
+ /// traversal when we are N levels deep in the tree, one entry for
+ /// this Expr and each of its parents.
+ ///
+ /// After visiting all siblings if can_evauate.top() is true, that
+ /// means there were no non evaluatable siblings (or their
+ /// descendants) so this Expr can be evaluated
+ can_evaluate: Vec<bool>,
+
+ ctx_state: ExecutionContextState,
+ planner: DefaultPhysicalPlanner,
+ input_schema: DFSchema,
+ input_batch: RecordBatch,
+}
+
+impl ExprRewriter for ConstEvaluator {
+ fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
+ // Default to being able to evaluate this node
+ self.can_evaluate.push(true);
+
+ // if this expr is not ok to evaluate, mark entire parent
+ // stack as not ok (as all parents have at least one child or
+ // descendant that is non evaluateable
+
+ if !Self::can_evaluate(expr) {
+ // walk back up stack, marking first parent that is not mutable
+ let parent_iter = self.can_evaluate.iter_mut().rev();
+ for p in parent_iter {
+ if !*p {
+ // optimization: if we find an element on the
+ // stack already marked, know all elements above are also
marked
+ break;
+ }
+ *p = false;
+ }
+ }
+
+ // NB: do not short circuit recursion even if we find a non
+ // evaluatable node (so we can fold other children, args to
+ // functions, etc)
+ Ok(RewriteRecursion::Continue)
+ }
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if self.can_evaluate.pop().unwrap() {
+ let scalar = self.evaluate_to_scalar(expr)?;
+ Ok(Expr::Literal(scalar))
+ } else {
+ Ok(expr)
+ }
+ }
+}
+
+impl ConstEvaluator {
+ /// Create a new `ConstantEvaluator`.
+ pub fn new() -> Self {
+ let planner = DefaultPhysicalPlanner::default();
+ let ctx_state = ExecutionContextState::new();
+ let input_schema = DFSchema::empty();
+
+ // The dummy column name is unused and doesn't matter as only
+ // expressions without column references can be evaluated
+ static DUMMY_COL_NAME: &str = ".";
+ let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME,
DataType::Null, true)]);
+
+ // Need a single "input" row to produce a single output row
+ let col = new_null_array(&DataType::Null, 1);
+ let input_batch =
+ RecordBatch::try_new(std::sync::Arc::new(schema),
vec![col]).unwrap();
+
+ Self {
+ can_evaluate: vec![],
+ ctx_state,
+ planner,
+ input_schema,
+ input_batch,
+ }
+ }
+
+ /// Can a function of the specified volatility be evaluated?
+ fn volatility_ok(volatility: Volatility) -> bool {
+ match volatility {
+ Volatility::Immutable => true,
+ // To evaluate stable functions, need ExecutionProps, see
+ // Simplifier for code that does that.
+ Volatility::Stable => false,
+ Volatility::Volatile => false,
+ }
+ }
+
+ /// Can the expression be evaluated at plan time, (assuming all of
+ /// its children can also be evaluated)?
+ fn can_evaluate(expr: &Expr) -> bool {
+ // check for reasons we can't evaluate this node
+ //
+ // NOTE all expr types are listed here so when new ones are
+ // added they can be checked for their ability to be evaluated
+ // at plan time
+ match expr {
+ // Has no runtime cost, but needed during planning
+ Expr::Alias(..) => false,
+ Expr::AggregateFunction { .. } => false,
+ Expr::AggregateUDF { .. } => false,
+ Expr::ScalarVariable(_) => false,
+ Expr::Column(_) => false,
+ Expr::ScalarFunction { fun, .. } =>
Self::volatility_ok(fun.volatility()),
+ Expr::ScalarUDF { fun, .. } =>
Self::volatility_ok(fun.signature.volatility),
+ Expr::WindowFunction { .. } => false,
+ Expr::Sort { .. } => false,
+ Expr::Wildcard => false,
+
+ Expr::Literal(_) => true,
+ Expr::BinaryExpr { .. } => true,
+ Expr::Not(_) => true,
+ Expr::IsNotNull(_) => true,
+ Expr::IsNull(_) => true,
+ Expr::Negative(_) => true,
+ Expr::Between { .. } => true,
+ Expr::Case { .. } => true,
+ Expr::Cast { .. } => true,
+ Expr::TryCast { .. } => true,
+ Expr::InList { .. } => true,
+ }
+ }
+
+ /// Internal helper to evaluates an Expr
+ fn evaluate_to_scalar(&self, expr: Expr) -> Result<ScalarValue> {
+ if let Expr::Literal(s) = expr {
+ return Ok(s);
+ }
+
+ let phys_expr = self.planner.create_physical_expr(
+ &expr,
+ &self.input_schema,
+ &self.input_batch.schema(),
+ &self.ctx_state,
+ )?;
+ let col_val = phys_expr.evaluate(&self.input_batch)?;
+ match col_val {
+ crate::physical_plan::ColumnarValue::Array(a) => {
+ if a.len() != 1 {
+ Err(DataFusionError::Execution(format!(
+ "Could not evaluate the expressison, found a result of
length {}",
+ a.len()
+ )))
+ } else {
+ Ok(ScalarValue::try_from_array(&a, 0)?)
+ }
+ }
+ crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
- use crate::logical_plan::col;
- use arrow::datatypes::DataType;
+ use crate::{
+ logical_plan::{col, create_udf, lit_timestamp_nano},
+ physical_plan::{
+ functions::{make_scalar_function, BuiltinScalarFunction},
+ udf::ScalarUDF,
+ },
+ };
+ use arrow::{
+ array::{ArrayRef, Int32Array},
+ datatypes::DataType,
+ };
use std::collections::HashSet;
#[test]
@@ -521,4 +712,200 @@ mod tests {
assert!(accum.contains(&Column::from_name("a")));
Ok(())
}
+
+ #[test]
+ fn test_const_evaluator() {
+ // true --> true
+ test_evaluate(lit(true), lit(true));
+ // true or true --> true
+ test_evaluate(lit(true).or(lit(true)), lit(true));
+ // true or false --> true
+ test_evaluate(lit(true).or(lit(false)), lit(true));
+
+ // "foo" == "foo" --> true
+ test_evaluate(lit("foo").eq(lit("foo")), lit(true));
+ // "foo" != "foo" --> false
+ test_evaluate(lit("foo").not_eq(lit("foo")), lit(false));
+
+ // c = 1 --> c = 1
+ test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1)));
+ // c = 1 + 2 --> c + 3
+ test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3)));
+ // (foo != foo) OR (c = 1) --> false OR (c = 1)
+ test_evaluate(
+ (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))),
+ lit(false).or(col("c").eq(lit(1))),
+ );
+ }
+
+ #[test]
+ fn test_const_evaluator_scalar_functions() {
+ // concat("foo", "bar") --> "foobar"
+ let expr = Expr::ScalarFunction {
+ args: vec![lit("foo"), lit("bar")],
+ fun: BuiltinScalarFunction::Concat,
+ };
+ test_evaluate(expr, lit("foobar"));
+
+ // ensure arguments are also constant folded
+ // concat("foo", concat("bar", "baz")) --> "foobarbaz"
+ let concat1 = Expr::ScalarFunction {
+ args: vec![lit("bar"), lit("baz")],
+ fun: BuiltinScalarFunction::Concat,
+ };
+ let expr = Expr::ScalarFunction {
+ args: vec![lit("foo"), concat1],
+ fun: BuiltinScalarFunction::Concat,
+ };
+ test_evaluate(expr, lit("foobarbaz"));
+
+ // Check non string arguments
+ // to_timestamp("2020-09-08T12:00:00+00:00") -->
timestamp(1599566400000000000i64)
+ let expr = Expr::ScalarFunction {
+ args: vec![lit("2020-09-08T12:00:00+00:00")],
+ fun: BuiltinScalarFunction::ToTimestamp,
+ };
+ test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64));
+
+ // check that non foldable arguments are folded
+ // to_timestamp(a) --> to_timestamp(a) [no rewrite possible]
+ let expr = Expr::ScalarFunction {
+ args: vec![col("a")],
+ fun: BuiltinScalarFunction::ToTimestamp,
+ };
+ test_evaluate(expr.clone(), expr);
+
+ // check that non foldable arguments are folded
+ // to_timestamp(a) --> to_timestamp(a) [no rewrite possible]
+ let expr = Expr::ScalarFunction {
+ args: vec![col("a")],
+ fun: BuiltinScalarFunction::ToTimestamp,
+ };
+ test_evaluate(expr.clone(), expr);
+
+ // volatile / stable functions should not be evaluated
+ // rand() + (1 + 2) --> rand() + 3
+ let fun = BuiltinScalarFunction::Random;
+ assert_eq!(fun.volatility(), Volatility::Volatile);
+ let rand = Expr::ScalarFunction { args: vec![], fun };
+ let expr = rand.clone() + (lit(1) + lit(2));
+ let expected = rand + lit(3);
+ test_evaluate(expr, expected);
+
+ // parenthesization matters: can't rewrite
+ // (rand() + 1) + 2 --> (rand() + 1) + 2)
+ let fun = BuiltinScalarFunction::Random;
+ assert_eq!(fun.volatility(), Volatility::Volatile);
+ let rand = Expr::ScalarFunction { args: vec![], fun };
+ let expr = (rand + lit(1)) + lit(2);
+ test_evaluate(expr.clone(), expr);
+
+ // volatile / stable functions should not be evaluated
+ // now() + (1 + 2) --> now() + 3
+ let fun = BuiltinScalarFunction::Now;
+ assert_eq!(fun.volatility(), Volatility::Stable);
+ let now = Expr::ScalarFunction { args: vec![], fun };
+ let expr = now.clone() + (lit(1) + lit(2));
+ let expected = now + lit(3);
+ test_evaluate(expr, expected);
+ }
+
+ #[test]
+ fn test_const_evaluator_udfs() {
+ let args = vec![lit(1) + lit(2), lit(30) + lit(40)];
+ let folded_args = vec![lit(3), lit(70)];
+
+ // immutable UDF should get folded
+ // udf_add(1+2, 30+40) --> 70
+ let expr = Expr::ScalarUDF {
+ args: args.clone(),
+ fun: make_udf_add(Volatility::Immutable),
+ };
+ test_evaluate(expr, lit(73));
+
+ // stable UDF should have args folded
+ // udf_add(1+2, 30+40) --> udf_add(3, 70)
+ let fun = make_udf_add(Volatility::Stable);
+ let expr = Expr::ScalarUDF {
+ args: args.clone(),
+ fun: Arc::clone(&fun),
+ };
+ let expected_expr = Expr::ScalarUDF {
+ args: folded_args.clone(),
+ fun: Arc::clone(&fun),
+ };
+ test_evaluate(expr, expected_expr);
+
+ // volatile UDF should have args folded
+ // udf_add(1+2, 30+40) --> udf_add(3, 70)
+ let fun = make_udf_add(Volatility::Volatile);
+ let expr = Expr::ScalarUDF {
+ args,
+ fun: Arc::clone(&fun),
+ };
+ let expected_expr = Expr::ScalarUDF {
+ args: folded_args,
+ fun: Arc::clone(&fun),
+ };
+ test_evaluate(expr, expected_expr);
+ }
+
+ // Make a UDF that adds its two values together, with the specified
volatility
+ fn make_udf_add(volatility: Volatility) -> Arc<ScalarUDF> {
+ let input_types = vec![DataType::Int32, DataType::Int32];
+ let return_type = Arc::new(DataType::Int32);
+
+ let fun = |args: &[ArrayRef]| {
+ let arg0 = &args[0]
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .expect("cast failed");
+ let arg1 = &args[1]
+ .as_any()
+ .downcast_ref::<Int32Array>()
+ .expect("cast failed");
+
+ // 2. perform the computation
+ let array = arg0
+ .iter()
+ .zip(arg1.iter())
+ .map(|args| {
+ if let (Some(arg0), Some(arg1)) = args {
+ Some(arg0 + arg1)
+ } else {
+ // one or both args were Null
+ None
+ }
+ })
+ .collect::<Int32Array>();
+
+ Ok(Arc::new(array) as ArrayRef)
+ };
+
+ let fun = make_scalar_function(fun);
+ Arc::new(create_udf(
+ "udf_add",
+ input_types,
+ return_type,
+ volatility,
+ fun,
+ ))
+ }
+
+ // udfs
+ // validate that even a volatile function's arguments will be evaluated
+
+ fn test_evaluate(input_expr: Expr, expected_expr: Expr) {
+ let mut const_evaluator = ConstEvaluator::new();
+ let evaluated_expr = input_expr
+ .clone()
+ .rewrite(&mut const_evaluator)
+ .expect("successfully evaluated");
+
+ assert_eq!(
+ evaluated_expr, expected_expr,
+ "Mismatch evaluating {}\n Expected:{}\n Got:{}",
+ input_expr, expected_expr, evaluated_expr
+ );
+ }
}
diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs
index 0c9498a..03e0054 100644
--- a/datafusion/src/test_util.rs
+++ b/datafusion/src/test_util.rs
@@ -88,6 +88,52 @@ macro_rules! assert_batches_sorted_eq {
};
}
+/// A macro to assert that one string is contained within another with
+/// a nice error message if they are not.
+///
+/// Usage: `assert_contains!(actual, expected)`
+///
+/// Is a macro so test error
+/// messages are on the same line as the failure;
+///
+/// Both arguments must be convertable into Strings (Into<String>)
+#[macro_export]
+macro_rules! assert_contains {
+ ($ACTUAL: expr, $EXPECTED: expr) => {
+ let actual_value: String = $ACTUAL.into();
+ let expected_value: String = $EXPECTED.into();
+ assert!(
+ actual_value.contains(&expected_value),
+ "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
+ expected_value,
+ actual_value
+ );
+ };
+}
+
+/// A macro to assert that one string is NOT contained within another with
+/// a nice error message if they are are.
+///
+/// Usage: `assert_not_contains!(actual, unexpected)`
+///
+/// Is a macro so test error
+/// messages are on the same line as the failure;
+///
+/// Both arguments must be convertable into Strings (Into<String>)
+#[macro_export]
+macro_rules! assert_not_contains {
+ ($ACTUAL: expr, $UNEXPECTED: expr) => {
+ let actual_value: String = $ACTUAL.into();
+ let unexpected_value: String = $UNEXPECTED.into();
+ assert!(
+ !actual_value.contains(&unexpected_value),
+ "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
+ unexpected_value,
+ actual_value
+ );
+ };
+}
+
/// Returns the arrow test data directory, which is by default stored
/// in a git submodule rooted at `testing/data`.
///
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index e484152..f3dba3f 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -34,6 +34,8 @@ use arrow::{
use datafusion::assert_batches_eq;
use datafusion::assert_batches_sorted_eq;
+use datafusion::assert_contains;
+use datafusion::assert_not_contains;
use datafusion::logical_plan::LogicalPlan;
use datafusion::physical_plan::functions::Volatility;
use datafusion::physical_plan::metrics::MetricValue;
@@ -47,50 +49,6 @@ use datafusion::{
};
use datafusion::{execution::context::ExecutionContext,
physical_plan::displayable};
-/// A macro to assert that one string is contained within another with
-/// a nice error message if they are not.
-///
-/// Usage: `assert_contains!(actual, expected)`
-///
-/// Is a macro so test error
-/// messages are on the same line as the failure;
-///
-/// Both arguments must be convertable into Strings (Into<String>)
-macro_rules! assert_contains {
- ($ACTUAL: expr, $EXPECTED: expr) => {
- let actual_value: String = $ACTUAL.into();
- let expected_value: String = $EXPECTED.into();
- assert!(
- actual_value.contains(&expected_value),
- "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
- expected_value,
- actual_value
- );
- };
-}
-
-/// A macro to assert that one string is NOT contained within another with
-/// a nice error message if they are are.
-///
-/// Usage: `assert_not_contains!(actual, unexpected)`
-///
-/// Is a macro so test error
-/// messages are on the same line as the failure;
-///
-/// Both arguments must be convertable into Strings (Into<String>)
-macro_rules! assert_not_contains {
- ($ACTUAL: expr, $UNEXPECTED: expr) => {
- let actual_value: String = $ACTUAL.into();
- let unexpected_value: String = $UNEXPECTED.into();
- assert!(
- !actual_value.contains(&unexpected_value),
- "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
- unexpected_value,
- actual_value
- );
- };
-}
-
#[tokio::test]
async fn nyc() -> Result<()> {
// schema for nyxtaxi csv files