This is an automated email from the ASF dual-hosted git repository. findepi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 070517a87f Derive UDF equality from PartialEq, Hash (#16842) 070517a87f is described below commit 070517a87f15fc9e5f64423e29ee155d4a6a868d Author: Piotr Findeisen <piotr.findei...@gmail.com> AuthorDate: Fri Jul 25 21:21:44 2025 +0200 Derive UDF equality from PartialEq, Hash (#16842) * Derive UDF equality from PartialEq, Hash Reduce boilerplate in cases where implementation of `{ScalarUDFImpl,AggregateUDFImpl,WindowUDFImpl}::{equals,hash_code}` can be derived using standard `PartialEq` and `Hash` traits. This is code complexity reduction. While valuable on its own, this also prepares for more automatic derivation of UDF equals/hash and/or removal of default implementations (which currently are error-prone). * udf_equals_hash example * test udf_equals_hash * empty: roll the dice 🎲 --- .../user_defined/user_defined_scalar_functions.rs | 171 +++++-------------- datafusion/expr/src/async_udf.rs | 38 +++-- datafusion/expr/src/expr_fn.rs | 67 ++++---- datafusion/expr/src/udf.rs | 33 ++-- datafusion/expr/src/utils.rs | 182 ++++++++++++++++++++- datafusion/ffi/src/udf/mod.rs | 66 ++++---- datafusion/proto/tests/cases/mod.rs | 32 +--- datafusion/sql/tests/sql_integration.rs | 38 +---- 8 files changed, 335 insertions(+), 292 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index c5f9bdeb69..dd8283613a 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::collections::HashMap; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array}; @@ -43,9 +43,9 @@ use datafusion_common::{ use datafusion_expr::expr::FieldMetadata; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, - LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, + lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, CreateFunction, + CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -181,6 +181,7 @@ async fn scalar_udf() -> Result<()> { Ok(()) } +#[derive(PartialEq, Hash)] struct Simple0ArgsScalarUDF { name: String, signature: Signature, @@ -218,33 +219,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - return_type, - } = self; - name == &other.name - && signature == &other.signature - && return_type == &other.return_type - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - return_type, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } #[tokio::test] @@ -517,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { } /// Volatile UDF that should append a different value to each row -#[derive(Debug)] +#[derive(Debug, PartialEq, Hash)] struct AddIndexToStringVolatileScalarUDF { name: String, signature: Signature, @@ -586,33 +561,7 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - return_type, - } = self; - name == &other.name - && signature == &other.signature - && return_type == &other.return_type - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - return_type, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } #[tokio::test] @@ -992,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory { // // it also defines custom [ScalarUDFImpl::simplify()] // to replace ScalarUDF expression with one instance contains. -#[derive(Debug)] +#[derive(Debug, PartialEq, Hash)] struct ScalarFunctionWrapper { name: String, expr: Expr, @@ -1031,37 +980,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { Ok(ExprSimplifyResult::Simplified(replacement)) } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - expr, - signature, - return_type, - } = self; - name == &other.name - && expr == &other.expr - && signature == &other.signature - && return_type == &other.return_type - } - - fn hash_value(&self) -> u64 { - let Self { - name, - expr, - signature, - return_type, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - expr.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } impl ScalarFunctionWrapper { @@ -1296,6 +1215,21 @@ struct MyRegexUdf { regex: Regex, } +impl PartialEq for MyRegexUdf { + fn eq(&self, other: &Self) -> bool { + let Self { signature, regex } = self; + signature == &other.signature && regex.as_str() == other.regex.as_str() + } +} + +impl Hash for MyRegexUdf { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { signature, regex } = self; + signature.hash(state); + regex.as_str().hash(state); + } +} + impl MyRegexUdf { fn new(pattern: &str) -> Self { Self { @@ -1348,19 +1282,7 @@ impl ScalarUDFImpl for MyRegexUdf { } } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() { - self.regex.as_str() == other.regex.as_str() - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.regex.as_str().hash(hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } #[tokio::test] @@ -1458,13 +1380,25 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB ctx.sql(sql).await?.collect().await } -#[derive(Debug)] +#[derive(Debug, PartialEq)] struct MetadataBasedUdf { name: String, signature: Signature, metadata: HashMap<String, String>, } +impl Hash for MetadataBasedUdf { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + name.hash(state); + signature.hash(state); + } +} + impl MetadataBasedUdf { fn new(metadata: HashMap<String, String>) -> Self { // The name we return must be unique. Otherwise we will not call distinct @@ -1537,32 +1471,7 @@ impl ScalarUDFImpl for MetadataBasedUdf { } } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - metadata, - } = self; - name == &other.name - && signature == &other.signature - && metadata == &other.metadata - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - metadata: _, // unhashable - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } #[tokio::test] diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 24ed124bb2..753ad7b778 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +use crate::utils::{arc_ptr_eq, arc_ptr_hash}; +use crate::{ + udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, +}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, FieldRef}; use async_trait::async_trait; @@ -26,7 +29,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::signature::Signature; use std::any::Any; use std::fmt::{Debug, Display}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::sync::Arc; /// A scalar UDF that can invoke using async methods @@ -62,6 +65,21 @@ pub struct AsyncScalarUDF { inner: Arc<dyn AsyncScalarUDFImpl>, } +impl PartialEq for AsyncScalarUDF { + fn eq(&self, other: &Self) -> bool { + let Self { inner } = self; + // TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting. + arc_ptr_eq(inner, &other.inner) + } +} + +impl Hash for AsyncScalarUDF { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { inner } = self; + arc_ptr_hash(inner, state); + } +} + impl AsyncScalarUDF { pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self { Self { inner } @@ -113,21 +131,7 @@ impl ScalarUDFImpl for AsyncScalarUDF { internal_err!("async functions should not be called directly") } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { inner } = self; - // TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting - Arc::ptr_eq(inner, &other.inner) - } - - fn hash_value(&self) -> u64 { - let Self { inner } = self; - let mut hasher = DefaultHasher::new(); - Arc::as_ptr(inner).hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } impl Display for AsyncScalarUDF { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c0351a9dca..1d8d183807 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,10 +26,11 @@ use crate::function::{ StateFieldsArgs, }; use crate::select_expr::SelectExpr; +use crate::utils::{arc_ptr_eq, arc_ptr_hash}; use crate::{ conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs, - ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, + udf_equals_hash, AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; use crate::{ AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, @@ -409,6 +410,36 @@ pub struct SimpleScalarUDF { fun: ScalarFunctionImplementation, } +impl PartialEq for SimpleScalarUDF { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + return_type, + fun, + } = self; + name == &other.name + && signature == &other.signature + && return_type == &other.return_type + && arc_ptr_eq(fun, &other.fun) + } +} + +impl Hash for SimpleScalarUDF { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { + name, + signature, + return_type, + fun, + } = self; + name.hash(state); + signature.hash(state); + return_type.hash(state); + arc_ptr_hash(fun, state); + } +} + impl Debug for SimpleScalarUDF { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { f.debug_struct("SimpleScalarUDF") @@ -476,37 +507,7 @@ impl ScalarUDFImpl for SimpleScalarUDF { (self.fun)(&args.args) } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - return_type, - fun, - } = self; - name == &other.name - && signature == &other.signature - && return_type == &other.return_type - && Arc::ptr_eq(fun, &other.fun) - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - return_type, - fun, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - Arc::as_ptr(fun).hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } /// Creates a new UDAF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3a94981ae4..171c4e041f 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,7 +21,7 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; -use crate::{ColumnarValue, Documentation, Expr, Signature}; +use crate::{udf_equals_hash, ColumnarValue, Documentation, Expr, Signature}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; @@ -747,6 +747,21 @@ struct AliasedScalarUDFImpl { aliases: Vec<String>, } +impl PartialEq for AliasedScalarUDFImpl { + fn eq(&self, other: &Self) -> bool { + let Self { inner, aliases } = self; + inner.equals(other.inner.as_ref()) && aliases == &other.aliases + } +} + +impl Hash for AliasedScalarUDFImpl { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { inner, aliases } = self; + inner.hash_value().hash(state); + aliases.hash(state); + } +} + impl AliasedScalarUDFImpl { pub fn new( inner: Arc<dyn ScalarUDFImpl>, @@ -831,21 +846,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.coerce_types(arg_types) } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::<AliasedScalarUDFImpl>() { - self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - std::any::type_name::<Self>().hash(hasher); - self.inner.hash_value().hash(hasher); - self.aliases.hash(hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8950f5e450..e554152328 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -19,6 +19,7 @@ use std::cmp::Ordering; use std::collections::{BTreeSet, HashSet}; +use std::hash::Hasher; use std::sync::Arc; use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams}; @@ -1260,6 +1261,94 @@ pub fn collect_subquery_cols( }) } +/// Generates implementation of `equals` and `hash_value` methods for a trait, delegating +/// to [`PartialEq`] and [`Hash`] implementations on Self. +/// Meant to be used with traits representing user-defined functions (UDFs). +/// +/// Example showing generation of [`ScalarUDFImpl::equals`] and [`ScalarUDFImpl::hash_value`] +/// implementations. +/// +/// ``` +/// # use arrow::datatypes::DataType; +/// # use datafusion_expr::{udf_equals_hash, ScalarFunctionArgs, ScalarUDFImpl}; +/// # use datafusion_expr_common::columnar_value::ColumnarValue; +/// # use datafusion_expr_common::signature::Signature; +/// # use std::any::Any; +/// +/// // Implementing PartialEq & Hash is a prerequisite for using this macro, +/// // but the implementation can be derived. +/// #[derive(Debug, PartialEq, Hash)] +/// struct VarcharToTimestampTz { +/// safe: bool, +/// } +/// +/// impl ScalarUDFImpl for VarcharToTimestampTz { +/// /* other methods omitted for brevity */ +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # +/// # fn name(&self) -> &str { +/// # "varchar_to_timestamp_tz" +/// # } +/// # +/// # fn signature(&self) -> &Signature { +/// # todo!() +/// # } +/// # +/// # fn return_type( +/// # &self, +/// # _arg_types: &[DataType], +/// # ) -> datafusion_common::Result<DataType> { +/// # todo!() +/// # } +/// # +/// # fn invoke_with_args( +/// # &self, +/// # args: ScalarFunctionArgs, +/// # ) -> datafusion_common::Result<ColumnarValue> { +/// # todo!() +/// # } +/// # +/// udf_equals_hash!(ScalarUDFImpl); +/// } +/// ``` +/// +/// [`ScalarUDFImpl::equals`]: crate::ScalarUDFImpl::equals +/// [`ScalarUDFImpl::hash_value`]: crate::ScalarUDFImpl::hash_value +#[macro_export] +macro_rules! udf_equals_hash { + ($udf_type:tt) => { + fn equals(&self, other: &dyn $udf_type) -> bool { + use ::core::any::Any; + use ::core::cmp::PartialEq; + let Some(other) = <dyn Any + 'static>::downcast_ref::<Self>(other.as_any()) + else { + return false; + }; + PartialEq::eq(self, other) + } + + fn hash_value(&self) -> u64 { + use ::std::any::type_name; + use ::std::hash::{DefaultHasher, Hash, Hasher}; + let hasher = &mut DefaultHasher::new(); + type_name::<Self>().hash(hasher); + Hash::hash(self, hasher); + Hasher::finish(hasher) + } + }; +} + +pub fn arc_ptr_eq<T: ?Sized>(a: &Arc<T>, b: &Arc<T>) -> bool { + // Not necessarily equivalent to `Arc::ptr_eq` for fat pointers. + std::ptr::eq(Arc::as_ptr(a), Arc::as_ptr(b)) +} + +pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) { + std::ptr::hash(Arc::as_ptr(a), hasher) +} + #[cfg(test)] mod tests { use super::*; @@ -1268,9 +1357,13 @@ mod tests { expr::WindowFunction, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::{max_udaf, min_udaf, sum_udaf}, - Cast, ExprFunctionExt, WindowFunctionDefinition, + Cast, ExprFunctionExt, ScalarFunctionArgs, ScalarUDFImpl, + WindowFunctionDefinition, }; use arrow::datatypes::{UnionFields, UnionMode}; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::signature::Volatility; + use std::any::Any; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1690,4 +1783,91 @@ mod tests { DataType::List(Arc::new(Field::new("my_union", union_type, true))); assert!(!can_hash(&list_union_type)); } + + #[test] + fn test_udf_equals_hash() { + #[derive(Debug, PartialEq, Hash)] + struct StatefulFunctionWithEqHash { + signature: Signature, + state: bool, + } + impl ScalarUDFImpl for StatefulFunctionWithEqHash { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "StatefulFunctionWithEqHash" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + todo!() + } + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result<ColumnarValue> { + todo!() + } + } + + #[derive(Debug, PartialEq, Hash)] + struct StatefulFunctionWithEqHashWithUdfEqualsHash { + signature: Signature, + state: bool, + } + impl ScalarUDFImpl for StatefulFunctionWithEqHashWithUdfEqualsHash { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "StatefulFunctionWithEqHashWithUdfEqualsHash" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { + todo!() + } + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result<ColumnarValue> { + todo!() + } + udf_equals_hash!(ScalarUDFImpl); + } + + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + + // Sadly, without `udf_equals_hash!` macro, the equals and hash_value ignore state fields, + // even though the struct implements `PartialEq` and `Hash`. + let a: Box<dyn ScalarUDFImpl> = Box::new(StatefulFunctionWithEqHash { + signature: signature.clone(), + state: true, + }); + let b: Box<dyn ScalarUDFImpl> = Box::new(StatefulFunctionWithEqHash { + signature: signature.clone(), + state: false, + }); + assert!(a.equals(b.as_ref())); + assert_eq!(a.hash_value(), b.hash_value()); + + // With udf_equals_hash! macro, the equals and hash_value compare the state. + // even though the struct implements `PartialEq` and `Hash`. + let a: Box<dyn ScalarUDFImpl> = + Box::new(StatefulFunctionWithEqHashWithUdfEqualsHash { + signature: signature.clone(), + state: true, + }); + let b: Box<dyn ScalarUDFImpl> = + Box::new(StatefulFunctionWithEqHashWithUdfEqualsHash { + signature: signature.clone(), + state: false, + }); + assert!(!a.equals(b.as_ref())); + // This could be true, but it's very unlikely that boolean true and false hash the same + assert_ne!(a.hash_value(), b.hash_value()); + } } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 09cd0df128..1c835bd3ec 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -32,7 +32,7 @@ use arrow::{ ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; use arrow_schema::FieldRef; -use datafusion::logical_expr::ReturnFieldArgs; +use datafusion::logical_expr::{udf_equals_hash, ReturnFieldArgs}; use datafusion::{ error::DataFusionError, logical_expr::type_coercion::functions::data_types_with_scalar_udf, @@ -46,7 +46,7 @@ use datafusion::{ use return_type_args::{ FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; pub mod return_type_args; @@ -287,6 +287,36 @@ pub struct ForeignScalarUDF { unsafe impl Send for ForeignScalarUDF {} unsafe impl Sync for ForeignScalarUDF {} +impl PartialEq for ForeignScalarUDF { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + aliases, + udf, + signature, + } = self; + name == &other.name + && aliases == &other.aliases + && std::ptr::eq(udf, &other.udf) + && signature == &other.signature + } +} + +impl Hash for ForeignScalarUDF { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { + name, + aliases, + udf, + signature, + } = self; + name.hash(state); + aliases.hash(state); + std::ptr::hash(udf, state); + signature.hash(state); + } +} + impl TryFrom<&FFI_ScalarUDF> for ForeignScalarUDF { type Error = DataFusionError; @@ -409,37 +439,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { } } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - aliases, - udf, - signature, - } = self; - name == &other.name - && aliases == &other.aliases - && std::ptr::eq(udf, &other.udf) - && signature == &other.signature - } - - fn hash_value(&self) -> u64 { - let Self { - name, - aliases, - udf, - signature, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - aliases.hash(&mut hasher); - std::ptr::hash(udf, &mut hasher); - signature.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } #[cfg(test)] diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index 6158b727df..ab08f5b9be 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -20,8 +20,8 @@ use datafusion::logical_expr::ColumnarValue; use datafusion_common::plan_err; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarFunctionArgs, ScalarUDFImpl, - Signature, Volatility, WindowUDFImpl, + udf_equals_hash, Accumulator, AggregateUDFImpl, PartitionEvaluator, + ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, WindowUDFImpl, }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; @@ -82,33 +82,7 @@ impl ScalarUDFImpl for MyRegexUdf { &self.aliases } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - pattern, - aliases, - } = self; - signature == &other.signature - && pattern == &other.pattern - && aliases == &other.aliases - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - pattern, - aliases, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - pattern.hash(&mut hasher); - aliases.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 1e857b6a07..dd5ec4a201 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; #[cfg(test)] use std::sync::Arc; use std::vec; @@ -25,9 +25,9 @@ use arrow::datatypes::{TimeUnit::Nanosecond, *}; use common::MockContextProvider; use datafusion_common::{assert_contains, DataFusionError, Result}; use datafusion_expr::{ - col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, ColumnarValue, - CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, + col, logical_plan::LogicalPlan, test::function_stub::sum_udaf, udf_equals_hash, + ColumnarValue, CreateIndex, DdlStatement, ScalarFunctionArgs, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ @@ -3312,7 +3312,7 @@ fn make_udf(name: &'static str, args: Vec<DataType>, return_type: DataType) -> S } /// Mocked UDF -#[derive(Debug)] +#[derive(Debug, PartialEq, Hash)] struct DummyUDF { name: &'static str, signature: Signature, @@ -3350,33 +3350,7 @@ impl ScalarUDFImpl for DummyUDF { panic!("dummy - not implemented") } - fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - return_type, - } = self; - name == &other.name - && signature == &other.signature - && return_type == &other.return_type - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - return_type, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(ScalarUDFImpl); } fn parse_decimals_parser_options() -> ParserOptions { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org