This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ede3de849d Move average unit tests to slt (#10401)
ede3de849d is described below
commit ede3de849d4bbcf47fa3b8cdfe293b7cefea4afd
Author: 张林伟 <[email protected]>
AuthorDate: Tue May 7 15:02:24 2024 +0800
Move average unit tests to slt (#10401)
---
datafusion/physical-expr/src/aggregate/average.rs | 107 ---------------------
datafusion/physical-expr/src/expressions/mod.rs | 43 +--------
datafusion/sqllogictest/test_files/aggregate.slt | 108 ++++++++++++++++++++++
3 files changed, 109 insertions(+), 149 deletions(-)
diff --git a/datafusion/physical-expr/src/aggregate/average.rs
b/datafusion/physical-expr/src/aggregate/average.rs
index 065c2179f4..80fcc9b70c 100644
--- a/datafusion/physical-expr/src/aggregate/average.rs
+++ b/datafusion/physical-expr/src/aggregate/average.rs
@@ -567,110 +567,3 @@ where
+ self.sums.capacity() * std::mem::size_of::<T>()
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::expressions::tests::assert_aggregate;
- use arrow::array::*;
- use datafusion_expr::AggregateFunction;
-
- #[test]
- fn avg_decimal() {
- // test agg
- let array: ArrayRef = Arc::new(
- (1..7)
- .map(Some)
- .collect::<Decimal128Array>()
- .with_precision_and_scale(10, 0)
- .unwrap(),
- );
-
- assert_aggregate(
- array,
- AggregateFunction::Avg,
- false,
- ScalarValue::Decimal128(Some(35000), 14, 4),
- );
- }
-
- #[test]
- fn avg_decimal_with_nulls() {
- let array: ArrayRef = Arc::new(
- (1..6)
- .map(|i| if i == 2 { None } else { Some(i) })
- .collect::<Decimal128Array>()
- .with_precision_and_scale(10, 0)
- .unwrap(),
- );
- assert_aggregate(
- array,
- AggregateFunction::Avg,
- false,
- ScalarValue::Decimal128(Some(32500), 14, 4),
- );
- }
-
- #[test]
- fn avg_decimal_all_nulls() {
- // test agg
- let array: ArrayRef = Arc::new(
- std::iter::repeat::<Option<i128>>(None)
- .take(6)
- .collect::<Decimal128Array>()
- .with_precision_and_scale(10, 0)
- .unwrap(),
- );
- assert_aggregate(
- array,
- AggregateFunction::Avg,
- false,
- ScalarValue::Decimal128(None, 14, 4),
- );
- }
-
- #[test]
- fn avg_i32() {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
- assert_aggregate(a, AggregateFunction::Avg, false,
ScalarValue::from(3_f64));
- }
-
- #[test]
- fn avg_i32_with_nulls() {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![
- Some(1),
- None,
- Some(3),
- Some(4),
- Some(5),
- ]));
- assert_aggregate(a, AggregateFunction::Avg, false,
ScalarValue::from(3.25f64));
- }
-
- #[test]
- fn avg_i32_all_nulls() {
- let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None]));
- assert_aggregate(a, AggregateFunction::Avg, false,
ScalarValue::Float64(None));
- }
-
- #[test]
- fn avg_u32() {
- let a: ArrayRef =
- Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32,
5_u32]));
- assert_aggregate(a, AggregateFunction::Avg, false,
ScalarValue::from(3.0f64));
- }
-
- #[test]
- fn avg_f32() {
- let a: ArrayRef =
- Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32,
5_f32]));
- assert_aggregate(a, AggregateFunction::Avg, false,
ScalarValue::from(3_f64));
- }
-
- #[test]
- fn avg_f64() {
- let a: ArrayRef =
- Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64,
5_f64]));
- assert_aggregate(a, AggregateFunction::Avg, false,
ScalarValue::from(3_f64));
- }
-}
diff --git a/datafusion/physical-expr/src/expressions/mod.rs
b/datafusion/physical-expr/src/expressions/mod.rs
index 0cd2ac2c9e..3efa965d14 100644
--- a/datafusion/physical-expr/src/expressions/mod.rs
+++ b/datafusion/physical-expr/src/expressions/mod.rs
@@ -99,14 +99,11 @@ pub use try_cast::{try_cast, TryCastExpr};
pub(crate) mod tests {
use std::sync::Arc;
- use crate::expressions::{col, create_aggregate_expr, try_cast};
use crate::AggregateExpr;
use arrow::record_batch::RecordBatch;
use arrow_array::ArrayRef;
- use arrow_schema::{Field, Schema};
use datafusion_common::{Result, ScalarValue};
- use datafusion_expr::type_coercion::aggregates::coerce_types;
- use datafusion_expr::{AggregateFunction, EmitTo};
+ use datafusion_expr::EmitTo;
/// macro to perform an aggregation using [`datafusion_expr::Accumulator`]
and verify the
/// result.
@@ -201,44 +198,6 @@ pub(crate) mod tests {
}};
}
- /// Assert `function(array) == expected` performing any necessary type
coercion
- pub fn assert_aggregate(
- array: ArrayRef,
- function: AggregateFunction,
- distinct: bool,
- expected: ScalarValue,
- ) {
- let data_type = array.data_type();
- let sig = function.signature();
- let coerced = coerce_types(&function, &[data_type.clone()],
&sig).unwrap();
-
- let input_schema = Schema::new(vec![Field::new("a", data_type.clone(),
true)]);
- let batch =
- RecordBatch::try_new(Arc::new(input_schema.clone()),
vec![array]).unwrap();
-
- let input = try_cast(
- col("a", &input_schema).unwrap(),
- &input_schema,
- coerced[0].clone(),
- )
- .unwrap();
-
- let schema = Schema::new(vec![Field::new("a", coerced[0].clone(),
true)]);
- let agg = create_aggregate_expr(
- &function,
- distinct,
- &[input],
- &[],
- &schema,
- "agg",
- false,
- )
- .unwrap();
-
- let result = aggregate(&batch, agg).unwrap();
- assert_eq!(expected, result);
- }
-
/// macro to perform an aggregation with two inputs and verify the result.
#[macro_export]
macro_rules! generic_test_op2 {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index 50c9c5eb95..f9c41864c8 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -1704,6 +1704,114 @@ select avg(c1) from test
----
1.75
+# avg_decimal
+statement ok
+create table t (c1 decimal(10, 0)) as values (1), (2), (3), (4), (5), (6);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3.5 Decimal128(14, 4)
+
+statement ok
+drop table t;
+
+# avg_decimal_with_nulls
+statement ok
+create table t (c1 decimal(10, 0)) as values (1), (NULL), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3.25 Decimal128(14, 4)
+
+statement ok
+drop table t;
+
+# avg_decimal_all_nulls
+statement ok
+create table t (c1 decimal(10, 0)) as values (NULL), (NULL), (NULL), (NULL),
(NULL), (NULL);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+NULL Decimal128(14, 4)
+
+statement ok
+drop table t;
+
+# avg_i32
+statement ok
+create table t (c1 int) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
+# avg_i32_with_nulls
+statement ok
+create table t (c1 int) as values (1), (NULL), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3.25 Float64
+
+statement ok
+drop table t;
+
+# avg_i32_all_nulls
+statement ok
+create table t (c1 int) as values (NULL), (NULL);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+NULL Float64
+
+statement ok
+drop table t;
+
+# avg_u32
+statement ok
+create table t (c1 int unsigned) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
+# avg_f32
+statement ok
+create table t (c1 float) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
+# avg_f64
+statement ok
+create table t (c1 double) as values (1), (2), (3), (4), (5);
+
+query RT
+select avg(c1), arrow_typeof(avg(c1)) from t;
+----
+3 Float64
+
+statement ok
+drop table t;
+
# simple_mean
query R
select mean(c1) from test
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]