alamb commented on a change in pull request #1387:
URL: https://github.com/apache/arrow-datafusion/pull/1387#discussion_r764404370
##########
File path: datafusion/src/physical_plan/aggregates.rs
##########
@@ -262,6 +266,199 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
mod tests {
use super::*;
use crate::error::Result;
+ use crate::physical_plan::expressions::{
+ ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum,
+ };
+
+ #[test]
+ fn test_count_arragg_approx_expr() -> Result<()> {
+ let funcs = vec![
+ AggregateFunction::Count,
+ AggregateFunction::ArrayAgg,
+ AggregateFunction::ApproxDistinct,
+ ];
+ let data_types = vec![
+ DataType::UInt32,
+ DataType::Int32,
+ DataType::Float32,
+ DataType::Float64,
+ DataType::Decimal(10, 2),
+ DataType::Utf8,
+ ];
+ for fun in funcs {
+ for data_type in &data_types {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ )];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &fun,
+ false,
+ &input_phy_exprs[0..1],
+ &input_schema,
+ "c1",
+ )?;
+ match fun {
+ AggregateFunction::Count => {
+ assert!(result_agg_phy_exprs.as_any().is::<Count>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", DataType::UInt64, true),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ AggregateFunction::ApproxDistinct => {
+
assert!(result_agg_phy_exprs.as_any().is::<ApproxDistinct>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", DataType::UInt64, false),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ AggregateFunction::ArrayAgg => {
+
assert!(result_agg_phy_exprs.as_any().is::<ArrayAgg>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new(
+ "c1",
+ DataType::List(Box::new(Field::new(
+ "item",
+ data_type.clone(),
+ true
+ ))),
+ false
+ ),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_min_max_expr() -> Result<()> {
+ let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
+ let data_types = vec![
+ DataType::UInt32,
+ DataType::Int32,
+ DataType::Float32,
+ DataType::Float64,
+ DataType::Decimal(10, 2),
+ DataType::Utf8,
+ ];
+ for fun in funcs {
+ for data_type in &data_types {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ )];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &fun,
+ false,
+ &input_phy_exprs[0..1],
+ &input_schema,
+ "c1",
+ )?;
+ match fun {
+ AggregateFunction::Min => {
+ assert!(result_agg_phy_exprs.as_any().is::<Min>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", data_type.clone(), true),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ AggregateFunction::Max => {
+ assert!(result_agg_phy_exprs.as_any().is::<Max>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", data_type.clone(), true),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_sum_avg_expr() -> Result<()> {
+ let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
+ let data_types = vec![
+ DataType::UInt32,
+ DataType::UInt64,
+ DataType::Int32,
+ DataType::Int64,
+ DataType::Float32,
+ DataType::Float64,
+ ];
+ for fun in funcs {
+ for data_type in &data_types {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ )];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &fun,
+ false,
+ &input_phy_exprs[0..1],
+ &input_schema,
+ "c1",
+ )?;
+ match fun {
+ AggregateFunction::Sum => {
+ assert!(result_agg_phy_exprs.as_any().is::<Sum>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ let mut expect_type = data_type.clone();
+ if matches!(
+ data_type,
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ ) {
+ expect_type = DataType::UInt64;
+ } else if matches!(
+ data_type,
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ ) {
+ expect_type = DataType::Int64;
+ } else if matches!(
+ data_type,
+ DataType::Float32 | DataType::Float64
+ ) {
+ expect_type = data_type.clone();
+ }
Review comment:
https://github.com/apache/arrow-datafusion/pull/1416 <-- PR wth proposed
cleanup
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]