This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d83b3b26e8 Chore: Move `aggregate statistics` optimizer test from core 
to optimizer crate (#12783)
d83b3b26e8 is described below

commit d83b3b26e897b87fb63c6928311526c8a2fe1aee
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Oct 9 11:38:48 2024 +0800

    Chore: Move `aggregate statistics` optimizer test from core to optimizer 
crate (#12783)
    
    * move test from core to optimizer crate
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * cleanup
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * upd
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * clippy
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion-cli/Cargo.lock                          |   6 +-
 .../physical_optimizer/aggregate_statistics.rs     | 325 ------------------
 datafusion/core/tests/physical_optimizer/mod.rs    |   1 -
 datafusion/physical-optimizer/Cargo.toml           |   6 +
 .../physical-optimizer/src/aggregate_statistics.rs | 376 ++++++++++++++++++++-
 .../physical-optimizer/src/topk_aggregation.rs     |  12 +-
 6 files changed, 388 insertions(+), 338 deletions(-)

diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 8a6ccacbb3..3e0f9e69e1 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1521,9 +1521,11 @@ dependencies = [
 name = "datafusion-physical-optimizer"
 version = "42.0.0"
 dependencies = [
+ "arrow",
  "arrow-schema",
  "datafusion-common",
  "datafusion-execution",
+ "datafusion-expr-common",
  "datafusion-physical-expr",
  "datafusion-physical-plan",
  "itertools",
@@ -2879,9 +2881,9 @@ dependencies = [
 
 [[package]]
 name = "proc-macro2"
-version = "1.0.86"
+version = "1.0.87"
 source = "registry+https://github.com/rust-lang/crates.io-index";
-checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
+checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a"
 dependencies = [
  "unicode-ident",
 ]
diff --git a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs 
b/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
deleted file mode 100644
index bbf4dcd2b7..0000000000
--- a/datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
+++ /dev/null
@@ -1,325 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Tests for the physical optimizer
-
-use datafusion_common::config::ConfigOptions;
-use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics;
-use datafusion_physical_optimizer::PhysicalOptimizerRule;
-use datafusion_physical_plan::aggregates::AggregateExec;
-use datafusion_physical_plan::projection::ProjectionExec;
-use datafusion_physical_plan::ExecutionPlan;
-use std::sync::Arc;
-
-use datafusion::error::Result;
-use datafusion::logical_expr::Operator;
-use datafusion::prelude::SessionContext;
-use datafusion::test_util::TestAggregate;
-use datafusion_physical_plan::aggregates::PhysicalGroupBy;
-use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
-use datafusion_physical_plan::common;
-use datafusion_physical_plan::filter::FilterExec;
-use datafusion_physical_plan::memory::MemoryExec;
-
-use arrow::array::Int32Array;
-use arrow::datatypes::{DataType, Field, Schema};
-use arrow::record_batch::RecordBatch;
-use datafusion_common::cast::as_int64_array;
-use datafusion_physical_expr::expressions::{self, cast};
-use datafusion_physical_plan::aggregates::AggregateMode;
-
-/// Mock data using a MemoryExec which has an exact count statistic
-fn mock_data() -> Result<Arc<MemoryExec>> {
-    let schema = Arc::new(Schema::new(vec![
-        Field::new("a", DataType::Int32, true),
-        Field::new("b", DataType::Int32, true),
-    ]));
-
-    let batch = RecordBatch::try_new(
-        Arc::clone(&schema),
-        vec![
-            Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
-            Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
-        ],
-    )?;
-
-    Ok(Arc::new(MemoryExec::try_new(
-        &[vec![batch]],
-        Arc::clone(&schema),
-        None,
-    )?))
-}
-
-/// Checks that the count optimization was applied and we still get the right 
result
-async fn assert_count_optim_success(
-    plan: AggregateExec,
-    agg: TestAggregate,
-) -> Result<()> {
-    let session_ctx = SessionContext::new();
-    let state = session_ctx.state();
-    let plan: Arc<dyn ExecutionPlan> = Arc::new(plan);
-
-    let optimized =
-        AggregateStatistics::new().optimize(Arc::clone(&plan), 
state.config_options())?;
-
-    // A ProjectionExec is a sign that the count optimization was applied
-    assert!(optimized.as_any().is::<ProjectionExec>());
-
-    // run both the optimized and nonoptimized plan
-    let optimized_result =
-        common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?;
-    let nonoptimized_result =
-        common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
-    assert_eq!(optimized_result.len(), nonoptimized_result.len());
-
-    //  and validate the results are the same and expected
-    assert_eq!(optimized_result.len(), 1);
-    check_batch(optimized_result.into_iter().next().unwrap(), &agg);
-    // check the non optimized one too to ensure types and names remain the 
same
-    assert_eq!(nonoptimized_result.len(), 1);
-    check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
-
-    Ok(())
-}
-
-fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
-    let schema = batch.schema();
-    let fields = schema.fields();
-    assert_eq!(fields.len(), 1);
-
-    let field = &fields[0];
-    assert_eq!(field.name(), agg.column_name());
-    assert_eq!(field.data_type(), &DataType::Int64);
-    // note that nullabiolity differs
-
-    assert_eq!(
-        as_int64_array(batch.column(0)).unwrap().values(),
-        &[agg.expected_count()]
-    );
-}
-
-#[tokio::test]
-async fn test_count_partial_direct_child() -> Result<()> {
-    // basic test case with the aggregation applied on a source with exact 
statistics
-    let source = mock_data()?;
-    let schema = source.schema();
-    let agg = TestAggregate::new_count_star();
-
-    let partial_agg = AggregateExec::try_new(
-        AggregateMode::Partial,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        source,
-        Arc::clone(&schema),
-    )?;
-
-    let final_agg = AggregateExec::try_new(
-        AggregateMode::Final,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        Arc::new(partial_agg),
-        Arc::clone(&schema),
-    )?;
-
-    assert_count_optim_success(final_agg, agg).await?;
-
-    Ok(())
-}
-
-#[tokio::test]
-async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
-    // basic test case with the aggregation applied on a source with exact 
statistics
-    let source = mock_data()?;
-    let schema = source.schema();
-    let agg = TestAggregate::new_count_column(&schema);
-
-    let partial_agg = AggregateExec::try_new(
-        AggregateMode::Partial,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        source,
-        Arc::clone(&schema),
-    )?;
-
-    let final_agg = AggregateExec::try_new(
-        AggregateMode::Final,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        Arc::new(partial_agg),
-        Arc::clone(&schema),
-    )?;
-
-    assert_count_optim_success(final_agg, agg).await?;
-
-    Ok(())
-}
-
-#[tokio::test]
-async fn test_count_partial_indirect_child() -> Result<()> {
-    let source = mock_data()?;
-    let schema = source.schema();
-    let agg = TestAggregate::new_count_star();
-
-    let partial_agg = AggregateExec::try_new(
-        AggregateMode::Partial,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        source,
-        Arc::clone(&schema),
-    )?;
-
-    // We introduce an intermediate optimization step between the partial and 
final aggregtator
-    let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
-
-    let final_agg = AggregateExec::try_new(
-        AggregateMode::Final,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        Arc::new(coalesce),
-        Arc::clone(&schema),
-    )?;
-
-    assert_count_optim_success(final_agg, agg).await?;
-
-    Ok(())
-}
-
-#[tokio::test]
-async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
-    let source = mock_data()?;
-    let schema = source.schema();
-    let agg = TestAggregate::new_count_column(&schema);
-
-    let partial_agg = AggregateExec::try_new(
-        AggregateMode::Partial,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        source,
-        Arc::clone(&schema),
-    )?;
-
-    // We introduce an intermediate optimization step between the partial and 
final aggregtator
-    let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
-
-    let final_agg = AggregateExec::try_new(
-        AggregateMode::Final,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        Arc::new(coalesce),
-        Arc::clone(&schema),
-    )?;
-
-    assert_count_optim_success(final_agg, agg).await?;
-
-    Ok(())
-}
-
-#[tokio::test]
-async fn test_count_inexact_stat() -> Result<()> {
-    let source = mock_data()?;
-    let schema = source.schema();
-    let agg = TestAggregate::new_count_star();
-
-    // adding a filter makes the statistics inexact
-    let filter = Arc::new(FilterExec::try_new(
-        expressions::binary(
-            expressions::col("a", &schema)?,
-            Operator::Gt,
-            cast(expressions::lit(1u32), &schema, DataType::Int32)?,
-            &schema,
-        )?,
-        source,
-    )?);
-
-    let partial_agg = AggregateExec::try_new(
-        AggregateMode::Partial,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        filter,
-        Arc::clone(&schema),
-    )?;
-
-    let final_agg = AggregateExec::try_new(
-        AggregateMode::Final,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        Arc::new(partial_agg),
-        Arc::clone(&schema),
-    )?;
-
-    let conf = ConfigOptions::new();
-    let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), 
&conf)?;
-
-    // check that the original ExecutionPlan was not replaced
-    assert!(optimized.as_any().is::<AggregateExec>());
-
-    Ok(())
-}
-
-#[tokio::test]
-async fn test_count_with_nulls_inexact_stat() -> Result<()> {
-    let source = mock_data()?;
-    let schema = source.schema();
-    let agg = TestAggregate::new_count_column(&schema);
-
-    // adding a filter makes the statistics inexact
-    let filter = Arc::new(FilterExec::try_new(
-        expressions::binary(
-            expressions::col("a", &schema)?,
-            Operator::Gt,
-            cast(expressions::lit(1u32), &schema, DataType::Int32)?,
-            &schema,
-        )?,
-        source,
-    )?);
-
-    let partial_agg = AggregateExec::try_new(
-        AggregateMode::Partial,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        filter,
-        Arc::clone(&schema),
-    )?;
-
-    let final_agg = AggregateExec::try_new(
-        AggregateMode::Final,
-        PhysicalGroupBy::default(),
-        vec![agg.count_expr(&schema)],
-        vec![None],
-        Arc::new(partial_agg),
-        Arc::clone(&schema),
-    )?;
-
-    let conf = ConfigOptions::new();
-    let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), 
&conf)?;
-
-    // check that the original ExecutionPlan was not replaced
-    assert!(optimized.as_any().is::<AggregateExec>());
-
-    Ok(())
-}
diff --git a/datafusion/core/tests/physical_optimizer/mod.rs 
b/datafusion/core/tests/physical_optimizer/mod.rs
index 4ec981bf2a..c06783aa02 100644
--- a/datafusion/core/tests/physical_optimizer/mod.rs
+++ b/datafusion/core/tests/physical_optimizer/mod.rs
@@ -15,7 +15,6 @@
 // specific language governing permissions and limitations
 // under the License.
 
