This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5909866bba fix: volatile expressions should not be target of common 
subexpt elimination (#8520)
5909866bba is described below

commit 5909866bba3e23e5f807972b84de526a4eb16c4c
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Dec 14 00:53:56 2023 -0800

    fix: volatile expressions should not be target of common subexpt 
elimination (#8520)
    
    * fix: volatile expressions should not be target of common subexpt 
elimination
    
    * Fix clippy
    
    * For review
    
    * Return error for unresolved scalar function
    
    * Improve error message
---
 datafusion/expr/src/expr.rs                        | 75 +++++++++++++++++++++-
 .../optimizer/src/common_subexpr_eliminate.rs      | 18 ++++--
 datafusion/sqllogictest/test_files/functions.slt   |  6 ++
 3 files changed, 91 insertions(+), 8 deletions(-)

diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 958f4f4a34..f0aab95b8f 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -373,6 +373,24 @@ impl ScalarFunctionDefinition {
             ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
         }
     }
+
+    /// Whether this function is volatile, i.e. whether it can return 
different results
+    /// when evaluated multiple times with the same input.
+    pub fn is_volatile(&self) -> Result<bool> {
+        match self {
+            ScalarFunctionDefinition::BuiltIn(fun) => {
+                Ok(fun.volatility() == crate::Volatility::Volatile)
+            }
+            ScalarFunctionDefinition::UDF(udf) => {
+                Ok(udf.signature().volatility == crate::Volatility::Volatile)
+            }
+            ScalarFunctionDefinition::Name(func) => {
+                internal_err!(
+                    "Cannot determine volatility of unresolved function: 
{func}"
+                )
+            }
+        }
+    }
 }
 
 impl ScalarFunction {
@@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
         .join(", "))
 }
 
+/// Whether the given expression is volatile, i.e. whether it can return 
different results
+/// when evaluated multiple times with the same input.
+pub fn is_volatile(expr: &Expr) -> Result<bool> {
+    match expr {
+        Expr::ScalarFunction(func) => func.func_def.is_volatile(),
+        _ => Ok(false),
+    }
+}
+
 #[cfg(test)]
 mod test {
     use crate::expr::Cast;
     use crate::expr_fn::col;
-    use crate::{case, lit, Expr};
+    use crate::{
+        case, lit, BuiltinScalarFunction, ColumnarValue, Expr, 
ReturnTypeFunction,
+        ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, 
Signature,
+        Volatility,
+    };
     use arrow::datatypes::DataType;
     use datafusion_common::Column;
     use datafusion_common::{Result, ScalarValue};
+    use std::sync::Arc;
 
     #[test]
     fn format_case_when() -> Result<()> {
@@ -1800,4 +1832,45 @@ mod test {
             "UInt32(1) OR UInt32(2)"
         );
     }
+
+    #[test]
+    fn test_is_volatile_scalar_func_definition() {
+        // BuiltIn
+        assert!(
+            ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
+                .is_volatile()
+                .unwrap()
+        );
+        assert!(
+            !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
+                .is_volatile()
+                .unwrap()
+        );
+
+        // UDF
+        let return_type: ReturnTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
+        let fun: ScalarFunctionImplementation =
+            Arc::new(move |_| 
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
+        let udf = Arc::new(ScalarUDF::new(
+            "TestScalarUDF",
+            &Signature::uniform(1, vec![DataType::Float32], 
Volatility::Stable),
+            &return_type,
+            &fun,
+        ));
+        assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+
+        let udf = Arc::new(ScalarUDF::new(
+            "TestScalarUDF",
+            &Signature::uniform(1, vec![DataType::Float32], 
Volatility::Volatile),
+            &return_type,
+            &fun,
+        ));
+        assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+
+        // Unresolved function
+        ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
+            .is_volatile()
+            .expect_err("Shouldn't determine volatility of unresolved 
function");
+    }
 }
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 1d21407a69..1e089257c6 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -29,7 +29,7 @@ use datafusion_common::tree_node::{
 use datafusion_common::{
     internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, 
Result,
 };
-use datafusion_expr::expr::Alias;
+use datafusion_expr::expr::{is_volatile, Alias};
 use datafusion_expr::logical_plan::{
     Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
 };
@@ -113,6 +113,8 @@ impl CommonSubexprEliminate {
         let Projection { expr, input, .. } = projection;
         let input_schema = Arc::clone(input.schema());
         let mut expr_set = ExprSet::new();
+
+        // Visit expr list and build expr identifier to occuring count map 
(`expr_set`).
         let arrays = to_arrays(expr, input_schema, &mut expr_set, 
ExprMask::Normal)?;
 
         let (mut new_expr, new_input) =
@@ -516,7 +518,7 @@ enum ExprMask {
 }
 
 impl ExprMask {
-    fn ignores(&self, expr: &Expr) -> bool {
+    fn ignores(&self, expr: &Expr) -> Result<bool> {
         let is_normal_minus_aggregates = matches!(
             expr,
             Expr::Literal(..)
@@ -527,12 +529,14 @@ impl ExprMask {
                 | Expr::Wildcard { .. }
         );
 
+        let is_volatile = is_volatile(expr)?;
+
         let is_aggr = matches!(expr, Expr::AggregateFunction(..));
 
-        match self {
-            Self::Normal => is_normal_minus_aggregates || is_aggr,
-            Self::NormalAndAggregates => is_normal_minus_aggregates,
-        }
+        Ok(match self {
+            Self::Normal => is_volatile || is_normal_minus_aggregates || 
is_aggr,
+            Self::NormalAndAggregates => is_volatile || 
is_normal_minus_aggregates,
+        })
     }
 }
 
@@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
 
         let (idx, sub_expr_desc) = self.pop_enter_mark();
         // skip exprs should not be recognize.
-        if self.expr_mask.ignores(expr) {
+        if self.expr_mask.ignores(expr)? {
             self.id_array[idx].0 = self.series_number;
             let desc = Self::desc_expr(expr);
             self.visit_stack.push(VisitRecord::ExprItem(desc));
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 4f55ea316b..1903088b07 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -995,3 +995,9 @@ query ?
 SELECT find_in_set(NULL, NULL)
 ----
 NULL
+
+# Verify that multiple calls to volatile functions like `random()` are not 
combined / optimized away
+query B
+SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() 
r2) WHERE r1 > 0 AND r2 > 0)
+----
+false

Reply via email to