alamb commented on code in PR #8733:
URL: https://github.com/apache/arrow-datafusion/pull/8733#discussion_r1445312366


##########
datafusion-examples/examples/advanced_udaf.rs:
##########
@@ -0,0 +1,228 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
+use std::{any::Any, sync::Arc};
+
+use arrow::{
+    array::{ArrayRef, Float32Array},
+    record_batch::RecordBatch,
+};
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::{cast::as_float64_array, ScalarValue};
+use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature};
+
+/// This example shows how to use the full AggregateUDFImpl API to implement a 
user
+/// defined aggregate function. As in the `simple_udaf.rs` example, this 
struct implements
+/// a function `accumulator` that returns the `Accumulator` instance.
+///
+/// To do so, we must implement the `AggregateUDFImpl` trait.
+#[derive(Debug, Clone)]
+struct GeoMeanUdf {
+    signature: Signature,
+}
+
+impl GeoMeanUdf {
+    /// Create a new instance of the GeoMeanUdf struct
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(
+                // this function will always take one arguments of type f64
+                vec![DataType::Float64],
+                // this function is deterministic and will always return the 
same
+                // result for the same input
+                Volatility::Immutable,
+            ),
+        }
+    }
+}
+
+impl AggregateUDFImpl for GeoMeanUdf {
+    /// We implement as_any so that we can downcast the AggregateUDFImpl trait 
object
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Return the name of this function
+    fn name(&self) -> &str {
+        "geo_mean"
+    }
+
+    /// Return the "signature" of this function -- namely that types of 
arguments it will take
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    /// What is the type of value that will be returned by this function.
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    /// This is the accumulator factory; DataFusion uses it to create new 
accumulators.
+    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
+        Ok(Box::new(GeometricMean::new()))
+    }
+
+    /// This is the description of the state. accumulator's state() must match 
the types here.
+    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
+        Ok(vec![DataType::Float64, DataType::UInt32])
+    }
+}
+
+/// A UDAF has state across multiple rows, and thus we require a `struct` with 
that state.
+#[derive(Debug)]
+struct GeometricMean {
+    n: u32,
+    prod: f64,
+}
+
+impl GeometricMean {
+    // how the struct is initialized
+    pub fn new() -> Self {
+        GeometricMean { n: 0, prod: 1.0 }
+    }
+}
+
+// UDAFs are built using the trait `Accumulator`, that offers DataFusion the 
necessary functions
+// to use them.
+impl Accumulator for GeometricMean {
+    // This function serializes our state to `ScalarValue`, which DataFusion 
uses
+    // to pass this state between execution stages.
+    // Note that this can be arbitrary data.
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        Ok(vec![
+            ScalarValue::from(self.prod),
+            ScalarValue::from(self.n),
+        ])
+    }
+
+    // DataFusion expects this function to return the final value of this 
aggregator.
+    // in this case, this is the formula of the geometric mean
+    fn evaluate(&self) -> Result<ScalarValue> {
+        let value = self.prod.powf(1.0 / self.n as f64);
+        Ok(ScalarValue::from(value))
+    }
+
+    // DataFusion calls this function to update the accumulator's state for a 
batch
+    // of inputs rows. In this case the product is updated with values from 
the first column
+    // and the count is updated based on the row count
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+        let arr = &values[0];
+        (0..arr.len()).try_for_each(|index| {
+            let v = ScalarValue::try_from_array(arr, index)?;
+
+            if let ScalarValue::Float64(Some(value)) = v {
+                self.prod *= value;
+                self.n += 1;
+            } else {
+                unreachable!("")
+            }
+            Ok(())
+        })
+    }
+
+    // Optimization hint: this trait also supports `update_batch` and 
`merge_batch`,

Review Comment:
   this optimization hit seems out of place -- I think this comment should say 
