This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b2cbc7809e Add trait based ScalarUDF API (#8578)
b2cbc7809e is described below

commit b2cbc7809ee0656099169307a73aadff23ab1030
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Dec 28 15:07:32 2023 -0500

    Add trait based ScalarUDF API (#8578)
    
    * Introduce new trait based ScalarUDF API
    
    * change name to `Self::new_from_impl`
    
    * Improve documentation, add link to advanced_udf.rs in the user guide
    
    * typo
    
    * Improve docs for aliases
    
    * Apply suggestions from code review
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    
    * improve docs
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 datafusion-examples/README.md                      |   3 +-
 datafusion-examples/examples/advanced_udf.rs       | 243 +++++++++++++++++++++
 datafusion-examples/examples/simple_udf.rs         |   6 +
 datafusion/expr/src/expr.rs                        |  55 +++--
 datafusion/expr/src/expr_fn.rs                     |  85 ++++++-
 datafusion/expr/src/lib.rs                         |   2 +-
 datafusion/expr/src/udf.rs                         | 169 +++++++++++++-
 datafusion/optimizer/src/analyzer/type_coercion.rs |  64 +++---
 docs/source/library-user-guide/adding-udfs.md      |   9 +-
 9 files changed, 562 insertions(+), 74 deletions(-)

diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md
index 057cdd4752..1296c74ea2 100644
--- a/datafusion-examples/README.md
+++ b/datafusion-examples/README.md
@@ -59,8 +59,9 @@ cargo run --example csv_sql
 - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure 
`object_store` and run a query against files stored in AWS S3
 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` 
and run a query against files vi HTTP
 - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom 
Query Optimizer pass
+- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined 
Scalar Function (UDF)
+- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more 
complicated User Defined Scalar Function (UDF)
 - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User 
Defined Aggregate Function (UDAF)
-- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined 
(scalar) Function (UDF)
 - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User 
Defined Window Function (UDWF)
 
 ## Distributed
diff --git a/datafusion-examples/examples/advanced_udf.rs 
b/datafusion-examples/examples/advanced_udf.rs
new file mode 100644
index 0000000000..6ebf88a0b6
--- /dev/null
+++ b/datafusion-examples/examples/advanced_udf.rs
@@ -0,0 +1,243 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::{
+    arrow::{
+        array::{ArrayRef, Float32Array, Float64Array},
+        datatypes::DataType,
+        record_batch::RecordBatch,
+    },
+    logical_expr::Volatility,
+};
+use std::any::Any;
+
+use arrow::array::{new_null_array, Array, AsArray};
+use arrow::compute;
+use arrow::datatypes::Float64Type;
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::{internal_err, ScalarValue};
+use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature};
+use std::sync::Arc;
+
+/// This example shows how to use the full ScalarUDFImpl API to implement a 
user
+/// defined function. As in the `simple_udf.rs` example, this struct implements
+/// a function that takes two arguments and returns the first argument raised 
to
+/// the power of the second argument `a^b`.
+///
+/// To do so, we must implement the `ScalarUDFImpl` trait.
+struct PowUdf {
+    signature: Signature,
+    aliases: Vec<String>,
+}
+
+impl PowUdf {
+    /// Create a new instance of the `PowUdf` struct
+    fn new() -> Self {
+        Self {
+            signature: Signature::exact(
+                // this function will always take two arguments of type f64
+                vec![DataType::Float64, DataType::Float64],
+                // this function is deterministic and will always return the 
same
+                // result for the same input
+                Volatility::Immutable,
+            ),
+            // we will also add an alias of "my_pow"
+            aliases: vec!["my_pow".to_string()],
+        }
+    }
+}
+
+impl ScalarUDFImpl for PowUdf {
+    /// We implement as_any so that we can downcast the ScalarUDFImpl trait 
object
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    /// Return the name of this function
+    fn name(&self) -> &str {
+        "pow"
+    }
+
+    /// Return the "signature" of this function -- namely what types of 
arguments it will take
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    /// What is the type of value that will be returned by this function? In
+    /// this case it will always be a constant value, but it could also be a
+    /// function of the input types.
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Float64)
+    }
+
+    /// This is the function that actually calculates the results.
+    ///
+    /// This is the same way that functions built into DataFusion are invoked,
+    /// which permits important special cases when one or both of the arguments
+    /// are single values (constants). For example `pow(a, 2)`
+    ///
+    /// However, it also means the implementation is more complex than when
+    /// using `create_udf`.
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        // DataFusion has arranged for the correct inputs to be passed to this
+        // function, but we check again to make sure
+        assert_eq!(args.len(), 2);
+        let (base, exp) = (&args[0], &args[1]);
+        assert_eq!(base.data_type(), DataType::Float64);
+        assert_eq!(exp.data_type(), DataType::Float64);
+
+        match (base, exp) {
+            // For demonstration purposes we also implement the scalar / scalar
+            // case here, but it is not typically required for high 
performance.
+            //
+            // For performance it is most important to optimize cases where at
+            // least one argument is an array. If all arguments are constants,
+            // the DataFusion expression simplification logic will often invoke
+            // this path once during planning, and simply use the result during
+            // execution.
+            (
+                ColumnarValue::Scalar(ScalarValue::Float64(base)),
+                ColumnarValue::Scalar(ScalarValue::Float64(exp)),
+            ) => {
+                // compute the output. Note DataFusion treats `None` as NULL.
+                let res = match (base, exp) {
+                    (Some(base), Some(exp)) => Some(base.powf(*exp)),
+                    // one or both arguments were NULL
+                    _ => None,
+                };
+                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
+            }
+            // special case if the exponent is a constant
+            (
+                ColumnarValue::Array(base_array),
+                ColumnarValue::Scalar(ScalarValue::Float64(exp)),
+            ) => {
+                let result_array = match exp {
+                    // a ^ null = null
+                    None => new_null_array(base_array.data_type(), 
base_array.len()),
+                    // a ^ exp
+                    Some(exp) => {
+                        // DataFusion has ensured both arguments are Float64:
+                        let base_array = 
base_array.as_primitive::<Float64Type>();
+                        // calculate the result for every row. The `unary`
+                        // kernel creates very fast "vectorized" code and
+                        // handles things like null values for us.
+                        let res: Float64Array =
+                            compute::unary(base_array, |base| base.powf(*exp));
+                        Arc::new(res)
+                    }
+                };
+                Ok(ColumnarValue::Array(result_array))
+            }
+
+            // special case if the base is a constant (note this code is quite
+            // similar to the previous case, so we omit comments)
+            (
+                ColumnarValue::Scalar(ScalarValue::Float64(base)),
+                ColumnarValue::Array(exp_array),
+            ) => {
+                let res = match base {
+                    None => new_null_array(exp_array.data_type(), 
exp_array.len()),
+                    Some(base) => {
+                        let exp_array = 
exp_array.as_primitive::<Float64Type>();
+                        let res: Float64Array =
+                            compute::unary(exp_array, |exp| base.powf(exp));
+                        Arc::new(res)
+                    }
+                };
+                Ok(ColumnarValue::Array(res))
+            }
+            // Both arguments are arrays so we have to perform the calculation 
for every row
+            (ColumnarValue::Array(base_array), 
ColumnarValue::Array(exp_array)) => {
+                let res: Float64Array = compute::binary(
+                    base_array.as_primitive::<Float64Type>(),
+                    exp_array.as_primitive::<Float64Type>(),
+                    |base, exp| base.powf(exp),
+                )?;
+                Ok(ColumnarValue::Array(Arc::new(res)))
+            }
+            // if the types were not float, it is a bug in DataFusion
+            _ => {
+                use datafusion_common::DataFusionError;
+                internal_err!("Invalid argument types to pow function")
+            }
+        }
+    }
+
+    /// We will also add an alias of "my_pow"
+    fn aliases(&self) -> &[String] {
+        &self.aliases
+    }
+}
+
+/// In this example we register `PowUdf` as a user defined function
+/// and invoke it via the DataFrame API and SQL
+#[tokio::main]
+async fn main() -> Result<()> {
+    let ctx = create_context()?;
+
+    // create the UDF
+    let pow = ScalarUDF::from(PowUdf::new());
+
+    // register the UDF with the context so it can be invoked by name and from 
SQL
+    ctx.register_udf(pow.clone());
+
+    // get a DataFrame from the context for scanning the "t" table
+    let df = ctx.table("t").await?;
+
+    // Call pow(a, 10) using the DataFrame API
+    let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?;
+
+    // note that the second argument is passed as an i32, not f64. DataFusion
+    // automatically coerces the types to match the UDF's defined signature.
+
+    // print the results
+    df.show().await?;
+
+    // You can also invoke both pow(2, 10)  and its alias my_pow(a, b) using 
SQL
+    let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?;
+    sql_df.show().await?;
+
+    Ok(())
+}
+
+/// create local execution context with an in-memory table:
+///
+/// ```text
+/// +-----+-----+
+/// | a   | b   |
+/// +-----+-----+
+/// | 2.1 | 1.0 |
+/// | 3.1 | 2.0 |
+/// | 4.1 | 3.0 |
+/// | 5.1 | 4.0 |
+/// +-----+-----+
+/// ```
+fn create_context() -> Result<SessionContext> {
+    // define data.
+    let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1]));
+    let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
+    let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;
+
+    // declare a new context. In Spark API, this corresponds to a new 
SparkSession
+    let ctx = SessionContext::new();
+
+    // declare a table in memory. In Spark API, this corresponds to 
createDataFrame(...).
+    ctx.register_batch("t", batch)?;
+    Ok(ctx)
+}
diff --git a/datafusion-examples/examples/simple_udf.rs 
b/datafusion-examples/examples/simple_udf.rs
index 5919917865..39e1e13ce3 100644
--- a/datafusion-examples/examples/simple_udf.rs
+++ b/datafusion-examples/examples/simple_udf.rs
@@ -140,5 +140,11 @@ async fn main() -> Result<()> {
     // print the results
     df.show().await?;
 
+    // Given that `pow` is registered in the context, we can also use it in 
SQL:
+    let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?;
+
+    // print the results
+    sql_df.show().await?;
+
     Ok(())
 }
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index b46e9ec8f6..0ec19bcadb 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1724,13 +1724,13 @@ mod test {
     use crate::expr::Cast;
     use crate::expr_fn::col;
     use crate::{
-        case, lit, BuiltinScalarFunction, ColumnarValue, Expr, 
ReturnTypeFunction,
-        ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, 
Signature,
-        Volatility,
+        case, lit, BuiltinScalarFunction, ColumnarValue, Expr, 
ScalarFunctionDefinition,
+        ScalarUDF, ScalarUDFImpl, Signature, Volatility,
     };
     use arrow::datatypes::DataType;
     use datafusion_common::Column;
     use datafusion_common::{Result, ScalarValue};
+    use std::any::Any;
     use std::sync::Arc;
 
     #[test]
@@ -1848,24 +1848,41 @@ mod test {
         );
 
         // UDF
-        let return_type: ReturnTypeFunction =
-            Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
-        let fun: ScalarFunctionImplementation =
-            Arc::new(move |_| 
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
-        let udf = Arc::new(ScalarUDF::new(
-            "TestScalarUDF",
-            &Signature::uniform(1, vec![DataType::Float32], 
Volatility::Stable),
-            &return_type,
-            &fun,
-        ));
+        struct TestScalarUDF {
+            signature: Signature,
+        }
+        impl ScalarUDFImpl for TestScalarUDF {
+            fn as_any(&self) -> &dyn Any {
+                self
+            }
+            fn name(&self) -> &str {
+                "TestScalarUDF"
+            }
+
+            fn signature(&self) -> &Signature {
+                &self.signature
+            }
+
+            fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> 
{
+                Ok(DataType::Utf8)
+            }
+
+            fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> 
{
+                Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
+            }
+        }
+        let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
+            signature: Signature::uniform(1, vec![DataType::Float32], 
Volatility::Stable),
+        }));
         assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
 
