This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 78ca43c9a6 Don't error in simplify_expressions rule (#8957)
78ca43c9a6 is described below
commit 78ca43c9a6db5e72850573eacf06bb4aedad4561
Author: Huaijin <[email protected]>
AuthorDate: Wed Jan 24 19:30:37 2024 +0800
Don't error in simplify_expressions rule (#8957)
* Don't error in simplify_expressions rule
* apply reviews
* apply reviers
---
.../src/simplify_expressions/expr_simplifier.rs | 153 ++++++++++++---------
.../src/simplify_expressions/simplify_exprs.rs | 63 ---------
datafusion/physical-expr/src/expressions/case.rs | 10 ++
.../sqllogictest/test_files/arrow_typeof.slt | 4 +-
datafusion/sqllogictest/test_files/math.slt | 38 ++---
datafusion/sqllogictest/test_files/scalar.slt | 4 +-
datafusion/sqllogictest/test_files/select.slt | 11 ++
7 files changed, 129 insertions(+), 154 deletions(-)
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 35450b1f32..95536e9fc5 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -35,11 +35,10 @@ use arrow::{
};
use datafusion_common::{
cast::{as_large_list_array, as_list_array},
- plan_err,
tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{
- exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue,
+ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue,
Expr, Like,
@@ -253,6 +252,14 @@ struct ConstEvaluator<'a> {
input_batch: RecordBatch,
}
+/// The simplify result of ConstEvaluator
+enum ConstSimplifyResult {
+ // Expr was simplifed and contains the new expression
+ Simplified(ScalarValue),
+ // Evalaution encountered an error, contains the original expression
+ SimplifyRuntimeError(DataFusionError, Expr),
+}
+
impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
type N = Expr;
@@ -285,7 +292,17 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match self.can_evaluate.pop() {
- Some(true) => Ok(Expr::Literal(self.evaluate_to_scalar(expr)?)),
+ // Certain expressions such as `CASE` and `COALESCE` are short
circuiting
+ // and may not evalute all their sub expressions. Thus if
+ // if any error is countered during simplification, return the
original
+ // so that normal evaluation can occur
+ Some(true) => {
+ let result = self.evaluate_to_scalar(expr);
+ match result {
+ ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)),
+ ConstSimplifyResult::SimplifyRuntimeError(_, expr) =>
Ok(expr),
+ }
+ }
Some(false) => Ok(expr),
_ => internal_err!("Failed to pop can_evaluate"),
}
@@ -380,29 +397,40 @@ impl<'a> ConstEvaluator<'a> {
}
/// Internal helper to evaluates an Expr
- pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) ->
Result<ScalarValue> {
+ pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) ->
ConstSimplifyResult {
if let Expr::Literal(s) = expr {
- return Ok(s);
+ return ConstSimplifyResult::Simplified(s);
}
let phys_expr =
- create_physical_expr(&expr, &self.input_schema,
self.execution_props)?;
- let col_val = phys_expr.evaluate(&self.input_batch)?;
+ match create_physical_expr(&expr, &self.input_schema,
self.execution_props) {
+ Ok(e) => e,
+ Err(err) => return
ConstSimplifyResult::SimplifyRuntimeError(err, expr),
+ };
+ let col_val = match phys_expr.evaluate(&self.input_batch) {
+ Ok(v) => v,
+ Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err,
expr),
+ };
match col_val {
ColumnarValue::Array(a) => {
if a.len() != 1 {
- exec_err!(
- "Could not evaluate the expression, found a result of
length {}",
- a.len()
+ ConstSimplifyResult::SimplifyRuntimeError(
+ DataFusionError::Execution(format!("Could not evaluate
the expression, found a result of length {}", a.len())),
+ expr,
)
} else if as_list_array(&a).is_ok() ||
as_large_list_array(&a).is_ok() {
- Ok(ScalarValue::List(a.as_list().to_owned().into()))
+ ConstSimplifyResult::Simplified(ScalarValue::List(
+ a.as_list().to_owned().into(),
+ ))
} else {
// Non-ListArray
- ScalarValue::try_from_array(&a, 0)
+ match ScalarValue::try_from_array(&a, 0) {
+ Ok(s) => ConstSimplifyResult::Simplified(s),
+ Err(err) =>
ConstSimplifyResult::SimplifyRuntimeError(err, expr),
+ }
}
}
- ColumnarValue::Scalar(s) => Ok(s),
+ ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s),
}
}
}
@@ -800,18 +828,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
op: Divide,
right,
}) if is_null(&right) => *right,
- // A / 0 -> Divide by zero error if A is not null and not floating
- // (float / 0 -> inf | -inf | NAN)
- Expr::BinaryExpr(BinaryExpr {
- left,
- op: Divide,
- right,
- }) if !info.nullable(&left)?
- && !info.get_data_type(&left)?.is_floating()
- && is_zero(&right) =>
- {
- return plan_err!("Divide by zero");
- }
//
// Rules for Modulo
@@ -840,21 +856,6 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
{
lit(0)
}
- // A % 0 --> Divide by zero Error (if A is not floating and not
null)
- // A % 0 --> NAN (if A is floating and not null)
- Expr::BinaryExpr(BinaryExpr {
- left,
- op: Modulo,
- right,
- }) if !info.nullable(&left)? && is_zero(&right) => {
- match info.get_data_type(&left)? {
- DataType::Float32 => lit(f32::NAN),
- DataType::Float64 => lit(f64::NAN),
- _ => {
- return plan_err!("Divide by zero");
- }
- }
- }
//
// Rules for BitwiseAnd
@@ -1321,9 +1322,7 @@ mod tests {
array::{ArrayRef, Int32Array},
datatypes::{DataType, Field, Schema},
};
- use datafusion_common::{
- assert_contains, cast::as_int32_array, plan_datafusion_err, DFField,
ToDFSchema,
- };
+ use datafusion_common::{assert_contains, cast::as_int32_array, DFField,
ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_physical_expr::execution_props::ExecutionProps;
@@ -1796,27 +1795,6 @@ mod tests {
assert_eq!(simplify(expr), expected);
}
- #[test]
- fn test_simplify_divide_zero_by_zero() {
- // 0 / 0 -> Divide by zero
- let expr = lit(0) / lit(0);
- let err = try_simplify(expr).unwrap_err();
-
- let _expected = plan_datafusion_err!("Divide by zero");
-
- assert!(matches!(err, ref _expected), "{err}");
- }
-
- #[test]
- fn test_simplify_divide_by_zero() {
- // A / 0 -> DivideByZeroError
- let expr = col("c2_non_null") / lit(0);
- assert_eq!(
- try_simplify(expr).unwrap_err().strip_backtrace(),
- "Error during planning: Divide by zero"
- );
- }
-
#[test]
fn test_simplify_modulo_by_null() {
let null = lit(ScalarValue::Null);
@@ -1841,6 +1819,26 @@ mod tests {
assert_eq!(simplify(expr), expected);
}
+ #[test]
+ fn test_simplify_divide_zero_by_zero() {
+ // because divide by 0 maybe occur in short-circuit expression
+ // so we should not simplify this, and throw error in runtime
+ let expr = lit(0) / lit(0);
+ let expected = expr.clone();
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_simplify_divide_by_zero() {
+ // because divide by 0 maybe occur in short-circuit expression
+ // so we should not simplify this, and throw error in runtime
+ let expr = col("c2_non_null") / lit(0);
+ let expected = expr.clone();
+
+ assert_eq!(simplify(expr), expected);
+ }
+
#[test]
fn test_simplify_modulo_by_one_non_null() {
let expr = col("c2_non_null") % lit(1);
@@ -2235,11 +2233,12 @@ mod tests {
#[test]
fn test_simplify_modulo_by_zero_non_null() {
+ // because modulo by 0 maybe occur in short-circuit expression
+ // so we should not simplify this, and throw error in runtime.
let expr = col("c2_non_null") % lit(0);
- assert_eq!(
- try_simplify(expr).unwrap_err().strip_backtrace(),
- "Error during planning: Divide by zero"
- );
+ let expected = expr.clone();
+
+ assert_eq!(simplify(expr), expected);
}
#[test]
@@ -3496,4 +3495,22 @@ mod tests {
let output = simplify_with_guarantee(expr.clone(), guarantees);
assert_eq!(&output, &expr_x);
}
+
+ #[test]
+ fn test_expression_partial_simplify_1() {
+ // (1 + 2) + (4 / 0) -> 3 + (4 / 0)
+ let expr = (lit(1) + lit(2)) + (lit(4) / lit(0));
+ let expected = (lit(3)) + (lit(4) / lit(0));
+
+ assert_eq!(simplify(expr), expected);
+ }
+
+ #[test]
+ fn test_expression_partial_simplify_2() {
+ // (1 > 2) and (4 / 0) -> false
+ let expr = (lit(1).gt(lit(2))).and(lit(4) / lit(0));
+ let expected = lit(false);
+
+ assert_eq!(simplify(expr), expected);
+ }
}
diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
index 43a41b1185..cfd02547b8 100644
--- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
+++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs
@@ -138,28 +138,6 @@ mod tests {
ExprSchemable, JoinType,
};
- /// A macro to assert that one string is contained within another with
- /// a nice error message if they are not.
- ///
- /// Usage: `assert_contains!(actual, expected)`
- ///
- /// Is a macro so test error
- /// messages are on the same line as the failure;
- ///
- /// Both arguments must be convertable into Strings (Into<String>)
- macro_rules! assert_contains {
- ($ACTUAL: expr, $EXPECTED: expr) => {
- let actual_value: String = $ACTUAL.into();
- let expected_value: String = $EXPECTED.into();
- assert!(
- actual_value.contains(&expected_value),
- "Can not find expected in
actual.\n\nExpected:\n{}\n\nActual:\n{}",
- expected_value,
- actual_value
- );
- };
- }
-
fn test_table_scan() -> LogicalPlan {
let schema = Schema::new(vec![
Field::new("a", DataType::Boolean, false),
@@ -425,18 +403,6 @@ mod tests {
assert_optimized_plan_eq(&plan, expected)
}
- // expect optimizing will result in an error, returning the error string
- fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime<Utc>)
-> String {
- let config =
OptimizerContext::new().with_query_execution_start_time(*date_time);
- let rule = SimplifyExpressions::new();
-
- let err = rule
- .try_optimize(plan, &config)
- .expect_err("expected optimization to fail");
-
- err.to_string()
- }
-
fn get_optimized_plan_formatted(
plan: &LogicalPlan,
date_time: &DateTime<Utc>,
@@ -468,21 +434,6 @@ mod tests {
Ok(())
}
- #[test]
- fn to_timestamp_expr_wrong_arg() -> Result<()> {
- let table_scan = test_table_scan();
- let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")];
- let plan = LogicalPlanBuilder::from(table_scan)
- .project(proj)?
- .build()?;
-
- let expected =
- "Error parsing timestamp from 'I'M NOT A TIMESTAMP': error parsing
date";
- let actual = get_optimized_plan_err(&plan, &Utc::now());
- assert_contains!(actual, expected);
- Ok(())
- }
-
#[test]
fn cast_expr() -> Result<()> {
let table_scan = test_table_scan();
@@ -498,20 +449,6 @@ mod tests {
Ok(())
}
- #[test]
- fn cast_expr_wrong_arg() -> Result<()> {
- let table_scan = test_table_scan();
- let proj = vec![Expr::Cast(Cast::new(Box::new(lit("")),
DataType::Int32))];
- let plan = LogicalPlanBuilder::from(table_scan)
- .project(proj)?
- .build()?;
-
- let expected = "Cannot cast string '' to value of Int32 type";
- let actual = get_optimized_plan_err(&plan, &Utc::now());
- assert_contains!(actual, expected);
- Ok(())
- }
-
#[test]
fn multiple_now_expr() -> Result<()> {
let table_scan = test_table_scan();
diff --git a/datafusion/physical-expr/src/expressions/case.rs
b/datafusion/physical-expr/src/expressions/case.rs
index 414ddd0921..6a168e2f1e 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -148,6 +148,11 @@ impl CaseExpr {
// Make sure we only consider rows that have not been matched yet
let when_match = and(&when_match, &remainder)?;
+ // When no rows available for when clause, skip then clause
+ if when_match.true_count() == 0 {
+ continue;
+ }
+
let then_value = self.when_then_expr[i]
.1
.evaluate_selection(batch, &when_match)?;
@@ -214,6 +219,11 @@ impl CaseExpr {
// Make sure we only consider rows that have not been matched yet
let when_value = and(&when_value, &remainder)?;
+ // When no rows available for when clause, skip then clause
+ if when_value.true_count() == 0 {
+ continue;
+ }
+
let then_value = self.when_then_expr[i]
.1
.evaluate_selection(batch, &when_value)?;
diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt
b/datafusion/sqllogictest/test_files/arrow_typeof.slt
index afc28ecc39..8b3bd7eac9 100644
--- a/datafusion/sqllogictest/test_files/arrow_typeof.slt
+++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt
@@ -405,7 +405,7 @@ select arrow_cast([1], 'FixedSizeList(1, Int64)');
----
[1]
-query error DataFusion error: Optimizer rule 'simplify_expressions' failed
+query error DataFusion error: Arrow error: Cast error: Cannot cast to
FixedSizeList\(4\): value at index 0 has length 3
select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)');
query ?
@@ -421,4 +421,4 @@ FixedSizeList(Field { name: "item", data_type: Int64,
nullable: true, dict_id: 0
query ?
select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)');
----
-[1, 2, 3]
\ No newline at end of file
+[1, 2, 3]
diff --git a/datafusion/sqllogictest/test_files/math.slt
b/datafusion/sqllogictest/test_files/math.slt
index 0fa7ff9c20..5f3e1dd9ee 100644
--- a/datafusion/sqllogictest/test_files/math.slt
+++ b/datafusion/sqllogictest/test_files/math.slt
@@ -121,7 +121,7 @@ statement error DataFusion error: Error during planning: No
function matches the
SELECT abs(1, 2);
# abs: unsupported argument type
-statement error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nThis feature is not implemented: Unsupported data type Utf8
for function abs
+query error DataFusion error: This feature is not implemented: Unsupported
data type Utf8 for function abs
SELECT abs('foo');
@@ -293,52 +293,52 @@ select c1*0, c2*0, c3*0, c4*0, c5*0, c6*0, c7*0, c8*0
from test_non_nullable_int
----
0 0 0 0 0 0 0 0
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c1/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c2/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c3/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c4/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c5/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c6/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c7/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c8/0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c1%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c2%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c3%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c4%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c5%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c6%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c7%0 FROM test_non_nullable_integer
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c8%0 FROM test_non_nullable_integer
statement ok
@@ -556,10 +556,10 @@ SELECT c1*0 FROM test_non_nullable_decimal
----
0
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c1/0 FROM test_non_nullable_decimal
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nError during planning: Divide by zero
+query error DataFusion error: Arrow error: Divide by zero error
SELECT c1%0 FROM test_non_nullable_decimal
statement ok
diff --git a/datafusion/sqllogictest/test_files/scalar.slt
b/datafusion/sqllogictest/test_files/scalar.slt
index 9b30699e3f..3e8ebe54c0 100644
--- a/datafusion/sqllogictest/test_files/scalar.slt
+++ b/datafusion/sqllogictest/test_files/scalar.slt
@@ -1527,7 +1527,7 @@ SELECT not(true), not(false)
----
false true
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nInternal error: NOT 'Literal \{ value: Int64\(1\) \}' can't
be evaluated because the expression's type is Int64, not boolean or NULL
+query error
SELECT not(1), not(0)
query ?B
@@ -1535,7 +1535,7 @@ SELECT null, not(null)
----
NULL NULL
-query error DataFusion error: Optimizer rule 'simplify_expressions'
failed\ncaused by\nInternal error: NOT 'Literal \{ value: Utf8\("hi"\) \}'
can't be evaluated because the expression's type is Utf8, not boolean or NULL
+query error
SELECT NOT('hi')
# test_negative_expressions()
diff --git a/datafusion/sqllogictest/test_files/select.slt
b/datafusion/sqllogictest/test_files/select.slt
index 9ffddc6e2d..faa5370c70 100644
--- a/datafusion/sqllogictest/test_files/select.slt
+++ b/datafusion/sqllogictest/test_files/select.slt
@@ -1175,3 +1175,14 @@ SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 /
x from t;
statement ok
DROP TABLE t;
+
+query I
+SELECT CASE 1 WHEN 2 THEN 4 / 0 END;
+----
+NULL
+
+query error DataFusion error: Arrow error: Parser error: Error parsing
timestamp from 'I AM NOT A TIMESTAMP': error parsing date
+SELECT to_timestamp('I AM NOT A TIMESTAMP');
+
+query error DataFusion error: Arrow error: Cast error: Cannot cast string ''
to value of Int32 type
+SELECT CAST('' AS int);