This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5a8348f711 UDAF: Extend more args to `state_fields` and
`groups_accumulator_supported` and introduce `ReversedUDAF` (#10525)
5a8348f711 is described below
commit 5a8348f7111b2b0d39f2bd3fd1b1534338113b9f
Author: Jay Zhan <[email protected]>
AuthorDate: Thu May 16 08:30:58 2024 +0800
UDAF: Extend more args to `state_fields` and `groups_accumulator_supported`
and introduce `ReversedUDAF` (#10525)
* extends args
Signed-off-by: jayzhan211 <[email protected]>
* reuse accumulator args
Signed-off-by: jayzhan211 <[email protected]>
* fix example
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion-examples/examples/advanced_udaf.rs | 15 ++----
.../examples/simplify_udaf_expression.rs | 11 ++---
.../tests/user_defined/user_defined_aggregates.rs | 2 +-
datafusion/expr/src/expr_fn.rs | 8 +--
datafusion/expr/src/function.rs | 51 +++++++++++++------
datafusion/expr/src/udaf.rs | 57 +++++++++++++---------
datafusion/functions-aggregate/src/covariance.rs | 22 +++------
datafusion/functions-aggregate/src/first_last.rs | 15 ++----
.../src/simplify_expressions/expr_simplifier.rs | 6 ++-
.../physical-expr-common/src/aggregate/mod.rs | 43 +++++++++++-----
10 files changed, 128 insertions(+), 102 deletions(-)
diff --git a/datafusion-examples/examples/advanced_udaf.rs
b/datafusion-examples/examples/advanced_udaf.rs
index 342a23b6e7..cf28447221 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -31,8 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
- function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
- GroupsAccumulator, Signature,
+ function::{AccumulatorArgs, StateFieldsArgs},
+ Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};
/// This example shows how to use the full AggregateUDFImpl API to implement a
user
@@ -92,21 +92,16 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}
/// This is the description of the state. accumulator's state() must match
the types here.
- fn state_fields(
- &self,
- _name: &str,
- value_type: DataType,
- _ordering_fields: Vec<arrow_schema::Field>,
- ) -> Result<Vec<arrow_schema::Field>> {
+ fn state_fields(&self, args: StateFieldsArgs) ->
Result<Vec<arrow_schema::Field>> {
Ok(vec![
- Field::new("prod", value_type, true),
+ Field::new("prod", args.return_type.clone(), true),
Field::new("n", DataType::UInt32, true),
])
}
/// Tell DataFusion that this aggregate supports the more performant
`GroupsAccumulator`
/// which is used for cases when there are grouping columns in the query
- fn groups_accumulator_supported(&self) -> bool {
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs
b/datafusion-examples/examples/simplify_udaf_expression.rs
index 92deb20272..08b6bcab01 100644
--- a/datafusion-examples/examples/simplify_udaf_expression.rs
+++ b/datafusion-examples/examples/simplify_udaf_expression.rs
@@ -17,7 +17,7 @@
use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
-use datafusion_expr::function::AggregateFunctionSimplification;
+use datafusion_expr::function::{AggregateFunctionSimplification,
StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;
use std::{any::Any, sync::Arc};
@@ -70,16 +70,11 @@ impl AggregateUDFImpl for BetterAvgUdaf {
unimplemented!("should not be invoked")
}
- fn state_fields(
- &self,
- _name: &str,
- _value_type: DataType,
- _ordering_fields: Vec<arrow_schema::Field>,
- ) -> Result<Vec<arrow_schema::Field>> {
+ fn state_fields(&self, _args: StateFieldsArgs) ->
Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}
- fn groups_accumulator_supported(&self) -> bool {
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 8f02fb30b0..d199f04ba7 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -725,7 +725,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
panic!("accumulator shouldn't invoke");
}
- fn groups_accumulator_supported(&self) -> bool {
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
true
}
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 1d976a12cc..64763a9736 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -23,6 +23,7 @@ use crate::expr::{
};
use crate::function::{
AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
+ StateFieldsArgs,
};
use crate::{
aggregate_function, conditional_expressions::CaseBuilder,
logical_plan::Subquery,
@@ -690,12 +691,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}
- fn state_fields(
- &self,
- _name: &str,
- _value_type: DataType,
- _ordering_fields: Vec<Field>,
- ) -> Result<Vec<Field>> {
+ fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(self.state_fields.clone())
}
}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 4e4d77924a..714cfa1af6 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -19,7 +19,7 @@
use crate::ColumnarValue;
use crate::{Accumulator, Expr, PartitionEvaluator};
-use arrow::datatypes::{DataType, Schema};
+use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Result;
use std::sync::Arc;
@@ -41,11 +41,14 @@ pub type ReturnTypeFunction =
/// [`AccumulatorArgs`] contains information about how an aggregate
/// function was called, including the types of its arguments and any optional
/// ordering expressions.
+#[derive(Debug)]
pub struct AccumulatorArgs<'a> {
/// The return type of the aggregate function.
pub data_type: &'a DataType,
+
/// The schema of the input arguments
pub schema: &'a Schema,
+
/// Whether to ignore nulls.
///
/// SQL allows the user to specify `IGNORE NULLS`, for example:
@@ -66,22 +69,40 @@ pub struct AccumulatorArgs<'a> {
///
/// If no `ORDER BY` is specified, `sort_exprs`` will be empty.
pub sort_exprs: &'a [Expr],
+
+ /// Whether the aggregate function is distinct.
+ ///
+ /// ```sql
+ /// SELECT COUNT(DISTINCT column1) FROM t;
+ /// ```
+ pub is_distinct: bool,
+
+ /// The input type of the aggregate function.
+ pub input_type: &'a DataType,
+
+ /// The number of arguments the aggregate function takes.
+ pub args_num: usize,
}
-impl<'a> AccumulatorArgs<'a> {
- pub fn new(
- data_type: &'a DataType,
- schema: &'a Schema,
- ignore_nulls: bool,
- sort_exprs: &'a [Expr],
- ) -> Self {
- Self {
- data_type,
- schema,
- ignore_nulls,
- sort_exprs,
- }
- }
+/// [`StateFieldsArgs`] contains information about the fields that an
+/// aggregate function's accumulator should have. Used for
[`AggregateUDFImpl::state_fields`].
+///
+/// [`AggregateUDFImpl::state_fields`]:
crate::udaf::AggregateUDFImpl::state_fields
+pub struct StateFieldsArgs<'a> {
+ /// The name of the aggregate function.
+ pub name: &'a str,
+
+ /// The input type of the aggregate function.
+ pub input_type: &'a DataType,
+
+ /// The return type of the aggregate function.
+ pub return_type: &'a DataType,
+
+ /// The ordering fields of the aggregate function.
+ pub ordering_fields: &'a [Field],
+
+ /// Whether the aggregate function is distinct.
+ pub is_distinct: bool,
}
/// Factory that returns an accumulator for the given aggregate function.
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 95121d78e7..4fd8d51679 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,7 +17,9 @@
//! [`AggregateUDF`]: User Defined Aggregate Functions
-use crate::function::{AccumulatorArgs, AggregateFunctionSimplification};
+use crate::function::{
+ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
+};
use crate::groups_accumulator::GroupsAccumulator;
use crate::utils::format_state_name;
use crate::{Accumulator, Expr};
@@ -177,18 +179,13 @@ impl AggregateUDF {
/// for more details.
///
/// This is used to support multi-phase aggregations
- pub fn state_fields(
- &self,
- name: &str,
- value_type: DataType,
- ordering_fields: Vec<Field>,
- ) -> Result<Vec<Field>> {
- self.inner.state_fields(name, value_type, ordering_fields)
+ pub fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ self.inner.state_fields(args)
}
/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more
details.
- pub fn groups_accumulator_supported(&self) -> bool {
- self.inner.groups_accumulator_supported()
+ pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+ self.inner.groups_accumulator_supported(args)
}
/// See [`AggregateUDFImpl::create_groups_accumulator`] for more details.
@@ -232,7 +229,7 @@ where
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
-/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator,
function::AccumulatorArgs};
+/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator,
function::{AccumulatorArgs, StateFieldsArgs}};
/// # use arrow::datatypes::Schema;
/// # use arrow::datatypes::Field;
/// #[derive(Debug, Clone)]
@@ -261,9 +258,9 @@ where
/// }
/// // This is the accumulator factory; DataFusion uses it to create new
accumulators.
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> { unimplemented!() }
-/// fn state_fields(&self, _name: &str, value_type: DataType,
_ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
+/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
/// Ok(vec![
-/// Field::new("value", value_type, true),
+/// Field::new("value", args.return_type.clone(), true),
/// Field::new("ordering", DataType::UInt32, true)
/// ])
/// }
@@ -319,19 +316,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// The name of the fields must be unique within the query and thus should
/// be derived from `name`. See [`format_state_name`] for a utility
function
/// to generate a unique name.
- fn state_fields(
- &self,
- name: &str,
- value_type: DataType,
- ordering_fields: Vec<Field>,
- ) -> Result<Vec<Field>> {
- let value_fields = vec![Field::new(
- format_state_name(name, "value"),
- value_type,
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ let fields = vec![Field::new(
+ format_state_name(args.name, "value"),
+ args.return_type.clone(),
true,
)];
- Ok(value_fields.into_iter().chain(ordering_fields).collect())
+ Ok(fields
+ .into_iter()
+ .chain(args.ordering_fields.to_vec())
+ .collect())
}
/// If the aggregate expression has a specialized
@@ -344,7 +339,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// `Self::accumulator` for certain queries, such as when this aggregate is
/// used as a window function or when there no GROUP BY columns in the
/// query.
- fn groups_accumulator_supported(&self) -> bool {
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {
false
}
@@ -389,6 +384,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
None
}
+
+ /// Returns the reverse expression of the aggregate function.
+ fn reverse_expr(&self) -> ReversedUDAF {
+ ReversedUDAF::NotSupported
+ }
+}
+
+pub enum ReversedUDAF {
+ /// The expression is the same as the original expression, like SUM, COUNT
+ Identical,
+ /// The expression does not support reverse calculation, like ArrayAgg
+ NotSupported,
+ /// The expression is different from the original expression
+ Reversed(Arc<dyn AggregateUDFImpl>),
}
/// AggregateUDF that adds an alias to the underlying function. It is better to
diff --git a/datafusion/functions-aggregate/src/covariance.rs
b/datafusion/functions-aggregate/src/covariance.rs
index 1210e1529d..6f03b256fd 100644
--- a/datafusion/functions-aggregate/src/covariance.rs
+++ b/datafusion/functions-aggregate/src/covariance.rs
@@ -30,8 +30,10 @@ use datafusion_common::{
ScalarValue,
};
use datafusion_expr::{
- function::AccumulatorArgs, type_coercion::aggregates::NUMERICS,
- utils::format_state_name, Accumulator, AggregateUDFImpl, Signature,
Volatility,
+ function::{AccumulatorArgs, StateFieldsArgs},
+ type_coercion::aggregates::NUMERICS,
+ utils::format_state_name,
+ Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr_common::aggregate::stats::StatsType;
@@ -101,12 +103,8 @@ impl AggregateUDFImpl for CovarianceSample {
Ok(DataType::Float64)
}
- fn state_fields(
- &self,
- name: &str,
- _value_type: DataType,
- _ordering_fields: Vec<Field>,
- ) -> Result<Vec<Field>> {
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64,
true),
Field::new(format_state_name(name, "mean1"), DataType::Float64,
true),
@@ -176,12 +174,8 @@ impl AggregateUDFImpl for CovariancePopulation {
Ok(DataType::Float64)
}
- fn state_fields(
- &self,
- name: &str,
- _value_type: DataType,
- _ordering_fields: Vec<Field>,
- ) -> Result<Vec<Field>> {
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
+ let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64,
true),
Field::new(format_state_name(name, "mean1"), DataType::Float64,
true),
diff --git a/datafusion/functions-aggregate/src/first_last.rs
b/datafusion/functions-aggregate/src/first_last.rs
index e3b685e903..5d3d483440 100644
--- a/datafusion/functions-aggregate/src/first_last.rs
+++ b/datafusion/functions-aggregate/src/first_last.rs
@@ -24,7 +24,7 @@ use datafusion_common::utils::{compare_rows,
get_arrayref_at_indices, get_row_at
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
-use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
@@ -147,18 +147,13 @@ impl AggregateUDFImpl for FirstValue {
.map(|acc|
Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _)
}
- fn state_fields(
- &self,
- name: &str,
- value_type: DataType,
- ordering_fields: Vec<Field>,
- ) -> Result<Vec<Field>> {
+ fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let mut fields = vec![Field::new(
- format_state_name(name, "first_value"),
- value_type,
+ format_state_name(args.name, "first_value"),
+ args.return_type.clone(),
true,
)];
- fields.extend(ordering_fields);
+ fields.extend(args.ordering_fields.to_vec());
fields.push(Field::new("is_set", DataType::Boolean, true));
Ok(fields)
}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 55052542a8..455d659fb2 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1759,7 +1759,9 @@ fn inlist_except(mut l1: InList, l2: InList) ->
Result<Expr> {
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{
- function::AggregateFunctionSimplification,
interval_arithmetic::Interval, *,
+ function::{AccumulatorArgs, AggregateFunctionSimplification},
+ interval_arithmetic::Interval,
+ *,
};
use std::{
collections::HashMap,
@@ -3783,7 +3785,7 @@ mod tests {
unimplemented!("not needed for tests")
}
- fn groups_accumulator_supported(&self) -> bool {
+ fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool
{
unimplemented!("not needed for testing")
}
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 05641b373b..da24f335b2 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -20,6 +20,7 @@ pub mod utils;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
+use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr,
GroupsAccumulator,
@@ -74,6 +75,7 @@ pub fn create_aggregate_expr(
ignore_nulls,
ordering_fields,
is_distinct,
+ input_type: input_exprs_types[0].clone(),
}))
}
@@ -166,6 +168,7 @@ pub struct AggregateFunctionExpr {
ignore_nulls: bool,
ordering_fields: Vec<Field>,
is_distinct: bool,
+ input_type: DataType,
}
impl AggregateFunctionExpr {
@@ -191,11 +194,15 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn state_fields(&self) -> Result<Vec<Field>> {
- self.fun.state_fields(
- self.name(),
- self.data_type.clone(),
- self.ordering_fields.clone(),
- )
+ let args = StateFieldsArgs {
+ name: &self.name,
+ input_type: &self.input_type,
+ return_type: &self.data_type,
+ ordering_fields: &self.ordering_fields,
+ is_distinct: self.is_distinct,
+ };
+
+ self.fun.state_fields(args)
}
fn field(&self) -> Result<Field> {
@@ -203,12 +210,15 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- let acc_args = AccumulatorArgs::new(
- &self.data_type,
- &self.schema,
- self.ignore_nulls,
- &self.sort_exprs,
- );
+ let acc_args = AccumulatorArgs {
+ data_type: &self.data_type,
+ schema: &self.schema,
+ ignore_nulls: self.ignore_nulls,
+ sort_exprs: &self.sort_exprs,
+ is_distinct: self.is_distinct,
+ input_type: &self.input_type,
+ args_num: self.args.len(),
+ };
self.fun.accumulator(acc_args)
}
@@ -273,7 +283,16 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn groups_accumulator_supported(&self) -> bool {
- self.fun.groups_accumulator_supported()
+ let args = AccumulatorArgs {
+ data_type: &self.data_type,
+ schema: &self.schema,
+ ignore_nulls: self.ignore_nulls,
+ sort_exprs: &self.sort_exprs,
+ is_distinct: self.is_distinct,
+ input_type: &self.input_type,
+ args_num: self.args.len(),
+ };
+ self.fun.groups_accumulator_supported(args)
}
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]