This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ede3de849d Move average unit tests to slt (#10401)
ede3de849d is described below

commit ede3de849d4bbcf47fa3b8cdfe293b7cefea4afd
Author: 张林伟 <[email protected]>
AuthorDate: Tue May 7 15:02:24 2024 +0800

    Move average unit tests to slt (#10401)
---
 datafusion/physical-expr/src/aggregate/average.rs | 107 ---------------------
 datafusion/physical-expr/src/expressions/mod.rs   |  43 +--------
 datafusion/sqllogictest/test_files/aggregate.slt  | 108 ++++++++++++++++++++++
 3 files changed, 109 insertions(+), 149 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
index 065c2179f4..80fcc9b70c 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -567,110 +567,3 @@ where
             + self.sums.capacity() * std::mem::size_of::<T>()
     }
 }
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::expressions::tests::assert_aggregate;
-    use arrow::array::*;
-    use datafusion_expr::AggregateFunction;
-
-    #[test]
-    fn avg_decimal() {
-        // test agg
-        let array: ArrayRef = Arc::new(
-            (1..7)
-                .map(Some)
-                .collect::<Decimal128Array>()
-                .with_precision_and_scale(10, 0)
-                .unwrap(),
-        );
-
-        assert_aggregate(
-            array,
-            AggregateFunction::Avg,
-            false,
-            ScalarValue::Decimal128(Some(35000), 14, 4),
-        );
-    }
-
-    #[test]
-    fn avg_decimal_with_nulls() {
-        let array: ArrayRef = Arc::new(
-            (1..6)
-                .map(|i| if i == 2 { None } else { Some(i) })
-                .collect::<Decimal128Array>()
-                .with_precision_and_scale(10, 0)
-                .unwrap(),
-        );
-        assert_aggregate(
-            array,
-            AggregateFunction::Avg,
-            false,
-            ScalarValue::Decimal128(Some(32500), 14, 4),
-        );
-    }
-
-    #[test]
-    fn avg_decimal_all_nulls() {
-        // test agg
-        let array: ArrayRef = Arc::new(
-            std::iter::repeat::<Option<i128>>(None)
-                .take(6)
-                .collect::<Decimal128Array>()
-                .with_precision_and_scale(10, 0)
-                .unwrap(),
-        );
-        assert_aggregate(
-            array,
-            AggregateFunction::Avg,
-            false,
-            ScalarValue::Decimal128(None, 14, 4),
-        );
-    }
-
-    #[test]
-    fn avg_i32() {
-        let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
-        assert_aggregate(a, AggregateFunction::Avg, false, 
ScalarValue::from(3_f64));
-    }
-
-    #[test]
-    fn avg_i32_with_nulls() {
-        let a: ArrayRef = Arc::new(Int32Array::from(vec![
-            Some(1),
-            None,
-            Some(3),
-            Some(4),
-            Some(5),
-        ]));
-        assert_aggregate(a, AggregateFunction::Avg, false, 
ScalarValue::from(3.25f64));
-    }
-
-    #[test]
-    fn avg_i32_all_nulls() {
-        let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
-        assert_aggregate(a, AggregateFunction::Avg, false, 
ScalarValue::Float64(None));
-    }
-
-    #[test]
-    fn avg_u32() {
-        let a: ArrayRef =
-            Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 
5_u32]));
-        assert_aggregate(a, AggregateFunction::Avg, false, 
ScalarValue::from(3.0f64));
-    }
-
-    #[test]
-    fn avg_f32() {
-        let a: ArrayRef =
-            Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 
5_f32]));
-        assert_aggregate(a, AggregateFunction::Avg, false, 
ScalarValue::from(3_f64));
-    }
-
-    #[test]
-    fn avg_f64() {
-        let a: ArrayRef =
-            Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 
5_f64]));
-        assert_aggregate(a, AggregateFunction::Avg, false, 
ScalarValue::from(3_f64));
-    }
-}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs 
b/datafusion/physical-expr/src/expressions/mod.rs
index 0cd2ac2c9e..3efa965d14 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -99,14 +99,11 @@ pub use try_cast::{try_cast, TryCastExpr};
 pub(crate) mod tests {
     use std::sync::Arc;
 
-    use crate::expressions::{col, create_aggregate_expr, try_cast};
     use crate::AggregateExpr;
     use arrow::record_batch::RecordBatch;
     use arrow_array::ArrayRef;
-    use arrow_schema::{Field, Schema};
     use datafusion_common::{Result, ScalarValue};
-    use datafusion_expr::type_coercion::aggregates::coerce_types;
-    use datafusion_expr::{AggregateFunction, EmitTo};
+    use datafusion_expr::EmitTo;
 
     /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] 
and verify the
     /// result.
@@ -201,44 +198,6 @@ pub(crate) mod tests {
         }};
     }
 
-    /// Assert `function(array) == expected` performing any necessary type 
coercion
-    pub fn assert_aggregate(
-        array: ArrayRef,
-        function: AggregateFunction,
-        distinct: bool,
-        expected: ScalarValue,
-    ) {
-        let data_type = array.data_type();
-        let sig = function.signature();
-        let coerced = coerce_types(&function, &[data_type.clone()], 
&sig).unwrap();
-
-        let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), 
true)]);
-        let batch =
-            RecordBatch::try_new(Arc::new(input_schema.clone()), 
vec![array]).unwrap();
-
-        let input = try_cast(
-            col("a", &input_schema).unwrap(),
-            &input_schema,
-            coerced[0].clone(),
-        )
-        .unwrap();
-
-        let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), 
true)]);
-        let agg = create_aggregate_expr(
-            &function,
-            distinct,
-            &[input],
-            &[],
-            &schema,
-            "agg",
-            false,
-        )
-        .unwrap();
-
-        let result = aggregate(&batch, agg).unwrap();
-        assert_eq!(expected, result);
-    }
-
     /// macro to perform an aggregation with two inputs and verify the result.
     #[macro_export]
     macro_rules! generic_test_op2 {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 50c9c5eb95..f9c41864c8 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1704,6 +1704,114 @@ select avg(c1) from test
 ----
 1.75
 
+# avg_decimal
+statement ok
+create table t (c1 decimal(10, 0)) as values (1), (2), (3), (4), (5), (6);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3.5 Decimal128(14, 4)
+
+statement ok
+drop table t;
+
+# avg_decimal_with_nulls
+statement ok
+create table t (c1 decimal(10, 0)) as values (1), (NULL), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3.25 Decimal128(14, 4)
+
+statement ok
+drop table t;
+
+# avg_decimal_all_nulls
+statement ok
+create table t (c1 decimal(10, 0)) as values (NULL), (NULL), (NULL), (NULL), 
(NULL), (NULL);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+NULL Decimal128(14, 4)
+
+statement ok
+drop table t;
+
+# avg_i32
+statement ok
+create table t (c1 int) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
+# avg_i32_with_nulls
+statement ok
+create table t (c1 int) as values (1), (NULL), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3.25 Float64
+
+statement ok
+drop table t;
+
+# avg_i32_all_nulls
+statement ok
+create table t (c1 int) as values (NULL), (NULL);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+NULL Float64
+
+statement ok
+drop table t;
+
+# avg_u32
+statement ok
+create table t (c1 int unsigned) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
+# avg_f32
+statement ok
+create table t (c1 float) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
+# avg_f64
+statement ok
+create table t (c1 double) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
 # simple_mean
 query R
 select mean(c1) from test


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to