This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 847e78a  Fix bugs with nullability during rewrites: Combine `simplify` 
and `Simplifier` (#1401)
847e78a is described below

commit 847e78a675703c24933af5d6a429c2576bc14e9d
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jan 8 05:27:43 2022 -0500

    Fix bugs with nullability during rewrites: Combine `simplify` and 
`Simplifier` (#1401)
    
    * Combine simplify and Simplifier
    
    * Make nullable more functional
---
 datafusion/src/optimizer/simplify_expressions.rs | 924 +++++++++++------------
 1 file changed, 439 insertions(+), 485 deletions(-)

diff --git a/datafusion/src/optimizer/simplify_expressions.rs 
b/datafusion/src/optimizer/simplify_expressions.rs
index ff2c05c..7040d34 100644
--- a/datafusion/src/optimizer/simplify_expressions.rs
+++ b/datafusion/src/optimizer/simplify_expressions.rs
@@ -45,33 +45,18 @@ use crate::{error::Result, logical_plan::Operator};
 ///
 pub struct SimplifyExpressions {}
 
-fn expr_contains(expr: &Expr, needle: &Expr) -> bool {
+/// returns true if `needle` is found in a chain of search_op
+/// expressions. Such as: (A AND B) AND C
+fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
     match expr {
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } => expr_contains(left, needle) || expr_contains(right, needle),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } => expr_contains(left, needle) || expr_contains(right, needle),
+        Expr::BinaryExpr { left, op, right } if *op == search_op => {
+            expr_contains(left, needle, search_op)
+                || expr_contains(right, needle, search_op)
+        }
         _ => expr == needle,
     }
 }
 
-fn as_binary_expr(expr: &Expr) -> Option<&Expr> {
-    match expr {
-        Expr::BinaryExpr { .. } => Some(expr),
-        _ => None,
-    }
-}
-
-fn operator_is_boolean(op: Operator) -> bool {
-    op == Operator::And || op == Operator::Or
-}
-
 fn is_one(s: &Expr) -> bool {
     match s {
         Expr::Literal(ScalarValue::Int8(Some(1)))
@@ -95,6 +80,22 @@ fn is_true(expr: &Expr) -> bool {
     }
 }
 
+/// returns true if expr is a
+/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise
+fn is_bool_lit(expr: &Expr) -> bool {
+    matches!(expr, Expr::Literal(ScalarValue::Boolean(_)))
+}
+
+/// Return a literal NULL value
+fn lit_null() -> Expr {
+    Expr::Literal(ScalarValue::Boolean(None))
+}
+
+/// returns true if expr is a `Not(_)`, false otherwise
+fn is_not(expr: &Expr) -> bool {
+    matches!(expr, Expr::Not(_))
+}
+
 fn is_null(expr: &Expr) -> bool {
     match expr {
         Expr::Literal(v) => v.is_null(),
@@ -109,160 +110,27 @@ fn is_false(expr: &Expr) -> bool {
     }
 }
 
-fn simplify(expr: &Expr) -> Expr {
-    match expr {
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } if is_true(left) || is_true(right) => lit(true),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } if is_false(left) => simplify(right),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } if is_false(right) => simplify(left),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } if left == right => simplify(left),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } if is_false(left) || is_false(right) => lit(false),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } if is_true(right) => simplify(left),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } if is_true(left) => simplify(right),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } if left == right => simplify(right),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Multiply,
-            right,
-        } if is_one(left) => simplify(right),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Multiply,
-            right,
-        } if is_one(right) => simplify(left),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Divide,
-            right,
-        } if is_one(right) => simplify(left),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Divide,
-            right,
-        } if left == right && is_null(left) => *left.clone(),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Divide,
-            right,
-        } if left == right => lit(1),
+/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
+fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
+    match haystack {
         Expr::BinaryExpr { left, op, right }
-            if left == right && operator_is_boolean(*op) =>
+            if op == &target_op
+                && (needle == left.as_ref() || needle == right.as_ref()) =>
         {
-            simplify(left)
+            true
         }
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } if expr_contains(left, right) => as_binary_expr(left)
-            .map(|x| match x {
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::Or,
-                    right: _,
-                } => simplify(&x.clone()),
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::And,
-                    right: _,
-                } => simplify(&*right.clone()),
-                _ => expr.clone(),
-            })
-            .unwrap_or_else(|| expr.clone()),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::Or,
-            right,
-        } if expr_contains(right, left) => as_binary_expr(right)
-            .map(|x| match x {
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::Or,
-                    right: _,
-                } => simplify(&*right.clone()),
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::And,
-                    right: _,
-                } => simplify(&*left.clone()),
-                _ => expr.clone(),
-            })
-            .unwrap_or_else(|| expr.clone()),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } if expr_contains(left, right) => as_binary_expr(left)
-            .map(|x| match x {
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::Or,
-                    right: _,
-                } => simplify(&*right.clone()),
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::And,
-                    right: _,
-                } => simplify(&x.clone()),
-                _ => expr.clone(),
-            })
-            .unwrap_or_else(|| expr.clone()),
-        Expr::BinaryExpr {
-            left,
-            op: Operator::And,
-            right,
-        } if expr_contains(right, left) => as_binary_expr(right)
-            .map(|x| match x {
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::Or,
-                    right: _,
-                } => simplify(&*left.clone()),
-                Expr::BinaryExpr {
-                    left: _,
-                    op: Operator::And,
-                    right: _,
-                } => simplify(&x.clone()),
-                _ => expr.clone(),
-            })
-            .unwrap_or_else(|| expr.clone()),
-        Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
-            left: Box::new(simplify(left)),
-            op: *op,
-            right: Box::new(simplify(right)),
-        },
-        _ => expr.clone(),
+        _ => false,
+    }
+}
+
+/// returns the contained boolean value in `expr` as
+/// `Expr::Literal(ScalarValue::Boolean(v))`.
+///
+/// panics if expr is not a literal boolean
+fn as_bool_lit(expr: Expr) -> Option<bool> {
+    match expr {
+        Expr::Literal(ScalarValue::Boolean(v)) => v,
+        _ => panic!("Expected boolean literal, got {:?}", expr),
     }
 }
 
