This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new dfd4442da8 Make FirstValue an UDAF, Change 
`AggregateUDFImpl::accumulator` signature, support ORDER BY for UDAFs (#9874)
dfd4442da8 is described below

commit dfd4442da8332116c7e1f9fcd9de7c6da856442b
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Apr 3 09:26:31 2024 +0800

    Make FirstValue an UDAF, Change `AggregateUDFImpl::accumulator` signature, 
support ORDER BY for UDAFs (#9874)
    
    * first draft
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clippy fix
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * use one vector for ordering req
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add sort exprs to accumulator
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clippy
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix doc test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * change to ref
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix typo
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix doc
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * move schema and logical ordering exprs
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * remove redudant info
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rename
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add ignore nulls
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix conflict
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * backup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * complete return_type
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * complete replace
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * split to first value udf
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * replace accumulator
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * small fix
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * remove ordering types
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * make state fields more flexible
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * replace done
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm comments
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm test1
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix state fields
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * args struct for accumulator
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * simplify
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add sig
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add comments
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix docs
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * use exprs utils
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm state type
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add comment
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion-examples/examples/advanced_udaf.rs      |  19 ++-
 datafusion/core/src/execution/context/mod.rs       |  20 ++++
 datafusion/core/src/physical_planner.rs            |  50 +++++---
 .../tests/user_defined/user_defined_aggregates.rs  |  20 ++--
 datafusion/expr/src/expr.rs                        |   3 +-
 datafusion/expr/src/expr_fn.rs                     | 128 ++++++++++++++++++---
 datafusion/expr/src/function.rs                    |  41 ++++++-
 datafusion/expr/src/tree_node/expr.rs              |   1 +
 datafusion/expr/src/udaf.rs                        | 101 ++++++++++------
 datafusion/optimizer/src/analyzer/type_coercion.rs |  15 ++-
 .../optimizer/src/common_subexpr_eliminate.rs      |   4 +-
 datafusion/physical-expr/src/aggregate/build_in.rs |   1 +
 .../physical-expr/src/aggregate/first_last.rs      |  58 +++++++++-
 datafusion/physical-expr/src/aggregate/utils.rs    |   2 +-
 datafusion/physical-expr/src/lib.rs                |   2 +
 datafusion/physical-plan/src/aggregates/mod.rs     |   2 +
 datafusion/physical-plan/src/udaf.rs               |  75 +++++++-----
 datafusion/physical-plan/src/windows/mod.rs        |  15 ++-
 datafusion/proto/src/logical_plan/from_proto.rs    |   1 +
 datafusion/proto/src/physical_plan/mod.rs          |   6 +-
 .../proto/tests/cases/roundtrip_logical_plan.rs    |   1 +
 .../proto/tests/cases/roundtrip_physical_plan.rs   |   6 +-
 datafusion/sql/src/expr/function.rs                |  11 +-
 datafusion/substrait/src/logical_plan/consumer.rs  |   2 +-
 24 files changed, 450 insertions(+), 134 deletions(-)

diff --git a/datafusion-examples/examples/advanced_udaf.rs 
b/datafusion-examples/examples/advanced_udaf.rs
index 10164a850b..342a23b6e7 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow_schema::{Field, Schema};
 use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
 use datafusion_physical_expr::NullState;
 use std::{any::Any, sync::Arc};
@@ -30,7 +31,8 @@ use datafusion::error::Result;
 use datafusion::prelude::*;
 use datafusion_common::{cast::as_float64_array, ScalarValue};
 use datafusion_expr::{
-    Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
+    function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl,
+    GroupsAccumulator, Signature,
 };
 
 /// This example shows how to use the full AggregateUDFImpl API to implement a 
user
@@ -85,13 +87,21 @@ impl AggregateUDFImpl for GeoMeanUdaf {
     /// is supported, DataFusion will use this row oriented
     /// accumulator when the aggregate function is used as a window function
     /// or when there are only aggregates (no GROUP BY columns) in the plan.
-    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
+    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
         Ok(Box::new(GeometricMean::new()))
     }
 
     /// This is the description of the state. accumulator's state() must match 
the types here.
-    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
-        Ok(vec![DataType::Float64, DataType::UInt32])
+    fn state_fields(
+        &self,
+        _name: &str,
+        value_type: DataType,
+        _ordering_fields: Vec<arrow_schema::Field>,
+    ) -> Result<Vec<arrow_schema::Field>> {
+        Ok(vec![
+            Field::new("prod", value_type, true),
+            Field::new("n", DataType::UInt32, true),
+        ])
     }
 
     /// Tell DataFusion that this aggregate supports the more performant 
`GroupsAccumulator`
@@ -191,7 +201,6 @@ impl Accumulator for GeometricMean {
 
 // create local session context with an in-memory table
 fn create_context() -> Result<SessionContext> {
-    use datafusion::arrow::datatypes::{Field, Schema};
     use datafusion::datasource::MemTable;
     // define a schema.
     let schema = Arc::new(Schema::new(vec![
diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index f8bf0d2ee1..4eaaf94ecf 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -69,11 +69,14 @@ use datafusion_common::{
     OwnedTableReference, SchemaReference,
 };
 use datafusion_execution::registry::SerializerRegistry;
+use datafusion_expr::type_coercion::aggregates::NUMERICS;
+use datafusion_expr::{create_first_value, Signature, Volatility};
 use datafusion_expr::{
     logical_plan::{DdlStatement, Statement},
     var_provider::is_system_variables,
     Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
 };
+use datafusion_physical_expr::create_first_value_accumulator;
 use datafusion_sql::{
     parser::{CopyToSource, CopyToStatement, DFParser},
     planner::{object_name_to_table_reference, ContextProvider, ParserOptions, 
SqlToRel},
@@ -82,6 +85,7 @@ use datafusion_sql::{
 
 use async_trait::async_trait;
 use chrono::{DateTime, Utc};
+use log::debug;
 use parking_lot::RwLock;
 use sqlparser::dialect::dialect_from_str;
 use url::Url;
@@ -1451,6 +1455,22 @@ impl SessionState {
         datafusion_functions_array::register_all(&mut new_self)
             .expect("can not register array expressions");
 
+        let first_value = create_first_value(
+            "FIRST_VALUE",
+            Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
+            Arc::new(create_first_value_accumulator),
+        );
+
+        match new_self.register_udaf(Arc::new(first_value)) {
+            Ok(Some(existing_udaf)) => {
+                debug!("Overwrite existing UDAF: {}", existing_udaf.name());
+            }
+            Ok(None) => {}
+            Err(err) => {
+                panic!("Failed to register UDAF: {}", err);
+            }
+        }
+
         new_self
     }
     /// Returns new [`SessionState`] using the provided
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index 4733c1433a..275d639a7a 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -247,24 +247,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> 
Result<String> {
             distinct,
             args,
             filter,
-            order_by,
+            order_by: _,
             null_treatment: _,
         }) => match func_def {
             AggregateFunctionDefinition::BuiltIn(..) => {
                 create_function_physical_name(func_def.name(), *distinct, args)
             }
             AggregateFunctionDefinition::UDF(fun) => {
-                // TODO: Add support for filter and order by in AggregateUDF
+                // TODO: Add support for filter by in AggregateUDF
                 if filter.is_some() {
                     return exec_err!(
                         "aggregate expression with filter is not supported"
                     );
                 }
-                if order_by.is_some() {
-                    return exec_err!(
-                        "aggregate expression with order_by is not supported"
-                    );
-                }
+
                 let names = args
                     .iter()
                     .map(|e| create_physical_name(e, false))
@@ -1667,20 +1663,22 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
                 )?),
                 None => None,
             };
-            let order_by = match order_by {
-                Some(e) => Some(create_physical_sort_exprs(
-                    e,
-                    logical_input_schema,
-                    execution_props,
-                )?),
-                None => None,
-            };
+
             let ignore_nulls = null_treatment
                 .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls)
                 == NullTreatment::IgnoreNulls;
             let (agg_expr, filter, order_by) = match func_def {
                 AggregateFunctionDefinition::BuiltIn(fun) => {
-                    let ordering_reqs = order_by.clone().unwrap_or(vec![]);
+                    let physical_sort_exprs = match order_by {
+                        Some(exprs) => Some(create_physical_sort_exprs(
+                            exprs,
+                            logical_input_schema,
+                            execution_props,
+                        )?),
+                        None => None,
+                    };
+                    let ordering_reqs: Vec<PhysicalSortExpr> =
+                        physical_sort_exprs.clone().unwrap_or(vec![]);
                     let agg_expr = aggregates::create_aggregate_expr(
                         fun,
                         *distinct,
@@ -1690,16 +1688,30 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
                         name,
                         ignore_nulls,
                     )?;
-                    (agg_expr, filter, order_by)
+                    (agg_expr, filter, physical_sort_exprs)
                 }
                 AggregateFunctionDefinition::UDF(fun) => {
+                    let sort_exprs = order_by.clone().unwrap_or(vec![]);
+                    let physical_sort_exprs = match order_by {
+                        Some(exprs) => Some(create_physical_sort_exprs(
+                            exprs,
+                            logical_input_schema,
+                            execution_props,
+                        )?),
+                        None => None,
+                    };
+                    let ordering_reqs: Vec<PhysicalSortExpr> =
+                        physical_sort_exprs.clone().unwrap_or(vec![]);
                     let agg_expr = udaf::create_aggregate_expr(
                         fun,
                         &args,
+                        &sort_exprs,
+                        &ordering_reqs,
                         physical_input_schema,
                         name,
-                    );
-                    (agg_expr?, filter, order_by)
+                        ignore_nulls,
+                    )?;
+                    (agg_expr, filter, physical_sort_exprs)
                 }
                 AggregateFunctionDefinition::Name(_) => {
                     return internal_err!(
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs 
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index a58a8cf516..6085fca876 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -45,7 +45,8 @@ use datafusion::{
 };
 use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err};
 use datafusion_expr::{
-    create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF,
+    create_udaf, function::AccumulatorArgs, AggregateUDFImpl, 
GroupsAccumulator,
+    SimpleAggregateUDF,
 };
 use datafusion_physical_expr::expressions::AvgAccumulator;
 
@@ -491,7 +492,7 @@ impl TimeSum {
         // Returns the same type as its input
         let return_type = timestamp_type.clone();
 
-        let state_type = vec![timestamp_type.clone()];
+        let state_fields = vec![Field::new("sum", timestamp_type, true)];
 
         let volatility = Volatility::Immutable;
 
@@ -505,7 +506,7 @@ impl TimeSum {
             return_type,
             volatility,
             accumulator,
-            state_type,
+            state_fields,
         ));
 
         // register the selector as "time_sum"
@@ -591,6 +592,11 @@ impl FirstSelector {
     fn register(ctx: &mut SessionContext) {
         let return_type = Self::output_datatype();
         let state_type = Self::state_datatypes();
+        let state_fields = state_type
+            .into_iter()
+            .enumerate()
+            .map(|(i, t)| Field::new(format!("{i}"), t, true))
+            .collect::<Vec<_>>();
 
         // Possible input signatures
         let signatures = vec![TypeSignature::Exact(Self::input_datatypes())];
@@ -607,7 +613,7 @@ impl FirstSelector {
             Signature::one_of(signatures, volatility),
             return_type,
             accumulator,
-            state_type,
+            state_fields,
         ));
 
         // register the selector as "first"
@@ -717,15 +723,11 @@ impl AggregateUDFImpl for TestGroupsAccumulator {
         Ok(DataType::UInt64)
     }
 
-    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> {
+    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
         // should use groups accumulator
         panic!("accumulator shouldn't invoke");
     }
 
-    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
-        Ok(vec![DataType::UInt64])
-    }
-
     fn groups_accumulator_supported(&self) -> bool {
         true
     }
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 7ede4cd8ff..427c3fde7c 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -577,6 +577,7 @@ impl AggregateFunction {
         distinct: bool,
         filter: Option<Box<Expr>>,
         order_by: Option<Vec<Expr>>,
+        null_treatment: Option<NullTreatment>,
     ) -> Self {
         Self {
             func_def: AggregateFunctionDefinition::UDF(udf),
@@ -584,7 +585,7 @@ impl AggregateFunction {
             distinct,
             filter,
             order_by,
-            null_treatment: None,
+            null_treatment,
         }
     }
 }
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index db9eb84c21..5294ca7545 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -21,15 +21,17 @@ use crate::expr::{
     AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, 
InSubquery,
     Placeholder, ScalarFunction, TryCast,
 };
-use crate::function::PartitionEvaluatorFactory;
+use crate::function::{
+    AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory,
+};
+use crate::udaf::format_state_name;
 use crate::{
     aggregate_function, built_in_function, 
conditional_expressions::CaseBuilder,
-    logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
-    BuiltinScalarFunction, Expr, LogicalPlan, Operator, 
ScalarFunctionImplementation,
-    ScalarUDF, Signature, Volatility,
+    logical_plan::Subquery, AggregateUDF, BuiltinScalarFunction, Expr, 
LogicalPlan,
+    Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
 };
 use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, 
WindowUDFImpl};
-use arrow::datatypes::DataType;
+use arrow::datatypes::{DataType, Field};
 use datafusion_common::{Column, Result};
 use std::any::Any;
 use std::fmt::Debug;
@@ -695,16 +697,32 @@ pub fn create_udaf(
 ) -> AggregateUDF {
     let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| 
t.as_ref().clone());
     let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| 
t.as_ref().clone());
+    let state_fields = state_type
+        .into_iter()
+        .enumerate()
+        .map(|(i, t)| Field::new(format!("{i}"), t, true))
+        .collect::<Vec<_>>();
     AggregateUDF::from(SimpleAggregateUDF::new(
         name,
         input_type,
         return_type,
         volatility,
         accumulator,
-        state_type,
+        state_fields,
     ))
 }
 
+/// Creates a new UDAF with a specific signature, state type and return type.
+/// The signature and state type must match the `Accumulator's implementation`.
+/// TOOD: We plan to move aggregate function to its own crate. This function 
will be deprecated then.
+pub fn create_first_value(
+    name: &str,
+    signature: Signature,
+    accumulator: AccumulatorFactoryFunction,
+) -> AggregateUDF {
+    AggregateUDF::from(FirstValue::new(name, signature, accumulator))
+}
+
 /// Implements [`AggregateUDFImpl`] for functions that have a single signature 
and
 /// return type.
 pub struct SimpleAggregateUDF {
@@ -712,7 +730,7 @@ pub struct SimpleAggregateUDF {
     signature: Signature,
     return_type: DataType,
     accumulator: AccumulatorFactoryFunction,
-    state_type: Vec<DataType>,
+    state_fields: Vec<Field>,
 }
 
 impl Debug for SimpleAggregateUDF {
@@ -734,7 +752,7 @@ impl SimpleAggregateUDF {
         return_type: DataType,
         volatility: Volatility,
         accumulator: AccumulatorFactoryFunction,
-        state_type: Vec<DataType>,
+        state_fields: Vec<Field>,
     ) -> Self {
         let name = name.into();
         let signature = Signature::exact(input_type, volatility);
@@ -743,7 +761,7 @@ impl SimpleAggregateUDF {
             signature,
             return_type,
             accumulator,
-            state_type,
+            state_fields,
         }
     }
 
@@ -752,7 +770,7 @@ impl SimpleAggregateUDF {
         signature: Signature,
         return_type: DataType,
         accumulator: AccumulatorFactoryFunction,
-        state_type: Vec<DataType>,
+        state_fields: Vec<Field>,
     ) -> Self {
         let name = name.into();
         Self {
@@ -760,7 +778,7 @@ impl SimpleAggregateUDF {
             signature,
             return_type,
             accumulator,
-            state_type,
+            state_fields,
         }
     }
 }
@@ -782,12 +800,92 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
         Ok(self.return_type.clone())
     }
 
-    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn 
crate::Accumulator>> {
-        (self.accumulator)(arg)
+    fn accumulator(
+        &self,
+        acc_args: AccumulatorArgs,
+    ) -> Result<Box<dyn crate::Accumulator>> {
+        (self.accumulator)(acc_args)
+    }
+
+    fn state_fields(
+        &self,
+        _name: &str,
+        _value_type: DataType,
+        _ordering_fields: Vec<Field>,
+    ) -> Result<Vec<Field>> {
+        Ok(self.state_fields.clone())
+    }
+}
+
+pub struct FirstValue {
+    name: String,
+    signature: Signature,
+    accumulator: AccumulatorFactoryFunction,
+}
+
+impl Debug for FirstValue {
+    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+        f.debug_struct("FirstValue")
+            .field("name", &self.name)
+            .field("signature", &self.signature)
+            .field("accumulator", &"<FUNC>")
+            .finish()
+    }
+}
+
+impl FirstValue {
+    pub fn new(
+        name: impl Into<String>,
+        signature: Signature,
+        accumulator: AccumulatorFactoryFunction,
+    ) -> Self {
+        let name = name.into();
+        Self {
+            name,
+            signature,
+            accumulator,
+        }
+    }
+}
+
+impl AggregateUDFImpl for FirstValue {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+        Ok(arg_types[0].clone())
+    }
+
+    fn accumulator(
+        &self,
+        acc_args: AccumulatorArgs,
+    ) -> Result<Box<dyn crate::Accumulator>> {
+        (self.accumulator)(acc_args)
     }
 
-    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
-        Ok(self.state_type.clone())
+    fn state_fields(
+        &self,
+        name: &str,
+        value_type: DataType,
+        ordering_fields: Vec<Field>,
+    ) -> Result<Vec<Field>> {
+        let mut fields = vec![Field::new(
+            format_state_name(name, "first_value"),
+            value_type,
+            true,
+        )];
+        fields.extend(ordering_fields);
+        fields.push(Field::new("is_set", DataType::Boolean, true));
+        Ok(fields)
     }
 }
 
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index adf4dd3fef..7598c805ad 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -17,8 +17,9 @@
 
 //! Function module contains typing and signature for built-in and user 
defined functions.
 
-use crate::{Accumulator, ColumnarValue, PartitionEvaluator};
-use arrow::datatypes::DataType;
+use crate::ColumnarValue;
+use crate::{Accumulator, Expr, PartitionEvaluator};
+use arrow::datatypes::{DataType, Schema};
 use datafusion_common::Result;
 use std::sync::Arc;
 
@@ -37,10 +38,40 @@ pub type ScalarFunctionImplementation =
 pub type ReturnTypeFunction =
     Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;
 
-/// Factory that returns an accumulator for the given aggregate, given
-/// its return datatype.
+/// Arguments passed to create an accumulator
+pub struct AccumulatorArgs<'a> {
+    // default arguments
+    /// the return type of the function
+    pub data_type: &'a DataType,
+    /// the schema of the input arguments
+    pub schema: &'a Schema,
+    /// whether to ignore nulls
+    pub ignore_nulls: bool,
+
+    // ordering arguments
+    /// the expressions of `order by`, if no ordering is required, this will 
be an empty slice
+    pub sort_exprs: &'a [Expr],
+}
+
+impl<'a> AccumulatorArgs<'a> {
+    pub fn new(
+        data_type: &'a DataType,
+        schema: &'a Schema,
+        ignore_nulls: bool,
+        sort_exprs: &'a [Expr],
+    ) -> Self {
+        Self {
+            data_type,
+            schema,
+            ignore_nulls,
+            sort_exprs,
+        }
+    }
+}
+
+/// Factory that returns an accumulator for the given aggregate function.
 pub type AccumulatorFactoryFunction =
-    Arc<dyn Fn(&DataType) -> Result<Box<dyn Accumulator>> + Send + Sync>;
+    Arc<dyn Fn(AccumulatorArgs) -> Result<Box<dyn Accumulator>> + Send + Sync>;
 
 /// Factory that creates a PartitionEvaluator for the given window
 /// function
diff --git a/datafusion/expr/src/tree_node/expr.rs 
b/datafusion/expr/src/tree_node/expr.rs
index 1c672851e9..0909d8f662 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -379,6 +379,7 @@ impl TreeNode for Expr {
                             false,
                             new_filter,
                             new_order_by,
+                            null_treatment,
                         )))
                     }
                     AggregateFunctionDefinition::Name(_) => {
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index c46dd9cd3a..ba80f39dde 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -17,16 +17,16 @@
 
 //! [`AggregateUDF`]: User Defined Aggregate Functions
 
+use crate::function::AccumulatorArgs;
 use crate::groups_accumulator::GroupsAccumulator;
 use crate::{Accumulator, Expr};
-use crate::{
-    AccumulatorFactoryFunction, ReturnTypeFunction, Signature, 
StateTypeFunction,
-};
-use arrow::datatypes::DataType;
+use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
+use arrow::datatypes::{DataType, Field};
 use datafusion_common::{not_impl_err, Result};
 use std::any::Any;
 use std::fmt::{self, Debug, Formatter};
 use std::sync::Arc;
+use std::vec;
 
 /// Logical representation of a user-defined [aggregate function] (UDAF).
 ///
@@ -90,14 +90,12 @@ impl AggregateUDF {
         signature: &Signature,
         return_type: &ReturnTypeFunction,
         accumulator: &AccumulatorFactoryFunction,
-        state_type: &StateTypeFunction,
     ) -> Self {
         Self::new_from_impl(AggregateUDFLegacyWrapper {
             name: name.to_owned(),
             signature: signature.clone(),
             return_type: return_type.clone(),
             accumulator: accumulator.clone(),
-            state_type: state_type.clone(),
         })
     }
 
@@ -131,12 +129,14 @@ impl AggregateUDF {
     /// This utility allows using the UDAF without requiring access to
     /// the registry, such as with the DataFrame API.
     pub fn call(&self, args: Vec<Expr>) -> Expr {
+        // TODO: Support dictinct, filter, order by and null_treatment
         Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf(
             Arc::new(self.clone()),
             args,
             false,
             None,
             None,
+            None,
         ))
     }
 
@@ -166,16 +166,21 @@ impl AggregateUDF {
         self.inner.return_type(args)
     }
 
-    /// Return an accumulator the given aggregate, given
-    /// its return datatype.
-    pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn 
Accumulator>> {
-        self.inner.accumulator(return_type)
+    /// Return an accumulator the given aggregate, given its return datatype
+    pub fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        self.inner.accumulator(acc_args)
     }
 
-    /// Return the type of the intermediate state used by this aggregator, 
given
-    /// its return datatype. Supports multi-phase aggregations
-    pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
-        self.inner.state_type(return_type)
+    /// Return the fields of the intermediate state used by this aggregator, 
given
+    /// its state name, value type and ordering fields. See 
[`AggregateUDFImpl::state_fields`]
+    /// for more details. Supports multi-phase aggregations
+    pub fn state_fields(
+        &self,
+        name: &str,
+        value_type: DataType,
+        ordering_fields: Vec<Field>,
+    ) -> Result<Vec<Field>> {
+        self.inner.state_fields(name, value_type, ordering_fields)
     }
 
     /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more 
details.
@@ -213,8 +218,10 @@ where
 /// # use std::any::Any;
 /// # use arrow::datatypes::DataType;
 /// # use datafusion_common::{DataFusionError, plan_err, Result};
-/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
-/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator};
+/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr};
+/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, 
function::AccumulatorArgs};
+/// # use arrow::datatypes::Schema;
+/// # use arrow::datatypes::Field;
 /// #[derive(Debug, Clone)]
 /// struct GeoMeanUdf {
 ///   signature: Signature
