This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new dfd4442da8 Make FirstValue an UDAF, Change
`AggregateUDFImpl::accumulator` signature, support ORDER BY for UDAFs (#9874)
dfd4442da8 is described below
commit dfd4442da8332116c7e1f9fcd9de7c6da856442b
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Apr 3 09:26:31 2024 +0800
Make FirstValue an UDAF, Change `AggregateUDFImpl::accumulator` signature,
support ORDER BY for UDAFs (#9874)
* first draft
Signed-off-by: jayzhan211 <[email protected]>
* clippy fix
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* use one vector for ordering req
Signed-off-by: jayzhan211 <[email protected]>
* add sort exprs to accumulator
Signed-off-by: jayzhan211 <[email protected]>
* clippy
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* fix doc test
Signed-off-by: jayzhan211 <[email protected]>
* change to ref
Signed-off-by: jayzhan211 <[email protected]>
* fix typo
Signed-off-by: jayzhan211 <[email protected]>
* fix doc
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* move schema and logical ordering exprs
Signed-off-by: jayzhan211 <[email protected]>
* remove redudant info
Signed-off-by: jayzhan211 <[email protected]>
* rename
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* add ignore nulls
Signed-off-by: jayzhan211 <[email protected]>
* fix conflict
Signed-off-by: jayzhan211 <[email protected]>
* backup
Signed-off-by: jayzhan211 <[email protected]>
* complete return_type
Signed-off-by: jayzhan211 <[email protected]>
* complete replace
Signed-off-by: jayzhan211 <[email protected]>
* split to first value udf
Signed-off-by: jayzhan211 <[email protected]>
* replace accumulator
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* small fix
Signed-off-by: jayzhan211 <[email protected]>
* remove ordering types
Signed-off-by: jayzhan211 <[email protected]>
* make state fields more flexible
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* replace done
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* rm comments
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* rm test1
Signed-off-by: jayzhan211 <[email protected]>
* fix state fields
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* args struct for accumulator
Signed-off-by: jayzhan211 <[email protected]>
* simplify
Signed-off-by: jayzhan211 <[email protected]>
* add sig
Signed-off-by: jayzhan211 <[email protected]>
* add comments
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* fix docs
Signed-off-by: jayzhan211 <[email protected]>
* use exprs utils
Signed-off-by: jayzhan211 <[email protected]>
* rm state type
Signed-off-by: jayzhan211 <[email protected]>
* add comment
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion-examples/examples/advanced_udaf.rs | 19 ++-
datafusion/core/src/execution/context/mod.rs | 20 ++++
datafusion/core/src/physical_planner.rs | 50 +++++---
.../tests/user_defined/user_defined_aggregates.rs | 20 ++--
datafusion/expr/src/expr.rs | 3 +-
datafusion/expr/src/expr_fn.rs | 128 ++++++++++++++++++---
datafusion/expr/src/function.rs | 41 ++++++-
datafusion/expr/src/tree_node/expr.rs | 1 +
datafusion/expr/src/udaf.rs | 101 ++++++++++------
datafusion/optimizer/src/analyzer/type_coercion.rs | 15 ++-
.../optimizer/src/common_subexpr_eliminate.rs | 4 +-
datafusion/physical-expr/src/aggregate/build_in.rs | 1 +
.../physical-expr/src/aggregate/first_last.rs | 58 +++++++++-
datafusion/physical-expr/src/aggregate/utils.rs | 2 +-
datafusion/physical-expr/src/lib.rs | 2 +
datafusion/physical-plan/src/aggregates/mod.rs | 2 +
datafusion/physical-plan/src/udaf.rs | 75 +++++++-----
datafusion/physical-plan/src/windows/mod.rs | 15 ++-
datafusion/proto/src/logical_plan/from_proto.rs | 1 +
datafusion/proto/src/physical_plan/mod.rs | 6 +-
.../proto/tests/cases/roundtrip_logical_plan.rs | 1 +
.../proto/tests/cases/roundtrip_physical_plan.rs | 6 +-
datafusion/sql/src/expr/function.rs | 11 +-
datafusion/substrait/src/logical_plan/consumer.rs | 2 +-
24 files changed, 450 insertions(+), 134 deletions(-)
diff --git a/datafusion-examples/examples/advanced_udaf.rs
b/datafusion-examples/examples/advanced_udaf.rs
index 10164a850b..342a23b6e7 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_physical_expr::NullState;
use std::{any::Any, sync::Arc};
@@ -30,7 +31,8 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::{cast::as_float64_array, ScalarValue};
use datafusion_expr::{
- Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
+ function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
+ GroupsAccumulator, Signature,
};
/// This example shows how to use the full AggregateUDFImpl API to implement a
user
@@ -85,13 +87,21 @@ impl AggregateUDFImpl for GeoMeanUdaf {
/// is supported, DataFusion will use this row oriented
/// accumulator when the aggregate function is used as a window function
/// or when there are only aggregates (no GROUP BY columns) in the plan.
- fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
Ok(Box::new(GeometricMean::new()))
}
/// This is the description of the state. accumulator's state() must match
the types here.
- fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
- Ok(vec![DataType::Float64, DataType::UInt32])
+ fn state_fields(
+ &self,
+ _name: &str,
+ value_type: DataType,
+ _ordering_fields: Vec<arrow_schema::Field>,
+ ) -> Result<Vec<arrow_schema::Field>> {
+ Ok(vec![
+ Field::new("prod", value_type, true),
+ Field::new("n", DataType::UInt32, true),
+ ])
}
/// Tell DataFusion that this aggregate supports the more performant
`GroupsAccumulator`
@@ -191,7 +201,6 @@ impl Accumulator for GeometricMean {
// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
- use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index f8bf0d2ee1..4eaaf94ecf 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -69,11 +69,14 @@ use datafusion_common::{
OwnedTableReference, SchemaReference,
};
use datafusion_execution::registry::SerializerRegistry;
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::{create_first_value, Signature, Volatility};
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
var_provider::is_system_variables,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
+use datafusion_physical_expr::create_first_value_accumulator;
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
planner::{object_name_to_table_reference, ContextProvider, ParserOptions,
SqlToRel},
@@ -82,6 +85,7 @@ use datafusion_sql::{
use async_trait::async_trait;
use chrono::{DateTime, Utc};
+use log::debug;
use parking_lot::RwLock;
use sqlparser::dialect::dialect_from_str;
use url::Url;
@@ -1451,6 +1455,22 @@ impl SessionState {
datafusion_functions_array::register_all(&mut new_self)
.expect("can not register array expressions");
+ let first_value = create_first_value(
+ "FIRST_VALUE",
+ Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
+ Arc::new(create_first_value_accumulator),
+ );
+
+ match new_self.register_udaf(Arc::new(first_value)) {
+ Ok(Some(existing_udaf)) => {
+ debug!("Overwrite existing UDAF: {}", existing_udaf.name());
+ }
+ Ok(None) => {}
+ Err(err) => {
+ panic!("Failed to register UDAF: {}", err);
+ }
+ }
+
new_self
}
/// Returns new [`SessionState`] using the provided
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 4733c1433a..275d639a7a 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
distinct,
args,
filter,
- order_by,
+ order_by: _,
null_treatment: _,
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(..) => {
create_function_physical_name(func_def.name(), *distinct, args)
}
AggregateFunctionDefinition::UDF(fun) => {
- // TODO: Add support for filter and order by in AggregateUDF
+ // TODO: Add support for filter by in AggregateUDF
if filter.is_some() {
return exec_err!(
"aggregate expression with filter is not supported"
);
}
- if order_by.is_some() {
- return exec_err!(
- "aggregate expression with order_by is not supported"
- );
- }
+
let names = args
.iter()
.map(|e| create_physical_name(e, false))
@@ -1667,20 +1663,22 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
)?),
None => None,
};
- let order_by = match order_by {
- Some(e) => Some(create_physical_sort_exprs(
- e,
- logical_input_schema,
- execution_props,
- )?),
- None => None,
- };
+
let ignore_nulls = null_treatment
.unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
== NullTreatment::IgnoreNulls;
let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
- let ordering_reqs = order_by.clone().unwrap_or(vec![]);
+ let physical_sort_exprs = match order_by {
+ Some(exprs) => Some(create_physical_sort_exprs(
+ exprs,
+ logical_input_schema,
+ execution_props,
+ )?),
+ None => None,
+ };
+ let ordering_reqs: Vec<PhysicalSortExpr> =
+ physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = aggregates::create_aggregate_expr(
fun,
*distinct,
@@ -1690,16 +1688,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
name,
ignore_nulls,
)?;
- (agg_expr, filter, order_by)
+ (agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::UDF(fun) => {
+ let sort_exprs = order_by.clone().unwrap_or(vec![]);
+ let physical_sort_exprs = match order_by {
+ Some(exprs) => Some(create_physical_sort_exprs(
+ exprs,
+ logical_input_schema,
+ execution_props,
+ )?),
+ None => None,
+ };
+ let ordering_reqs: Vec<PhysicalSortExpr> =
+ physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
+ &sort_exprs,
+ &ordering_reqs,
physical_input_schema,
name,
- );
- (agg_expr?, filter, order_by)
+ ignore_nulls,
+ )?;
+ (agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::Name(_) => {
return internal_err!(
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index a58a8cf516..6085fca876 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -45,7 +45,8 @@ use datafusion::{
};
use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
use datafusion_expr::{
- create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
+ create_udaf, function::AccumulatorArgs, AggregateUDFImpl,
GroupsAccumulator,
+ SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
@@ -491,7 +492,7 @@ impl TimeSum {
// Returns the same type as its input
let return_type = timestamp_type.clone();
- let state_type = vec![timestamp_type.clone()];
+ let state_fields = vec![Field::new("sum", timestamp_type, true)];
let volatility = Volatility::Immutable;
@@ -505,7 +506,7 @@ impl TimeSum {
return_type,
volatility,
accumulator,
- state_type,
+ state_fields,
));
// register the selector as "time_sum"
@@ -591,6 +592,11 @@ impl FirstSelector {
fn register(ctx: &mut SessionContext) {
let return_type = Self::output_datatype();
let state_type = Self::state_datatypes();
+ let state_fields = state_type
+ .into_iter()
+ .enumerate()
+ .map(|(i, t)| Field::new(format!("{i}"), t, true))
+ .collect::<Vec<_>>();
// Possible input signatures
let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
@@ -607,7 +613,7 @@ impl FirstSelector {
Signature::one_of(signatures, volatility),
return_type,
accumulator,
- state_type,
+ state_fields,
));
// register the selector as "first"
@@ -717,15 +723,11 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
Ok(DataType::UInt64)
}
- fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
+ fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
// should use groups accumulator
panic!("accumulator shouldn't invoke");
}
- fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
- Ok(vec![DataType::UInt64])
- }
-
fn groups_accumulator_supported(&self) -> bool {
true
}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 7ede4cd8ff..427c3fde7c 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -577,6 +577,7 @@ impl AggregateFunction {
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Option<Vec<Expr>>,
+ null_treatment: Option<NullTreatment>,
) -> Self {
Self {
func_def: AggregateFunctionDefinition::UDF(udf),
@@ -584,7 +585,7 @@ impl AggregateFunction {
distinct,
filter,
order_by,
- null_treatment: None,
+ null_treatment,
}
}
}
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index db9eb84c21..5294ca7545 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -21,15 +21,17 @@ use crate::expr::{
AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList,
InSubquery,
Placeholder, ScalarFunction, TryCast,
};
-use crate::function::PartitionEvaluatorFactory;
+use crate::function::{
+ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
+};
+use crate::udaf::format_state_name;
use crate::{
aggregate_function, built_in_function,
conditional_expressions::CaseBuilder,
- logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
- BuiltinScalarFunction, Expr, LogicalPlan, Operator,
ScalarFunctionImplementation,
- ScalarUDF, Signature, Volatility,
+ logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr,
LogicalPlan,
+ Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
};
use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF,
WindowUDFImpl};
-use arrow::datatypes::DataType;
+use arrow::datatypes::{DataType, Field};
use datafusion_common::{Column, Result};
use std::any::Any;
use std::fmt::Debug;
@@ -695,16 +697,32 @@ pub fn create_udaf(
) -> AggregateUDF {
let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t|
t.as_ref().clone());
let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t|
t.as_ref().clone());
+ let state_fields = state_type
+ .into_iter()
+ .enumerate()
+ .map(|(i, t)| Field::new(format!("{i}"), t, true))
+ .collect::<Vec<_>>();
AggregateUDF::from(SimpleAggregateUDF::new(
name,
input_type,
return_type,
volatility,
accumulator,
- state_type,
+ state_fields,
))
}
+/// Creates a new UDAF with a specific signature, state type and return type.
+/// The signature and state type must match the `Accumulator's implementation`.
+/// TOOD: We plan to move aggregate function to its own crate. This function
will be deprecated then.
+pub fn create_first_value(
+ name: &str,
+ signature: Signature,
+ accumulator: AccumulatorFactoryFunction,
+) -> AggregateUDF {
+ AggregateUDF::from(FirstValue::new(name, signature, accumulator))
+}
+
/// Implements [`AggregateUDFImpl`] for functions that have a single signature
and
/// return type.
pub struct SimpleAggregateUDF {
@@ -712,7 +730,7 @@ pub struct SimpleAggregateUDF {
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
- state_type: Vec<DataType>,
+ state_fields: Vec<Field>,
}
impl Debug for SimpleAggregateUDF {
@@ -734,7 +752,7 @@ impl SimpleAggregateUDF {
return_type: DataType,
volatility: Volatility,
accumulator: AccumulatorFactoryFunction,
- state_type: Vec<DataType>,
+ state_fields: Vec<Field>,
) -> Self {
let name = name.into();
let signature = Signature::exact(input_type, volatility);
@@ -743,7 +761,7 @@ impl SimpleAggregateUDF {
signature,
return_type,
accumulator,
- state_type,
+ state_fields,
}
}
@@ -752,7 +770,7 @@ impl SimpleAggregateUDF {
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
- state_type: Vec<DataType>,
+ state_fields: Vec<Field>,
) -> Self {
let name = name.into();
Self {
@@ -760,7 +778,7 @@ impl SimpleAggregateUDF {
signature,
return_type,
accumulator,
- state_type,
+ state_fields,
}
}
}
@@ -782,12 +800,92 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
Ok(self.return_type.clone())
}
- fn accumulator(&self, arg: &DataType) -> Result<Box<dyn
crate::Accumulator>> {
- (self.accumulator)(arg)
+ fn accumulator(
+ &self,
+ acc_args: AccumulatorArgs,
+ ) -> Result<Box<dyn crate::Accumulator>> {
+ (self.accumulator)(acc_args)
+ }
+
+ fn state_fields(
+ &self,
+ _name: &str,
+ _value_type: DataType,
+ _ordering_fields: Vec<Field>,
+ ) -> Result<Vec<Field>> {
+ Ok(self.state_fields.clone())
+ }
+}
+
+pub struct FirstValue {
+ name: String,
+ signature: Signature,
+ accumulator: AccumulatorFactoryFunction,
+}
+
+impl Debug for FirstValue {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ f.debug_struct("FirstValue")
+ .field("name", &self.name)
+ .field("signature", &self.signature)
+ .field("accumulator", &"<FUNC>")
+ .finish()
+ }
+}
+
+impl FirstValue {
+ pub fn new(
+ name: impl Into<String>,
+ signature: Signature,
+ accumulator: AccumulatorFactoryFunction,
+ ) -> Self {
+ let name = name.into();
+ Self {
+ name,
+ signature,
+ accumulator,
+ }
+ }
+}
+
+impl AggregateUDFImpl for FirstValue {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ Ok(arg_types[0].clone())
+ }
+
+ fn accumulator(
+ &self,
+ acc_args: AccumulatorArgs,
+ ) -> Result<Box<dyn crate::Accumulator>> {
+ (self.accumulator)(acc_args)
}
- fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
- Ok(self.state_type.clone())
+ fn state_fields(
+ &self,
+ name: &str,
+ value_type: DataType,
+ ordering_fields: Vec<Field>,
+ ) -> Result<Vec<Field>> {
+ let mut fields = vec![Field::new(
+ format_state_name(name, "first_value"),
+ value_type,
+ true,
+ )];
+ fields.extend(ordering_fields);
+ fields.push(Field::new("is_set", DataType::Boolean, true));
+ Ok(fields)
}
}
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index adf4dd3fef..7598c805ad 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -17,8 +17,9 @@
//! Function module contains typing and signature for built-in and user
defined functions.
-use crate::{Accumulator, ColumnarValue, PartitionEvaluator};
-use arrow::datatypes::DataType;
+use crate::ColumnarValue;
+use crate::{Accumulator, Expr, PartitionEvaluator};
+use arrow::datatypes::{DataType, Schema};
use datafusion_common::Result;
use std::sync::Arc;
@@ -37,10 +38,40 @@ pub type ScalarFunctionImplementation =
pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;
-/// Factory that returns an accumulator for the given aggregate, given
-/// its return datatype.
+/// Arguments passed to create an accumulator
+pub struct AccumulatorArgs<'a> {
+ // default arguments
+ /// the return type of the function
+ pub data_type: &'a DataType,
+ /// the schema of the input arguments
+ pub schema: &'a Schema,
+ /// whether to ignore nulls
+ pub ignore_nulls: bool,
+
+ // ordering arguments
+ /// the expressions of `order by`, if no ordering is required, this will
be an empty slice
+ pub sort_exprs: &'a [Expr],
+}
+
+impl<'a> AccumulatorArgs<'a> {
+ pub fn new(
+ data_type: &'a DataType,
+ schema: &'a Schema,
+ ignore_nulls: bool,
+ sort_exprs: &'a [Expr],
+ ) -> Self {
+ Self {
+ data_type,
+ schema,
+ ignore_nulls,
+ sort_exprs,
+ }
+ }
+}
+
+/// Factory that returns an accumulator for the given aggregate function.
pub type AccumulatorFactoryFunction =
- Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;
+ Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;
/// Factory that creates a PartitionEvaluator for the given window
/// function
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index 1c672851e9..0909d8f662 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -379,6 +379,7 @@ impl TreeNode for Expr {
false,
new_filter,
new_order_by,
+ null_treatment,
)))
}
AggregateFunctionDefinition::Name(_) => {
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index c46dd9cd3a..ba80f39dde 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,16 +17,16 @@
//! [`AggregateUDF`]: User Defined Aggregate Functions
+use crate::function::AccumulatorArgs;
use crate::groups_accumulator::GroupsAccumulator;
use crate::{Accumulator, Expr};
-use crate::{
- AccumulatorFactoryFunction, ReturnTypeFunction, Signature,
StateTypeFunction,
-};
-use arrow::datatypes::DataType;
+use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
+use arrow::datatypes::{DataType, Field};
use datafusion_common::{not_impl_err, Result};
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
+use std::vec;
/// Logical representation of a user-defined [aggregate function] (UDAF).
///
@@ -90,14 +90,12 @@ impl AggregateUDF {
signature: &Signature,
return_type: &ReturnTypeFunction,
accumulator: &AccumulatorFactoryFunction,
- state_type: &StateTypeFunction,
) -> Self {
Self::new_from_impl(AggregateUDFLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
accumulator: accumulator.clone(),
- state_type: state_type.clone(),
})
}
@@ -131,12 +129,14 @@ impl AggregateUDF {
/// This utility allows using the UDAF without requiring access to
/// the registry, such as with the DataFrame API.
pub fn call(&self, args: Vec<Expr>) -> Expr {
+ // TODO: Support dictinct, filter, order by and null_treatment
Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
Arc::new(self.clone()),
args,
false,
None,
None,
+ None,
))
}
@@ -166,16 +166,21 @@ impl AggregateUDF {
self.inner.return_type(args)
}
- /// Return an accumulator the given aggregate, given
- /// its return datatype.
- pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn
Accumulator>> {
- self.inner.accumulator(return_type)
+ /// Return an accumulator the given aggregate, given its return datatype
+ pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ self.inner.accumulator(acc_args)
}
- /// Return the type of the intermediate state used by this aggregator,
given
- /// its return datatype. Supports multi-phase aggregations
- pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
- self.inner.state_type(return_type)
+ /// Return the fields of the intermediate state used by this aggregator,
given
+ /// its state name, value type and ordering fields. See
[`AggregateUDFImpl::state_fields`]
+ /// for more details. Supports multi-phase aggregations
+ pub fn state_fields(
+ &self,
+ name: &str,
+ value_type: DataType,
+ ordering_fields: Vec<Field>,
+ ) -> Result<Vec<Field>> {
+ self.inner.state_fields(name, value_type, ordering_fields)
}
/// See [`AggregateUDFImpl::groups_accumulator_supported`] for more
details.
@@ -213,8 +218,10 @@ where
/// # use std::any::Any;
/// # use arrow::datatypes::DataType;
/// # use datafusion_common::{DataFusionError, plan_err, Result};
-/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
-/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator};
+/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
+/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator,
function::AccumulatorArgs};
+/// # use arrow::datatypes::Schema;
+/// # use arrow::datatypes::Field;
/// #[derive(Debug, Clone)]
/// struct GeoMeanUdf {
/// signature: Signature
@@ -240,9 +247,12 @@ where
/// Ok(DataType::Float64)
/// }
/// // This is the accumulator factory; DataFusion uses it to create new
accumulators.
-/// fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>>
{ unimplemented!() }
-/// fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
-/// Ok(vec![DataType::Float64, DataType::UInt32])
+/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> { unimplemented!() }
+/// fn state_fields(&self, _name: &str, value_type: DataType,
_ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
+/// Ok(vec![
+/// Field::new("value", value_type, true),
+/// Field::new("ordering", DataType::UInt32, true)
+/// ])
/// }
/// }
///
@@ -269,15 +279,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// Return a new [`Accumulator`] that aggregates values for a specific
/// group during query execution.
- fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>>;
+ ///
+ /// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`]
for more details.
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>>;
+
+ /// Return the fields of the intermediate state.
+ ///
+ /// name: the name of the state
+ ///
+ /// value_type: the type of the value, it should be the result of the
`return_type`
+ ///
+ /// ordering_fields: the fields used for ordering, empty if no ordering
expression is provided
+ fn state_fields(
+ &self,
+ name: &str,
+ value_type: DataType,
+ ordering_fields: Vec<Field>,
+ ) -> Result<Vec<Field>> {
+ let value_fields = vec![Field::new(
+ format_state_name(name, "value"),
+ value_type,
+ true,
+ )];
- /// Return the type used to serialize the [`Accumulator`]'s intermediate
state.
- /// See [`Accumulator::state()`] for more details
- fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>>;
+ Ok(value_fields.into_iter().chain(ordering_fields).collect())
+ }
/// If the aggregate expression has a specialized
/// [`GroupsAccumulator`] implementation. If this returns true,
- /// `[Self::create_groups_accumulator`] will be called.
+ /// `[Self::create_groups_accumulator]` will be called.
fn groups_accumulator_supported(&self) -> bool {
false
}
@@ -337,12 +367,8 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
self.inner.return_type(arg_types)
}
- fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
- self.inner.accumulator(arg)
- }
-
- fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
- self.inner.state_type(return_type)
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ self.inner.accumulator(acc_args)
}
fn aliases(&self) -> &[String] {
@@ -361,8 +387,6 @@ pub struct AggregateUDFLegacyWrapper {
return_type: ReturnTypeFunction,
/// actual implementation
accumulator: AccumulatorFactoryFunction,
- /// the accumulator's state's description as a function of the return type
- state_type: StateTypeFunction,
}
impl Debug for AggregateUDFLegacyWrapper {
@@ -394,12 +418,13 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
Ok(res.as_ref().clone())
}
- fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
- (self.accumulator)(arg)
+ fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn
Accumulator>> {
+ (self.accumulator)(acc_args)
}
+}
- fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
- let res = (self.state_type)(return_type)?;
- Ok(res.as_ref().clone())
- }
+/// returns the name of the state
+/// TODO: Remove duplicated function in physical-expr
+pub(crate) fn format_state_name(name: &str, state_name: &str) -> String {
+ format!("{name}[{state_name}]")
}
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index b7b7c4f20e..fbbd9a9456 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -366,7 +366,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(
- fun, new_expr, false, filter, order_by,
+ fun,
+ new_expr,
+ false,
+ filter,
+ order_by,
+ null_treatment,
),
)))
}
@@ -896,6 +901,7 @@ mod test {
false,
None,
None,
+ None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty)?);
let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n
EmptyRelation";
@@ -906,7 +912,6 @@ mod test {
fn agg_udaf_invalid_input() -> Result<()> {
let empty = empty();
let return_type = DataType::Float64;
- let state_type = vec![DataType::UInt64, DataType::Float64];
let accumulator: AccumulatorFactoryFunction =
Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
@@ -914,7 +919,10 @@ mod test {
Signature::uniform(1, vec![DataType::Float64],
Volatility::Immutable),
return_type,
accumulator,
- state_type,
+ vec![
+ Field::new("count", DataType::UInt64, true),
+ Field::new("avg", DataType::Float64, true),
+ ],
));
let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
Arc::new(my_avg),
@@ -922,6 +930,7 @@ mod test {
false,
None,
None,
+ None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty)?);
let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, "")
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 77613aa662..c3c0569df7 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -806,7 +806,6 @@ mod test {
let return_type = DataType::UInt32;
let accumulator: AccumulatorFactoryFunction = Arc::new(|_|
unimplemented!());
- let state_type = vec![DataType::UInt32];
let udf_agg = |inner: Expr| {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
@@ -814,12 +813,13 @@ mod test {
Signature::exact(vec![DataType::UInt32],
Volatility::Stable),
return_type.clone(),
accumulator.clone(),
- state_type.clone(),
+ vec![Field::new("value", DataType::UInt32, true)],
))),
vec![inner],
false,
None,
None,
+ None,
))
};
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index cee6798638..c549e62193 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -367,6 +367,7 @@ pub fn create_aggregate_expr(
input_phy_types[0].clone(),
ordering_req.to_vec(),
ordering_types,
+ vec![],
)
.with_ignore_nulls(ignore_nulls),
),
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 6d6e32a149..26bd219f65 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -21,7 +21,7 @@ use std::any::Any;
use std::sync::Arc;
use crate::aggregate::utils::{down_cast_any_ref, get_sort_options,
ordering_fields};
-use crate::expressions::format_state_name;
+use crate::expressions::{self, format_state_name};
use crate::{
reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr,
PhysicalSortExpr,
};
@@ -29,11 +29,13 @@ use crate::{
use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
use arrow::compute::{self, lexsort_to_indices, SortColumn};
use arrow::datatypes::{DataType, Field};
+use arrow_schema::SortOptions;
use datafusion_common::utils::{compare_rows, get_arrayref_at_indices,
get_row_at_idx};
use datafusion_common::{
arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
};
-use datafusion_expr::Accumulator;
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{Accumulator, Expr};
/// FIRST_VALUE aggregate expression
#[derive(Debug, Clone)]
@@ -45,6 +47,7 @@ pub struct FirstValue {
ordering_req: LexOrdering,
requirement_satisfied: bool,
ignore_nulls: bool,
+ state_fields: Vec<Field>,
}
impl FirstValue {
@@ -55,6 +58,7 @@ impl FirstValue {
input_data_type: DataType,
ordering_req: LexOrdering,
order_by_data_types: Vec<DataType>,
+ state_fields: Vec<Field>,
) -> Self {
let requirement_satisfied = ordering_req.is_empty();
Self {
@@ -65,6 +69,7 @@ impl FirstValue {
ordering_req,
requirement_satisfied,
ignore_nulls: false,
+ state_fields,
}
}
@@ -149,6 +154,10 @@ impl AggregateExpr for FirstValue {
}
fn state_fields(&self) -> Result<Vec<Field>> {
+ if !self.state_fields.is_empty() {
+ return Ok(self.state_fields.clone());
+ }
+
let mut fields = vec![Field::new(
format_state_name(&self.name, "first_value"),
self.input_data_type.clone(),
@@ -384,6 +393,50 @@ impl Accumulator for FirstValueAccumulator {
}
}
+pub fn create_first_value_accumulator(
+ acc_args: AccumulatorArgs,
+) -> Result<Box<dyn Accumulator>> {
+ let mut all_sort_orders = vec![];
+
+ // Construct PhysicalSortExpr objects from Expr objects:
+ let mut sort_exprs = vec![];
+ for expr in acc_args.sort_exprs {
+ if let Expr::Sort(sort) = expr {
+ if let Expr::Column(col) = sort.expr.as_ref() {
+ let name = &col.name;
+ let e = expressions::col(name, acc_args.schema)?;
+ sort_exprs.push(PhysicalSortExpr {
+ expr: e,
+ options: SortOptions {
+ descending: !sort.asc,
+ nulls_first: sort.nulls_first,
+ },
+ });
+ }
+ }
+ }
+ if !sort_exprs.is_empty() {
+ all_sort_orders.extend(sort_exprs);
+ }
+
+ let ordering_req = all_sort_orders;
+
+ let ordering_dtypes = ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(acc_args.schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ let requirement_satisfied = ordering_req.is_empty();
+
+ FirstValueAccumulator::try_new(
+ acc_args.data_type,
+ &ordering_dtypes,
+ ordering_req,
+ acc_args.ignore_nulls,
+ )
+ .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied))
as _)
+}
+
/// LAST_VALUE aggregate expression
#[derive(Debug, Clone)]
pub struct LastValue {
@@ -471,6 +524,7 @@ impl LastValue {
input_data_type,
reverse_order_bys(&ordering_req),
order_by_data_types,
+ vec![],
)
}
}
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs
b/datafusion/physical-expr/src/aggregate/utils.rs
index 60d59c16be..613f6118e9 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -188,7 +188,7 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
}
/// Construct corresponding fields for lexicographical ordering requirement
expression
-pub(crate) fn ordering_fields(
+pub fn ordering_fields(
ordering_req: &[PhysicalSortExpr],
// Data type of each expression in the ordering requirement
data_types: &[DataType],
diff --git a/datafusion/physical-expr/src/lib.rs
b/datafusion/physical-expr/src/lib.rs
index 7819d51161..655771270a 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -58,3 +58,5 @@ pub use sort_expr::{
PhysicalSortRequirement,
};
pub use utils::{reverse_order_bys, split_conjunction};
+
+pub use aggregate::first_last::create_first_value_accumulator;
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index e263876b07..f8ad03bf6d 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -2026,6 +2026,7 @@ mod tests {
DataType::Float64,
ordering_req.clone(),
vec![DataType::Float64],
+ vec![],
))]
} else {
vec![Arc::new(LastValue::new(
@@ -2209,6 +2210,7 @@ mod tests {
DataType::Float64,
sort_expr_reverse.clone(),
vec![DataType::Float64],
+ vec![],
)),
Arc::new(LastValue::new(
col_b.clone(),
diff --git a/datafusion/physical-plan/src/udaf.rs
b/datafusion/physical-plan/src/udaf.rs
index fd9279dfd5..74a5603c0c 100644
--- a/datafusion/physical-plan/src/udaf.rs
+++ b/datafusion/physical-plan/src/udaf.rs
@@ -17,22 +17,20 @@
//! This module contains functions and structs supporting user-defined
aggregate functions.
-use datafusion_expr::GroupsAccumulator;
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{Expr, GroupsAccumulator};
use fmt::Debug;
use std::any::Any;
use std::fmt;
-use arrow::{
- datatypes::Field,
- datatypes::{DataType, Schema},
-};
+use arrow::datatypes::{DataType, Field, Schema};
-use super::{expressions::format_state_name, Accumulator, AggregateExpr};
+use super::{Accumulator, AggregateExpr};
use datafusion_common::{not_impl_err, Result};
pub use datafusion_expr::AggregateUDF;
-use datafusion_physical_expr::PhysicalExpr;
+use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr};
-use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
+use datafusion_physical_expr::aggregate::utils::{down_cast_any_ref,
ordering_fields};
use std::sync::Arc;
/// Creates a physical expression of the UDAF, that includes all necessary
type coercion.
@@ -40,19 +38,34 @@ use std::sync::Arc;
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
- input_schema: &Schema,
+ sort_exprs: &[Expr],
+ ordering_req: &[PhysicalSortExpr],
+ schema: &Schema,
name: impl Into<String>,
+ ignore_nulls: bool,
) -> Result<Arc<dyn AggregateExpr>> {
let input_exprs_types = input_phy_exprs
.iter()
- .map(|arg| arg.data_type(input_schema))
+ .map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;
+ let ordering_types = ordering_req
+ .iter()
+ .map(|e| e.expr.data_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ let ordering_fields = ordering_fields(ordering_req, &ordering_types);
+
Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
args: input_phy_exprs.to_vec(),
data_type: fun.return_type(&input_exprs_types)?,
name: name.into(),
+ schema: schema.clone(),
+ sort_exprs: sort_exprs.to_vec(),
+ ordering_req: ordering_req.to_vec(),
+ ignore_nulls,
+ ordering_fields,
}))
}
@@ -64,6 +77,13 @@ pub struct AggregateFunctionExpr {
/// Output / return type of this aggregate
data_type: DataType,
name: String,
+ schema: Schema,
+ // The logical order by expressions
+ sort_exprs: Vec<Expr>,
+ // The physical order by expressions
+ ordering_req: LexOrdering,
+ ignore_nulls: bool,
+ ordering_fields: Vec<Field>,
}
impl AggregateFunctionExpr {
@@ -84,21 +104,11 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn state_fields(&self) -> Result<Vec<Field>> {
- let fields = self
- .fun
- .state_type(&self.data_type)?
- .iter()
- .enumerate()
- .map(|(i, data_type)| {
- Field::new(
- format_state_name(&self.name, &format!("{i}")),
- data_type.clone(),
- true,
- )
- })
- .collect::<Vec<Field>>();
-
- Ok(fields)
+ self.fun.state_fields(
+ self.name(),
+ self.data_type.clone(),
+ self.ordering_fields.clone(),
+ )
}
fn field(&self) -> Result<Field> {
@@ -106,11 +116,18 @@ impl AggregateExpr for AggregateFunctionExpr {
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- self.fun.accumulator(&self.data_type)
+ let acc_args = AccumulatorArgs::new(
+ &self.data_type,
+ &self.schema,
+ self.ignore_nulls,
+ &self.sort_exprs,
+ );
+
+ self.fun.accumulator(acc_args)
}
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- let accumulator = self.fun.accumulator(&self.data_type)?;
+ let accumulator = self.create_accumulator()?;
// Accumulators that have window frame startings different
// than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
@@ -175,6 +192,10 @@ impl AggregateExpr for AggregateFunctionExpr {
fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.fun.create_groups_accumulator()
}
+
+ fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
+ (!self.ordering_req.is_empty()).then_some(&self.ordering_req)
+ }
}
impl PartialEq<dyn Any> for AggregateFunctionExpr {
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index 21f42f41fb..c5c845614c 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -92,8 +92,19 @@ pub fn create_window_expr(
))
}
WindowFunctionDefinition::AggregateUDF(fun) => {
- let aggregate =
- udaf::create_aggregate_expr(fun.as_ref(), args, input_schema,
name)?;
+ // TODO: Ordering not supported for Window UDFs yet
+ let sort_exprs = &[];
+ let ordering_req = &[];
+
+ let aggregate = udaf::create_aggregate_expr(
+ fun.as_ref(),
+ args,
+ sort_exprs,
+ ordering_req,
+ input_schema,
+ name,
+ ignore_nulls,
+ )?;
window_expr_from_aggregate_expr(
partition_by,
order_by,
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 3694418412..6a536b2fa3 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1389,6 +1389,7 @@ pub fn parse_expr(
false,
parse_optional_expr(pb.filter.as_deref(), registry,
codec)?.map(Box::new),
parse_vec_expr(&pb.order_by, registry, codec)?,
+ None,
)))
}
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 00dacffe06..4d5d6cadad 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -517,7 +517,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
}
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
let agg_udf =
registry.udaf(udaf_name)?;
-
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr,
&physical_schema, name)
+ // TODO: `order by` is not
supported for UDAF yet
+ let sort_exprs = &[];
+ let ordering_req = &[];
+ let ignore_nulls = false;
+
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs,
ordering_req, &physical_schema, name, ignore_nulls)
}
}
}).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 4cd133dc21..f136e31455 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1772,6 +1772,7 @@ fn roundtrip_aggregate_udf() {
false,
Some(Box::new(lit(true))),
None,
+ None,
));
let ctx = SessionContext::new();
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 0238291c77..5dacf692e9 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -412,14 +412,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
let return_type = DataType::Int64;
let accumulator: AccumulatorFactoryFunction = Arc::new(|_|
Ok(Box::new(Example)));
- let state_type = vec![DataType::Int64];
let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
"example",
Signature::exact(vec![DataType::Int64], Volatility::Immutable),
return_type,
accumulator,
- state_type,
+ vec![Field::new("value", DataType::Int64, true)],
));
let ctx = SessionContext::new();
@@ -431,8 +430,11 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![udaf::create_aggregate_expr(
&udaf,
&[col("b", &schema)?],
+ &[],
+ &[],
&schema,
"example_agg",
+ false,
)?];
roundtrip_test_with_context(
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 582404b297..e97eb1a32b 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -221,9 +221,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
} else {
// User defined aggregate functions (UDAF) have precedence in case
it has the same name as a scalar built-in function
if let Some(fm) = self.context_provider.get_aggregate_meta(&name) {
+ let order_by =
+ self.order_by_to_sort_expr(&order_by, schema,
planner_context, true)?;
+ let order_by = (!order_by.is_empty()).then_some(order_by);
let args = self.function_args_to_expr(args, schema,
planner_context)?;
+ // TODO: Support filter and distinct for UDAFs
return
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
- fm, args, false, None, None,
+ fm,
+ args,
+ false,
+ None,
+ order_by,
+ null_treatment,
)));
}
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index e68f3f9928..73782ab27f 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -754,7 +754,7 @@ pub async fn from_substrait_agg_func(
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
Ok(Arc::new(Expr::AggregateFunction(
- expr::AggregateFunction::new_udf(fun, args, distinct, filter,
order_by),
+ expr::AggregateFunction::new_udf(fun, args, distinct, filter,
order_by, None),
)))
} else if let Ok(fun) =
aggregate_function::AggregateFunction::from_str(function_name)
{