This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b76a35cad consider volatile function in simply_expression (#13128)
6b76a35cad is described below

commit 6b76a35cadb17b33a5140f6f67e0491aabaa409e
Author: Lordworms <[email protected]>
AuthorDate: Fri Nov 1 11:12:22 2024 -0700

    consider volatile function in simply_expression (#13128)
    
    * consider volatile function in simply_expression
    
    * refactor and fix bugs
    
    * fix clippy
    
    * refactor
    
    * refactor
    
    * format
    
    * fix clippy
    
    * Resolve logical conflict
    
    * simplify more
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../src/simplify_expressions/expr_simplifier.rs    | 73 ++++++++++++++++++++--
 .../optimizer/src/simplify_expressions/utils.rs    | 13 ++--
 2 files changed, 78 insertions(+), 8 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index ce6734616b..40be1f8539 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -862,8 +862,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 right,
             }) if has_common_conjunction(&left, &right) => {
                 let lhs: IndexSet<Expr> = 
iter_conjunction_owned(*left).collect();
-                let (common, rhs): (Vec<_>, Vec<_>) =
-                    iter_conjunction_owned(*right).partition(|e| 
lhs.contains(e));
+                let (common, rhs): (Vec<_>, Vec<_>) = 
iter_conjunction_owned(*right)
+                    .partition(|e| lhs.contains(e) && !e.is_volatile());
 
                 let new_rhs = rhs.into_iter().reduce(and);
                 let new_lhs = lhs.into_iter().filter(|e| 
!common.contains(e)).reduce(and);
@@ -1682,8 +1682,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
 }
 
 fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
-    let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect();
-    iter_conjunction(rhs).any(|e| lhs.contains(&e))
+    let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
+    iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
 }
 
 // TODO: We might not need this after defer pattern for Box is stabilized. 
https://github.com/rust-lang/rust/issues/87121
@@ -3978,4 +3978,69 @@ mod tests {
             unimplemented!("not needed for tests")
         }
     }
+    #[derive(Debug)]
+    struct VolatileUdf {
+        signature: Signature,
+    }
+
+    impl VolatileUdf {
+        pub fn new() -> Self {
+            Self {
+                signature: Signature::exact(vec![], Volatility::Volatile),
+            }
+        }
+    }
+    impl ScalarUDFImpl for VolatileUdf {
+        fn as_any(&self) -> &dyn std::any::Any {
+            self
+        }
+
+        fn name(&self) -> &str {
+            "VolatileUdf"
+        }
+
+        fn signature(&self) -> &Signature {
+            &self.signature
+        }
+
+        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+            Ok(DataType::Int16)
+        }
+    }
+    #[test]
+    fn test_optimize_volatile_conditions() {
+        let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
+        let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
+        {
+            let expr = rand
+                .clone()
+                .eq(lit(0))
+                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
+
+            assert_eq!(simplify(expr.clone()), expr);
+        }
+
+        {
+            let expr = col("column1")
+                .eq(lit(2))
+                .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
+
+            assert_eq!(simplify(expr), col("column1").eq(lit(2)));
+        }
+
+        {
+            let expr = 
(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
+                "column1",
+            )
+            .eq(lit(2))
+            .and(rand.clone().eq(lit(0))));
+
+            assert_eq!(
+                simplify(expr),
+                col("column1")
+                    .eq(lit(2))
+                    .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
+            );
+        }
+    }
 }
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs 
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 38bfc1a934..c30c3631c1 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -67,16 +67,21 @@ pub static POWS_OF_TEN: [i128; 38] = [
 
 /// returns true if `needle` is found in a chain of search_op
 /// expressions. Such as: (A AND B) AND C
-pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
+fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> 
bool {
     match expr {
         Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op 
=> {
-            expr_contains(left, needle, search_op)
-                || expr_contains(right, needle, search_op)
+            expr_contains_inner(left, needle, search_op)
+                || expr_contains_inner(right, needle, search_op)
         }
         _ => expr == needle,
     }
 }
 
+/// check volatile calls and return if expr contains needle
+pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
+    expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
+}
+
 /// Deletes all 'needles' or remains one 'needle' that are found in a chain of 
xor
 /// expressions. Such as: A ^ (A ^ (B ^ A))
 pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) 
-> Expr {
@@ -206,7 +211,7 @@ pub fn is_false(expr: &Expr) -> bool {
 
 /// returns true if `haystack` looks like (needle OP X) or (X OP needle)
 pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool 
{
-    matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op 
== &target_op && (needle == left.as_ref() || needle == right.as_ref()))
+    matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op 
== &target_op && (needle == left.as_ref() || needle == right.as_ref()) && 
!needle.is_volatile())
 }
 
 /// returns true if `not_expr` is !`expr` (not)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to