@@ -240,9 +247,12 @@ where
 ///      Ok(DataType::Float64)
 ///    }
 ///    // This is the accumulator factory; DataFusion uses it to create new 
accumulators.
-///    fn accumulator(&self, _arg: &DataType) -> Result<Box<dyn Accumulator>> 
{ unimplemented!() }
-///    fn state_type(&self, _return_type: &DataType) -> Result<Vec<DataType>> {
-///        Ok(vec![DataType::Float64, DataType::UInt32])
+///    fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> { unimplemented!() }
+///    fn state_fields(&self, _name: &str, value_type: DataType, 
_ordering_fields: Vec<Field>) -> Result<Vec<Field>> {
+///        Ok(vec![
+///             Field::new("value", value_type, true),
+///             Field::new("ordering", DataType::UInt32, true)
+///        ])
 ///    }
 /// }
 ///
@@ -269,15 +279,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
 
     /// Return a new [`Accumulator`] that aggregates values for a specific
     /// group during query execution.
-    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>>;
+    ///
+    /// `acc_args`: the arguments to the accumulator. See [`AccumulatorArgs`] 
for more details.
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>>;
+
+    /// Return the fields of the intermediate state.
+    ///
+    /// name: the name of the state
+    ///
+    /// value_type: the type of the value, it should be the result of the 
`return_type`
+    ///
+    /// ordering_fields: the fields used for ordering, empty if no ordering 
expression is provided
+    fn state_fields(
+        &self,
+        name: &str,
+        value_type: DataType,
+        ordering_fields: Vec<Field>,
+    ) -> Result<Vec<Field>> {
+        let value_fields = vec![Field::new(
+            format_state_name(name, "value"),
+            value_type,
+            true,
+        )];
 
-    /// Return the type used to serialize the  [`Accumulator`]'s intermediate 
state.
-    /// See [`Accumulator::state()`] for more details
-    fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>>;
+        Ok(value_fields.into_iter().chain(ordering_fields).collect())
+    }
 
     /// If the aggregate expression has a specialized
     /// [`GroupsAccumulator`] implementation. If this returns true,
-    /// `[Self::create_groups_accumulator`] will be called.
+    /// `[Self::create_groups_accumulator]` will be called.
     fn groups_accumulator_supported(&self) -> bool {
         false
     }
@@ -337,12 +367,8 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
         self.inner.return_type(arg_types)
     }
 