-        let udf = Arc::new(ScalarUDF::new(
-            "TestScalarUDF",
-            &Signature::uniform(1, vec![DataType::Float32], 
Volatility::Volatile),
-            &return_type,
-            &fun,
-        ));
+        let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
+            signature: Signature::uniform(
+                1,
+                vec![DataType::Float32],
+                Volatility::Volatile,
+            ),
+        }));
         assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
 
         // Unresolved function
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index cedf1d8451..eed41d97cc 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -22,15 +22,16 @@ use crate::expr::{
     Placeholder, ScalarFunction, TryCast,
 };
 use crate::function::PartitionEvaluatorFactory;
-use crate::WindowUDF;
 use crate::{
     aggregate_function, built_in_function, 
conditional_expressions::CaseBuilder,
     logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
     BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
     ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, 
Volatility,
 };
+use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF};
 use arrow::datatypes::DataType;
 use datafusion_common::{Column, Result};
+use std::any::Any;
 use std::ops::Not;
 use std::sync::Arc;
 
@@ -944,11 +945,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder {
     CaseBuilder::new(None, vec![when], vec![then], None)
 }
 
-/// Creates a new UDF with a specific signature and specific return type.
-/// This is a helper function to create a new UDF.
-/// The function `create_udf` returns a subset of all possible 
`ScalarFunction`:
-/// * the UDF has a fixed return type
-/// * the UDF has a fixed signature (e.g. [f64, f64])
+/// Convenience method to create a new user defined scalar function (UDF) with 
a
+/// specific signature and specific return type.
+///
+/// Note this function does not expose all available features of [`ScalarUDF`],
+/// such as
+///
+/// * computing return types based on input types
+/// * multiple [`Signature`]s
+/// * aliases
+///
+/// See [`ScalarUDF`] for details and examples on how to use the full
+/// functionality.
 pub fn create_udf(
     name: &str,
     input_types: Vec<DataType>,
@@ -956,13 +964,66 @@ pub fn create_udf(
     volatility: Volatility,
     fun: ScalarFunctionImplementation,
 ) -> ScalarUDF {
-    let return_type: ReturnTypeFunction = Arc::new(move |_| 
Ok(return_type.clone()));
-    ScalarUDF::new(
+    let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| 
t.as_ref().clone());
+    ScalarUDF::from(SimpleScalarUDF::new(
         name,
-        &Signature::exact(input_types, volatility),
-        &return_type,
-        &fun,
-    )
+        input_types,
+        return_type,
+        volatility,
+        fun,
+    ))
+}
+
+/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
+/// return type.
+pub struct SimpleScalarUDF {
+    name: String,
+    signature: Signature,
+    return_type: DataType,
+    fun: ScalarFunctionImplementation,
+}
+
+impl SimpleScalarUDF {
+    /// Create a new `SimpleScalarUDF` from a name, input types, return type 
and
+    /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
+    pub fn new(
+        name: impl Into<String>,
+        input_types: Vec<DataType>,
+        return_type: DataType,
+        volatility: Volatility,
+        fun: ScalarFunctionImplementation,
+    ) -> Self {
+        let name = name.into();
+        let signature = Signature::exact(input_types, volatility);
+        Self {
+            name,
+            signature,
+            return_type,
+            fun,
+        }
+    }
+}
+
+impl ScalarUDFImpl for SimpleScalarUDF {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Ok(self.return_type.clone())
+    }
+
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        (self.fun)(args)
+    }
 }
 
 /// Creates a new UDAF with a specific signature, state type and return type.
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 48532e13dc..bf8e9e2954 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -80,7 +80,7 @@ pub use signature::{
 };
 pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
 pub use udaf::AggregateUDF;
