This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new edbd93aacf Add `ScalarUDFImpl::invoke_with_args` to support passing 
the return type created for the udf instance (#13290)
edbd93aacf is described below

commit edbd93aacf0b2397cbb1051b1da261fa008c23dd
Author: Joe Isaacs <[email protected]>
AuthorDate: Thu Nov 21 16:08:42 2024 +0000

    Add `ScalarUDFImpl::invoke_with_args` to support passing the return type 
created for the udf instance (#13290)
    
    * Added support for `ScalarUDFImpl::invoke_with_return_type` where the 
invoke is passed the return type created for the udf instance
    
    * Do not yet deprecate invoke_batch, add docs to invoke_with_args
    
    * add ticket reference
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/expr/src/lib.rs                         |  2 +-
 datafusion/expr/src/udf.rs                         | 98 +++++++++++++---------
 datafusion/functions/benches/random.rs             |  2 +
 datafusion/functions/src/core/version.rs           |  1 +
 datafusion/functions/src/datetime/to_local_time.rs |  9 +-
 datafusion/functions/src/datetime/to_timestamp.rs  |  4 +-
 datafusion/functions/src/datetime/to_unixtime.rs   |  1 +
 datafusion/functions/src/math/log.rs               | 20 ++---
 datafusion/functions/src/math/power.rs             |  4 +-
 datafusion/functions/src/math/signum.rs            |  2 +
 datafusion/functions/src/regex/regexpcount.rs      | 24 +++---
 datafusion/functions/src/utils.rs                  |  7 +-
 datafusion/physical-expr/src/scalar_function.rs    |  8 +-
 13 files changed, 107 insertions(+), 75 deletions(-)

diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 27b2d71b1f..d8b829f27e 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, 
TableSource, TableType};
 pub use udaf::{
     aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, 
StatisticsArgs,
 };
-pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
+pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, 
ScalarUDFImpl};
 pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
 pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
 pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 1a5d50477b..57b8d9c6b0 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -203,10 +203,7 @@ impl ScalarUDF {
         self.inner.simplify(args, info)
     }
 
-    /// Invoke the function on `args`, returning the appropriate result.
-    ///
-    /// See [`ScalarUDFImpl::invoke`] for more details.
-    #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
+    #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
     pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
         #[allow(deprecated)]
         self.inner.invoke(args)
@@ -216,20 +213,27 @@ impl ScalarUDF {
         self.inner.is_nullable(args, schema)
     }
 
-    /// Invoke the function with `args` and number of rows, returning the 
appropriate result.
-    ///
-    /// See [`ScalarUDFImpl::invoke_batch`] for more details.
+    #[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
     pub fn invoke_batch(
         &self,
         args: &[ColumnarValue],
         number_rows: usize,
     ) -> Result<ColumnarValue> {
+        #[allow(deprecated)]
         self.inner.invoke_batch(args, number_rows)
     }
 
+    /// Invoke the function on `args`, returning the appropriate result.
+    ///
+    /// See [`ScalarUDFImpl::invoke_with_args`] for details.
+    pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        self.inner.invoke_with_args(args)
+    }
+
     /// Invoke the function without `args` but number of rows, returning the 
appropriate result.
     ///
-    /// See [`ScalarUDFImpl::invoke_no_args`] for more details.
+    /// Note: This method is deprecated and will be removed in future releases.
+    /// User defined functions should implement [`Self::invoke_with_args`] 
instead.
     #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
     pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
         #[allow(deprecated)]
@@ -324,7 +328,17 @@ where
     }
 }
 
-/// Trait for implementing [`ScalarUDF`].
+pub struct ScalarFunctionArgs<'a> {
+    // The evaluated arguments to the function
+    pub args: &'a [ColumnarValue],
+    // The number of rows in record batch being evaluated
+    pub number_rows: usize,
+    // The return type of the scalar function returned (from `return_type` or 
`return_type_from_exprs`)
+    // when creating the physical expression from the logical expression
+    pub return_type: &'a DataType,
+}
+
+/// Trait for implementing user defined scalar functions.
 ///
 /// This trait exposes the full API for implementing user defined functions and
 /// can be used to implement any function.
@@ -332,18 +346,19 @@ where
 /// See [`advanced_udf.rs`] for a full example with complete implementation and
 /// [`ScalarUDF`] for other available options.
 ///
-///
 /// [`advanced_udf.rs`]: 
