This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5909866bba fix: volatile expressions should not be target of common
subexpt elimination (#8520)
5909866bba is described below
commit 5909866bba3e23e5f807972b84de526a4eb16c4c
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Dec 14 00:53:56 2023 -0800
fix: volatile expressions should not be target of common subexpt
elimination (#8520)
* fix: volatile expressions should not be target of common subexpt
elimination
* Fix clippy
* For review
* Return error for unresolved scalar function
* Improve error message
---
datafusion/expr/src/expr.rs | 75 +++++++++++++++++++++-
.../optimizer/src/common_subexpr_eliminate.rs | 18 ++++--
datafusion/sqllogictest/test_files/functions.slt | 6 ++
3 files changed, 91 insertions(+), 8 deletions(-)
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 958f4f4a34..f0aab95b8f 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -373,6 +373,24 @@ impl ScalarFunctionDefinition {
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
}
}
+
+ /// Whether this function is volatile, i.e. whether it can return
different results
+ /// when evaluated multiple times with the same input.
+ pub fn is_volatile(&self) -> Result<bool> {
+ match self {
+ ScalarFunctionDefinition::BuiltIn(fun) => {
+ Ok(fun.volatility() == crate::Volatility::Volatile)
+ }
+ ScalarFunctionDefinition::UDF(udf) => {
+ Ok(udf.signature().volatility == crate::Volatility::Volatile)
+ }
+ ScalarFunctionDefinition::Name(func) => {
+ internal_err!(
+ "Cannot determine volatility of unresolved function:
{func}"
+ )
+ }
+ }
+ }
}
impl ScalarFunction {
@@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
.join(", "))
}
+/// Whether the given expression is volatile, i.e. whether it can return
different results
+/// when evaluated multiple times with the same input.
+pub fn is_volatile(expr: &Expr) -> Result<bool> {
+ match expr {
+ Expr::ScalarFunction(func) => func.func_def.is_volatile(),
+ _ => Ok(false),
+ }
+}
+
#[cfg(test)]
mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
- use crate::{case, lit, Expr};
+ use crate::{
+ case, lit, BuiltinScalarFunction, ColumnarValue, Expr,
ReturnTypeFunction,
+ ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF,
Signature,
+ Volatility,
+ };
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
+ use std::sync::Arc;
#[test]
fn format_case_when() -> Result<()> {
@@ -1800,4 +1832,45 @@ mod test {
"UInt32(1) OR UInt32(2)"
);
}
+
+ #[test]
+ fn test_is_volatile_scalar_func_definition() {
+ // BuiltIn
+ assert!(
+ ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
+ .is_volatile()
+ .unwrap()
+ );
+ assert!(
+ !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
+ .is_volatile()
+ .unwrap()
+ );
+
+ // UDF
+ let return_type: ReturnTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
+ let fun: ScalarFunctionImplementation =
+ Arc::new(move |_|
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
+ let udf = Arc::new(ScalarUDF::new(
+ "TestScalarUDF",
+ &Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
+ &return_type,
+ &fun,
+ ));
+ assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+
+ let udf = Arc::new(ScalarUDF::new(
+ "TestScalarUDF",
+ &Signature::uniform(1, vec![DataType::Float32],
Volatility::Volatile),
+ &return_type,
+ &fun,
+ ));
+ assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+
+ // Unresolved function
+ ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
+ .is_volatile()
+ .expect_err("Shouldn't determine volatility of unresolved
function");
+ }
}
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 1d21407a69..1e089257c6 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -29,7 +29,7 @@ use datafusion_common::tree_node::{
use datafusion_common::{
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError,
Result,
};
-use datafusion_expr::expr::Alias;
+use datafusion_expr::expr::{is_volatile, Alias};
use datafusion_expr::logical_plan::{
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
};
@@ -113,6 +113,8 @@ impl CommonSubexprEliminate {
let Projection { expr, input, .. } = projection;
let input_schema = Arc::clone(input.schema());
let mut expr_set = ExprSet::new();
+
+ // Visit expr list and build expr identifier to occuring count map
(`expr_set`).
let arrays = to_arrays(expr, input_schema, &mut expr_set,
ExprMask::Normal)?;
let (mut new_expr, new_input) =
@@ -516,7 +518,7 @@ enum ExprMask {
}
impl ExprMask {
- fn ignores(&self, expr: &Expr) -> bool {
+ fn ignores(&self, expr: &Expr) -> Result<bool> {
let is_normal_minus_aggregates = matches!(
expr,
Expr::Literal(..)
@@ -527,12 +529,14 @@ impl ExprMask {
| Expr::Wildcard { .. }
);
+ let is_volatile = is_volatile(expr)?;
+
let is_aggr = matches!(expr, Expr::AggregateFunction(..));
- match self {
- Self::Normal => is_normal_minus_aggregates || is_aggr,
- Self::NormalAndAggregates => is_normal_minus_aggregates,
- }
+ Ok(match self {
+ Self::Normal => is_volatile || is_normal_minus_aggregates ||
is_aggr,
+ Self::NormalAndAggregates => is_volatile ||
is_normal_minus_aggregates,
+ })
}
}
@@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
let (idx, sub_expr_desc) = self.pop_enter_mark();
// skip exprs should not be recognize.
- if self.expr_mask.ignores(expr) {
+ if self.expr_mask.ignores(expr)? {
self.id_array[idx].0 = self.series_number;
let desc = Self::desc_expr(expr);
self.visit_stack.push(VisitRecord::ExprItem(desc));
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index 4f55ea316b..1903088b07 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -995,3 +995,9 @@ query ?
SELECT find_in_set(NULL, NULL)
----
NULL
+
+# Verify that multiple calls to volatile functions like `random()` are not
combined / optimized away
+query B
+SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random()
r2) WHERE r1 > 0 AND r2 > 0)
+----
+false