-pub use udf::ScalarUDF;
+pub use udf::{ScalarUDF, ScalarUDFImpl};
 pub use udwf::WindowUDF;
 pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
 pub use window_function::{BuiltInWindowFunction, WindowFunction};
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 3a18ca2d25..2ec80a4a9e 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -17,9 +17,12 @@
 
 //! [`ScalarUDF`]: Scalar User Defined Functions
 
-use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
+use crate::{
+    ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, 
Signature,
+};
 use arrow::datatypes::DataType;
 use datafusion_common::Result;
+use std::any::Any;
 use std::fmt;
 use std::fmt::Debug;
 use std::fmt::Formatter;
@@ -27,11 +30,19 @@ use std::sync::Arc;
 
 /// Logical representation of a Scalar User Defined Function.
 ///
-/// A scalar function produces a single row output for each row of input.
+/// A scalar function produces a single row output for each row of input. This
+/// struct contains the information DataFusion needs to plan and invoke
+/// functions you supply such name, type signature, return type, and actual
+/// implementation.
 ///
-/// This struct contains the information DataFusion needs to plan and invoke
-/// functions such name, type signature, return type, and actual 
implementation.
 ///
+/// 1. For simple (less performant) use cases, use [`create_udf`] and 
[`simple_udf.rs`].
+///
+/// 2. For advanced use cases, use  [`ScalarUDFImpl`] and [`advanced_udf.rs`].
+///
+/// [`create_udf`]: crate::expr_fn::create_udf
+/// [`simple_udf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
+/// [`advanced_udf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
 #[derive(Clone)]
 pub struct ScalarUDF {
     /// The name of the function
@@ -79,7 +90,11 @@ impl std::hash::Hash for ScalarUDF {
 }
 
 impl ScalarUDF {
-    /// Create a new ScalarUDF
+    /// Create a new ScalarUDF from low level details.
+    ///
+    /// See  [`ScalarUDFImpl`] for a more convenient way to create a
+    /// `ScalarUDF` using trait objects
+    #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl 
instead")]
     pub fn new(
         name: &str,
         signature: &Signature,
@@ -95,6 +110,34 @@ impl ScalarUDF {
         }
     }
 
+    /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
+    ///
+    /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
+    pub fn new_from_impl<F>(fun: F) -> ScalarUDF
+    where
+        F: ScalarUDFImpl + Send + Sync + 'static,
+    {
+        // TODO change the internal implementation to use the trait object
+        let arc_fun = Arc::new(fun);
+        let captured_self = arc_fun.clone();
+        let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
+            let return_type = captured_self.return_type(arg_types)?;
+            Ok(Arc::new(return_type))
+        });
+
+        let captured_self = arc_fun.clone();
+        let func: ScalarFunctionImplementation =
+            Arc::new(move |args| captured_self.invoke(args));
+
+        Self {
+            name: arc_fun.name().to_string(),
+            signature: arc_fun.signature().clone(),
+            return_type: return_type.clone(),
+            fun: func,
+            aliases: arc_fun.aliases().to_vec(),
+        }
+    }
+
     /// Adds additional names that can be used to invoke this function, in 
