This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 78ca43c9a6 Don't error in simplify_expressions rule (#8957)
78ca43c9a6 is described below

commit 78ca43c9a6db5e72850573eacf06bb4aedad4561
Author: Huaijin <[email protected]>
AuthorDate: Wed Jan 24 19:30:37 2024 +0800

    Don't error in simplify_expressions rule (#8957)
    
    * Don't error in simplify_expressions rule
    
    * apply reviews
    
    * apply reviers
---
 .../src/simplify_expressions/expr_simplifier.rs    | 153 ++++++++++++---------
 .../src/simplify_expressions/simplify_exprs.rs     |  63 ---------
 datafusion/physical-expr/src/expressions/case.rs   |  10 ++
 .../sqllogictest/test_files/arrow_typeof.slt       |   4 +-
 datafusion/sqllogictest/test_files/math.slt        |  38 ++---
 datafusion/sqllogictest/test_files/scalar.slt      |   4 +-
 datafusion/sqllogictest/test_files/select.slt      |  11 ++
 7 files changed, 129 insertions(+), 154 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 35450b1f32..95536e9fc5 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -35,11 +35,10 @@ use arrow::{
 };
 use datafusion_common::{
     cast::{as_large_list_array, as_list_array},
-    plan_err,
     tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter},
 };
 use datafusion_common::{
-    exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, 
ScalarValue,
+    internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
 };
 use datafusion_expr::{
     and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, 
Expr, Like,
@@ -253,6 +252,14 @@ struct ConstEvaluator<'a> {
     input_batch: RecordBatch,
 }
 
+/// The simplify result of ConstEvaluator
+enum ConstSimplifyResult {
+    // Expr was simplifed and contains the new expression
+    Simplified(ScalarValue),
+    // Evalaution encountered an error, contains the original expression
+    SimplifyRuntimeError(DataFusionError, Expr),
+}
+
 impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
     type N = Expr;
 
@@ -285,7 +292,17 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
 
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         match self.can_evaluate.pop() {
-            Some(true) => Ok(Expr::Literal(self.evaluate_to_scalar(expr)?)),
+            // Certain expressions such as `CASE` and `COALESCE` are short 
circuiting
+            // and may not evalute all their sub expressions. Thus if
+            // if any error is countered during simplification, return the 
original
+            // so that normal evaluation can occur
+            Some(true) => {
+                let result = self.evaluate_to_scalar(expr);
+                match result {
+                    ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)),
+                    ConstSimplifyResult::SimplifyRuntimeError(_, expr) => 
Ok(expr),
+                }
+            }
             Some(false) => Ok(expr),
             _ => internal_err!("Failed to pop can_evaluate"),
         }
@@ -380,29 +397,40 @@ impl<'a> ConstEvaluator<'a> {
     }
 
     /// Internal helper to evaluates an Expr
-    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> 
Result<ScalarValue> {
+    pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> 
ConstSimplifyResult {
         if let Expr::Literal(s) = expr {
-            return Ok(s);
+            return ConstSimplifyResult::Simplified(s);
         }
 
         let phys_expr =
-            create_physical_expr(&expr, &self.input_schema, 
self.execution_props)?;
-        let col_val = phys_expr.evaluate(&self.input_batch)?;
+            match create_physical_expr(&expr, &self.input_schema, 
self.execution_props) {
+                Ok(e) => e,
+                Err(err) => return 
ConstSimplifyResult::SimplifyRuntimeError(err, expr),
+            };
+        let col_val = match phys_expr.evaluate(&self.input_batch) {
+            Ok(v) => v,
+            Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, 
expr),
+        };
         match col_val {
             ColumnarValue::Array(a) => {
                 if a.len() != 1 {
-                    exec_err!(
-                        "Could not evaluate the expression, found a result of 
length {}",
-                        a.len()
+                    ConstSimplifyResult::SimplifyRuntimeError(
+                        DataFusionError::Execution(format!("Could not evaluate 
the expression, found a result of length {}", a.len())),
+                        expr,
                     )
                 } else if as_list_array(&a).is_ok() || 
as_large_list_array(&a).is_ok() {
-                    Ok(ScalarValue::List(a.as_list().to_owned().into()))
+                    ConstSimplifyResult::Simplified(ScalarValue::List(
+                        a.as_list().to_owned().into(),
+                    ))
                 } else {
                     // Non-ListArray
-                    ScalarValue::try_from_array(&a, 0)
+                    match ScalarValue::try_from_array(&a, 0) {
+                        Ok(s) => ConstSimplifyResult::Simplified(s),
+                        Err(err) => 
ConstSimplifyResult::SimplifyRuntimeError(err, expr),
+                    }
                 }
             }
-            ColumnarValue::Scalar(s) => Ok(s),
+            ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s),
         }
     }
 }
@@ -800,18 +828,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 op: Divide,
                 right,
             }) if is_null(&right) => *right,
