This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 053dc5a90 move type coercion of agg and agg_udaf to logical phase 
(#3768)
053dc5a90 is described below

commit 053dc5a90654bfe68dfa65f71040d3ba2dffb515
Author: Kun Liu <[email protected]>
AuthorDate: Wed Oct 12 05:35:38 2022 +0800

    move type coercion of agg and agg_udaf to logical phase (#3768)
---
 datafusion/core/src/physical_plan/udaf.rs          |  13 +-
 datafusion/optimizer/src/type_coercion.rs          | 186 +++++++++++++++++++-
 datafusion/physical-expr/src/aggregate/build_in.rs | 188 ++++++++++++---------
 .../physical-expr/src/aggregate/coercion_rule.rs   |  54 ------
 datafusion/physical-expr/src/aggregate/mod.rs      |   1 -
 5 files changed, 297 insertions(+), 145 deletions(-)

diff --git a/datafusion/core/src/physical_plan/udaf.rs 
b/datafusion/core/src/physical_plan/udaf.rs
index e017bb5ad..659ff560d 100644
--- a/datafusion/core/src/physical_plan/udaf.rs
+++ b/datafusion/core/src/physical_plan/udaf.rs
@@ -26,9 +26,7 @@ use arrow::{
     datatypes::{DataType, Schema},
 };
 
-use super::{
-    expressions::format_state_name, type_coercion::coerce, Accumulator, 
AggregateExpr,
-};
+use super::{expressions::format_state_name, Accumulator, AggregateExpr};
 use crate::error::Result;
 use crate::physical_plan::PhysicalExpr;
 pub use datafusion_expr::AggregateUDF;
@@ -43,18 +41,15 @@ pub fn create_aggregate_expr(
     input_schema: &Schema,
     name: impl Into<String>,
 ) -> Result<Arc<dyn AggregateExpr>> {
-    // coerce
-    let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, 
&fun.signature)?;
-
-    let coerced_exprs_types = coerced_phy_exprs
+    let input_exprs_types = input_phy_exprs
         .iter()
         .map(|arg| arg.data_type(input_schema))
         .collect::<Result<Vec<_>>>()?;
 
     Ok(Arc::new(AggregateFunctionExpr {
         fun: fun.clone(),
-        args: coerced_phy_exprs.clone(),
-        data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
+        args: input_phy_exprs.to_vec(),
+        data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(),
         name: name.into(),
     }))
 }
diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
index 5438632ba..89d5d660b 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
 };
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
-    function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, 
is_unknown,
-    Expr, LogicalPlan, Operator,
+    aggregate_function, function, is_false, is_not_false, is_not_true, 
is_not_unknown,
+    is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, 
Operator,
 };
 use datafusion_expr::{ExprSchemable, Signature};
 use std::sync::Arc;
@@ -407,6 +407,39 @@ impl ExprRewriter for TypeCoercionRewriter {
                 };
                 Ok(expr)
             }
+            Expr::AggregateFunction {
+                fun,
+                args,
+                distinct,
+                filter,
+            } => {
+                let new_expr = coerce_agg_exprs_for_signature(
+                    &fun,
+                    &args,
+                    &self.schema,
+                    &aggregate_function::signature(&fun),
+                )?;
+                let expr = Expr::AggregateFunction {
+                    fun,
+                    args: new_expr,
+                    distinct,
+                    filter,
+                };
+                Ok(expr)
+            }
+            Expr::AggregateUDF { fun, args, filter } => {
+                let new_expr = coerce_arguments_for_signature(
+                    args.as_slice(),
+                    &self.schema,
+                    &fun.signature,
+                )?;
+                let expr = Expr::AggregateUDF {
+                    fun,
+                    args: new_expr,
+                    filter,
+                };
+                Ok(expr)
+            }
             expr => Ok(expr),
         }
     }
@@ -448,6 +481,33 @@ fn coerce_arguments_for_signature(
         .collect::<Result<Vec<_>>>()
 }
 