-    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
-        self.inner.accumulator(arg)
-    }
-
-    fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
-        self.inner.state_type(return_type)
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        self.inner.accumulator(acc_args)
     }
 
     fn aliases(&self) -> &[String] {
@@ -361,8 +387,6 @@ pub struct AggregateUDFLegacyWrapper {
     return_type: ReturnTypeFunction,
     /// actual implementation
     accumulator: AccumulatorFactoryFunction,
-    /// the accumulator's state's description as a function of the return type
-    state_type: StateTypeFunction,
 }
 
 impl Debug for AggregateUDFLegacyWrapper {
@@ -394,12 +418,13 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
         Ok(res.as_ref().clone())
     }
 
-    fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
-        (self.accumulator)(arg)
+    fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
+        (self.accumulator)(acc_args)
     }
+}
 
-    fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
-        let res = (self.state_type)(return_type)?;
-        Ok(res.as_ref().clone())
-    }
+/// returns the name of the state
+/// TODO: Remove duplicated function in physical-expr
+pub(crate) fn format_state_name(name: &str, state_name: &str) -> String {
+    format!("{name}[{state_name}]")
 }
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index b7b7c4f20e..fbbd9a9456 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -366,7 +366,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                     )?;
                     Ok(Transformed::yes(Expr::AggregateFunction(
                         expr::AggregateFunction::new_udf(
-                            fun, new_expr, false, filter, order_by,
+                            fun,
+                            new_expr,
+                            false,
+                            filter,
+                            order_by,
+                            null_treatment,
                         ),
                     )))
                 }
