This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 92d9274e6a Support compute return types from argument values (not just
their DataTypes) (#8985)
92d9274e6a is described below
commit 92d9274e6a2c677fe07938e531be4b45532a89a9
Author: Junhao Liu <[email protected]>
AuthorDate: Thu Feb 15 10:28:05 2024 -0600
Support compute return types from argument values (not just their
DataTypes) (#8985)
* ScalarValue return types from argument values
* change file name
* try using ?Sized
* use Ok
* move method default impl outside trait
* Use type trait for ExprSchemable
* fix nit
* Proposed Return Type from Expr suggestions (#1)
* Improve return_type_from_args
* Rework example
* Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
---------
Co-authored-by: Junhao Liu <[email protected]>
* Apply suggestions from code review
Co-authored-by: Alex Huang <[email protected]>
* Fix tests + clippy
* rework types to use dyn trait
* fmt
* docs
* Apply suggestions from code review
Co-authored-by: Jeffrey Vo <[email protected]>
* Add docs explaining what happens when both `return_type` and
`return_type_from_exprs` are called
* clippy
* fix doc -- comedy of errors
---------
Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Alex Huang <[email protected]>
Co-authored-by: Jeffrey Vo <[email protected]>
---
.../user_defined/user_defined_scalar_functions.rs | 142 ++++++++++++++++++++-
datafusion/expr/src/expr_schema.rs | 40 +++---
datafusion/expr/src/udf.rs | 69 ++++++++--
datafusion/optimizer/src/analyzer/type_coercion.rs | 16 +--
datafusion/physical-expr/src/planner.rs | 14 +-
datafusion/physical-expr/src/udf.rs | 17 +--
6 files changed, 245 insertions(+), 53 deletions(-)
diff --git
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index a86c76b9b6..9812789740 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
-use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result,
ScalarValue};
+use datafusion_common::{
+ assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array,
not_impl_err,
+ plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
+};
use datafusion_expr::{
- create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder,
ScalarUDF,
- ScalarUDFImpl, Signature, Volatility,
+ create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
+ LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
+use std::any::Any;
use std::iter;
use std::sync::Arc;
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() ->
Result<()> {
Ok(())
}
+#[derive(Debug)]
+struct TakeUDF {
+ signature: Signature,
+}
+
+impl TakeUDF {
+ fn new() -> Self {
+ Self {
+ signature: Signature::any(3, Volatility::Immutable),
+ }
+ }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input
values
+impl ScalarUDFImpl for TakeUDF {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+ fn name(&self) -> &str {
+ "take"
+ }
+ fn signature(&self) -> &Signature {
+ &self.signature
+ }
+ fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+ not_impl_err!("Not called because the return_type_from_exprs is
implemented")
+ }
+
+ /// This function returns the type of the first or second argument based on
+ /// the third argument:
+ ///
+ /// 1. If the third argument is '0', return the type of the first argument
+ /// 2. If the third argument is '1', return the type of the second argument
+ fn return_type_from_exprs(
+ &self,
+ arg_exprs: &[Expr],
+ schema: &dyn ExprSchema,
+ ) -> Result<DataType> {
+ if arg_exprs.len() != 3 {
+ return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
+ }
+
+ let take_idx = if let
Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
+ arg_exprs.get(2)
+ {
+ if *idx == 0 || *idx == 1 {
+ *idx as usize
+ } else {
+ return plan_err!("The third argument must be 0 or 1, got:
{idx}");
+ }
+ } else {
+ return plan_err!(
+ "The third argument must be a literal of type int64, but got
{:?}",
+ arg_exprs.get(2)
+ );
+ };
+
+ arg_exprs.get(take_idx).unwrap().get_type(schema)
+ }
+
+ // The actual implementation
+ fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+ let take_idx = match &args[2] {
+ ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v
as usize,
+ _ => unreachable!(),
+ };
+ match &args[take_idx] {
+ ColumnarValue::Array(array) =>
Ok(ColumnarValue::Array(array.clone())),
+ ColumnarValue::Scalar(_) => unimplemented!(),
+ }
+ }
+}
+
+#[tokio::test]
+async fn verify_udf_return_type() -> Result<()> {
+ // Create a new ScalarUDF from the implementation
+ let take = ScalarUDF::from(TakeUDF::new());
+
+ // SELECT
+ // take(smallint_col, double_col, 0) as take0,
+ // take(smallint_col, double_col, 1) as take1
+ // FROM alltypes_plain;
+ let exprs = vec![
+ take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
+ .alias("take0"),
+ take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
+ .alias("take1"),
+ ];
+
+ let ctx = SessionContext::new();
+ register_alltypes_parquet(&ctx).await?;
+
+ let df = ctx.table("alltypes_plain").await?.select(exprs)?;
+
+ let schema = df.schema();
+
+ // The output schema should be
+ // * type of column smallint_col (int32)
+ // * type of column double_col (float64)
+ assert_eq!(schema.field(0).data_type(), &DataType::Int32);
+ assert_eq!(schema.field(1).data_type(), &DataType::Float64);
+
+ let expected = [
+ "+-------+-------+",
+ "| take0 | take1 |",
+ "+-------+-------+",
+ "| 0 | 0.0 |",
+ "| 0 | 0.0 |",
+ "| 0 | 0.0 |",
+ "| 0 | 0.0 |",
+ "| 1 | 10.1 |",
+ "| 1 | 10.1 |",
+ "| 1 | 10.1 |",
+ "| 1 | 10.1 |",
+ "+-------+-------+",
+ ];
+ assert_batches_sorted_eq!(&expected, &df.collect().await?);
+
+ Ok(())
+}
+
fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
@@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) ->
Result<()> {
Ok(())
}
+async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> {
+ let testdata = datafusion::test_util::parquet_test_data();
+ ctx.register_parquet(
+ "alltypes_plain",
+ &format!("{testdata}/alltypes_plain.parquet"),
+ ParquetReadOptions::default(),
+ )
+ .await?;
+ Ok(())
+}
+
/// Execute SQL and return results as a RecordBatch
async fn plan_and_collect(ctx: &SessionContext, sql: &str) ->
Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 517d7a35f7..491b4a8522 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -28,8 +28,8 @@ use crate::{utils, LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
- internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema,
- DataFusionError, ExprSchema, Result,
+ internal_err, plan_datafusion_err, plan_err, Column, DFField,
DataFusionError,
+ ExprSchema, Result,
};
use std::collections::HashMap;
use std::sync::Arc;
@@ -37,26 +37,28 @@ use std::sync::Arc;
/// trait to allow expr to typable with respect to a schema
pub trait ExprSchemable {
/// given a schema, return the type of the expr
- fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
+ fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
/// given a schema, return the nullability of the expr
- fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
+ fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
/// given a schema, return the expr's optional metadata
- fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String,
String>>;
+ fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String,
String>>;
/// convert to a field with respect to a schema
- fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
+ fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField>;
/// cast to a type with respect to a schema
- fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) ->
Result<Expr>;
+ fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) ->
Result<Expr>;
}
impl ExprSchemable for Expr {
/// Returns the [arrow::datatypes::DataType] of the expression
/// based on [ExprSchema]
///
- /// Note: [DFSchema] implements [ExprSchema].
+ /// Note: [`DFSchema`] implements [ExprSchema].
+ ///
+ /// [`DFSchema`]: datafusion_common::DFSchema
///
/// # Examples
///
@@ -90,7 +92,7 @@ impl ExprSchemable for Expr {
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
- fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
+ fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expr::Placeholder(Placeholder { data_type, .. }) => match
&data_type {
@@ -136,7 +138,7 @@ impl ExprSchemable for Expr {
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
- Ok(fun.return_type(&arg_data_types)?)
+ Ok(fun.return_type_from_exprs(args, schema)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be
resolved.")
@@ -213,14 +215,16 @@ impl ExprSchemable for Expr {
/// Returns the nullability of the expression based on [ExprSchema].
///
- /// Note: [DFSchema] implements [ExprSchema].
+ /// Note: [`DFSchema`] implements [ExprSchema].
+ ///
+ /// [`DFSchema`]: datafusion_common::DFSchema
///
/// # Errors
///
/// This function errors when it is not possible to compute its
/// nullability. This happens when the expression refers to a
/// column that does not exist in the schema.
- fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
+ fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
match self {
Expr::Alias(Alias { expr, .. })
| Expr::Not(expr)
@@ -327,7 +331,7 @@ impl ExprSchemable for Expr {
}
}
- fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String,
String>> {
+ fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String,
String>> {
match self {
Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
@@ -339,7 +343,7 @@ impl ExprSchemable for Expr {
///
/// So for example, a projected expression `col(c1) + col(c2)` is
/// placed in an output field **named** col("c1 + c2")
- fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
+ fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
c.relation.clone(),
@@ -370,7 +374,7 @@ impl ExprSchemable for Expr {
///
/// This function errors when it is impossible to cast the
/// expression to the target [arrow::datatypes::DataType].
- fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) ->
Result<Expr> {
+ fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) ->
Result<Expr> {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
return Ok(self);
@@ -394,10 +398,10 @@ impl ExprSchemable for Expr {
}
/// return the schema [`Field`] for the type referenced by `get_indexed_field`
-fn field_for_index<S: ExprSchema>(
+fn field_for_index(
expr: &Expr,
field: &GetFieldAccess,
- schema: &S,
+ schema: &dyn ExprSchema,
) -> Result<Field> {
let expr_dt = expr.get_type(schema)?;
match field {
@@ -457,7 +461,7 @@ mod tests {
use super::*;
use crate::{col, lit};
use arrow::datatypes::{DataType, Fields};
- use datafusion_common::{Column, ScalarValue, TableReference};
+ use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
macro_rules! test_is_expr_nullable {
($EXPR_TYPE:ident) => {{
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 3017e1ec02..5b5d92a628 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -17,12 +17,13 @@
//! [`ScalarUDF`]: Scalar User Defined Functions
+use crate::ExprSchemable;
use crate::{
ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
ScalarFunctionImplementation, Signature,
};
use arrow::datatypes::DataType;
-use datafusion_common::Result;
+use datafusion_common::{ExprSchema, Result};
use std::any::Any;
use std::fmt;
use std::fmt::Debug;
@@ -110,7 +111,7 @@ impl ScalarUDF {
///
/// If you implement [`ScalarUDFImpl`] directly you should return aliases
directly.
pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>)
-> Self {
- Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases))
+ Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(),
aliases))
}
/// Returns a [`Expr`] logical expression to call this UDF with specified
@@ -146,10 +147,17 @@ impl ScalarUDF {
}
/// The datatype this function returns given the input argument input
types.
+ /// This function is used when the input arguments are [`Expr`]s.
///
- /// See [`ScalarUDFImpl::return_type`] for more details.
- pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
- self.inner.return_type(args)
+ ///
+ /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
+ pub fn return_type_from_exprs(
+ &self,
+ args: &[Expr],
+ schema: &dyn ExprSchema,
+ ) -> Result<DataType> {
+ // If the implementation provides a return_type_from_exprs, use it
+ self.inner.return_type_from_exprs(args, schema)
}
/// Invoke the function on `args`, returning the appropriate result.
@@ -246,9 +254,54 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
fn signature(&self) -> &Signature;
/// What [`DataType`] will be returned by this function, given the types of
- /// the arguments
+ /// the arguments.
+ ///
+ /// # Notes
+ ///
+ /// If you provide an implementation for [`Self::return_type_from_exprs`],
+ /// DataFusion will not call `return_type` (this function). In this case it
+ /// is recommended to return [`DataFusionError::Internal`].
+ ///
+ /// [`DataFusionError::Internal`]:
datafusion_common::DataFusionError::Internal
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
+ /// What [`DataType`] will be returned by this function, given the
+ /// arguments?
+ ///
+ /// Note most UDFs should implement [`Self::return_type`] and not this
+ /// function. The output type for most functions only depends on the types
+ /// of their inputs (e.g. `sqrt(f32)` is always `f32`).
+ ///
+ /// By default, this function calls [`Self::return_type`] with the
+ /// types of each argument.
+ ///
+ /// This method can be overridden for functions that return different
+ /// *types* based on the *values* of their arguments.
+ ///
+ /// For example, the following two function calls get the same argument
+ /// types (something and a `Utf8` string) but return different types based
+ /// on the value of the second argument:
+ ///
+ /// * `arrow_cast(x, 'Int16')` --> `Int16`
+ /// * `arrow_cast(x, 'Float32')` --> `Float32`
+ ///
+ /// # Notes:
+ ///
+ /// This function must consistently return the same type for the same
+ /// logical input even if the input is simplified (e.g. it must return the
same
+ /// value for `('foo' | 'bar')` as it does for ('foobar').
+ fn return_type_from_exprs(
+ &self,
+ args: &[Expr],
+ schema: &dyn ExprSchema,
+ ) -> Result<DataType> {
+ let arg_types = args
+ .iter()
+ .map(|arg| arg.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ self.return_type(&arg_types)
+ }
+
/// Invoke the function on `args`, returning the appropriate result
///
/// The function will be invoked passed with the slice of [`ColumnarValue`]
@@ -290,13 +343,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
#[derive(Debug)]
struct AliasedScalarUDFImpl {
- inner: ScalarUDF,
+ inner: Arc<dyn ScalarUDFImpl>,
aliases: Vec<String>,
}
impl AliasedScalarUDFImpl {
pub fn new(
- inner: ScalarUDF,
+ inner: Arc<dyn ScalarUDFImpl>,
new_aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
let mut aliases = inner.aliases().to_vec();
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 662e0fc7c2..fba77047dd 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -681,17 +681,17 @@ fn coerce_case_expression(case: Case, schema:
&DFSchemaRef) -> Result<Case> {
let case_type = case
.expr
.as_ref()
- .map(|expr| expr.get_type(&schema))
+ .map(|expr| expr.get_type(schema))
.transpose()?;
let then_types = case
.when_then_expr
.iter()
- .map(|(_when, then)| then.get_type(&schema))
+ .map(|(_when, then)| then.get_type(schema))
.collect::<Result<Vec<_>>>()?;
let else_type = case
.else_expr
.as_ref()
- .map(|expr| expr.get_type(&schema))
+ .map(|expr| expr.get_type(schema))
.transpose()?;
// find common coercible types
@@ -701,7 +701,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef)
-> Result<Case> {
let when_types = case
.when_then_expr
.iter()
- .map(|(when, _then)| when.get_type(&schema))
+ .map(|(when, _then)| when.get_type(schema))
.collect::<Result<Vec<_>>>()?;
let coerced_type =
get_coerce_type_for_case_expression(&when_types,
Some(case_type));
@@ -727,7 +727,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef)
-> Result<Case> {
let case_expr = case
.expr
.zip(case_when_coerce_type.as_ref())
- .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type,
&schema))
+ .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type,
schema))
.transpose()?
.map(Box::new);
let when_then = case
@@ -735,7 +735,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef)
-> Result<Case> {
.into_iter()
.map(|(when, then)| {
let when_type =
case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
- let when = when.cast_to(when_type, &schema).map_err(|e| {
+ let when = when.cast_to(when_type, schema).map_err(|e| {
DataFusionError::Context(
format!(
"WHEN expressions in CASE couldn't be \
@@ -744,13 +744,13 @@ fn coerce_case_expression(case: Case, schema:
&DFSchemaRef) -> Result<Case> {
Box::new(e),
)
})?;
- let then = then.cast_to(&then_else_coerce_type, &schema)?;
+ let then = then.cast_to(&then_else_coerce_type, schema)?;
Ok((Box::new(when), Box::new(then)))
})
.collect::<Result<Vec<_>>>()?;
let else_expr = case
.else_expr
- .map(|expr| expr.cast_to(&then_else_coerce_type, &schema))
+ .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
.transpose()?
.map(Box::new);
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 6408af5cda..b8491aea2d 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -272,11 +272,15 @@ pub fn create_physical_expr(
execution_props,
)
}
- ScalarFunctionDefinition::UDF(fun) =>
udf::create_physical_expr(
- fun.clone().as_ref(),
- &physical_args,
- input_schema,
- ),
+ ScalarFunctionDefinition::UDF(fun) => {
+ let return_type = fun.return_type_from_exprs(args,
input_dfschema)?;
+
+ udf::create_physical_expr(
+ fun.clone().as_ref(),
+ &physical_args,
+ return_type,
+ )
+ }
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be
resolved.")
}
diff --git a/datafusion/physical-expr/src/udf.rs
b/datafusion/physical-expr/src/udf.rs
index e0117fecb4..d9c7c9e5c2 100644
--- a/datafusion/physical-expr/src/udf.rs
+++ b/datafusion/physical-expr/src/udf.rs
@@ -17,28 +17,24 @@
//! UDF support
use crate::{PhysicalExpr, ScalarFunctionExpr};
-use arrow::datatypes::Schema;
+use arrow_schema::DataType;
use datafusion_common::Result;
pub use datafusion_expr::ScalarUDF;
use std::sync::Arc;
/// Create a physical expression of the UDF.
-/// This function errors when `args`' can't be coerced to a valid argument
type of the UDF.
+///
+/// Arguments:
pub fn create_physical_expr(
fun: &ScalarUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
- input_schema: &Schema,
+ return_type: DataType,
) -> Result<Arc<dyn PhysicalExpr>> {
- let input_exprs_types = input_phy_exprs
- .iter()
- .map(|e| e.data_type(input_schema))
- .collect::<Result<Vec<_>>>()?;
-
Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
fun.fun(),
input_phy_exprs.to_vec(),
- fun.return_type(&input_exprs_types)?,
+ return_type,
fun.monotonicity()?,
fun.signature().type_signature.supports_zero_argument(),
)))
@@ -46,7 +42,6 @@ pub fn create_physical_expr(
#[cfg(test)]
mod tests {
- use arrow::datatypes::Schema;
use arrow_schema::DataType;
use datafusion_common::Result;
use datafusion_expr::{
@@ -102,7 +97,7 @@ mod tests {
// create and register the udf
let udf = ScalarUDF::from(TestScalarUDF::new());
- let p_expr = create_physical_expr(&udf, &[], &Schema::empty())?;
+ let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?;
assert_eq!(
p_expr