This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c9935ae52e fix: common_subexpr_eliminate rule should not apply to 
short-circuit expression (#8928)
c9935ae52e is described below

commit c9935ae52ebdc54a3578d789e2c1c4cd29ba54bd
Author: Huaijin <[email protected]>
AuthorDate: Tue Jan 23 03:16:39 2024 +0800

    fix: common_subexpr_eliminate rule should not apply to short-circuit 
expression (#8928)
    
    * fix: common_subexpr_eliminate rule should not apply to short-circuit 
expression
    
    * add more tests
    
    * format
    
    * minor
    
    * apply reviews
    
    * add some commont
    
    * fmt
---
 datafusion/expr/src/expr.rs                        | 48 ++++++++++++++++++++++
 .../optimizer/src/common_subexpr_eliminate.rs      | 17 +++++---
 datafusion/sqllogictest/test_files/select.slt      | 44 ++++++++++++++++++++
 3 files changed, 103 insertions(+), 6 deletions(-)

diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 9aeebb190e..c5d158d876 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1266,6 +1266,54 @@ impl Expr {
             Ok(Transformed::Yes(expr))
         })
     }
+
+    /// Returns true if some of this `exprs` subexpressions may not be 
evaluated
+    /// and thus any side effects (like divide by zero) may not be encountered
+    pub fn short_circuits(&self) -> bool {
+        match self {
+            Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
+                matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if 
*fun == BuiltinScalarFunction::Coalesce)
+            }
+            Expr::BinaryExpr(BinaryExpr { op, .. }) => {
+                matches!(op, Operator::And | Operator::Or)
+            }
+            Expr::Case { .. } => true,
+            // Use explicit pattern match instead of a default
+            // implementation, so that in the future if someone adds
+            // new Expr types, they will check here as well
+            Expr::AggregateFunction(..)
+            | Expr::Alias(..)
+            | Expr::Between(..)
+            | Expr::Cast(..)
+            | Expr::Column(..)
+            | Expr::Exists(..)
+            | Expr::GetIndexedField(..)
+            | Expr::GroupingSet(..)
+            | Expr::InList(..)
+            | Expr::InSubquery(..)
+            | Expr::IsFalse(..)
+            | Expr::IsNotFalse(..)
+            | Expr::IsNotNull(..)
+            | Expr::IsNotTrue(..)
+            | Expr::IsNotUnknown(..)
+            | Expr::IsNull(..)
+            | Expr::IsTrue(..)
+            | Expr::IsUnknown(..)
+            | Expr::Like(..)
+            | Expr::ScalarSubquery(..)
+            | Expr::ScalarVariable(_, _)
+            | Expr::SimilarTo(..)
+            | Expr::Not(..)
+            | Expr::Negative(..)
+            | Expr::OuterReferenceColumn(_, _)
+            | Expr::TryCast(..)
+            | Expr::Wildcard { .. }
+            | Expr::WindowFunction(..)
+            | Expr::Literal(..)
+            | Expr::Sort(..)
+            | Expr::Placeholder(..) => false,
+        }
+    }
 }
 
 // modifies expr if it is a placeholder with datatype of right
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index f29c7406ac..fe71171ce5 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -616,8 +616,8 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
 
     fn pre_visit(&mut self, expr: &Expr) -> Result<VisitRecursion> {
         // related to https://github.com/apache/arrow-datafusion/issues/8814
-        // If the expr contain volatile expression or is a case expression, 
skip it.
-        if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? {
+        // If the expr contain volatile expression or is a short-circuit 
expression, skip it.
+        if expr.short_circuits() || is_volatile_expression(expr)? {
             return Ok(VisitRecursion::Skip);
         }
         self.visit_stack
@@ -696,7 +696,13 @@ struct CommonSubexprRewriter<'a> {
 impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
     type N = Expr;
 
-    fn pre_visit(&mut self, _: &Expr) -> Result<RewriteRecursion> {
+    fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
+        // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to 
generate
+        // the `id_array`, which records the expr's identifier used to rewrite 
expr. So if we
+        // skip an expr in `ExprIdentifierVisitor`, we should skip it here, 
too.
+        if expr.short_circuits() || is_volatile_expression(expr)? {
+            return Ok(RewriteRecursion::Stop);
+        }
         if self.curr_index >= self.id_array.len()
             || self.max_series_number > self.id_array[self.curr_index].0
         {
@@ -1249,12 +1255,11 @@ mod test {
         let table_scan = test_table_scan()?;
 
         let plan = LogicalPlanBuilder::from(table_scan)
-            .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))?
+            .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))?
             .build()?;
 
         let expected = "Projection: test.a, test.b, test.c\
-        \n  Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND 
Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\
-        \n    Projection: Int32(1) > test.a AS Int32(1) > 
test.atest.aInt32(1), test.a, test.b, test.c\
+        \n  Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - 
Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n    
Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, 
test.b, test.c\
         \n      TableScan: test";
 
         assert_optimized_plan_eq(expected, &plan);
diff --git a/datafusion/sqllogictest/test_files/select.slt 
b/datafusion/sqllogictest/test_files/select.slt
index ca48c07b09..9ffddc6e2d 100644
--- a/datafusion/sqllogictest/test_files/select.slt
+++ b/datafusion/sqllogictest/test_files/select.slt
@@ -1129,5 +1129,49 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B;
 0 0
 0 0
 
+# Expressions that short circuit should not be refactored out as that may 
cause side effects (divide by zero)
+# at plan time that would not actually happen during execution, so the follow 
three query should not be extract
+# the common sub-expression
+query TT
+explain select coalesce(1, y/x), coalesce(2, y/x) from t;
+----
+logical_plan
+Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), 
CAST(t.y / t.x AS Int64))
+--TableScan: t projection=[x, y]
+physical_plan
+ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as 
coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as 
coalesce(Int64(2),t.y / t.x)]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+query TT
+EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;
+----
+logical_plan
+Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y 
> Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND 
Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) 
AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x
+--TableScan: t projection=[x, y]
+physical_plan
+ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) 
AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 
1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y 
< Int64(1) / t.x]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+query TT
+EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;
+----
+logical_plan
+Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y 
= Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR 
Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) 
OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x
+--TableScan: t projection=[x, y]
+physical_plan
+ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) 
OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 
/ CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < 
Int64(1) / t.x]
+--MemoryExec: partitions=1, partition_sizes=[1]
+
+# due to the reason describe in 
https://github.com/apache/arrow-datafusion/issues/8927,
+# the following queries will fail
+query error
+select coalesce(1, y/x), coalesce(2, y/x) from t;
+
+query error
+SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t;
+
+query error
+SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t;
+
 statement ok
 DROP TABLE t;

Reply via email to