This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a154884545 [MINOR]: Add size check for aggregate (#8813)
a154884545 is described below

commit a154884545cfdeb1a6c20872b3882a5624cd1119
Author: Mustafa Akur <[email protected]>
AuthorDate: Thu Jan 11 08:57:36 2024 +0300

    [MINOR]: Add size check for aggregate (#8813)
    
    * Add size check for aggregate
    
    * Fix failing tests
    
    * Minor changes
---
 .../combine_partial_final_agg.rs                   |  6 +++--
 .../limited_distinct_aggregation.rs                | 27 +++++++++++-----------
 datafusion/physical-plan/src/aggregates/mod.rs     | 10 ++++++--
 datafusion/physical-plan/src/limit.rs              |  2 +-
 4 files changed, 27 insertions(+), 18 deletions(-)

diff --git 
a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs 
b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
index 7359a64630..61eb2381c6 100644
--- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
+++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
@@ -269,12 +269,13 @@ mod tests {
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
     ) -> Arc<dyn ExecutionPlan> {
         let schema = input.schema();
+        let n_aggr = aggr_expr.len();
         Arc::new(
             AggregateExec::try_new(
                 AggregateMode::Partial,
                 group_by,
                 aggr_expr,
-                vec![],
+                vec![None; n_aggr],
                 input,
                 schema,
             )
@@ -288,12 +289,13 @@ mod tests {
         aggr_expr: Vec<Arc<dyn AggregateExpr>>,
     ) -> Arc<dyn ExecutionPlan> {
         let schema = input.schema();
+        let n_aggr = aggr_expr.len();
         Arc::new(
             AggregateExec::try_new(
                 AggregateMode::Final,
                 group_by,
                 aggr_expr,
-                vec![],
+                vec![None; n_aggr],
                 input,
                 schema,
             )
diff --git 
a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs 
b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
index 540f9a6a13..9855247151 100644
--- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
+++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
@@ -305,7 +305,7 @@ mod tests {
             AggregateMode::Partial,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
             vec![],         /* aggr_expr */
-            vec![None],     /* filter_expr */
+            vec![],         /* filter_expr */
             source,         /* input */
             schema.clone(), /* input_schema */
         )?;
@@ -313,7 +313,7 @@ mod tests {
             AggregateMode::Final,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
             vec![],                /* aggr_expr */
-            vec![None],            /* filter_expr */
+            vec![],                /* filter_expr */
             Arc::new(partial_agg), /* input */
             schema.clone(),        /* input_schema */
         )?;
@@ -355,7 +355,7 @@ mod tests {
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
             vec![],         /* aggr_expr */
-            vec![None],     /* filter_expr */
+            vec![],         /* filter_expr */
             source,         /* input */
             schema.clone(), /* input_schema */
         )?;
@@ -396,7 +396,7 @@ mod tests {
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
             vec![],         /* aggr_expr */
-            vec![None],     /* filter_expr */
+            vec![],         /* filter_expr */
             source,         /* input */
             schema.clone(), /* input_schema */
         )?;
@@ -437,7 +437,7 @@ mod tests {
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec!["a".to_string(), 
"b".to_string()]),
             vec![],         /* aggr_expr */
-            vec![None],     /* filter_expr */
+            vec![],         /* filter_expr */
             source,         /* input */
             schema.clone(), /* input_schema */
         )?;
@@ -445,7 +445,7 @@ mod tests {
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
             vec![],                 /* aggr_expr */
-            vec![None],             /* filter_expr */
+            vec![],                 /* filter_expr */
             Arc::new(group_by_agg), /* input */
             schema.clone(),         /* input_schema */
         )?;
@@ -487,7 +487,7 @@ mod tests {
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec![]),
             vec![],         /* aggr_expr */
-            vec![None],     /* filter_expr */
+            vec![],         /* filter_expr */
             source,         /* input */
             schema.clone(), /* input_schema */
         )?;
@@ -549,13 +549,14 @@ mod tests {
             cast(expressions::lit(1u32), &schema, DataType::Int32)?,
             &schema,
         )?);
+        let agg = TestAggregate::new_count_star();
         let single_agg = AggregateExec::try_new(
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
-            vec![],            /* aggr_expr */
-            vec![filter_expr], /* filter_expr */
-            source,            /* input */
-            schema.clone(),    /* input_schema */
+            vec![agg.count_expr()], /* aggr_expr */
+            vec![filter_expr],      /* filter_expr */
+            source,                 /* input */
+            schema.clone(),         /* input_schema */
         )?;
         let limit_exec = LocalLimitExec::new(
             Arc::new(single_agg),
@@ -565,7 +566,7 @@ mod tests {
         // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` 
not printing out
         let expected = [
             "LocalLimitExec: fetch=10",
-            "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]",
+            "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]",
             "MemoryExec: partitions=1, partition_sizes=[1]",
         ];
         let plan: Arc<dyn ExecutionPlan> = Arc::new(limit_exec);
@@ -588,7 +589,7 @@ mod tests {
             AggregateMode::Single,
             build_group_by(&schema.clone(), vec!["a".to_string()]),
             vec![],         /* aggr_expr */
-            vec![None],     /* filter_expr */
+            vec![],         /* filter_expr */
             source,         /* input */
             schema.clone(), /* input_schema */
         )?;
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index 0b94dd01cf..4f37be7263 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -37,7 +37,7 @@ use arrow::array::ArrayRef;
 use arrow::datatypes::{Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use datafusion_common::stats::Precision;
-use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
+use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, 
Result};
 use datafusion_execution::TaskContext;
 use datafusion_expr::Accumulator;
 use datafusion_physical_expr::{
@@ -321,6 +321,11 @@ impl AggregateExec {
         input_schema: SchemaRef,
         schema: SchemaRef,
     ) -> Result<Self> {
+        // Make sure arguments are consistent in size
+        if aggr_expr.len() != filter_expr.len() {
+            return internal_err!("Inconsistent aggregate expr: {:?} and filter 
expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
+        }
+
         let input_eq_properties = input.equivalence_properties();
         // Get GROUP BY expressions:
         let groupby_exprs = group_by.input_exprs();
@@ -1795,11 +1800,12 @@ mod tests {
             (1, groups_some.clone(), aggregates_v1),
             (2, groups_some, aggregates_v2),
         ] {
+            let n_aggr = aggregates.len();
             let partial_aggregate = Arc::new(AggregateExec::try_new(
                 AggregateMode::Partial,
                 groups,
                 aggregates,
-                vec![None; 3],
+                vec![None; n_aggr],
                 input.clone(),
                 input_schema.clone(),
             )?);
diff --git a/datafusion/physical-plan/src/limit.rs 
b/datafusion/physical-plan/src/limit.rs
index 37e8ffd761..c31d5f62c7 100644
--- a/datafusion/physical-plan/src/limit.rs
+++ b/datafusion/physical-plan/src/limit.rs
@@ -877,7 +877,7 @@ mod tests {
             AggregateMode::Final,
             build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
             vec![],
-            vec![None],
+            vec![],
             csv.clone(),
             csv.schema().clone(),
         )?;

Reply via email to