This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8ddd99c84 Fix `AggregateStatistics` optimization so it doesn't change
output type (#2674)
8ddd99c84 is described below
commit 8ddd99c8432fdac2c236040973f984a4146f18b7
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Jun 2 17:18:31 2022 -0400
Fix `AggregateStatistics` optimization so it doesn't change output type
(#2674)
* Fix `AggregateStatistics` optimization so it doens't change output type
* fix test
* Give some constants symbolic names to improve readability
* Consolidate expected differences in COUNT(*) and COUNT(a) in tests
* Simplify how the verification is done
* fmt
---
.../src/physical_optimizer/aggregate_statistics.rs | 161 +++++++++++++++------
datafusion/core/tests/custom_sources.rs | 6 +-
datafusion/expr/src/utils.rs | 6 +-
datafusion/sql/src/planner.rs | 9 +-
4 files changed, 133 insertions(+), 49 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 4cf96d235..bcf4fec07 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -19,6 +19,7 @@
use std::sync::Arc;
use arrow::datatypes::Schema;
+use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use crate::execution::context::SessionConfig;
use crate::physical_plan::aggregates::{AggregateExec, AggregateMode};
@@ -37,6 +38,9 @@ use crate::error::Result;
#[derive(Default)]
pub struct AggregateStatistics {}
+/// The name of the column corresponding to [`COUNT_STAR_EXPANSION`]
+const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))";
+
impl AggregateStatistics {
#[allow(missing_docs)]
pub fn new() -> Self {
@@ -148,10 +152,10 @@ fn take_optimizable_table_count(
.as_any()
.downcast_ref::<expressions::Literal>()
{
- if lit_expr.value() == &ScalarValue::UInt8(Some(1)) {
+ if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
- ScalarValue::UInt64(Some(num_rows as u64)),
- "COUNT(UInt8(1))",
+ ScalarValue::Int64(Some(num_rows as i64)),
+ COUNT_STAR_NAME,
));
}
}
@@ -183,7 +187,7 @@ fn take_optimizable_column_count(
{
let expr = format!("COUNT({})", col_expr.name());
return Some((
- ScalarValue::UInt64(Some((num_rows - val) as u64)),
+ ScalarValue::Int64(Some((num_rows - val) as i64)),
expr,
));
}
@@ -254,9 +258,10 @@ mod tests {
use super::*;
use std::sync::Arc;
- use arrow::array::{Int32Array, UInt64Array};
+ use arrow::array::{Int32Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
+ use datafusion_physical_expr::PhysicalExpr;
use crate::error::Result;
use crate::logical_plan::Operator;
@@ -291,40 +296,106 @@ mod tests {
}
/// Checks that the count optimization was applied and we still get the
right result
- async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) ->
Result<()> {
+ async fn assert_count_optim_success(
+ plan: AggregateExec,
+ agg: TestAggregate,
+ ) -> Result<()> {
let session_ctx = SessionContext::new();
- let task_ctx = session_ctx.task_ctx();
let conf = session_ctx.copied_config();
- let optimized = AggregateStatistics::new().optimize(Arc::new(plan),
&conf)?;
-
- let (col, count) = match nulls {
- false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
3),
- true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
- };
+ let plan = Arc::new(plan) as _;
+ let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan),
&conf)?;
// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
- let result = common::collect(optimized.execute(0, task_ctx)?).await?;
- assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
+
+ // run both the optimized and nonoptimized plan
+ let optimized_result =
+ common::collect(optimized.execute(0,
session_ctx.task_ctx())?).await?;
+ let nonoptimized_result =
+ common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
+ assert_eq!(optimized_result.len(), nonoptimized_result.len());
+
+ // and validate the results are the same and expected
+ assert_eq!(optimized_result.len(), 1);
+ check_batch(optimized_result.into_iter().next().unwrap(), &agg);
+ // check the non optimized one too to ensure types and names remain
the same
+ assert_eq!(nonoptimized_result.len(), 1);
+ check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
+
+ Ok(())
+ }
+
+ fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
+ let schema = batch.schema();
+ let fields = schema.fields();
+ assert_eq!(fields.len(), 1);
+
+ let field = &fields[0];
+ assert_eq!(field.name(), agg.column_name());
+ assert_eq!(field.data_type(), &DataType::Int64);
+ // note that nullabiolity differs
+
assert_eq!(
- result[0]
+ batch
.column(0)
.as_any()
- .downcast_ref::<UInt64Array>()
+ .downcast_ref::<Int64Array>()
.unwrap()
.values(),
- &[count]
+ &[agg.expected_count()]
);
- Ok(())
}
- fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn
AggregateExpr> {
- // Return appropriate expr depending if COUNT is for col or table
- let expr = match schema {
- None => expressions::lit(ScalarValue::UInt8(Some(1))),
- Some(s) => expressions::col(col.unwrap(), s).unwrap(),
- };
- Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
+ /// Describe the type of aggregate being tested
+ enum TestAggregate {
+ /// Testing COUNT(*) type aggregates
+ CountStar,
+
+ /// Testing for COUNT(column) aggregate
+ ColumnA(Arc<Schema>),
+ }
+
+ impl TestAggregate {
+ fn new_count_star() -> Self {
+ Self::CountStar
+ }
+
+ fn new_count_column(schema: &Arc<Schema>) -> Self {
+ Self::ColumnA(schema.clone())
+ }
+
+ /// Return appropriate expr depending if COUNT is for col or table (*)
+ fn count_expr(&self) -> Arc<dyn AggregateExpr> {
+ Arc::new(Count::new(
+ self.column(),
+ self.column_name(),
+ DataType::Int64,
+ ))
+ }
+
+ /// what argument would this aggregate need in the plan?
+ fn column(&self) -> Arc<dyn PhysicalExpr> {
+ match self {
+ Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION),
+ Self::ColumnA(s) => expressions::col("a", s).unwrap(),
+ }
+ }
+
+ /// What name would this aggregate produce in a plan?
+ fn column_name(&self) -> &'static str {
+ match self {
+ Self::CountStar => COUNT_STAR_NAME,
+ Self::ColumnA(_) => "COUNT(a)",
+ }
+ }
+
+ /// What is the expected count?
+ fn expected_count(&self) -> i64 {
+ match self {
+ TestAggregate::CountStar => 3,
+ TestAggregate::ColumnA(_) => 2,
+ }
+ }
}
#[tokio::test]
@@ -332,11 +403,12 @@ mod tests {
// basic test case with the aggregation applied on a source with exact
statistics
let source = mock_data()?;
let schema = source.schema();
+ let agg = TestAggregate::new_count_star();
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
- vec![count_expr(None, None)],
+ vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;
@@ -344,12 +416,12 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
- vec![count_expr(None, None)],
+ vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
- assert_count_optim_success(final_agg, false).await?;
+ assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
@@ -359,11 +431,12 @@ mod tests {
// basic test case with the aggregation applied on a source with exact
statistics
let source = mock_data()?;
let schema = source.schema();
+ let agg = TestAggregate::new_count_column(&schema);
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
- vec![count_expr(Some(&schema), Some("a"))],
+ vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;
@@ -371,12 +444,12 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
- vec![count_expr(Some(&schema), Some("a"))],
+ vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
- assert_count_optim_success(final_agg, true).await?;
+ assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
@@ -385,11 +458,12 @@ mod tests {
async fn test_count_partial_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
+ let agg = TestAggregate::new_count_star();
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
- vec![count_expr(None, None)],
+ vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;
@@ -400,12 +474,12 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
- vec![count_expr(None, None)],
+ vec![agg.count_expr()],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
- assert_count_optim_success(final_agg, false).await?;
+ assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
@@ -414,11 +488,12 @@ mod tests {
async fn test_count_partial_with_nulls_indirect_child() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
+ let agg = TestAggregate::new_count_column(&schema);
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
- vec![count_expr(Some(&schema), Some("a"))],
+ vec![agg.count_expr()],
source,
Arc::clone(&schema),
)?;
@@ -429,12 +504,12 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
- vec![count_expr(Some(&schema), Some("a"))],
+ vec![agg.count_expr()],
Arc::new(coalesce),
Arc::clone(&schema),
)?;
- assert_count_optim_success(final_agg, true).await?;
+ assert_count_optim_success(final_agg, agg).await?;
Ok(())
}
@@ -443,6 +518,7 @@ mod tests {
async fn test_count_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
+ let agg = TestAggregate::new_count_star();
// adding a filter makes the statistics inexact
let filter = Arc::new(FilterExec::try_new(
@@ -458,7 +534,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
- vec![count_expr(None, None)],
+ vec![agg.count_expr()],
filter,
Arc::clone(&schema),
)?;
@@ -466,7 +542,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
- vec![count_expr(None, None)],
+ vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
@@ -485,6 +561,7 @@ mod tests {
async fn test_count_with_nulls_inexact_stat() -> Result<()> {
let source = mock_data()?;
let schema = source.schema();
+ let agg = TestAggregate::new_count_column(&schema);
// adding a filter makes the statistics inexact
let filter = Arc::new(FilterExec::try_new(
@@ -500,7 +577,7 @@ mod tests {
let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
vec![],
- vec![count_expr(Some(&schema), Some("a"))],
+ vec![agg.count_expr()],
filter,
Arc::clone(&schema),
)?;
@@ -508,7 +585,7 @@ mod tests {
let final_agg = AggregateExec::try_new(
AggregateMode::Final,
vec![],
- vec![count_expr(Some(&schema), Some("a"))],
+ vec![agg.count_expr()],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;
diff --git a/datafusion/core/tests/custom_sources.rs
b/datafusion/core/tests/custom_sources.rs
index 1e4ac6e51..cccac0523 100644
--- a/datafusion/core/tests/custom_sources.rs
+++ b/datafusion/core/tests/custom_sources.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::array::{Int32Array, PrimitiveArray, UInt64Array};
+use arrow::array::{Int32Array, Int64Array, PrimitiveArray};
use arrow::compute::kernels::aggregate;
use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
@@ -284,12 +284,12 @@ async fn optimizers_catch_all_statistics() {
let expected = RecordBatch::try_new(
Arc::new(Schema::new(vec![
- Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
+ Field::new("COUNT(UInt8(1))", DataType::Int64, false),
Field::new("MIN(test.c1)", DataType::Int32, false),
Field::new("MAX(test.c1)", DataType::Int32, false),
])),
vec![
- Arc::new(UInt64Array::from_slice(&[4])),
+ Arc::new(Int64Array::from_slice(&[4])),
Arc::new(Int32Array::from_slice(&[1])),
Arc::new(Int32Array::from_slice(&[100])),
],
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index ac22d094e..3986eb3e6 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -26,11 +26,15 @@ use crate::logical_plan::{
};
use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder};
use datafusion_common::{
- Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
+ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue,
};
use std::collections::HashSet;
use std::sync::Arc;
+/// The value to which `COUNT(*)` is expanded to in
+/// `COUNT(<constant>)` expressions
+pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1));
+
/// Recursively walk a list of expression trees, collecting the unique set of
columns
/// referenced in the expression
pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) ->
Result<()> {
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 51f160336..1e5daa472 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -30,7 +30,7 @@ use datafusion_expr::logical_plan::{
};
use datafusion_expr::utils::{
expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
expr_to_columns,
- find_aggregate_exprs, find_column_exprs, find_window_exprs,
+ find_aggregate_exprs, find_column_exprs, find_window_exprs,
COUNT_STAR_EXPANSION,
};
use datafusion_expr::{
and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF,
@@ -2122,14 +2122,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
) -> Result<(AggregateFunction, Vec<Expr>)> {
let args = match fun {
+ // Special case rewrite COUNT(*) to COUNT(constant)
AggregateFunction::Count => function
.args
.into_iter()
.map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Value(
Value::Number(_, _),
- ))) => Ok(lit(1_u8)),
- FunctionArg::Unnamed(FunctionArgExpr::Wildcard) =>
Ok(lit(1_u8)),
+ ))) => Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())),
+ FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
+ Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
+ }
_ => self.sql_fn_arg_to_logical_expr(a, schema, &mut
HashMap::new()),
})
.collect::<Result<Vec<Expr>>>()?,