This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 6b76a35cad consider volatile function in simply_expression (#13128)
6b76a35cad is described below
commit 6b76a35cadb17b33a5140f6f67e0491aabaa409e
Author: Lordworms <[email protected]>
AuthorDate: Fri Nov 1 11:12:22 2024 -0700
consider volatile function in simply_expression (#13128)
* consider volatile function in simply_expression
* refactor and fix bugs
* fix clippy
* refactor
* refactor
* format
* fix clippy
* Resolve logical conflict
* simplify more
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../src/simplify_expressions/expr_simplifier.rs | 73 ++++++++++++++++++++--
.../optimizer/src/simplify_expressions/utils.rs | 13 ++--
2 files changed, 78 insertions(+), 8 deletions(-)
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index ce6734616b..40be1f8539 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -862,8 +862,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
right,
}) if has_common_conjunction(&left, &right) => {
let lhs: IndexSet<Expr> =
iter_conjunction_owned(*left).collect();
- let (common, rhs): (Vec<_>, Vec<_>) =
- iter_conjunction_owned(*right).partition(|e|
lhs.contains(e));
+ let (common, rhs): (Vec<_>, Vec<_>) =
iter_conjunction_owned(*right)
+ .partition(|e| lhs.contains(e) && !e.is_volatile());
let new_rhs = rhs.into_iter().reduce(and);
let new_lhs = lhs.into_iter().filter(|e|
!common.contains(e)).reduce(and);
@@ -1682,8 +1682,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
}
fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
- let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect();
- iter_conjunction(rhs).any(|e| lhs.contains(&e))
+ let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect();
+ iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile())
}
// TODO: We might not need this after defer pattern for Box is stabilized.
https://github.com/rust-lang/rust/issues/87121
@@ -3978,4 +3978,69 @@ mod tests {
unimplemented!("not needed for tests")
}
}
+ #[derive(Debug)]
+ struct VolatileUdf {
+ signature: Signature,
+ }
+
+ impl VolatileUdf {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::exact(vec![], Volatility::Volatile),
+ }
+ }
+ }
+ impl ScalarUDFImpl for VolatileUdf {
+ fn as_any(&self) -> &dyn std::any::Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "VolatileUdf"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Int16)
+ }
+ }
+ #[test]
+ fn test_optimize_volatile_conditions() {
+ let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new()));
+ let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![]));
+ {
+ let expr = rand
+ .clone()
+ .eq(lit(0))
+ .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
+
+ assert_eq!(simplify(expr.clone()), expr);
+ }
+
+ {
+ let expr = col("column1")
+ .eq(lit(2))
+ .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0))));
+
+ assert_eq!(simplify(expr), col("column1").eq(lit(2)));
+ }
+
+ {
+ let expr =
(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col(
+ "column1",
+ )
+ .eq(lit(2))
+ .and(rand.clone().eq(lit(0))));
+
+ assert_eq!(
+ simplify(expr),
+ col("column1")
+ .eq(lit(2))
+ .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0))))
+ );
+ }
+ }
}
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 38bfc1a934..c30c3631c1 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -67,16 +67,21 @@ pub static POWS_OF_TEN: [i128; 38] = [
/// returns true if `needle` is found in a chain of search_op
/// expressions. Such as: (A AND B) AND C
-pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
+fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) ->
bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op
=> {
- expr_contains(left, needle, search_op)
- || expr_contains(right, needle, search_op)
+ expr_contains_inner(left, needle, search_op)
+ || expr_contains_inner(right, needle, search_op)
}
_ => expr == needle,
}
}
+/// check volatile calls and return if expr contains needle
+pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
+ expr_contains_inner(expr, needle, search_op) && !needle.is_volatile()
+}
+
/// Deletes all 'needles' or remains one 'needle' that are found in a chain of
xor
/// expressions. Such as: A ^ (A ^ (B ^ A))
pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool)
-> Expr {
@@ -206,7 +211,7 @@ pub fn is_false(expr: &Expr) -> bool {
/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool
{
- matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op
== &target_op && (needle == left.as_ref() || needle == right.as_ref()))
+ matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op
== &target_op && (needle == left.as_ref() || needle == right.as_ref()) &&
!needle.is_volatile())
}
/// returns true if `not_expr` is !`expr` (not)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]