alamb commented on a change in pull request #1387:
URL: https://github.com/apache/arrow-datafusion/pull/1387#discussion_r764400376
##########
File path: datafusion/src/physical_plan/aggregates.rs
##########
@@ -262,6 +266,199 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
mod tests {
use super::*;
use crate::error::Result;
+ use crate::physical_plan::expressions::{
+ ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum,
+ };
+
+ #[test]
+ fn test_count_arragg_approx_expr() -> Result<()> {
+ let funcs = vec![
+ AggregateFunction::Count,
+ AggregateFunction::ArrayAgg,
+ AggregateFunction::ApproxDistinct,
+ ];
+ let data_types = vec![
+ DataType::UInt32,
+ DataType::Int32,
+ DataType::Float32,
+ DataType::Float64,
+ DataType::Decimal(10, 2),
+ DataType::Utf8,
+ ];
+ for fun in funcs {
+ for data_type in &data_types {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ )];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &fun,
+ false,
+ &input_phy_exprs[0..1],
+ &input_schema,
+ "c1",
+ )?;
+ match fun {
+ AggregateFunction::Count => {
+ assert!(result_agg_phy_exprs.as_any().is::<Count>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", DataType::UInt64, true),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ AggregateFunction::ApproxDistinct => {
+
assert!(result_agg_phy_exprs.as_any().is::<ApproxDistinct>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", DataType::UInt64, false),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ AggregateFunction::ArrayAgg => {
+
assert!(result_agg_phy_exprs.as_any().is::<ArrayAgg>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new(
+ "c1",
+ DataType::List(Box::new(Field::new(
+ "item",
+ data_type.clone(),
+ true
+ ))),
+ false
+ ),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_min_max_expr() -> Result<()> {
+ let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
+ let data_types = vec![
+ DataType::UInt32,
+ DataType::Int32,
+ DataType::Float32,
+ DataType::Float64,
+ DataType::Decimal(10, 2),
+ DataType::Utf8,
+ ];
+ for fun in funcs {
+ for data_type in &data_types {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ )];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &fun,
+ false,
+ &input_phy_exprs[0..1],
+ &input_schema,
+ "c1",
+ )?;
+ match fun {
+ AggregateFunction::Min => {
+ assert!(result_agg_phy_exprs.as_any().is::<Min>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", data_type.clone(), true),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ AggregateFunction::Max => {
+ assert!(result_agg_phy_exprs.as_any().is::<Max>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ assert_eq!(
+ Field::new("c1", data_type.clone(), true),
+ result_agg_phy_exprs.field().unwrap()
+ );
+ }
+ _ => {}
+ };
+ }
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_sum_avg_expr() -> Result<()> {
+ let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
+ let data_types = vec![
+ DataType::UInt32,
+ DataType::UInt64,
+ DataType::Int32,
+ DataType::Int64,
+ DataType::Float32,
+ DataType::Float64,
+ ];
+ for fun in funcs {
+ for data_type in &data_types {
+ let input_schema =
+ Schema::new(vec![Field::new("c1", data_type.clone(),
true)]);
+ let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> =
vec![Arc::new(
+ expressions::Column::new_with_schema("c1",
&input_schema).unwrap(),
+ )];
+ let result_agg_phy_exprs = create_aggregate_expr(
+ &fun,
+ false,
+ &input_phy_exprs[0..1],
+ &input_schema,
+ "c1",
+ )?;
+ match fun {
+ AggregateFunction::Sum => {
+ assert!(result_agg_phy_exprs.as_any().is::<Sum>());
+ assert_eq!("c1", result_agg_phy_exprs.name());
+ let mut expect_type = data_type.clone();
+ if matches!(
+ data_type,
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ ) {
+ expect_type = DataType::UInt64;
+ } else if matches!(
+ data_type,
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ ) {
+ expect_type = DataType::Int64;
+ } else if matches!(
+ data_type,
+ DataType::Float32 | DataType::Float64
+ ) {
+ expect_type = data_type.clone();
+ }
Review comment:
Just FYI you can write this kind of logic in a more concise way with
something like (untested and abbreviated)
```rust
let expect_type = match (data_type) {
DataType::UInt8 | .... => DataType::UInt64,
DataType::Int8 | .... => DataType::Int64,
_ => data_type.clone()
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]