This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b2cbc7809e Add trait based ScalarUDF API (#8578)
b2cbc7809e is described below
commit b2cbc7809ee0656099169307a73aadff23ab1030
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu Dec 28 15:07:32 2023 -0500
Add trait based ScalarUDF API (#8578)
* Introduce new trait based ScalarUDF API
* change name to `Self::new_from_impl`
* Improve documentation, add link to advanced_udf.rs in the user guide
* typo
* Improve docs for aliases
* Apply suggestions from code review
Co-authored-by: Liang-Chi Hsieh <[email protected]>
* improve docs
---------
Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
datafusion-examples/README.md | 3 +-
datafusion-examples/examples/advanced_udf.rs | 243 +++++++++++++++++++++
datafusion-examples/examples/simple_udf.rs | 6 +
datafusion/expr/src/expr.rs | 55 +++--
datafusion/expr/src/expr_fn.rs | 85 ++++++-
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/udf.rs | 169 +++++++++++++-
datafusion/optimizer/src/analyzer/type_coercion.rs | 64 +++---
docs/source/library-user-guide/adding-udfs.md | 9 +-
9 files changed, 562 insertions(+), 74 deletions(-)
diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md
index 057cdd4752..1296c74ea2 100644
--- a/datafusion-examples/README.md
+++ b/datafusion-examples/README.md
@@ -59,8 +59,9 @@ cargo run --example csv_sql
- [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure
`object_store` and run a query against files stored in AWS S3
- [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store`
and run a query against files vi HTTP
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom
Query Optimizer pass
+- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined
Scalar Function (UDF)
+- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more
complicated User Defined Scalar Function (UDF)
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User
Defined Aggregate Function (UDAF)
-- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined
(scalar) Function (UDF)
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User
Defined Window Function (UDWF)
## Distributed
diff --git a/datafusion-examples/examples/advanced_udf.rs
b/datafusion-examples/examples/advanced_udf.rs
new file mode 100644
index 0000000000..6ebf88a0b6
--- /dev/null
+++ b/datafusion-examples/examples/advanced_udf.rs
@@ -0,0 +1,243 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::{
+ arrow::{
+ array::{ArrayRef, Float32Array, Float64Array},
+ datatypes::DataType,
+ record_batch::RecordBatch,
+ },
+ logical_expr::Volatility,
+};
+use std::any::Any;
+
+use arrow::array::{new_null_array, Array, AsArray};
+use arrow::compute;
+use arrow::datatypes::Float64Type;
+use datafusion::error::Result;
+use datafusion::prelude::*;
+use datafusion_common::{internal_err, ScalarValue};
+use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature};
+use std::sync::Arc;
+
+/// This example shows how to use the full ScalarUDFImpl API to implement a
user
+/// defined function. As in the `simple_udf.rs` example, this struct implements
+/// a function that takes two arguments and returns the first argument raised
to
+/// the power of the second argument `a^b`.
+///
+/// To do so, we must implement the `ScalarUDFImpl` trait.
+struct PowUdf {
+ signature: Signature,
+ aliases: Vec<String>,
+}
+
+impl PowUdf {
+ /// Create a new instance of the `PowUdf` struct
+ fn new() -> Self {
+ Self {
+ signature: Signature::exact(
+ // this function will always take two arguments of type f64
+ vec![DataType::Float64, DataType::Float64],
+ // this function is deterministic and will always return the
same
+ // result for the same input
+ Volatility::Immutable,
+ ),
+ // we will also add an alias of "my_pow"
+ aliases: vec!["my_pow".to_string()],
+ }
+ }
+}
+
+impl ScalarUDFImpl for PowUdf {
+ /// We implement as_any so that we can downcast the ScalarUDFImpl trait
object
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ /// Return the name of this function
+ fn name(&self) -> &str {
+ "pow"
+ }
+
+ /// Return the "signature" of this function -- namely what types of
arguments it will take
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ /// What is the type of value that will be returned by this function? In
+ /// this case it will always be a constant value, but it could also be a
+ /// function of the input types.
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Float64)
+ }
+
+ /// This is the function that actually calculates the results.
+ ///
+ /// This is the same way that functions built into DataFusion are invoked,
+ /// which permits important special cases when one or both of the arguments
+ /// are single values (constants). For example `pow(a, 2)`
+ ///
+ /// However, it also means the implementation is more complex than when
+ /// using `create_udf`.
+ fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ // DataFusion has arranged for the correct inputs to be passed to this
+ // function, but we check again to make sure
+ assert_eq!(args.len(), 2);
+ let (base, exp) = (&args[0], &args[1]);
+ assert_eq!(base.data_type(), DataType::Float64);
+ assert_eq!(exp.data_type(), DataType::Float64);
+
+ match (base, exp) {
+ // For demonstration purposes we also implement the scalar / scalar
+ // case here, but it is not typically required for high
performance.
+ //
+ // For performance it is most important to optimize cases where at
+ // least one argument is an array. If all arguments are constants,
+ // the DataFusion expression simplification logic will often invoke
+ // this path once during planning, and simply use the result during
+ // execution.
+ (
+ ColumnarValue::Scalar(ScalarValue::Float64(base)),
+ ColumnarValue::Scalar(ScalarValue::Float64(exp)),
+ ) => {
+ // compute the output. Note DataFusion treats `None` as NULL.
+ let res = match (base, exp) {
+ (Some(base), Some(exp)) => Some(base.powf(*exp)),
+ // one or both arguments were NULL
+ _ => None,
+ };
+ Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
+ }
+ // special case if the exponent is a constant
+ (
+ ColumnarValue::Array(base_array),
+ ColumnarValue::Scalar(ScalarValue::Float64(exp)),
+ ) => {
+ let result_array = match exp {
+ // a ^ null = null
+ None => new_null_array(base_array.data_type(),
base_array.len()),
+ // a ^ exp
+ Some(exp) => {
+ // DataFusion has ensured both arguments are Float64:
+ let base_array =
base_array.as_primitive::<Float64Type>();
+ // calculate the result for every row. The `unary`
+ // kernel creates very fast "vectorized" code and
+ // handles things like null values for us.
+ let res: Float64Array =
+ compute::unary(base_array, |base| base.powf(*exp));
+ Arc::new(res)
+ }
+ };
+ Ok(ColumnarValue::Array(result_array))
+ }
+
+ // special case if the base is a constant (note this code is quite
+ // similar to the previous case, so we omit comments)
+ (
+ ColumnarValue::Scalar(ScalarValue::Float64(base)),
+ ColumnarValue::Array(exp_array),
+ ) => {
+ let res = match base {
+ None => new_null_array(exp_array.data_type(),
exp_array.len()),
+ Some(base) => {
+ let exp_array =
exp_array.as_primitive::<Float64Type>();
+ let res: Float64Array =
+ compute::unary(exp_array, |exp| base.powf(exp));
+ Arc::new(res)
+ }
+ };
+ Ok(ColumnarValue::Array(res))
+ }
+ // Both arguments are arrays so we have to perform the calculation
for every row
+ (ColumnarValue::Array(base_array),
ColumnarValue::Array(exp_array)) => {
+ let res: Float64Array = compute::binary(
+ base_array.as_primitive::<Float64Type>(),
+ exp_array.as_primitive::<Float64Type>(),
+ |base, exp| base.powf(exp),
+ )?;
+ Ok(ColumnarValue::Array(Arc::new(res)))
+ }
+ // if the types were not float, it is a bug in DataFusion
+ _ => {
+ use datafusion_common::DataFusionError;
+ internal_err!("Invalid argument types to pow function")
+ }
+ }
+ }
+
+ /// We will also add an alias of "my_pow"
+ fn aliases(&self) -> &[String] {
+ &self.aliases
+ }
+}
+
+/// In this example we register `PowUdf` as a user defined function
+/// and invoke it via the DataFrame API and SQL
+#[tokio::main]
+async fn main() -> Result<()> {
+ let ctx = create_context()?;
+
+ // create the UDF
+ let pow = ScalarUDF::from(PowUdf::new());
+
+ // register the UDF with the context so it can be invoked by name and from
SQL
+ ctx.register_udf(pow.clone());
+
+ // get a DataFrame from the context for scanning the "t" table
+ let df = ctx.table("t").await?;
+
+ // Call pow(a, 10) using the DataFrame API
+ let df = df.select(vec![pow.call(vec![col("a"), lit(10i32)])])?;
+
+ // note that the second argument is passed as an i32, not f64. DataFusion
+ // automatically coerces the types to match the UDF's defined signature.
+
+ // print the results
+ df.show().await?;
+
+ // You can also invoke both pow(2, 10) and its alias my_pow(a, b) using
SQL
+ let sql_df = ctx.sql("SELECT pow(2, 10), my_pow(a, b) FROM t").await?;
+ sql_df.show().await?;
+
+ Ok(())
+}
+
+/// create local execution context with an in-memory table:
+///
+/// ```text
+/// +-----+-----+
+/// | a | b |
+/// +-----+-----+
+/// | 2.1 | 1.0 |
+/// | 3.1 | 2.0 |
+/// | 4.1 | 3.0 |
+/// | 5.1 | 4.0 |
+/// +-----+-----+
+/// ```
+fn create_context() -> Result<SessionContext> {
+ // define data.
+ let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1]));
+ let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
+ let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;
+
+ // declare a new context. In Spark API, this corresponds to a new
SparkSession
+ let ctx = SessionContext::new();
+
+ // declare a table in memory. In Spark API, this corresponds to
createDataFrame(...).
+ ctx.register_batch("t", batch)?;
+ Ok(ctx)
+}
diff --git a/datafusion-examples/examples/simple_udf.rs
b/datafusion-examples/examples/simple_udf.rs
index 5919917865..39e1e13ce3 100644
--- a/datafusion-examples/examples/simple_udf.rs
+++ b/datafusion-examples/examples/simple_udf.rs
@@ -140,5 +140,11 @@ async fn main() -> Result<()> {
// print the results
df.show().await?;
+ // Given that `pow` is registered in the context, we can also use it in
SQL:
+ let sql_df = ctx.sql("SELECT pow(a, b) FROM t").await?;
+
+ // print the results
+ sql_df.show().await?;
+
Ok(())
}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index b46e9ec8f6..0ec19bcadb 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1724,13 +1724,13 @@ mod test {
use crate::expr::Cast;
use crate::expr_fn::col;
use crate::{
- case, lit, BuiltinScalarFunction, ColumnarValue, Expr,
ReturnTypeFunction,
- ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF,
Signature,
- Volatility,
+ case, lit, BuiltinScalarFunction, ColumnarValue, Expr,
ScalarFunctionDefinition,
+ ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
+ use std::any::Any;
use std::sync::Arc;
#[test]
@@ -1848,24 +1848,41 @@ mod test {
);
// UDF
- let return_type: ReturnTypeFunction =
- Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
- let fun: ScalarFunctionImplementation =
- Arc::new(move |_|
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
- let udf = Arc::new(ScalarUDF::new(
- "TestScalarUDF",
- &Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
- &return_type,
- &fun,
- ));
+ struct TestScalarUDF {
+ signature: Signature,
+ }
+ impl ScalarUDFImpl for TestScalarUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "TestScalarUDF"
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>
{
+ Ok(DataType::Utf8)
+ }
+
+ fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>
{
+ Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
+ }
+ }
+ let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
+ signature: Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
+ }));
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
- let udf = Arc::new(ScalarUDF::new(
- "TestScalarUDF",
- &Signature::uniform(1, vec![DataType::Float32],
Volatility::Volatile),
- &return_type,
- &fun,
- ));
+ let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
+ signature: Signature::uniform(
+ 1,
+ vec![DataType::Float32],
+ Volatility::Volatile,
+ ),
+ }));
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
// Unresolved function
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index cedf1d8451..eed41d97cc 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -22,15 +22,16 @@ use crate::expr::{
Placeholder, ScalarFunction, TryCast,
};
use crate::function::PartitionEvaluatorFactory;
-use crate::WindowUDF;
use crate::{
aggregate_function, built_in_function,
conditional_expressions::CaseBuilder,
logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction,
Volatility,
};
+use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF};
use arrow::datatypes::DataType;
use datafusion_common::{Column, Result};
+use std::any::Any;
use std::ops::Not;
use std::sync::Arc;
@@ -944,11 +945,18 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder::new(None, vec![when], vec![then], None)
}
-/// Creates a new UDF with a specific signature and specific return type.
-/// This is a helper function to create a new UDF.
-/// The function `create_udf` returns a subset of all possible
`ScalarFunction`:
-/// * the UDF has a fixed return type
-/// * the UDF has a fixed signature (e.g. [f64, f64])
+/// Convenience method to create a new user defined scalar function (UDF) with
a
+/// specific signature and specific return type.
+///
+/// Note this function does not expose all available features of [`ScalarUDF`],
+/// such as
+///
+/// * computing return types based on input types
+/// * multiple [`Signature`]s
+/// * aliases
+///
+/// See [`ScalarUDF`] for details and examples on how to use the full
+/// functionality.
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
@@ -956,13 +964,66 @@ pub fn create_udf(
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
- let return_type: ReturnTypeFunction = Arc::new(move |_|
Ok(return_type.clone()));
- ScalarUDF::new(
+ let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t|
t.as_ref().clone());
+ ScalarUDF::from(SimpleScalarUDF::new(
name,
- &Signature::exact(input_types, volatility),
- &return_type,
- &fun,
- )
+ input_types,
+ return_type,
+ volatility,
+ fun,
+ ))
+}
+
+/// Implements [`ScalarUDFImpl`] for functions that have a single signature and
+/// return type.
+pub struct SimpleScalarUDF {
+ name: String,
+ signature: Signature,
+ return_type: DataType,
+ fun: ScalarFunctionImplementation,
+}
+
+impl SimpleScalarUDF {
+ /// Create a new `SimpleScalarUDF` from a name, input types, return type
and
+ /// implementation. Implementing [`ScalarUDFImpl`] allows more flexibility
+ pub fn new(
+ name: impl Into<String>,
+ input_types: Vec<DataType>,
+ return_type: DataType,
+ volatility: Volatility,
+ fun: ScalarFunctionImplementation,
+ ) -> Self {
+ let name = name.into();
+ let signature = Signature::exact(input_types, volatility);
+ Self {
+ name,
+ signature,
+ return_type,
+ fun,
+ }
+ }
+}
+
+impl ScalarUDFImpl for SimpleScalarUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ &self.name
+ }
+
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ Ok(self.return_type.clone())
+ }
+
+ fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ (self.fun)(args)
+ }
}
/// Creates a new UDAF with a specific signature, state type and return type.
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index 48532e13dc..bf8e9e2954 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -80,7 +80,7 @@ pub use signature::{
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::AggregateUDF;
-pub use udf::ScalarUDF;
+pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::WindowUDF;
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
pub use window_function::{BuiltInWindowFunction, WindowFunction};
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 3a18ca2d25..2ec80a4a9e 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -17,9 +17,12 @@
//! [`ScalarUDF`]: Scalar User Defined Functions
-use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
+use crate::{
+ ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation,
Signature,
+};
use arrow::datatypes::DataType;
use datafusion_common::Result;
+use std::any::Any;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
@@ -27,11 +30,19 @@ use std::sync::Arc;
/// Logical representation of a Scalar User Defined Function.
///
-/// A scalar function produces a single row output for each row of input.
+/// A scalar function produces a single row output for each row of input. This
+/// struct contains the information DataFusion needs to plan and invoke
+/// functions you supply such name, type signature, return type, and actual
+/// implementation.
///
-/// This struct contains the information DataFusion needs to plan and invoke
-/// functions such name, type signature, return type, and actual
implementation.
///
+/// 1. For simple (less performant) use cases, use [`create_udf`] and
[`simple_udf.rs`].
+///
+/// 2. For advanced use cases, use [`ScalarUDFImpl`] and [`advanced_udf.rs`].
+///
+/// [`create_udf`]: crate::expr_fn::create_udf
+/// [`simple_udf.rs`]:
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udf.rs
+/// [`advanced_udf.rs`]:
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
#[derive(Clone)]
pub struct ScalarUDF {
/// The name of the function
@@ -79,7 +90,11 @@ impl std::hash::Hash for ScalarUDF {
}
impl ScalarUDF {
- /// Create a new ScalarUDF
+ /// Create a new ScalarUDF from low level details.
+ ///
+ /// See [`ScalarUDFImpl`] for a more convenient way to create a
+ /// `ScalarUDF` using trait objects
+ #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl
instead")]
pub fn new(
name: &str,
signature: &Signature,
@@ -95,6 +110,34 @@ impl ScalarUDF {
}
}
+ /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object
+ ///
+ /// Note this is the same as using the `From` impl (`ScalarUDF::from`)
+ pub fn new_from_impl<F>(fun: F) -> ScalarUDF
+ where
+ F: ScalarUDFImpl + Send + Sync + 'static,
+ {
+ // TODO change the internal implementation to use the trait object
+ let arc_fun = Arc::new(fun);
+ let captured_self = arc_fun.clone();
+ let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
+ let return_type = captured_self.return_type(arg_types)?;
+ Ok(Arc::new(return_type))
+ });
+
+ let captured_self = arc_fun.clone();
+ let func: ScalarFunctionImplementation =
+ Arc::new(move |args| captured_self.invoke(args));
+
+ Self {
+ name: arc_fun.name().to_string(),
+ signature: arc_fun.signature().clone(),
+ return_type: return_type.clone(),
+ fun: func,
+ aliases: arc_fun.aliases().to_vec(),
+ }
+ }
+
/// Adds additional names that can be used to invoke this function, in
addition to `name`
pub fn with_aliases(
mut self,
@@ -105,7 +148,9 @@ impl ScalarUDF {
self
}
- /// creates a logical expression with a call of the UDF
+ /// Returns a [`Expr`] logical expression to call this UDF with specified
+ /// arguments.
+ ///
/// This utility allows using the UDF without requiring access to the
registry.
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
@@ -124,22 +169,126 @@ impl ScalarUDF {
&self.aliases
}
- /// Returns this function's signature (what input types are accepted)
+ /// Returns this function's [`Signature`] (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
}
- /// Return the type of the function given its input types
+ /// The datatype this function returns given the input argument input types
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
}
- /// Return the actual implementation
+ /// Return an [`Arc`] to the function implementation
pub fn fun(&self) -> ScalarFunctionImplementation {
self.fun.clone()
}
+}
- // TODO maybe add an invoke() method that runs the actual function?
+impl<F> From<F> for ScalarUDF
+where
+ F: ScalarUDFImpl + Send + Sync + 'static,
+{
+ fn from(fun: F) -> Self {
+ Self::new_from_impl(fun)
+ }
+}
+
+/// Trait for implementing [`ScalarUDF`].
+///
+/// This trait exposes the full API for implementing user defined functions and
+/// can be used to implement any function.
+///
+/// See [`advanced_udf.rs`] for a full example with complete implementation and
+/// [`ScalarUDF`] for other available options.
+///
+///
+/// [`advanced_udf.rs`]:
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
+/// # Basic Example
+/// ```
+/// # use std::any::Any;
+/// # use arrow::datatypes::DataType;
+/// # use datafusion_common::{DataFusionError, plan_err, Result};
+/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
+/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
+/// struct AddOne {
+/// signature: Signature
+/// };
+///
+/// impl AddOne {
+/// fn new() -> Self {
+/// Self {
+/// signature: Signature::uniform(1, vec![DataType::Int32],
Volatility::Immutable)
+/// }
+/// }
+/// }
+///
+/// /// Implement the ScalarUDFImpl trait for AddOne
+/// impl ScalarUDFImpl for AddOne {
+/// fn as_any(&self) -> &dyn Any { self }
+/// fn name(&self) -> &str { "add_one" }
+/// fn signature(&self) -> &Signature { &self.signature }
+/// fn return_type(&self, args: &[DataType]) -> Result<DataType> {
+/// if !matches!(args.get(0), Some(&DataType::Int32)) {
+/// return plan_err!("add_one only accepts Int32 arguments");
+/// }
+/// Ok(DataType::Int32)
+/// }
+/// // The actual implementation would add one to the argument
+/// fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!() }
+/// }
+///
+/// // Create a new ScalarUDF from the implementation
+/// let add_one = ScalarUDF::from(AddOne::new());
+///
+/// // Call the function `add_one(col)`
+/// let expr = add_one.call(vec![col("a")]);
+/// ```
+pub trait ScalarUDFImpl {
+ /// Returns this object as an [`Any`] trait object
+ fn as_any(&self) -> &dyn Any;
+
+ /// Returns this function's name
+ fn name(&self) -> &str;
+
+ /// Returns the function's [`Signature`] for information about what input
+ /// types are accepted and the function's Volatility.
+ fn signature(&self) -> &Signature;
+
+ /// What [`DataType`] will be returned by this function, given the types of
+ /// the arguments
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
+
+ /// Invoke the function on `args`, returning the appropriate result
+ ///
+ /// The function will be invoked passed with the slice of [`ColumnarValue`]
+ /// (either scalar or array).
+ ///
+ /// # Zero Argument Functions
+ /// If the function has zero parameters (e.g. `now()`) it will be passed a
+ /// single element slice which is a a null array to indicate the batch's
row
+ /// count (so the function can know the resulting array size).
+ ///
+ /// # Performance
+ ///
+ /// For the best performance, the implementations of `invoke` should handle
+ /// the common case when one or more of their arguments are constant values
+ /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`]
+ /// and treating all arguments as arrays will work, but will be slower.
+ fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue>;
+
+ /// Returns any aliases (alternate names) for this function.
+ ///
+ /// Aliases can be used to invoke the same function using different names.
+ /// For example in some databases `now()` and `current_timestamp()` are
+ /// aliases for the same function. This behavior can be obtained by
+ /// returning `current_timestamp` as an alias for the `now` function.
+ ///
+ /// Note: `aliases` should only include names other than [`Self::name`].
+ /// Defaults to `[]` (no aliases)
+ fn aliases(&self) -> &[String] {
+ &[]
+ }
}
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index c5e1180b9f..b6298f5b55 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -738,7 +738,8 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef)
-> Result<Case> {
#[cfg(test)]
mod test {
- use std::sync::Arc;
+ use std::any::Any;
+ use std::sync::{Arc, OnceLock};
use arrow::array::{FixedSizeListArray, Int32Array};
use arrow::datatypes::{DataType, TimeUnit};
@@ -750,13 +751,13 @@ mod test {
use datafusion_expr::{
cast, col, concat, concat_ws, create_udaf, is_true,
AccumulatorFactoryFunction,
AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction,
Case,
- ColumnarValue, ExprSchemable, Filter, Operator, StateTypeFunction,
Subquery,
+ ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl,
StateTypeFunction,
+ Subquery,
};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
- Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation,
ScalarUDF,
- Signature, Volatility,
+ Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature,
Volatility,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
@@ -808,22 +809,36 @@ mod test {
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)
}
+ static TEST_SIGNATURE: OnceLock<Signature> = OnceLock::new();
+
+ struct TestScalarUDF {}
+ impl ScalarUDFImpl for TestScalarUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn name(&self) -> &str {
+ "TestScalarUDF"
+ }
+ fn signature(&self) -> &Signature {
+ TEST_SIGNATURE.get_or_init(|| {
+ Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable)
+ })
+ }
+ fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Utf8)
+ }
+
+ fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
+ }
+ }
+
#[test]
fn scalar_udf() -> Result<()> {
let empty = empty();
- let return_type: ReturnTypeFunction =
- Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
- let fun: ScalarFunctionImplementation =
- Arc::new(move |_|
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
- let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
- Arc::new(ScalarUDF::new(
- "TestScalarUDF",
- &Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
- &return_type,
- &fun,
- )),
- vec![lit(123_i32)],
- ));
+
+ let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit(123_i32)]);
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf],
empty)?);
let expected =
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n
EmptyRelation";
@@ -833,24 +848,13 @@ mod test {
#[test]
fn scalar_udf_invalid_input() -> Result<()> {
let empty = empty();
- let return_type: ReturnTypeFunction =
- Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
- let fun: ScalarFunctionImplementation = Arc::new(move |_|
unimplemented!());
- let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
- Arc::new(ScalarUDF::new(
- "TestScalarUDF",
- &Signature::uniform(1, vec![DataType::Int32],
Volatility::Stable),
- &return_type,
- &fun,
- )),
- vec![lit("Apple")],
- ));
+ let udf = ScalarUDF::from(TestScalarUDF {}).call(vec![lit("Apple")]);
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf],
empty)?);
let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()),
&plan, "")
.err()
.unwrap();
assert_eq!(
- "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to
the signature Uniform(1, [Int32]) failed.",
+ "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to
the signature Uniform(1, [Float32]) failed.",
err.strip_backtrace()
);
Ok(())
diff --git a/docs/source/library-user-guide/adding-udfs.md
b/docs/source/library-user-guide/adding-udfs.md
index 11cf52eb3f..c51e4de323 100644
--- a/docs/source/library-user-guide/adding-udfs.md
+++ b/docs/source/library-user-guide/adding-udfs.md
@@ -76,7 +76,9 @@ The challenge however is that DataFusion doesn't know about
this function. We ne
### Registering a Scalar UDF
-To register a Scalar UDF, you need to wrap the function implementation in a
`ScalarUDF` struct and then register it with the `SessionContext`. DataFusion
provides the `create_udf` and `make_scalar_function` helper functions to make
this easier.
+To register a Scalar UDF, you need to wrap the function implementation in a
[`ScalarUDF`] struct and then register it with the `SessionContext`.
+DataFusion provides the [`create_udf`] and helper functions to make this
easier.
+There is a lower level API with more functionality but is more complex, that
is documented in [`advanced_udf.rs`].
```rust
use datafusion::logical_expr::{Volatility, create_udf};
@@ -93,6 +95,11 @@ let udf = create_udf(
);
```
+[`scalarudf`]:
https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.ScalarUDF.html
+[`create_udf`]:
https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udf.html
+[`make_scalar_function`]:
https://docs.rs/datafusion/latest/datafusion/physical_expr/functions/fn.make_scalar_function.html
+[`advanced_udf.rs`]:
https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
+
A few things to note:
- The first argument is the name of the function. This is the name that will
be used in SQL queries.