@@ -896,6 +901,7 @@ mod test {
             false,
             None,
             None,
+            None,
         ));
         let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], 
empty)?);
         let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n  
EmptyRelation";
@@ -906,7 +912,6 @@ mod test {
     fn agg_udaf_invalid_input() -> Result<()> {
         let empty = empty();
         let return_type = DataType::Float64;
-        let state_type = vec![DataType::UInt64, DataType::Float64];
         let accumulator: AccumulatorFactoryFunction =
             Arc::new(|_| Ok(Box::<AvgAccumulator>::default()));
         let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
@@ -914,7 +919,10 @@ mod test {
             Signature::uniform(1, vec![DataType::Float64], 
Volatility::Immutable),
             return_type,
             accumulator,
-            state_type,
+            vec![
+                Field::new("count", DataType::UInt64, true),
+                Field::new("avg", DataType::Float64, true),
+            ],
         ));
         let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf(
             Arc::new(my_avg),
@@ -922,6 +930,7 @@ mod test {
             false,
             None,
             None,
+            None,
         ));
         let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], 
empty)?);
         let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), 
&plan, "")
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 77613aa662..c3c0569df7 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -806,7 +806,6 @@ mod test {
 
         let return_type = DataType::UInt32;
         let accumulator: AccumulatorFactoryFunction = Arc::new(|_| 
unimplemented!());
-        let state_type = vec![DataType::UInt32];
         let udf_agg = |inner: Expr| {
             
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
                 
Arc::new(AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
@@ -814,12 +813,13 @@ mod test {
                     Signature::exact(vec![DataType::UInt32], 
Volatility::Stable),
                     return_type.clone(),
                     accumulator.clone(),
-                    state_type.clone(),
+                    vec![Field::new("value", DataType::UInt32, true)],
                 ))),
                 vec![inner],
                 false,
                 None,
                 None,
+                None,
             ))
         };
 
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index cee6798638..c549e62193 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -367,6 +367,7 @@ pub fn create_aggregate_expr(
                 input_phy_types[0].clone(),
                 ordering_req.to_vec(),
                 ordering_types,
+                vec![],
             )
             .with_ignore_nulls(ignore_nulls),
         ),
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs 
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 6d6e32a149..26bd219f65 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -21,7 +21,7 @@ use std::any::Any;
 use std::sync::Arc;
 
 use crate::aggregate::utils::{down_cast_any_ref, get_sort_options, 
ordering_fields};
-use crate::expressions::format_state_name;
+use crate::expressions::{self, format_state_name};
 use crate::{
     reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, 
PhysicalSortExpr,
 };
