alamb commented on code in PR #2674:
URL: https://github.com/apache/arrow-datafusion/pull/2674#discussion_r888298254
##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -291,65 +296,132 @@ mod tests {
}
/// Checks that the count optimization was applied and we still get the
right result
- async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) ->
Result<()> {
+ async fn assert_count_optim_success(
+ plan: AggregateExec,
+ agg: TestAggregate,
+ ) -> Result<()> {
let session_ctx = SessionContext::new();
- let task_ctx = session_ctx.task_ctx();
let conf = session_ctx.copied_config();
- let optimized = AggregateStatistics::new().optimize(Arc::new(plan),
&conf)?;
-
- let (col, count) = match nulls {
- false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false),
3),
- true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
- };
+ let plan = Arc::new(plan) as _;
+ let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan),
&conf)?;
// A ProjectionExec is a sign that the count optimization was applied
assert!(optimized.as_any().is::<ProjectionExec>());
- let result = common::collect(optimized.execute(0, task_ctx)?).await?;
- assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
+
+ // run both the optimized and nonoptimized plan
+ let optimized_result =
+ common::collect(optimized.execute(0,
session_ctx.task_ctx())?).await?;
+ let nonoptimized_result =
+ common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
+ assert_eq!(optimized_result.len(), nonoptimized_result.len());
+
+ // and validate the results are the same and expected
+ assert_eq!(optimized_result.len(), 1);
+ check_batch(optimized_result.into_iter().next().unwrap(), &agg);
+ // check the non optimized one too to ensure types and names remain
the same
+ assert_eq!(nonoptimized_result.len(), 1);
+ check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
+
+ Ok(())
+ }
+
+ fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
+ let schema = batch.schema();
+ let fields = schema.fields();
+ assert_eq!(fields.len(), 1);
+
+ let field = &fields[0];
+ assert_eq!(field.name(), agg.column_name());
+ assert_eq!(field.data_type(), &DataType::Int64);
+ // note that nullabiolity differs
+
assert_eq!(
- result[0]
+ batch
.column(0)
.as_any()
- .downcast_ref::<UInt64Array>()
+ .downcast_ref::<Int64Array>()
.unwrap()
.values(),
- &[count]
+ &[agg.expected_count()]
);
- Ok(())
}
- fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn
AggregateExpr> {
- // Return appropriate expr depending if COUNT is for col or table
- let expr = match schema {
- None => expressions::lit(ScalarValue::UInt8(Some(1))),
- Some(s) => expressions::col(col.unwrap(), s).unwrap(),
- };
- Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
+ /// Describe the type of aggregate being tested
+ enum TestAggregate {
Review Comment:
This now parameterizes the difference between different tests into an
explicit `enum` rather than implicit assumptions. I think it makes the tests
easier to follow
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]