-mod aggregate_statistics;
 mod combine_partial_final_agg;
 mod limit_pushdown;
 mod limited_distinct_aggregation;
diff --git a/datafusion/physical-optimizer/Cargo.toml 
b/datafusion/physical-optimizer/Cargo.toml
index acf3eee105..e7bf4a80fc 100644
--- a/datafusion/physical-optimizer/Cargo.toml
+++ b/datafusion/physical-optimizer/Cargo.toml
@@ -32,9 +32,15 @@ rust-version = { workspace = true }
 workspace = true
 
 [dependencies]
+arrow = { workspace = true }
 arrow-schema = { workspace = true }
 datafusion-common = { workspace = true, default-features = true }
 datafusion-execution = { workspace = true }
+datafusion-expr-common = { workspace = true, default-features = true }
 datafusion-physical-expr = { workspace = true }
 datafusion-physical-plan = { workspace = true }
 itertools = { workspace = true }
+
+[dev-dependencies]
+datafusion-functions-aggregate = { workspace = true }
+tokio = { workspace = true }
diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs 
b/datafusion/physical-optimizer/src/aggregate_statistics.rs
index a11b498b95..fd21362fd3 100644
--- a/datafusion/physical-optimizer/src/aggregate_statistics.rs
+++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs
@@ -20,15 +20,15 @@ use std::sync::Arc;
 
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::scalar::ScalarValue;
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::Result;
 use datafusion_physical_plan::aggregates::AggregateExec;
