alamb commented on code in PR #2674:
URL: https://github.com/apache/arrow-datafusion/pull/2674#discussion_r886837655
##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -37,6 +38,9 @@ use crate::error::Result;
#[derive(Default)]
pub struct AggregateStatistics {}
+/// The name of the column corresponding to [`COUNT_STAR_EXPANSION`]
+const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))";
Review Comment:
This constant was hard coded in a few places and I think this symbolic name
helps understand what it is doing
##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -148,10 +152,10 @@ fn take_optimizable_table_count(
.as_any()
.downcast_ref::<expressions::Literal>()
{
- if lit_expr.value() == &ScalarValue::UInt8(Some(1)) {
+ if lit_expr.value() == &COUNT_STAR_EXPANSION {
Review Comment:
There was an implicit coupling between the SQL planner and this file, which
I have now made explicit with a named constant
##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -293,38 +297,80 @@ mod tests {
/// Checks that the count optimization was applied and we still get the
right result
async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) ->
Result<()> {
let session_ctx = SessionContext::new();
- let task_ctx = session_ctx.task_ctx();
let conf = session_ctx.copied_config();
- let optimized = AggregateStatistics::new().optimize(Arc::new(plan),
&conf)?;
+ let plan = Arc::new(plan) as _;
+ let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan),
&conf)?;
let (col, count) = match nulls {
- false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
3),
- true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
+ false => (Field::new(COUNT_STAR_NAME, DataType::Int64, false), 3),
+ true => (Field::new("COUNT(a)", DataType::Int64, false), 2),
};
// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
+ let task_ctx = session_ctx.task_ctx();
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
assert_eq!(
result[0]
.column(0)
.as_any()
- .downcast_ref::<UInt64Array>()
+ .downcast_ref::<Int64Array>()
.unwrap()
.values(),
&[count]
);
+
+ // Validate that the optimized plan returns the exact same
+ // answer (both schema and data) as the original plan
+ let task_ctx = session_ctx.task_ctx();
Review Comment:
This test would have caught this issue when it was introduced in
https://github.com/apache/arrow-datafusion/pull/2636
##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -293,38 +297,80 @@ mod tests {
/// Checks that the count optimization was applied and we still get the
right result
async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) ->
Result<()> {
let session_ctx = SessionContext::new();
- let task_ctx = session_ctx.task_ctx();
let conf = session_ctx.copied_config();
- let optimized = AggregateStatistics::new().optimize(Arc::new(plan),
&conf)?;
+ let plan = Arc::new(plan) as _;
+ let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan),
&conf)?;
let (col, count) = match nulls {
- false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
3),
- true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
+ false => (Field::new(COUNT_STAR_NAME, DataType::Int64, false), 3),
+ true => (Field::new("COUNT(a)", DataType::Int64, false), 2),
};
// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
+ let task_ctx = session_ctx.task_ctx();
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
assert_eq!(
result[0]
.column(0)
.as_any()
- .downcast_ref::<UInt64Array>()
+ .downcast_ref::<Int64Array>()
.unwrap()
.values(),
&[count]
);
+
+ // Validate that the optimized plan returns the exact same
+ // answer (both schema and data) as the original plan
+ let task_ctx = session_ctx.task_ctx();
+ let plan_result = common::collect(plan.execute(0, task_ctx)?).await?;
+ assert_eq!(normalize(result), normalize(plan_result));
Ok(())
}
+ /// Normalize record batches for comparison:
+ /// 1. Sets nullable to `true`
+ fn normalize(batches: Vec<RecordBatch>) -> Vec<RecordBatch> {
+ let schema = normalize_schema(&batches[0].schema());
+ batches
+ .into_iter()
+ .map(|batch| {
+ RecordBatch::try_new(schema.clone(), batch.columns().to_vec())
+ .expect("Error creating record batch")
+ })
+ .collect()
+ }
+ fn normalize_schema(schema: &Schema) -> Arc<Schema> {
+ let nullable = true;
+ let normalized_fields = schema
+ .fields()
+ .iter()
+ .map(|f| {
+ Field::new(f.name(), f.data_type().clone(), nullable)
+ .with_metadata(f.metadata().cloned())
+ })
+ .collect();
+ Arc::new(Schema::new_with_metadata(
+ normalized_fields,
+ schema.metadata().clone(),
+ ))
+ }
+
fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn
AggregateExpr> {
- // Return appropriate expr depending if COUNT is for col or table
- let expr = match schema {
- None => expressions::lit(ScalarValue::UInt8(Some(1))),
- Some(s) => expressions::col(col.unwrap(), s).unwrap(),
+ // Return appropriate expr depending if COUNT is for col or table (*)
+ let (expr, name) = match schema {
+ None => (
+ expressions::lit(COUNT_STAR_EXPANSION),
+ COUNT_STAR_NAME.to_string(),
+ ),
+ Some(s) => (
+ expressions::col(col.unwrap(), s).unwrap(),
+ format!("COUNT({})", col.unwrap()),
+ ),
};
- Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
+
+ Arc::new(Count::new(expr, name, DataType::Int64))
Review Comment:
Now that the schema is checked, we can't use some arbitrary column name, we
need to use the actual name the plan would
##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -293,38 +297,80 @@ mod tests {
/// Checks that the count optimization was applied and we still get the
right result
async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) ->
Result<()> {
let session_ctx = SessionContext::new();
- let task_ctx = session_ctx.task_ctx();
let conf = session_ctx.copied_config();
- let optimized = AggregateStatistics::new().optimize(Arc::new(plan),
&conf)?;
+ let plan = Arc::new(plan) as _;
+ let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan),
&conf)?;
let (col, count) = match nulls {
- false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
3),
- true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
+ false => (Field::new(COUNT_STAR_NAME, DataType::Int64, false), 3),
+ true => (Field::new("COUNT(a)", DataType::Int64, false), 2),
};
// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
+ let task_ctx = session_ctx.task_ctx();
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
assert_eq!(
result[0]
.column(0)
.as_any()
- .downcast_ref::<UInt64Array>()
+ .downcast_ref::<Int64Array>()
.unwrap()
.values(),
&[count]
);
+
+ // Validate that the optimized plan returns the exact same
+ // answer (both schema and data) as the original plan
+ let task_ctx = session_ctx.task_ctx();
+ let plan_result = common::collect(plan.execute(0, task_ctx)?).await?;
+ assert_eq!(normalize(result), normalize(plan_result));
Ok(())
}
+ /// Normalize record batches for comparison:
+ /// 1. Sets nullable to `true`
+ fn normalize(batches: Vec<RecordBatch>) -> Vec<RecordBatch> {
Review Comment:
This is stupid but necessary to pass the tests
##########
datafusion/sql/src/planner.rs:
##########
@@ -2197,14 +2197,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
) -> Result<(AggregateFunction, Vec<Expr>)> {
let args = match fun {
+ // Special case rewrite COUNT(*) to COUNT(constant)
AggregateFunction::Count => function
.args
.into_iter()
.map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Value(
Value::Number(_, _),
- ))) => Ok(lit(1_u8)),
- FunctionArg::Unnamed(FunctionArgExpr::Wildcard) =>
Ok(lit(1_u8)),
+ ))) => Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())),
Review Comment:
this is a readability improvement to name a constant to make what is
happening more explicit
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]