@@ -29,11 +29,13 @@ use crate::{
 use arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
 use arrow::compute::{self, lexsort_to_indices, SortColumn};
 use arrow::datatypes::{DataType, Field};
+use arrow_schema::SortOptions;
 use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, 
get_row_at_idx};
 use datafusion_common::{
     arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
 };
-use datafusion_expr::Accumulator;
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{Accumulator, Expr};
 
 /// FIRST_VALUE aggregate expression
 #[derive(Debug, Clone)]
@@ -45,6 +47,7 @@ pub struct FirstValue {
     ordering_req: LexOrdering,
     requirement_satisfied: bool,
     ignore_nulls: bool,
+    state_fields: Vec<Field>,
 }
 
 impl FirstValue {
@@ -55,6 +58,7 @@ impl FirstValue {
         input_data_type: DataType,
         ordering_req: LexOrdering,
         order_by_data_types: Vec<DataType>,
+        state_fields: Vec<Field>,
     ) -> Self {
         let requirement_satisfied = ordering_req.is_empty();
         Self {
@@ -65,6 +69,7 @@ impl FirstValue {
             ordering_req,
             requirement_satisfied,
             ignore_nulls: false,
+            state_fields,
         }
     }
 
@@ -149,6 +154,10 @@ impl AggregateExpr for FirstValue {
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
+        if !self.state_fields.is_empty() {
+            return Ok(self.state_fields.clone());
+        }
+
         let mut fields = vec![Field::new(
             format_state_name(&self.name, "first_value"),
             self.input_data_type.clone(),
@@ -384,6 +393,50 @@ impl Accumulator for FirstValueAccumulator {
     }
 }
 
+pub fn create_first_value_accumulator(
+    acc_args: AccumulatorArgs,
+) -> Result<Box<dyn Accumulator>> {
+    let mut all_sort_orders = vec![];
+
+    // Construct PhysicalSortExpr objects from Expr objects:
+    let mut sort_exprs = vec![];
+    for expr in acc_args.sort_exprs {
+        if let Expr::Sort(sort) = expr {
+            if let Expr::Column(col) = sort.expr.as_ref() {
+                let name = &col.name;
+                let e = expressions::col(name, acc_args.schema)?;
+                sort_exprs.push(PhysicalSortExpr {
+                    expr: e,
+                    options: SortOptions {
+                        descending: !sort.asc,
+                        nulls_first: sort.nulls_first,
+                    },
+                });
+            }
+        }
+    }
+    if !sort_exprs.is_empty() {
+        all_sort_orders.extend(sort_exprs);
+    }
+
+    let ordering_req = all_sort_orders;
+
+    let ordering_dtypes = ordering_req
+        .iter()
+        .map(|e| e.expr.data_type(acc_args.schema))
+        .collect::<Result<Vec<_>>>()?;
+
+    let requirement_satisfied = ordering_req.is_empty();
+
+    FirstValueAccumulator::try_new(
+        acc_args.data_type,
+        &ordering_dtypes,
+        ordering_req,
+        acc_args.ignore_nulls,
+    )
+    .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) 
as _)
+}
+
 /// LAST_VALUE aggregate expression
 #[derive(Debug, Clone)]
 pub struct LastValue {
@@ -471,6 +524,7 @@ impl LastValue {
             input_data_type,
             reverse_order_bys(&ordering_req),
             order_by_data_types,
+            vec![],
         )
     }
 }
diff --git a/datafusion/physical-expr/src/aggregate/utils.rs 
b/datafusion/physical-expr/src/aggregate/utils.rs
index 60d59c16be..613f6118e9 100644
--- a/datafusion/physical-expr/src/aggregate/utils.rs
+++ b/datafusion/physical-expr/src/aggregate/utils.rs
@@ -188,7 +188,7 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
 }
 
 /// Construct corresponding fields for lexicographical ordering requirement 
