This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ddd99c84 Fix `AggregateStatistics` optimization so it doesn't change 
output type (#2674)
8ddd99c84 is described below

commit 8ddd99c8432fdac2c236040973f984a4146f18b7
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jun 2 17:18:31 2022 -0400

    Fix `AggregateStatistics` optimization so it doesn't change output type 
(#2674)
    
    * Fix `AggregateStatistics` optimization so it doens't change output type
    
    * fix test
    
    * Give some constants symbolic names to improve readability
    
    * Consolidate expected differences in COUNT(*) and COUNT(a) in tests
    
    * Simplify how the verification is done
    
    * fmt
---
 .../src/physical_optimizer/aggregate_statistics.rs | 161 +++++++++++++++------
 datafusion/core/tests/custom_sources.rs            |   6 +-
 datafusion/expr/src/utils.rs                       |   6 +-
 datafusion/sql/src/planner.rs                      |   9 +-
 4 files changed, 133 insertions(+), 49 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs 
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 4cf96d235..bcf4fec07 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -19,6 +19,7 @@
 use std::sync::Arc;
 
 use arrow::datatypes::Schema;
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
 
 use crate::execution::context::SessionConfig;
 use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
@@ -37,6 +38,9 @@ use crate::error::Result;
 #[derive(Default)]
 pub struct AggregateStatistics {}
 
+/// The name of the column corresponding to [`COUNT_STAR_EXPANSION`]
+const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))";
+
 impl AggregateStatistics {
     #[allow(missing_docs)]
     pub fn new() -> Self {
@@ -148,10 +152,10 @@ fn take_optimizable_table_count(
                 .as_any()
                 .downcast_ref::<expressions::Literal>()
             {
-                if lit_expr.value() == &ScalarValue::UInt8(Some(1)) {
+                if lit_expr.value() == &COUNT_STAR_EXPANSION {
                     return Some((
-                        ScalarValue::UInt64(Some(num_rows as u64)),
-                        "COUNT(UInt8(1))",
+                        ScalarValue::Int64(Some(num_rows as i64)),
+                        COUNT_STAR_NAME,
                     ));
                 }
             }
@@ -183,7 +187,7 @@ fn take_optimizable_column_count(
                 {
                     let expr = format!("COUNT({})", col_expr.name());
                     return Some((
-                        ScalarValue::UInt64(Some((num_rows - val) as u64)),
+                        ScalarValue::Int64(Some((num_rows - val) as i64)),
                         expr,
                     ));
                 }
@@ -254,9 +258,10 @@ mod tests {
     use super::*;
     use std::sync::Arc;
 
-    use arrow::array::{Int32Array, UInt64Array};
+    use arrow::array::{Int32Array, Int64Array};
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow::record_batch::RecordBatch;
+    use datafusion_physical_expr::PhysicalExpr;
 
     use crate::error::Result;
     use crate::logical_plan::Operator;
@@ -291,40 +296,106 @@ mod tests {
     }
 
     /// Checks that the count optimization was applied and we still get the 
right result
-    async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> 
Result<()> {
+    async fn assert_count_optim_success(
+        plan: AggregateExec,
+        agg: TestAggregate,
+    ) -> Result<()> {
         let session_ctx = SessionContext::new();
-        let task_ctx = session_ctx.task_ctx();
         let conf = session_ctx.copied_config();
-        let optimized = AggregateStatistics::new().optimize(Arc::new(plan), 
&conf)?;
-
-        let (col, count) = match nulls {
-            false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 
3),
-            true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
-        };
+        let plan = Arc::new(plan) as _;
+        let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), 
&conf)?;
 
         // A ProjectionExec is a sign that the count optimization was applied
         assert!(optimized.as_any().is::<ProjectionExec>());
-        let result = common::collect(optimized.execute(0, task_ctx)?).await?;
-        assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
+
+        // run both the optimized and nonoptimized plan
+        let optimized_result =
+            common::collect(optimized.execute(0, 
session_ctx.task_ctx())?).await?;
+        let nonoptimized_result =
+            common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
+        assert_eq!(optimized_result.len(), nonoptimized_result.len());
+
+        //  and validate the results are the same and expected
+        assert_eq!(optimized_result.len(), 1);
+        check_batch(optimized_result.into_iter().next().unwrap(), &agg);
+        // check the non optimized one too to ensure types and names remain 
the same
+        assert_eq!(nonoptimized_result.len(), 1);
+        check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
+
+        Ok(())
+    }
+
+    fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
+        let schema = batch.schema();
+        let fields = schema.fields();
+        assert_eq!(fields.len(), 1);
+
+        let field = &fields[0];
+        assert_eq!(field.name(), agg.column_name());
+        assert_eq!(field.data_type(), &DataType::Int64);
+        // note that nullabiolity differs
+
         assert_eq!(
-            result[0]
+            batch
                 .column(0)
                 .as_any()
-                .downcast_ref::<UInt64Array>()
+                .downcast_ref::<Int64Array>()
                 .unwrap()
                 .values(),
-            &[count]
+            &[agg.expected_count()]
         );
-        Ok(())
     }
 
-    fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn 
AggregateExpr> {
-        // Return appropriate expr depending if COUNT is for col or table
-        let expr = match schema {
-            None => expressions::lit(ScalarValue::UInt8(Some(1))),
-            Some(s) => expressions::col(col.unwrap(), s).unwrap(),
-        };
-        Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
+    /// Describe the type of aggregate being tested
+    enum TestAggregate {
+        /// Testing COUNT(*) type aggregates
+        CountStar,
+
+        /// Testing for COUNT(column) aggregate
+        ColumnA(Arc<Schema>),
+    }
+
+    impl TestAggregate {
+        fn new_count_star() -> Self {
+            Self::CountStar
+        }
+
+        fn new_count_column(schema: &Arc<Schema>) -> Self {
+            Self::ColumnA(schema.clone())
+        }
+
+        /// Return appropriate expr depending if COUNT is for col or table (*)
+        fn count_expr(&self) -> Arc<dyn AggregateExpr> {
+            Arc::new(Count::new(
+                self.column(),
+                self.column_name(),
+                DataType::Int64,
+            ))
+        }
+
+        /// what argument would this aggregate need in the plan?
+        fn column(&self) -> Arc<dyn PhysicalExpr> {
+            match self {
+                Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
+                Self::ColumnA(s) => expressions::col("a", s).unwrap(),
+            }
+        }
+
+        /// What name would this aggregate produce in a plan?
+        fn column_name(&self) -> &'static str {
+            match self {
+                Self::CountStar => COUNT_STAR_NAME,
+                Self::ColumnA(_) => "COUNT(a)",
+            }
+        }
+
+        /// What is the expected count?
+        fn expected_count(&self) -> i64 {
+            match self {
+                TestAggregate::CountStar => 3,
+                TestAggregate::ColumnA(_) => 2,
+            }
+        }
     }
 
     #[tokio::test]
@@ -332,11 +403,12 @@ mod tests {
         // basic test case with the aggregation applied on a source with exact 
statistics
         let source = mock_data()?;
         let schema = source.schema();
+        let agg = TestAggregate::new_count_star();
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
             vec![],
-            vec![count_expr(None, None)],
+            vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
         )?;
@@ -344,12 +416,12 @@ mod tests {
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
             vec![],
-            vec![count_expr(None, None)],
+            vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
 
-        assert_count_optim_success(final_agg, false).await?;
+        assert_count_optim_success(final_agg, agg).await?;
 
         Ok(())
     }
@@ -359,11 +431,12 @@ mod tests {
         // basic test case with the aggregation applied on a source with exact 
statistics
         let source = mock_data()?;
         let schema = source.schema();
+        let agg = TestAggregate::new_count_column(&schema);
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
             vec![],
-            vec![count_expr(Some(&schema), Some("a"))],
+            vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
         )?;
@@ -371,12 +444,12 @@ mod tests {
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
             vec![],
-            vec![count_expr(Some(&schema), Some("a"))],
+            vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
 
-        assert_count_optim_success(final_agg, true).await?;
+        assert_count_optim_success(final_agg, agg).await?;
 
         Ok(())
     }
@@ -385,11 +458,12 @@ mod tests {
     async fn test_count_partial_indirect_child() -> Result<()> {
         let source = mock_data()?;
         let schema = source.schema();
+        let agg = TestAggregate::new_count_star();
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
             vec![],
-            vec![count_expr(None, None)],
+            vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
         )?;
@@ -400,12 +474,12 @@ mod tests {
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
             vec![],
-            vec![count_expr(None, None)],
+            vec![agg.count_expr()],
             Arc::new(coalesce),
             Arc::clone(&schema),
         )?;
 
-        assert_count_optim_success(final_agg, false).await?;
+        assert_count_optim_success(final_agg, agg).await?;
 
         Ok(())
     }
@@ -414,11 +488,12 @@ mod tests {
     async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
         let source = mock_data()?;
         let schema = source.schema();
+        let agg = TestAggregate::new_count_column(&schema);
 
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
             vec![],
-            vec![count_expr(Some(&schema), Some("a"))],
+            vec![agg.count_expr()],
             source,
             Arc::clone(&schema),
         )?;
@@ -429,12 +504,12 @@ mod tests {
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
             vec![],
-            vec![count_expr(Some(&schema), Some("a"))],
+            vec![agg.count_expr()],
             Arc::new(coalesce),
             Arc::clone(&schema),
         )?;
 
-        assert_count_optim_success(final_agg, true).await?;
+        assert_count_optim_success(final_agg, agg).await?;
 
         Ok(())
     }
@@ -443,6 +518,7 @@ mod tests {
     async fn test_count_inexact_stat() -> Result<()> {
         let source = mock_data()?;
         let schema = source.schema();
+        let agg = TestAggregate::new_count_star();
 
         // adding a filter makes the statistics inexact
         let filter = Arc::new(FilterExec::try_new(
@@ -458,7 +534,7 @@ mod tests {
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
             vec![],
-            vec![count_expr(None, None)],
+            vec![agg.count_expr()],
             filter,
             Arc::clone(&schema),
         )?;
@@ -466,7 +542,7 @@ mod tests {
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
             vec![],
-            vec![count_expr(None, None)],
+            vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
@@ -485,6 +561,7 @@ mod tests {
     async fn test_count_with_nulls_inexact_stat() -> Result<()> {
         let source = mock_data()?;
         let schema = source.schema();
+        let agg = TestAggregate::new_count_column(&schema);
 
         // adding a filter makes the statistics inexact
         let filter = Arc::new(FilterExec::try_new(
@@ -500,7 +577,7 @@ mod tests {
         let partial_agg = AggregateExec::try_new(
             AggregateMode::Partial,
             vec![],
-            vec![count_expr(Some(&schema), Some("a"))],
+            vec![agg.count_expr()],
             filter,
             Arc::clone(&schema),
         )?;
@@ -508,7 +585,7 @@ mod tests {
         let final_agg = AggregateExec::try_new(
             AggregateMode::Final,
             vec![],
-            vec![count_expr(Some(&schema), Some("a"))],
+            vec![agg.count_expr()],
             Arc::new(partial_agg),
             Arc::clone(&schema),
         )?;
diff --git a/datafusion/core/tests/custom_sources.rs 
b/datafusion/core/tests/custom_sources.rs
index 1e4ac6e51..cccac0523 100644
--- a/datafusion/core/tests/custom_sources.rs
+++ b/datafusion/core/tests/custom_sources.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::{Int32Array, PrimitiveArray, UInt64Array};
+use arrow::array::{Int32Array, Int64Array, PrimitiveArray};
 use arrow::compute::kernels::aggregate;
 use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
 use arrow::error::Result as ArrowResult;
@@ -284,12 +284,12 @@ async fn optimizers_catch_all_statistics() {
 
     let expected = RecordBatch::try_new(
         Arc::new(Schema::new(vec![
-            Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
+            Field::new("COUNT(UInt8(1))", DataType::Int64, false),
             Field::new("MIN(test.c1)", DataType::Int32, false),
             Field::new("MAX(test.c1)", DataType::Int32, false),
         ])),
         vec![
-            Arc::new(UInt64Array::from_slice(&[4])),
+            Arc::new(Int64Array::from_slice(&[4])),
             Arc::new(Int32Array::from_slice(&[1])),
             Arc::new(Int32Array::from_slice(&[100])),
         ],
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index ac22d094e..3986eb3e6 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -26,11 +26,15 @@ use crate::logical_plan::{
 };
 use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
 use datafusion_common::{
-    Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
+    Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, 
ScalarValue,
 };
 use std::collections::HashSet;
 use std::sync::Arc;
 
+///  The value to which `COUNT(*)` is expanded to in
+///  `COUNT(<constant>)` expressions
+pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1));
+
 /// Recursively walk a list of expression trees, collecting the unique set of 
columns
 /// referenced in the expression
 pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> 
Result<()> {
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 51f160336..1e5daa472 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -30,7 +30,7 @@ use datafusion_expr::logical_plan::{
 };
 use datafusion_expr::utils::{
     expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, 
expr_to_columns,
-    find_aggregate_exprs, find_column_exprs, find_window_exprs,
+    find_aggregate_exprs, find_column_exprs, find_window_exprs, 
COUNT_STAR_EXPANSION,
 };
 use datafusion_expr::{
     and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
@@ -2122,14 +2122,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         schema: &DFSchema,
     ) -> Result<(AggregateFunction, Vec<Expr>)> {
         let args = match fun {
+            // Special case rewrite COUNT(*) to COUNT(constant)
             AggregateFunction::Count => function
                 .args
                 .into_iter()
                 .map(|a| match a {
                     FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Value(
                         Value::Number(_, _),
-                    ))) => Ok(lit(1_u8)),
-                    FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => 
Ok(lit(1_u8)),
+                    ))) => Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())),
+                    FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
+                        Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
+                    }
                     _ => self.sql_fn_arg_to_logical_expr(a, schema, &mut 
HashMap::new()),
                 })
                 .collect::<Result<Vec<Expr>>>()?,

Reply via email to