something more like
   
   ```rust
   // Merge the output of `Self::state()` from other instances of this 
accumulator
   // into this accumulator's state



##########
datafusion/proto/tests/cases/roundtrip_physical_plan.rs:
##########
@@ -374,18 +374,24 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
         }
     }
 
-    let rt_func: ReturnTypeFunction = Arc::new(move |_| 
Ok(Arc::new(DataType::Int64)));
+    let return_type = DataType::Int64;
     let accumulator: AccumulatorFactoryFunction = Arc::new(|_| 
Ok(Box::new(Example)));
-    let st_func: StateTypeFunction =
-        Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64])));
-
-    let udaf = AggregateUDF::new(
-        "example",
-        &Signature::exact(vec![DataType::Int64], Volatility::Immutable),
-        &rt_func,
-        &accumulator,
-        &st_func,
-    );
+    let state_type = vec![DataType::Int64];
+
+    // let udaf = AggregateUDF::new(

Review Comment:
   Did you mean to leave this commented out?



##########
datafusion-examples/examples/advanced_udaf.rs:
##########
@@ -0,0 +1,228 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
+use std::{any::Any, sync::Arc};
+
+use arrow::{
+    array::{ArrayRef, Float32Array},
+    record_batch::RecordBatch,
+};
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::{cast::as_float64_array, ScalarValue};
+use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature};
+
+/// This example shows how to use the full AggregateUDFImpl API to implement a 
user
+/// defined aggregate function. As in the `simple_udaf.rs` example, this 
struct implements
+/// a function `accumulator` that returns the `Accumulator` instance.
+///
+/// To do so, we must implement the `AggregateUDFImpl` trait.
+#[derive(Debug, Clone)]
+struct GeoMeanUdf {
+    signature: Signature,
+}
+
+impl GeoMeanUdf {
+    /// Create a new instance of the GeoMeanUdf struct
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(
+                // this function will always take one arguments of type f64
+                vec![DataType::Float64],
+                // this function is deterministic and will always return the 
same
+                // result for the same input
+                Volatility::Immutable,
+            ),
+        }
+    }
+}
+
+impl AggregateUDFImpl for GeoMeanUdf {
+    /// We implement as_any so that we can downcast the AggregateUDFImpl trait 
object
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Return the name of this function
+    fn name(&self) -> &str {
+        "geo_mean"
+    }
+
+    /// Return the "signature" of this function -- namely that types of 
arguments it will take
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    /// What is the type of value that will be returned by this function.
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    /// This is the accumulator factory; DataFusion uses it to create new 
accumulators.
+    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {

Review Comment:
   👍  
   
   While this "advanced" usage isn't much more advanced than the current 
"simple" UDAF I think this PR now provides a home / plausible way to implement 
the full `GroupsAccumulator` API for UDAFs (which is the powerful, very 
performant API used by built in aggregate functions in DataFusion)



##########
datafusion/expr/src/udaf.rs:
##########
@@ -117,33 +132,176 @@ impl AggregateUDF {
     }
 
     /// Returns this function's name
+    ///
+    /// See [`AggregateUDFImpl::name`] for more details.
     pub fn name(&self) -> &str {
-        &self.name
+        self.inner.name()
     }
 
     /// Returns this function's signature (what input types are accepted)
+    ///
+    /// See [`AggregateUDFImpl::signature`] for more details.
     pub fn signature(&self) -> &Signature {
-        &self.signature
+        self.inner.signature()
     }
 
     /// Return the type of the function given its input types
+    ///
+    /// See [`AggregateUDFImpl::return_type`] for more details.
     pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
-        // Old API returns an Arc of the datatype for some reason
-        let res = (self.return_type)(args)?;
-        Ok(res.as_ref().clone())
+        self.inner.return_type(args)
     }
 
     /// Return an accumualator the given aggregate, given
     /// its return datatype.
     pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn 
Accumulator>> {
-        (self.accumulator)(return_type)
+        self.inner.accumulator(return_type)
     }
 
     /// Return the type of the intermediate state used by this aggregator, 
given
     /// its return datatype. Supports multi-phase aggregations
     pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
-        // old API returns an Arc for some reason, try and unwrap it here
+        self.inner.state_type(return_type)
+    }
+}
+
+impl<F> From<F> for AggregateUDF
+where
+    F: AggregateUDFImpl + Send + Sync + 'static,
+{
+    fn from(fun: F) -> Self {
+        Self::new_from_impl(fun)
+    }
+}
+
+/// Trait for implementing [`AggregateUDF`].
+///
+/// This trait exposes the full API for implementing user defined aggregate 
functions and
+/// can be used to implement any function.
+///
+/// See [`advanced_udaf.rs`] for a full example with complete implementation 
and
+/// [`AggregateUDF`] for other available options.
+///
+///
+/// [`advanced_udaf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
+/// # Basic Example
+/// ```
+/// # use std::any::Any;
+/// # use arrow::datatypes::DataType;
+/// # use datafusion_common::{DataFusionError, plan_err, Result};
+/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
+/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator};
+/// #[derive(Debug, Clone)]
+/// struct GeoMeanUdf {
+///   signature: Signature
+/// };
+///
+/// impl GeoMeanUdf {
+///   fn new() -> Self {
+///     Self {
+///       signature: Signature::uniform(1, vec![DataType::Float64], 
Volatility::Immutable)
+///      }
+///   }
+/// }
+///
+/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
+/// impl AggregateUDFImpl for GeoMeanUdf {
+///    fn as_any(&self) -> &dyn Any { self }
+///    fn name(&self) -> &str { "geo_mean" }
+///    fn signature(&self) -> &Signature { &self.signature }
+///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
+///      if !matches!(args.get(0), Some(&DataType::Float64)) {
+///        return plan_err!("add_one only accepts Float64 arguments");
+///      }
+///      Ok(DataType::Float64)
+///    }
+///    // This is the accumulator factory; DataFusion uses it to create new 
accumulators.
+///    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> 
{ unimplemented!() }
+///    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
+///        Ok(vec![DataType::Float64, DataType::UInt32])
+///    }
+/// }
+///
+/// // Create a new AggregateUDF from the implementation
+/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
+///
+/// // Call the function `geo_mean(col)`
+/// let expr = geometric_mean.call(vec![col("a")]);
+/// ```
+pub trait AggregateUDFImpl: Debug + Send + Sync {
+    /// Returns this object as an [`Any`] trait object
+    fn as_any(&self) -> &dyn Any;
+
+    /// Returns this function's name
+    fn name(&self) -> &str;
+
+    /// Returns the function's [`Signature`] for information about what input
+    /// types are accepted and the function's Volatility.
+    fn signature(&self) -> &Signature;
+
+    /// What [`DataType`] will be returned by this function, given the types of
+    /// the arguments
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
+
+    /// This is the accumulator factory [`AccumulatorFactoryFunction`];
+    /// DataFusion uses it to create new accumulators.
+    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>>;

Review Comment:
   ```suggestion
       /// Return a new [`Accumulator`] that aggregates values for a specific
       /// group during query execution.
       fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>>;
   ```



##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -899,23 +898,57 @@ mod test {
 
     #[test]
     fn aggregate() -> Result<()> {
+        #[derive(Debug, Clone)]

Review Comment:
   Maybe we could use SimpleAggregateUDF here as well. Not sure if that is any 
better / worse though



##########
datafusion/optimizer/src/analyzer/type_coercion.rs:
##########
@@ -902,19 +902,17 @@ mod test {
     #[test]
     fn agg_udaf_invalid_input() -> Result<()> {
         let empty = empty();
-        let return_type: ReturnTypeFunction =
-            Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
-        let state_type: StateTypeFunction =
-            Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, 
DataType::Float64])));
+        let return_type = DataType::Float64;
+        let state_type = vec![DataType::UInt64, DataType::Float64];
         let accumulator: AccumulatorFactoryFunction =
             Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
