This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a8348f711 UDAF: Extend more args to `state_fields` and 
`groups_accumulator_supported` and introduce `ReversedUDAF` (#10525)
5a8348f711 is described below

commit 5a8348f7111b2b0d39f2bd3fd1b1534338113b9f
Author: Jay Zhan <[email protected]>
AuthorDate: Thu May 16 08:30:58 2024 +0800

    UDAF: Extend more args to `state_fields` and `groups_accumulator_supported` 
and introduce `ReversedUDAF` (#10525)
    
    * extends args
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * reuse accumulator args
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix example
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion-examples/examples/advanced_udaf.rs      | 15 ++----
 .../examples/simplify_udaf_expression.rs           | 11 ++---
 .../tests/user_defined/user_defined_aggregates.rs  |  2 +-
 datafusion/expr/src/expr_fn.rs                     |  8 +--
 datafusion/expr/src/function.rs                    | 51 +++++++++++++------
 datafusion/expr/src/udaf.rs                        | 57 +++++++++++++---------
 datafusion/functions-aggregate/src/covariance.rs   | 22 +++------
 datafusion/functions-aggregate/src/first_last.rs   | 15 ++----
 .../src/simplify_expressions/expr_simplifier.rs    |  6 ++-
 .../physical-expr-common/src/aggregate/mod.rs      | 43 +++++++++++-----
 10 files changed, 128 insertions(+), 102 deletions(-)

diff --git a/datafusion-examples/examples/advanced_udaf.rs 
b/datafusion-examples/examples/advanced_udaf.rs
index 342a23b6e7..cf28447221 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -31,8 +31,8 @@ use datafusion::error::Result;
 use datafusion::prelude::*;
 use datafusion_common::{cast::as_float64_array, ScalarValue};
 use datafusion_expr::{
-    function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
-    GroupsAccumulator, Signature,
+    function::{AccumulatorArgs, StateFieldsArgs},
+    Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
 };
 
 /// This example shows how to use the full AggregateUDFImpl API to implement a 
user
@@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf {
     }
 
     /// This is the description of the state. accumulator's state() must match 
the types here.
-    fn state_fields(
-        &self,
-        _name: &str,
-        value_type: DataType,
-        _ordering_fields: Vec<arrow_schema::Field>,
-    ) -> Result<Vec<arrow_schema::Field>> {
+    fn state_fields(&self, args: StateFieldsArgs) -> 
Result<Vec<arrow_schema::Field>> {
         Ok(vec![
-            Field::new("prod", value_type, true),
+            Field::new("prod", args.return_type.clone(), true),
             Field::new("n", DataType::UInt32, true),
         ])
     }
 
     /// Tell DataFusion that this aggregate supports the more performant 
`GroupsAccumulator`
     /// which is used for cases when there are grouping columns in the query
-    fn groups_accumulator_supported(&self) -> bool {
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
         true
     }
 
diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs 
b/datafusion-examples/examples/simplify_udaf_expression.rs
index 92deb20272..08b6bcab01 100644
--- a/datafusion-examples/examples/simplify_udaf_expression.rs
+++ b/datafusion-examples/examples/simplify_udaf_expression.rs
@@ -17,7 +17,7 @@
 
 use arrow_schema::{Field, Schema};
 use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
-use datafusion_expr::function::AggregateFunctionSimplification;
+use datafusion_expr::function::{AggregateFunctionSimplification, 
StateFieldsArgs};
 use datafusion_expr::simplify::SimplifyInfo;
 
 use std::{any::Any, sync::Arc};
@@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf {
         unimplemented!("should not be invoked")
     }
 
-    fn state_fields(
-        &self,
-        _name: &str,
-        _value_type: DataType,
-        _ordering_fields: Vec<arrow_schema::Field>,
-    ) -> Result<Vec<arrow_schema::Field>> {
+    fn state_fields(&self, _args: StateFieldsArgs) -> 
Result<Vec<arrow_schema::Field>> {
         unimplemented!("should not be invoked")
     }
 
-    fn groups_accumulator_supported(&self) -> bool {
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
         true
     }
 
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs 
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 8f02fb30b0..d199f04ba7 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
         panic!("accumulator shouldn't invoke");
     }
 
-    fn groups_accumulator_supported(&self) -> bool {
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
         true
     }
 
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 1d976a12cc..64763a9736 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -23,6 +23,7 @@ use crate::expr::{
 };
 use crate::function::{
     AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
+    StateFieldsArgs,
 };
 use crate::{
     aggregate_function, conditional_expressions::CaseBuilder, 
logical_plan::Subquery,
@@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
         (self.accumulator)(acc_args)
     }
 
-    fn state_fields(
-        &self,
-        _name: &str,
-        _value_type: DataType,
-        _ordering_fields: Vec<Field>,
-    ) -> Result<Vec<Field>> {
+    fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
         Ok(self.state_fields.clone())
     }
 }
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 4e4d77924a..714cfa1af6 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -19,7 +19,7 @@
 
 use crate::ColumnarValue;
 use crate::{Accumulator, Expr, PartitionEvaluator};
-use arrow::datatypes::{DataType, Schema};
+use arrow::datatypes::{DataType, Field, Schema};
 use datafusion_common::Result;
 use std::sync::Arc;
 
@@ -41,11 +41,14 @@ pub type ReturnTypeFunction =
 /// [`AccumulatorArgs`] contains information about how an aggregate
 /// function was called, including the types of its arguments and any optional
 /// ordering expressions.
+#[derive(Debug)]
 pub struct AccumulatorArgs<'a> {
     /// The return type of the aggregate function.
     pub data_type: &'a DataType,
+
     /// The schema of the input arguments
     pub schema: &'a Schema,
+
     /// Whether to ignore nulls.
     ///
     /// SQL allows the user to specify `IGNORE NULLS`, for example:
@@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> {
     ///
     /// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
     pub sort_exprs: &'a [Expr],
+
+    /// Whether the aggregate function is distinct.
+    ///
+    /// ```sql
+    /// SELECT COUNT(DISTINCT column1) FROM t;
+    /// ```
+    pub is_distinct: bool,
+
+    /// The input type of the aggregate function.
+    pub input_type: &'a DataType,
+
+    /// The number of arguments the aggregate function takes.
+    pub args_num: usize,
 }
 
-impl<'a> AccumulatorArgs<'a> {
-    pub fn new(
-        data_type: &'a DataType,
-        schema: &'a Schema,
-        ignore_nulls: bool,
-        sort_exprs: &'a [Expr],
-    ) -> Self {
-        Self {
-            data_type,
-            schema,
-            ignore_nulls,
-            sort_exprs,
-        }
-    }
+/// [`StateFieldsArgs`] contains information about the fields that an
+/// aggregate function's accumulator should have. Used for 
[`AggregateUDFImpl::state_fields`].
+///
+/// [`AggregateUDFImpl::state_fields`]: 
crate::udaf::AggregateUDFImpl::state_fields
+pub struct StateFieldsArgs<'a> {
+    /// The name of the aggregate function.
+    pub name: &'a str,
+
+    /// The input type of the aggregate function.
+    pub input_type: &'a DataType,
+
+    /// The return type of the aggregate function.
+    pub return_type: &'a DataType,
+
+    /// The ordering fields of the aggregate function.
+    pub ordering_fields: &'a [Field],
+
+    /// Whether the aggregate function is distinct.
+    pub is_distinct: bool,
 }
 
 /// Factory that returns an accumulator for the given aggregate function.
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 95121d78e7..4fd8d51679 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,7 +17,9 @@
 
 //! [`AggregateUDF`]: User Defined Aggregate Functions
 
-use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
+use crate::function::{
+    AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
+};
 use crate::groups_accumulator::GroupsAccumulator;
 use crate::utils::format_state_name;
 use crate::{Accumulator, Expr};
@@ -177,18 +179,13 @@ impl AggregateUDF {
     /// for more details.
     ///
     /// This is used to support multi-phase aggregations
-    pub fn state_fields(
-        &self,
-        name: &str,
-        value_type: DataType,
-        ordering_fields: Vec<Field>,
-    ) -> Result<Vec<Field>> {
-        self.inner.state_fields(name, value_type, ordering_fields)
+    pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+        self.inner.state_fields(args)
     }
 
     /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more 
details.
-    pub fn groups_accumulator_supported(&self) -> bool {
-        self.inner.groups_accumulator_supported()
+    pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        self.inner.groups_accumulator_supported(args)
     }
 
     /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
@@ -232,7 +229,7 @@ where
 /// # use arrow::datatypes::DataType;
 /// # use datafusion_common::{DataFusionError, plan_err, Result};
 /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
-/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, 
function::AccumulatorArgs};
+/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, 
function::{AccumulatorArgs, StateFieldsArgs}};
 /// # use arrow::datatypes::Schema;
 /// # use arrow::datatypes::Field;
 /// #[derive(Debug, Clone)]
@@ -261,9 +258,9 @@ where
 ///    }
 ///    // This is the accumulator factory; DataFusion uses it to create new 
accumulators.
 ///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> { unimplemented!() }
-///    fn state_fields(&self, _name: &str, value_type: DataType, 
_ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
+///    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
 ///        Ok(vec![
-///             Field::new("value", value_type, true),
+///             Field::new("value", args.return_type.clone(), true),
 ///             Field::new("ordering", DataType::UInt32, true)
 ///        ])
 ///    }
@@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
     /// The name of the fields must be unique within the query and thus should
     /// be derived from `name`. See [`format_state_name`] for a utility 
function
     /// to generate a unique name.
-    fn state_fields(
-        &self,
-        name: &str,
-        value_type: DataType,
-        ordering_fields: Vec<Field>,
-    ) -> Result<Vec<Field>> {
-        let value_fields = vec![Field::new(
-            format_state_name(name, "value"),
-            value_type,
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+        let fields = vec![Field::new(
+            format_state_name(args.name, "value"),
+            args.return_type.clone(),
             true,
         )];
 
-        Ok(value_fields.into_iter().chain(ordering_fields).collect())
+        Ok(fields
+            .into_iter()
+            .chain(args.ordering_fields.to_vec())
+            .collect())
     }
 
     /// If the aggregate expression has a specialized
@@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
     /// `Self::accumulator` for certain queries, such as when this aggregate is
     /// used as a window function or when there no GROUP BY columns in the
     /// query.
-    fn groups_accumulator_supported(&self) -> bool {
+    fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
         false
     }
 
@@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
     fn simplify(&self) -> Option<AggregateFunctionSimplification> {
         None
     }
+
+    /// Returns the reverse expression of the aggregate function.
+    fn reverse_expr(&self) -> ReversedUDAF {
+        ReversedUDAF::NotSupported
+    }
+}
+
+pub enum ReversedUDAF {
+    /// The expression is the same as the original expression, like SUM, COUNT
+    Identical,
+    /// The expression does not support reverse calculation, like ArrayAgg
+    NotSupported,
+    /// The expression is different from the original expression
+    Reversed(Arc<dyn AggregateUDFImpl>),
 }
 
 /// AggregateUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/functions-aggregate/src/covariance.rs 
b/datafusion/functions-aggregate/src/covariance.rs
index 1210e1529d..6f03b256fd 100644
--- a/datafusion/functions-aggregate/src/covariance.rs
+++ b/datafusion/functions-aggregate/src/covariance.rs
@@ -30,8 +30,10 @@ use datafusion_common::{
     ScalarValue,
 };
 use datafusion_expr::{
-    function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
-    utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, 
Volatility,
+    function::{AccumulatorArgs, StateFieldsArgs},
+    type_coercion::aggregates::NUMERICS,
+    utils::format_state_name,
+    Accumulator, AggregateUDFImpl, Signature, Volatility,
 };
 use datafusion_physical_expr_common::aggregate::stats::StatsType;
 
@@ -101,12 +103,8 @@ impl AggregateUDFImpl for CovarianceSample {
         Ok(DataType::Float64)
     }
 
-    fn state_fields(
-        &self,
-        name: &str,
-        _value_type: DataType,
-        _ordering_fields: Vec<Field>,
-    ) -> Result<Vec<Field>> {
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+        let name = args.name;
         Ok(vec![
             Field::new(format_state_name(name, "count"), DataType::UInt64, 
true),
             Field::new(format_state_name(name, "mean1"), DataType::Float64, 
true),
@@ -176,12 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation {
         Ok(DataType::Float64)
     }
 
-    fn state_fields(
-        &self,
-        name: &str,
-        _value_type: DataType,
-        _ordering_fields: Vec<Field>,
-    ) -> Result<Vec<Field>> {
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+        let name = args.name;
         Ok(vec![
             Field::new(format_state_name(name, "count"), DataType::UInt64, 
true),
             Field::new(format_state_name(name, "mean1"), DataType::Float64, 
true),
diff --git a/datafusion/functions-aggregate/src/first_last.rs 
b/datafusion/functions-aggregate/src/first_last.rs
index e3b685e903..5d3d483440 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows, 
get_arrayref_at_indices, get_row_at
 use datafusion_common::{
     arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
 };
-use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
 use datafusion_expr::type_coercion::aggregates::NUMERICS;
 use datafusion_expr::utils::format_state_name;
 use datafusion_expr::{
@@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue {
         .map(|acc| 
Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
     }
 
-    fn state_fields(
-        &self,
-        name: &str,
-        value_type: DataType,
-        ordering_fields: Vec<Field>,
-    ) -> Result<Vec<Field>> {
+    fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
         let mut fields = vec![Field::new(
-            format_state_name(name, "first_value"),
-            value_type,
+            format_state_name(args.name, "first_value"),
+            args.return_type.clone(),
             true,
         )];
-        fields.extend(ordering_fields);
+        fields.extend(args.ordering_fields.to_vec());
         fields.push(Field::new("is_set", DataType::Boolean, true));
         Ok(fields)
     }
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 55052542a8..455d659fb2 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> 
Result<Expr> {
 mod tests {
     use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
     use datafusion_expr::{
-        function::AggregateFunctionSimplification, 
interval_arithmetic::Interval, *,
+        function::{AccumulatorArgs, AggregateFunctionSimplification},
+        interval_arithmetic::Interval,
+        *,
     };
     use std::{
         collections::HashMap,
@@ -3783,7 +3785,7 @@ mod tests {
             unimplemented!("not needed for tests")
         }
 
-        fn groups_accumulator_supported(&self) -> bool {
+        fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool 
{
             unimplemented!("not needed for testing")
         }
 
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs 
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 05641b373b..da24f335b2 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -20,6 +20,7 @@ pub mod utils;
 
 use arrow::datatypes::{DataType, Field, Schema};
 use datafusion_common::{not_impl_err, Result};
+use datafusion_expr::function::StateFieldsArgs;
 use datafusion_expr::type_coercion::aggregates::check_arg_count;
 use datafusion_expr::{
     function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, 
GroupsAccumulator,
@@ -74,6 +75,7 @@ pub fn create_aggregate_expr(
         ignore_nulls,
         ordering_fields,
         is_distinct,
+        input_type: input_exprs_types[0].clone(),
     }))
 }
 
@@ -166,6 +168,7 @@ pub struct AggregateFunctionExpr {
     ignore_nulls: bool,
     ordering_fields: Vec<Field>,
     is_distinct: bool,
+    input_type: DataType,
 }
 
 impl AggregateFunctionExpr {
@@ -191,11 +194,15 @@ impl AggregateExpr for AggregateFunctionExpr {
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
-        self.fun.state_fields(
-            self.name(),
-            self.data_type.clone(),
-            self.ordering_fields.clone(),
-        )
+        let args = StateFieldsArgs {
+            name: &self.name,
+            input_type: &self.input_type,
+            return_type: &self.data_type,
+            ordering_fields: &self.ordering_fields,
+            is_distinct: self.is_distinct,
+        };
+
+        self.fun.state_fields(args)
     }
 
     fn field(&self) -> Result<Field> {
@@ -203,12 +210,15 @@ impl AggregateExpr for AggregateFunctionExpr {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        let acc_args = AccumulatorArgs::new(
-            &self.data_type,
-            &self.schema,
-            self.ignore_nulls,
-            &self.sort_exprs,
-        );
+        let acc_args = AccumulatorArgs {
+            data_type: &self.data_type,
+            schema: &self.schema,
+            ignore_nulls: self.ignore_nulls,
+            sort_exprs: &self.sort_exprs,
+            is_distinct: self.is_distinct,
+            input_type: &self.input_type,
+            args_num: self.args.len(),
+        };
 
         self.fun.accumulator(acc_args)
     }
@@ -273,7 +283,16 @@ impl AggregateExpr for AggregateFunctionExpr {
     }
 
     fn groups_accumulator_supported(&self) -> bool {
-        self.fun.groups_accumulator_supported()
+        let args = AccumulatorArgs {
+            data_type: &self.data_type,
+            schema: &self.schema,
+            ignore_nulls: self.ignore_nulls,
+            sort_exprs: &self.sort_exprs,
+            is_distinct: self.is_distinct,
+            input_type: &self.input_type,
+            args_num: self.args.len(),
+        };
+        self.fun.groups_accumulator_supported(args)
     }
 
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to