addition to `name`
     pub fn with_aliases(
         mut self,
@@ -105,7 +148,9 @@ impl ScalarUDF {
         self
     }
 
-    /// creates a logical expression with a call of the UDF
+    /// Returns a [`Expr`] logical expression to call this UDF with specified
+    /// arguments.
+    ///
     /// This utility allows using the UDF without requiring access to the 
registry.
     pub fn call(&self, args: Vec<Expr>) -> Expr {
         Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
@@ -124,22 +169,126 @@ impl ScalarUDF {
         &self.aliases
     }
 
-    /// Returns this function's signature (what input types are accepted)
+    /// Returns this function's [`Signature`] (what input types are accepted)
     pub fn signature(&self) -> &Signature {
         &self.signature
     }
 
-    /// Return the type of the function given its input types
+    /// The datatype this function returns given the input argument input types
     pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
         // Old API returns an Arc of the datatype for some reason
         let res = (self.return_type)(args)?;
         Ok(res.as_ref().clone())
     }
 
-    /// Return the actual implementation
+    /// Return an [`Arc`] to the function implementation
     pub fn fun(&self) -> ScalarFunctionImplementation {
         self.fun.clone()
     }
+}
 
-    // TODO maybe add an invoke() method that runs the actual function?
+impl<F> From<F> for ScalarUDF
+where
+    F: ScalarUDFImpl + Send + Sync + 'static,
+{
+    fn from(fun: F) -> Self {
+        Self::new_from_impl(fun)
+    }
+}
+
+/// Trait for implementing [`ScalarUDF`].
+///
+/// This trait exposes the full API for implementing user defined functions and
+/// can be used to implement any function.
+///
+/// See [`advanced_udf.rs`] for a full example with complete implementation and
+/// [`ScalarUDF`] for other available options.
+///
+///
+/// [`advanced_udf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
+/// # Basic Example
+/// ```
+/// # use std::any::Any;
+/// # use arrow::datatypes::DataType;
+/// # use datafusion_common::{DataFusionError, plan_err, Result};
+/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
+/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
+/// struct AddOne {
+///   signature: Signature
+/// };
+///
+/// impl AddOne {
+///   fn new() -> Self {
+///     Self {
+///       signature: Signature::uniform(1, vec![DataType::Int32], 
Volatility::Immutable)
+///      }
+///   }
+/// }
+///
+/// /// Implement the ScalarUDFImpl trait for AddOne
+/// impl ScalarUDFImpl for AddOne {
+///    fn as_any(&self) -> &dyn Any { self }
+///    fn name(&self) -> &str { "add_one" }
+///    fn signature(&self) -> &Signature { &self.signature }
+///    fn return_type(&self, args: &[DataType]) -> Result<DataType> {
+///      if !matches!(args.get(0), Some(&DataType::Int32)) {
+///        return plan_err!("add_one only accepts Int32 arguments");
+///      }
+///      Ok(DataType::Int32)
+///    }
+///    // The actual implementation would add one to the argument
+///    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { 
unimplemented!() }
+/// }
+///
+/// // Create a new ScalarUDF from the implementation
+/// let add_one = ScalarUDF::from(AddOne::new());
+///
+/// // Call the function `add_one(col)`
+/// let expr = add_one.call(vec![col("a")]);
+/// ```
+pub trait ScalarUDFImpl {
+    /// Returns this object as an [`Any`] trait object
+    fn as_any(&self) -> &dyn Any;
+
+    /// Returns this function's name
+    fn name(&self) -> &str;
+
+    /// Returns the function's [`Signature`] for information about what input
+    /// types are accepted and the function's Volatility.
+    fn signature(&self) -> &Signature;
+
+    /// What [`DataType`] will be returned by this function, given the types of
+    /// the arguments
+    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
+
+    /// Invoke the function on `args`, returning the appropriate result
+    ///
+    /// The function will be invoked passed with the slice of [`ColumnarValue`]
+    /// (either scalar or array).
+    ///
+    /// # Zero Argument Functions
+    /// If the function has zero parameters (e.g. `now()`) it will be passed a
+    /// single element slice which is a a null array to indicate the batch's 
row
+    /// count (so the function can know the resulting array size).
+    ///
+    /// # Performance
+    ///
+    /// For the best performance, the implementations of `invoke` should handle
+    /// the common case when one or more of their arguments are constant values
+    /// (aka  [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`]
+    /// and treating all arguments as arrays will work, but will be slower.
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue>;
+
+    /// Returns any aliases (alternate names) for this function.
+    ///
+    /// Aliases can be used to invoke the same function using different names.
+    /// For example in some databases `now()` and `current_timestamp()` are
+    /// aliases for the same function. This behavior can be obtained by
+    /// returning `current_timestamp` as an alias for the `now` function.
+    ///
+    /// Note: `aliases` should only include names other than [`Self::name`].
+    /// Defaults to `[]` (no aliases)
+    fn aliases(&self) -> &[String] {
+        &[]
+    }
 }
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index c5e1180b9f..b6298f5b55 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -738,7 +738,8 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) 
-> Result<Case> {
 
 #[cfg(test)]
 mod test {
-    use std::sync::Arc;
+    use std::any::Any;
+    use std::sync::{Arc, OnceLock};
 
     use arrow::array::{FixedSizeListArray, Int32Array};
     use arrow::datatypes::{DataType, TimeUnit};
@@ -750,13 +751,13 @@ mod test {
     use datafusion_expr::{
         cast, col, concat, concat_ws, create_udaf, is_true, 
AccumulatorFactoryFunction,
         AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, 
Case,
-        ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction, 
Subquery,
+        ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, 
StateTypeFunction,
+        Subquery,
     };
     use datafusion_expr::{
         lit,
         logical_plan::{EmptyRelation, Projection},
-        Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, 
ScalarUDF,
-        Signature, Volatility,
+        Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, 
Volatility,
     };
     use datafusion_physical_expr::expressions::AvgAccumulator;
 
@@ -808,22 +809,36 @@ mod test {
         assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
     }
 
+    static TEST_SIGNATURE: OnceLock<Signature> = OnceLock::new();
+
+    struct TestScalarUDF {}
+    impl ScalarUDFImpl for TestScalarUDF {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+
+        fn name(&self) -> &str {
+            "TestScalarUDF"
+        }
+        fn signature(&self) -> &Signature {
+            TEST_SIGNATURE.get_or_init(|| {
+                Signature::uniform(1, vec![DataType::Float32], 
Volatility::Stable)
+            })
+        }
+        fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+            Ok(DataType::Utf8)
+        }
+
+        fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
+            Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
+        }
+    }
+
     #[test]
     fn scalar_udf() -> Result<()> {
         let empty = empty();
-        let return_type: ReturnTypeFunction =
-            Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
-        let fun: ScalarFunctionImplementation =
-            Arc::new(move |_| 
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
-        let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
-            Arc::new(ScalarUDF::new(
-                "TestScalarUDF",
-                &Signature::uniform(1, vec![DataType::Float32], 
Volatility::Stable),
-                &return_type,
-                &fun,
-            )),
-            vec![lit(123_i32)],
-        ));
+
+        let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]);
         let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], 