https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
+///
 /// # Basic Example
 /// ```
 /// # use std::any::Any;
 /// # use std::sync::OnceLock;
 /// # use arrow::datatypes::DataType;
 /// # use datafusion_common::{DataFusionError, plan_err, Result};
-/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, 
Volatility};
+/// # use datafusion_expr::{col, ColumnarValue, Documentation, 
ScalarFunctionArgs, Signature, Volatility};
 /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
 /// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
 ///
+/// /// This struct for a simple UDF that adds one to an int32
 /// #[derive(Debug)]
 /// struct AddOne {
 ///   signature: Signature,
@@ -356,7 +371,7 @@ where
 ///      }
 ///   }
 /// }
-///  
+///
 /// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
 ///
 /// fn get_doc() -> &'static Documentation {
@@ -383,7 +398,9 @@ where
 ///      Ok(DataType::Int32)
 ///    }
 ///    // The actual implementation would add one to the argument
-///    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { 
unimplemented!() }
+///    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+///         unimplemented!()
+///    }
 ///    fn documentation(&self) -> Option<&Documentation> {
 ///         Some(get_doc())
 ///     }
@@ -479,24 +496,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
 
     /// Invoke the function on `args`, returning the appropriate result
     ///
-    /// The function will be invoked passed with the slice of [`ColumnarValue`]
-    /// (either scalar or array).
-    ///
-    /// If the function does not take any arguments, please use 
[invoke_no_args]
-    /// instead and return [not_impl_err] for this function.
-    ///
-    ///
-    /// # Performance
-    ///
-    /// For the best performance, the implementations of `invoke` should handle
-    /// the common case when one or more of their arguments are constant values
-    /// (aka  [`ColumnarValue::Scalar`]).
-    ///
-    /// [`ColumnarValue::values_to_arrays`] can be used to convert the 
arguments
-    /// to arrays, which will likely be simpler code, but be slower.
-    ///
-    /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
-    #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
+    /// Note: This method is deprecated and will be removed in future releases.
+    /// User defined functions should implement [`Self::invoke_with_args`] 
instead.
+    #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
     fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
         not_impl_err!(
             "Function {} does not implement invoke but called",
@@ -507,17 +509,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
     /// Invoke the function with `args` and the number of rows,
     /// returning the appropriate result.
     ///
-    /// The function will be invoked with the slice of [`ColumnarValue`]
-    /// (either scalar or array).
-    ///
-    /// # Performance
+    /// Note: See notes on  [`Self::invoke_with_args`]
     ///
-    /// For the best performance, the implementations should handle the common 
case
-    /// when one or more of their arguments are constant values (aka
-    /// [`ColumnarValue::Scalar`]).
+    /// Note: This method is deprecated and will be removed in future releases.
+    /// User defined functions should implement [`Self::invoke_with_args`] 
instead.
     ///
-    /// [`ColumnarValue::values_to_arrays`] can be used to convert the 
arguments
-    /// to arrays, which will likely be simpler code, but be slower.
+    /// See <https://github.com/apache/datafusion/issues/13515> for more 
details.
     fn invoke_batch(
         &self,
         args: &[ColumnarValue],
@@ -537,9 +534,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
         }
     }
 
+    /// Invoke the function returning the appropriate result.
+    ///
+    /// # Performance
+    ///
+    /// For the best performance, the implementations should handle the common 
case
+    /// when one or more of their arguments are constant values (aka
+    /// [`ColumnarValue::Scalar`]).
+    ///
+    /// [`ColumnarValue::values_to_arrays`] can be used to convert the 
arguments
+    /// to arrays, which will likely be simpler code, but be slower.
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        #[allow(deprecated)]
+        self.invoke_batch(args.args, args.number_rows)
+    }
+
     /// Invoke the function without `args`, instead the number of rows are 
provided,
     /// returning the appropriate result.
-    #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
+    ///
+    /// Note: This method is deprecated and will be removed in future releases.
+    /// User defined functions should implement [`Self::invoke_with_args`] 
instead.
+    #[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
     fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
         not_impl_err!(
             "Function {} does not implement invoke_no_args but called",
@@ -767,6 +782,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
         args: &[ColumnarValue],
         number_rows: usize,
     ) -> Result<ColumnarValue> {
+        #[allow(deprecated)]
         self.inner.invoke_batch(args, number_rows)
     }
 