expression
-pub(crate) fn ordering_fields(
+pub fn ordering_fields(
     ordering_req: &[PhysicalSortExpr],
     // Data type of each expression in the ordering requirement
     data_types: &[DataType],
diff --git a/datafusion/physical-expr/src/lib.rs 
b/datafusion/physical-expr/src/lib.rs
index 7819d51161..655771270a 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -58,3 +58,5 @@ pub use sort_expr::{
     PhysicalSortRequirement,
 };
 pub use utils::{reverse_order_bys, split_conjunction};
+
+pub use aggregate::first_last::create_first_value_accumulator;
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index e263876b07..f8ad03bf6d 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -2026,6 +2026,7 @@ mod tests {
                 DataType::Float64,
                 ordering_req.clone(),
                 vec![DataType::Float64],
+                vec![],
             ))]
         } else {
             vec![Arc::new(LastValue::new(
@@ -2209,6 +2210,7 @@ mod tests {
                 DataType::Float64,
                 sort_expr_reverse.clone(),
                 vec![DataType::Float64],
+                vec![],
             )),
             Arc::new(LastValue::new(
                 col_b.clone(),
diff --git a/datafusion/physical-plan/src/udaf.rs 
b/datafusion/physical-plan/src/udaf.rs
index fd9279dfd5..74a5603c0c 100644
--- a/datafusion/physical-plan/src/udaf.rs
+++ b/datafusion/physical-plan/src/udaf.rs
@@ -17,22 +17,20 @@
 
 //! This module contains functions and structs supporting user-defined 
aggregate functions.
 
-use datafusion_expr::GroupsAccumulator;
+use datafusion_expr::function::AccumulatorArgs;
+use datafusion_expr::{Expr, GroupsAccumulator};
 use fmt::Debug;
 use std::any::Any;
 use std::fmt;
 
-use arrow::{
-    datatypes::Field,
-    datatypes::{DataType, Schema},
-};
+use arrow::datatypes::{DataType, Field, Schema};
 
-use super::{expressions::format_state_name, Accumulator, AggregateExpr};
+use super::{Accumulator, AggregateExpr};
 use datafusion_common::{not_impl_err, Result};
 pub use datafusion_expr::AggregateUDF;
-use datafusion_physical_expr::PhysicalExpr;
+use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr};
 
-use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
+use datafusion_physical_expr::aggregate::utils::{down_cast_any_ref, 
ordering_fields};
 use std::sync::Arc;
 
 /// Creates a physical expression of the UDAF, that includes all necessary 
type coercion.
@@ -40,19 +38,34 @@ use std::sync::Arc;
 pub fn create_aggregate_expr(
     fun: &AggregateUDF,
     input_phy_exprs: &[Arc<dyn PhysicalExpr>],
-    input_schema: &Schema,
+    sort_exprs: &[Expr],
+    ordering_req: &[PhysicalSortExpr],
+    schema: &Schema,
     name: impl Into<String>,
+    ignore_nulls: bool,
 ) -> Result<Arc<dyn AggregateExpr>> {
     let input_exprs_types = input_phy_exprs
         .iter()
-        .map(|arg| arg.data_type(input_schema))
+        .map(|arg| arg.data_type(schema))
         .collect::<Result<Vec<_>>>()?;
 
+    let ordering_types = ordering_req
+        .iter()
+        .map(|e| e.expr.data_type(schema))
+        .collect::<Result<Vec<_>>>()?;
+
+    let ordering_fields = ordering_fields(ordering_req, &ordering_types);
+
     Ok(Arc::new(AggregateFunctionExpr {
         fun: fun.clone(),
         args: input_phy_exprs.to_vec(),
         data_type: fun.return_type(&input_exprs_types)?,
         name: name.into(),
+        schema: schema.clone(),
+        sort_exprs: sort_exprs.to_vec(),
+        ordering_req: ordering_req.to_vec(),
+        ignore_nulls,
+        ordering_fields,
     }))
 }
 
@@ -64,6 +77,13 @@ pub struct AggregateFunctionExpr {
     /// Output / return type of this aggregate
     data_type: DataType,
     name: String,
+    schema: Schema,
+    // The logical order by expressions
+    sort_exprs: Vec<Expr>,
+    // The physical order by expressions
+    ordering_req: LexOrdering,
+    ignore_nulls: bool,
+    ordering_fields: Vec<Field>,
 }
 
 impl AggregateFunctionExpr {
@@ -84,21 +104,11 @@ impl AggregateExpr for AggregateFunctionExpr {
     }
 
     fn state_fields(&self) -> Result<Vec<Field>> {
-        let fields = self
-            .fun
-            .state_type(&self.data_type)?
-            .iter()
-            .enumerate()
-            .map(|(i, data_type)| {
-                Field::new(
-                    format_state_name(&self.name, &format!("{i}")),
-                    data_type.clone(),
-                    true,
-                )
-            })
-            .collect::<Vec<Field>>();
-
-        Ok(fields)
+        self.fun.state_fields(
+            self.name(),
+            self.data_type.clone(),
+            self.ordering_fields.clone(),
+        )
     }
 
     fn field(&self) -> Result<Field> {
@@ -106,11 +116,18 @@ impl AggregateExpr for AggregateFunctionExpr {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        self.fun.accumulator(&self.data_type)
+        let acc_args = AccumulatorArgs::new(
+            &self.data_type,
+            &self.schema,
+            self.ignore_nulls,
+            &self.sort_exprs,
+        );
+
+        self.fun.accumulator(acc_args)
     }
 
     fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        let accumulator = self.fun.accumulator(&self.data_type)?;
+        let accumulator = self.create_accumulator()?;
 
         // Accumulators that have window frame startings different
         // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to
@@ -175,6 +192,10 @@ impl AggregateExpr for AggregateFunctionExpr {
     fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
         self.fun.create_groups_accumulator()
     }
+
+    fn order_bys(&self) -> Option<&[PhysicalSortExpr]> {
+        (!self.ordering_req.is_empty()).then_some(&self.ordering_req)
+    }
 }
 
 impl PartialEq<dyn Any> for AggregateFunctionExpr {
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index 21f42f41fb..c5c845614c 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -92,8 +92,19 @@ pub fn create_window_expr(
             ))
         }
         WindowFunctionDefinition::AggregateUDF(fun) => {
-            let aggregate =
-                udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, 
name)?;
+            // TODO: Ordering not supported for Window UDFs yet
+            let sort_exprs = &[];
+            let ordering_req = &[];
+
+            let aggregate = udaf::create_aggregate_expr(
+                fun.as_ref(),
+                args,
+                sort_exprs,
+                ordering_req,
+                input_schema,
+                name,
+                ignore_nulls,
+            )?;
             window_expr_from_aggregate_expr(
                 partition_by,
                 order_by,
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index 3694418412..6a536b2fa3 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1389,6 +1389,7 @@ pub fn parse_expr(
                 false,
                 parse_optional_expr(pb.filter.as_deref(), registry, 
codec)?.map(Box::new),
                 parse_vec_expr(&pb.order_by, registry, codec)?,
+                None,
             )))
         }
 
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 00dacffe06..4d5d6cadad 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -517,7 +517,11 @@ impl AsExecutionPlan for PhysicalPlanNode {
                                         }
                                         
AggregateFunction::UserDefinedAggrFunction(udaf_name) => {
                                             let agg_udf = 
registry.udaf(udaf_name)?;
-                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, 
&physical_schema, name)
+                                            // TODO: `order by` is not 
supported for UDAF yet
+                                            let sort_exprs = &[];
+                                            let ordering_req = &[];
+                                            let ignore_nulls = false;
+                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, 
ordering_req, &physical_schema, name, ignore_nulls)
                                         }
                                     }
                                 }).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index 4cd133dc21..f136e31455 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -1772,6 +1772,7 @@ fn roundtrip_aggregate_udf() {
         false,
         Some(Box::new(lit(true))),
         None,
+        None,
     ));
 
     let ctx = SessionContext::new();
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 0238291c77..5dacf692e9 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -412,14 +412,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
 
     let return_type = DataType::Int64;
     let accumulator: AccumulatorFactoryFunction = Arc::new(|_| 
Ok(Box::new(Example)));
-    let state_type = vec![DataType::Int64];
 
     let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature(
         "example",
         Signature::exact(vec![DataType::Int64], Volatility::Immutable),
         return_type,
         accumulator,
-        state_type,
+        vec![Field::new("value", DataType::Int64, true)],
     ));
 
     let ctx = SessionContext::new();