@@ -281,11 +149,9 @@ impl OptimizerRule for SimplifyExpressions {
         // projected columns. With just the projected schema, it's not 
possible to infer types for
         // expressions that references non-projected columns within the same 
project plan or its
         // children plans.
-        let mut simplifier =
-            super::simplify_expressions::Simplifier::new(plan.all_schemas());
+        let mut simplifier = Simplifier::new(plan.all_schemas());
 
-        let mut const_evaluator =
-            super::simplify_expressions::ConstEvaluator::new(execution_props);
+        let mut const_evaluator = ConstEvaluator::new(execution_props);
 
         let new_inputs = plan
             .inputs()
@@ -301,9 +167,6 @@ impl OptimizerRule for SimplifyExpressions {
                 // Constant folding should not change expression name.
                 let name = &e.name(plan.schema());
 
-                // TODO combine simplify into Simplifier
-                let e = simplify(&e);
-
                 // TODO iterate until no changes are made
                 // during rewrite (evaluating constants can
                 // enable new simplifications and
@@ -316,7 +179,6 @@ impl OptimizerRule for SimplifyExpressions {
 
                 let new_name = &new_e.name(plan.schema());
 
-                // TODO simplify this logic
                 if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) {
                     if expr_name != new_expr_name {
                         Ok(new_e.alias(expr_name))
@@ -554,212 +416,252 @@ impl<'a> Simplifier<'a> {
         false
     }
 
-    fn boolean_folding_for_or(
-        const_bool: &Option<bool>,
-        bool_expr: Box<Expr>,
-        left_right_order: bool,
-    ) -> Expr {
-        // See if we can fold 'const_bool OR bool_expr' to a constant boolean
-        match const_bool {
-            // TRUE or expr (including NULL) = TRUE
-            Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))),
-            // FALSE or expr (including NULL) = expr
-            Some(false) => *bool_expr,
-            None => match *bool_expr {
-                // NULL or TRUE = TRUE
-                Expr::Literal(ScalarValue::Boolean(Some(true))) => {
-                    Expr::Literal(ScalarValue::Boolean(Some(true)))
-                }
-                // NULL or FALSE = NULL
-                Expr::Literal(ScalarValue::Boolean(Some(false))) => {
-                    Expr::Literal(ScalarValue::Boolean(None))
-                }
-                // NULL or NULL = NULL
-                Expr::Literal(ScalarValue::Boolean(None)) => {
-                    Expr::Literal(ScalarValue::Boolean(None))
-                }
-                // NULL or expr can be either NULL or TRUE
-                // So let us not rewrite it
-                _ => {
-                    let mut left =
-                        
Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool)));
-                    let mut right = bool_expr;
-                    if !left_right_order {
-                        std::mem::swap(&mut left, &mut right);
-                    }
-
-                    Expr::BinaryExpr {
-                        left,
-                        op: Operator::Or,
-                        right,
-                    }
-                }
-            },
-        }
-    }
-
-    fn boolean_folding_for_and(
-        const_bool: &Option<bool>,
-        bool_expr: Box<Expr>,
-        left_right_order: bool,
-    ) -> Expr {
-        // See if we can fold 'const_bool AND bool_expr' to a constant boolean
-        match const_bool {
-            // TRUE and expr (including NULL) = expr
-            Some(true) => *bool_expr,
-            // FALSE and expr (including NULL) = FALSE
-            Some(false) => Expr::Literal(ScalarValue::Boolean(Some(false))),
-            None => match *bool_expr {
-                // NULL and TRUE = NULL
-                Expr::Literal(ScalarValue::Boolean(Some(true))) => {
-                    Expr::Literal(ScalarValue::Boolean(None))
-                }
-                // NULL and FALSE = FALSE
-                Expr::Literal(ScalarValue::Boolean(Some(false))) => {
-                    Expr::Literal(ScalarValue::Boolean(Some(false)))
-                }
-                // NULL and NULL = NULL
-                Expr::Literal(ScalarValue::Boolean(None)) => {
-                    Expr::Literal(ScalarValue::Boolean(None))
-                }
-                // NULL and expr can either be NULL or FALSE
-                // So let us not rewrite it
-                _ => {
-                    let mut left =
-                        
Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool)));
-                    let mut right = bool_expr;
-                    if !left_right_order {
-                        std::mem::swap(&mut left, &mut right);
-                    }
-
-                    Expr::BinaryExpr {
-                        left,
-                        op: Operator::And,
-                        right,
-                    }
-                }
-            },
-        }
+    /// Returns true if expr is nullable
+    fn nullable(&self, expr: &Expr) -> Result<bool> {
+        self.schemas
+            .iter()
+            .find_map(|schema| {
+                // expr may be from another input, so ignore errors
+                // by converting to None to keep trying
+                expr.nullable(schema.as_ref()).ok()
+            })
+            .ok_or_else(|| {
+                // This means we weren't able to compute `Expr::nullable` with
+                // *any* input schemas, signalling a problem
+                DataFusionError::Internal(format!(
+                    "Could not find find columns in '{}' during simplify",
+                    expr
+                ))
+            })
     }
 }
 
 impl<'a> ExprRewriter for Simplifier<'a> {
     /// rewrite the expression simplifying any constant expressions
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        use Expr::*;
+        use Operator::{And, Divide, Eq, Multiply, NotEq, Or};
+
         let new_expr = match expr {
-            Expr::BinaryExpr { left, op, right } => match op {
-                Operator::Eq => match (left.as_ref(), right.as_ref()) {
-                    (
-                        Expr::Literal(ScalarValue::Boolean(l)),
-                        Expr::Literal(ScalarValue::Boolean(r)),
-                    ) => match (l, r) {
-                        (Some(l), Some(r)) => {
-                            Expr::Literal(ScalarValue::Boolean(Some(l == r)))
-                        }
-                        _ => Expr::Literal(ScalarValue::Boolean(None)),
-                    },
-                    (Expr::Literal(ScalarValue::Boolean(b)), _)
-                        if self.is_boolean_type(&right) =>
-                    {
-                        match b {
-                            Some(true) => *right,
-                            Some(false) => Expr::Not(right),
-                            None => Expr::Literal(ScalarValue::Boolean(None)),
-                        }
-                    }
-                    (_, Expr::Literal(ScalarValue::Boolean(b)))
-                        if self.is_boolean_type(&left) =>
-                    {
-                        match b {
-                            Some(true) => *left,
-                            Some(false) => Expr::Not(left),
-                            None => Expr::Literal(ScalarValue::Boolean(None)),
-                        }
-                    }
-                    _ => Expr::BinaryExpr {
-                        left,
-                        op: Operator::Eq,
-                        right,
-                    },
-                },
-                Operator::NotEq => match (left.as_ref(), right.as_ref()) {
-                    (
-                        Expr::Literal(ScalarValue::Boolean(l)),
-                        Expr::Literal(ScalarValue::Boolean(r)),
-                    ) => match (l, r) {
-                        (Some(l), Some(r)) => {
-                            Expr::Literal(ScalarValue::Boolean(Some(l != r)))
-                        }
-                        _ => Expr::Literal(ScalarValue::Boolean(None)),
-                    },
-                    (Expr::Literal(ScalarValue::Boolean(b)), _)
-                        if self.is_boolean_type(&right) =>
-                    {
-                        match b {
-                            Some(true) => Expr::Not(right),
-                            Some(false) => *right,
-                            None => Expr::Literal(ScalarValue::Boolean(None)),
-                        }
-                    }
-                    (_, Expr::Literal(ScalarValue::Boolean(b)))
-                        if self.is_boolean_type(&left) =>
-                    {
-                        match b {
-                            Some(true) => Expr::Not(left),
-                            Some(false) => *left,
-                            None => Expr::Literal(ScalarValue::Boolean(None)),
-                        }
-                    }
-                    _ => Expr::BinaryExpr {
-                        left,
-                        op: Operator::NotEq,
-                        right,
-                    },
-                },
-                Operator::Or => match (left.as_ref(), right.as_ref()) {
-                    (Expr::Literal(ScalarValue::Boolean(b)), _)
-                        if self.is_boolean_type(&right) =>
-                    {
-                        Self::boolean_folding_for_or(b, right, true)
-                    }
-                    (_, Expr::Literal(ScalarValue::Boolean(b)))
-                        if self.is_boolean_type(&left) =>
-                    {
-                        Self::boolean_folding_for_or(b, left, false)
-                    }
-                    _ => Expr::BinaryExpr {
-                        left,
-                        op: Operator::Or,
-                        right,
-                    },
-                },
-                Operator::And => match (left.as_ref(), right.as_ref()) {
-                    (Expr::Literal(ScalarValue::Boolean(b)), _)
-                        if self.is_boolean_type(&right) =>
-                    {
-                        Self::boolean_folding_for_and(b, right, true)
-                    }
-                    (_, Expr::Literal(ScalarValue::Boolean(b)))
-                        if self.is_boolean_type(&left) =>
-                    {
-                        Self::boolean_folding_for_and(b, left, false)
-                    }
-                    _ => Expr::BinaryExpr {
-                        left,
-                        op: Operator::And,
-                        right,
-                    },
-                },
-                _ => Expr::BinaryExpr { left, op, right },
-            },
-            // Not(Not(expr)) --> expr
-            Expr::Not(inner) => {
-                if let Expr::Not(negated_inner) = *inner {
-                    *negated_inner
-                } else {
-                    Expr::Not(inner)
+            //
+            // Rules for Eq
+            //
+
+            // true = A  --> A
+            // false = A --> !A
+            // null = A --> null
+            BinaryExpr {
+                left,
+                op: Eq,
+                right,
+            } if is_bool_lit(&left) && self.is_boolean_type(&right) => {
+                match as_bool_lit(*left) {
+                    Some(true) => *right,
+                    Some(false) => Not(right),
+                    None => lit_null(),
                 }
             }
+            // A = true  --> A
+            // A = false --> !A
+            // A = null --> null
+            BinaryExpr {
+                left,
+                op: Eq,
+                right,
+            } if is_bool_lit(&right) && self.is_boolean_type(&left) => {
+                match as_bool_lit(*right) {
+                    Some(true) => *left,
+                    Some(false) => Not(left),
+                    None => lit_null(),
+                }
+            }
+
+            //
+            // Rules for NotEq
+            //
+
+            // true != A  --> !A
+            // false != A --> A
+            // null != A --> null
+            BinaryExpr {
+                left,
+                op: NotEq,
+                right,
+            } if is_bool_lit(&left) && self.is_boolean_type(&right) => {
+                match as_bool_lit(*left) {
+                    Some(true) => Not(right),
+                    Some(false) => *right,
+                    None => lit_null(),
+                }
+            }
+            // A != true  --> !A
+            // A != false --> A
+            // A != null --> null,
+            BinaryExpr {
+                left,
+                op: NotEq,
+                right,
+            } if is_bool_lit(&right) && self.is_boolean_type(&left) => {
+                match as_bool_lit(*right) {
+                    Some(true) => Not(left),
+                    Some(false) => *left,
+                    None => lit_null(),
+                }
+            }
+
+            //
+            // Rules for OR
+            //
+
+            // true OR A --> true (even if A is null)
+            BinaryExpr {
+                left,
+                op: Or,
+                right: _,
+            } if is_true(&left) => *left,
+            // false OR A --> A
+            BinaryExpr {
+                left,
+                op: Or,
+                right,
+            } if is_false(&left) => *right,
+            // A OR true --> true (even if A is null)
+            BinaryExpr {
+                left: _,
+                op: Or,
+                right,
+            } if is_true(&right) => *right,
+            // A OR false --> A
+            BinaryExpr {
+                left,
+                op: Or,
+                right,
+            } if is_false(&right) => *left,
+            // (..A..) OR A --> (..A..)
+            BinaryExpr {
+                left,
+                op: Or,
+                right,
+            } if expr_contains(&left, &right, Or) => *left,
+            // A OR (..A..) --> (..A..)
+            BinaryExpr {
+                left,
+                op: Or,
+                right,
+            } if expr_contains(&right, &left, Or) => *right,
+            // A OR (A AND B) --> A (if B not null)
+            BinaryExpr {
+                left,
+                op: Or,
+                right,
+            } if !self.nullable(&right)? && is_op_with(And, &right, &left) => 
*left,
+            // (A AND B) OR A --> A (if B not null)
+            BinaryExpr {
+                left,
+                op: Or,
+                right,
+            } if !self.nullable(&left)? && is_op_with(And, &left, &right) => 
*right,
+
+            //
+            // Rules for AND
+            //
+
+            // true AND A --> A
+            BinaryExpr {
+                left,
+                op: And,
+                right,
+            } if is_true(&left) => *right,
+            // false AND A --> false (even if A is null)
+            BinaryExpr {
+                left,
+                op: And,
+                right: _,
+            } if is_false(&left) => *left,
+            // A AND true --> A
+            BinaryExpr {
+                left,
+                op: And,
+                right,
+            } if is_true(&right) => *left,
+            // A AND false --> false (even if A is null)
+            BinaryExpr {
+                left: _,
+                op: And,
+                right,
+            } if is_false(&right) => *right,
+            // (..A..) AND A --> (..A..)
+            BinaryExpr {
+                left,
+                op: And,
+                right,
+            } if expr_contains(&left, &right, And) => *left,
+            // A AND (..A..) --> (..A..)
+            BinaryExpr {
+                left,
+                op: And,
+                right,
+            } if expr_contains(&right, &left, And) => *right,
+            // A AND (A OR B) --> A (if B not null)
+            BinaryExpr {
+                left,
+                op: And,
+                right,
+            } if !self.nullable(&right)? && is_op_with(Or, &right, &left) => 
*left,
+            // (A OR B) AND A --> A (if B not null)
+            BinaryExpr {
+                left,
+                op: And,
+                right,
+            } if !self.nullable(&left)? && is_op_with(Or, &left, &right) => 
*right,
+
+            //
+            // Rules for Multiply
+            //
+            BinaryExpr {
+                left,
+                op: Multiply,
+                right,
+            } if is_one(&right) => *left,
+            BinaryExpr {
+                left,
+                op: Multiply,
+                right,
+            } if is_one(&left) => *right,
+
+            //
+            // Rules for Divide
+            //
+
+            // A / 1 --> A
+            BinaryExpr {
+                left,
+                op: Divide,
+                right,
+            } if is_one(&right) => *left,
+            // A / null --> null
+            BinaryExpr {
+                left,
+                op: Divide,
+                right,
+            } if left == right && is_null(&left) => *left,
+            // A / A --> 1 (if a is not nullable)
+            BinaryExpr {
+                left,
+                op: Divide,
+                right,
+            } if !self.nullable(&left)? && left == right => lit(1),
+
+            //
+            // Rules for Not
+            //
+
+            // !(!A) --> A
+            Not(inner) if is_not(&inner) => match *inner {
+                Not(negated_inner) => *negated_inner,
+                _ => unreachable!(),
+            },
+
             expr => {
                 // no additional rewrites possible
                 expr
@@ -791,8 +693,8 @@ mod tests {
         let expr_b = lit(true).or(col("c2"));
         let expected = lit(true);
 
-        assert_eq!(simplify(&expr_a), expected);
-        assert_eq!(simplify(&expr_b), expected);
+        assert_eq!(simplify(expr_a), expected);
+        assert_eq!(simplify(expr_b), expected);
     }
 
     #[test]
@@ -801,8 +703,8 @@ mod tests {
         let expr_b = col("c2").or(lit(false));
         let expected = col("c2");
 
-        assert_eq!(simplify(&expr_a), expected);
-        assert_eq!(simplify(&expr_b), expected);
+        assert_eq!(simplify(expr_a), expected);
+        assert_eq!(simplify(expr_b), expected);
     }
 
     #[test]
@@ -810,7 +712,7 @@ mod tests {
         let expr = col("c2").or(col("c2"));
         let expected = col("c2");
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
@@ -819,8 +721,8 @@ mod tests {
         let expr_b = col("c2").and(lit(false));
         let expected = lit(false);
 
-        assert_eq!(simplify(&expr_a), expected);
-        assert_eq!(simplify(&expr_b), expected);
+        assert_eq!(simplify(expr_a), expected);
+        assert_eq!(simplify(expr_b), expected);
     }
 
     #[test]
@@ -828,7 +730,7 @@ mod tests {
         let expr = col("c2").and(col("c2"));
         let expected = col("c2");
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
@@ -837,8 +739,8 @@ mod tests {
         let expr_b = col("c2").and(lit(true));
         let expected = col("c2");
 
-        assert_eq!(simplify(&expr_a), expected);
-        assert_eq!(simplify(&expr_b), expected);
+        assert_eq!(simplify(expr_a), expected);
+        assert_eq!(simplify(expr_b), expected);
     }
 
     #[test]
@@ -847,8 +749,8 @@ mod tests {
         let expr_b = binary_expr(lit(1), Operator::Multiply, col("c2"));
         let expected = col("c2");
 
-        assert_eq!(simplify(&expr_a), expected);
-        assert_eq!(simplify(&expr_b), expected);
+        assert_eq!(simplify(expr_a), expected);
+        assert_eq!(simplify(expr_b), expected);
     }
 
     #[test]
@@ -856,15 +758,24 @@ mod tests {
         let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
         let expected = col("c2");
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
     fn test_simplify_divide_by_same() {
         let expr = binary_expr(col("c2"), Operator::Divide, col("c2"));
+        // if c2 is null, c2 / c2 = null, so can't simplify
+        let expected = expr.clone();
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_divide_by_same_non_null() {
+        let expr = binary_expr(col("c2_non_null"), Operator::Divide, 
col("c2_non_null"));
         let expected = lit(1);
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
@@ -873,21 +784,21 @@ mod tests {
         let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5)));
         let expected = col("c2").gt(lit(5));
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
     fn test_simplify_composed_and() {
-        // ((c > 5) AND (d < 6)) AND (c > 5)
+        // ((c > 5) AND (c1 < 6)) AND (c > 5)
         let expr = binary_expr(
-            binary_expr(col("c2").gt(lit(5)), Operator::And, 
col("d").lt(lit(6))),
+            binary_expr(col("c2").gt(lit(5)), Operator::And, 
col("c1").lt(lit(6))),
             Operator::And,
             col("c2").gt(lit(5)),
         );
         let expected =
-            binary_expr(col("c2").gt(lit(5)), Operator::And, 
col("d").lt(lit(6)));
+            binary_expr(col("c2").gt(lit(5)), Operator::And, 
col("c1").lt(lit(6)));
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
@@ -900,20 +811,91 @@ mod tests {
         );
         let expected = expr.clone();
 
-        assert_eq!(simplify(&expr), expected);
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
     fn test_simplify_or_and() {
-        // (c > 5) OR ((d < 6) AND (c > 5) -- can remove
-        let expr = binary_expr(
-            col("c2").gt(lit(5)),
+        let l = col("c2").gt(lit(5));
+        let r = binary_expr(col("c1").lt(lit(6)), Operator::And, 
col("c2").gt(lit(5)));
+
+        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5))
+        let expr = binary_expr(l.clone(), Operator::Or, r.clone());
+
+        // no rewrites if c1 can be null
+        let expected = expr.clone();
+        assert_eq!(simplify(expr), expected);
+
+        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5)
+        let expr = binary_expr(l, Operator::Or, r);
+
+        // no rewrites if c1 can be null
+        let expected = expr.clone();
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_or_and_non_null() {
+        let l = col("c2_non_null").gt(lit(5));
+        let r = binary_expr(
+            col("c1_non_null").lt(lit(6)),
+            Operator::And,
+            col("c2_non_null").gt(lit(5)),
+        );
+
+        // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5
+        let expr = binary_expr(l.clone(), Operator::Or, r.clone());
+
+        // This is only true if `c1 < 6` is not nullable / can not be null.
+        let expected = col("c2_non_null").gt(lit(5));
+
+        assert_eq!(simplify(expr), expected);
+
+        // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5
+        let expr = binary_expr(l, Operator::Or, r);
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_and_or() {
+        let l = col("c2").gt(lit(5));
+        let r = binary_expr(col("c1").lt(lit(6)), Operator::Or, 
col("c2").gt(lit(5)));
+
+        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
+        let expr = binary_expr(l.clone(), Operator::And, r.clone());
+
+        // no rewrites if c1 can be null
+        let expected = expr.clone();
+        assert_eq!(simplify(expr), expected);
+
+        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
+        let expr = binary_expr(l, Operator::And, r);
+        let expected = expr.clone();
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_and_or_non_null() {
+        let l = col("c2_non_null").gt(lit(5));
+        let r = binary_expr(
+            col("c1_non_null").lt(lit(6)),
             Operator::Or,
-            binary_expr(col("d").lt(lit(6)), Operator::And, 
col("c2").gt(lit(5))),
+            col("c2_non_null").gt(lit(5)),
         );
-        let expected = col("c2").gt(lit(5));
 
-        assert_eq!(simplify(&expr), expected);
+        // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5
+        let expr = binary_expr(l.clone(), Operator::And, r.clone());
+
+        // This is only true if `c1 < 6` is not nullable / can not be null.
+        let expected = col("c2_non_null").gt(lit(5));
+
+        assert_eq!(simplify(expr), expected);
+
+        // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5
+        let expr = binary_expr(l, Operator::And, r);
+
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
@@ -921,7 +903,7 @@ mod tests {
         let expr = binary_expr(lit_null(), Operator::And, lit(false));
         let expr_eq = lit(false);
 
-        assert_eq!(simplify(&expr), expr_eq);
+        assert_eq!(simplify(expr), expr_eq);
     }
 
     #[test]
@@ -930,16 +912,16 @@ mod tests {
         let expr_plus = binary_expr(null.clone(), Operator::Divide, 
null.clone());
         let expr_eq = null;
 
-        assert_eq!(simplify(&expr_plus), expr_eq);
+        assert_eq!(simplify(expr_plus), expr_eq);
     }
 
     #[test]
-    fn test_simplify_do_not_simplify_arithmetic_expr() {
+    fn test_simplify_simplify_arithmetic_expr() {
         let expr_plus = binary_expr(lit(1), Operator::Plus, lit(1));
         let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1));
 
-        assert_eq!(simplify(&expr_plus), expr_plus);
-        assert_eq!(simplify(&expr_eq), expr_eq);
+        assert_eq!(simplify(expr_plus), lit(2));
+        assert_eq!(simplify(expr_eq), lit(true));
     }
 
     // ------------------------------
@@ -1182,11 +1164,17 @@ mod tests {
     // ----- Simplifier tests -------
     // ------------------------------
 
-    // TODO rename to simplify
-    fn do_simplify(expr: Expr) -> Expr {
+    fn simplify(expr: Expr) -> Expr {
         let schema = expr_test_schema();
         let mut rewriter = Simplifier::new(vec![&schema]);
-        expr.rewrite(&mut rewriter).expect("expected to simplify")
+
+        let execution_props = ExecutionProps::new();
+        let mut const_evaluator = ConstEvaluator::new(&execution_props);
+
+        expr.rewrite(&mut rewriter)
+            .expect("expected to simplify")
+            .rewrite(&mut const_evaluator)
+            .expect("expected to const evaluate")
     }
 
     fn expr_test_schema() -> DFSchemaRef {
@@ -1194,6 +1182,8 @@ mod tests {
             DFSchema::new(vec![
                 DFField::new(None, "c1", DataType::Utf8, true),
                 DFField::new(None, "c2", DataType::Boolean, true),
+                DFField::new(None, "c1_non_null", DataType::Utf8, false),
+                DFField::new(None, "c2_non_null", DataType::Boolean, false),
             ])
             .unwrap(),
         )
@@ -1201,20 +1191,20 @@ mod tests {
 
     #[test]
     fn simplify_expr_not_not() {
-        assert_eq!(do_simplify(col("c2").not().not().not()), col("c2").not(),);
+        assert_eq!(simplify(col("c2").not().not().not()), col("c2").not(),);
     }
 
     #[test]
     fn simplify_expr_null_comparison() {
         // x = null is always null
         assert_eq!(
-            do_simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
+            simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))),
             lit(ScalarValue::Boolean(None)),
         );
 
         // null != null is always null
         assert_eq!(
-            do_simplify(
+            simplify(
                 
lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None)))
             ),
             lit(ScalarValue::Boolean(None)),
@@ -1222,13 +1212,13 @@ mod tests {
 
         // x != null is always null
         assert_eq!(
-            do_simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
+            simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))),
             lit(ScalarValue::Boolean(None)),
         );
 
         // null = x is always null
         assert_eq!(
-            do_simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
+            simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))),
             lit(ScalarValue::Boolean(None)),
         );
     }
@@ -1239,16 +1229,16 @@ mod tests {
         assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
 
         // true = ture -> true
-        assert_eq!(do_simplify(lit(true).eq(lit(true))), lit(true));
+        assert_eq!(simplify(lit(true).eq(lit(true))), lit(true));
 
         // true = false -> false
-        assert_eq!(do_simplify(lit(true).eq(lit(false))), lit(false),);
+        assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),);
 
         // c2 = true -> c2
-        assert_eq!(do_simplify(col("c2").eq(lit(true))), col("c2"));
+        assert_eq!(simplify(col("c2").eq(lit(true))), col("c2"));
 
         // c2 = false => !c2
-        assert_eq!(do_simplify(col("c2").eq(lit(false))), col("c2").not(),);
+        assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),);
     }
 
     #[test]
@@ -1262,25 +1252,8 @@ mod tests {
         // Make sure c1 column to be used in tests is not boolean type
         assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
 
-        // don't fold c1 = true
-        assert_eq!(
-            do_simplify(col("c1").eq(lit(true))),
-            col("c1").eq(lit(true)),
-        );
-
-        // don't fold c1 = false
-        assert_eq!(
-            do_simplify(col("c1").eq(lit(false))),
-            col("c1").eq(lit(false)),
-        );
-
-        // test constant operands
-        assert_eq!(do_simplify(lit(1).eq(lit(true))), lit(1).eq(lit(true)),);
-
-        assert_eq!(
-            do_simplify(lit("a").eq(lit(false))),
-            lit("a").eq(lit(false)),
-        );
+        // don't fold c1 = foo
+        assert_eq!(simplify(col("c1").eq(lit("foo"))), 
col("c1").eq(lit("foo")),);
     }
 
     #[test]
@@ -1290,15 +1263,15 @@ mod tests {
         assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean);
 
         // c2 != true -> !c2
-        assert_eq!(do_simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
+        assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),);
 
         // c2 != false -> c2