diff --git a/datafusion/functions/benches/random.rs 
b/datafusion/functions/benches/random.rs
index 5df5d9c7de..bc20e0ff11 100644
--- a/datafusion/functions/benches/random.rs
+++ b/datafusion/functions/benches/random.rs
@@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     c.bench_function("random_1M_rows_batch_8192", |b| {
         b.iter(|| {
             for _ in 0..iterations {
+                #[allow(deprecated)] // TODO: migrate to invoke_with_args
                 black_box(random_func.invoke_batch(&[], 8192).unwrap());
             }
         })
@@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
     c.bench_function("random_1M_rows_batch_128", |b| {
         b.iter(|| {
             for _ in 0..iterations_128 {
+                #[allow(deprecated)] // TODO: migrate to invoke_with_args
                 black_box(random_func.invoke_batch(&[], 128).unwrap());
             }
         })
diff --git a/datafusion/functions/src/core/version.rs 
b/datafusion/functions/src/core/version.rs
index 36cf07e9e5..eac0aa38f0 100644
--- a/datafusion/functions/src/core/version.rs
+++ b/datafusion/functions/src/core/version.rs
@@ -121,6 +121,7 @@ mod test {
     #[tokio::test]
     async fn test_version_udf() {
         let version_udf = ScalarUDF::from(VersionFunc::new());
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let version = version_udf.invoke_batch(&[], 1).unwrap();
 
         if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = 
version {
diff --git a/datafusion/functions/src/datetime/to_local_time.rs 
b/datafusion/functions/src/datetime/to_local_time.rs
index fef1eb9a60..5048b8fd47 100644
--- a/datafusion/functions/src/datetime/to_local_time.rs
+++ b/datafusion/functions/src/datetime/to_local_time.rs
@@ -431,7 +431,7 @@ mod tests {
     use arrow::datatypes::{DataType, TimeUnit};
     use chrono::NaiveDateTime;
     use datafusion_common::ScalarValue;
-    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
+    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
 
     use super::{adjust_to_local_time, ToLocalTimeFunc};
 
@@ -558,7 +558,11 @@ mod tests {
 
     fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
         let res = ToLocalTimeFunc::new()
-            .invoke_batch(&[ColumnarValue::Scalar(input)], 1)
+            .invoke_with_args(ScalarFunctionArgs {
+                args: &[ColumnarValue::Scalar(input)],
+                number_rows: 1,
+                return_type: &expected.data_type(),
+            })
             .unwrap();
         match res {
             ColumnarValue::Scalar(res) => {
@@ -617,6 +621,7 @@ mod tests {
                 .map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
                 .collect::<TimestampNanosecondArray>();
             let batch_size = input.len();
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let result = ToLocalTimeFunc::new()
                 .invoke_batch(&[ColumnarValue::Array(Arc::new(input))], 
batch_size)
                 .unwrap();
diff --git a/datafusion/functions/src/datetime/to_timestamp.rs 
b/datafusion/functions/src/datetime/to_timestamp.rs
index f15fad701c..78a7bf505d 100644
--- a/datafusion/functions/src/datetime/to_timestamp.rs
+++ b/datafusion/functions/src/datetime/to_timestamp.rs
@@ -1008,7 +1008,7 @@ mod tests {
             for array in arrays {
                 let rt = udf.return_type(&[array.data_type()]).unwrap();
                 assert!(matches!(rt, Timestamp(_, Some(_))));
-
+                #[allow(deprecated)] // TODO: migrate to invoke_with_args
                 let res = udf
                     .invoke_batch(&[array.clone()], 1)
                     .expect("that to_timestamp parsed values without error");
@@ -1051,7 +1051,7 @@ mod tests {
             for array in arrays {
                 let rt = udf.return_type(&[array.data_type()]).unwrap();
                 assert!(matches!(rt, Timestamp(_, None)));
-
+                #[allow(deprecated)] // TODO: migrate to invoke_with_args
                 let res = udf
                     .invoke_batch(&[array.clone()], 1)
                     .expect("that to_timestamp parsed values without error");
diff --git a/datafusion/functions/src/datetime/to_unixtime.rs 
b/datafusion/functions/src/datetime/to_unixtime.rs
index dd90ce6a6c..c291596c25 100644
--- a/datafusion/functions/src/datetime/to_unixtime.rs
+++ b/datafusion/functions/src/datetime/to_unixtime.rs
@@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
             DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) 
=> args[0]
                 .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
                 .cast_to(&DataType::Int64, None),
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             DataType::Utf8 => ToTimestampSecondsFunc::new()
                 .invoke_batch(args, batch_size)?
                 .cast_to(&DataType::Int64, None),
diff --git a/datafusion/functions/src/math/log.rs 
b/datafusion/functions/src/math/log.rs
index 9110f9f532..14b6dc3e05 100644
--- a/datafusion/functions/src/math/log.rs
+++ b/datafusion/functions/src/math/log.rs
@@ -277,7 +277,7 @@ mod tests {
             ]))), // num
             ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 
20]))),
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let _ = LogFunc::new().invoke_batch(&args, 4);
     }
 
@@ -286,7 +286,7 @@ mod tests {
         let args = [
             ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new().invoke_batch(&args, 1);
         result.expect_err("expected error");
     }
@@ -296,7 +296,7 @@ mod tests {
         let args = [
             ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 1)
             .expect("failed to initialize function log");
@@ -320,7 +320,7 @@ mod tests {
         let args = [
             ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 1)
             .expect("failed to initialize function log");
@@ -345,7 +345,7 @@ mod tests {
             ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
             ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 1)
             .expect("failed to initialize function log");
@@ -370,7 +370,7 @@ mod tests {
             ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
             ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 1)
             .expect("failed to initialize function log");
@@ -396,7 +396,7 @@ mod tests {
                 10.0, 100.0, 1000.0, 10000.0,
             ]))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 4)
             .expect("failed to initialize function log");
@@ -425,7 +425,7 @@ mod tests {
                 10.0, 100.0, 1000.0, 10000.0,
             ]))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 4)
             .expect("failed to initialize function log");
@@ -455,7 +455,7 @@ mod tests {
                 8.0, 4.0, 81.0, 625.0,
             ]))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 4)
             .expect("failed to initialize function log");
@@ -485,7 +485,7 @@ mod tests {
                 8.0, 4.0, 81.0, 625.0,
             ]))), // num
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = LogFunc::new()
             .invoke_batch(&args, 4)
             .expect("failed to initialize function log");
