This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 13569340bc ExprBuilder for Physical Aggregate Expr (#11617)
13569340bc is described below

commit 13569340bce99e4a317ec4d71e5c46d69dfa733d
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Jul 24 22:30:05 2024 +0800

    ExprBuilder for Physical Aggregate Expr (#11617)
    
    * aggregate expr builder
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * replace parts of test
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * continue
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup all
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clipp
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add sort
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm field
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * address comment
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fix import path
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/core/src/lib.rs                         |   5 +
 .../src/physical_optimizer/aggregate_statistics.rs |  20 +-
 .../combine_partial_final_agg.rs                   |  41 +--
 datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs |  23 +-
 .../physical-expr-common/src/aggregate/mod.rs      | 286 +++++++++++++++------
 datafusion/physical-plan/src/aggregates/mod.rs     | 134 ++++------
 datafusion/physical-plan/src/windows/mod.rs        |  39 ++-
 datafusion/proto/src/physical_plan/mod.rs          |  11 +-
 .../proto/tests/cases/roundtrip_physical_plan.rs   | 177 +++++--------
 9 files changed, 369 insertions(+), 367 deletions(-)

diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs
index 9ab6ed527d..d9ab9e1c07 100644
--- a/datafusion/core/src/lib.rs
+++ b/datafusion/core/src/lib.rs
@@ -545,6 +545,11 @@ pub mod optimizer {
     pub use datafusion_optimizer::*;
 }
 
+/// re-export of [`datafusion_physical_expr`] crate
+pub mod physical_expr_common {
+    pub use datafusion_physical_expr_common::*;
+}
+
 /// re-export of [`datafusion_physical_expr`] crate
 pub mod physical_expr {
     pub use datafusion_physical_expr::*;
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs 
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index e7580d3e33..5f08e4512b 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -326,7 +326,7 @@ pub(crate) mod tests {
     use datafusion_functions_aggregate::count::count_udaf;
     use datafusion_physical_expr::expressions::cast;
     use datafusion_physical_expr::PhysicalExpr;
-    use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
+    use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
     use datafusion_physical_plan::aggregates::AggregateMode;
 
     /// Mock data using a MemoryExec which has an exact count statistic
@@ -419,19 +419,11 @@ pub(crate) mod tests {
 
         // Return appropriate expr depending if COUNT is for col or table (*)
         pub(crate) fn count_expr(&self, schema: &Schema) -> Arc<dyn 
AggregateExpr> {
-            create_aggregate_expr(
-                &count_udaf(),
-                &[self.column()],
-                &[],
-                &[],
-                &[],
-                schema,
-                self.column_name(),
-                false,
-                false,
-                false,
-            )
-            .unwrap()
+            AggregateExprBuilder::new(count_udaf(), vec![self.column()])
+                .schema(Arc::new(schema.clone()))
+                .name(self.column_name())
+                .build()
+                .unwrap()
         }
 
         /// what argument would this aggregate need in the plan?
diff --git 
a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs 
b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
index ddb7d36fb5..6f3274820c 100644
--- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
+++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
@@ -177,7 +177,7 @@ mod tests {
     use datafusion_functions_aggregate::count::count_udaf;
     use datafusion_functions_aggregate::sum::sum_udaf;
     use datafusion_physical_expr::expressions::col;
-    use datafusion_physical_plan::udaf::create_aggregate_expr;
+    use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
 
     /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan 
against the expected
     macro_rules! assert_optimized {
@@ -278,19 +278,11 @@ mod tests {
         name: &str,
         schema: &Schema,
     ) -> Arc<dyn AggregateExpr> {
-        create_aggregate_expr(
-            &count_udaf(),
-            &[expr],
-            &[],
-            &[],
-            &[],
-            schema,
-            name,
-            false,
-            false,
-            false,
-        )
-        .unwrap()
+        AggregateExprBuilder::new(count_udaf(), vec![expr])
+            .schema(Arc::new(schema.clone()))
+            .name(name)
+            .build()
+            .unwrap()
     }
 
     #[test]
@@ -368,19 +360,14 @@ mod tests {
     #[test]
     fn aggregations_with_group_combined() -> Result<()> {
         let schema = schema();
-
-        let aggr_expr = vec![create_aggregate_expr(
-            &sum_udaf(),
-            &[col("b", &schema)?],
-            &[],
-            &[],
-            &[],
-            &schema,
-            "Sum(b)",
-            false,
-            false,
-            false,
-        )?];
+        let aggr_expr =
+            vec![
+                AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
+                    .schema(Arc::clone(&schema))
+                    .name("Sum(b)")
+                    .build()
+                    .unwrap(),
+            ];
         let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
             vec![(col("c", &schema)?, "c".to_string())];
 
diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
index 736560da97..6f286c9aeb 100644
--- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
@@ -35,7 +35,7 @@ use datafusion_common::tree_node::{TreeNode, 
TreeNodeRecursion, TreeNodeVisitor}
 use datafusion_functions_aggregate::sum::sum_udaf;
 use datafusion_physical_expr::expressions::col;
 use datafusion_physical_expr::PhysicalSortExpr;
-use datafusion_physical_plan::udaf::create_aggregate_expr;
+use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
 use datafusion_physical_plan::InputOrderMode;
 use test_utils::{add_empty_batches, StringBatchGenerator};
 
@@ -103,19 +103,14 @@ async fn run_aggregate_test(input1: Vec<RecordBatch>, 
group_by_columns: Vec<&str
             .with_sort_information(vec![sort_keys]),
     );
 
-    let aggregate_expr = vec![create_aggregate_expr(
-        &sum_udaf(),
-        &[col("d", &schema).unwrap()],
-        &[],
-        &[],
-        &[],
-        &schema,
-        "sum1",
-        false,
-        false,
-        false,
-    )
-    .unwrap()];
+    let aggregate_expr =
+        vec![
+            AggregateExprBuilder::new(sum_udaf(), vec![col("d", 
&schema).unwrap()])
+                .schema(Arc::clone(&schema))
+                .name("sum1")
+                .build()
+                .unwrap(),
+        ];
     let expr = group_by_columns
         .iter()
         .map(|elem| (col(elem, &schema).unwrap(), elem.to_string()))
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs 
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index 8c5f9f9e5a..b58a5a6faf 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -22,8 +22,8 @@ pub mod stats;
 pub mod tdigest;
 pub mod utils;
 
-use arrow::datatypes::{DataType, Field, Schema};
-use datafusion_common::{not_impl_err, DFSchema, Result};
+use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
 use datafusion_expr::function::StateFieldsArgs;
 use datafusion_expr::type_coercion::aggregates::check_arg_count;
 use datafusion_expr::ReversedUDAF;
@@ -33,7 +33,7 @@ use datafusion_expr::{
 use std::fmt::Debug;
 use std::{any::Any, sync::Arc};
 
-use self::utils::{down_cast_any_ref, ordering_fields};
+use self::utils::down_cast_any_ref;
 use crate::physical_expr::PhysicalExpr;
 use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
 use crate::utils::reverse_order_bys;
@@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity;
 /// `is_reversed` is used to indicate whether the aggregation is running in 
reverse order,
 /// it could be used to hint Accumulator to accumulate in the reversed order,
 /// you can just set to false if you are not reversing expression
+///
+/// You can also create expression by [`AggregateExprBuilder`]
 #[allow(clippy::too_many_arguments)]
 pub fn create_aggregate_expr(
     fun: &AggregateUDF,
@@ -66,45 +68,23 @@ pub fn create_aggregate_expr(
     name: impl Into<String>,
     ignore_nulls: bool,
     is_distinct: bool,
-    is_reversed: bool,
 ) -> Result<Arc<dyn AggregateExpr>> {
-    debug_assert_eq!(sort_exprs.len(), ordering_req.len());
-
-    let input_exprs_types = input_phy_exprs
-        .iter()
-        .map(|arg| arg.data_type(schema))
-        .collect::<Result<Vec<_>>>()?;
-
-    check_arg_count(
-        fun.name(),
-        &input_exprs_types,
-        &fun.signature().type_signature,
-    )?;
-
-    let ordering_types = ordering_req
-        .iter()
-        .map(|e| e.expr.data_type(schema))
-        .collect::<Result<Vec<_>>>()?;
-
-    let ordering_fields = ordering_fields(ordering_req, &ordering_types);
-    let name = name.into();
-
-    Ok(Arc::new(AggregateFunctionExpr {
-        fun: fun.clone(),
-        args: input_phy_exprs.to_vec(),
-        logical_args: input_exprs.to_vec(),
-        data_type: fun.return_type(&input_exprs_types)?,
-        name,
-        schema: schema.clone(),
-        dfschema: DFSchema::empty(),
-        sort_exprs: sort_exprs.to_vec(),
-        ordering_req: ordering_req.to_vec(),
-        ignore_nulls,
-        ordering_fields,
-        is_distinct,
-        input_type: input_exprs_types[0].clone(),
-        is_reversed,
-    }))
+    let mut builder =
+        AggregateExprBuilder::new(Arc::new(fun.clone()), 
input_phy_exprs.to_vec());
+    builder = builder.sort_exprs(sort_exprs.to_vec());
+    builder = builder.order_by(ordering_req.to_vec());
+    builder = builder.logical_exprs(input_exprs.to_vec());
+    builder = builder.schema(Arc::new(schema.clone()));
+    builder = builder.name(name);
+
+    if ignore_nulls {
+        builder = builder.ignore_nulls();
+    }
+    if is_distinct {
+        builder = builder.distinct();
+    }
+
+    builder.build()
 }
 
 #[allow(clippy::too_many_arguments)]
@@ -121,44 +101,196 @@ pub fn create_aggregate_expr_with_dfschema(
     is_distinct: bool,
     is_reversed: bool,
 ) -> Result<Arc<dyn AggregateExpr>> {
-    debug_assert_eq!(sort_exprs.len(), ordering_req.len());
-
+    let mut builder =
+        AggregateExprBuilder::new(Arc::new(fun.clone()), 
input_phy_exprs.to_vec());
+    builder = builder.sort_exprs(sort_exprs.to_vec());
+    builder = builder.order_by(ordering_req.to_vec());
+    builder = builder.logical_exprs(input_exprs.to_vec());
+    builder = builder.dfschema(dfschema.clone());
     let schema: Schema = dfschema.into();
+    builder = builder.schema(Arc::new(schema));
+    builder = builder.name(name);
+
+    if ignore_nulls {
+        builder = builder.ignore_nulls();
+    }
+    if is_distinct {
+        builder = builder.distinct();
+    }
+    if is_reversed {
+        builder = builder.reversed();
+    }
+
+    builder.build()
+}
+
+/// Builder for physical [`AggregateExpr`]
+///
+/// `AggregateExpr` contains the information necessary to call
+/// an aggregate expression.
+#[derive(Debug, Clone)]
+pub struct AggregateExprBuilder {
+    fun: Arc<AggregateUDF>,
+    /// Physical expressions of the aggregate function
+    args: Vec<Arc<dyn PhysicalExpr>>,
+    /// Logical expressions of the aggregate function, it will be deprecated 
in <https://github.com/apache/datafusion/issues/11359>
+    logical_args: Vec<Expr>,
+    name: String,
+    /// Arrow Schema for the aggregate function
+    schema: SchemaRef,
+    /// Datafusion Schema for the aggregate function
+    dfschema: DFSchema,
+    /// The logical order by expressions, it will be deprecated in 
<https://github.com/apache/datafusion/issues/11359>
+    sort_exprs: Vec<Expr>,
+    /// The physical order by expressions
+    ordering_req: LexOrdering,
+    /// Whether to ignore null values
+    ignore_nulls: bool,
+    /// Whether is distinct aggregate function
+    is_distinct: bool,
+    /// Whether the expression is reversed
+    is_reversed: bool,
+}
+
+impl AggregateExprBuilder {
+    pub fn new(fun: Arc<AggregateUDF>, args: Vec<Arc<dyn PhysicalExpr>>) -> 
Self {
+        Self {
+            fun,
+            args,
+            logical_args: vec![],
+            name: String::new(),
+            schema: Arc::new(Schema::empty()),
+            dfschema: DFSchema::empty(),
+            sort_exprs: vec![],
+            ordering_req: vec![],
+            ignore_nulls: false,
+            is_distinct: false,
+            is_reversed: false,
+        }
+    }
+
+    pub fn build(self) -> Result<Arc<dyn AggregateExpr>> {
+        let Self {
+            fun,
+            args,
+            logical_args,
+            name,
+            schema,
+            dfschema,
+            sort_exprs,
+            ordering_req,
+            ignore_nulls,
+            is_distinct,
+            is_reversed,
+        } = self;
+        if args.is_empty() {
+            return internal_err!("args should not be empty");
+        }
+
+        let mut ordering_fields = vec![];
+
+        debug_assert_eq!(sort_exprs.len(), ordering_req.len());
+        if !ordering_req.is_empty() {
+            let ordering_types = ordering_req
+                .iter()
+                .map(|e| e.expr.data_type(&schema))
+                .collect::<Result<Vec<_>>>()?;
+
+            ordering_fields = utils::ordering_fields(&ordering_req, 
&ordering_types);
+        }
+
+        let input_exprs_types = args
+            .iter()
+            .map(|arg| arg.data_type(&schema))
+            .collect::<Result<Vec<_>>>()?;
+
+        check_arg_count(
+            fun.name(),
+            &input_exprs_types,
+            &fun.signature().type_signature,
+        )?;
 
-    let input_exprs_types = input_phy_exprs
-        .iter()
-        .map(|arg| arg.data_type(&schema))
-        .collect::<Result<Vec<_>>>()?;
-
-    check_arg_count(
-        fun.name(),
-        &input_exprs_types,
-        &fun.signature().type_signature,
-    )?;
-
-    let ordering_types = ordering_req
-        .iter()
-        .map(|e| e.expr.data_type(&schema))
-        .collect::<Result<Vec<_>>>()?;
-
-    let ordering_fields = ordering_fields(ordering_req, &ordering_types);
-
-    Ok(Arc::new(AggregateFunctionExpr {
-        fun: fun.clone(),
-        args: input_phy_exprs.to_vec(),
-        logical_args: input_exprs.to_vec(),
-        data_type: fun.return_type(&input_exprs_types)?,
-        name: name.into(),
-        schema: schema.clone(),
-        dfschema: dfschema.clone(),
-        sort_exprs: sort_exprs.to_vec(),
-        ordering_req: ordering_req.to_vec(),
-        ignore_nulls,
-        ordering_fields,
-        is_distinct,
-        input_type: input_exprs_types[0].clone(),
-        is_reversed,
-    }))
+        let data_type = fun.return_type(&input_exprs_types)?;
+
+        Ok(Arc::new(AggregateFunctionExpr {
+            fun: Arc::unwrap_or_clone(fun),
+            args,
+            logical_args,
+            data_type,
+            name,
+            schema: Arc::unwrap_or_clone(schema),
+            dfschema,
+            sort_exprs,
+            ordering_req,
+            ignore_nulls,
+            ordering_fields,
+            is_distinct,
+            input_type: input_exprs_types[0].clone(),
+            is_reversed,
+        }))
+    }
+
+    pub fn name(mut self, name: impl Into<String>) -> Self {
+        self.name = name.into();
+        self
+    }
+
+    pub fn schema(mut self, schema: SchemaRef) -> Self {
+        self.schema = schema;
+        self
+    }
+
+    pub fn dfschema(mut self, dfschema: DFSchema) -> Self {
+        self.dfschema = dfschema;
+        self
+    }
+
+    pub fn order_by(mut self, order_by: LexOrdering) -> Self {
+        self.ordering_req = order_by;
+        self
+    }
+
+    pub fn reversed(mut self) -> Self {
+        self.is_reversed = true;
+        self
+    }
+
+    pub fn with_reversed(mut self, is_reversed: bool) -> Self {
+        self.is_reversed = is_reversed;
+        self
+    }
+
+    pub fn distinct(mut self) -> Self {
+        self.is_distinct = true;
+        self
+    }
+
+    pub fn with_distinct(mut self, is_distinct: bool) -> Self {
+        self.is_distinct = is_distinct;
+        self
+    }
+
+    pub fn ignore_nulls(mut self) -> Self {
+        self.ignore_nulls = true;
+        self
+    }
+
+    pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
+        self.ignore_nulls = ignore_nulls;
+        self
+    }
+
+    /// This method will be deprecated in 
<https://github.com/apache/datafusion/issues/11359>
+    pub fn sort_exprs(mut self, sort_exprs: Vec<Expr>) -> Self {
+        self.sort_exprs = sort_exprs;
+        self
+    }
+
+    /// This method will be deprecated in 
<https://github.com/apache/datafusion/issues/11359>
+    pub fn logical_exprs(mut self, logical_args: Vec<Expr>) -> Self {
+        self.logical_args = logical_args;
+        self
+    }
 }
 
 /// An aggregate expression that:
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index e7cd5cb272..d1152038eb 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -1211,7 +1211,7 @@ mod tests {
 
     use crate::common::collect;
     use datafusion_physical_expr_common::aggregate::{
-        create_aggregate_expr, create_aggregate_expr_with_dfschema,
+        create_aggregate_expr_with_dfschema, AggregateExprBuilder,
     };
     use datafusion_physical_expr_common::expressions::Literal;
     use futures::{FutureExt, Stream};
@@ -1351,18 +1351,11 @@ mod tests {
             ],
         };
 
-        let aggregates = vec![create_aggregate_expr(
-            &count_udaf(),
-            &[lit(1i8)],
-            &[datafusion_expr::lit(1i8)],
-            &[],
-            &[],
-            &input_schema,
-            "COUNT(1)",
-            false,
-            false,
-            false,
-        )?];
+        let aggregates = vec![AggregateExprBuilder::new(count_udaf(), 
vec![lit(1i8)])
+            .schema(Arc::clone(&input_schema))
+            .name("COUNT(1)")
+            .logical_exprs(vec![datafusion_expr::lit(1i8)])
+            .build()?];
 
         let task_ctx = if spill {
             new_spill_ctx(4, 1000)
@@ -1501,18 +1494,13 @@ mod tests {
             groups: vec![vec![false]],
         };
 
-        let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![create_aggregate_expr(
-            &avg_udaf(),
-            &[col("b", &input_schema)?],
-            &[datafusion_expr::col("b")],
-            &[],
-            &[],
-            &input_schema,
-            "AVG(b)",
-            false,
-            false,
-            false,
-        )?];
+        let aggregates: Vec<Arc<dyn AggregateExpr>> =
+            vec![
+                AggregateExprBuilder::new(avg_udaf(), vec![col("b", 
&input_schema)?])
+                    .schema(Arc::clone(&input_schema))
+                    .name("AVG(b)")
+                    .build()?,
+            ];
 
         let task_ctx = if spill {
             // set to an appropriate value to trigger spill
@@ -1803,21 +1791,11 @@ mod tests {
     }
 
     // Median(a)
-    fn test_median_agg_expr(schema: &Schema) -> Result<Arc<dyn AggregateExpr>> 
{
-        let args = vec![col("a", schema)?];
-        let fun = median_udaf();
-        datafusion_physical_expr_common::aggregate::create_aggregate_expr(
-            &fun,
-            &args,
-            &[],
-            &[],
-            &[],
-            schema,
-            "MEDIAN(a)",
-            false,
-            false,
-            false,
-        )
+    fn test_median_agg_expr(schema: SchemaRef) -> Result<Arc<dyn 
AggregateExpr>> {
+        AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?])
+            .schema(schema)
+            .name("MEDIAN(a)")
+            .build()
     }
 
     #[tokio::test]
@@ -1840,21 +1818,16 @@ mod tests {
 
         // something that allocates within the aggregator
         let aggregates_v0: Vec<Arc<dyn AggregateExpr>> =
-            vec![test_median_agg_expr(&input_schema)?];
+            vec![test_median_agg_expr(Arc::clone(&input_schema))?];
 
         // use fast-path in `row_hash.rs`.
-        let aggregates_v2: Vec<Arc<dyn AggregateExpr>> = 
vec![create_aggregate_expr(
-            &avg_udaf(),
-            &[col("b", &input_schema)?],
-            &[datafusion_expr::col("b")],
-            &[],
-            &[],
-            &input_schema,
-            "AVG(b)",
-            false,
-            false,
-            false,
-        )?];
+        let aggregates_v2: Vec<Arc<dyn AggregateExpr>> =
+            vec![
+                AggregateExprBuilder::new(avg_udaf(), vec![col("b", 
&input_schema)?])
+                    .schema(Arc::clone(&input_schema))
+                    .name("AVG(b)")
+                    .build()?,
+            ];
 
         for (version, groups, aggregates) in [
             (0, groups_none, aggregates_v0),
@@ -1908,18 +1881,13 @@ mod tests {
 
         let groups = PhysicalGroupBy::default();
 
-        let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![create_aggregate_expr(
-            &avg_udaf(),
-            &[col("a", &schema)?],
-            &[datafusion_expr::col("a")],
-            &[],
-            &[],
-            &schema,
-            "AVG(a)",
-            false,
-            false,
-            false,
-        )?];
+        let aggregates: Vec<Arc<dyn AggregateExpr>> =
+            vec![
+                AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?])
+                    .schema(Arc::clone(&schema))
+                    .name("AVG(a)")
+                    .build()?,
+            ];
 
         let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
         let refs = blocking_exec.refs();
@@ -1953,18 +1921,13 @@ mod tests {
         let groups =
             PhysicalGroupBy::new_single(vec![(col("a", &schema)?, 
"a".to_string())]);
 
-        let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![create_aggregate_expr(
-            &avg_udaf(),
-            &[col("b", &schema)?],
-            &[datafusion_expr::col("b")],
-            &[],
-            &[],
-            &schema,
-            "AVG(b)",
-            false,
-            false,
-            false,
-        )?];
+        let aggregates: Vec<Arc<dyn AggregateExpr>> =
+            vec![
+                AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+                    .schema(Arc::clone(&schema))
+                    .name("AVG(b)")
+                    .build()?,
+            ];
 
         let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
         let refs = blocking_exec.refs();
@@ -2388,18 +2351,11 @@ mod tests {
             ],
         );
 
-        let aggregates: Vec<Arc<dyn AggregateExpr>> = 
vec![create_aggregate_expr(
-            count_udaf().as_ref(),
-            &[lit(1)],
-            &[datafusion_expr::lit(1)],
-            &[],
-            &[],
-            schema.as_ref(),
-            "1",
-            false,
-            false,
-            false,
-        )?];
+        let aggregates: Vec<Arc<dyn AggregateExpr>> =
+            vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)])
+                .schema(Arc::clone(&schema))
+                .name("1")
+                .build()?];
 
         let input_batches = (0..4)
             .map(|_| {
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index 959796489c..ffe558e215 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -26,16 +26,16 @@ use crate::{
         cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, 
NthValue, Ntile,
         PhysicalSortExpr, RowNumber,
     },
-    udaf, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr,
+    ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr,
 };
 
 use arrow::datatypes::Schema;
 use arrow_schema::{DataType, Field, SchemaRef};
-use datafusion_common::{exec_err, Column, DataFusionError, Result, 
ScalarValue};
-use datafusion_expr::Expr;
+use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
+use datafusion_expr::{col, Expr, SortExpr};
 use datafusion_expr::{
-    BuiltInWindowFunction, PartitionEvaluator, SortExpr, WindowFrame,
-    WindowFunctionDefinition, WindowUDF,
+    BuiltInWindowFunction, PartitionEvaluator, WindowFrame, 
WindowFunctionDefinition,
+    WindowUDF,
 };
 use datafusion_physical_expr::equivalence::collapse_lex_req;
 use datafusion_physical_expr::{
@@ -44,6 +44,7 @@ use datafusion_physical_expr::{
     AggregateExpr, ConstExpr, EquivalenceProperties, LexOrdering,
     PhysicalSortRequirement,
 };
+use datafusion_physical_expr_common::aggregate::AggregateExprBuilder;
 use itertools::Itertools;
 
 mod bounded_window_agg_exec;
@@ -95,7 +96,7 @@ pub fn create_window_expr(
     fun: &WindowFunctionDefinition,
     name: String,
     args: &[Arc<dyn PhysicalExpr>],
-    logical_args: &[Expr],
+    _logical_args: &[Expr],
     partition_by: &[Arc<dyn PhysicalExpr>],
     order_by: &[PhysicalSortExpr],
     window_frame: Arc<WindowFrame>,
@@ -129,7 +130,6 @@ pub fn create_window_expr(
             ))
         }
         WindowFunctionDefinition::AggregateUDF(fun) => {
-            // TODO: Ordering not supported for Window UDFs yet
             // Convert `Vec<PhysicalSortExpr>` into `Vec<Expr::Sort>`
             let sort_exprs = order_by
                 .iter()
@@ -137,28 +137,20 @@ pub fn create_window_expr(
                     let field_name = expr.to_string();
                     let field_name = 
field_name.split('@').next().unwrap_or(&field_name);
                     Expr::Sort(SortExpr {
-                        expr: Box::new(Expr::Column(Column::new(
-                            None::<String>,
-                            field_name,
-                        ))),
+                        expr: Box::new(col(field_name)),
                         asc: !options.descending,
                         nulls_first: options.nulls_first,
                     })
                 })
                 .collect::<Vec<_>>();
 
-            let aggregate = udaf::create_aggregate_expr(
-                fun.as_ref(),
-                args,
-                logical_args,
-                &sort_exprs,
-                order_by,
-                input_schema,
-                name,
-                ignore_nulls,
-                false,
-                false,
-            )?;
+            let aggregate = AggregateExprBuilder::new(Arc::clone(fun), 
args.to_vec())
+                .schema(Arc::new(input_schema.clone()))
+                .name(name)
+                .order_by(order_by.to_vec())
+                .sort_exprs(sort_exprs)
+                .with_ignore_nulls(ignore_nulls)
+                .build()?;
             window_expr_from_aggregate_expr(
                 partition_by,
                 order_by,
@@ -166,6 +158,7 @@ pub fn create_window_expr(
                 aggregate,
             )
         }
+        // TODO: Ordering not supported for Window UDFs yet
         WindowFunctionDefinition::WindowUDF(fun) => 
Arc::new(BuiltInWindowExpr::new(
             create_udwf_window_expr(fun, args, input_schema, name)?,
             partition_by,
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 8c9e5bbd0e..5c4d41f0ec 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -18,6 +18,7 @@
 use std::fmt::Debug;
 use std::sync::Arc;
 
+use datafusion::physical_expr_common::aggregate::AggregateExprBuilder;
 use prost::bytes::BufMut;
 use prost::Message;
 
@@ -58,7 +59,7 @@ use 
datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMerge
 use datafusion::physical_plan::union::{InterleaveExec, UnionExec};
 use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
 use datafusion::physical_plan::{
-    udaf, AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, 
WindowExpr,
+    AggregateExpr, ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr,
 };
 use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result};
 use datafusion_expr::{AggregateUDF, ScalarUDF};
@@ -501,13 +502,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                                                 None => 
registry.udaf(udaf_name)?
                                             };
 
-                                            // TODO: 'logical_exprs' is not 
supported for UDAF yet.
-                                            // approx_percentile_cont and 
approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
-                                            let logical_exprs = &[];
+                                            // TODO: approx_percentile_cont 
and approx_percentile_cont_weight are not supported for UDAF from protobuf yet.
                                             // TODO: `order by` is not 
supported for UDAF yet
-                                            let sort_exprs = &[];
-                                            let ordering_req = &[];
-                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, 
sort_exprs, ordering_req, &physical_schema, name, agg_node.ignore_nulls, 
agg_node.distinct, false)
+                                            AggregateExprBuilder::new(agg_udf, 
input_phy_expr).schema(Arc::clone(&physical_schema)).name(name).with_ignore_nulls(agg_node.ignore_nulls).with_distinct(agg_node.distinct).build()
                                         }
                                     }
                                 }).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 31ed0837d2..3ddc122e3d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -24,6 +24,7 @@ use std::vec;
 
 use arrow::array::RecordBatch;
 use arrow::csv::WriterBuilder;
+use datafusion::physical_expr_common::aggregate::AggregateExprBuilder;
 use prost::Message;
 
 use datafusion::arrow::array::ArrayRef;
@@ -64,7 +65,6 @@ use 
datafusion::physical_plan::placeholder_row::PlaceholderRowExec;
 use datafusion::physical_plan::projection::ProjectionExec;
 use datafusion::physical_plan::repartition::RepartitionExec;
 use datafusion::physical_plan::sorts::sort::SortExec;
-use datafusion::physical_plan::udaf::create_aggregate_expr;
 use datafusion::physical_plan::union::{InterleaveExec, UnionExec};
 use datafusion::physical_plan::windows::{
     BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec,
@@ -86,7 +86,7 @@ use datafusion_expr::{
 };
 use datafusion_functions_aggregate::average::avg_udaf;
 use datafusion_functions_aggregate::nth_value::nth_value_udaf;
-use datafusion_functions_aggregate::string_agg::StringAgg;
+use datafusion_functions_aggregate::string_agg::string_agg_udaf;
 use datafusion_proto::physical_plan::{
     AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
 };
@@ -291,18 +291,13 @@ fn roundtrip_window() -> Result<()> {
     ));
 
     let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new(
-        create_aggregate_expr(
-            &avg_udaf(),
-            &[cast(col("b", &schema)?, &schema, DataType::Float64)?],
-            &[],
-            &[],
-            &[],
-            &schema,
-            "avg(b)",
-            false,
-            false,
-            false,
-        )?,
+        AggregateExprBuilder::new(
+            avg_udaf(),
+            vec![cast(col("b", &schema)?, &schema, DataType::Float64)?],
+        )
+        .schema(Arc::clone(&schema))
+        .name("avg(b)")
+        .build()?,
         &[],
         &[],
         Arc::new(WindowFrame::new(None)),
@@ -315,18 +310,10 @@ fn roundtrip_window() -> Result<()> {
     );
 
     let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?];
-    let sum_expr = create_aggregate_expr(
-        &sum_udaf(),
-        &args,
-        &[],
-        &[],
-        &[],
-        &schema,
-        "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING",
-        false,
-        false,
-        false,
-    )?;
+    let sum_expr = AggregateExprBuilder::new(sum_udaf(), args)
+        .schema(Arc::clone(&schema))
+        .name("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING")
+        .build()?;
 
     let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new(
         sum_expr,
@@ -357,49 +344,28 @@ fn rountrip_aggregate() -> Result<()> {
     let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
         vec![(col("a", &schema)?, "unused".to_string())];
 
+    let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", 
&schema)?])
+        .schema(Arc::clone(&schema))
+        .name("AVG(b)")
+        .build()?;
+    let nth_expr =
+        AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, 
lit(1u64)])
+            .schema(Arc::clone(&schema))
+            .name("NTH_VALUE(b, 1)")
+            .build()?;
+    let str_agg_expr =
+        AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, 
lit(1u64)])
+            .schema(Arc::clone(&schema))
+            .name("NTH_VALUE(b, 1)")
+            .build()?;
+
     let test_cases: Vec<Vec<Arc<dyn AggregateExpr>>> = vec![
         // AVG
-        vec![create_aggregate_expr(
-            &avg_udaf(),
-            &[col("b", &schema)?],
-            &[],
-            &[],
-            &[],
-            &schema,
-            "AVG(b)",
-            false,
-            false,
-            false,
-        )?],
+        vec![avg_expr],
         // NTH_VALUE
-        vec![create_aggregate_expr(
-            &nth_value_udaf(),
-            &[col("b", &schema)?, lit(1u64)],
-            &[],
-            &[],
-            &[],
-            &schema,
-            "NTH_VALUE(b, 1)",
-            false,
-            false,
-            false,
-        )?],
+        vec![nth_expr],
         // STRING_AGG
-        vec![create_aggregate_expr(
-            &AggregateUDF::new_from_impl(StringAgg::new()),
-            &[
-                cast(col("b", &schema)?, &schema, DataType::Utf8)?,
-                lit(ScalarValue::Utf8(Some(",".to_string()))),
-            ],
-            &[],
-            &[],
-            &[],
-            &schema,
-            "STRING_AGG(name, ',')",
-            false,
-            false,
-            false,
-        )?],
+        vec![str_agg_expr],
     ];
 
     for aggregates in test_cases {
@@ -426,18 +392,13 @@ fn rountrip_aggregate_with_limit() -> Result<()> {
     let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
         vec![(col("a", &schema)?, "unused".to_string())];
 
-    let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
-        &avg_udaf(),
-        &[col("b", &schema)?],
-        &[],
-        &[],
-        &[],
-        &schema,
-        "AVG(b)",
-        false,
-        false,
-        false,
-    )?];
+    let aggregates: Vec<Arc<dyn AggregateExpr>> =
+        vec![
+            AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
+                .schema(Arc::clone(&schema))
+                .name("AVG(b)")
+                .build()?,
+        ];
 
     let agg = AggregateExec::try_new(
         AggregateMode::Final,
@@ -498,18 +459,13 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
     let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
         vec![(col("a", &schema)?, "unused".to_string())];
 
-    let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
-        &udaf,
-        &[col("b", &schema)?],
-        &[],
-        &[],
-        &[],
-        &schema,
-        "example_agg",
-        false,
-        false,
-        false,
-    )?];
+    let aggregates: Vec<Arc<dyn AggregateExpr>> =
+        vec![
+            AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?])
+                .schema(Arc::clone(&schema))
+                .name("example_agg")
+                .build()?,
+        ];
 
     roundtrip_test_with_context(
         Arc::new(AggregateExec::try_new(
@@ -994,21 +950,16 @@ fn roundtrip_aggregate_udf_extension_codec() -> 
Result<()> {
         DataType::Int64,
     ));
 
-    let udaf = AggregateUDF::from(MyAggregateUDF::new("result".to_string()));
-    let aggr_args: [Arc<dyn PhysicalExpr>; 1] =
-        [Arc::new(Literal::new(ScalarValue::from(42)))];
-    let aggr_expr = create_aggregate_expr(
-        &udaf,
-        &aggr_args,
-        &[],
-        &[],
-        &[],
-        &schema,
-        "aggregate_udf",
-        false,
-        false,
-        false,
-    )?;
+    let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new(
+        "result".to_string(),
+    )));
+    let aggr_args: Vec<Arc<dyn PhysicalExpr>> =
+        vec![Arc::new(Literal::new(ScalarValue::from(42)))];
+
+    let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), 
aggr_args.clone())
+        .schema(Arc::clone(&schema))
+        .name("aggregate_udf")
+        .build()?;
 
     let filter = Arc::new(FilterExec::try_new(
         Arc::new(BinaryExpr::new(
@@ -1030,18 +981,12 @@ fn roundtrip_aggregate_udf_extension_codec() -> 
Result<()> {
         vec![col("author", &schema)?],
     )?);
 
-    let aggr_expr = create_aggregate_expr(
-        &udaf,
-        &aggr_args,
-        &[],
-        &[],
-        &[],
-        &schema,
-        "aggregate_udf",
-        true,
-        true,
-        false,
-    )?;
+    let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone())
+        .schema(Arc::clone(&schema))
+        .name("aggregate_udf")
+        .distinct()
+        .ignore_nulls()
+        .build()?;
 
     let aggregate = Arc::new(AggregateExec::try_new(
         AggregateMode::Final,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to