This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 053dc5a90 move type coercion of agg and agg_udaf to logical phase
(#3768)
053dc5a90 is described below
commit 053dc5a90654bfe68dfa65f71040d3ba2dffb515
Author: Kun Liu <[email protected]>
AuthorDate: Wed Oct 12 05:35:38 2022 +0800
move type coercion of agg and agg_udaf to logical phase (#3768)
---
datafusion/core/src/physical_plan/udaf.rs | 13 +-
datafusion/optimizer/src/type_coercion.rs | 186 +++++++++++++++++++-
datafusion/physical-expr/src/aggregate/build_in.rs | 188 ++++++++++++---------
.../physical-expr/src/aggregate/coercion_rule.rs | 54 ------
datafusion/physical-expr/src/aggregate/mod.rs | 1 -
5 files changed, 297 insertions(+), 145 deletions(-)
diff --git a/datafusion/core/src/physical_plan/udaf.rs
b/datafusion/core/src/physical_plan/udaf.rs
index e017bb5ad..659ff560d 100644
--- a/datafusion/core/src/physical_plan/udaf.rs
+++ b/datafusion/core/src/physical_plan/udaf.rs
@@ -26,9 +26,7 @@ use arrow::{
datatypes::{DataType, Schema},
};
-use super::{
- expressions::format_state_name, type_coercion::coerce, Accumulator,
AggregateExpr,
-};
+use super::{expressions::format_state_name, Accumulator, AggregateExpr};
use crate::error::Result;
use crate::physical_plan::PhysicalExpr;
pub use datafusion_expr::AggregateUDF;
@@ -43,18 +41,15 @@ pub fn create_aggregate_expr(
input_schema: &Schema,
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
- // coerce
- let coerced_phy_exprs = coerce(input_phy_exprs, input_schema,
&fun.signature)?;
-
- let coerced_exprs_types = coerced_phy_exprs
+ let input_exprs_types = input_phy_exprs
.iter()
.map(|arg| arg.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
Ok(Arc::new(AggregateFunctionExpr {
fun: fun.clone(),
- args: coerced_phy_exprs.clone(),
- data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
+ args: input_phy_exprs.to_vec(),
+ data_type: (fun.return_type)(&input_exprs_types)?.as_ref().clone(),
name: name.into(),
}))
}
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/type_coercion.rs
index 5438632ba..89d5d660b 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
- function, is_false, is_not_false, is_not_true, is_not_unknown, is_true,
is_unknown,
- Expr, LogicalPlan, Operator,
+ aggregate_function, function, is_false, is_not_false, is_not_true,
is_not_unknown,
+ is_true, is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan,
Operator,
};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;
@@ -407,6 +407,39 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
+ Expr::AggregateFunction {
+ fun,
+ args,
+ distinct,
+ filter,
+ } => {
+ let new_expr = coerce_agg_exprs_for_signature(
+ &fun,
+ &args,
+ &self.schema,
+ &aggregate_function::signature(&fun),
+ )?;
+ let expr = Expr::AggregateFunction {
+ fun,
+ args: new_expr,
+ distinct,
+ filter,
+ };
+ Ok(expr)
+ }
+ Expr::AggregateUDF { fun, args, filter } => {
+ let new_expr = coerce_arguments_for_signature(
+ args.as_slice(),
+ &self.schema,
+ &fun.signature,
+ )?;
+ let expr = Expr::AggregateUDF {
+ fun,
+ args: new_expr,
+ filter,
+ };
+ Ok(expr)
+ }
expr => Ok(expr),
}
}
@@ -448,6 +481,33 @@ fn coerce_arguments_for_signature(
.collect::<Result<Vec<_>>>()
}
+/// Returns the coerced exprs for each `input_exprs`.
+/// Get the coerced data type from `aggregate_rule::coerce_types` and add
`try_cast` if the
+/// data type of `input_exprs` need to be coerced.
+fn coerce_agg_exprs_for_signature(
+ agg_fun: &AggregateFunction,
+ input_exprs: &[Expr],
+ schema: &DFSchema,
+ signature: &Signature,
+) -> Result<Vec<Expr>> {
+ if input_exprs.is_empty() {
+ return Ok(vec![]);
+ }
+ let current_types = input_exprs
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ let coerced_types =
+ type_coercion::aggregates::coerce_types(agg_fun, ¤t_types,
signature)?;
+
+ input_exprs
+ .iter()
+ .enumerate()
+ .map(|(i, expr)| expr.clone().cast_to(&coerced_types[i], schema))
+ .collect::<Result<Vec<_>>>()
+}
+
#[cfg(test)]
mod test {
use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
@@ -456,7 +516,9 @@ mod test {
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{
- cast, col, concat, concat_ws, is_true, BuiltinScalarFunction,
ColumnarValue,
+ cast, col, concat, concat_ws, create_udaf, is_true,
+ AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF,
+ BuiltinScalarFunction, ColumnarValue, StateTypeFunction,
};
use datafusion_expr::{
lit,
@@ -464,6 +526,7 @@ mod test {
Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation,
ScalarUDF,
Signature, Volatility,
};
+ use datafusion_physical_expr::expressions::AvgAccumulator;
use std::sync::Arc;
#[test]
@@ -596,6 +659,123 @@ mod test {
Ok(())
}
+ #[test]
+ fn agg_udaf() -> Result<()> {
+ let empty = empty();
+ let my_avg = create_udaf(
+ "MY_AVG",
+ DataType::Float64,
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(|_|
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
+ Arc::new(vec![DataType::UInt64, DataType::Float64]),
+ );
+ let udaf = Expr::AggregateUDF {
+ fun: Arc::new(my_avg),
+ args: vec![lit(10i64)],
+ filter: None,
+ };
+ let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty, None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation",
+ &format!("{:?}", plan)
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn agg_udaf_invalid_input() -> Result<()> {
+ let empty = empty();
+ let return_type: ReturnTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
+ let state_type: StateTypeFunction =
+ Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64,
DataType::Float64])));
+ let accumulator: AccumulatorFunctionImplementation =
+ Arc::new(|_|
Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
+ let my_avg = AggregateUDF::new(
+ "MY_AVG",
+ &Signature::uniform(1, vec![DataType::Float64],
Volatility::Immutable),
+ &return_type,
+ &accumulator,
+ &state_type,
+ );
+ let udaf = Expr::AggregateUDF {
+ fun: Arc::new(my_avg),
+ args: vec![lit("10")],
+ filter: None,
+ };
+ let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf],
empty, None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config);
+ assert!(plan.is_err());
+ assert_eq!(
+ "Plan(\"Coercion from [Utf8] to the signature Uniform(1,
[Float64]) failed.\")",
+ &format!("{:?}", plan.err().unwrap())
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn agg_function_case() -> Result<()> {
+ let empty = empty();
+ let fun: AggregateFunction = AggregateFunction::Avg;
+ let agg_expr = Expr::AggregateFunction {
+ fun,
+ args: vec![lit(12i64)],
+ distinct: false,
+ filter: None,
+ };
+ let plan =
+ LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty,
None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: AVG(Int64(12))\n EmptyRelation",
+ &format!("{:?}", plan)
+ );
+
+ let empty = empty_with_type(DataType::Int32);
+ let fun: AggregateFunction = AggregateFunction::Avg;
+ let agg_expr = Expr::AggregateFunction {
+ fun,
+ args: vec![col("a")],
+ distinct: false,
+ filter: None,
+ };
+ let plan =
+ LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty,
None)?);
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: AVG(a)\n EmptyRelation",
+ &format!("{:?}", plan)
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn agg_function_invalid_input() -> Result<()> {
+ let empty = empty();
+ let fun: AggregateFunction = AggregateFunction::Avg;
+ let agg_expr = Expr::AggregateFunction {
+ fun,
+ args: vec![lit("1")],
+ distinct: false,
+ filter: None,
+ };
+ let expr = Projection::try_new(vec![agg_expr], empty, None);
+ assert!(expr.is_err());
+ assert_eq!(
+ "Plan(\"The function Avg does not support inputs of type Utf8.\")",
+ &format!("{:?}", expr.err().unwrap())
+ );
+ Ok(())
+ }
+
#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs
b/datafusion/physical-expr/src/aggregate/build_in.rs
index e3154488c..597b51575 100644
--- a/datafusion/physical-expr/src/aggregate/build_in.rs
+++ b/datafusion/physical-expr/src/aggregate/build_in.rs
@@ -26,11 +26,9 @@
//! * Signature: see `Signature`
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min,
([f32]) -> f32, ([f64]) -> f64.
-use crate::aggregate::coercion_rule::coerce_exprs;
use crate::{expressions, AggregateExpr, PhysicalExpr};
use arrow::datatypes::Schema;
use datafusion_common::{DataFusionError, Result};
-use datafusion_expr::aggregate_function;
use datafusion_expr::aggregate_function::return_type;
pub use datafusion_expr::AggregateFunction;
use std::sync::Arc;
@@ -45,89 +43,72 @@ pub fn create_aggregate_expr(
name: impl Into<String>,
) -> Result<Arc<dyn AggregateExpr>> {
let name = name.into();
- // get the coerced phy exprs if some expr need to be wrapped with the try
cast.
- let coerced_phy_exprs = coerce_exprs(
- fun,
- input_phy_exprs,
- input_schema,
- &aggregate_function::signature(fun),
- )?;
- if coerced_phy_exprs.is_empty() {
- return Err(DataFusionError::Plan(format!(
- "Invalid or wrong number of arguments passed to aggregate: '{}'",
- name,
- )));
- }
- let coerced_exprs_types = coerced_phy_exprs
- .iter()
- .map(|e| e.data_type(input_schema))
- .collect::<Result<Vec<_>>>()?;
-
// get the result data type for this aggregate function
let input_phy_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
let return_type = return_type(fun, &input_phy_types)?;
+ let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
(AggregateFunction::Count, false) => Arc::new(expressions::Count::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Count, true) =>
Arc::new(expressions::DistinctCount::new(
- coerced_exprs_types,
- coerced_phy_exprs,
+ input_phy_types,
+ input_phy_exprs,
name,
return_type,
)),
(AggregateFunction::Grouping, _) =>
Arc::new(expressions::Grouping::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Sum, true) =>
Arc::new(expressions::DistinctSum::new(
- vec![coerced_phy_exprs[0].clone()],
+ vec![input_phy_exprs[0].clone()],
name,
return_type,
)),
(AggregateFunction::ApproxDistinct, _) => {
Arc::new(expressions::ApproxDistinct::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
- coerced_exprs_types[0].clone(),
+ input_phy_types[0].clone(),
))
}
(AggregateFunction::ArrayAgg, false) =>
Arc::new(expressions::ArrayAgg::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
- coerced_exprs_types[0].clone(),
+ input_phy_types[0].clone(),
)),
(AggregateFunction::ArrayAgg, true) => {
Arc::new(expressions::DistinctArrayAgg::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
- coerced_exprs_types[0].clone(),
+ input_phy_types[0].clone(),
))
}
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Max, _) => Arc::new(expressions::Max::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
(AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
@@ -137,7 +118,7 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::Variance, false) =>
Arc::new(expressions::Variance::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
@@ -146,21 +127,17 @@ pub fn create_aggregate_expr(
"VAR(DISTINCT) aggregations are not available".to_string(),
));
}
- (AggregateFunction::VariancePop, false) => {
- Arc::new(expressions::VariancePop::new(
- coerced_phy_exprs[0].clone(),
- name,
- return_type,
- ))
- }
+ (AggregateFunction::VariancePop, false) => Arc::new(
+ expressions::VariancePop::new(input_phy_exprs[0].clone(), name,
return_type),
+ ),
(AggregateFunction::VariancePop, true) => {
return Err(DataFusionError::NotImplemented(
"VAR_POP(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::Covariance, false) =>
Arc::new(expressions::Covariance::new(
- coerced_phy_exprs[0].clone(),
- coerced_phy_exprs[1].clone(),
+ input_phy_exprs[0].clone(),
+ input_phy_exprs[1].clone(),
name,
return_type,
)),
@@ -171,8 +148,8 @@ pub fn create_aggregate_expr(
}
(AggregateFunction::CovariancePop, false) => {
Arc::new(expressions::CovariancePop::new(
- coerced_phy_exprs[0].clone(),
- coerced_phy_exprs[1].clone(),
+ input_phy_exprs[0].clone(),
+ input_phy_exprs[1].clone(),
name,
return_type,
))
@@ -183,7 +160,7 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::Stddev, false) =>
Arc::new(expressions::Stddev::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
@@ -193,7 +170,7 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::StddevPop, false) =>
Arc::new(expressions::StddevPop::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
@@ -204,8 +181,8 @@ pub fn create_aggregate_expr(
}
(AggregateFunction::Correlation, false) => {
Arc::new(expressions::Correlation::new(
- coerced_phy_exprs[0].clone(),
- coerced_phy_exprs[1].clone(),
+ input_phy_exprs[0].clone(),
+ input_phy_exprs[1].clone(),
name,
return_type,
))
@@ -216,17 +193,17 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::ApproxPercentileCont, false) => {
- if coerced_phy_exprs.len() == 2 {
+ if input_phy_exprs.len() == 2 {
Arc::new(expressions::ApproxPercentileCont::new(
// Pass in the desired percentile expr
- coerced_phy_exprs,
+ input_phy_exprs,
name,
return_type,
)?)
} else {
Arc::new(expressions::ApproxPercentileCont::new_with_max_size(
// Pass in the desired percentile expr
- coerced_phy_exprs,
+ input_phy_exprs,
name,
return_type,
)?)
@@ -241,7 +218,7 @@ pub fn create_aggregate_expr(
(AggregateFunction::ApproxPercentileContWithWeight, false) => {
Arc::new(expressions::ApproxPercentileContWithWeight::new(
// Pass in the desired percentile expr
- coerced_phy_exprs,
+ input_phy_exprs,
name,
return_type,
)?)
@@ -254,7 +231,7 @@ pub fn create_aggregate_expr(
}
(AggregateFunction::ApproxMedian, false) => {
Arc::new(expressions::ApproxMedian::try_new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)?)
@@ -265,7 +242,7 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::Median, false) =>
Arc::new(expressions::Median::new(
- coerced_phy_exprs[0].clone(),
+ input_phy_exprs[0].clone(),
name,
return_type,
)),
@@ -281,13 +258,14 @@ pub fn create_aggregate_expr(
mod tests {
use super::*;
use crate::expressions::{
- ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg,
Correlation,
- Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev,
Sum,
- Variance,
+ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont,
ArrayAgg, Avg,
+ Correlation, Count, Covariance, DistinctArrayAgg, DistinctCount, Max,
Min,
+ Stddev, Sum, Variance,
};
use arrow::datatypes::{DataType, Field};
use datafusion_common::ScalarValue;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
+ use datafusion_expr::{aggregate_function, type_coercion, Signature};
#[test]
fn test_count_arragg_approx_expr() -> Result<()> {
@@ -311,7 +289,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -344,9 +322,9 @@ mod tests {
DataType::List(Box::new(Field::new(
"item",
data_type.clone(),
- true
+ true,
))),
- false
+ false,
),
result_agg_phy_exprs.field().unwrap()
);
@@ -354,7 +332,7 @@ mod tests {
_ => {}
};
- let result_distinct = create_aggregate_expr(
+ let result_distinct = create_physical_agg_expr_for_test(
&fun,
true,
&input_phy_exprs[0..1],
@@ -387,9 +365,9 @@ mod tests {
DataType::List(Box::new(Field::new(
"item",
data_type.clone(),
- true
+ true,
))),
- false
+ false,
),
result_agg_phy_exprs.field().unwrap()
);
@@ -412,7 +390,7 @@ mod tests {
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&AggregateFunction::ApproxPercentileCont,
false,
&input_phy_exprs[..],
@@ -441,7 +419,7 @@ mod tests {
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
];
- let err = create_aggregate_expr(
+ let err = create_physical_agg_expr_for_test(
&AggregateFunction::ApproxPercentileCont,
false,
&input_phy_exprs[..],
@@ -472,7 +450,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -521,7 +499,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -583,7 +561,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -621,7 +599,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -659,7 +637,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -697,7 +675,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -744,7 +722,7 @@ mod tests {
.unwrap(),
),
];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..2],
@@ -791,7 +769,7 @@ mod tests {
.unwrap(),
),
];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..2],
@@ -838,7 +816,7 @@ mod tests {
.unwrap(),
),
];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..2],
@@ -876,7 +854,7 @@ mod tests {
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
)];
- let result_agg_phy_exprs = create_aggregate_expr(
+ let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
@@ -1065,4 +1043,58 @@ mod tests {
let observed = return_type(&AggregateFunction::Stddev,
&[DataType::Utf8]);
assert!(observed.is_err());
}
+
+ // Helper function
+ // Create aggregate expr with type coercion
+ fn create_physical_agg_expr_for_test(
+ fun: &AggregateFunction,
+ distinct: bool,
+ input_phy_exprs: &[Arc<dyn PhysicalExpr>],
+ input_schema: &Schema,
+ name: impl Into<String>,
+ ) -> Result<Arc<dyn AggregateExpr>> {
+ let name = name.into();
+ let coerced_phy_exprs = coerce_exprs_for_test(
+ fun,
+ input_phy_exprs,
+ input_schema,
+ &aggregate_function::signature(fun),
+ )?;
+ if coerced_phy_exprs.is_empty() {
+ return Err(DataFusionError::Plan(format!(
+ "Invalid or wrong number of arguments passed to aggregate:
'{}'",
+ name,
+ )));
+ }
+ create_aggregate_expr(fun, distinct, &coerced_phy_exprs, input_schema,
name)
+ }
+
+ // Returns the coerced exprs for each `input_exprs`.
+ // Get the coerced data type from `aggregate_rule::coerce_types` and add
`try_cast` if the
+ // data type of `input_exprs` need to be coerced.
+ fn coerce_exprs_for_test(
+ agg_fun: &AggregateFunction,
+ input_exprs: &[Arc<dyn PhysicalExpr>],
+ schema: &Schema,
+ signature: &Signature,
+ ) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
+ if input_exprs.is_empty() {
+ return Ok(vec![]);
+ }
+ let input_types = input_exprs
+ .iter()
+ .map(|e| e.data_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ // get the coerced data types
+ let coerced_types =
+ type_coercion::aggregates::coerce_types(agg_fun, &input_types,
signature)?;
+
+ // try cast if need
+ input_exprs
+ .iter()
+ .zip(coerced_types.into_iter())
+ .map(|(expr, coerced_type)| try_cast(expr.clone(), schema,
coerced_type))
+ .collect::<Result<Vec<_>>>()
+ }
}
diff --git a/datafusion/physical-expr/src/aggregate/coercion_rule.rs
b/datafusion/physical-expr/src/aggregate/coercion_rule.rs
deleted file mode 100644
index a8c68390a..000000000
--- a/datafusion/physical-expr/src/aggregate/coercion_rule.rs
+++ /dev/null
@@ -1,54 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Define coercion rules for Aggregate function.
-
-use crate::expressions::try_cast;
-use crate::PhysicalExpr;
-use arrow::datatypes::Schema;
-use datafusion_common::Result;
-use datafusion_expr::{type_coercion, AggregateFunction, Signature};
-use std::sync::Arc;
-
-/// Returns the coerced exprs for each `input_exprs`.
-/// Get the coerced data type from `aggregate_rule::coerce_types` and add
`try_cast` if the
-/// data type of `input_exprs` need to be coerced.
-pub fn coerce_exprs(
- agg_fun: &AggregateFunction,
- input_exprs: &[Arc<dyn PhysicalExpr>],
- schema: &Schema,
- signature: &Signature,
-) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
- if input_exprs.is_empty() {
- return Ok(vec![]);
- }
- let input_types = input_exprs
- .iter()
- .map(|e| e.data_type(schema))
- .collect::<Result<Vec<_>>>()?;
-
- // get the coerced data types
- let coerced_types =
- type_coercion::aggregates::coerce_types(agg_fun, &input_types,
signature)?;
-
- // try cast if need
- input_exprs
- .iter()
- .zip(coerced_types.into_iter())
- .map(|(expr, coerced_type)| try_cast(expr.clone(), schema,
coerced_type))
- .collect::<Result<Vec<_>>>()
-}
diff --git a/datafusion/physical-expr/src/aggregate/mod.rs
b/datafusion/physical-expr/src/aggregate/mod.rs
index ec338eb68..f63746874 100644
--- a/datafusion/physical-expr/src/aggregate/mod.rs
+++ b/datafusion/physical-expr/src/aggregate/mod.rs
@@ -31,7 +31,6 @@ pub(crate) mod approx_percentile_cont_with_weight;
pub(crate) mod array_agg;
pub(crate) mod array_agg_distinct;
pub(crate) mod average;
-pub(crate) mod coercion_rule;
pub(crate) mod correlation;
pub(crate) mod count;
pub(crate) mod count_distinct;