diff --git a/datafusion/functions/src/math/power.rs 
b/datafusion/functions/src/math/power.rs
index a24c613f52..acf5f84df9 100644
--- a/datafusion/functions/src/math/power.rs
+++ b/datafusion/functions/src/math/power.rs
@@ -205,7 +205,7 @@ mod tests {
             ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 
3.0, 5.0]))), // base
             ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 
4.0, 4.0]))), // exponent
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = PowerFunc::new()
             .invoke_batch(&args, 4)
             .expect("failed to initialize function power");
@@ -232,7 +232,7 @@ mod tests {
             ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 
5]))), // base
             ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 
4]))), // exponent
         ];
-
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = PowerFunc::new()
             .invoke_batch(&args, 4)
             .expect("failed to initialize function power");
diff --git a/datafusion/functions/src/math/signum.rs 
b/datafusion/functions/src/math/signum.rs
index 7f21297712..33ff630f30 100644
--- a/datafusion/functions/src/math/signum.rs
+++ b/datafusion/functions/src/math/signum.rs
@@ -167,6 +167,7 @@ mod test {
             f32::NEG_INFINITY,
         ]));
         let batch_size = array.len();
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = SignumFunc::new()
             .invoke_batch(&[ColumnarValue::Array(array)], batch_size)
             .expect("failed to initialize function signum");
@@ -207,6 +208,7 @@ mod test {
             f64::NEG_INFINITY,
         ]));
         let batch_size = array.len();
+        #[allow(deprecated)] // TODO: migrate to invoke_with_args
         let result = SignumFunc::new()
             .invoke_batch(&[ColumnarValue::Array(array)], batch_size)
             .expect("failed to initialize function signum");
