This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a154884545 [MINOR]: Add size check for aggregate (#8813)
a154884545 is described below
commit a154884545cfdeb1a6c20872b3882a5624cd1119
Author: Mustafa Akur <[email protected]>
AuthorDate: Thu Jan 11 08:57:36 2024 +0300
[MINOR]: Add size check for aggregate (#8813)
* Add size check for aggregate
* Fix failing tests
* Minor changes
---
.../combine_partial_final_agg.rs | 6 +++--
.../limited_distinct_aggregation.rs | 27 +++++++++++-----------
datafusion/physical-plan/src/aggregates/mod.rs | 10 ++++++--
datafusion/physical-plan/src/limit.rs | 2 +-
4 files changed, 27 insertions(+), 18 deletions(-)
diff --git
a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
index 7359a64630..61eb2381c6 100644
--- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
+++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs
@@ -269,12 +269,13 @@ mod tests {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
) -> Arc<dyn ExecutionPlan> {
let schema = input.schema();
+ let n_aggr = aggr_expr.len();
Arc::new(
AggregateExec::try_new(
AggregateMode::Partial,
group_by,
aggr_expr,
- vec![],
+ vec![None; n_aggr],
input,
schema,
)
@@ -288,12 +289,13 @@ mod tests {
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
) -> Arc<dyn ExecutionPlan> {
let schema = input.schema();
+ let n_aggr = aggr_expr.len();
Arc::new(
AggregateExec::try_new(
AggregateMode::Final,
group_by,
aggr_expr,
- vec![],
+ vec![None; n_aggr],
input,
schema,
)
diff --git
a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
index 540f9a6a13..9855247151 100644
--- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
+++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs
@@ -305,7 +305,7 @@ mod tests {
AggregateMode::Partial,
build_group_by(&schema.clone(), vec!["a".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
source, /* input */
schema.clone(), /* input_schema */
)?;
@@ -313,7 +313,7 @@ mod tests {
AggregateMode::Final,
build_group_by(&schema.clone(), vec!["a".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
Arc::new(partial_agg), /* input */
schema.clone(), /* input_schema */
)?;
@@ -355,7 +355,7 @@ mod tests {
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
source, /* input */
schema.clone(), /* input_schema */
)?;
@@ -396,7 +396,7 @@ mod tests {
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
source, /* input */
schema.clone(), /* input_schema */
)?;
@@ -437,7 +437,7 @@ mod tests {
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string(),
"b".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
source, /* input */
schema.clone(), /* input_schema */
)?;
@@ -445,7 +445,7 @@ mod tests {
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
Arc::new(group_by_agg), /* input */
schema.clone(), /* input_schema */
)?;
@@ -487,7 +487,7 @@ mod tests {
AggregateMode::Single,
build_group_by(&schema.clone(), vec![]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
source, /* input */
schema.clone(), /* input_schema */
)?;
@@ -549,13 +549,14 @@ mod tests {
cast(expressions::lit(1u32), &schema, DataType::Int32)?,
&schema,
)?);
+ let agg = TestAggregate::new_count_star();
let single_agg = AggregateExec::try_new(
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
- vec![], /* aggr_expr */
- vec![filter_expr], /* filter_expr */
- source, /* input */
- schema.clone(), /* input_schema */
+ vec![agg.count_expr()], /* aggr_expr */
+ vec![filter_expr], /* filter_expr */
+ source, /* input */
+ schema.clone(), /* input_schema */
)?;
let limit_exec = LocalLimitExec::new(
Arc::new(single_agg),
@@ -565,7 +566,7 @@ mod tests {
// TODO(msirek): open an issue for `filter_expr` of `AggregateExec`
not printing out
let expected = [
"LocalLimitExec: fetch=10",
- "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[]",
+ "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]",
"MemoryExec: partitions=1, partition_sizes=[1]",
];
let plan: Arc<dyn ExecutionPlan> = Arc::new(limit_exec);
@@ -588,7 +589,7 @@ mod tests {
AggregateMode::Single,
build_group_by(&schema.clone(), vec!["a".to_string()]),
vec![], /* aggr_expr */
- vec![None], /* filter_expr */
+ vec![], /* filter_expr */
source, /* input */
schema.clone(), /* input_schema */
)?;
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index 0b94dd01cf..4f37be7263 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -37,7 +37,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::stats::Precision;
-use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
+use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError,
Result};
use datafusion_execution::TaskContext;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::{
@@ -321,6 +321,11 @@ impl AggregateExec {
input_schema: SchemaRef,
schema: SchemaRef,
) -> Result<Self> {
+ // Make sure arguments are consistent in size
+ if aggr_expr.len() != filter_expr.len() {
+ return internal_err!("Inconsistent aggregate expr: {:?} and filter
expr: {:?} for AggregateExec, their size should match", aggr_expr, filter_expr);
+ }
+
let input_eq_properties = input.equivalence_properties();
// Get GROUP BY expressions:
let groupby_exprs = group_by.input_exprs();
@@ -1795,11 +1800,12 @@ mod tests {
(1, groups_some.clone(), aggregates_v1),
(2, groups_some, aggregates_v2),
] {
+ let n_aggr = aggregates.len();
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates,
- vec![None; 3],
+ vec![None; n_aggr],
input.clone(),
input_schema.clone(),
)?);
diff --git a/datafusion/physical-plan/src/limit.rs
b/datafusion/physical-plan/src/limit.rs
index 37e8ffd761..c31d5f62c7 100644
--- a/datafusion/physical-plan/src/limit.rs
+++ b/datafusion/physical-plan/src/limit.rs
@@ -877,7 +877,7 @@ mod tests {
AggregateMode::Final,
build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
vec![],
- vec![None],
+ vec![],
csv.clone(),
csv.schema().clone(),
)?;