This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d83b3b26e8 Chore: Move `aggregate statistics` optimizer test from core
to optimizer crate (#12783)
d83b3b26e8 is described below
commit d83b3b26e897b87fb63c6928311526c8a2fe1aee
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Oct 9 11:38:48 2024 +0800
Chore: Move `aggregate statistics` optimizer test from core to optimizer
crate (#12783)
* move test from core to optimizer crate
Signed-off-by: jayzhan211 <[email protected]>
* cleanup
Signed-off-by: jayzhan211 <[email protected]>
* upd
Signed-off-by: jayzhan211 <[email protected]>
* clippy
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion-cli/Cargo.lock | 6 +-
.../physical_optimizer/aggregate_statistics.rs | 325 ------------------
datafusion/core/tests/physical_optimizer/mod.rs | 1 -
datafusion/physical-optimizer/Cargo.toml | 6 +
.../physical-optimizer/src/aggregate_statistics.rs | 376 ++++++++++++++++++++-
.../physical-optimizer/src/topk_aggregation.rs | 12 +-
6 files changed, 388 insertions(+), 338 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 8a6ccacbb3..3e0f9e69e1 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1521,9 +1521,11 @@ dependencies = [
name = "datafusion-physical-optimizer"
version = "42.0.0"
dependencies = [
+ "arrow",
"arrow-schema",
"datafusion-common",
"datafusion-execution",
+ "datafusion-expr-common",
"datafusion-physical-expr",
"datafusion-physical-plan",
"itertools",
@@ -2879,9 +2881,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
-version = "1.0.86"
+version = "1.0.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
+checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a"
dependencies = [
"unicode-ident",
]
diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
deleted file mode 100644
index bbf4dcd2b7..0000000000
--- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
+++ /dev/null
@@ -1,325 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Tests for the physical optimizer
-
-use datafusion_common::config::ConfigOptions;
-use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics;
-use datafusion_physical_optimizer::PhysicalOptimizerRule;
-use datafusion_physical_plan::aggregates::AggregateExec;
-use datafusion_physical_plan::projection::ProjectionExec;
-use datafusion_physical_plan::ExecutionPlan;
-use std::sync::Arc;
-
-use datafusion::error::Result;
-use datafusion::logical_expr::Operator;
-use datafusion::prelude::SessionContext;
-use datafusion::test_util::TestAggregate;
-use datafusion_physical_plan::aggregates::PhysicalGroupBy;
-use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
-use datafusion_physical_plan::common;
-use datafusion_physical_plan::filter::FilterExec;
-use datafusion_physical_plan::memory::MemoryExec;
-
-use arrow::array::Int32Array;
-use arrow::datatypes::{DataType, Field, Schema};
-use arrow::record_batch::RecordBatch;
-use datafusion_common::cast::as_int64_array;
-use datafusion_physical_expr::expressions::{self, cast};
-use datafusion_physical_plan::aggregates::AggregateMode;
-
-/// Mock data using a MemoryExec which has an exact count statistic
-fn mock_data() -> Result<Arc<MemoryExec>> {
- let schema = Arc::new(Schema::new(vec![
- Field::new("a", DataType::Int32, true),
- Field::new("b", DataType::Int32, true),
- ]));
-
- let batch = RecordBatch::try_new(
- Arc::clone(&schema),
- vec![
- Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
- Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
- ],
- )?;
-
- Ok(Arc::new(MemoryExec::try_new(
- &[vec![batch]],
- Arc::clone(&schema),
- None,
- )?))
-}
-
-/// Checks that the count optimization was applied and we still get the right
result
-async fn assert_count_optim_success(
- plan: AggregateExec,
- agg: TestAggregate,
-) -> Result<()> {
- let session_ctx = SessionContext::new();
- let state = session_ctx.state();
- let plan: Arc<dyn ExecutionPlan> = Arc::new(plan);
-
- let optimized =
- AggregateStatistics::new().optimize(Arc::clone(&plan),
state.config_options())?;
-
- // A ProjectionExec is a sign that the count optimization was applied
- assert!(optimized.as_any().is::<ProjectionExec>());
-
- // run both the optimized and nonoptimized plan
- let optimized_result =
- common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?;
- let nonoptimized_result =
- common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
- assert_eq!(optimized_result.len(), nonoptimized_result.len());
-
- // and validate the results are the same and expected
- assert_eq!(optimized_result.len(), 1);
- check_batch(optimized_result.into_iter().next().unwrap(), &agg);
- // check the non optimized one too to ensure types and names remain the
same
- assert_eq!(nonoptimized_result.len(), 1);
- check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
-
- Ok(())
-}
-
-fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
- let schema = batch.schema();
- let fields = schema.fields();
- assert_eq!(fields.len(), 1);
-
- let field = &fields[0];
- assert_eq!(field.name(), agg.column_name());
- assert_eq!(field.data_type(), &DataType::Int64);
- // note that nullabiolity differs
-
- assert_eq!(
- as_int64_array(batch.column(0)).unwrap().values(),
- &[agg.expected_count()]
- );
-}
-
-#[tokio::test]
-async fn test_count_partial_direct_child() -> Result<()> {
- // basic test case with the aggregation applied on a source with exact
statistics
- let source = mock_data()?;
- let schema = source.schema();
- let agg = TestAggregate::new_count_star();
-
- let partial_agg = AggregateExec::try_new(
- AggregateMode::Partial,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- source,
- Arc::clone(&schema),
- )?;
-
- let final_agg = AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- Arc::new(partial_agg),
- Arc::clone(&schema),
- )?;
-
- assert_count_optim_success(final_agg, agg).await?;
-
- Ok(())
-}
-
-#[tokio::test]
-async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
- // basic test case with the aggregation applied on a source with exact
statistics
- let source = mock_data()?;
- let schema = source.schema();
- let agg = TestAggregate::new_count_column(&schema);
-
- let partial_agg = AggregateExec::try_new(
- AggregateMode::Partial,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- source,
- Arc::clone(&schema),
- )?;
-
- let final_agg = AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- Arc::new(partial_agg),
- Arc::clone(&schema),
- )?;
-
- assert_count_optim_success(final_agg, agg).await?;
-
- Ok(())
-}
-
-#[tokio::test]
-async fn test_count_partial_indirect_child() -> Result<()> {
- let source = mock_data()?;
- let schema = source.schema();
- let agg = TestAggregate::new_count_star();
-
- let partial_agg = AggregateExec::try_new(
- AggregateMode::Partial,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- source,
- Arc::clone(&schema),
- )?;
-
- // We introduce an intermediate optimization step between the partial and
final aggregtator
- let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
-
- let final_agg = AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- Arc::new(coalesce),
- Arc::clone(&schema),
- )?;
-
- assert_count_optim_success(final_agg, agg).await?;
-
- Ok(())
-}
-
-#[tokio::test]
-async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
- let source = mock_data()?;
- let schema = source.schema();
- let agg = TestAggregate::new_count_column(&schema);
-
- let partial_agg = AggregateExec::try_new(
- AggregateMode::Partial,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- source,
- Arc::clone(&schema),
- )?;
-
- // We introduce an intermediate optimization step between the partial and
final aggregtator
- let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
-
- let final_agg = AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- Arc::new(coalesce),
- Arc::clone(&schema),
- )?;
-
- assert_count_optim_success(final_agg, agg).await?;
-
- Ok(())
-}
-
-#[tokio::test]
-async fn test_count_inexact_stat() -> Result<()> {
- let source = mock_data()?;
- let schema = source.schema();
- let agg = TestAggregate::new_count_star();
-
- // adding a filter makes the statistics inexact
- let filter = Arc::new(FilterExec::try_new(
- expressions::binary(
- expressions::col("a", &schema)?,
- Operator::Gt,
- cast(expressions::lit(1u32), &schema, DataType::Int32)?,
- &schema,
- )?,
- source,
- )?);
-
- let partial_agg = AggregateExec::try_new(
- AggregateMode::Partial,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- filter,
- Arc::clone(&schema),
- )?;
-
- let final_agg = AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- Arc::new(partial_agg),
- Arc::clone(&schema),
- )?;
-
- let conf = ConfigOptions::new();
- let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg),
&conf)?;
-
- // check that the original ExecutionPlan was not replaced
- assert!(optimized.as_any().is::<AggregateExec>());
-
- Ok(())
-}
-
-#[tokio::test]
-async fn test_count_with_nulls_inexact_stat() -> Result<()> {
- let source = mock_data()?;
- let schema = source.schema();
- let agg = TestAggregate::new_count_column(&schema);
-
- // adding a filter makes the statistics inexact
- let filter = Arc::new(FilterExec::try_new(
- expressions::binary(
- expressions::col("a", &schema)?,
- Operator::Gt,
- cast(expressions::lit(1u32), &schema, DataType::Int32)?,
- &schema,
- )?,
- source,
- )?);
-
- let partial_agg = AggregateExec::try_new(
- AggregateMode::Partial,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- filter,
- Arc::clone(&schema),
- )?;
-
- let final_agg = AggregateExec::try_new(
- AggregateMode::Final,
- PhysicalGroupBy::default(),
- vec![agg.count_expr(&schema)],
- vec![None],
- Arc::new(partial_agg),
- Arc::clone(&schema),
- )?;
-
- let conf = ConfigOptions::new();
- let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg),
&conf)?;
-
- // check that the original ExecutionPlan was not replaced
- assert!(optimized.as_any().is::<AggregateExec>());
-
- Ok(())
-}
diff --git a/datafusion/core/tests/physical_optimizer/mod.rs
b/datafusion/core/tests/physical_optimizer/mod.rs
index 4ec981bf2a..c06783aa02 100644
--- a/datafusion/core/tests/physical_optimizer/mod.rs
+++ b/datafusion/core/tests/physical_optimizer/mod.rs
@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-mod aggregate_statistics;
mod combine_partial_final_agg;
mod limit_pushdown;
mod limited_distinct_aggregation;
diff --git a/datafusion/physical-optimizer/Cargo.toml
b/datafusion/physical-optimizer/Cargo.toml
index acf3eee105..e7bf4a80fc 100644
--- a/datafusion/physical-optimizer/Cargo.toml
+++ b/datafusion/physical-optimizer/Cargo.toml
@@ -32,9 +32,15 @@ rust-version = { workspace = true }
workspace = true
[dependencies]
+arrow = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-execution = { workspace = true }
+datafusion-expr-common = { workspace = true, default-features = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-plan = { workspace = true }
itertools = { workspace = true }
+
+[dev-dependencies]
+datafusion-functions-aggregate = { workspace = true }
+tokio = { workspace = true }
diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs
b/datafusion/physical-optimizer/src/aggregate_statistics.rs
index a11b498b95..fd21362fd3 100644
--- a/datafusion/physical-optimizer/src/aggregate_statistics.rs
+++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs
@@ -20,15 +20,15 @@ use std::sync::Arc;
use datafusion_common::config::ConfigOptions;
use datafusion_common::scalar::ScalarValue;
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_physical_plan::aggregates::AggregateExec;
+use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
use datafusion_physical_plan::projection::ProjectionExec;
+use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
use datafusion_physical_plan::{expressions, ExecutionPlan};
use crate::PhysicalOptimizerRule;
-use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
-use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
/// Optimizer that uses available statistics for aggregate functions
#[derive(Default, Debug)]
@@ -146,3 +146,373 @@ fn take_optimizable_value_from_statistics(
let value = agg_expr.fun().value_from_stats(statistics_args);
value.map(|val| (val, agg_expr.name().to_string()))
}
+
+#[cfg(test)]
+mod tests {
+ use crate::aggregate_statistics::AggregateStatistics;
+ use crate::PhysicalOptimizerRule;
+ use datafusion_common::config::ConfigOptions;
+ use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
+ use datafusion_execution::TaskContext;
+ use datafusion_functions_aggregate::count::count_udaf;
+ use datafusion_physical_expr::aggregate::AggregateExprBuilder;
+ use datafusion_physical_expr::PhysicalExpr;
+ use datafusion_physical_plan::aggregates::AggregateExec;
+ use datafusion_physical_plan::projection::ProjectionExec;
+ use datafusion_physical_plan::udaf::AggregateFunctionExpr;
+ use datafusion_physical_plan::ExecutionPlan;
+ use std::sync::Arc;
+
+ use datafusion_common::Result;
+ use datafusion_expr_common::operator::Operator;
+
+ use datafusion_physical_plan::aggregates::PhysicalGroupBy;
+ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
+ use datafusion_physical_plan::common;
+ use datafusion_physical_plan::filter::FilterExec;
+ use datafusion_physical_plan::memory::MemoryExec;
+
+ use arrow::array::Int32Array;
+ use arrow::datatypes::{DataType, Field, Schema};
+ use arrow::record_batch::RecordBatch;
+ use datafusion_common::cast::as_int64_array;
+ use datafusion_physical_expr::expressions::{self, cast};
+ use datafusion_physical_plan::aggregates::AggregateMode;
+
+ /// Describe the type of aggregate being tested
+ pub enum TestAggregate {
+ /// Testing COUNT(*) type aggregates
+ CountStar,
+
+ /// Testing for COUNT(column) aggregate
+ ColumnA(Arc<Schema>),
+ }
+
+ impl TestAggregate {
+ /// Create a new COUNT(*) aggregate
+ pub fn new_count_star() -> Self {
+ Self::CountStar
+ }
+
+ /// Create a new COUNT(column) aggregate
+ pub fn new_count_column(schema: &Arc<Schema>) -> Self {
+ Self::ColumnA(Arc::clone(schema))
+ }
+
+ /// Return appropriate expr depending if COUNT is for col or table (*)
+ pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr {
+ AggregateExprBuilder::new(count_udaf(), vec![self.column()])
+ .schema(Arc::new(schema.clone()))
+ .alias(self.column_name())
+ .build()
+ .unwrap()
+ }
+
+ /// what argument would this aggregate need in the plan?
+ fn column(&self) -> Arc<dyn PhysicalExpr> {
+ match self {
+ Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
+ Self::ColumnA(s) => expressions::col("a", s).unwrap(),
+ }
+ }
+
+ /// What name would this aggregate produce in a plan?
+ pub fn column_name(&self) -> &'static str {
+ match self {
+ Self::CountStar => "COUNT(*)",
+ Self::ColumnA(_) => "COUNT(a)",
+ }
+ }
+
+ /// What is the expected count?
+ pub fn expected_count(&self) -> i64 {
+ match self {
+ TestAggregate::CountStar => 3,
+ TestAggregate::ColumnA(_) => 2,
+ }
+ }
+ }
+
+ /// Mock data using a MemoryExec which has an exact count statistic
+ fn mock_data() -> Result<Arc<MemoryExec>> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Int32, true),
+ ]));
+
+ let batch = RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
+ Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
+ ],
+ )?;
+
+ Ok(Arc::new(MemoryExec::try_new(
+ &[vec![batch]],
+ Arc::clone(&schema),
+ None,
+ )?))
+ }
+
+ /// Checks that the count optimization was applied and we still get the
right result
+ async fn assert_count_optim_success(
+ plan: AggregateExec,
+ agg: TestAggregate,
+ ) -> Result<()> {
+ let task_ctx = Arc::new(TaskContext::default());
+ let plan: Arc<dyn ExecutionPlan> = Arc::new(plan);
+
+ let config = ConfigOptions::new();
+ let optimized =
+ AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?;
+
+ // A ProjectionExec is a sign that the count optimization was applied
+ assert!(optimized.as_any().is::<ProjectionExec>());
+
+ // run both the optimized and nonoptimized plan
+ let optimized_result =
+ common::collect(optimized.execute(0,
Arc::clone(&task_ctx))?).await?;
+ let nonoptimized_result = common::collect(plan.execute(0,
task_ctx)?).await?;
+ assert_eq!(optimized_result.len(), nonoptimized_result.len());
+
+ // and validate the results are the same and expected
+ assert_eq!(optimized_result.len(), 1);
+ check_batch(optimized_result.into_iter().next().unwrap(), &agg);
+ // check the non optimized one too to ensure types and names remain
the same
+ assert_eq!(nonoptimized_result.len(), 1);
+ check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
+
+ Ok(())
+ }
+
+ fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
+ let schema = batch.schema();
+ let fields = schema.fields();
+ assert_eq!(fields.len(), 1);
+
+ let field = &fields[0];
+ assert_eq!(field.name(), agg.column_name());
+ assert_eq!(field.data_type(), &DataType::Int64);
+ // note that nullabiolity differs
+
+ assert_eq!(
+ as_int64_array(batch.column(0)).unwrap().values(),
+ &[agg.expected_count()]
+ );
+ }
+
+ #[tokio::test]
+ async fn test_count_partial_direct_child() -> Result<()> {
+ // basic test case with the aggregation applied on a source with exact
statistics
+ let source = mock_data()?;
+ let schema = source.schema();
+ let agg = TestAggregate::new_count_star();
+
+ let partial_agg = AggregateExec::try_new(
+ AggregateMode::Partial,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ source,
+ Arc::clone(&schema),
+ )?;
+
+ let final_agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ Arc::new(partial_agg),
+ Arc::clone(&schema),
+ )?;
+
+ assert_count_optim_success(final_agg, agg).await?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
+ // basic test case with the aggregation applied on a source with exact
statistics
+ let source = mock_data()?;
+ let schema = source.schema();
+ let agg = TestAggregate::new_count_column(&schema);
+
+ let partial_agg = AggregateExec::try_new(
+ AggregateMode::Partial,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ source,
+ Arc::clone(&schema),
+ )?;
+
+ let final_agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ Arc::new(partial_agg),
+ Arc::clone(&schema),
+ )?;
+
+ assert_count_optim_success(final_agg, agg).await?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_count_partial_indirect_child() -> Result<()> {
+ let source = mock_data()?;
+ let schema = source.schema();
+ let agg = TestAggregate::new_count_star();
+
+ let partial_agg = AggregateExec::try_new(
+ AggregateMode::Partial,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ source,
+ Arc::clone(&schema),
+ )?;
+
+ // We introduce an intermediate optimization step between the partial
and final aggregtator
+ let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
+
+ let final_agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ Arc::new(coalesce),
+ Arc::clone(&schema),
+ )?;
+
+ assert_count_optim_success(final_agg, agg).await?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
+ let source = mock_data()?;
+ let schema = source.schema();
+ let agg = TestAggregate::new_count_column(&schema);
+
+ let partial_agg = AggregateExec::try_new(
+ AggregateMode::Partial,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ source,
+ Arc::clone(&schema),
+ )?;
+
+ // We introduce an intermediate optimization step between the partial
and final aggregtator
+ let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
+
+ let final_agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ Arc::new(coalesce),
+ Arc::clone(&schema),
+ )?;
+
+ assert_count_optim_success(final_agg, agg).await?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_count_inexact_stat() -> Result<()> {
+ let source = mock_data()?;
+ let schema = source.schema();
+ let agg = TestAggregate::new_count_star();
+
+ // adding a filter makes the statistics inexact
+ let filter = Arc::new(FilterExec::try_new(
+ expressions::binary(
+ expressions::col("a", &schema)?,
+ Operator::Gt,
+ cast(expressions::lit(1u32), &schema, DataType::Int32)?,
+ &schema,
+ )?,
+ source,
+ )?);
+
+ let partial_agg = AggregateExec::try_new(
+ AggregateMode::Partial,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ filter,
+ Arc::clone(&schema),
+ )?;
+
+ let final_agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ Arc::new(partial_agg),
+ Arc::clone(&schema),
+ )?;
+
+ let conf = ConfigOptions::new();
+ let optimized =
+ AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
+
+ // check that the original ExecutionPlan was not replaced
+ assert!(optimized.as_any().is::<AggregateExec>());
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_count_with_nulls_inexact_stat() -> Result<()> {
+ let source = mock_data()?;
+ let schema = source.schema();
+ let agg = TestAggregate::new_count_column(&schema);
+
+ // adding a filter makes the statistics inexact
+ let filter = Arc::new(FilterExec::try_new(
+ expressions::binary(
+ expressions::col("a", &schema)?,
+ Operator::Gt,
+ cast(expressions::lit(1u32), &schema, DataType::Int32)?,
+ &schema,
+ )?,
+ source,
+ )?);
+
+ let partial_agg = AggregateExec::try_new(
+ AggregateMode::Partial,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ filter,
+ Arc::clone(&schema),
+ )?;
+
+ let final_agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ PhysicalGroupBy::default(),
+ vec![agg.count_expr(&schema)],
+ vec![None],
+ Arc::new(partial_agg),
+ Arc::clone(&schema),
+ )?;
+
+ let conf = ConfigOptions::new();
+ let optimized =
+ AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
+
+ // check that the original ExecutionPlan was not replaced
+ assert!(optimized.as_any().is::<AggregateExec>());
+
+ Ok(())
+ }
+}
diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs
b/datafusion/physical-optimizer/src/topk_aggregation.rs
index 5dec99535c..c8a28ed0ec 100644
--- a/datafusion/physical-optimizer/src/topk_aggregation.rs
+++ b/datafusion/physical-optimizer/src/topk_aggregation.rs
@@ -19,19 +19,17 @@
use std::sync::Arc;
-use datafusion_physical_plan::aggregates::AggregateExec;
-use datafusion_physical_plan::sorts::sort::SortExec;
-use datafusion_physical_plan::ExecutionPlan;
-
-use arrow_schema::DataType;
+use crate::PhysicalOptimizerRule;
+use arrow::datatypes::DataType;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
use datafusion_physical_expr::expressions::Column;
-
-use crate::PhysicalOptimizerRule;
+use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::execution_plan::CardinalityEffect;
use datafusion_physical_plan::projection::ProjectionExec;
+use datafusion_physical_plan::sorts::sort::SortExec;
+use datafusion_physical_plan::ExecutionPlan;
use itertools::Itertools;
/// An optimizer rule that passes a `limit` hint to aggregations if the whole
result is not needed
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]