@@ -431,8 +430,11 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
     let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![udaf::create_aggregate_expr(
         &udaf,
         &[col("b", &schema)?],
+        &[],
+        &[],
         &schema,
         "example_agg",
+        false,
     )?];
 
     roundtrip_test_with_context(
diff --git a/datafusion/sql/src/expr/function.rs 
b/datafusion/sql/src/expr/function.rs
index 582404b297..e97eb1a32b 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -221,9 +221,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         } else {
             // User defined aggregate functions (UDAF) have precedence in case 
it has the same name as a scalar built-in function
             if let Some(fm) = self.context_provider.get_aggregate_meta(&name) {
+                let order_by =
+                    self.order_by_to_sort_expr(&order_by, schema, 
planner_context, true)?;
+                let order_by = (!order_by.is_empty()).then_some(order_by);
                 let args = self.function_args_to_expr(args, schema, 
planner_context)?;
+                // TODO: Support filter and distinct for UDAFs
                 return 
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
-                    fm, args, false, None, None,
+                    fm,
+                    args,
+                    false,
+                    None,
+                    order_by,
+                    null_treatment,
                 )));
             }
 
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index e68f3f9928..73782ab27f 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -754,7 +754,7 @@ pub async fn from_substrait_agg_func(
     // try udaf first, then built-in aggr fn.
     if let Ok(fun) = ctx.udaf(function_name) {
         Ok(Arc::new(Expr::AggregateFunction(
-            expr::AggregateFunction::new_udf(fun, args, distinct, filter, 
order_by),
+            expr::AggregateFunction::new_udf(fun, args, distinct, filter, 
order_by, None),
         )))
     } else if let Ok(fun) = 
aggregate_function::AggregateFunction::from_str(function_name)
     {

Reply via email to