This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 32fe176128 Minor: Consolidate UDF tests (#7704)
32fe176128 is described below

commit 32fe1761287c9b58294a54805a0fcafc4ab046b8
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Oct 3 07:34:22 2023 -0400

    Minor: Consolidate UDF tests (#7704)
    
    * Minor: Consolidate user defined functions
    
    * cleanup
    
    * move more tests
    
    * more
    
    * cleanup use
---
 datafusion/core/src/execution/context.rs           |  86 +---------
 datafusion/core/tests/sql/expr.rs                  |  16 +-
 datafusion/core/tests/sql/mod.rs                   |  55 +------
 datafusion/core/tests/user_defined/mod.rs          |   3 +
 .../tests/user_defined/user_defined_aggregates.rs  |  94 +++++++++++
 .../user_defined_scalar_functions.rs}              | 174 ++++++++++++++-------
 6 files changed, 222 insertions(+), 206 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index 6cfb73a510..4bdd40a914 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -2222,16 +2222,13 @@ mod tests {
     use crate::execution::context::QueryPlanner;
     use crate::execution::memory_pool::MemoryConsumer;
     use crate::execution::runtime_env::RuntimeConfig;
-    use crate::physical_plan::expressions::AvgAccumulator;
     use crate::test;
     use crate::test_util::parquet_test_data;
     use crate::variable::VarType;
-    use arrow::array::ArrayRef;
     use arrow::record_batch::RecordBatch;
     use arrow_schema::{Field, Schema};
     use async_trait::async_trait;
-    use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
-    use datafusion_physical_expr::functions::make_scalar_function;
+    use datafusion_expr::Expr;
     use std::fs::File;
     use std::path::PathBuf;
     use std::sync::Weak;
@@ -2330,87 +2327,6 @@ mod tests {
         Ok(())
     }
 
-    #[tokio::test]
-    async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> 
{
-        let ctx = SessionContext::new();
-        ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
-            .unwrap();
-
-        let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
-        let myfunc = make_scalar_function(myfunc);
-
-        ctx.register_udf(create_udf(
-            "MY_FUNC",
-            vec![DataType::Int32],
-            Arc::new(DataType::Int32),
-            Volatility::Immutable,
-            myfunc,
-        ));
-
-        // doesn't work as it was registered with non lowercase
-        let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
-            .await
-            .unwrap_err();
-        assert!(err
-            .to_string()
-            .contains("Error during planning: Invalid function \'my_func\'"));
-
-        // Can call it if you put quotes
-        let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM 
t").await?;
-
-        let expected = [
-            "+--------------+",
-            "| MY_FUNC(t.i) |",
-            "+--------------+",
-            "| 1            |",
-            "+--------------+",
-        ];
-        assert_batches_eq!(expected, &result);
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn case_sensitive_identifiers_user_defined_aggregates() -> 
Result<()> {
-        let ctx = SessionContext::new();
-        ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
-            .unwrap();
-
-        // Note capitalization
-        let my_avg = create_udaf(
-            "MY_AVG",
-            vec![DataType::Float64],
-            Arc::new(DataType::Float64),
-            Volatility::Immutable,
-            Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
-            Arc::new(vec![DataType::UInt64, DataType::Float64]),
-        );
-
-        ctx.register_udaf(my_avg);
-
-        // doesn't work as it was registered as non lowercase
-        let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
-            .await
-            .unwrap_err();
-        assert!(err
-            .to_string()
-            .contains("Error during planning: Invalid function \'my_avg\'"));
-
-        // Can call it if you put quotes
-        let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM 
t").await?;
-
-        let expected = [
-            "+-------------+",
-            "| MY_AVG(t.i) |",
-            "+-------------+",
-            "| 1.0         |",
-            "+-------------+",
-        ];
-        assert_batches_eq!(expected, &result);
-
-        Ok(())
-    }
-
     #[tokio::test]
     async fn query_csv_with_custom_partition_extension() -> Result<()> {
         let tmp_dir = TempDir::new()?;
diff --git a/datafusion/core/tests/sql/expr.rs 
b/datafusion/core/tests/sql/expr.rs
index 044b3b57ea..af33cfea65 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -616,7 +616,7 @@ async fn test_array_cast_expressions() -> Result<()> {
 
 #[tokio::test]
 async fn test_random_expression() -> Result<()> {
-    let ctx = create_ctx();
+    let ctx = SessionContext::new();
     let sql = "SELECT random() r1";
     let actual = execute(&ctx, sql).await;
     let r1 = actual[0][0].parse::<f64>().unwrap();
@@ -627,7 +627,7 @@ async fn test_random_expression() -> Result<()> {
 
 #[tokio::test]
 async fn test_uuid_expression() -> Result<()> {
-    let ctx = create_ctx();
+    let ctx = SessionContext::new();
     let sql = "SELECT uuid()";
     let actual = execute(&ctx, sql).await;
     let uuid = actual[0][0].parse::<uuid::Uuid>().unwrap();
@@ -886,18 +886,6 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> {
     Ok(())
 }
 
-#[tokio::test]
-async fn csv_query_avg_sqrt() -> Result<()> {
-    let ctx = create_ctx();
-    register_aggregate_csv(&ctx).await?;
-    let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100";
-    let mut actual = execute(&ctx, sql).await;
-    actual.sort();
-    let expected = vec![vec!["0.6706002946036462"]];
-    assert_float_eq(&expected, &actual);
-    Ok(())
-}
-
 #[tokio::test]
 async fn nested_subquery() -> Result<()> {
     let ctx = SessionContext::new();
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 7d175b6526..4529889270 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -26,6 +26,7 @@ use chrono::prelude::*;
 use chrono::Duration;
 
 use datafusion::datasource::TableProvider;
+use datafusion::error::{DataFusionError, Result};
 use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan};
 use datafusion::physical_plan::metrics::MetricValue;
 use datafusion::physical_plan::ExecutionPlan;
@@ -34,15 +35,9 @@ use datafusion::prelude::*;
 use datafusion::test_util;
 use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
 use datafusion::{datasource::MemTable, physical_plan::collect};
-use datafusion::{
-    error::{DataFusionError, Result},
-    physical_plan::ColumnarValue,
-};
 use datafusion::{execution::context::SessionContext, 
physical_plan::displayable};
-use datafusion_common::cast::as_float64_array;
 use datafusion_common::plan_err;
 use datafusion_common::{assert_contains, assert_not_contains};
-use datafusion_expr::Volatility;
 use object_store::path::Path;
 use std::fs::File;
 use std::io::Write;
@@ -101,54 +96,6 @@ pub mod select;
 mod sql_api;
 pub mod subqueries;
 pub mod timestamp;
-pub mod udf;
-
-fn assert_float_eq<T>(expected: &[Vec<T>], received: &[Vec<String>])
-where
-    T: AsRef<str>,
-{
-    expected
-        .iter()
-        .flatten()
-        .zip(received.iter().flatten())
-        .for_each(|(l, r)| {
-            let (l, r) = (
-                l.as_ref().parse::<f64>().unwrap(),
-                r.as_str().parse::<f64>().unwrap(),
-            );
-            if l.is_nan() || r.is_nan() {
-                assert!(l.is_nan() && r.is_nan());
-            } else if (l - r).abs() > 2.0 * f64::EPSILON {
-                panic!("{l} != {r}")
-            }
-        });
-}
-
-fn create_ctx() -> SessionContext {
-    let ctx = SessionContext::new();
-
-    // register a custom UDF
-    ctx.register_udf(create_udf(
-        "custom_sqrt",
-        vec![DataType::Float64],
-        Arc::new(DataType::Float64),
-        Volatility::Immutable,
-        Arc::new(custom_sqrt),
-    ));
-
-    ctx
-}
-
-fn custom_sqrt(args: &[ColumnarValue]) -> Result<ColumnarValue> {
-    let arg = &args[0];
-    if let ColumnarValue::Array(v) = arg {
-        let input = as_float64_array(v).expect("cast failed");
-        let array: Float64Array = input.iter().map(|v| v.map(|x| 
x.sqrt())).collect();
-        Ok(ColumnarValue::Array(Arc::new(array)))
-    } else {
-        unimplemented!()
-    }
-}
 
 fn create_join_context(
     column_left: &str,
diff --git a/datafusion/core/tests/user_defined/mod.rs 
b/datafusion/core/tests/user_defined/mod.rs
index ab6f51c47b..09c7c3d326 100644
--- a/datafusion/core/tests/user_defined/mod.rs
+++ b/datafusion/core/tests/user_defined/mod.rs
@@ -15,6 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
+/// Tests for user defined Scalar functions
+mod user_defined_scalar_functions;
+
 /// Tests for User Defined Aggregate Functions
 mod user_defined_aggregates;
 
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs 
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 3b7b4d0e87..fb0ecd02c6 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -19,11 +19,14 @@
 //! user defined aggregate functions
 
 use arrow::{array::AsArray, datatypes::Fields};
+use arrow_array::Int32Array;
+use arrow_schema::Schema;
 use std::sync::{
     atomic::{AtomicBool, Ordering},
     Arc,
 };
 
+use datafusion::datasource::MemTable;
 use datafusion::{
     arrow::{
         array::{ArrayRef, Float64Array, TimestampNanosecondArray},
@@ -43,6 +46,8 @@ use datafusion::{
 use datafusion_common::{
     assert_contains, cast::as_primitive_array, exec_err, DataFusionError,
 };
+use datafusion_expr::create_udaf;
+use datafusion_physical_expr::expressions::AvgAccumulator;
 
 /// Test to show the contents of the setup
 #[tokio::test]
@@ -204,6 +209,95 @@ async fn execute(ctx: &SessionContext, sql: &str) -> 
Result<Vec<RecordBatch>> {
     ctx.sql(sql).await?.collect().await
 }
 
+/// tests the creation, registration and usage of a UDAF
+#[tokio::test]
+async fn simple_udaf() -> Result<()> {
+    let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+    let batch1 = RecordBatch::try_new(
+        Arc::new(schema.clone()),
+        vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
+    )?;
+    let batch2 = RecordBatch::try_new(
+        Arc::new(schema.clone()),
+        vec![Arc::new(Int32Array::from(vec![4, 5]))],
+    )?;
+
+    let ctx = SessionContext::new();
+
+    let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], 
vec![batch2]])?;
+    ctx.register_table("t", Arc::new(provider))?;
+
+    // define a udaf, using a DataFusion's accumulator
+    let my_avg = create_udaf(
+        "my_avg",
+        vec![DataType::Float64],
+        Arc::new(DataType::Float64),
+        Volatility::Immutable,
+        Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+        Arc::new(vec![DataType::UInt64, DataType::Float64]),
+    );
+
+    ctx.register_udaf(my_avg);
+
+    let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?;
+
+    let expected = [
+        "+-------------+",
+        "| my_avg(t.a) |",
+        "+-------------+",
+        "| 3.0         |",
+        "+-------------+",
+    ];
+    assert_batches_eq!(expected, &result);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
+    let ctx = SessionContext::new();
+    let arr = Int32Array::from(vec![1]);
+    let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+    ctx.register_batch("t", batch).unwrap();
+
+    // Note capitalization
+    let my_avg = create_udaf(
+        "MY_AVG",
+        vec![DataType::Float64],
+        Arc::new(DataType::Float64),
+        Volatility::Immutable,
+        Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+        Arc::new(vec![DataType::UInt64, DataType::Float64]),
+    );
+
+    ctx.register_udaf(my_avg);
+
+    // doesn't work as it was registered as non lowercase
+    let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err();
+    assert!(err
+        .to_string()
+        .contains("Error during planning: Invalid function \'my_avg\'"));
+
+    // Can call it if you put quotes
+    let result = ctx
+        .sql("SELECT \"MY_AVG\"(i) FROM t")
+        .await?
+        .collect()
+        .await?;
+
+    let expected = [
+        "+-------------+",
+        "| MY_AVG(t.i) |",
+        "+-------------+",
+        "| 1.0         |",
+        "+-------------+",
+    ];
+    assert_batches_eq!(expected, &result);
+
+    Ok(())
+}
+
 /// Returns an context with a table "t" and the "first" and "time_sum"
 /// aggregate functions registered.
 ///
diff --git a/datafusion/core/tests/sql/udf.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
similarity index 68%
rename from datafusion/core/tests/sql/udf.rs
rename to datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 97512d0249..1c7e713729 100644
--- a/datafusion/core/tests/sql/udf.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -15,26 +15,56 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use super::*;
 use arrow::compute::kernels::numeric::add;
+use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
+use arrow_schema::{DataType, Field, Schema};
+use datafusion::prelude::*;
 use datafusion::{
     execution::registry::FunctionRegistry,
-    physical_plan::{expressions::AvgAccumulator, 
functions::make_scalar_function},
+    physical_plan::functions::make_scalar_function, test_util,
 };
-use datafusion_common::{cast::as_int32_array, ScalarValue};
-use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder};
+use datafusion_common::cast::as_float64_array;
+use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, 
ScalarValue};
+use datafusion_expr::{
+    create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, 
Volatility,
+};
+use std::sync::Arc;
 
 /// test that casting happens on udfs.
 /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the 
logical plan and
 /// physical plan have the same schema.
 #[tokio::test]
 async fn csv_query_custom_udf_with_cast() -> Result<()> {
-    let ctx = create_ctx();
+    let ctx = create_udf_context();
     register_aggregate_csv(&ctx).await?;
     let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100";
-    let actual = execute(&ctx, sql).await;
-    let expected = vec![vec!["0.6584408483418833"]];
-    assert_float_eq(&expected, &actual);
+    let actual = plan_and_collect(&ctx, sql).await.unwrap();
+    let expected = [
+        "+------------------------------------------+",
+        "| AVG(custom_sqrt(aggregate_test_100.c11)) |",
+        "+------------------------------------------+",
+        "| 0.6584408483418833                       |",
+        "+------------------------------------------+",
+    ];
+    assert_batches_eq!(&expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_avg_sqrt() -> Result<()> {
+    let ctx = create_udf_context();
+    register_aggregate_csv(&ctx).await?;
+    // Note it is a different column (c12) than above (c11)
+    let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100";
+    let actual = plan_and_collect(&ctx, sql).await.unwrap();
+    let expected = [
+        "+------------------------------------------+",
+        "| AVG(custom_sqrt(aggregate_test_100.c12)) |",
+        "+------------------------------------------+",
+        "| 0.6706002946036462                       |",
+        "+------------------------------------------+",
+    ];
+    assert_batches_eq!(&expected, &actual);
     Ok(())
 }
 
@@ -212,51 +242,6 @@ async fn scalar_udf_override_built_in_scalar_function() -> 
Result<()> {
     Ok(())
 }
 
-/// tests the creation, registration and usage of a UDAF
-#[tokio::test]
-async fn simple_udaf() -> Result<()> {
-    let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
-
-    let batch1 = RecordBatch::try_new(
-        Arc::new(schema.clone()),
-        vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
-    )?;
-    let batch2 = RecordBatch::try_new(
-        Arc::new(schema.clone()),
-        vec![Arc::new(Int32Array::from(vec![4, 5]))],
-    )?;
-
-    let ctx = SessionContext::new();
-
-    let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], 
vec![batch2]])?;
-    ctx.register_table("t", Arc::new(provider))?;
-
-    // define a udaf, using a DataFusion's accumulator
-    let my_avg = create_udaf(
-        "my_avg",
-        vec![DataType::Float64],
-        Arc::new(DataType::Float64),
-        Volatility::Immutable,
-        Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
-        Arc::new(vec![DataType::UInt64, DataType::Float64]),
-    );
-
-    ctx.register_udaf(my_avg);
-
-    let result = plan_and_collect(&ctx, "SELECT MY_AVG(a) FROM t").await?;
-
-    let expected = [
-        "+-------------+",
-        "| my_avg(t.a) |",
-        "+-------------+",
-        "| 3.0         |",
-        "+-------------+",
-    ];
-    assert_batches_eq!(expected, &result);
-
-    Ok(())
-}
-
 #[tokio::test]
 async fn udaf_as_window_func() -> Result<()> {
     #[derive(Debug)]
@@ -314,3 +299,86 @@ async fn udaf_as_window_func() -> Result<()> {
     assert_eq!(format!("{:?}", dataframe.logical_plan()), expected);
     Ok(())
 }
+
+#[tokio::test]
+async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
+    let ctx = SessionContext::new();
+    let arr = Int32Array::from(vec![1]);
+    let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+    ctx.register_batch("t", batch).unwrap();
+
+    let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
+    let myfunc = make_scalar_function(myfunc);
+
+    ctx.register_udf(create_udf(
+        "MY_FUNC",
+        vec![DataType::Int32],
+        Arc::new(DataType::Int32),
+        Volatility::Immutable,
+        myfunc,
+    ));
+
+    // doesn't work as it was registered with non lowercase
+    let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
+        .await
+        .unwrap_err();
+    assert!(err
+        .to_string()
+        .contains("Error during planning: Invalid function \'my_func\'"));
+
+    // Can call it if you put quotes
+    let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;
+
+    let expected = [
+        "+--------------+",
+        "| MY_FUNC(t.i) |",
+        "+--------------+",
+        "| 1            |",
+        "+--------------+",
+    ];
+    assert_batches_eq!(expected, &result);
+
+    Ok(())
+}
+
+fn create_udf_context() -> SessionContext {
+    let ctx = SessionContext::new();
+    // register a custom UDF
+    ctx.register_udf(create_udf(
+        "custom_sqrt",
+        vec![DataType::Float64],
+        Arc::new(DataType::Float64),
+        Volatility::Immutable,
+        Arc::new(custom_sqrt),
+    ));
+
+    ctx
+}
+
+fn custom_sqrt(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+    let arg = &args[0];
+    if let ColumnarValue::Array(v) = arg {
+        let input = as_float64_array(v).expect("cast failed");
+        let array: Float64Array = input.iter().map(|v| v.map(|x| 
x.sqrt())).collect();
+        Ok(ColumnarValue::Array(Arc::new(array)))
+    } else {
+        unimplemented!()
+    }
+}
+
+async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> {
+    let testdata = datafusion::test_util::arrow_test_data();
+    let schema = test_util::aggr_test_schema();
+    ctx.register_csv(
+        "aggregate_test_100",
+        &format!("{testdata}/csv/aggregate_test_100.csv"),
+        CsvReadOptions::new().schema(&schema),
+    )
+    .await?;
+    Ok(())
+}
+
+/// Execute SQL and return results as a RecordBatch
+async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> 
Result<Vec<RecordBatch>> {
+    ctx.sql(sql).await?.collect().await
+}

Reply via email to