This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 847e78a Fix bugs with nullability during rewrites: Combine `simplify`
and `Simplifier` (#1401)
847e78a is described below
commit 847e78a675703c24933af5d6a429c2576bc14e9d
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jan 8 05:27:43 2022 -0500
Fix bugs with nullability during rewrites: Combine `simplify` and
`Simplifier` (#1401)
* Combine simplify and Simplifier
* Make nullable more functional
---
datafusion/src/optimizer/simplify_expressions.rs | 924 +++++++++++------------
1 file changed, 439 insertions(+), 485 deletions(-)
diff --git a/datafusion/src/optimizer/simplify_expressions.rs
b/datafusion/src/optimizer/simplify_expressions.rs
index ff2c05c..7040d34 100644
--- a/datafusion/src/optimizer/simplify_expressions.rs
+++ b/datafusion/src/optimizer/simplify_expressions.rs
@@ -45,33 +45,18 @@ use crate::{error::Result, logical_plan::Operator};
///
pub struct SimplifyExpressions {}
-fn expr_contains(expr: &Expr, needle: &Expr) -> bool {
+/// returns true if `needle` is found in a chain of search_op
+/// expressions. Such as: (A AND B) AND C
+fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
match expr {
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } => expr_contains(left, needle) || expr_contains(right, needle),
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } => expr_contains(left, needle) || expr_contains(right, needle),
+ Expr::BinaryExpr { left, op, right } if *op == search_op => {
+ expr_contains(left, needle, search_op)
+ || expr_contains(right, needle, search_op)
+ }
_ => expr == needle,
}
}
-fn as_binary_expr(expr: &Expr) -> Option<&Expr> {
- match expr {
- Expr::BinaryExpr { .. } => Some(expr),
- _ => None,
- }
-}
-
-fn operator_is_boolean(op: Operator) -> bool {
- op == Operator::And || op == Operator::Or
-}
-
fn is_one(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(1)))
@@ -95,6 +80,22 @@ fn is_true(expr: &Expr) -> bool {
}
}
+/// returns true if expr is a
+/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise
+fn is_bool_lit(expr: &Expr) -> bool {
+ matches!(expr, Expr::Literal(ScalarValue::Boolean(_)))
+}
+
+/// Return a literal NULL value
+fn lit_null() -> Expr {
+ Expr::Literal(ScalarValue::Boolean(None))
+}
+
+/// returns true if expr is a `Not(_)`, false otherwise
+fn is_not(expr: &Expr) -> bool {
+ matches!(expr, Expr::Not(_))
+}
+
fn is_null(expr: &Expr) -> bool {
match expr {
Expr::Literal(v) => v.is_null(),
@@ -109,160 +110,27 @@ fn is_false(expr: &Expr) -> bool {
}
}
-fn simplify(expr: &Expr) -> Expr {
- match expr {
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } if is_true(left) || is_true(right) => lit(true),
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } if is_false(left) => simplify(right),
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } if is_false(right) => simplify(left),
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } if left == right => simplify(left),
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } if is_false(left) || is_false(right) => lit(false),
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } if is_true(right) => simplify(left),
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } if is_true(left) => simplify(right),
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } if left == right => simplify(right),
- Expr::BinaryExpr {
- left,
- op: Operator::Multiply,
- right,
- } if is_one(left) => simplify(right),
- Expr::BinaryExpr {
- left,
- op: Operator::Multiply,
- right,
- } if is_one(right) => simplify(left),
- Expr::BinaryExpr {
- left,
- op: Operator::Divide,
- right,
- } if is_one(right) => simplify(left),
- Expr::BinaryExpr {
- left,
- op: Operator::Divide,
- right,
- } if left == right && is_null(left) => *left.clone(),
- Expr::BinaryExpr {
- left,
- op: Operator::Divide,
- right,
- } if left == right => lit(1),
+/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
+fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
+ match haystack {
Expr::BinaryExpr { left, op, right }
- if left == right && operator_is_boolean(*op) =>
+ if op == &target_op
+ && (needle == left.as_ref() || needle == right.as_ref()) =>
{
- simplify(left)
+ true
}
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } if expr_contains(left, right) => as_binary_expr(left)
- .map(|x| match x {
- Expr::BinaryExpr {
- left: _,
- op: Operator::Or,
- right: _,
- } => simplify(&x.clone()),
- Expr::BinaryExpr {
- left: _,
- op: Operator::And,
- right: _,
- } => simplify(&*right.clone()),
- _ => expr.clone(),
- })
- .unwrap_or_else(|| expr.clone()),
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- } if expr_contains(right, left) => as_binary_expr(right)
- .map(|x| match x {
- Expr::BinaryExpr {
- left: _,
- op: Operator::Or,
- right: _,
- } => simplify(&*right.clone()),
- Expr::BinaryExpr {
- left: _,
- op: Operator::And,
- right: _,
- } => simplify(&*left.clone()),
- _ => expr.clone(),
- })
- .unwrap_or_else(|| expr.clone()),
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } if expr_contains(left, right) => as_binary_expr(left)
- .map(|x| match x {
- Expr::BinaryExpr {
- left: _,
- op: Operator::Or,
- right: _,
- } => simplify(&*right.clone()),
- Expr::BinaryExpr {
- left: _,
- op: Operator::And,
- right: _,
- } => simplify(&x.clone()),
- _ => expr.clone(),
- })
- .unwrap_or_else(|| expr.clone()),
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- } if expr_contains(right, left) => as_binary_expr(right)
- .map(|x| match x {
- Expr::BinaryExpr {
- left: _,
- op: Operator::Or,
- right: _,
- } => simplify(&*left.clone()),
- Expr::BinaryExpr {
- left: _,
- op: Operator::And,
- right: _,
- } => simplify(&x.clone()),
- _ => expr.clone(),
- })
- .unwrap_or_else(|| expr.clone()),
- Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
- left: Box::new(simplify(left)),
- op: *op,
- right: Box::new(simplify(right)),
- },
- _ => expr.clone(),
+ _ => false,
+ }
+}
+
+/// returns the contained boolean value in `expr` as
+/// `Expr::Literal(ScalarValue::Boolean(v))`.
+///
+/// panics if expr is not a literal boolean
+fn as_bool_lit(expr: Expr) -> Option<bool> {
+ match expr {
+ Expr::Literal(ScalarValue::Boolean(v)) => v,
+ _ => panic!("Expected boolean literal, got {:?}", expr),
}
}
@@ -281,11 +149,9 @@ impl OptimizerRule for SimplifyExpressions {
// projected columns. With just the projected schema, it's not
possible to infer types for
// expressions that references non-projected columns within the same
project plan or its
// children plans.
- let mut simplifier =
- super::simplify_expressions::Simplifier::new(plan.all_schemas());
+ let mut simplifier = Simplifier::new(plan.all_schemas());
- let mut const_evaluator =
- super::simplify_expressions::ConstEvaluator::new(execution_props);
+ let mut const_evaluator = ConstEvaluator::new(execution_props);
let new_inputs = plan
.inputs()
@@ -301,9 +167,6 @@ impl OptimizerRule for SimplifyExpressions {
// Constant folding should not change expression name.
let name = &e.name(plan.schema());
- // TODO combine simplify into Simplifier
- let e = simplify(&e);
-
// TODO iterate until no changes are made
// during rewrite (evaluating constants can
// enable new simplifications and
@@ -316,7 +179,6 @@ impl OptimizerRule for SimplifyExpressions {
let new_name = &new_e.name(plan.schema());
- // TODO simplify this logic
if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) {
if expr_name != new_expr_name {
Ok(new_e.alias(expr_name))
@@ -554,212 +416,252 @@ impl<'a> Simplifier<'a> {
false
}
- fn boolean_folding_for_or(
- const_bool: &Option<bool>,
- bool_expr: Box<Expr>,
- left_right_order: bool,
- ) -> Expr {
- // See if we can fold 'const_bool OR bool_expr' to a constant boolean
- match const_bool {
- // TRUE or expr (including NULL) = TRUE
- Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))),
- // FALSE or expr (including NULL) = expr
- Some(false) => *bool_expr,
- None => match *bool_expr {
- // NULL or TRUE = TRUE
- Expr::Literal(ScalarValue::Boolean(Some(true))) => {
- Expr::Literal(ScalarValue::Boolean(Some(true)))
- }
- // NULL or FALSE = NULL
- Expr::Literal(ScalarValue::Boolean(Some(false))) => {
- Expr::Literal(ScalarValue::Boolean(None))
- }
- // NULL or NULL = NULL
- Expr::Literal(ScalarValue::Boolean(None)) => {
- Expr::Literal(ScalarValue::Boolean(None))
- }
- // NULL or expr can be either NULL or TRUE
- // So let us not rewrite it
- _ => {
- let mut left =
-
Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool)));
- let mut right = bool_expr;
- if !left_right_order {
- std::mem::swap(&mut left, &mut right);
- }
-
- Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- }
- }
- },
- }
- }
-
- fn boolean_folding_for_and(
- const_bool: &Option<bool>,
- bool_expr: Box<Expr>,
- left_right_order: bool,
- ) -> Expr {
- // See if we can fold 'const_bool AND bool_expr' to a constant boolean
- match const_bool {
- // TRUE and expr (including NULL) = expr
- Some(true) => *bool_expr,
- // FALSE and expr (including NULL) = FALSE
- Some(false) => Expr::Literal(ScalarValue::Boolean(Some(false))),
- None => match *bool_expr {
- // NULL and TRUE = NULL
- Expr::Literal(ScalarValue::Boolean(Some(true))) => {
- Expr::Literal(ScalarValue::Boolean(None))
- }
- // NULL and FALSE = FALSE
- Expr::Literal(ScalarValue::Boolean(Some(false))) => {
- Expr::Literal(ScalarValue::Boolean(Some(false)))
- }
- // NULL and NULL = NULL
- Expr::Literal(ScalarValue::Boolean(None)) => {
- Expr::Literal(ScalarValue::Boolean(None))
- }
- // NULL and expr can either be NULL or FALSE
- // So let us not rewrite it
- _ => {
- let mut left =
-
Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool)));
- let mut right = bool_expr;
- if !left_right_order {
- std::mem::swap(&mut left, &mut right);
- }
-
- Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- }
- }
- },
- }
+ /// Returns true if expr is nullable
+ fn nullable(&self, expr: &Expr) -> Result<bool> {
+ self.schemas
+ .iter()
+ .find_map(|schema| {
+ // expr may be from another input, so ignore errors
+ // by converting to None to keep trying
+ expr.nullable(schema.as_ref()).ok()
+ })
+ .ok_or_else(|| {
+ // This means we weren't able to compute `Expr::nullable` with
+ // *any* input schemas, signalling a problem
+ DataFusionError::Internal(format!(
+ "Could not find find columns in '{}' during simplify",
+ expr
+ ))
+ })
}
}
impl<'a> ExprRewriter for Simplifier<'a> {
/// rewrite the expression simplifying any constant expressions
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ use Expr::*;
+ use Operator::{And, Divide, Eq, Multiply, NotEq, Or};
+
let new_expr = match expr {
- Expr::BinaryExpr { left, op, right } => match op {
- Operator::Eq => match (left.as_ref(), right.as_ref()) {
- (
- Expr::Literal(ScalarValue::Boolean(l)),
- Expr::Literal(ScalarValue::Boolean(r)),
- ) => match (l, r) {
- (Some(l), Some(r)) => {
- Expr::Literal(ScalarValue::Boolean(Some(l == r)))
- }
- _ => Expr::Literal(ScalarValue::Boolean(None)),
- },
- (Expr::Literal(ScalarValue::Boolean(b)), _)
- if self.is_boolean_type(&right) =>
- {
- match b {
- Some(true) => *right,
- Some(false) => Expr::Not(right),
- None => Expr::Literal(ScalarValue::Boolean(None)),
- }
- }
- (_, Expr::Literal(ScalarValue::Boolean(b)))
- if self.is_boolean_type(&left) =>
- {
- match b {
- Some(true) => *left,
- Some(false) => Expr::Not(left),
- None => Expr::Literal(ScalarValue::Boolean(None)),
- }
- }
- _ => Expr::BinaryExpr {
- left,
- op: Operator::Eq,
- right,
- },
- },
- Operator::NotEq => match (left.as_ref(), right.as_ref()) {
- (
- Expr::Literal(ScalarValue::Boolean(l)),
- Expr::Literal(ScalarValue::Boolean(r)),
- ) => match (l, r) {
- (Some(l), Some(r)) => {
- Expr::Literal(ScalarValue::Boolean(Some(l != r)))
- }
- _ => Expr::Literal(ScalarValue::Boolean(None)),
- },
- (Expr::Literal(ScalarValue::Boolean(b)), _)
- if self.is_boolean_type(&right) =>
- {
- match b {
- Some(true) => Expr::Not(right),
- Some(false) => *right,
- None => Expr::Literal(ScalarValue::Boolean(None)),
- }
- }
- (_, Expr::Literal(ScalarValue::Boolean(b)))
- if self.is_boolean_type(&left) =>
- {
- match b {
- Some(true) => Expr::Not(left),
- Some(false) => *left,
- None => Expr::Literal(ScalarValue::Boolean(None)),
- }
- }
- _ => Expr::BinaryExpr {
- left,
- op: Operator::NotEq,
- right,
- },
- },
- Operator::Or => match (left.as_ref(), right.as_ref()) {
- (Expr::Literal(ScalarValue::Boolean(b)), _)
- if self.is_boolean_type(&right) =>
- {
- Self::boolean_folding_for_or(b, right, true)
- }
- (_, Expr::Literal(ScalarValue::Boolean(b)))
- if self.is_boolean_type(&left) =>
- {
- Self::boolean_folding_for_or(b, left, false)
- }
- _ => Expr::BinaryExpr {
- left,
- op: Operator::Or,
- right,
- },
- },
- Operator::And => match (left.as_ref(), right.as_ref()) {
- (Expr::Literal(ScalarValue::Boolean(b)), _)
- if self.is_boolean_type(&right) =>
- {
- Self::boolean_folding_for_and(b, right, true)
- }
- (_, Expr::Literal(ScalarValue::Boolean(b)))
- if self.is_boolean_type(&left) =>
- {
- Self::boolean_folding_for_and(b, left, false)
- }
- _ => Expr::BinaryExpr {
- left,
- op: Operator::And,
- right,
- },
- },
- _ => Expr::BinaryExpr { left, op, right },
- },
- // Not(Not(expr)) --> expr
- Expr::Not(inner) => {
- if let Expr::Not(negated_inner) = *inner {
- *negated_inner
- } else {
- Expr::Not(inner)
+ //
+ // Rules for Eq
+ //
+
+ // true = A --> A
+ // false = A --> !A
+ // null = A --> null
+ BinaryExpr {
+ left,
+ op: Eq,
+ right,
+ } if is_bool_lit(&left) && self.is_boolean_type(&right) => {
+ match as_bool_lit(*left) {
+ Some(true) => *right,
+ Some(false) => Not(right),
+ None => lit_null(),
}
}
+ // A = true --> A
+ // A = false --> !A
+ // A = null --> null
+ BinaryExpr {
+ left,
+ op: Eq,
+ right,
+ } if is_bool_lit(&right) && self.is_boolean_type(&left) => {
+ match as_bool_lit(*right) {
+ Some(true) => *left,
+ Some(false) => Not(left),
+ None => lit_null(),
+ }
+ }
+
+ //
+ // Rules for NotEq
+ //
+
+ // true != A --> !A
+ // false != A --> A
+ // null != A --> null
+ BinaryExpr {
+ left,
+ op: NotEq,
+ right,
+ } if is_bool_lit(&left) && self.is_boolean_type(&right) => {
+ match as_bool_lit(*left) {
+ Some(true) => Not(right),
+ Some(false) => *right,
+ None => lit_null(),
+ }
+ }
+ // A != true --> !A
+ // A != false --> A
+ // A != null --> null,
+ BinaryExpr {
+ left,
+ op: NotEq,
+ right,
+ } if is_bool_lit(&right) && self.is_boolean_type(&left) => {
+ match as_bool_lit(*right) {
+ Some(true) => Not(left),
+ Some(false) => *left,
+ None => lit_null(),
+ }
+ }
+
+ //
+ // Rules for OR
+ //
+
+ // true OR A --> true (even if A is null)
+ BinaryExpr {
+ left,
+ op: Or,
+ right: _,
+ } if is_true(&left) => *left,
+ // false OR A --> A
+ BinaryExpr {
+ left,
+ op: Or,
+ right,
+ } if is_false(&left) => *right,
+ // A OR true --> true (even if A is null)
+ BinaryExpr {
+ left: _,
+ op: Or,
+ right,
+ } if is_true(&right) => *right,
+ // A OR false --> A
+ BinaryExpr {
+ left,
+ op: Or,
+ right,
+ } if is_false(&right) => *left,
+ // (..A..) OR A --> (..A..)
+ BinaryExpr {
+ left,
+ op: Or,
+ right,
+ } if expr_contains(&left, &right, Or) => *left,
+ // A OR (..A..) --> (..A..)
+ BinaryExpr {
+ left,
+ op: Or,
+ right,
+ } if expr_contains(&right, &left, Or) => *right,
+ // A OR (A AND B) --> A (if B not null)
+ BinaryExpr {
+ left,
+ op: Or,
+ right,
+ } if !self.nullable(&right)? && is_op_with(And, &right, &left) =>
*left,
+ // (A AND B) OR A --> A (if B not null)
+ BinaryExpr {
+ left,
+ op: Or,
+ right,
+ } if !self.nullable(&left)? && is_op_with(And, &left, &right) =>
*right,
+
+ //
+ // Rules for AND
+ //
+
+ // true AND A --> A
+ BinaryExpr {
+ left,
+ op: And,
+ right,
+ } if is_true(&left) => *right,
+ // false AND A --> false (even if A is null)
+ BinaryExpr {
+ left,
+ op: And,
+ right: _,
+ } if is_false(&left) => *left,
+ // A AND true --> A
+ BinaryExpr {
+ left,
+ op: And,
+ right,
+ } if is_true(&right) => *left,
+ // A AND false --> false (even if A is null)
+ BinaryExpr {
+ left: _,
+ op: And,
+ right,
+ } if is_false(&right) => *right,
+ // (..A..) AND A --> (..A..)
+ BinaryExpr {
+ left,
+ op: And,
+ right,
+ } if expr_contains(&left, &right, And) => *left,
+ // A AND (..A..) --> (..A..)
+ BinaryExpr {
+ left,
+ op: And,
+ right,
+ } if expr_contains(&right, &left, And) => *right,
+ // A AND (A OR B) --> A (if B not null)
+ BinaryExpr {
+ left,
+ op: And,
+ right,
+ } if !self.nullable(&right)? && is_op_with(Or, &right, &left) =>
*left,
+ // (A OR B) AND A --> A (if B not null)
+ BinaryExpr {
+ left,
+ op: And,
+ right,
+ } if !self.nullable(&left)? && is_op_with(Or, &left, &right) =>
*right,
+
+ //
+ // Rules for Multiply
+ //
+ BinaryExpr {
+ left,
+ op: Multiply,
+ right,
+ } if is_one(&right) => *left,
+ BinaryExpr {
+ left,
+ op: Multiply,
+ right,
+ } if is_one(&left) => *right,
+
+ //
+ // Rules for Divide
+ //
+
+ // A / 1 --> A
+ BinaryExpr {
+ left,
+ op: Divide,
+ right,
+ } if is_one(&right) => *left,
+ // A / null --> null
+ BinaryExpr {
+ left,
+ op: Divide,
+ right,
+ } if left == right && is_null(&left) => *left,
+ // A / A --> 1 (if a is not nullable)
+ BinaryExpr {
+ left,
+ op: Divide,
+ right,
+ } if !self.nullable(&left)? && left == right => lit(1),
+
+ //
+ // Rules for Not
+ //
+
+ // !(!A) --> A
+ Not(inner) if is_not(&inner) => match *inner {
+ Not(negated_inner) => *negated_inner,
+ _ => unreachable!(),
+ },
+
expr => {
// no additional rewrites possible
expr
@@ -791,8 +693,8 @@ mod tests {
let expr_b = lit(true).or(col("c2"));
let expected = lit(true);
- assert_eq!(simplify(&expr_a), expected);
- assert_eq!(simplify(&expr_b), expected);
+ assert_eq!(simplify(expr_a), expected);
+ assert_eq!(simplify(expr_b), expected);
}
#[test]
@@ -801,8 +703,8 @@ mod tests {
let expr_b = col("c2").or(lit(false));
let expected = col("c2");
- assert_eq!(simplify(&expr_a), expected);
- assert_eq!(simplify(&expr_b), expected);
+ assert_eq!(simplify(expr_a), expected);
+ assert_eq!(simplify(expr_b), expected);
}
#[test]
@@ -810,7 +712,7 @@ mod tests {
let expr = col("c2").or(col("c2"));
let expected = col("c2");
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
@@ -819,8 +721,8 @@ mod tests {
let expr_b = col("c2").and(lit(false));
let expected = lit(false);
- assert_eq!(simplify(&expr_a), expected);
- assert_eq!(simplify(&expr_b), expected);
+ assert_eq!(simplify(expr_a), expected);
+ assert_eq!(simplify(expr_b), expected);
}
#[test]
@@ -828,7 +730,7 @@ mod tests {
let expr = col("c2").and(col("c2"));
let expected = col("c2");
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
@@ -837,8 +739,8 @@ mod tests {
let expr_b = col("c2").and(lit(true));
let expected = col("c2");
- assert_eq!(simplify(&expr_a), expected);
- assert_eq!(simplify(&expr_b), expected);
+ assert_eq!(simplify(expr_a), expected);
+ assert_eq!(simplify(expr_b), expected);
}
#[test]
@@ -847,8 +749,8 @@ mod tests {
let expr_b = binary_expr(lit(1), Operator::Multiply, col("c2"));
let expected = col("c2");
- assert_eq!(simplify(&expr_a), expected);
- assert_eq!(simplify(&expr_b), expected);
+ assert_eq!(simplify(expr_a), expected);
+ assert_eq!(simplify(expr_b), expected);
}
#[test]
@@ -856,15 +758,24 @@ mod tests {
let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
let expected = col("c2");
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
fn test_simplify_divide_by_same() {
let expr = binary_expr(col("c2"), Operator::Divide, col("c2"));
+ // if c2 is null, c2 / c2 = null, so can't simplify
+ let expected = expr.clone();
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_divide_by_same_non_null() {
+ let expr = binary_expr(col("c2_non_null"), Operator::Divide,
col("c2_non_null"));
let expected = lit(1);
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
@@ -873,21 +784,21 @@ mod tests {
let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
let expected = col("c2").gt(lit(5));
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
fn test_simplify_composed_and() {
- // ((c > 5) AND (d < 6)) AND (c > 5)
+ // ((c > 5) AND (c1 < 6)) AND (c > 5)
let expr = binary_expr(
- binary_expr(col("c2").gt(lit(5)), Operator::And,
col("d").lt(lit(6))),
+ binary_expr(col("c2").gt(lit(5)), Operator::And,
col("c1").lt(lit(6))),
Operator::And,
col("c2").gt(lit(5)),
);
let expected =
- binary_expr(col("c2").gt(lit(5)), Operator::And,
col("d").lt(lit(6)));
+ binary_expr(col("c2").gt(lit(5)), Operator::And,
col("c1").lt(lit(6)));
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
@@ -900,20 +811,91 @@ mod tests {
);
let expected = expr.clone();
- assert_eq!(simplify(&expr), expected);
+ assert_eq!(simplify(expr), expected);
}
#[test]
fn test_simplify_or_and() {
- // (c > 5) OR ((d < 6) AND (c > 5) -- can remove
- let expr = binary_expr(
- col("c2").gt(lit(5)),
+ let l = col("c2").gt(lit(5));
+ let r = binary_expr(col("c1").lt(lit(6)), Operator::And,
col("c2").gt(lit(5)));
+
+ // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
+ let expr = binary_expr(l.clone(), Operator::Or, r.clone());
+
+ // no rewrites if c1 can be null
+ let expected = expr.clone();
+ assert_eq!(simplify(expr), expected);
+
+ // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
+ let expr = binary_expr(l, Operator::Or, r);
+
+ // no rewrites if c1 can be null
+ let expected = expr.clone();
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_or_and_non_null() {
+ let l = col("c2_non_null").gt(lit(5));
+ let r = binary_expr(
+ col("c1_non_null").lt(lit(6)),
+ Operator::And,
+ col("c2_non_null").gt(lit(5)),
+ );
+
+ // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
+ let expr = binary_expr(l.clone(), Operator::Or, r.clone());
+
+ // This is only true if `c1 < 6` is not nullable / can not be null.
+ let expected = col("c2_non_null").gt(lit(5));
+
+ assert_eq!(simplify(expr), expected);
+
+ // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
+ let expr = binary_expr(l, Operator::Or, r);
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_and_or() {
+ let l = col("c2").gt(lit(5));
+ let r = binary_expr(col("c1").lt(lit(6)), Operator::Or,
col("c2").gt(lit(5)));
+
+ // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
+ let expr = binary_expr(l.clone(), Operator::And, r.clone());
+
+ // no rewrites if c1 can be null
+ let expected = expr.clone();
+ assert_eq!(simplify(expr), expected);
+
+ // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
+ let expr = binary_expr(l, Operator::And, r);
+ let expected = expr.clone();
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_and_or_non_null() {
+ let l = col("c2_non_null").gt(lit(5));
+ let r = binary_expr(
+ col("c1_non_null").lt(lit(6)),
Operator::Or,
- binary_expr(col("d").lt(lit(6)), Operator::And,
col("c2").gt(lit(5))),
+ col("c2_non_null").gt(lit(5)),
);
- let expected = col("c2").gt(lit(5));
- assert_eq!(simplify(&expr), expected);
+ // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
+ let expr = binary_expr(l.clone(), Operator::And, r.clone());
+
+ // This is only true if `c1 < 6` is not nullable / can not be null.
+ let expected = col("c2_non_null").gt(lit(5));
+
+ assert_eq!(simplify(expr), expected);
+
+ // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
+ let expr = binary_expr(l, Operator::And, r);
+
+ assert_eq!(simplify(expr), expected);
}
#[test]
@@ -921,7 +903,7 @@ mod tests {
let expr = binary_expr(lit_null(), Operator::And, lit(false));
let expr_eq = lit(false);
- assert_eq!(simplify(&expr), expr_eq);
+ assert_eq!(simplify(expr), expr_eq);
}
#[test]
@@ -930,16 +912,16 @@ mod tests {
let expr_plus = binary_expr(null.clone(), Operator::Divide,
null.clone());
let expr_eq = null;
- assert_eq!(simplify(&expr_plus), expr_eq);
+ assert_eq!(simplify(expr_plus), expr_eq);
}
#[test]
- fn test_simplify_do_not_simplify_arithmetic_expr() {
+ fn test_simplify_simplify_arithmetic_expr() {
let expr_plus = binary_expr(lit(1), Operator::Plus, lit(1));
let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
- assert_eq!(simplify(&expr_plus), expr_plus);
- assert_eq!(simplify(&expr_eq), expr_eq);
+ assert_eq!(simplify(expr_plus), lit(2));
+ assert_eq!(simplify(expr_eq), lit(true));
}
// ------------------------------
@@ -1182,11 +1164,17 @@ mod tests {
// ----- Simplifier tests -------
// ------------------------------
- // TODO rename to simplify
- fn do_simplify(expr: Expr) -> Expr {
+ fn simplify(expr: Expr) -> Expr {
let schema = expr_test_schema();
let mut rewriter = Simplifier::new(vec![&schema]);
- expr.rewrite(&mut rewriter).expect("expected to simplify")
+
+ let execution_props = ExecutionProps::new();
+ let mut const_evaluator = ConstEvaluator::new(&execution_props);
+
+ expr.rewrite(&mut rewriter)
+ .expect("expected to simplify")
+ .rewrite(&mut const_evaluator)
+ .expect("expected to const evaluate")
}
fn expr_test_schema() -> DFSchemaRef {
@@ -1194,6 +1182,8 @@ mod tests {
DFSchema::new(vec![
DFField::new(None, "c1", DataType::Utf8, true),
DFField::new(None, "c2", DataType::Boolean, true),
+ DFField::new(None, "c1_non_null", DataType::Utf8, false),
+ DFField::new(None, "c2_non_null", DataType::Boolean, false),
])
.unwrap(),
)
@@ -1201,20 +1191,20 @@ mod tests {
#[test]
fn simplify_expr_not_not() {
- assert_eq!(do_simplify(col("c2").not().not().not()), col("c2").not(),);
+ assert_eq!(simplify(col("c2").not().not().not()), col("c2").not(),);
}
#[test]
fn simplify_expr_null_comparison() {
// x = null is always null
assert_eq!(
- do_simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
+ simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
lit(ScalarValue::Boolean(None)),
);
// null != null is always null
assert_eq!(
- do_simplify(
+ simplify(
lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
),
lit(ScalarValue::Boolean(None)),
@@ -1222,13 +1212,13 @@ mod tests {
// x != null is always null
assert_eq!(
- do_simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
+ simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
lit(ScalarValue::Boolean(None)),
);
// null = x is always null
assert_eq!(
- do_simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
+ simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
lit(ScalarValue::Boolean(None)),
);
}
@@ -1239,16 +1229,16 @@ mod tests {
assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
// true = ture -> true
- assert_eq!(do_simplify(lit(true).eq(lit(true))), lit(true));
+ assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
// true = false -> false
- assert_eq!(do_simplify(lit(true).eq(lit(false))), lit(false),);
+ assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
// c2 = true -> c2
- assert_eq!(do_simplify(col("c2").eq(lit(true))), col("c2"));
+ assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
// c2 = false => !c2
- assert_eq!(do_simplify(col("c2").eq(lit(false))), col("c2").not(),);
+ assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
}
#[test]
@@ -1262,25 +1252,8 @@ mod tests {
// Make sure c1 column to be used in tests is not boolean type
assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
- // don't fold c1 = true
- assert_eq!(
- do_simplify(col("c1").eq(lit(true))),
- col("c1").eq(lit(true)),
- );
-
- // don't fold c1 = false
- assert_eq!(
- do_simplify(col("c1").eq(lit(false))),
- col("c1").eq(lit(false)),
- );
-
- // test constant operands
- assert_eq!(do_simplify(lit(1).eq(lit(true))), lit(1).eq(lit(true)),);
-
- assert_eq!(
- do_simplify(lit("a").eq(lit(false))),
- lit("a").eq(lit(false)),
- );
+ // don't fold c1 = foo
+ assert_eq!(simplify(col("c1").eq(lit("foo"))),
col("c1").eq(lit("foo")),);
}
#[test]
@@ -1290,15 +1263,15 @@ mod tests {
assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
// c2 != true -> !c2
- assert_eq!(do_simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
+ assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
// c2 != false -> c2
- assert_eq!(do_simplify(col("c2").not_eq(lit(false))), col("c2"),);
+ assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
// test constant
- assert_eq!(do_simplify(lit(true).not_eq(lit(true))), lit(false),);
+ assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
- assert_eq!(do_simplify(lit(true).not_eq(lit(false))), lit(true),);
+ assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
}
#[test]
@@ -1311,44 +1284,25 @@ mod tests {
assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
assert_eq!(
- do_simplify(col("c1").not_eq(lit(true))),
- col("c1").not_eq(lit(true)),
- );
-
- assert_eq!(
- do_simplify(col("c1").not_eq(lit(false))),
- col("c1").not_eq(lit(false)),
- );
-
- // test constants
- assert_eq!(
- do_simplify(lit(1).not_eq(lit(true))),
- lit(1).not_eq(lit(true)),
- );
-
- assert_eq!(
- do_simplify(lit("a").not_eq(lit(false))),
- lit("a").not_eq(lit(false)),
+ simplify(col("c1").not_eq(lit("foo"))),
+ col("c1").not_eq(lit("foo")),
);
}
#[test]
fn simplify_expr_case_when_then_else() {
assert_eq!(
- do_simplify(Expr::Case {
+ simplify(Expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(col("c2").not_eq(lit(false))),
- Box::new(lit("ok").eq(lit(true))),
+ Box::new(lit("ok").eq(lit("not_ok"))),
)],
else_expr: Some(Box::new(col("c2").eq(lit(true)))),
}),
Expr::Case {
expr: None,
- when_then_expr: vec![(
- Box::new(col("c2")),
- Box::new(lit("ok").eq(lit(true)))
- )],
+ when_then_expr: vec![(Box::new(col("c2")),
Box::new(lit(false)))],
else_expr: Some(Box::new(col("c2"))),
}
);
@@ -1362,22 +1316,22 @@ mod tests {
#[test]
fn simplify_expr_bool_or() {
// col || true is always true
- assert_eq!(do_simplify(col("c2").or(lit(true))), lit(true),);
+ assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
// col || false is always col
- assert_eq!(do_simplify(col("c2").or(lit(false))), col("c2"),);
+ assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
// true || null is always true
- assert_eq!(do_simplify(lit(true).or(lit_null())), lit(true),);
+ assert_eq!(simplify(lit(true).or(lit_null())), lit(true),);
// null || true is always true
- assert_eq!(do_simplify(lit_null().or(lit(true))), lit(true),);
+ assert_eq!(simplify(lit_null().or(lit(true))), lit(true),);
// false || null is always null
- assert_eq!(do_simplify(lit(false).or(lit_null())), lit_null(),);
+ assert_eq!(simplify(lit(false).or(lit_null())), lit_null(),);
// null || false is always null
- assert_eq!(do_simplify(lit_null().or(lit(false))), lit_null(),);
+ assert_eq!(simplify(lit_null().or(lit(false))), lit_null(),);
// ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
// it can be either NULL or TRUE depending on the value of `c1
BETWEEN Int32(0) AND Int32(10)`
@@ -1389,28 +1343,28 @@ mod tests {
high: Box::new(lit(10)),
};
let expr = expr.or(lit_null());
- let result = do_simplify(expr.clone());
+ let result = simplify(expr.clone());
assert_eq!(expr, result);
}
#[test]
fn simplify_expr_bool_and() {
// col & true is always col
- assert_eq!(do_simplify(col("c2").and(lit(true))), col("c2"),);
+ assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
// col & false is always false
- assert_eq!(do_simplify(col("c2").and(lit(false))), lit(false),);
+ assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
// true && null is always null
- assert_eq!(do_simplify(lit(true).and(lit_null())), lit_null(),);
+ assert_eq!(simplify(lit(true).and(lit_null())), lit_null(),);
// null && true is always null
- assert_eq!(do_simplify(lit_null().and(lit(true))), lit_null(),);
+ assert_eq!(simplify(lit_null().and(lit(true))), lit_null(),);
// false && null is always false
- assert_eq!(do_simplify(lit(false).and(lit_null())), lit(false),);
+ assert_eq!(simplify(lit(false).and(lit_null())), lit(false),);
// null && false is always false
- assert_eq!(do_simplify(lit_null().and(lit(false))), lit(false),);
+ assert_eq!(simplify(lit_null().and(lit(false))), lit(false),);
// c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
// it can be either NULL or FALSE depending on the value of `c1
BETWEEN Int32(0) AND Int32(10`
@@ -1422,7 +1376,7 @@ mod tests {
high: Box::new(lit(10)),
};
let expr = expr.and(lit_null());
- let result = do_simplify(expr.clone());
+ let result = simplify(expr.clone());
assert_eq!(expr, result);
}
@@ -1473,12 +1427,12 @@ mod tests {
);
}
- // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
#[test]
fn test_simplify_optimized_plan_with_composed_and() {
let table_scan = test_table_scan();
+ // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
let plan = LogicalPlanBuilder::from(table_scan)
- .project(vec![col("a")])
+ .project(vec![col("a"), col("b")])
.unwrap()
.filter(and(
and(col("a").gt(lit(5)), col("b").lt(lit(6))),
@@ -1492,7 +1446,7 @@ mod tests {
&plan,
"\
Filter: #test.a > Int32(5) AND #test.b < Int32(6) AS test.a >
Int32(5) AND test.b < Int32(6) AND test.a > Int32(5)\
- \n Projection: #test.a\
+ \n Projection: #test.a, #test.b\
\n TableScan: test projection=None",
);
}