-        let my_avg = AggregateUDF::new(
-            "MY_AVG",
-            &Signature::uniform(1, vec![DataType::Float64], 
Volatility::Immutable),
-            &return_type,
-            &accumulator,
-            &state_type,
-        );
+        let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
+            "MY_AVG".to_string(),

Review Comment:
   Since `new_with_signature` takes `impl Into<String>` I think this would also 
work:
   
   ```suggestion
               "MY_AVG",
   ```



##########
datafusion/expr/src/udaf.rs:
##########
@@ -117,33 +132,176 @@ impl AggregateUDF {
     }
 
     /// Returns this function's name
+    ///
+    /// See [`AggregateUDFImpl::name`] for more details.
     pub fn name(&self) -> &str {
-        &self.name
+        self.inner.name()
     }
 
     /// Returns this function's signature (what input types are accepted)
+    ///
+    /// See [`AggregateUDFImpl::signature`] for more details.
     pub fn signature(&self) -> &Signature {
-        &self.signature
+        self.inner.signature()
     }
 
     /// Return the type of the function given its input types
+    ///
+    /// See [`AggregateUDFImpl::return_type`] for more details.
     pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
-        // Old API returns an Arc of the datatype for some reason
-        let res = (self.return_type)(args)?;
-        Ok(res.as_ref().clone())
+        self.inner.return_type(args)
     }
 
     /// Return an accumualator the given aggregate, given
     /// its return datatype.
     pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn 
