This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d6ab343791 Add helper function for processing scalar function input
(#8962)
d6ab343791 is described below
commit d6ab343791bfc37030bc02e70b5ff9a8e773e72e
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Jan 24 13:53:17 2024 -0800
Add helper function for processing scalar function input (#8962)
* Add helper function for scalar function
* Update datafusion/physical-expr/src/functions.rs
Co-authored-by: Andrew Lamb <[email protected]>
* Fix
* Fix
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-examples/examples/simple_udf.rs | 18 ++-------
.../src/simplify_expressions/expr_simplifier.rs | 27 ++-----------
datafusion/physical-expr/src/functions.rs | 46 ++++++++++++++++++++++
docs/source/library-user-guide/adding-udfs.md | 11 +++---
4 files changed, 59 insertions(+), 43 deletions(-)
diff --git a/datafusion-examples/examples/simple_udf.rs
b/datafusion-examples/examples/simple_udf.rs
index 491fac272c..dda6ba62e0 100644
--- a/datafusion-examples/examples/simple_udf.rs
+++ b/datafusion-examples/examples/simple_udf.rs
@@ -28,6 +28,7 @@ use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_common::cast::as_float64_array;
use datafusion_expr::ColumnarValue;
+use datafusion_physical_expr::functions::columnar_values_to_array;
use std::sync::Arc;
/// create local execution context with an in-memory table:
@@ -70,22 +71,11 @@ async fn main() -> Result<()> {
// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);
- // Try to obtain row number
- let len = args
- .iter()
- .fold(Option::<usize>::None, |acc, arg| match arg {
- ColumnarValue::Scalar(_) => acc,
- ColumnarValue::Array(a) => Some(a.len()),
- });
-
- let inferred_length = len.unwrap_or(1);
-
- let arg0 = args[0].clone().into_array(inferred_length)?;
- let arg1 = args[1].clone().into_array(inferred_length)?;
+ let args = columnar_values_to_array(args)?;
// 1. cast both arguments to f64. These casts MUST be aligned with the
signature or this function panics!
- let base = as_float64_array(&arg0).expect("cast failed");
- let exponent = as_float64_array(&arg1).expect("cast failed");
+ let base = as_float64_array(&args[0]).expect("cast failed");
+ let exponent = as_float64_array(&args[1]).expect("cast failed");
// this is guaranteed by DataFusion. We place it just to make it
obvious.
assert_eq!(exponent.len(), base.len());
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 561fe1d12d..1c12289491 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1376,6 +1376,7 @@ mod tests {
use datafusion_physical_expr::execution_props::ExecutionProps;
use chrono::{DateTime, TimeZone, Utc};
+ use datafusion_physical_expr::functions::columnar_values_to_array;
// ------------------------------
// --- ExprSimplifier tests -----
@@ -1489,30 +1490,10 @@ mod tests {
let return_type = Arc::new(DataType::Int32);
let fun = Arc::new(|args: &[ColumnarValue]| {
- let len = args
- .iter()
- .fold(Option::<usize>::None, |acc, arg| match arg {
- ColumnarValue::Scalar(_) => acc,
- ColumnarValue::Array(a) => Some(a.len()),
- });
-
- let inferred_length = len.unwrap_or(1);
-
- let arg0 = match &args[0] {
- ColumnarValue::Array(array) => array.clone(),
- ColumnarValue::Scalar(scalar) => {
- scalar.to_array_of_size(inferred_length).unwrap()
- }
- };
- let arg1 = match &args[1] {
- ColumnarValue::Array(array) => array.clone(),
- ColumnarValue::Scalar(scalar) => {
- scalar.to_array_of_size(inferred_length).unwrap()
- }
- };
+ let args = columnar_values_to_array(args)?;
- let arg0 = as_int32_array(&arg0)?;
- let arg1 = as_int32_array(&arg1)?;
+ let arg0 = as_int32_array(&args[0])?;
+ let arg1 = as_int32_array(&args[1])?;
// 2. perform the computation
let array = arg0
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index ac959dec6e..2bfdf49912 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -42,6 +42,7 @@ use arrow::{
compute::kernels::length::{bit_length, length},
datatypes::{DataType, Int32Type, Int64Type, Schema},
};
+use arrow_array::Array;
use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
pub use datafusion_expr::FuncMonotonicity;
use datafusion_expr::{
@@ -191,6 +192,51 @@ pub(crate) enum Hint {
AcceptsSingular,
}
+/// A helper function used to infer the length of arguments of Scalar
functions and convert
+/// [`ColumnarValue`]s to [`ArrayRef`]s with the inferred length. Note that
this function
+/// only works for functions that accept either that all arguments are scalars
or all arguments
+/// are arrays with same length. Otherwise, it will return an error.
+pub fn columnar_values_to_array(args: &[ColumnarValue]) ->
Result<Vec<ArrayRef>> {
+ if args.is_empty() {
+ return Ok(vec![]);
+ }
+
+ let len = args
+ .iter()
+ .fold(Option::<usize>::None, |acc, arg| match arg {
+ ColumnarValue::Scalar(_) if acc.is_none() => Some(1),
+ ColumnarValue::Scalar(_) => {
+ if let Some(1) = acc {
+ acc
+ } else {
+ None
+ }
+ }
+ ColumnarValue::Array(a) => {
+ if let Some(l) = acc {
+ if l == a.len() {
+ acc
+ } else {
+ None
+ }
+ } else {
+ Some(a.len())
+ }
+ }
+ });
+
+ let inferred_length = len.ok_or(DataFusionError::Internal(
+ "Arguments has mixed length".to_string(),
+ ))?;
+
+ let args = args
+ .iter()
+ .map(|arg| arg.clone().into_array(inferred_length))
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(args)
+}
+
/// Decorates a function to handle [`ScalarValue`]s by converting them to
arrays before calling the function
/// and vice-versa after evaluation.
/// Note that this function makes a scalar function with no arguments or all
scalar inputs return a scalar.
diff --git a/docs/source/library-user-guide/adding-udfs.md
b/docs/source/library-user-guide/adding-udfs.md
index 64dc25411d..1824b23f9f 100644
--- a/docs/source/library-user-guide/adding-udfs.md
+++ b/docs/source/library-user-guide/adding-udfs.md
@@ -41,12 +41,12 @@ use std::sync::Arc;
use datafusion::arrow::array::{ArrayRef, Int64Array};
use datafusion::common::Result;
-
use datafusion::common::cast::as_int64_array;
+use datafusion::physical_plan::functions::columnar_values_to_array;
-pub fn add_one(args: &[ArrayRef]) -> Result<ArrayRef> {
+pub fn add_one(args: &[ColumnarValue]) -> Result<ArrayRef> {
// Error handling omitted for brevity
-
+ let args = columnar_values_to_array(args)?;
let i64s = as_int64_array(&args[0])?;
let new_array = i64s
@@ -82,7 +82,6 @@ There is a lower level API with more functionality but is
more complex, that is
```rust
use datafusion::logical_expr::{Volatility, create_udf};
-use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::arrow::datatypes::DataType;
use std::sync::Arc;
@@ -91,13 +90,13 @@ let udf = create_udf(
vec![DataType::Int64],
Arc::new(DataType::Int64),
Volatility::Immutable,
- make_scalar_function(add_one),
+ Arc::new(add_one),
);
```
[`scalarudf`]:
https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
[`create_udf`]:
https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
-[`make_scalar_function`]:
https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
+[`process_scalar_func_inputs`]:
https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html
[`advanced_udf.rs`]:
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
A few things to note: