alamb commented on code in PR #2674:
URL: https://github.com/apache/arrow-datafusion/pull/2674#discussion_r888298254


##########
datafusion/core/src/physical_optimizer/aggregate_statistics.rs:
##########
@@ -291,65 +296,132 @@ mod tests {
     }
 
     /// Checks that the count optimization was applied and we still get the 
right result
-    async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> 
Result<()> {
+    async fn assert_count_optim_success(
+        plan: AggregateExec,
+        agg: TestAggregate,
+    ) -> Result<()> {
         let session_ctx = SessionContext::new();
-        let task_ctx = session_ctx.task_ctx();
         let conf = session_ctx.copied_config();
-        let optimized = AggregateStatistics::new().optimize(Arc::new(plan), 
&conf)?;
-
-        let (col, count) = match nulls {
-            false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 
3),
-            true => (Field::new("COUNT(a)", DataType::UInt64, false), 2),
-        };
+        let plan = Arc::new(plan) as _;
+        let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), 
&conf)?;
 
         // A ProjectionExec is a sign that the count optimization was applied
         assert!(optimized.as_any().is::<ProjectionExec>());
-        let result = common::collect(optimized.execute(0, task_ctx)?).await?;
-        assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col])));
+
+        // run both the optimized and nonoptimized plan
+        let optimized_result =
+            common::collect(optimized.execute(0, 
session_ctx.task_ctx())?).await?;
+        let nonoptimized_result =
+            common::collect(plan.execute(0, session_ctx.task_ctx())?).await?;
+        assert_eq!(optimized_result.len(), nonoptimized_result.len());
+
+        //  and validate the results are the same and expected
+        assert_eq!(optimized_result.len(), 1);
+        check_batch(optimized_result.into_iter().next().unwrap(), &agg);
+        // check the non optimized one too to ensure types and names remain 
the same
+        assert_eq!(nonoptimized_result.len(), 1);
+        check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg);
+
+        Ok(())
+    }
+
+    fn check_batch(batch: RecordBatch, agg: &TestAggregate) {
+        let schema = batch.schema();
+        let fields = schema.fields();
+        assert_eq!(fields.len(), 1);
+
+        let field = &fields[0];
+        assert_eq!(field.name(), agg.column_name());
+        assert_eq!(field.data_type(), &DataType::Int64);
+        // note that nullabiolity differs
+
         assert_eq!(
-            result[0]
+            batch
                 .column(0)
                 .as_any()
-                .downcast_ref::<UInt64Array>()
+                .downcast_ref::<Int64Array>()
                 .unwrap()
                 .values(),
-            &[count]
+            &[agg.expected_count()]
         );
-        Ok(())
     }
 
-    fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn 
AggregateExpr> {
-        // Return appropriate expr depending if COUNT is for col or table
-        let expr = match schema {
-            None => expressions::lit(ScalarValue::UInt8(Some(1))),
-            Some(s) => expressions::col(col.unwrap(), s).unwrap(),
-        };
-        Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64))
+    /// Describe the type of aggregate being tested
+    enum TestAggregate {

Review Comment:
   This now parameterizes the difference between different tests into an 
explicit `enum` rather than implicit assumptions. I think it makes the tests 
easier to follow



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to