-        assert_eq!(do_simplify(col("c2").not_eq(lit(false))), col("c2"),);
+        assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),);
 
         // test constant
-        assert_eq!(do_simplify(lit(true).not_eq(lit(true))), lit(false),);
+        assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),);
 
-        assert_eq!(do_simplify(lit(true).not_eq(lit(false))), lit(true),);
+        assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),);
     }
 
     #[test]
@@ -1311,44 +1284,25 @@ mod tests {
         assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8);
 
         assert_eq!(
-            do_simplify(col("c1").not_eq(lit(true))),
-            col("c1").not_eq(lit(true)),
-        );
-
-        assert_eq!(
-            do_simplify(col("c1").not_eq(lit(false))),
-            col("c1").not_eq(lit(false)),
-        );
-
-        // test constants
-        assert_eq!(
-            do_simplify(lit(1).not_eq(lit(true))),
-            lit(1).not_eq(lit(true)),
-        );
-
-        assert_eq!(
-            do_simplify(lit("a").not_eq(lit(false))),
-            lit("a").not_eq(lit(false)),
+            simplify(col("c1").not_eq(lit("foo"))),
+            col("c1").not_eq(lit("foo")),
         );
     }
 
     #[test]
     fn simplify_expr_case_when_then_else() {
         assert_eq!(
-            do_simplify(Expr::Case {
+            simplify(Expr::Case {
                 expr: None,
                 when_then_expr: vec![(
                     Box::new(col("c2").not_eq(lit(false))),
-                    Box::new(lit("ok").eq(lit(true))),
+                    Box::new(lit("ok").eq(lit("not_ok"))),
                 )],
                 else_expr: Some(Box::new(col("c2").eq(lit(true)))),
             }),
             Expr::Case {
                 expr: None,
-                when_then_expr: vec![(
-                    Box::new(col("c2")),
-                    Box::new(lit("ok").eq(lit(true)))
-                )],
+                when_then_expr: vec![(Box::new(col("c2")), 
Box::new(lit(false)))],
                 else_expr: Some(Box::new(col("c2"))),
             }
         );
@@ -1362,22 +1316,22 @@ mod tests {
     #[test]
     fn simplify_expr_bool_or() {
         // col || true is always true
-        assert_eq!(do_simplify(col("c2").or(lit(true))), lit(true),);
+        assert_eq!(simplify(col("c2").or(lit(true))), lit(true),);
 
         // col || false is always col
-        assert_eq!(do_simplify(col("c2").or(lit(false))), col("c2"),);
+        assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),);
 
         // true || null is always true
-        assert_eq!(do_simplify(lit(true).or(lit_null())), lit(true),);
+        assert_eq!(simplify(lit(true).or(lit_null())), lit(true),);
 
         // null || true is always true
-        assert_eq!(do_simplify(lit_null().or(lit(true))), lit(true),);
+        assert_eq!(simplify(lit_null().or(lit(true))), lit(true),);
 
         // false || null is always null
-        assert_eq!(do_simplify(lit(false).or(lit_null())), lit_null(),);
+        assert_eq!(simplify(lit(false).or(lit_null())), lit_null(),);
 
         // null || false is always null
-        assert_eq!(do_simplify(lit_null().or(lit(false))), lit_null(),);
+        assert_eq!(simplify(lit_null().or(lit(false))), lit_null(),);
 
         // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL)
         // it can be either NULL or  TRUE depending on the value of `c1 
BETWEEN Int32(0) AND Int32(10)`
@@ -1389,28 +1343,28 @@ mod tests {
             high: Box::new(lit(10)),
         };
         let expr = expr.or(lit_null());
-        let result = do_simplify(expr.clone());
+        let result = simplify(expr.clone());
         assert_eq!(expr, result);
     }
 
     #[test]
     fn simplify_expr_bool_and() {
         // col & true is always col
-        assert_eq!(do_simplify(col("c2").and(lit(true))), col("c2"),);
+        assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),);
         // col & false is always false
-        assert_eq!(do_simplify(col("c2").and(lit(false))), lit(false),);
+        assert_eq!(simplify(col("c2").and(lit(false))), lit(false),);
 
         // true && null is always null
-        assert_eq!(do_simplify(lit(true).and(lit_null())), lit_null(),);
+        assert_eq!(simplify(lit(true).and(lit_null())), lit_null(),);
 
         // null && true is always null
-        assert_eq!(do_simplify(lit_null().and(lit(true))), lit_null(),);
+        assert_eq!(simplify(lit_null().and(lit(true))), lit_null(),);
 
         // false && null is always false
-        assert_eq!(do_simplify(lit(false).and(lit_null())), lit(false),);
+        assert_eq!(simplify(lit(false).and(lit_null())), lit(false),);
 
         // null && false is always false
-        assert_eq!(do_simplify(lit_null().and(lit(false))), lit(false),);
+        assert_eq!(simplify(lit_null().and(lit(false))), lit(false),);
 
         // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL)
         // it can be either NULL or FALSE depending on the value of `c1 
BETWEEN Int32(0) AND Int32(10`
@@ -1422,7 +1376,7 @@ mod tests {
             high: Box::new(lit(10)),
         };
         let expr = expr.and(lit_null());
-        let result = do_simplify(expr.clone());
+        let result = simplify(expr.clone());
         assert_eq!(expr, result);
     }
 
@@ -1473,12 +1427,12 @@ mod tests {
         );
     }
 
-    // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
     #[test]
     fn test_simplify_optimized_plan_with_composed_and() {
         let table_scan = test_table_scan();
+        // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6)
         let plan = LogicalPlanBuilder::from(table_scan)
-            .project(vec![col("a")])
+            .project(vec![col("a"), col("b")])
             .unwrap()
             .filter(and(
                 and(col("a").gt(lit(5)), col("b").lt(lit(6))),
@@ -1492,7 +1446,7 @@ mod tests {
             &plan,
             "\
             Filter: #test.a > Int32(5) AND #test.b < Int32(6) AS test.a > 
Int32(5) AND test.b < Int32(6) AND test.a > Int32(5)\
-            \n  Projection: #test.a\
+            \n  Projection: #test.a, #test.b\
                \n    TableScan: test projection=None",
         );
     }

Reply via email to