+/// Returns the coerced exprs for each `input_exprs`.
+/// Get the coerced data type from `aggregate_rule::coerce_types` and add 
`try_cast` if the
+/// data type of `input_exprs` need to be coerced.
+fn coerce_agg_exprs_for_signature(
+    agg_fun: &AggregateFunction,
+    input_exprs: &[Expr],
+    schema: &DFSchema,
+    signature: &Signature,
+) -> Result<Vec<Expr>> {
+    if input_exprs.is_empty() {
+        return Ok(vec![]);
+    }
+    let current_types = input_exprs
+        .iter()
+        .map(|e| e.get_type(schema))
+        .collect::<Result<Vec<_>>>()?;
+
+    let coerced_types =
+        type_coercion::aggregates::coerce_types(agg_fun, &current_types, 
signature)?;
+
+    input_exprs
+        .iter()
+        .enumerate()
+        .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema))
+        .collect::<Result<Vec<_>>>()
+}
+
 #[cfg(test)]
 mod test {
     use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
@@ -456,7 +516,9 @@ mod test {
     use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
     use datafusion_expr::expr_rewriter::ExprRewritable;
     use datafusion_expr::{
-        cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, 
ColumnarValue,
+        cast, col, concat, concat_ws, create_udaf, is_true,
+        AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF,
+        BuiltinScalarFunction, ColumnarValue, StateTypeFunction,
     };
     use datafusion_expr::{
         lit,
@@ -464,6 +526,7 @@ mod test {
         Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, 
ScalarUDF,
         Signature, Volatility,
     };
+    use datafusion_physical_expr::expressions::AvgAccumulator;
     use std::sync::Arc;
 
     #[test]
@@ -596,6 +659,123 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn agg_udaf() -> Result<()> {
+        let empty = empty();
+        let my_avg = create_udaf(
+            "MY_AVG",
+            DataType::Float64,
+            Arc::new(DataType::Float64),
+            Volatility::Immutable,
+            Arc::new(|_| 
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
+            Arc::new(vec![DataType::UInt64, DataType::Float64]),
+        );
+        let udaf = Expr::AggregateUDF {
+            fun: Arc::new(my_avg),
+            args: vec![lit(10i64)],
+            filter: None,
+        };
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], 
empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn agg_udaf_invalid_input() -> Result<()> {
+        let empty = empty();
+        let return_type: ReturnTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
+        let state_type: StateTypeFunction =
+            Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, 
DataType::Float64])));
+        let accumulator: AccumulatorFunctionImplementation =
+            Arc::new(|_| 
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
+        let my_avg = AggregateUDF::new(
+            "MY_AVG",
+            &Signature::uniform(1, vec![DataType::Float64], 
Volatility::Immutable),
+            &return_type,
+            &accumulator,
+            &state_type,
+        );
+        let udaf = Expr::AggregateUDF {
+            fun: Arc::new(my_avg),
+            args: vec![lit("10")],
+            filter: None,
+        };
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], 
empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config);
+        assert!(plan.is_err());
+        assert_eq!(
+            "Plan(\"Coercion from [Utf8] to the signature Uniform(1, 
[Float64]) failed.\")",
+            &format!("{:?}", plan.err().unwrap())
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn agg_function_case() -> Result<()> {
+        let empty = empty();
+        let fun: AggregateFunction = AggregateFunction::Avg;
+        let agg_expr = Expr::AggregateFunction {
+            fun,
+            args: vec![lit(12i64)],
+            distinct: false,
+            filter: None,
+        };
+        let plan =
+            LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, 
None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: AVG(Int64(12))\n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
+
+        let empty = empty_with_type(DataType::Int32);
+        let fun: AggregateFunction = AggregateFunction::Avg;
+        let agg_expr = Expr::AggregateFunction {
+            fun,
+            args: vec![col("a")],
+            distinct: false,
+            filter: None,
+        };
+        let plan =
+            LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty, 
None)?);
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: AVG(a)\n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn agg_function_invalid_input() -> Result<()> {
+        let empty = empty();
+        let fun: AggregateFunction = AggregateFunction::Avg;
+        let agg_expr = Expr::AggregateFunction {
+            fun,
+            args: vec![lit("1")],
+            distinct: false,
+            filter: None,
+        };
+        let expr = Projection::try_new(vec![agg_expr], empty, None);
+        assert!(expr.is_err());
+        assert_eq!(
+            "Plan(\"The function Avg does not support inputs of type Utf8.\")",
+            &format!("{:?}", expr.err().unwrap())
+        );
+        Ok(())
+    }
+
     #[test]
     fn binary_op_date32_add_interval() -> Result<()> {
         //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs 
b/datafusion/physical-expr/src/aggregate/build_in.rs
index e3154488c..597b51575 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -26,11 +26,9 @@
 //! * Signature: see `Signature`
 //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, 
([f32]) -> f32, ([f64]) -> f64.
 
-use crate::aggregate::coercion_rule::coerce_exprs;
 use crate::{expressions, AggregateExpr, PhysicalExpr};
 use arrow::datatypes::Schema;
 use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::aggregate_function;
 use datafusion_expr::aggregate_function::return_type;
 pub use datafusion_expr::AggregateFunction;
 use std::sync::Arc;
@@ -45,89 +43,72 @@ pub fn create_aggregate_expr(
     name: impl Into<String>,
 ) -> Result<Arc<dyn AggregateExpr>> {
     let name = name.into();
-    // get the coerced phy exprs if some expr need to be wrapped with the try 
cast.
-    let coerced_phy_exprs = coerce_exprs(
-        fun,
-        input_phy_exprs,
-        input_schema,
-        &aggregate_function::signature(fun),
-    )?;
-    if coerced_phy_exprs.is_empty() {
-        return Err(DataFusionError::Plan(format!(
-            "Invalid or wrong number of arguments passed to aggregate: '{}'",
-            name,
-        )));
-    }
-    let coerced_exprs_types = coerced_phy_exprs
-        .iter()
-        .map(|e| e.data_type(input_schema))
-        .collect::<Result<Vec<_>>>()?;
-
     // get the result data type for this aggregate function
     let input_phy_types = input_phy_exprs
         .iter()
         .map(|e| e.data_type(input_schema))
         .collect::<Result<Vec<_>>>()?;
     let return_type = return_type(fun, &input_phy_types)?;
+    let input_phy_exprs = input_phy_exprs.to_vec();
 
     Ok(match (fun, distinct) {
         (AggregateFunction::Count, false) => Arc::new(expressions::Count::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
         (AggregateFunction::Count, true) => 
Arc::new(expressions::DistinctCount::new(
-            coerced_exprs_types,
-            coerced_phy_exprs,
+            input_phy_types,
+            input_phy_exprs,
             name,
             return_type,
         )),
         (AggregateFunction::Grouping, _) => 
Arc::new(expressions::Grouping::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
         (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
         (AggregateFunction::Sum, true) => 
Arc::new(expressions::DistinctSum::new(
-            vec![coerced_phy_exprs[0].clone()],
+            vec![input_phy_exprs[0].clone()],
             name,
             return_type,
         )),
         (AggregateFunction::ApproxDistinct, _) => {
             Arc::new(expressions::ApproxDistinct::new(
-                coerced_phy_exprs[0].clone(),
+                input_phy_exprs[0].clone(),
                 name,
-                coerced_exprs_types[0].clone(),
+                input_phy_types[0].clone(),
             ))
         }
         (AggregateFunction::ArrayAgg, false) => 
Arc::new(expressions::ArrayAgg::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
-            coerced_exprs_types[0].clone(),
+            input_phy_types[0].clone(),
         )),
         (AggregateFunction::ArrayAgg, true) => {
             Arc::new(expressions::DistinctArrayAgg::new(
-                coerced_phy_exprs[0].clone(),
+                input_phy_exprs[0].clone(),
                 name,
-                coerced_exprs_types[0].clone(),
+                input_phy_types[0].clone(),
             ))
         }
         (AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
         (AggregateFunction::Max, _) => Arc::new(expressions::Max::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
         (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
@@ -137,7 +118,7 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::Variance, false) => 
Arc::new(expressions::Variance::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
@@ -146,21 +127,17 @@ pub fn create_aggregate_expr(
                 "VAR(DISTINCT) aggregations are not available".to_string(),
             ));
         }
-        (AggregateFunction::VariancePop, false) => {
-            Arc::new(expressions::VariancePop::new(
-                coerced_phy_exprs[0].clone(),
-                name,
-                return_type,
-            ))
-        }
+        (AggregateFunction::VariancePop, false) => Arc::new(
+            expressions::VariancePop::new(input_phy_exprs[0].clone(), name, 
return_type),
+        ),
         (AggregateFunction::VariancePop, true) => {
             return Err(DataFusionError::NotImplemented(
                 "VAR_POP(DISTINCT) aggregations are not available".to_string(),
             ));
         }
         (AggregateFunction::Covariance, false) => 
Arc::new(expressions::Covariance::new(
-            coerced_phy_exprs[0].clone(),
-            coerced_phy_exprs[1].clone(),
+            input_phy_exprs[0].clone(),
+            input_phy_exprs[1].clone(),
             name,
             return_type,
         )),
@@ -171,8 +148,8 @@ pub fn create_aggregate_expr(
         }
         (AggregateFunction::CovariancePop, false) => {
             Arc::new(expressions::CovariancePop::new(
-                coerced_phy_exprs[0].clone(),
-                coerced_phy_exprs[1].clone(),
+                input_phy_exprs[0].clone(),
+                input_phy_exprs[1].clone(),
                 name,
                 return_type,
             ))
@@ -183,7 +160,7 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::Stddev, false) => 
Arc::new(expressions::Stddev::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
@@ -193,7 +170,7 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::StddevPop, false) => 
Arc::new(expressions::StddevPop::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
@@ -204,8 +181,8 @@ pub fn create_aggregate_expr(
         }
         (AggregateFunction::Correlation, false) => {
             Arc::new(expressions::Correlation::new(
-                coerced_phy_exprs[0].clone(),
-                coerced_phy_exprs[1].clone(),
+                input_phy_exprs[0].clone(),
+                input_phy_exprs[1].clone(),
                 name,
                 return_type,
             ))
@@ -216,17 +193,17 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::ApproxPercentileCont, false) => {
-            if coerced_phy_exprs.len() == 2 {
+            if input_phy_exprs.len() == 2 {
                 Arc::new(expressions::ApproxPercentileCont::new(
                     // Pass in the desired percentile expr
-                    coerced_phy_exprs,
+                    input_phy_exprs,
                     name,
                     return_type,
                 )?)
             } else {
                 Arc::new(expressions::ApproxPercentileCont::new_with_max_size(
                     // Pass in the desired percentile expr
-                    coerced_phy_exprs,
+                    input_phy_exprs,
                     name,
                     return_type,
                 )?)
@@ -241,7 +218,7 @@ pub fn create_aggregate_expr(
         (AggregateFunction::ApproxPercentileContWithWeight, false) => {
             Arc::new(expressions::ApproxPercentileContWithWeight::new(
                 // Pass in the desired percentile expr
-                coerced_phy_exprs,
+                input_phy_exprs,
                 name,
                 return_type,
             )?)
@@ -254,7 +231,7 @@ pub fn create_aggregate_expr(
         }
         (AggregateFunction::ApproxMedian, false) => {
             Arc::new(expressions::ApproxMedian::try_new(
-                coerced_phy_exprs[0].clone(),
+                input_phy_exprs[0].clone(),
                 name,
                 return_type,
             )?)
@@ -265,7 +242,7 @@ pub fn create_aggregate_expr(
             ));
         }
         (AggregateFunction::Median, false) => 
Arc::new(expressions::Median::new(
-            coerced_phy_exprs[0].clone(),
+            input_phy_exprs[0].clone(),
             name,
             return_type,
         )),
@@ -281,13 +258,14 @@ pub fn create_aggregate_expr(
 mod tests {
     use super::*;
     use crate::expressions::{
-        ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, 
Correlation,
-        Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, 
Sum,
-        Variance,
+        try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, 
ArrayAgg, Avg,
+        Correlation, Count, Covariance, DistinctArrayAgg, DistinctCount, Max, 
Min,
+        Stddev, Sum, Variance,
     };
     use arrow::datatypes::{DataType, Field};
     use datafusion_common::ScalarValue;
     use datafusion_expr::type_coercion::aggregates::NUMERICS;
+    use datafusion_expr::{aggregate_function, type_coercion, Signature};
 
     #[test]
     fn test_count_arragg_approx_expr() -> Result<()> {
@@ -311,7 +289,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -344,9 +322,9 @@ mod tests {
                                 DataType::List(Box::new(Field::new(
                                     "item",
                                     data_type.clone(),
-                                    true
+                                    true,
                                 ))),
-                                false
+                                false,
                             ),
                             result_agg_phy_exprs.field().unwrap()
                         );
@@ -354,7 +332,7 @@ mod tests {
                     _ => {}
                 };
 
-                let result_distinct = create_aggregate_expr(
+                let result_distinct = create_physical_agg_expr_for_test(
                     &fun,
                     true,
                     &input_phy_exprs[0..1],
@@ -387,9 +365,9 @@ mod tests {
                                 DataType::List(Box::new(Field::new(
                                     "item",
                                     data_type.clone(),
-                                    true
+                                    true,
                                 ))),
-                                false
+                                false,
                             ),
                             result_agg_phy_exprs.field().unwrap()
                         );
@@ -412,7 +390,7 @@ mod tests {
                 ),
                 
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
             ];
-            let result_agg_phy_exprs = create_aggregate_expr(
+            let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                 &AggregateFunction::ApproxPercentileCont,
                 false,
                 &input_phy_exprs[..],
@@ -441,7 +419,7 @@ mod tests {
                 ),
                 
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
             ];
-            let err = create_aggregate_expr(
+            let err = create_physical_agg_expr_for_test(
                 &AggregateFunction::ApproxPercentileCont,
                 false,
                 &input_phy_exprs[..],
@@ -472,7 +450,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -521,7 +499,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -583,7 +561,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -621,7 +599,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -659,7 +637,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -697,7 +675,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -744,7 +722,7 @@ mod tests {
                             .unwrap(),
                     ),
                 ];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..2],
@@ -791,7 +769,7 @@ mod tests {
                             .unwrap(),
                     ),
                 ];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..2],
@@ -838,7 +816,7 @@ mod tests {
                             .unwrap(),
                     ),
                 ];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..2],
@@ -876,7 +854,7 @@ mod tests {
                 let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = 
vec![Arc::new(
                     expressions::Column::new_with_schema("c1", 
&input_schema).unwrap(),
                 )];
-                let result_agg_phy_exprs = create_aggregate_expr(
+                let result_agg_phy_exprs = create_physical_agg_expr_for_test(
                     &fun,
                     false,
                     &input_phy_exprs[0..1],
@@ -1065,4 +1043,58 @@ mod tests {
         let observed = return_type(&AggregateFunction::Stddev, 
&[DataType::Utf8]);
         assert!(observed.is_err());
     }
+
+    // Helper function
+    // Create aggregate expr with type coercion
+    fn create_physical_agg_expr_for_test(
+        fun: &AggregateFunction,
+        distinct: bool,
+        input_phy_exprs: &[Arc<dyn PhysicalExpr>],
+        input_schema: &Schema,
+        name: impl Into<String>,
+    ) -> Result<Arc<dyn AggregateExpr>> {
+        let name = name.into();
+        let coerced_phy_exprs = coerce_exprs_for_test(
+            fun,
+            input_phy_exprs,
+            input_schema,
+            &aggregate_function::signature(fun),
+        )?;
+        if coerced_phy_exprs.is_empty() {
+            return Err(DataFusionError::Plan(format!(
+                "Invalid or wrong number of arguments passed to aggregate: 
'{}'",
+                name,
+            )));
+        }
+        create_aggregate_expr(fun, distinct, &coerced_phy_exprs, input_schema, 
name)
+    }
+
+    // Returns the coerced exprs for each `input_exprs`.
+    // Get the coerced data type from `aggregate_rule::coerce_types` and add 
`try_cast` if the
+    // data type of `input_exprs` need to be coerced.
+    fn coerce_exprs_for_test(
+        agg_fun: &AggregateFunction,
+        input_exprs: &[Arc<dyn PhysicalExpr>],
+        schema: &Schema,
+        signature: &Signature,
+    ) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
+        if input_exprs.is_empty() {
+            return Ok(vec![]);
+        }
+        let input_types = input_exprs
+            .iter()
+            .map(|e| e.data_type(schema))
+            .collect::<Result<Vec<_>>>()?;
+
+        // get the coerced data types
+        let coerced_types =
+            type_coercion::aggregates::coerce_types(agg_fun, &input_types, 
signature)?;
+
+        // try cast if need
+        input_exprs
+            .iter()
+            .zip(coerced_types.into_iter())
+            .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, 
coerced_type))
+            .collect::<Result<Vec<_>>>()
+    }
 }
diff --git a/datafusion/physical-expr/src/aggregate/coercion_rule.rs 
b/datafusion/physical-expr/src/aggregate/coercion_rule.rs
deleted file mode 100644
index a8c68390a..000000000
--- a/datafusion/physical-expr/src/aggregate/coercion_rule.rs
+++ /dev/null
@@ -1,54 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Define coercion rules for Aggregate function.
-
-use crate::expressions::try_cast;
-use crate::PhysicalExpr;
-use arrow::datatypes::Schema;
-use datafusion_common::Result;
-use datafusion_expr::{type_coercion, AggregateFunction, Signature};
-use std::sync::Arc;
-
-/// Returns the coerced exprs for each `input_exprs`.
-/// Get the coerced data type from `aggregate_rule::coerce_types` and add 
`try_cast` if the
-/// data type of `input_exprs` need to be coerced.
-pub fn coerce_exprs(
-    agg_fun: &AggregateFunction,
-    input_exprs: &[Arc<dyn PhysicalExpr>],
-    schema: &Schema,
-    signature: &Signature,
-) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
-    if input_exprs.is_empty() {
-        return Ok(vec![]);
-    }
-    let input_types = input_exprs
-        .iter()
-        .map(|e| e.data_type(schema))
-        .collect::<Result<Vec<_>>>()?;
-
-    // get the coerced data types
-    let coerced_types =
-        type_coercion::aggregates::coerce_types(agg_fun, &input_types, 
signature)?;
-
-    // try cast if need
-    input_exprs
-        .iter()
-        .zip(coerced_types.into_iter())
-        .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, 
coerced_type))
-        .collect::<Result<Vec<_>>>()
-}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs 
b/datafusion/physical-expr/src/aggregate/mod.rs
index ec338eb68..f63746874 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -31,7 +31,6 @@ pub(crate) mod approx_percentile_cont_with_weight;
 pub(crate) mod array_agg;
 pub(crate) mod array_agg_distinct;
 pub(crate) mod average;
-pub(crate) mod coercion_rule;
 pub(crate) mod correlation;
 pub(crate) mod count;
 pub(crate) mod count_distinct;

Reply via email to