This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new d6ab343791 Add helper function for processing scalar function input 
(#8962)
d6ab343791 is described below

commit d6ab343791bfc37030bc02e70b5ff9a8e773e72e
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Jan 24 13:53:17 2024 -0800

    Add helper function for processing scalar function input (#8962)
    
    * Add helper function for scalar function
    
    * Update datafusion/physical-expr/src/functions.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * Fix
    
    * Fix
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion-examples/examples/simple_udf.rs         | 18 ++-------
 .../src/simplify_expressions/expr_simplifier.rs    | 27 ++-----------
 datafusion/physical-expr/src/functions.rs          | 46 ++++++++++++++++++++++
 docs/source/library-user-guide/adding-udfs.md      | 11 +++---
 4 files changed, 59 insertions(+), 43 deletions(-)

diff --git a/datafusion-examples/examples/simple_udf.rs 
b/datafusion-examples/examples/simple_udf.rs
index 491fac272c..dda6ba62e0 100644
--- a/datafusion-examples/examples/simple_udf.rs
+++ b/datafusion-examples/examples/simple_udf.rs
@@ -28,6 +28,7 @@ use datafusion::error::Result;
 use datafusion::prelude::*;
 use datafusion_common::cast::as_float64_array;
 use datafusion_expr::ColumnarValue;
+use datafusion_physical_expr::functions::columnar_values_to_array;
 use std::sync::Arc;
 
 /// create local execution context with an in-memory table:
@@ -70,22 +71,11 @@ async fn main() -> Result<()> {
         // this is guaranteed by DataFusion based on the function's signature.
         assert_eq!(args.len(), 2);
 
-        // Try to obtain row number
-        let len = args
-            .iter()
-            .fold(Option::<usize>::None, |acc, arg| match arg {
-                ColumnarValue::Scalar(_) => acc,
-                ColumnarValue::Array(a) => Some(a.len()),
-            });
-
-        let inferred_length = len.unwrap_or(1);
-
-        let arg0 = args[0].clone().into_array(inferred_length)?;
-        let arg1 = args[1].clone().into_array(inferred_length)?;
+        let args = columnar_values_to_array(args)?;
 
         // 1. cast both arguments to f64. These casts MUST be aligned with the 
signature or this function panics!
-        let base = as_float64_array(&arg0).expect("cast failed");
-        let exponent = as_float64_array(&arg1).expect("cast failed");
+        let base = as_float64_array(&args[0]).expect("cast failed");
+        let exponent = as_float64_array(&args[1]).expect("cast failed");
 
         // this is guaranteed by DataFusion. We place it just to make it 
obvious.
         assert_eq!(exponent.len(), base.len());
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 561fe1d12d..1c12289491 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -1376,6 +1376,7 @@ mod tests {
     use datafusion_physical_expr::execution_props::ExecutionProps;
 
     use chrono::{DateTime, TimeZone, Utc};
+    use datafusion_physical_expr::functions::columnar_values_to_array;
 
     // ------------------------------
     // --- ExprSimplifier tests -----
@@ -1489,30 +1490,10 @@ mod tests {
         let return_type = Arc::new(DataType::Int32);
 
         let fun = Arc::new(|args: &[ColumnarValue]| {
-            let len = args
-                .iter()
-                .fold(Option::<usize>::None, |acc, arg| match arg {
-                    ColumnarValue::Scalar(_) => acc,
-                    ColumnarValue::Array(a) => Some(a.len()),
-                });
-
-            let inferred_length = len.unwrap_or(1);
-
-            let arg0 = match &args[0] {
-                ColumnarValue::Array(array) => array.clone(),
-                ColumnarValue::Scalar(scalar) => {
-                    scalar.to_array_of_size(inferred_length).unwrap()
-                }
-            };
-            let arg1 = match &args[1] {
-                ColumnarValue::Array(array) => array.clone(),
-                ColumnarValue::Scalar(scalar) => {
-                    scalar.to_array_of_size(inferred_length).unwrap()
-                }
-            };
+            let args = columnar_values_to_array(args)?;
 
-            let arg0 = as_int32_array(&arg0)?;
-            let arg1 = as_int32_array(&arg1)?;
+            let arg0 = as_int32_array(&args[0])?;
+            let arg1 = as_int32_array(&args[1])?;
 
             // 2. perform the computation
             let array = arg0
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index ac959dec6e..2bfdf49912 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -42,6 +42,7 @@ use arrow::{
     compute::kernels::length::{bit_length, length},
     datatypes::{DataType, Int32Type, Int64Type, Schema},
 };
+use arrow_array::Array;
 use datafusion_common::{internal_err, DataFusionError, Result, ScalarValue};
 pub use datafusion_expr::FuncMonotonicity;
 use datafusion_expr::{
@@ -191,6 +192,51 @@ pub(crate) enum Hint {
     AcceptsSingular,
 }
 
+/// A helper function used to infer the length of arguments of Scalar 
functions and convert
+/// [`ColumnarValue`]s to [`ArrayRef`]s with the inferred length. Note that 
this function
+/// only works for functions that accept either that all arguments are scalars 
or all arguments
+/// are arrays with same length. Otherwise, it will return an error.
+pub fn columnar_values_to_array(args: &[ColumnarValue]) -> 
Result<Vec<ArrayRef>> {
+    if args.is_empty() {
+        return Ok(vec![]);
+    }
+
+    let len = args
+        .iter()
+        .fold(Option::<usize>::None, |acc, arg| match arg {
+            ColumnarValue::Scalar(_) if acc.is_none() => Some(1),
+            ColumnarValue::Scalar(_) => {
+                if let Some(1) = acc {
+                    acc
+                } else {
+                    None
+                }
+            }
+            ColumnarValue::Array(a) => {
+                if let Some(l) = acc {
+                    if l == a.len() {
+                        acc
+                    } else {
+                        None
+                    }
+                } else {
+                    Some(a.len())
+                }
+            }
+        });
+
+    let inferred_length = len.ok_or(DataFusionError::Internal(
+        "Arguments has mixed length".to_string(),
+    ))?;
+
+    let args = args
+        .iter()
+        .map(|arg| arg.clone().into_array(inferred_length))
+        .collect::<Result<Vec<_>>>()?;
+
+    Ok(args)
+}
+
 /// Decorates a function to handle [`ScalarValue`]s by converting them to 
arrays before calling the function
 /// and vice-versa after evaluation.
 /// Note that this function makes a scalar function with no arguments or all 
scalar inputs return a scalar.
diff --git a/docs/source/library-user-guide/adding-udfs.md 
b/docs/source/library-user-guide/adding-udfs.md
index 64dc25411d..1824b23f9f 100644
--- a/docs/source/library-user-guide/adding-udfs.md
+++ b/docs/source/library-user-guide/adding-udfs.md
@@ -41,12 +41,12 @@ use std::sync::Arc;
 
 use datafusion::arrow::array::{ArrayRef, Int64Array};
 use datafusion::common::Result;
-
 use datafusion::common::cast::as_int64_array;
+use datafusion::physical_plan::functions::columnar_values_to_array;
 
-pub fn add_one(args: &[ArrayRef]) -> Result<ArrayRef> {
+pub fn add_one(args: &[ColumnarValue]) -> Result<ArrayRef> {
     // Error handling omitted for brevity
-
+    let args = columnar_values_to_array(args)?;
     let i64s = as_int64_array(&args[0])?;
 
     let new_array = i64s
@@ -82,7 +82,6 @@ There is a lower level API with more functionality but is 
more complex, that is
 
 ```rust
 use datafusion::logical_expr::{Volatility, create_udf};
-use datafusion::physical_plan::functions::make_scalar_function;
 use datafusion::arrow::datatypes::DataType;
 use std::sync::Arc;
 
@@ -91,13 +90,13 @@ let udf = create_udf(
     vec![DataType::Int64],
     Arc::new(DataType::Int64),
     Volatility::Immutable,
-    make_scalar_function(add_one),
+    Arc::new(add_one),
 );
 ```
 
 [`scalarudf`]: 
https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
 [`create_udf`]: 
https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
-[`make_scalar_function`]: 
https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
+[`process_scalar_func_inputs`]: 
https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.process_scalar_func_inputs.html
 [`advanced_udf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
 
 A few things to note:

Reply via email to