diff --git a/datafusion/functions/src/regex/regexpcount.rs 
b/datafusion/functions/src/regex/regexpcount.rs
index 8da154430f..819463795b 100644
--- a/datafusion/functions/src/regex/regexpcount.rs
+++ b/datafusion/functions/src/regex/regexpcount.rs
@@ -655,7 +655,7 @@ mod tests {
             let v_sv = ScalarValue::Utf8(Some(v.to_string()));
             let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
             let expected = expected.get(pos).cloned();
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[ColumnarValue::Scalar(v_sv), 
ColumnarValue::Scalar(regex_sv)],
                 1,
@@ -670,7 +670,7 @@ mod tests {
             // largeutf8
             let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
             let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[ColumnarValue::Scalar(v_sv), 
ColumnarValue::Scalar(regex_sv)],
                 1,
@@ -685,7 +685,7 @@ mod tests {
             // utf8view
             let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
             let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[ColumnarValue::Scalar(v_sv), 
ColumnarValue::Scalar(regex_sv)],
                 1,
@@ -711,7 +711,7 @@ mod tests {
             let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
             let start_sv = ScalarValue::Int64(Some(start));
             let expected = expected.get(pos).cloned();
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -730,7 +730,7 @@ mod tests {
             // largeutf8
             let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
             let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -749,7 +749,7 @@ mod tests {
             // utf8view
             let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
             let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -781,7 +781,7 @@ mod tests {
             let start_sv = ScalarValue::Int64(Some(start));
             let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
             let expected = expected.get(pos).cloned();
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -802,7 +802,7 @@ mod tests {
             let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
             let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
             let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -823,7 +823,7 @@ mod tests {
             let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
             let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
             let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -905,7 +905,7 @@ mod tests {
             let start_sv = ScalarValue::Int64(Some(start));
             let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| 
f.to_string()));
             let expected = expected.get(pos).cloned();
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -926,7 +926,7 @@ mod tests {
             let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
             let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| 
s.to_string()));
             let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| 
f.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
@@ -947,7 +947,7 @@ mod tests {
             let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
             let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| 
s.to_string()));
             let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| 
f.to_string()));
-
+            #[allow(deprecated)] // TODO: migrate to invoke_with_args
             let re = RegexpCountFunc::new().invoke_batch(
                 &[
                     ColumnarValue::Scalar(v_sv),
diff --git a/datafusion/functions/src/utils.rs 
b/datafusion/functions/src/utils.rs
index 87180cb77d..8b47350041 100644
--- a/datafusion/functions/src/utils.rs
+++ b/datafusion/functions/src/utils.rs
@@ -146,9 +146,10 @@ pub mod test {
             match expected {
                 Ok(expected) => {
                     assert_eq!(return_type.is_ok(), true);
-                    assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE);
+                    let return_type = return_type.unwrap();
+                    assert_eq!(return_type, $EXPECTED_DATA_TYPE);
 
-                    let result = func.invoke_batch($ARGS, cardinality);
+                    let result = 
func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, 
number_rows: cardinality, return_type: &return_type});
                     assert_eq!(result.is_ok(), true, "function returned an 
error: {}", result.unwrap_err());
 
                     let result = 
result.unwrap().clone().into_array(cardinality).expect("Failed to convert to 
array");
@@ -169,7 +170,7 @@ pub mod test {
                     }
                     else {
                         // invoke is expected error - cannot use .expect_err() 
due to Debug not being implemented
-                        match func.invoke_batch($ARGS, cardinality) {
+                        match 
func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, 
number_rows: cardinality, return_type: &return_type.unwrap()}) {
                             Ok(_) => assert!(false, "expected error"),
                             Err(error) => {
                                 
assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
diff --git a/datafusion/physical-expr/src/scalar_function.rs 
b/datafusion/physical-expr/src/scalar_function.rs
index 9bf168e8a1..74d0ecdadd 100644
--- a/datafusion/physical-expr/src/scalar_function.rs
+++ b/datafusion/physical-expr/src/scalar_function.rs
@@ -43,7 +43,7 @@ use datafusion_common::{internal_err, DFSchema, Result, 
ScalarValue};
 use datafusion_expr::interval_arithmetic::Interval;
 use datafusion_expr::sort_properties::ExprProperties;
 use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
-use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF};
+use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, 
ScalarUDF};
 
 /// Physical expression of a scalar function
 #[derive(Eq, PartialEq, Hash)]
@@ -141,7 +141,11 @@ impl PhysicalExpr for ScalarFunctionExpr {
             .collect::<Result<Vec<_>>>()?;
 
         // evaluate the function
-        let output = self.fun.invoke_batch(&inputs, batch.num_rows())?;
+        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
+            args: inputs.as_slice(),
+            number_rows: batch.num_rows(),
+            return_type: &self.return_type,
+        })?;
 
         if let ColumnarValue::Array(array) = &output {
             if array.len() != batch.num_rows() {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to