empty)?);
         let expected =
             "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n  
EmptyRelation";
@@ -833,24 +848,13 @@ mod test {
     #[test]
     fn scalar_udf_invalid_input() -> Result<()> {
         let empty = empty();
-        let return_type: ReturnTypeFunction =
-            Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
-        let fun: ScalarFunctionImplementation = Arc::new(move |_| 
unimplemented!());
-        let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
-            Arc::new(ScalarUDF::new(
-                "TestScalarUDF",
-                &Signature::uniform(1, vec![DataType::Int32], 
Volatility::Stable),
-                &return_type,
-                &fun,
-            )),
-            vec![lit("Apple")],
-        ));
+        let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]);
         let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], 
empty)?);
         let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), 
&plan, "")
             .err()
             .unwrap();
         assert_eq!(
-    "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to 
the signature Uniform(1, [Int32]) failed.",
+    "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to 
the signature Uniform(1, [Float32]) failed.",
     err.strip_backtrace()
     );
         Ok(())
diff --git a/docs/source/library-user-guide/adding-udfs.md 
b/docs/source/library-user-guide/adding-udfs.md
index 11cf52eb3f..c51e4de323 100644
--- a/docs/source/library-user-guide/adding-udfs.md
+++ b/docs/source/library-user-guide/adding-udfs.md
@@ -76,7 +76,9 @@ The challenge however is that DataFusion doesn't know about 
this function. We ne
 
 ### Registering a Scalar UDF
 
-To register a Scalar UDF, you need to wrap the function implementation in a 
`ScalarUDF` struct and then register it with the `SessionContext`. DataFusion 
provides the `create_udf` and `make_scalar_function` helper functions to make 
this easier.
+To register a Scalar UDF, you need to wrap the function implementation in a 
[`ScalarUDF`] struct and then register it with the `SessionContext`.
+DataFusion provides the [`create_udf`] and helper functions to make this 
easier.
+There is a lower level API with more functionality but is more complex, that 
is documented in [`advanced_udf.rs`].
 
 ```rust
 use datafusion::logical_expr::{Volatility, create_udf};
@@ -93,6 +95,11 @@ let udf = create_udf(
 );
 ```
 
+[`scalarudf`]: 
https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
+[`create_udf`]: 
https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
+[`make_scalar_function`]: 
https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
+[`advanced_udf.rs`]: 
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
+
 A few things to note:
 
 - The first argument is the name of the function. This is the name that will 
be used in SQL queries.


Reply via email to