Accumulator>> {
-        (self.accumulator)(return_type)
+        self.inner.accumulator(return_type)
     }
 
     /// Return the type of the intermediate state used by this aggregator, 
given
     /// its return datatype. Supports multi-phase aggregations
     pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
-        // old API returns an Arc for some reason, try and unwrap it here
+        self.inner.state_type(return_type)
+    }
+}
+
+impl<F> From<F> for AggregateUDF
+where
+    F: AggregateUDFImpl + Send + Sync + 'static,
+{
+    fn from(fun: F) -> Self {
+        Self::new_from_impl(fun)
+    }
+}
+
+/// Trait for implementing [`AggregateUDF`].
+///
+/// This trait exposes the full API for implementing user defined aggregate 
functions and
+/// can be used to implement any function.
+///
+/// See [`advanced_udaf.rs`] for a full example with complete implementation 
and
+/// [`AggregateUDF`] for other available options.
+///
+///
+/// [`advanced_udaf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
+/// # Basic Example
+/// ```
+/// # use std::any::Any;
+/// # use arrow::datatypes::DataType;
+/// # use datafusion_common::{DataFusionError, plan_err, Result};
+/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
+/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator};
+/// #[derive(Debug, Clone)]
+/// struct GeoMeanUdf {
+///   signature: Signature
+/// };
+///
+/// impl GeoMeanUdf {
+///   fn new() -> Self {
+///     Self {
+///       signature: Signature::uniform(1, vec![DataType::Float64], 
Volatility::Immutable)
+///      }
+///   }
+/// }
+///
+/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf
+/// impl AggregateUDFImpl for GeoMeanUdf {
+///    fn as_any(&self) -> &dyn Any { self }
+///    fn name(&self) -> &str { "geo_mean" }
+///    fn signature(&self) -> &Signature { &self.signature }
+///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
+///      if !matches!(args.get(0), Some(&DataType::Float64)) {
+///        return plan_err!("add_one only accepts Float64 arguments");
+///      }
+///      Ok(DataType::Float64)
+///    }
+///    // This is the accumulator factory; DataFusion uses it to create new 
accumulators.
+///    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> 
{ unimplemented!() }
+///    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
+///        Ok(vec![DataType::Float64, DataType::UInt32])
+///    }
+/// }
+///
+/// // Create a new AggregateUDF from the implementation
+/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new());
+///
+/// // Call the function `geo_mean(col)`
+/// let expr = geometric_mean.call(vec![col("a")]);
+/// ```
+pub trait AggregateUDFImpl: Debug + Send + Sync {
+    /// Returns this object as an [`Any`] trait object
+    fn as_any(&self) -> &dyn Any;
+
+    /// Returns this function's name
+    fn name(&self) -> &str;
+
+    /// Returns the function's [`Signature`] for information about what input
+    /// types are accepted and the function's Volatility.
+    fn signature(&self) -> &Signature;
+
+    /// What [`DataType`] will be returned by this function, given the types of
+    /// the arguments
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
+
+    /// This is the accumulator factory [`AccumulatorFactoryFunction`];
+    /// DataFusion uses it to create new accumulators.
+    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>>;
+
+    /// This is the description of the state.
+    /// accumulator's state() must match the types here.

Review Comment:
   ```suggestion
       /// Return the type used to serialize the  [`Accumulator`]'s 
intermediate state. 
       /// See [`Accumulator::state()`] for more details
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to