This is an automated email from the ASF dual-hosted git repository.

findepi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 070517a87f Derive UDF equality from PartialEq, Hash (#16842)
070517a87f is described below

commit 070517a87f15fc9e5f64423e29ee155d4a6a868d
Author: Piotr Findeisen <piotr.findei...@gmail.com>
AuthorDate: Fri Jul 25 21:21:44 2025 +0200

    Derive UDF equality from PartialEq, Hash (#16842)
    
    * Derive UDF equality from PartialEq, Hash
    
    Reduce boilerplate in cases where implementation of
    `{ScalarUDFImpl,AggregateUDFImpl,WindowUDFImpl}::{equals,hash_code}` can
    be derived using standard `PartialEq` and `Hash` traits.
    
    This is code complexity reduction.
    
    While valuable on its own, this also prepares for more automatic
    derivation of UDF equals/hash and/or removal of default implementations
    (which currently are error-prone).
    
    * udf_equals_hash example
    
    * test udf_equals_hash
    
    * empty: roll the dice 🎲
---
 .../user_defined/user_defined_scalar_functions.rs  | 171 +++++--------------
 datafusion/expr/src/async_udf.rs                   |  38 +++--
 datafusion/expr/src/expr_fn.rs                     |  67 ++++----
 datafusion/expr/src/udf.rs                         |  33 ++--
 datafusion/expr/src/utils.rs                       | 182 ++++++++++++++++++++-
 datafusion/ffi/src/udf/mod.rs                      |  66 ++++----
 datafusion/proto/tests/cases/mod.rs                |  32 +---
 datafusion/sql/tests/sql_integration.rs            |  38 +----
 8 files changed, 335 insertions(+), 292 deletions(-)

diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index c5f9bdeb69..dd8283613a 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -17,7 +17,7 @@
 
 use std::any::Any;
 use std::collections::HashMap;
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::{Hash, Hasher};
 use std::sync::Arc;
 
 use arrow::array::{as_string_array, create_array, record_batch, Int8Array, 
UInt64Array};
@@ -43,9 +43,9 @@ use datafusion_common::{
 use datafusion_expr::expr::FieldMetadata;
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
 use datafusion_expr::{
-    lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, 
CreateFunctionBody,
-    LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, 
ScalarFunctionArgs,
-    ScalarUDF, ScalarUDFImpl, Signature, Volatility,
+    lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, 
CreateFunction,
+    CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, 
ReturnFieldArgs,
+    ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
 };
 use datafusion_functions_nested::range::range_udf;
 use parking_lot::Mutex;
@@ -181,6 +181,7 @@ async fn scalar_udf() -> Result<()> {
     Ok(())
 }
 
+#[derive(PartialEq, Hash)]
 struct Simple0ArgsScalarUDF {
     name: String,
     signature: Signature,
@@ -218,33 +219,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
         Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            signature,
-            return_type,
-        } = self;
-        name == &other.name
-            && signature == &other.signature
-            && return_type == &other.return_type
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            signature,
-            return_type,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        signature.hash(&mut hasher);
-        return_type.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 #[tokio::test]
@@ -517,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> 
Result<()> {
 }
 
 /// Volatile UDF that should append a different value to each row
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash)]
 struct AddIndexToStringVolatileScalarUDF {
     name: String,
     signature: Signature,
@@ -586,33 +561,7 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
         Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            signature,
-            return_type,
-        } = self;
-        name == &other.name
-            && signature == &other.signature
-            && return_type == &other.return_type
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            signature,
-            return_type,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        signature.hash(&mut hasher);
-        return_type.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 #[tokio::test]
@@ -992,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory {
 //
 // it also defines custom [ScalarUDFImpl::simplify()]
 // to replace ScalarUDF expression with one instance contains.
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash)]
 struct ScalarFunctionWrapper {
     name: String,
     expr: Expr,
@@ -1031,37 +980,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
         Ok(ExprSimplifyResult::Simplified(replacement))
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            expr,
-            signature,
-            return_type,
-        } = self;
-        name == &other.name
-            && expr == &other.expr
-            && signature == &other.signature
-            && return_type == &other.return_type
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            expr,
-            signature,
-            return_type,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        expr.hash(&mut hasher);
-        signature.hash(&mut hasher);
-        return_type.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 impl ScalarFunctionWrapper {
@@ -1296,6 +1215,21 @@ struct MyRegexUdf {
     regex: Regex,
 }
 
+impl PartialEq for MyRegexUdf {
+    fn eq(&self, other: &Self) -> bool {
+        let Self { signature, regex } = self;
+        signature == &other.signature && regex.as_str() == other.regex.as_str()
+    }
+}
+
+impl Hash for MyRegexUdf {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let Self { signature, regex } = self;
+        signature.hash(state);
+        regex.as_str().hash(state);
+    }
+}
+
 impl MyRegexUdf {
     fn new(pattern: &str) -> Self {
         Self {
@@ -1348,19 +1282,7 @@ impl ScalarUDFImpl for MyRegexUdf {
         }
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
-            self.regex.as_str() == other.regex.as_str()
-        } else {
-            false
-        }
-    }
-
-    fn hash_value(&self) -> u64 {
-        let hasher = &mut DefaultHasher::new();
-        self.regex.as_str().hash(hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 #[tokio::test]
@@ -1458,13 +1380,25 @@ async fn plan_and_collect(ctx: &SessionContext, sql: 
&str) -> Result<Vec<RecordB
     ctx.sql(sql).await?.collect().await
 }
 
-#[derive(Debug)]
+#[derive(Debug, PartialEq)]
 struct MetadataBasedUdf {
     name: String,
     signature: Signature,
     metadata: HashMap<String, String>,
 }
 
+impl Hash for MetadataBasedUdf {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let Self {
+            name,
+            signature,
+            metadata: _, // unhashable
+        } = self;
+        name.hash(state);
+        signature.hash(state);
+    }
+}
+
 impl MetadataBasedUdf {
     fn new(metadata: HashMap<String, String>) -> Self {
         // The name we return must be unique. Otherwise we will not call 
distinct
@@ -1537,32 +1471,7 @@ impl ScalarUDFImpl for MetadataBasedUdf {
         }
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            signature,
-            metadata,
-        } = self;
-        name == &other.name
-            && signature == &other.signature
-            && metadata == &other.metadata
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            signature,
-            metadata: _, // unhashable
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        signature.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 #[tokio::test]
diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs
index 24ed124bb2..753ad7b778 100644
--- a/datafusion/expr/src/async_udf.rs
+++ b/datafusion/expr/src/async_udf.rs
@@ -15,7 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
+use crate::utils::{arc_ptr_eq, arc_ptr_hash};
+use crate::{
+    udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, 
ScalarUDFImpl,
+};
 use arrow::array::ArrayRef;
 use arrow::datatypes::{DataType, FieldRef};
 use async_trait::async_trait;
@@ -26,7 +29,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
 use datafusion_expr_common::signature::Signature;
 use std::any::Any;
 use std::fmt::{Debug, Display};
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::{Hash, Hasher};
 use std::sync::Arc;
 
 /// A scalar UDF that can invoke using async methods
@@ -62,6 +65,21 @@ pub struct AsyncScalarUDF {
     inner: Arc<dyn AsyncScalarUDFImpl>,
 }
 
+impl PartialEq for AsyncScalarUDF {
+    fn eq(&self, other: &Self) -> bool {
+        let Self { inner } = self;
+        // TODO when MSRV >= 1.86.0, switch to 
`inner.equals(other.inner.as_ref())` leveraging trait upcasting.
+        arc_ptr_eq(inner, &other.inner)
+    }
+}
+
+impl Hash for AsyncScalarUDF {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let Self { inner } = self;
+        arc_ptr_hash(inner, state);
+    }
+}
+
 impl AsyncScalarUDF {
     pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self {
         Self { inner }
@@ -113,21 +131,7 @@ impl ScalarUDFImpl for AsyncScalarUDF {
         internal_err!("async functions should not be called directly")
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self { inner } = self;
-        // TODO when MSRV >= 1.86.0, switch to 
`inner.equals(other.inner.as_ref())` leveraging trait upcasting
-        Arc::ptr_eq(inner, &other.inner)
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self { inner } = self;
-        let mut hasher = DefaultHasher::new();
-        Arc::as_ptr(inner).hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 impl Display for AsyncScalarUDF {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index c0351a9dca..1d8d183807 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -26,10 +26,11 @@ use crate::function::{
     StateFieldsArgs,
 };
 use crate::select_expr::SelectExpr;
+use crate::utils::{arc_ptr_eq, arc_ptr_hash};
 use crate::{
     conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
-    AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, 
ScalarFunctionArgs,
-    ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
+    udf_equals_hash, AggregateUDF, Expr, LogicalPlan, Operator, 
PartitionEvaluator,
+    ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, 
Volatility,
 };
 use crate::{
     AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, 
WindowUDFImpl,
@@ -409,6 +410,36 @@ pub struct SimpleScalarUDF {
     fun: ScalarFunctionImplementation,
 }
 
+impl PartialEq for SimpleScalarUDF {
+    fn eq(&self, other: &Self) -> bool {
+        let Self {
+            name,
+            signature,
+            return_type,
+            fun,
+        } = self;
+        name == &other.name
+            && signature == &other.signature
+            && return_type == &other.return_type
+            && arc_ptr_eq(fun, &other.fun)
+    }
+}
+
+impl Hash for SimpleScalarUDF {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let Self {
+            name,
+            signature,
+            return_type,
+            fun,
+        } = self;
+        name.hash(state);
+        signature.hash(state);
+        return_type.hash(state);
+        arc_ptr_hash(fun, state);
+    }
+}
+
 impl Debug for SimpleScalarUDF {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
         f.debug_struct("SimpleScalarUDF")
@@ -476,37 +507,7 @@ impl ScalarUDFImpl for SimpleScalarUDF {
         (self.fun)(&args.args)
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            signature,
-            return_type,
-            fun,
-        } = self;
-        name == &other.name
-            && signature == &other.signature
-            && return_type == &other.return_type
-            && Arc::ptr_eq(fun, &other.fun)
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            signature,
-            return_type,
-            fun,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        signature.hash(&mut hasher);
-        return_type.hash(&mut hasher);
-        Arc::as_ptr(fun).hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 /// Creates a new UDAF with a specific signature, state type and return type.
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 3a94981ae4..171c4e041f 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -21,7 +21,7 @@ use crate::async_udf::AsyncScalarUDF;
 use crate::expr::schema_name_from_exprs_comma_separated_without_space;
 use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
 use crate::sort_properties::{ExprProperties, SortProperties};
-use crate::{ColumnarValue, Documentation, Expr, Signature};
+use crate::{udf_equals_hash, ColumnarValue, Documentation, Expr, Signature};
 use arrow::datatypes::{DataType, Field, FieldRef};
 use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue};
 use datafusion_expr_common::interval_arithmetic::Interval;
@@ -747,6 +747,21 @@ struct AliasedScalarUDFImpl {
     aliases: Vec<String>,
 }
 
+impl PartialEq for AliasedScalarUDFImpl {
+    fn eq(&self, other: &Self) -> bool {
+        let Self { inner, aliases } = self;
+        inner.equals(other.inner.as_ref()) && aliases == &other.aliases
+    }
+}
+
+impl Hash for AliasedScalarUDFImpl {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let Self { inner, aliases } = self;
+        inner.hash_value().hash(state);
+        aliases.hash(state);
+    }
+}
+
 impl AliasedScalarUDFImpl {
     pub fn new(
         inner: Arc<dyn ScalarUDFImpl>,
@@ -831,21 +846,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
         self.inner.coerce_types(arg_types)
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        if let Some(other) = 
other.as_any().downcast_ref::<AliasedScalarUDFImpl>() {
-            self.inner.equals(other.inner.as_ref()) && self.aliases == 
other.aliases
-        } else {
-            false
-        }
-    }
-
-    fn hash_value(&self) -> u64 {
-        let hasher = &mut DefaultHasher::new();
-        std::any::type_name::<Self>().hash(hasher);
-        self.inner.hash_value().hash(hasher);
-        self.aliases.hash(hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 
     fn documentation(&self) -> Option<&Documentation> {
         self.inner.documentation()
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 8950f5e450..e554152328 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -19,6 +19,7 @@
 
 use std::cmp::Ordering;
 use std::collections::{BTreeSet, HashSet};
+use std::hash::Hasher;
 use std::sync::Arc;
 
 use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
@@ -1260,6 +1261,94 @@ pub fn collect_subquery_cols(
     })
 }
 
+/// Generates implementation of `equals` and `hash_value` methods for a trait, 
delegating
+/// to [`PartialEq`] and [`Hash`] implementations on Self.
+/// Meant to be used with traits representing user-defined functions (UDFs).
+///
+/// Example showing generation of [`ScalarUDFImpl::equals`] and 
[`ScalarUDFImpl::hash_value`]
+/// implementations.
+///
+/// ```
+/// # use arrow::datatypes::DataType;
+/// # use datafusion_expr::{udf_equals_hash, ScalarFunctionArgs, 
ScalarUDFImpl};
+/// # use datafusion_expr_common::columnar_value::ColumnarValue;
+/// # use datafusion_expr_common::signature::Signature;
+/// # use std::any::Any;
+///
+/// // Implementing PartialEq & Hash is a prerequisite for using this macro,
+/// // but the implementation can be derived.
+/// #[derive(Debug, PartialEq, Hash)]
+/// struct VarcharToTimestampTz {
+///     safe: bool,
+/// }
+///
+/// impl ScalarUDFImpl for VarcharToTimestampTz {
+///     /* other methods omitted for brevity */
+/// #    fn as_any(&self) -> &dyn Any {
+/// #        self
+/// #    }
+/// #
+/// #    fn name(&self) -> &str {
+/// #        "varchar_to_timestamp_tz"
+/// #    }
+/// #
+/// #    fn signature(&self) -> &Signature {
+/// #        todo!()
+/// #    }
+/// #
+/// #    fn return_type(
+/// #        &self,
+/// #        _arg_types: &[DataType],
+/// #    ) -> datafusion_common::Result<DataType> {
+/// #        todo!()
+/// #    }
+/// #
+/// #    fn invoke_with_args(
+/// #        &self,
+/// #        args: ScalarFunctionArgs,
+/// #    ) -> datafusion_common::Result<ColumnarValue> {
+/// #        todo!()
+/// #    }
+/// #
+///     udf_equals_hash!(ScalarUDFImpl);
+/// }
+/// ```
+///
+/// [`ScalarUDFImpl::equals`]: crate::ScalarUDFImpl::equals
+/// [`ScalarUDFImpl::hash_value`]: crate::ScalarUDFImpl::hash_value
+#[macro_export]
+macro_rules! udf_equals_hash {
+    ($udf_type:tt) => {
+        fn equals(&self, other: &dyn $udf_type) -> bool {
+            use ::core::any::Any;
+            use ::core::cmp::PartialEq;
+            let Some(other) = <dyn Any + 
'static>::downcast_ref::<Self>(other.as_any())
+            else {
+                return false;
+            };
+            PartialEq::eq(self, other)
+        }
+
+        fn hash_value(&self) -> u64 {
+            use ::std::any::type_name;
+            use ::std::hash::{DefaultHasher, Hash, Hasher};
+            let hasher = &mut DefaultHasher::new();
+            type_name::<Self>().hash(hasher);
+            Hash::hash(self, hasher);
+            Hasher::finish(hasher)
+        }
+    };
+}
+
+pub fn arc_ptr_eq<T: ?Sized>(a: &Arc<T>, b: &Arc<T>) -> bool {
+    // Not necessarily equivalent to `Arc::ptr_eq` for fat pointers.
+    std::ptr::eq(Arc::as_ptr(a), Arc::as_ptr(b))
+}
+
+pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) {
+    std::ptr::hash(Arc::as_ptr(a), hasher)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1268,9 +1357,13 @@ mod tests {
         expr::WindowFunction,
         expr_vec_fmt, grouping_set, lit, rollup,
         test::function_stub::{max_udaf, min_udaf, sum_udaf},
-        Cast, ExprFunctionExt, WindowFunctionDefinition,
+        Cast, ExprFunctionExt, ScalarFunctionArgs, ScalarUDFImpl,
+        WindowFunctionDefinition,
     };
     use arrow::datatypes::{UnionFields, UnionMode};
+    use datafusion_expr_common::columnar_value::ColumnarValue;
+    use datafusion_expr_common::signature::Volatility;
+    use std::any::Any;
 
     #[test]
     fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
@@ -1690,4 +1783,91 @@ mod tests {
             DataType::List(Arc::new(Field::new("my_union", union_type, true)));
         assert!(!can_hash(&list_union_type));
     }
+
+    #[test]
+    fn test_udf_equals_hash() {
+        #[derive(Debug, PartialEq, Hash)]
+        struct StatefulFunctionWithEqHash {
+            signature: Signature,
+            state: bool,
+        }
+        impl ScalarUDFImpl for StatefulFunctionWithEqHash {
+            fn as_any(&self) -> &dyn Any {
+                self
+            }
+            fn name(&self) -> &str {
+                "StatefulFunctionWithEqHash"
+            }
+            fn signature(&self) -> &Signature {
+                &self.signature
+            }
+            fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> 
{
+                todo!()
+            }
+            fn invoke_with_args(
+                &self,
+                _args: ScalarFunctionArgs,
+            ) -> Result<ColumnarValue> {
+                todo!()
+            }
+        }
+
+        #[derive(Debug, PartialEq, Hash)]
+        struct StatefulFunctionWithEqHashWithUdfEqualsHash {
+            signature: Signature,
+            state: bool,
+        }
+        impl ScalarUDFImpl for StatefulFunctionWithEqHashWithUdfEqualsHash {
+            fn as_any(&self) -> &dyn Any {
+                self
+            }
+            fn name(&self) -> &str {
+                "StatefulFunctionWithEqHashWithUdfEqualsHash"
+            }
+            fn signature(&self) -> &Signature {
+                &self.signature
+            }
+            fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> 
{
+                todo!()
+            }
+            fn invoke_with_args(
+                &self,
+                _args: ScalarFunctionArgs,
+            ) -> Result<ColumnarValue> {
+                todo!()
+            }
+            udf_equals_hash!(ScalarUDFImpl);
+        }
+
+        let signature = Signature::exact(vec![DataType::Utf8], 
Volatility::Immutable);
+
+        // Sadly, without `udf_equals_hash!` macro, the equals and hash_value 
ignore state fields,
+        // even though the struct implements `PartialEq` and `Hash`.
+        let a: Box<dyn ScalarUDFImpl> = Box::new(StatefulFunctionWithEqHash {
+            signature: signature.clone(),
+            state: true,
+        });
+        let b: Box<dyn ScalarUDFImpl> = Box::new(StatefulFunctionWithEqHash {
+            signature: signature.clone(),
+            state: false,
+        });
+        assert!(a.equals(b.as_ref()));
+        assert_eq!(a.hash_value(), b.hash_value());
+
+        // With udf_equals_hash! macro, the equals and hash_value compare the 
state.
+        // even though the struct implements `PartialEq` and `Hash`.
+        let a: Box<dyn ScalarUDFImpl> =
+            Box::new(StatefulFunctionWithEqHashWithUdfEqualsHash {
+                signature: signature.clone(),
+                state: true,
+            });
+        let b: Box<dyn ScalarUDFImpl> =
+            Box::new(StatefulFunctionWithEqHashWithUdfEqualsHash {
+                signature: signature.clone(),
+                state: false,
+            });
+        assert!(!a.equals(b.as_ref()));
+        // This could be true, but it's very unlikely that boolean true and 
false hash the same
+        assert_ne!(a.hash_value(), b.hash_value());
+    }
 }
diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs
index 09cd0df128..1c835bd3ec 100644
--- a/datafusion/ffi/src/udf/mod.rs
+++ b/datafusion/ffi/src/udf/mod.rs
@@ -32,7 +32,7 @@ use arrow::{
     ffi::{from_ffi, to_ffi, FFI_ArrowSchema},
 };
 use arrow_schema::FieldRef;
-use datafusion::logical_expr::ReturnFieldArgs;
+use datafusion::logical_expr::{udf_equals_hash, ReturnFieldArgs};
 use datafusion::{
     error::DataFusionError,
     logical_expr::type_coercion::functions::data_types_with_scalar_udf,
@@ -46,7 +46,7 @@ use datafusion::{
 use return_type_args::{
     FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned,
 };
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::{Hash, Hasher};
 use std::{ffi::c_void, sync::Arc};
 
 pub mod return_type_args;
@@ -287,6 +287,36 @@ pub struct ForeignScalarUDF {
 unsafe impl Send for ForeignScalarUDF {}
 unsafe impl Sync for ForeignScalarUDF {}
 
+impl PartialEq for ForeignScalarUDF {
+    fn eq(&self, other: &Self) -> bool {
+        let Self {
+            name,
+            aliases,
+            udf,
+            signature,
+        } = self;
+        name == &other.name
+            && aliases == &other.aliases
+            && std::ptr::eq(udf, &other.udf)
+            && signature == &other.signature
+    }
+}
+
+impl Hash for ForeignScalarUDF {
+    fn hash<H: Hasher>(&self, state: &mut H) {
+        let Self {
+            name,
+            aliases,
+            udf,
+            signature,
+        } = self;
+        name.hash(state);
+        aliases.hash(state);
+        std::ptr::hash(udf, state);
+        signature.hash(state);
+    }
+}
+
 impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF {
     type Error = DataFusionError;
 
@@ -409,37 +439,7 @@ impl ScalarUDFImpl for ForeignScalarUDF {
         }
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            aliases,
-            udf,
-            signature,
-        } = self;
-        name == &other.name
-            && aliases == &other.aliases
-            && std::ptr::eq(udf, &other.udf)
-            && signature == &other.signature
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            aliases,
-            udf,
-            signature,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        aliases.hash(&mut hasher);
-        std::ptr::hash(udf, &mut hasher);
-        signature.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 #[cfg(test)]
diff --git a/datafusion/proto/tests/cases/mod.rs 
b/datafusion/proto/tests/cases/mod.rs
index 6158b727df..ab08f5b9be 100644
--- a/datafusion/proto/tests/cases/mod.rs
+++ b/datafusion/proto/tests/cases/mod.rs
@@ -20,8 +20,8 @@ use datafusion::logical_expr::ColumnarValue;
 use datafusion_common::plan_err;
 use datafusion_expr::function::AccumulatorArgs;
 use datafusion_expr::{
-    Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarFunctionArgs, 
ScalarUDFImpl,
-    Signature, Volatility, WindowUDFImpl,
+    udf_equals_hash, Accumulator, AggregateUDFImpl, PartitionEvaluator,
+    ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, WindowUDFImpl,
 };
 use datafusion_functions_window_common::field::WindowUDFFieldArgs;
 use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
@@ -82,33 +82,7 @@ impl ScalarUDFImpl for MyRegexUdf {
         &self.aliases
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            signature,
-            pattern,
-            aliases,
-        } = self;
-        signature == &other.signature
-            && pattern == &other.pattern
-            && aliases == &other.aliases
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            signature,
-            pattern,
-            aliases,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        signature.hash(&mut hasher);
-        pattern.hash(&mut hasher);
-        aliases.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index 1e857b6a07..dd5ec4a201 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use std::any::Any;
-use std::hash::{DefaultHasher, Hash, Hasher};
+use std::hash::Hash;
 #[cfg(test)]
 use std::sync::Arc;
 use std::vec;
@@ -25,9 +25,9 @@ use arrow::datatypes::{TimeUnit::Nanosecond, *};
 use common::MockContextProvider;
 use datafusion_common::{assert_contains, DataFusionError, Result};
 use datafusion_expr::{
-    col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, 
ColumnarValue,
-    CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, 
Signature,
-    Volatility,
+    col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, 
udf_equals_hash,
+    ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF,
+    ScalarUDFImpl, Signature, Volatility,
 };
 use datafusion_functions::{string, unicode};
 use datafusion_sql::{
@@ -3312,7 +3312,7 @@ fn make_udf(name: &'static str, args: Vec<DataType>, 
return_type: DataType) -> S
 }
 
 /// Mocked UDF
-#[derive(Debug)]
+#[derive(Debug, PartialEq, Hash)]
 struct DummyUDF {
     name: &'static str,
     signature: Signature,
@@ -3350,33 +3350,7 @@ impl ScalarUDFImpl for DummyUDF {
         panic!("dummy - not implemented")
     }
 
-    fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
-        let Some(other) = other.as_any().downcast_ref::<Self>() else {
-            return false;
-        };
-        let Self {
-            name,
-            signature,
-            return_type,
-        } = self;
-        name == &other.name
-            && signature == &other.signature
-            && return_type == &other.return_type
-    }
-
-    fn hash_value(&self) -> u64 {
-        let Self {
-            name,
-            signature,
-            return_type,
-        } = self;
-        let mut hasher = DefaultHasher::new();
-        std::any::type_name::<Self>().hash(&mut hasher);
-        name.hash(&mut hasher);
-        signature.hash(&mut hasher);
-        return_type.hash(&mut hasher);
-        hasher.finish()
-    }
+    udf_equals_hash!(ScalarUDFImpl);
 }
 
 fn parse_decimals_parser_options() -> ParserOptions {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to