-            // A / 0 -> Divide by zero error if A is not null and not floating
-            // (float / 0 -> inf | -inf | NAN)
-            Expr::BinaryExpr(BinaryExpr {
-                left,
-                op: Divide,
-                right,
-            }) if !info.nullable(&left)?
-                && !info.get_data_type(&left)?.is_floating()
-                && is_zero(&right) =>
-            {
-                return plan_err!("Divide by zero");
-            }
 
             //
             // Rules for Modulo
@@ -840,21 +856,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
             {
                 lit(0)
             }
-            // A % 0 --> Divide by zero Error (if A is not floating and not 
null)
-            // A % 0 --> NAN (if A is floating and not null)
-            Expr::BinaryExpr(BinaryExpr {
-                left,
-                op: Modulo,
-                right,
-            }) if !info.nullable(&left)? && is_zero(&right) => {
-                match info.get_data_type(&left)? {
-                    DataType::Float32 => lit(f32::NAN),
-                    DataType::Float64 => lit(f64::NAN),
-                    _ => {
-                        return plan_err!("Divide by zero");
-                    }
-                }
-            }
 
             //
             // Rules for BitwiseAnd
@@ -1321,9 +1322,7 @@ mod tests {
         array::{ArrayRef, Int32Array},
         datatypes::{DataType, Field, Schema},
     };
-    use datafusion_common::{
-        assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, 
ToDFSchema,
-    };
+    use datafusion_common::{assert_contains, cast::as_int32_array, DFField, 
ToDFSchema};
     use datafusion_expr::{interval_arithmetic::Interval, *};
     use datafusion_physical_expr::execution_props::ExecutionProps;
 
@@ -1796,27 +1795,6 @@ mod tests {
         assert_eq!(simplify(expr), expected);
     }
 
