This is an automated email from the ASF dual-hosted git repository. findepi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new ac3a5735ff Derive UDAF equality from Eq, Hash (#17067) ac3a5735ff is described below commit ac3a5735ff3b38b0e9aed65ca6b20ff0d3ad9c02 Author: Piotr Findeisen <piotr.findei...@gmail.com> AuthorDate: Fri Aug 8 07:50:00 2025 +0200 Derive UDAF equality from Eq, Hash (#17067) * Require Eq to use udf_equals_hash The UDF comparison is expected to be reflexive. Require `Eq` for any uses of `udf_equals_hash` short-cut. * Add UdfEq wrapper around Arc to UDF impl The wrapper implements PartialEq, Eq, Hash by forwarding to UDF impl equals and hash_value functions. * Derive UDAF equality from Eq, Hash Reduce boilerplate in cases where implementation of `AggregateUDFImpl::{equals,hash_value}` can be derived using standard `Eq` and `Hash` traits. --- .../tests/user_defined/user_defined_aggregates.rs | 74 ++++----- .../user_defined/user_defined_scalar_functions.rs | 9 +- .../user_defined/user_defined_window_functions.rs | 2 +- datafusion/doc/src/lib.rs | 4 +- datafusion/expr/src/async_udf.rs | 1 + datafusion/expr/src/expr_fn.rs | 49 +----- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/ptr_eq.rs | 3 +- datafusion/expr/src/udaf.rs | 29 ++-- datafusion/expr/src/udf.rs | 25 +-- datafusion/expr/src/udf_eq.rs | 181 +++++++++++++++++++++ datafusion/expr/src/udwf.rs | 6 +- datafusion/expr/src/utils.rs | 10 +- datafusion/ffi/src/udaf/mod.rs | 44 ++--- datafusion/ffi/src/udf/mod.rs | 1 + datafusion/ffi/src/udwf/mod.rs | 1 + .../src/approx_percentile_cont.rs | 7 +- .../src/approx_percentile_cont_with_weight.rs | 30 +--- .../functions-aggregate/src/bit_and_or_xor.rs | 40 +---- datafusion/functions-aggregate/src/first_last.rs | 56 +------ datafusion/functions-aggregate/src/regr.rs | 33 +--- datafusion/functions-aggregate/src/stddev.rs | 24 +-- datafusion/functions-aggregate/src/string_agg.rs | 30 +--- datafusion/functions-window/src/lead_lag.rs | 2 +- datafusion/functions-window/src/nth_value.rs | 2 +- datafusion/functions-window/src/rank.rs | 2 +- .../src/simplify_expressions/expr_simplifier.rs | 21 +-- datafusion/proto/tests/cases/mod.rs | 21 +-- datafusion/sql/tests/sql_integration.rs | 2 +- 29 files changed, 323 insertions(+), 387 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 7f1a12e9cd..cdba41a0d1 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::collections::HashMap; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::mem::{size_of, size_of_val}; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -55,7 +55,7 @@ use datafusion_common::{assert_contains, exec_datafusion_err}; use datafusion_common::{cast::as_primitive_array, exec_err}; use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr, + col, create_udaf, function::AccumulatorArgs, udf_equals_hash, AggregateUDFImpl, Expr, GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition, }; use datafusion_functions_aggregate::average::AvgAccumulator; @@ -778,7 +778,7 @@ impl Accumulator for FirstSelector { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] struct TestGroupsAccumulator { signature: Signature, result: u64, @@ -817,20 +817,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(Box::new(self.clone())) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::<TestGroupsAccumulator>() { - self.result == other.result && self.signature == other.signature - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.signature.hash(hasher); - self.result.hash(hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } impl Accumulator for TestGroupsAccumulator { @@ -902,6 +889,32 @@ struct MetadataBasedAggregateUdf { metadata: HashMap<String, String>, } +impl PartialEq for MetadataBasedAggregateUdf { + fn eq(&self, other: &Self) -> bool { + let Self { + name, + signature, + metadata, + } = self; + name == &other.name + && signature == &other.signature + && metadata == &other.metadata + } +} +impl Eq for MetadataBasedAggregateUdf {} +impl Hash for MetadataBasedAggregateUdf { + fn hash<H: Hasher>(&self, state: &mut H) { + let Self { + name, + signature, + metadata: _, // unhashable + } = self; + std::any::type_name::<Self>().hash(state); + name.hash(state); + signature.hash(state); + } +} + impl MetadataBasedAggregateUdf { fn new(metadata: HashMap<String, String>) -> Self { // The name we return must be unique. Otherwise we will not call distinct @@ -958,32 +971,7 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf { })) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - metadata, - } = self; - name == &other.name - && signature == &other.signature - && metadata == &other.metadata - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - metadata: _, // unhashable - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 32c2f1d302..bf7f58d51b 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -181,7 +181,7 @@ async fn scalar_udf() -> Result<()> { Ok(()) } -#[derive(PartialEq, Hash)] +#[derive(PartialEq, Eq, Hash)] struct Simple0ArgsScalarUDF { name: String, signature: Signature, @@ -492,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { } /// Volatile UDF that should append a different value to each row -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AddIndexToStringVolatileScalarUDF { name: String, signature: Signature, @@ -941,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory { // // it also defines custom [ScalarUDFImpl::simplify()] // to replace ScalarUDF expression with one instance contains. -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct ScalarFunctionWrapper { name: String, expr: Expr, @@ -1221,6 +1221,7 @@ impl PartialEq for MyRegexUdf { signature == &other.signature && regex.as_str() == other.regex.as_str() } } +impl Eq for MyRegexUdf {} impl Hash for MyRegexUdf { fn hash<H: Hasher>(&self, state: &mut H) { @@ -1380,7 +1381,7 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB ctx.sql(sql).await?.collect().await } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] struct MetadataBasedUdf { name: String, signature: Signature, diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 04ed7dc125..f0ff0cbb22 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -525,7 +525,7 @@ impl OddCounter { } fn register(ctx: &mut SessionContext, test_state: Arc<TestState>) { - #[derive(Debug, Clone, PartialEq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimpleWindowUDF { signature: Signature, test_state: PtrEq<Arc<TestState>>, diff --git a/datafusion/doc/src/lib.rs b/datafusion/doc/src/lib.rs index ca74c3b06d..c86a40ece2 100644 --- a/datafusion/doc/src/lib.rs +++ b/datafusion/doc/src/lib.rs @@ -39,7 +39,7 @@ /// thus all text should be in English. /// /// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Documentation { /// The section in the documentation where the UDF will be documented pub doc_section: DocSection, @@ -158,7 +158,7 @@ impl Documentation { } } -#[derive(Debug, Clone, PartialEq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DocSection { /// True to include this doc section in the public /// documentation, false otherwise diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 4b5a55d90c..ad07d0690e 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -69,6 +69,7 @@ impl PartialEq for AsyncScalarUDF { arc_ptr_eq(inner, &other.inner) } } +impl Eq for AsyncScalarUDF {} impl Hash for AsyncScalarUDF { fn hash<H: Hasher>(&self, state: &mut H) { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9e2285f7c0..6e5cd068b3 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -45,7 +45,7 @@ use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::ops::Not; use std::sync::Arc; @@ -403,7 +403,7 @@ pub fn create_udf( /// Implements [`ScalarUDFImpl`] for functions that have a single signature and /// return type. -#[derive(PartialEq, Hash)] +#[derive(PartialEq, Eq, Hash)] pub struct SimpleScalarUDF { name: String, signature: Signature, @@ -511,11 +511,12 @@ pub fn create_udaf( /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. +#[derive(PartialEq, Eq, Hash)] pub struct SimpleAggregateUDF { name: String, signature: Signature, return_type: DataType, - accumulator: AccumulatorFactoryFunction, + accumulator: PtrEq<AccumulatorFactoryFunction>, state_fields: Vec<FieldRef>, } @@ -547,7 +548,7 @@ impl SimpleAggregateUDF { name, signature, return_type, - accumulator, + accumulator: accumulator.into(), state_fields, } } @@ -566,7 +567,7 @@ impl SimpleAggregateUDF { name, signature, return_type, - accumulator, + accumulator: accumulator.into(), state_fields, } } @@ -600,41 +601,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { Ok(self.state_fields.clone()) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - name, - signature, - return_type, - accumulator, - state_fields, - } = self; - name == &other.name - && signature == &other.signature - && return_type == &other.return_type - && Arc::ptr_eq(accumulator, &other.accumulator) - && state_fields == &other.state_fields - } - - fn hash_value(&self) -> u64 { - let Self { - name, - signature, - return_type, - accumulator, - state_fields, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - name.hash(&mut hasher); - signature.hash(&mut hasher); - return_type.hash(&mut hasher); - Arc::as_ptr(accumulator).hash(&mut hasher); - state_fields.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } /// Creates a new UDWF with a specific signature, state type and return type. @@ -661,7 +628,7 @@ pub fn create_udwf( /// Implements [`WindowUDFImpl`] for functions that have a single signature and /// return type. -#[derive(PartialEq, Hash)] +#[derive(PartialEq, Eq, Hash)] pub struct SimpleWindowUDF { name: String, signature: Signature, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 636d7aac59..b4ad838721 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,6 +71,7 @@ pub mod ptr_eq; pub mod test; pub mod tree_node; pub mod type_coercion; +pub mod udf_eq; pub mod utils; pub mod var_provider; pub mod window_frame; diff --git a/datafusion/expr/src/ptr_eq.rs b/datafusion/expr/src/ptr_eq.rs index 5a177b266d..c85b3d9950 100644 --- a/datafusion/expr/src/ptr_eq.rs +++ b/datafusion/expr/src/ptr_eq.rs @@ -34,7 +34,7 @@ pub fn arc_ptr_hash<T: ?Sized>(a: &Arc<T>, hasher: &mut impl Hasher) { std::ptr::hash(Arc::as_ptr(a), hasher) } -/// A wrapper around a pointer that implements `PartialEq` and `Hash` comparing +/// A wrapper around a pointer that implements `Eq` and `Hash` comparing /// the underlying pointer address. #[derive(Clone)] #[allow(private_bounds)] // This is so that PtrEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. @@ -48,6 +48,7 @@ where arc_ptr_eq(&self.0, &other.0) } } +impl<T> Eq for PtrEq<Arc<T>> where T: ?Sized {} impl<T> Hash for PtrEq<Arc<T>> where diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 984c21d581..bd72801372 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -38,9 +38,10 @@ use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; use crate::groups_accumulator::GroupsAccumulator; +use crate::udf_eq::UdfEq; use crate::utils::format_state_name; use crate::utils::AggregateOrderSensitivity; -use crate::{expr_vec_fmt, Accumulator, Expr}; +use crate::{expr_vec_fmt, udf_equals_hash, Accumulator, Expr}; use crate::{Documentation, Signature}; /// Logical representation of a user-defined [aggregate function] (UDAF). @@ -1037,9 +1038,9 @@ pub enum ReversedUDAF { /// AggregateUDF that adds an alias to the underlying function. It is better to /// implement [`AggregateUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedAggregateUDFImpl { - inner: Arc<dyn AggregateUDFImpl>, + inner: UdfEq<Arc<dyn AggregateUDFImpl>>, aliases: Vec<String>, } @@ -1051,7 +1052,10 @@ impl AliasedAggregateUDFImpl { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } @@ -1111,7 +1115,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { .map(|udf| { udf.map(|udf| { Arc::new(AliasedAggregateUDFImpl { - inner: udf, + inner: udf.into(), aliases: self.aliases.clone(), }) as Arc<dyn AggregateUDFImpl> }) @@ -1134,20 +1138,7 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.coerce_types(arg_types) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() { - self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases - } else { - false - } - } - - fn hash_value(&self) -> u64 { - let hasher = &mut DefaultHasher::new(); - self.inner.hash_value().hash(hasher); - self.aliases.hash(hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); fn is_descending(&self) -> Option<bool> { self.inner.is_descending() diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 40e0da2678..272e131a83 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,6 +21,7 @@ use crate::async_udf::AsyncScalarUDF; use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; +use crate::udf_eq::UdfEq; use crate::{udf_equals_hash, ColumnarValue, Documentation, Expr, Signature}; use arrow::datatypes::{DataType, Field, FieldRef}; use datafusion_common::config::ConfigOptions; @@ -743,27 +744,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// ScalarUDF that adds an alias to the underlying function. It is better to /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedScalarUDFImpl { - inner: Arc<dyn ScalarUDFImpl>, + inner: UdfEq<Arc<dyn ScalarUDFImpl>>, aliases: Vec<String>, } -impl PartialEq for AliasedScalarUDFImpl { - fn eq(&self, other: &Self) -> bool { - let Self { inner, aliases } = self; - inner.equals(other.inner.as_ref()) && aliases == &other.aliases - } -} - -impl Hash for AliasedScalarUDFImpl { - fn hash<H: Hasher>(&self, state: &mut H) { - let Self { inner, aliases } = self; - inner.hash_value().hash(state); - aliases.hash(state); - } -} - impl AliasedScalarUDFImpl { pub fn new( inner: Arc<dyn ScalarUDFImpl>, @@ -771,7 +757,10 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } + Self { + inner: inner.into(), + aliases, + } } } diff --git a/datafusion/expr/src/udf_eq.rs b/datafusion/expr/src/udf_eq.rs new file mode 100644 index 0000000000..1871aab3fd --- /dev/null +++ b/datafusion/expr/src/udf_eq.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{AggregateUDFImpl, ScalarUDFImpl, WindowUDFImpl}; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +use std::sync::Arc; + +/// A wrapper around a pointer to UDF that implements `Eq` and `Hash` delegating to +/// corresponding methods on the UDF trait. +#[derive(Clone)] +#[allow(private_bounds)] // This is so that UdfEq can only be used with allowed pointer types (e.g. Arc), without allowing misuse. +pub struct UdfEq<Ptr: UdfPointer>(Ptr); + +impl<Ptr> PartialEq for UdfEq<Ptr> +where + Ptr: UdfPointer, +{ + fn eq(&self, other: &Self) -> bool { + self.0.equals(&other.0) + } +} +impl<Ptr> Eq for UdfEq<Ptr> where Ptr: UdfPointer {} +impl<Ptr> Hash for UdfEq<Ptr> +where + Ptr: UdfPointer, +{ + fn hash<H: Hasher>(&self, state: &mut H) { + self.0.hash_value().hash(state); + } +} + +impl<Ptr> From<Ptr> for UdfEq<Ptr> +where + Ptr: UdfPointer, +{ + fn from(ptr: Ptr) -> Self { + UdfEq(ptr) + } +} + +impl<Ptr> Debug for UdfEq<Ptr> +where + Ptr: UdfPointer + Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl<Ptr> Deref for UdfEq<Ptr> +where + Ptr: UdfPointer, +{ + type Target = Ptr; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +trait UdfPointer: Deref { + fn equals(&self, other: &Self::Target) -> bool; + fn hash_value(&self) -> u64; +} + +macro_rules! impl_for_udf_eq { + ($udf:ty) => { + impl UdfPointer for Arc<$udf> { + fn equals(&self, other: &$udf) -> bool { + self.as_ref().equals(other) + } + + fn hash_value(&self) -> u64 { + self.as_ref().hash_value() + } + } + }; +} + +impl_for_udf_eq!(dyn AggregateUDFImpl + '_); +impl_for_udf_eq!(dyn ScalarUDFImpl + '_); +impl_for_udf_eq!(dyn WindowUDFImpl + '_); + +#[cfg(test)] +mod tests { + use super::*; + use crate::ScalarFunctionArgs; + use arrow::datatypes::DataType; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::signature::{Signature, Volatility}; + use std::any::Any; + use std::hash::DefaultHasher; + + #[derive(Debug)] + struct TestScalarUDF { + signature: Signature, + name: &'static str, + } + impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result<DataType> { + unimplemented!() + } + + fn invoke_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> datafusion_common::Result<ColumnarValue> { + unimplemented!() + } + } + + #[test] + pub fn test_eq_eq_wrapper() { + let signature = Signature::any(1, Volatility::Immutable); + + let a1: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "a", + }); + let a2: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "a", + }); + let b: Arc<dyn ScalarUDFImpl> = Arc::new(TestScalarUDF { + signature: signature.clone(), + name: "b", + }); + + // Reflexivity + let wrapper = UdfEq(Arc::clone(&a1)); + assert_eq!(wrapper, wrapper); + + // Two wrappers around equal pointer + assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a1))); + assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a1)))); + + // Two wrappers around different pointers but equal in ScalarUDFImpl::equals sense + assert_eq!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&a2))); + assert_eq!(hash(UdfEq(Arc::clone(&a1))), hash(UdfEq(Arc::clone(&a2)))); + + // different functions (not equal) + assert_ne!(UdfEq(Arc::clone(&a1)), UdfEq(Arc::clone(&b))); + } + + fn hash<T: Hash>(value: T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 4be58c3a4a..032daa5c25 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -29,7 +29,7 @@ use std::{ use arrow::datatypes::{DataType, FieldRef}; use crate::expr::WindowFunction; -use crate::ptr_eq::PtrEq; +use crate::udf_eq::UdfEq; use crate::{ function::WindowFunctionSimplification, udf_equals_hash, Expr, PartitionEvaluator, Signature, @@ -479,9 +479,9 @@ impl PartialOrd for dyn WindowUDFImpl { /// WindowUDF that adds an alias to the underlying function. It is better to /// implement [`WindowUDFImpl`], which supports aliases, directly if possible. -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct AliasedWindowUDFImpl { - inner: PtrEq<Arc<dyn WindowUDFImpl>>, + inner: UdfEq<Arc<dyn WindowUDFImpl>>, aliases: Vec<String>, } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 874811665c..80ad0f8784 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1276,9 +1276,9 @@ pub fn collect_subquery_cols( /// # use datafusion_expr_common::signature::Signature; /// # use std::any::Any; /// -/// // Implementing PartialEq & Hash is a prerequisite for using this macro, +/// // Implementing Eq & Hash is a prerequisite for using this macro, /// // but the implementation can be derived. -/// #[derive(Debug, PartialEq, Hash)] +/// #[derive(Debug, PartialEq, Eq, Hash)] /// struct VarcharToTimestampTz { /// safe: bool, /// } @@ -1322,11 +1322,13 @@ macro_rules! udf_equals_hash { ($udf_type:tt) => { fn equals(&self, other: &dyn $udf_type) -> bool { use ::core::any::Any; - use ::core::cmp::PartialEq; + use ::core::cmp::{Eq, PartialEq}; let Some(other) = <dyn Any + 'static>::downcast_ref::<Self>(other.as_any()) else { return false; }; + fn assert_self_impls_eq<T: Eq>() {} + assert_self_impls_eq::<Self>(); PartialEq::eq(self, other) } @@ -1804,7 +1806,7 @@ mod tests { } } - #[derive(Debug, PartialEq, Hash)] + #[derive(Debug, PartialEq, Eq, Hash)] struct StatefulFunctionWithEqHashWithUdfEqualsHash { signature: Signature, state: bool, diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 66e1c28bb9..63f7d26544 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -39,7 +39,7 @@ use datafusion::{ }; use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::{Hash, Hasher}; use std::{ffi::c_void, sync::Arc}; use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; @@ -49,6 +49,7 @@ use crate::{ util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, volatility::FFI_Volatility, }; +use datafusion::logical_expr::udf_equals_hash; use prost::{DecodeError, Message}; mod accumulator; @@ -384,6 +385,19 @@ pub struct ForeignAggregateUDF { unsafe impl Send for ForeignAggregateUDF {} unsafe impl Sync for ForeignAggregateUDF {} +impl PartialEq for ForeignAggregateUDF { + fn eq(&self, other: &Self) -> bool { + // FFI_AggregateUDF cannot be compared, so identity equality is the best we can do. + std::ptr::eq(self, other) + } +} +impl Eq for ForeignAggregateUDF {} +impl Hash for ForeignAggregateUDF { + fn hash<H: Hasher>(&self, state: &mut H) { + std::ptr::hash(self, state) + } +} + impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { type Error = DataFusionError; @@ -554,33 +568,7 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - aliases, - udaf, - } = self; - signature == &other.signature - && aliases == &other.aliases - && std::ptr::eq(udaf, &other.udaf) - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - aliases, - udaf, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - aliases.hash(&mut hasher); - std::ptr::hash(udaf, &mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[repr(C)] diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 4d634e0be2..8f877d44f8 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -304,6 +304,7 @@ impl PartialEq for ForeignScalarUDF { && signature == &other.signature } } +impl Eq for ForeignScalarUDF {} impl Hash for ForeignScalarUDF { fn hash<H: Hasher>(&self, state: &mut H) { diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs index ec1b6698f5..a5e18cdf1e 100644 --- a/datafusion/ffi/src/udwf/mod.rs +++ b/datafusion/ffi/src/udwf/mod.rs @@ -261,6 +261,7 @@ impl PartialEq for ForeignWindowUDF { std::ptr::eq(self, other) } } +impl Eq for ForeignWindowUDF {} impl Hash for ForeignWindowUDF { fn hash<H: Hasher>(&self, state: &mut H) { std::ptr::hash(self, state) diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 863ee15d89..36c005274d 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -39,8 +39,8 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature, - TypeSignature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, + Signature, TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, @@ -102,6 +102,7 @@ pub fn approx_percentile_cont( description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxPercentileCont { signature: Signature, } @@ -336,6 +337,8 @@ impl AggregateUDFImpl for ApproxPercentileCont { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index d30ea624ca..9a19f43a52 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -17,7 +17,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -30,7 +30,8 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, + TypeSignature, }; use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; @@ -100,6 +101,7 @@ pub fn approx_percentile_cont_with_weight( description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] +#[derive(PartialEq, Eq, Hash)] pub struct ApproxPercentileContWithWeight { signature: Signature, approx_percentile_cont: ApproxPercentileCont, @@ -237,29 +239,7 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - approx_percentile_cont, - } = self; - signature == &other.signature - && approx_percentile_cont.equals(&other.approx_percentile_cont) - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - approx_percentile_cont, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - hasher.write_u64(approx_percentile_cont.hash_value()); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 8ca5d992a7..8d573580d4 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -20,7 +20,7 @@ use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::{size_of, size_of_val}; use ahash::RandomState; @@ -36,8 +36,8 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::INTEGERS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, - Signature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, + ReversedUDAF, Signature, Volatility, }; use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; @@ -211,7 +211,7 @@ impl Display for BitwiseOperationType { } /// [BitwiseOperation] struct encapsulates information about a bitwise operation. -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] struct BitwiseOperation { signature: Signature, /// `operation` indicates the type of bitwise operation to be performed. @@ -314,37 +314,7 @@ impl AggregateUDFImpl for BitwiseOperation { Some(self.documentation) } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - operation, - func_name, - documentation, - } = self; - signature == &other.signature - && operation == &other.operation - && func_name == &other.func_name - && documentation == &other.documentation - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - operation, - func_name, - documentation, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - operation.hash(&mut hasher); - func_name.hash(&mut hasher); - documentation.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } struct BitAndAccumulator<T: ArrowNumericType> { diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0856237d08..87f14ae634 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::Arc; @@ -45,8 +45,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, ExprFunctionExt, - GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, EmitTo, Expr, + ExprFunctionExt, GroupsAccumulator, ReversedUDAF, Signature, SortExpr, Volatility, }; use datafusion_functions_aggregate_common::utils::get_sort_options; use datafusion_macros::user_doc; @@ -89,6 +89,7 @@ pub fn last_value(expression: Expr, order_by: Vec<SortExpr>) -> Expr { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct FirstValue { signature: Signature, is_input_pre_ordered: bool, @@ -294,29 +295,7 @@ impl AggregateUDFImpl for FirstValue { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - is_input_pre_ordered, - } = self; - signature == &other.signature - && is_input_pre_ordered == &other.is_input_pre_ordered - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - is_input_pre_ordered, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - is_input_pre_ordered.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } // TODO: rename to PrimitiveGroupsAccumulator @@ -1029,6 +1008,7 @@ impl Accumulator for FirstValueAccumulator { ```"#, standard_argument(name = "expression",) )] +#[derive(PartialEq, Eq, Hash)] pub struct LastValue { signature: Signature, is_input_pre_ordered: bool, @@ -1238,29 +1218,7 @@ impl AggregateUDFImpl for LastValue { } } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - is_input_pre_ordered, - } = self; - signature == &other.signature - && is_input_pre_ordered == &other.is_input_pre_ordered - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - is_input_pre_ordered, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - is_input_pre_ordered.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } /// This accumulator is used when there is no ordering specified for the diff --git a/datafusion/functions-aggregate/src/regr.rs b/datafusion/functions-aggregate/src/regr.rs index f7e0f0a104..c8dde7aed6 100644 --- a/datafusion/functions-aggregate/src/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -34,11 +34,11 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use std::sync::{Arc, LazyLock}; @@ -59,6 +59,7 @@ make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); +#[derive(PartialEq, Eq, Hash)] pub struct Regr { signature: Signature, regr_type: RegrType, @@ -527,33 +528,7 @@ impl AggregateUDFImpl for Regr { self.regr_type.documentation() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - regr_type, - func_name, - } = self; - signature == &other.signature - && regr_type == &other.regr_type - && func_name == &other.func_name - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - regr_type, - func_name, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - regr_type.hash(&mut hasher); - func_name.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } /// `RegrAccumulator` is used to compute linear regression aggregate functions diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 2f9f1cac84..d0512b3815 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::align_of_val; use std::sync::Arc; @@ -31,8 +31,8 @@ use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::format_state_name; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, - Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, + Signature, Volatility, }; use datafusion_functions_aggregate_common::stats::StatsType; use datafusion_macros::user_doc; @@ -62,6 +62,7 @@ make_udaf_expr_and_func!( standard_argument(name = "expression",) )] /// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +#[derive(PartialEq, Eq, Hash)] pub struct Stddev { signature: Signature, alias: Vec<String>, @@ -155,22 +156,7 @@ impl AggregateUDFImpl for Stddev { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { signature, alias } = self; - signature == &other.signature && alias == &other.alias - } - - fn hash_value(&self) -> u64 { - let Self { signature, alias } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - alias.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } make_udaf_expr_and_func!( diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 56c5ee1aaa..7564572744 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -18,7 +18,7 @@ //! [`StringAgg`] accumulator for the `string_agg` function use std::any::Any; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; use std::mem::size_of_val; use crate::array_agg::ArrayAgg; @@ -29,7 +29,8 @@ use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, + udf_equals_hash, Accumulator, AggregateUDFImpl, Documentation, Signature, + TypeSignature, Volatility, }; use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; use datafusion_macros::user_doc; @@ -82,7 +83,7 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp ) )] /// STRING_AGG aggregate expression -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct StringAgg { signature: Signature, array_agg: ArrayAgg, @@ -182,28 +183,7 @@ impl AggregateUDFImpl for StringAgg { self.doc() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { - signature, - array_agg, - } = self; - signature == &other.signature && array_agg.equals(&other.array_agg) - } - - fn hash_value(&self) -> u64 { - let Self { - signature, - array_agg, - } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - hasher.write_u64(array_agg.hash_value()); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Debug)] diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs index 140b797514..8f9a1a7a72 100644 --- a/datafusion/functions-window/src/lead_lag.rs +++ b/datafusion/functions-window/src/lead_lag.rs @@ -120,7 +120,7 @@ impl WindowShiftKind { } /// window shift expression -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct WindowShift { signature: Signature, kind: WindowShiftKind, diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index 783e6e5652..2da2fae23d 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -94,7 +94,7 @@ impl NthValueKind { } } -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct NthValue { signature: Signature, kind: NthValueKind, diff --git a/datafusion/functions-window/src/rank.rs b/datafusion/functions-window/src/rank.rs index 4b29b6dac8..e026bdf594 100644 --- a/datafusion/functions-window/src/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -64,7 +64,7 @@ define_udwf_and_expr!( ); /// Rank calculates the rank in the window function with order by -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct Rank { name: String, signature: Signature, diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index dfe841fe36..d33cefb341 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -2181,7 +2181,7 @@ mod tests { }; use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; - use std::hash::{DefaultHasher, Hash, Hasher}; + use std::hash::Hash; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, @@ -4347,7 +4347,7 @@ mod tests { /// A Mock UDAF which defines `simplify` to be used in tests /// related to UDAF simplification - #[derive(Debug, Clone)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifyMockUdaf { simplify: bool, } @@ -4406,20 +4406,7 @@ mod tests { } } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { simplify } = self; - simplify == &other.simplify - } - - fn hash_value(&self) -> u64 { - let Self { simplify } = self; - let mut hasher = DefaultHasher::new(); - simplify.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[test] @@ -4443,7 +4430,7 @@ mod tests { /// A Mock UDWF which defines `simplify` to be used in tests /// related to UDWF simplification - #[derive(Debug, Clone, PartialEq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct SimplifyMockUdwf { simplify: bool, } diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index eba227a84a..ee5005fdde 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -27,7 +27,7 @@ use datafusion_functions_window_common::field::WindowUDFFieldArgs; use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::any::Any; use std::fmt::Debug; -use std::hash::{DefaultHasher, Hash, Hasher}; +use std::hash::Hash; mod roundtrip_logical_plan; mod roundtrip_physical_plan; @@ -127,22 +127,7 @@ impl AggregateUDFImpl for MyAggregateUDF { unimplemented!() } - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { - let Some(other) = other.as_any().downcast_ref::<Self>() else { - return false; - }; - let Self { signature, result } = self; - signature == &other.signature && result == &other.result - } - - fn hash_value(&self) -> u64 { - let Self { signature, result } = self; - let mut hasher = DefaultHasher::new(); - std::any::type_name::<Self>().hash(&mut hasher); - signature.hash(&mut hasher); - result.hash(&mut hasher); - hasher.finish() - } + udf_equals_hash!(AggregateUDFImpl); } #[derive(Clone, PartialEq, ::prost::Message)] @@ -151,7 +136,7 @@ pub struct MyAggregateUdfNode { pub result: String, } -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] pub(in crate::cases) struct CustomUDWF { signature: Signature, payload: String, diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 2514404250..751254ff20 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -3312,7 +3312,7 @@ fn make_udf(name: &'static str, args: Vec<DataType>, return_type: DataType) -> S } /// Mocked UDF -#[derive(Debug, PartialEq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash)] struct DummyUDF { name: &'static str, signature: Signature, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org