This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 32fe176128 Minor: Consolidate UDF tests (#7704)
32fe176128 is described below
commit 32fe1761287c9b58294a54805a0fcafc4ab046b8
Author: Andrew Lamb <[email protected]>
AuthorDate: Tue Oct 3 07:34:22 2023 -0400
Minor: Consolidate UDF tests (#7704)
* Minor: Consolidate user defined functions
* cleanup
* move more tests
* more
* cleanup use
---
datafusion/core/src/execution/context.rs | 86 +---------
datafusion/core/tests/sql/expr.rs | 16 +-
datafusion/core/tests/sql/mod.rs | 55 +------
datafusion/core/tests/user_defined/mod.rs | 3 +
.../tests/user_defined/user_defined_aggregates.rs | 94 +++++++++++
.../user_defined_scalar_functions.rs} | 174 ++++++++++++++-------
6 files changed, 222 insertions(+), 206 deletions(-)
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index 6cfb73a510..4bdd40a914 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -2222,16 +2222,13 @@ mod tests {
use crate::execution::context::QueryPlanner;
use crate::execution::memory_pool::MemoryConsumer;
use crate::execution::runtime_env::RuntimeConfig;
- use crate::physical_plan::expressions::AvgAccumulator;
use crate::test;
use crate::test_util::parquet_test_data;
use crate::variable::VarType;
- use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use arrow_schema::{Field, Schema};
use async_trait::async_trait;
- use datafusion_expr::{create_udaf, create_udf, Expr, Volatility};
- use datafusion_physical_expr::functions::make_scalar_function;
+ use datafusion_expr::Expr;
use std::fs::File;
use std::path::PathBuf;
use std::sync::Weak;
@@ -2330,87 +2327,6 @@ mod tests {
Ok(())
}
- #[tokio::test]
- async fn case_sensitive_identifiers_user_defined_functions() -> Result<()>
{
- let ctx = SessionContext::new();
- ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
- .unwrap();
-
- let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
- let myfunc = make_scalar_function(myfunc);
-
- ctx.register_udf(create_udf(
- "MY_FUNC",
- vec![DataType::Int32],
- Arc::new(DataType::Int32),
- Volatility::Immutable,
- myfunc,
- ));
-
- // doesn't work as it was registered with non lowercase
- let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
- .await
- .unwrap_err();
- assert!(err
- .to_string()
- .contains("Error during planning: Invalid function \'my_func\'"));
-
- // Can call it if you put quotes
- let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM
t").await?;
-
- let expected = [
- "+--------------+",
- "| MY_FUNC(t.i) |",
- "+--------------+",
- "| 1 |",
- "+--------------+",
- ];
- assert_batches_eq!(expected, &result);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn case_sensitive_identifiers_user_defined_aggregates() ->
Result<()> {
- let ctx = SessionContext::new();
- ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
- .unwrap();
-
- // Note capitalization
- let my_avg = create_udaf(
- "MY_AVG",
- vec![DataType::Float64],
- Arc::new(DataType::Float64),
- Volatility::Immutable,
- Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
- Arc::new(vec![DataType::UInt64, DataType::Float64]),
- );
-
- ctx.register_udaf(my_avg);
-
- // doesn't work as it was registered as non lowercase
- let err = plan_and_collect(&ctx, "SELECT MY_AVG(i) FROM t")
- .await
- .unwrap_err();
- assert!(err
- .to_string()
- .contains("Error during planning: Invalid function \'my_avg\'"));
-
- // Can call it if you put quotes
- let result = plan_and_collect(&ctx, "SELECT \"MY_AVG\"(i) FROM
t").await?;
-
- let expected = [
- "+-------------+",
- "| MY_AVG(t.i) |",
- "+-------------+",
- "| 1.0 |",
- "+-------------+",
- ];
- assert_batches_eq!(expected, &result);
-
- Ok(())
- }
-
#[tokio::test]
async fn query_csv_with_custom_partition_extension() -> Result<()> {
let tmp_dir = TempDir::new()?;
diff --git a/datafusion/core/tests/sql/expr.rs
b/datafusion/core/tests/sql/expr.rs
index 044b3b57ea..af33cfea65 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -616,7 +616,7 @@ async fn test_array_cast_expressions() -> Result<()> {
#[tokio::test]
async fn test_random_expression() -> Result<()> {
- let ctx = create_ctx();
+ let ctx = SessionContext::new();
let sql = "SELECT random() r1";
let actual = execute(&ctx, sql).await;
let r1 = actual[0][0].parse::<f64>().unwrap();
@@ -627,7 +627,7 @@ async fn test_random_expression() -> Result<()> {
#[tokio::test]
async fn test_uuid_expression() -> Result<()> {
- let ctx = create_ctx();
+ let ctx = SessionContext::new();
let sql = "SELECT uuid()";
let actual = execute(&ctx, sql).await;
let uuid = actual[0][0].parse::<uuid::Uuid>().unwrap();
@@ -886,18 +886,6 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> {
Ok(())
}
-#[tokio::test]
-async fn csv_query_avg_sqrt() -> Result<()> {
- let ctx = create_ctx();
- register_aggregate_csv(&ctx).await?;
- let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100";
- let mut actual = execute(&ctx, sql).await;
- actual.sort();
- let expected = vec![vec!["0.6706002946036462"]];
- assert_float_eq(&expected, &actual);
- Ok(())
-}
-
#[tokio::test]
async fn nested_subquery() -> Result<()> {
let ctx = SessionContext::new();
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 7d175b6526..4529889270 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -26,6 +26,7 @@ use chrono::prelude::*;
use chrono::Duration;
use datafusion::datasource::TableProvider;
+use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{Aggregate, LogicalPlan, TableScan};
use datafusion::physical_plan::metrics::MetricValue;
use datafusion::physical_plan::ExecutionPlan;
@@ -34,15 +35,9 @@ use datafusion::prelude::*;
use datafusion::test_util;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion::{datasource::MemTable, physical_plan::collect};
-use datafusion::{
- error::{DataFusionError, Result},
- physical_plan::ColumnarValue,
-};
use datafusion::{execution::context::SessionContext,
physical_plan::displayable};
-use datafusion_common::cast::as_float64_array;
use datafusion_common::plan_err;
use datafusion_common::{assert_contains, assert_not_contains};
-use datafusion_expr::Volatility;
use object_store::path::Path;
use std::fs::File;
use std::io::Write;
@@ -101,54 +96,6 @@ pub mod select;
mod sql_api;
pub mod subqueries;
pub mod timestamp;
-pub mod udf;
-
-fn assert_float_eq<T>(expected: &[Vec<T>], received: &[Vec<String>])
-where
- T: AsRef<str>,
-{
- expected
- .iter()
- .flatten()
- .zip(received.iter().flatten())
- .for_each(|(l, r)| {
- let (l, r) = (
- l.as_ref().parse::<f64>().unwrap(),
- r.as_str().parse::<f64>().unwrap(),
- );
- if l.is_nan() || r.is_nan() {
- assert!(l.is_nan() && r.is_nan());
- } else if (l - r).abs() > 2.0 * f64::EPSILON {
- panic!("{l} != {r}")
- }
- });
-}
-
-fn create_ctx() -> SessionContext {
- let ctx = SessionContext::new();
-
- // register a custom UDF
- ctx.register_udf(create_udf(
- "custom_sqrt",
- vec![DataType::Float64],
- Arc::new(DataType::Float64),
- Volatility::Immutable,
- Arc::new(custom_sqrt),
- ));
-
- ctx
-}
-
-fn custom_sqrt(args: &[ColumnarValue]) -> Result<ColumnarValue> {
- let arg = &args[0];
- if let ColumnarValue::Array(v) = arg {
- let input = as_float64_array(v).expect("cast failed");
- let array: Float64Array = input.iter().map(|v| v.map(|x|
x.sqrt())).collect();
- Ok(ColumnarValue::Array(Arc::new(array)))
- } else {
- unimplemented!()
- }
-}
fn create_join_context(
column_left: &str,
diff --git a/datafusion/core/tests/user_defined/mod.rs
b/datafusion/core/tests/user_defined/mod.rs
index ab6f51c47b..09c7c3d326 100644
--- a/datafusion/core/tests/user_defined/mod.rs
+++ b/datafusion/core/tests/user_defined/mod.rs
@@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.
+/// Tests for user defined Scalar functions
+mod user_defined_scalar_functions;
+
/// Tests for User Defined Aggregate Functions
mod user_defined_aggregates;
diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
index 3b7b4d0e87..fb0ecd02c6 100644
--- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs
+++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs
@@ -19,11 +19,14 @@
//! user defined aggregate functions
use arrow::{array::AsArray, datatypes::Fields};
+use arrow_array::Int32Array;
+use arrow_schema::Schema;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
+use datafusion::datasource::MemTable;
use datafusion::{
arrow::{
array::{ArrayRef, Float64Array, TimestampNanosecondArray},
@@ -43,6 +46,8 @@ use datafusion::{
use datafusion_common::{
assert_contains, cast::as_primitive_array, exec_err, DataFusionError,
};
+use datafusion_expr::create_udaf;
+use datafusion_physical_expr::expressions::AvgAccumulator;
/// Test to show the contents of the setup
#[tokio::test]
@@ -204,6 +209,95 @@ async fn execute(ctx: &SessionContext, sql: &str) ->
Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
}
+/// tests the creation, registration and usage of a UDAF
+#[tokio::test]
+async fn simple_udaf() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
+
+ let batch1 = RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
+ )?;
+ let batch2 = RecordBatch::try_new(
+ Arc::new(schema.clone()),
+ vec![Arc::new(Int32Array::from(vec![4, 5]))],
+ )?;
+
+ let ctx = SessionContext::new();
+
+ let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1],
vec![batch2]])?;
+ ctx.register_table("t", Arc::new(provider))?;
+
+ // define a udaf, using a DataFusion's accumulator
+ let my_avg = create_udaf(
+ "my_avg",
+ vec![DataType::Float64],
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+ Arc::new(vec![DataType::UInt64, DataType::Float64]),
+ );
+
+ ctx.register_udaf(my_avg);
+
+ let result = ctx.sql("SELECT MY_AVG(a) FROM t").await?.collect().await?;
+
+ let expected = [
+ "+-------------+",
+ "| my_avg(t.a) |",
+ "+-------------+",
+ "| 3.0 |",
+ "+-------------+",
+ ];
+ assert_batches_eq!(expected, &result);
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
+ let ctx = SessionContext::new();
+ let arr = Int32Array::from(vec![1]);
+ let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+ ctx.register_batch("t", batch).unwrap();
+
+ // Note capitalization
+ let my_avg = create_udaf(
+ "MY_AVG",
+ vec![DataType::Float64],
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
+ Arc::new(vec![DataType::UInt64, DataType::Float64]),
+ );
+
+ ctx.register_udaf(my_avg);
+
+ // doesn't work as it was registered as non lowercase
+ let err = ctx.sql("SELECT MY_AVG(i) FROM t").await.unwrap_err();
+ assert!(err
+ .to_string()
+ .contains("Error during planning: Invalid function \'my_avg\'"));
+
+ // Can call it if you put quotes
+ let result = ctx
+ .sql("SELECT \"MY_AVG\"(i) FROM t")
+ .await?
+ .collect()
+ .await?;
+
+ let expected = [
+ "+-------------+",
+ "| MY_AVG(t.i) |",
+ "+-------------+",
+ "| 1.0 |",
+ "+-------------+",
+ ];
+ assert_batches_eq!(expected, &result);
+
+ Ok(())
+}
+
/// Returns an context with a table "t" and the "first" and "time_sum"
/// aggregate functions registered.
///
diff --git a/datafusion/core/tests/sql/udf.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
similarity index 68%
rename from datafusion/core/tests/sql/udf.rs
rename to datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index 97512d0249..1c7e713729 100644
--- a/datafusion/core/tests/sql/udf.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -15,26 +15,56 @@
// specific language governing permissions and limitations
// under the License.
-use super::*;
use arrow::compute::kernels::numeric::add;
+use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
+use arrow_schema::{DataType, Field, Schema};
+use datafusion::prelude::*;
use datafusion::{
execution::registry::FunctionRegistry,
- physical_plan::{expressions::AvgAccumulator,
functions::make_scalar_function},
+ physical_plan::functions::make_scalar_function, test_util,
};
-use datafusion_common::{cast::as_int32_array, ScalarValue};
-use datafusion_expr::{create_udaf, Accumulator, LogicalPlanBuilder};
+use datafusion_common::cast::as_float64_array;
+use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result,
ScalarValue};
+use datafusion_expr::{
+ create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder,
Volatility,
+};
+use std::sync::Arc;
/// test that casting happens on udfs.
/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the
logical plan and
/// physical plan have the same schema.
#[tokio::test]
async fn csv_query_custom_udf_with_cast() -> Result<()> {
- let ctx = create_ctx();
+ let ctx = create_udf_context();
register_aggregate_csv(&ctx).await?;
let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100";
- let actual = execute(&ctx, sql).await;
- let expected = vec![vec!["0.6584408483418833"]];
- assert_float_eq(&expected, &actual);
+ let actual = plan_and_collect(&ctx, sql).await.unwrap();
+ let expected = [
+ "+------------------------------------------+",
+ "| AVG(custom_sqrt(aggregate_test_100.c11)) |",
+ "+------------------------------------------+",
+ "| 0.6584408483418833 |",
+ "+------------------------------------------+",
+ ];
+ assert_batches_eq!(&expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_avg_sqrt() -> Result<()> {
+ let ctx = create_udf_context();
+ register_aggregate_csv(&ctx).await?;
+ // Note it is a different column (c12) than above (c11)
+ let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100";
+ let actual = plan_and_collect(&ctx, sql).await.unwrap();
+ let expected = [
+ "+------------------------------------------+",
+ "| AVG(custom_sqrt(aggregate_test_100.c12)) |",
+ "+------------------------------------------+",
+ "| 0.6706002946036462 |",
+ "+------------------------------------------+",
+ ];
+ assert_batches_eq!(&expected, &actual);
Ok(())
}
@@ -212,51 +242,6 @@ async fn scalar_udf_override_built_in_scalar_function() ->
Result<()> {
Ok(())
}
-/// tests the creation, registration and usage of a UDAF
-#[tokio::test]
-async fn simple_udaf() -> Result<()> {
- let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
-
- let batch1 = RecordBatch::try_new(
- Arc::new(schema.clone()),
- vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
- )?;
- let batch2 = RecordBatch::try_new(
- Arc::new(schema.clone()),
- vec![Arc::new(Int32Array::from(vec![4, 5]))],
- )?;
-
- let ctx = SessionContext::new();
-
- let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1],
vec![batch2]])?;
- ctx.register_table("t", Arc::new(provider))?;
-
- // define a udaf, using a DataFusion's accumulator
- let my_avg = create_udaf(
- "my_avg",
- vec![DataType::Float64],
- Arc::new(DataType::Float64),
- Volatility::Immutable,
- Arc::new(|_| Ok(Box::<AvgAccumulator>::default())),
- Arc::new(vec![DataType::UInt64, DataType::Float64]),
- );
-
- ctx.register_udaf(my_avg);
-
- let result = plan_and_collect(&ctx, "SELECT MY_AVG(a) FROM t").await?;
-
- let expected = [
- "+-------------+",
- "| my_avg(t.a) |",
- "+-------------+",
- "| 3.0 |",
- "+-------------+",
- ];
- assert_batches_eq!(expected, &result);
-
- Ok(())
-}
-
#[tokio::test]
async fn udaf_as_window_func() -> Result<()> {
#[derive(Debug)]
@@ -314,3 +299,86 @@ async fn udaf_as_window_func() -> Result<()> {
assert_eq!(format!("{:?}", dataframe.logical_plan()), expected);
Ok(())
}
+
+#[tokio::test]
+async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
+ let ctx = SessionContext::new();
+ let arr = Int32Array::from(vec![1]);
+ let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
+ ctx.register_batch("t", batch).unwrap();
+
+ let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
+ let myfunc = make_scalar_function(myfunc);
+
+ ctx.register_udf(create_udf(
+ "MY_FUNC",
+ vec![DataType::Int32],
+ Arc::new(DataType::Int32),
+ Volatility::Immutable,
+ myfunc,
+ ));
+
+ // doesn't work as it was registered with non lowercase
+ let err = plan_and_collect(&ctx, "SELECT MY_FUNC(i) FROM t")
+ .await
+ .unwrap_err();
+ assert!(err
+ .to_string()
+ .contains("Error during planning: Invalid function \'my_func\'"));
+
+ // Can call it if you put quotes
+ let result = plan_and_collect(&ctx, "SELECT \"MY_FUNC\"(i) FROM t").await?;
+
+ let expected = [
+ "+--------------+",
+ "| MY_FUNC(t.i) |",
+ "+--------------+",
+ "| 1 |",
+ "+--------------+",
+ ];
+ assert_batches_eq!(expected, &result);
+
+ Ok(())
+}
+
+fn create_udf_context() -> SessionContext {
+ let ctx = SessionContext::new();
+ // register a custom UDF
+ ctx.register_udf(create_udf(
+ "custom_sqrt",
+ vec![DataType::Float64],
+ Arc::new(DataType::Float64),
+ Volatility::Immutable,
+ Arc::new(custom_sqrt),
+ ));
+
+ ctx
+}
+
+fn custom_sqrt(args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ let arg = &args[0];
+ if let ColumnarValue::Array(v) = arg {
+ let input = as_float64_array(v).expect("cast failed");
+ let array: Float64Array = input.iter().map(|v| v.map(|x|
x.sqrt())).collect();
+ Ok(ColumnarValue::Array(Arc::new(array)))
+ } else {
+ unimplemented!()
+ }
+}
+
+async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> {
+ let testdata = datafusion::test_util::arrow_test_data();
+ let schema = test_util::aggr_test_schema();
+ ctx.register_csv(
+ "aggregate_test_100",
+ &format!("{testdata}/csv/aggregate_test_100.csv"),
+ CsvReadOptions::new().schema(&schema),
+ )
+ .await?;
+ Ok(())
+}
+
+/// Execute SQL and return results as a RecordBatch
+async fn plan_and_collect(ctx: &SessionContext, sql: &str) ->
Result<Vec<RecordBatch>> {
+ ctx.sql(sql).await?.collect().await
+}