-    #[test]
-    fn test_simplify_divide_zero_by_zero() {
-        // 0 / 0 -> Divide by zero
-        let expr = lit(0) / lit(0);
-        let err = try_simplify(expr).unwrap_err();
-
-        let _expected = plan_datafusion_err!("Divide by zero");
-
-        assert!(matches!(err, ref _expected), "{err}");
-    }
-
-    #[test]
-    fn test_simplify_divide_by_zero() {
-        // A / 0 -> DivideByZeroError
-        let expr = col("c2_non_null") / lit(0);
-        assert_eq!(
-            try_simplify(expr).unwrap_err().strip_backtrace(),
-            "Error during planning: Divide by zero"
-        );
-    }
-
     #[test]
     fn test_simplify_modulo_by_null() {
         let null = lit(ScalarValue::Null);
@@ -1841,6 +1819,26 @@ mod tests {
         assert_eq!(simplify(expr), expected);
     }
 
+    #[test]
+    fn test_simplify_divide_zero_by_zero() {
+        // because divide by 0 maybe occur in short-circuit expression
+        // so we should not simplify this, and throw error in runtime
+        let expr = lit(0) / lit(0);
+        let expected = expr.clone();
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_simplify_divide_by_zero() {
+        // because divide by 0 maybe occur in short-circuit expression
+        // so we should not simplify this, and throw error in runtime
+        let expr = col("c2_non_null") / lit(0);
+        let expected = expr.clone();
+
+        assert_eq!(simplify(expr), expected);
+    }
+
     #[test]
     fn test_simplify_modulo_by_one_non_null() {
         let expr = col("c2_non_null") % lit(1);
@@ -2235,11 +2233,12 @@ mod tests {
 
     #[test]
     fn test_simplify_modulo_by_zero_non_null() {
+        // because modulo by 0 maybe occur in short-circuit expression
+        // so we should not simplify this, and throw error in runtime.
         let expr = col("c2_non_null") % lit(0);
-        assert_eq!(
-            try_simplify(expr).unwrap_err().strip_backtrace(),
-            "Error during planning: Divide by zero"
-        );
+        let expected = expr.clone();
+
+        assert_eq!(simplify(expr), expected);
     }
 
     #[test]
@@ -3496,4 +3495,22 @@ mod tests {
         let output = simplify_with_guarantee(expr.clone(), guarantees);
         assert_eq!(&output, &expr_x);
     }
+
+    #[test]
+    fn test_expression_partial_simplify_1() {
+        // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
+        let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
+        let expected = (lit(3)) + (lit(4) / lit(0));
+
+        assert_eq!(simplify(expr), expected);
+    }
+
+    #[test]
+    fn test_expression_partial_simplify_2() {
+        // (1 > 2) and (4 / 0) -> false
+        let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
+        let expected = lit(false);
+
+        assert_eq!(simplify(expr), expected);
+    }
 }
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs 
b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index 43a41b1185..cfd02547b8 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -138,28 +138,6 @@ mod tests {
         ExprSchemable, JoinType,
     };
 
-    /// A macro to assert that one string is contained within another with
-    /// a nice error message if they are not.
-    ///
-    /// Usage: `assert_contains!(actual, expected)`
-    ///
-    /// Is a macro so test error
-    /// messages are on the same line as the failure;
-    ///
-    /// Both arguments must be convertable into Strings (Into<String>)
-    macro_rules! assert_contains {
-        ($ACTUAL: expr, $EXPECTED: expr) => {
-            let actual_value: String = $ACTUAL.into();
-            let expected_value: String = $EXPECTED.into();
-            assert!(
-                actual_value.contains(&expected_value),
-                "Can not find expected in 
actual.\n\nExpected:\n{}\n\nActual:\n{}",
-                expected_value,
-                actual_value
-            );
-        };
-    }
-
     fn test_table_scan() -> LogicalPlan {
         let schema = Schema::new(vec![
             Field::new("a", DataType::Boolean, false),
@@ -425,18 +403,6 @@ mod tests {
         assert_optimized_plan_eq(&plan, expected)
     }
 
-    // expect optimizing will result in an error, returning the error string
-    fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime<Utc>) 
-> String {
-        let config = 
OptimizerContext::new().with_query_execution_start_time(*date_time);
-        let rule = SimplifyExpressions::new();
-
-        let err = rule
-            .try_optimize(plan, &config)
-            .expect_err("expected optimization to fail");
-
-        err.to_string()
-    }
-
     fn get_optimized_plan_formatted(
         plan: &LogicalPlan,
         date_time: &DateTime<Utc>,
@@ -468,21 +434,6 @@ mod tests {
         Ok(())
     }
 
-    #[test]
-    fn to_timestamp_expr_wrong_arg() -> Result<()> {
-        let table_scan = test_table_scan();
-        let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")];
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .project(proj)?
-            .build()?;
-
-        let expected =
-            "Error parsing timestamp from 'I'M NOT A TIMESTAMP': error parsing 
date";
-        let actual = get_optimized_plan_err(&plan, &Utc::now());
-        assert_contains!(actual, expected);
-        Ok(())
-    }
-
     #[test]
     fn cast_expr() -> Result<()> {
         let table_scan = test_table_scan();
@@ -498,20 +449,6 @@ mod tests {
         Ok(())
     }
 
-    #[test]
-    fn cast_expr_wrong_arg() -> Result<()> {
-        let table_scan = test_table_scan();
-        let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")), 
DataType::Int32))];
-        let plan = LogicalPlanBuilder::from(table_scan)
-            .project(proj)?
-            .build()?;
-
-        let expected = "Cannot cast string '' to value of Int32 type";
-        let actual = get_optimized_plan_err(&plan, &Utc::now());
-        assert_contains!(actual, expected);
-        Ok(())
-    }
-
     #[test]
     fn multiple_now_expr() -> Result<()> {
         let table_scan = test_table_scan();
diff --git a/datafusion/physical-expr/src/expressions/case.rs 
b/datafusion/physical-expr/src/expressions/case.rs
index 414ddd0921..6a168e2f1e 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -148,6 +148,11 @@ impl CaseExpr {
             // Make sure we only consider rows that have not been matched yet
             let when_match = and(&when_match, &remainder)?;
 
+            // When no rows available for when clause, skip then clause
+            if when_match.true_count() == 0 {
+                continue;
+            }
+
             let then_value = self.when_then_expr[i]
                 .1
                 .evaluate_selection(batch, &when_match)?;
@@ -214,6 +219,11 @@ impl CaseExpr {
             // Make sure we only consider rows that have not been matched yet
             let when_value = and(&when_value, &remainder)?;
 
+            // When no rows available for when clause, skip then clause
+            if when_value.true_count() == 0 {
+                continue;
+            }
+
             let then_value = self.when_then_expr[i]
                 .1
                 .evaluate_selection(batch, &when_value)?;
diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt 
b/datafusion/sqllogictest/test_files/arrow_typeof.slt
index afc28ecc39..8b3bd7eac9 100644
--- a/datafusion/sqllogictest/test_files/arrow_typeof.slt
+++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt
@@ -405,7 +405,7 @@ select arrow_cast([1], 'FixedSizeList(1, Int64)');
 ----
 [1]
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' failed
+query error DataFusion error: Arrow error: Cast error: Cannot cast to 
FixedSizeList\(4\): value at index 0 has length 3
 select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)');
 
 query ?
@@ -421,4 +421,4 @@ FixedSizeList(Field { name: "item", data_type: Int64, 
nullable: true, dict_id: 0
 query ?
 select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)');
 ----
-[1, 2, 3]
\ No newline at end of file
+[1, 2, 3]
diff --git a/datafusion/sqllogictest/test_files/math.slt 
b/datafusion/sqllogictest/test_files/math.slt
index 0fa7ff9c20..5f3e1dd9ee 100644
--- a/datafusion/sqllogictest/test_files/math.slt
+++ b/datafusion/sqllogictest/test_files/math.slt
@@ -121,7 +121,7 @@ statement error DataFusion error: Error during planning: No 
function matches the
 SELECT abs(1, 2);
 
 # abs: unsupported argument type
-statement error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nThis feature is not implemented: Unsupported data type Utf8 
for function abs
+query error DataFusion error: This feature is not implemented: Unsupported 
data type Utf8 for function abs
 SELECT abs('foo');
 
 
@@ -293,52 +293,52 @@ select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0 
from test_non_nullable_int
 ----
 0 0 0 0 0 0 0 0
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c1/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c2/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c3/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c4/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c5/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c6/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c7/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c8/0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c1%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c2%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c3%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c4%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c5%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c6%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c7%0 FROM test_non_nullable_integer
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c8%0 FROM test_non_nullable_integer
 
 statement ok
@@ -556,10 +556,10 @@ SELECT c1*0 FROM test_non_nullable_decimal
 ----
 0
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c1/0 FROM test_non_nullable_decimal 
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
 SELECT c1%0 FROM test_non_nullable_decimal 
 
 statement ok
diff --git a/datafusion/sqllogictest/test_files/scalar.slt 
b/datafusion/sqllogictest/test_files/scalar.slt
index 9b30699e3f..3e8ebe54c0 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1527,7 +1527,7 @@ SELECT not(true), not(false)
 ----
 false true
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nInternal error: NOT 'Literal \{ value: Int64\(1\) \}' can't 
be evaluated because the expression's type is Int64, not boolean or NULL
+query error
 SELECT not(1), not(0)
 
 query ?B
@@ -1535,7 +1535,7 @@ SELECT null, not(null)
 ----
 NULL NULL
 
-query error DataFusion error: Optimizer rule 'simplify_expressions' 
failed\ncaused by\nInternal error: NOT 'Literal \{ value: Utf8\("hi"\) \}' 
can't be evaluated because the expression's type is Utf8, not boolean or NULL
+query error
 SELECT NOT('hi')
 
 # test_negative_expressions()
diff --git a/datafusion/sqllogictest/test_files/select.slt 
b/datafusion/sqllogictest/test_files/select.slt
index 9ffddc6e2d..faa5370c70 100644
--- a/datafusion/sqllogictest/test_files/select.slt
+++ b/datafusion/sqllogictest/test_files/select.slt
@@ -1175,3 +1175,14 @@ SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / 
x from t;
 
 statement ok
 DROP TABLE t;
+
+query I
+SELECT CASE 1 WHEN 2 THEN 4 / 0 END;
+----
+NULL
+
+query error DataFusion error: Arrow error: Parser error: Error parsing 
timestamp from 'I AM NOT A TIMESTAMP': error parsing date
+SELECT to_timestamp('I AM NOT A TIMESTAMP');
+
+query error DataFusion error: Arrow error: Cast error: Cannot cast string '' 
to value of Int32 type
+SELECT CAST('' AS int);

Reply via email to