+use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
 use datafusion_physical_plan::projection::ProjectionExec;
+use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
 use datafusion_physical_plan::{expressions, ExecutionPlan};
 
 use crate::PhysicalOptimizerRule;
-use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
-use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs};
 
 /// Optimizer that uses available statistics for aggregate functions
 #[derive(Default, Debug)]
@@ -146,3 +146,373 @@ fn take_optimizable_value_from_statistics(
     let value = agg_expr.fun().value_from_stats(statistics_args);
     value.map(|val| (val, agg_expr.name().to_string()))
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::aggregate_statistics::AggregateStatistics;
+    use crate::PhysicalOptimizerRule;
+    use datafusion_common::config::ConfigOptions;
+    use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
+    use datafusion_execution::TaskContext;
+    use datafusion_functions_aggregate::count::count_udaf;
+    use datafusion_physical_expr::aggregate::AggregateExprBuilder;
+    use datafusion_physical_expr::PhysicalExpr;
+    use datafusion_physical_plan::aggregates::AggregateExec;
+    use datafusion_physical_plan::projection::ProjectionExec;
+    use datafusion_physical_plan::udaf::AggregateFunctionExpr;
+    use datafusion_physical_plan::ExecutionPlan;
+    use std::sync::Arc;
+
+    use datafusion_common::Result;
+    use datafusion_expr_common::operator::Operator;
+
+    use datafusion_physical_plan::aggregates::PhysicalGroupBy;
+    use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec;
+    use datafusion_physical_plan::common;
+    use datafusion_physical_plan::filter::FilterExec;
+    use datafusion_physical_plan::memory::MemoryExec;
+
+    use arrow::array::Int32Array;
+    use arrow::datatypes::{DataType, Field, Schema};
+    use arrow::record_batch::RecordBatch;
+    use datafusion_common::cast::as_int64_array;
+    use datafusion_physical_expr::expressions::{self, cast};
+    use datafusion_physical_plan::aggregates::AggregateMode;
+
+    /// Describe the type of aggregate being tested
+    pub enum TestAggregate {
+        /// Testing COUNT(*) type aggregates
+        CountStar,
+
+        /// Testing for COUNT(column) aggregate
+        ColumnA(Arc<Schema>),
+    }
+
+    impl TestAggregate {
+        /// Create a new COUNT(*) aggregate
+        pub fn new_count_star() -> Self {
+            Self::CountStar
+        }
+
+        /// Create a new COUNT(column) aggregate
+        pub fn new_count_column(schema: &Arc<Schema>) -> Self {
+            Self::ColumnA(Arc::clone(schema))
+        }
+
+        /// Return appropriate expr depending if COUNT is for col or table (*)
+        pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr {
+            AggregateExprBuilder::new(count_udaf(), vec![self.column()])
+                .schema(Arc::new(schema.clone()))
+                .alias(self.column_name())
+                .build()
+                .unwrap()
+        }
+
+        /// what argument would this aggregate need in the plan?
+        fn column(&self) -> Arc<dyn PhysicalExpr> {
+            match self {
+                Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
+                Self::ColumnA(s) => expressions::col("a", s).unwrap(),
+            }
+        }
+
+        /// What name would this aggregate produce in a plan?
+        pub fn column_name(&self) -> &'static str {
+            match self {
+                Self::CountStar => "COUNT(*)",
+                Self::ColumnA(_) => "COUNT(a)",
+            }
+        }
+
+        /// What is the expected count?
+        pub fn expected_count(&self) -> i64 {
+            match self {
+                TestAggregate::CountStar => 3,
+                TestAggregate::ColumnA(_) => 2,
+            }
+        }
+    }
+
+    /// Mock data using a MemoryExec which has an exact count statistic
+    fn mock_data() -> Result<Arc<MemoryExec>> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Int32, true),
+        ]));
+
+        let batch = RecordBatch::try_new(
+            Arc::clone(&schema),
+            vec![
+                Arc::new(Int32Array::from(vec![Some(1), Some(2), None])),
+                Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])),
+            ],
+        )?;
+
+        Ok(Arc::new(MemoryExec::try_new(
+            &[vec![batch]],
+            Arc::clone(&schema),
+            None,
+        )?))
+    }
+
+    /// Checks that the count optimization was applied and we still get the 
right result
+    async fn assert_count_optim_success(
+        plan: AggregateExec,
+        agg: TestAggregate,
+    ) -> Result<()> {
+        let task_ctx = Arc::new(TaskContext::default());
+        let plan: Arc<dyn ExecutionPlan> = Arc::new(plan);
+
+        let config = ConfigOptions::new();
+        let optimized =
+            AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?;
+
+        // A ProjectionExec is a sign that the count optimization was applied
+        assert!(optimized.as_any().is::<ProjectionExec>());
+
+        // run both the optimized and nonoptimized plan
+        let optimized_result =
+            common::collect(optimized.execute(0, 
Arc::clone(&task_ctx))?).await?;
+        let nonoptimized_result = common::collect(plan.execute(0, 
task_ctx)?).await?;
+        assert_eq!(optimized_result.len(), nonoptimized_result.len());
+
+        //  and validate the results are the same and expected
+        assert_eq!(optimized_result.len(), 1);
+        check_batch(optimized_result.into_iter().next().unwrap(), &agg);
+        // check the non optimized one too to ensure types and names remain 
the same
+        assert_eq!(nonoptimized_result.len(), 1);
+        check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
+
+        Ok(())
+    }
+
+    fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
+        let schema = batch.schema();
+        let fields = schema.fields();
+        assert_eq!(fields.len(), 1);
+
+        let field = &fields[0];
+        assert_eq!(field.name(), agg.column_name());
+        assert_eq!(field.data_type(), &DataType::Int64);
+        // note that nullabiolity differs
+
+        assert_eq!(
+            as_int64_array(batch.column(0)).unwrap().values(),
+            &[agg.expected_count()]
+        );
+    }
+
+    #[tokio::test]
+    async fn test_count_partial_direct_child() -> Result<()> {
+        // basic test case with the aggregation applied on a source with exact 
statistics
+        let source = mock_data()?;
+        let schema = source.schema();
+        let agg = TestAggregate::new_count_star();
+
+        let partial_agg = AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            source,
+            Arc::clone(&schema),
+        )?;
+
+        let final_agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            Arc::new(partial_agg),
+            Arc::clone(&schema),
+        )?;
+
+        assert_count_optim_success(final_agg, agg).await?;
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_count_partial_with_nulls_direct_child() -> Result<()> {
+        // basic test case with the aggregation applied on a source with exact 
statistics
+        let source = mock_data()?;
+        let schema = source.schema();
+        let agg = TestAggregate::new_count_column(&schema);
+
+        let partial_agg = AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            source,
+            Arc::clone(&schema),
+        )?;
+
+        let final_agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            Arc::new(partial_agg),
+            Arc::clone(&schema),
+        )?;
+
+        assert_count_optim_success(final_agg, agg).await?;
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_count_partial_indirect_child() -> Result<()> {
+        let source = mock_data()?;
+        let schema = source.schema();
+        let agg = TestAggregate::new_count_star();
+
+        let partial_agg = AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            source,
+            Arc::clone(&schema),
+        )?;
+
+        // We introduce an intermediate optimization step between the partial 
and final aggregtator
+        let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
+
+        let final_agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            Arc::new(coalesce),
+            Arc::clone(&schema),
+        )?;
+
+        assert_count_optim_success(final_agg, agg).await?;
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
+        let source = mock_data()?;
+        let schema = source.schema();
+        let agg = TestAggregate::new_count_column(&schema);
+
+        let partial_agg = AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            source,
+            Arc::clone(&schema),
+        )?;
+
+        // We introduce an intermediate optimization step between the partial 
and final aggregtator
+        let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg));
+
+        let final_agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            Arc::new(coalesce),
+            Arc::clone(&schema),
+        )?;
+
+        assert_count_optim_success(final_agg, agg).await?;
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_count_inexact_stat() -> Result<()> {
+        let source = mock_data()?;
+        let schema = source.schema();
+        let agg = TestAggregate::new_count_star();
+
+        // adding a filter makes the statistics inexact
+        let filter = Arc::new(FilterExec::try_new(
+            expressions::binary(
+                expressions::col("a", &schema)?,
+                Operator::Gt,
+                cast(expressions::lit(1u32), &schema, DataType::Int32)?,
+                &schema,
+            )?,
+            source,
+        )?);
+
+        let partial_agg = AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            filter,
+            Arc::clone(&schema),
+        )?;
+
+        let final_agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            Arc::new(partial_agg),
+            Arc::clone(&schema),
+        )?;
+
+        let conf = ConfigOptions::new();
+        let optimized =
+            AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
+
+        // check that the original ExecutionPlan was not replaced
+        assert!(optimized.as_any().is::<AggregateExec>());
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_count_with_nulls_inexact_stat() -> Result<()> {
+        let source = mock_data()?;
+        let schema = source.schema();
+        let agg = TestAggregate::new_count_column(&schema);
+
+        // adding a filter makes the statistics inexact
+        let filter = Arc::new(FilterExec::try_new(
+            expressions::binary(
+                expressions::col("a", &schema)?,
+                Operator::Gt,
+                cast(expressions::lit(1u32), &schema, DataType::Int32)?,
+                &schema,
+            )?,
+            source,
+        )?);
+
+        let partial_agg = AggregateExec::try_new(
+            AggregateMode::Partial,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            filter,
+            Arc::clone(&schema),
+        )?;
+
+        let final_agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            PhysicalGroupBy::default(),
+            vec![agg.count_expr(&schema)],
+            vec![None],
+            Arc::new(partial_agg),
+            Arc::clone(&schema),
+        )?;
+
+        let conf = ConfigOptions::new();
+        let optimized =
+            AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;
+
+        // check that the original ExecutionPlan was not replaced
+        assert!(optimized.as_any().is::<AggregateExec>());
+
+        Ok(())
+    }
+}
diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs 
b/datafusion/physical-optimizer/src/topk_aggregation.rs
index 5dec99535c..c8a28ed0ec 100644
--- a/datafusion/physical-optimizer/src/topk_aggregation.rs
+++ b/datafusion/physical-optimizer/src/topk_aggregation.rs
@@ -19,19 +19,17 @@
 
 use std::sync::Arc;
 
-use datafusion_physical_plan::aggregates::AggregateExec;
-use datafusion_physical_plan::sorts::sort::SortExec;
-use datafusion_physical_plan::ExecutionPlan;
-
-use arrow_schema::DataType;
+use crate::PhysicalOptimizerRule;
+use arrow::datatypes::DataType;
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_common::Result;
 use datafusion_physical_expr::expressions::Column;
-
-use crate::PhysicalOptimizerRule;
+use datafusion_physical_plan::aggregates::AggregateExec;
 use datafusion_physical_plan::execution_plan::CardinalityEffect;
 use datafusion_physical_plan::projection::ProjectionExec;
+use datafusion_physical_plan::sorts::sort::SortExec;
+use datafusion_physical_plan::ExecutionPlan;
 use itertools::Itertools;
 
 /// An optimizer rule that passes a `limit` hint to aggregations if the whole 
result is not needed


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to