This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 92d9274e6a Support compute return types from argument values (not just 
their DataTypes) (#8985)
92d9274e6a is described below

commit 92d9274e6a2c677fe07938e531be4b45532a89a9
Author: Junhao Liu <[email protected]>
AuthorDate: Thu Feb 15 10:28:05 2024 -0600

    Support compute return types from argument values (not just their 
DataTypes) (#8985)
    
    * ScalarValue return types from argument values
    
    * change file name
    
    * try using ?Sized
    
    * use Ok
    
    * move method default impl outside trait
    
    * Use type trait for ExprSchemable
    
    * fix nit
    
    * Proposed Return Type from Expr suggestions (#1)
    
    * Improve return_type_from_args
    
    * Rework example
    
    * Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
    
    ---------
    
    Co-authored-by: Junhao Liu <[email protected]>
    
    * Apply suggestions from code review
    
    Co-authored-by: Alex Huang <[email protected]>
    
    * Fix tests + clippy
    
    * rework types to use dyn trait
    
    * fmt
    
    * docs
    
    * Apply suggestions from code review
    
    Co-authored-by: Jeffrey Vo <[email protected]>
    
    * Add docs explaining what happens when both `return_type` and 
`return_type_from_exprs` are called
    
    * clippy
    
    * fix doc -- comedy of errors
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
    Co-authored-by: Alex Huang <[email protected]>
    Co-authored-by: Jeffrey Vo <[email protected]>
---
 .../user_defined/user_defined_scalar_functions.rs  | 142 ++++++++++++++++++++-
 datafusion/expr/src/expr_schema.rs                 |  40 +++---
 datafusion/expr/src/udf.rs                         |  69 ++++++++--
 datafusion/optimizer/src/analyzer/type_coercion.rs |  16 +--
 datafusion/physical-expr/src/planner.rs            |  14 +-
 datafusion/physical-expr/src/udf.rs                |  17 +--
 6 files changed, 245 insertions(+), 53 deletions(-)

diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index a86c76b9b6..9812789740 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema};
 use datafusion::prelude::*;
 use datafusion::{execution::registry::FunctionRegistry, test_util};
 use datafusion_common::cast::as_float64_array;
-use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, 
ScalarValue};
+use datafusion_common::{
+    assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, 
not_impl_err,
+    plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
+};
 use datafusion_expr::{
-    create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, 
ScalarUDF,
-    ScalarUDFImpl, Signature, Volatility,
+    create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
+    LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
 };
 use rand::{thread_rng, Rng};
+use std::any::Any;
 use std::iter;
 use std::sync::Arc;
 
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> 
Result<()> {
     Ok(())
 }
 
+#[derive(Debug)]
+struct TakeUDF {
+    signature: Signature,
+}
+
+impl TakeUDF {
+    fn new() -> Self {
+        Self {
+            signature: Signature::any(3, Volatility::Immutable),
+        }
+    }
+}
+
+/// Implement a ScalarUDFImpl whose return type is a function of the input 
values
+impl ScalarUDFImpl for TakeUDF {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+    fn name(&self) -> &str {
+        "take"
+    }
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
+        not_impl_err!("Not called because the return_type_from_exprs is 
implemented")
+    }
+
+    /// This function returns the type of the first or second argument based on
+    /// the third argument:
+    ///
+    /// 1. If the third argument is '0', return the type of the first argument
+    /// 2. If the third argument is '1', return the type of the second argument
+    fn return_type_from_exprs(
+        &self,
+        arg_exprs: &[Expr],
+        schema: &dyn ExprSchema,
+    ) -> Result<DataType> {
+        if arg_exprs.len() != 3 {
+            return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
+        }
+
+        let take_idx = if let 
Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
+            arg_exprs.get(2)
+        {
+            if *idx == 0 || *idx == 1 {
+                *idx as usize
+            } else {
+                return plan_err!("The third argument must be 0 or 1, got: 
{idx}");
+            }
+        } else {
+            return plan_err!(
+                "The third argument must be a literal of type int64, but got 
{:?}",
+                arg_exprs.get(2)
+            );
+        };
+
+        arg_exprs.get(take_idx).unwrap().get_type(schema)
+    }
+
+    // The actual implementation
+    fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
+        let take_idx = match &args[2] {
+            ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v 
as usize,
+            _ => unreachable!(),
+        };
+        match &args[take_idx] {
+            ColumnarValue::Array(array) => 
Ok(ColumnarValue::Array(array.clone())),
+            ColumnarValue::Scalar(_) => unimplemented!(),
+        }
+    }
+}
+
+#[tokio::test]
+async fn verify_udf_return_type() -> Result<()> {
+    // Create a new ScalarUDF from the implementation
+    let take = ScalarUDF::from(TakeUDF::new());
+
+    // SELECT
+    //   take(smallint_col, double_col, 0) as take0,
+    //   take(smallint_col, double_col, 1) as take1
+    // FROM alltypes_plain;
+    let exprs = vec![
+        take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
+            .alias("take0"),
+        take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
+            .alias("take1"),
+    ];
+
+    let ctx = SessionContext::new();
+    register_alltypes_parquet(&ctx).await?;
+
+    let df = ctx.table("alltypes_plain").await?.select(exprs)?;
+
+    let schema = df.schema();
+
+    // The output schema should be
+    // * type of column smallint_col (int32)
+    // * type of column double_col (float64)
+    assert_eq!(schema.field(0).data_type(), &DataType::Int32);
+    assert_eq!(schema.field(1).data_type(), &DataType::Float64);
+
+    let expected = [
+        "+-------+-------+",
+        "| take0 | take1 |",
+        "+-------+-------+",
+        "| 0     | 0.0   |",
+        "| 0     | 0.0   |",
+        "| 0     | 0.0   |",
+        "| 0     | 0.0   |",
+        "| 1     | 10.1  |",
+        "| 1     | 10.1  |",
+        "| 1     | 10.1  |",
+        "| 1     | 10.1  |",
+        "+-------+-------+",
+    ];
+    assert_batches_sorted_eq!(&expected, &df.collect().await?);
+
+    Ok(())
+}
+
 fn create_udf_context() -> SessionContext {
     let ctx = SessionContext::new();
     // register a custom UDF
@@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> 
Result<()> {
     Ok(())
 }
 
+async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> {
+    let testdata = datafusion::test_util::parquet_test_data();
+    ctx.register_parquet(
+        "alltypes_plain",
+        &format!("{testdata}/alltypes_plain.parquet"),
+        ParquetReadOptions::default(),
+    )
+    .await?;
+    Ok(())
+}
+
 /// Execute SQL and return results as a RecordBatch
 async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> 
Result<Vec<RecordBatch>> {
     ctx.sql(sql).await?.collect().await
diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index 517d7a35f7..491b4a8522 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -28,8 +28,8 @@ use crate::{utils, LogicalPlan, Projection, Subquery};
 use arrow::compute::can_cast_types;
 use arrow::datatypes::{DataType, Field};
 use datafusion_common::{
-    internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema,
-    DataFusionError, ExprSchema, Result,
+    internal_err, plan_datafusion_err, plan_err, Column, DFField, 
DataFusionError,
+    ExprSchema, Result,
 };
 use std::collections::HashMap;
 use std::sync::Arc;
@@ -37,26 +37,28 @@ use std::sync::Arc;
 /// trait to allow expr to typable with respect to a schema
 pub trait ExprSchemable {
     /// given a schema, return the type of the expr
-    fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
+    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
 
     /// given a schema, return the nullability of the expr
-    fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
+    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
 
     /// given a schema, return the expr's optional metadata
-    fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, 
String>>;
+    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, 
String>>;
 
     /// convert to a field with respect to a schema
-    fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
+    fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField>;
 
     /// cast to a type with respect to a schema
-    fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> 
Result<Expr>;
+    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> 
Result<Expr>;
 }
 
 impl ExprSchemable for Expr {
     /// Returns the [arrow::datatypes::DataType] of the expression
     /// based on [ExprSchema]
     ///
-    /// Note: [DFSchema] implements [ExprSchema].
+    /// Note: [`DFSchema`] implements [ExprSchema].
+    ///
+    /// [`DFSchema`]: datafusion_common::DFSchema
     ///
     /// # Examples
     ///
@@ -90,7 +92,7 @@ impl ExprSchemable for Expr {
     /// expression refers to a column that does not exist in the
     /// schema, or when the expression is incorrectly typed
     /// (e.g. `[utf8] + [bool]`).
-    fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
+    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
         match self {
             Expr::Alias(Alias { expr, name, .. }) => match &**expr {
                 Expr::Placeholder(Placeholder { data_type, .. }) => match 
&data_type {
@@ -136,7 +138,7 @@ impl ExprSchemable for Expr {
                         fun.return_type(&arg_data_types)
                     }
                     ScalarFunctionDefinition::UDF(fun) => {
-                        Ok(fun.return_type(&arg_data_types)?)
+                        Ok(fun.return_type_from_exprs(args, schema)?)
                     }
                     ScalarFunctionDefinition::Name(_) => {
                         internal_err!("Function `Expr` with name should be 
resolved.")
@@ -213,14 +215,16 @@ impl ExprSchemable for Expr {
 
     /// Returns the nullability of the expression based on [ExprSchema].
     ///
-    /// Note: [DFSchema] implements [ExprSchema].
+    /// Note: [`DFSchema`] implements [ExprSchema].
+    ///
+    /// [`DFSchema`]: datafusion_common::DFSchema
     ///
     /// # Errors
     ///
     /// This function errors when it is not possible to compute its
     /// nullability.  This happens when the expression refers to a
     /// column that does not exist in the schema.
-    fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
+    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
         match self {
             Expr::Alias(Alias { expr, .. })
             | Expr::Not(expr)
@@ -327,7 +331,7 @@ impl ExprSchemable for Expr {
         }
     }
 
-    fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, 
String>> {
+    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, 
String>> {
         match self {
             Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
             Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
@@ -339,7 +343,7 @@ impl ExprSchemable for Expr {
     ///
     /// So for example, a projected expression `col(c1) + col(c2)` is
     /// placed in an output field **named** col("c1 + c2")
-    fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
+    fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField> {
         match self {
             Expr::Column(c) => Ok(DFField::new(
                 c.relation.clone(),
@@ -370,7 +374,7 @@ impl ExprSchemable for Expr {
     ///
     /// This function errors when it is impossible to cast the
     /// expression to the target [arrow::datatypes::DataType].
-    fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> 
Result<Expr> {
+    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> 
Result<Expr> {
         let this_type = self.get_type(schema)?;
         if this_type == *cast_to_type {
             return Ok(self);
@@ -394,10 +398,10 @@ impl ExprSchemable for Expr {
 }
 
 /// return the schema [`Field`] for the type referenced by `get_indexed_field`
-fn field_for_index<S: ExprSchema>(
+fn field_for_index(
     expr: &Expr,
     field: &GetFieldAccess,
-    schema: &S,
+    schema: &dyn ExprSchema,
 ) -> Result<Field> {
     let expr_dt = expr.get_type(schema)?;
     match field {
@@ -457,7 +461,7 @@ mod tests {
     use super::*;
     use crate::{col, lit};
     use arrow::datatypes::{DataType, Fields};
-    use datafusion_common::{Column, ScalarValue, TableReference};
+    use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};
 
     macro_rules! test_is_expr_nullable {
         ($EXPR_TYPE:ident) => {{
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 3017e1ec02..5b5d92a628 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -17,12 +17,13 @@
 
 //! [`ScalarUDF`]: Scalar User Defined Functions
 
+use crate::ExprSchemable;
 use crate::{
     ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
     ScalarFunctionImplementation, Signature,
 };
 use arrow::datatypes::DataType;
-use datafusion_common::Result;
+use datafusion_common::{ExprSchema, Result};
 use std::any::Any;
 use std::fmt;
 use std::fmt::Debug;
@@ -110,7 +111,7 @@ impl ScalarUDF {
     ///
     /// If you implement [`ScalarUDFImpl`] directly you should return aliases 
directly.
     pub fn with_aliases(self, aliases: impl IntoIterator<Item = &'static str>) 
-> Self {
-        Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases))
+        Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), 
aliases))
     }
 
     /// Returns a [`Expr`] logical expression to call this UDF with specified
@@ -146,10 +147,17 @@ impl ScalarUDF {
     }
 
     /// The datatype this function returns given the input argument input 
types.
+    /// This function is used when the input arguments are [`Expr`]s.
     ///
-    /// See [`ScalarUDFImpl::return_type`] for more details.
-    pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
-        self.inner.return_type(args)
+    ///
+    /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details.
+    pub fn return_type_from_exprs(
+        &self,
+        args: &[Expr],
+        schema: &dyn ExprSchema,
+    ) -> Result<DataType> {
+        // If the implementation provides a return_type_from_exprs, use it
+        self.inner.return_type_from_exprs(args, schema)
     }
 
     /// Invoke the function on `args`, returning the appropriate result.
@@ -246,9 +254,54 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
     fn signature(&self) -> &Signature;
 
     /// What [`DataType`] will be returned by this function, given the types of
-    /// the arguments
+    /// the arguments.
+    ///
+    /// # Notes
+    ///
+    /// If you provide an implementation for [`Self::return_type_from_exprs`],
+    /// DataFusion will not call `return_type` (this function). In this case it
+    /// is recommended to return [`DataFusionError::Internal`].
+    ///
+    /// [`DataFusionError::Internal`]: 
datafusion_common::DataFusionError::Internal
     fn return_type(&self, arg_types: &[DataType]) -> Result<DataType>;
 
+    /// What [`DataType`] will be returned by this function, given the
+    /// arguments?
+    ///
+    /// Note most UDFs should implement [`Self::return_type`] and not this
+    /// function. The output type for most functions only depends on the types
+    /// of their inputs (e.g. `sqrt(f32)` is always `f32`).
+    ///
+    /// By default, this function calls [`Self::return_type`] with the
+    /// types of each argument.
+    ///
+    /// This method can be overridden for functions that return different
+    /// *types* based on the *values* of their arguments.
+    ///
+    /// For example, the following two function calls get the same argument
+    /// types (something and a `Utf8` string) but return different types based
+    /// on the value of the second argument:
+    ///
+    /// * `arrow_cast(x, 'Int16')` --> `Int16`
+    /// * `arrow_cast(x, 'Float32')` --> `Float32`
+    ///
+    /// # Notes:
+    ///
+    /// This function must consistently return the same type for the same
+    /// logical input even if the input is simplified (e.g. it must return the 
same
+    /// value for `('foo' | 'bar')` as it does for ('foobar').
+    fn return_type_from_exprs(
+        &self,
+        args: &[Expr],
+        schema: &dyn ExprSchema,
+    ) -> Result<DataType> {
+        let arg_types = args
+            .iter()
+            .map(|arg| arg.get_type(schema))
+            .collect::<Result<Vec<_>>>()?;
+        self.return_type(&arg_types)
+    }
+
     /// Invoke the function on `args`, returning the appropriate result
     ///
     /// The function will be invoked passed with the slice of [`ColumnarValue`]
@@ -290,13 +343,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
 /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible.
 #[derive(Debug)]
 struct AliasedScalarUDFImpl {
-    inner: ScalarUDF,
+    inner: Arc<dyn ScalarUDFImpl>,
     aliases: Vec<String>,
 }
 
 impl AliasedScalarUDFImpl {
     pub fn new(
-        inner: ScalarUDF,
+        inner: Arc<dyn ScalarUDFImpl>,
         new_aliases: impl IntoIterator<Item = &'static str>,
     ) -> Self {
         let mut aliases = inner.aliases().to_vec();
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 662e0fc7c2..fba77047dd 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -681,17 +681,17 @@ fn coerce_case_expression(case: Case, schema: 
&DFSchemaRef) -> Result<Case> {
     let case_type = case
         .expr
         .as_ref()
-        .map(|expr| expr.get_type(&schema))
+        .map(|expr| expr.get_type(schema))
         .transpose()?;
     let then_types = case
         .when_then_expr
         .iter()
-        .map(|(_when, then)| then.get_type(&schema))
+        .map(|(_when, then)| then.get_type(schema))
         .collect::<Result<Vec<_>>>()?;
     let else_type = case
         .else_expr
         .as_ref()
-        .map(|expr| expr.get_type(&schema))
+        .map(|expr| expr.get_type(schema))
         .transpose()?;
 
     // find common coercible types
@@ -701,7 +701,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) 
-> Result<Case> {
             let when_types = case
                 .when_then_expr
                 .iter()
-                .map(|(when, _then)| when.get_type(&schema))
+                .map(|(when, _then)| when.get_type(schema))
                 .collect::<Result<Vec<_>>>()?;
             let coerced_type =
                 get_coerce_type_for_case_expression(&when_types, 
Some(case_type));
@@ -727,7 +727,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) 
-> Result<Case> {
     let case_expr = case
         .expr
         .zip(case_when_coerce_type.as_ref())
-        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, 
&schema))
+        .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, 
schema))
         .transpose()?
         .map(Box::new);
     let when_then = case
@@ -735,7 +735,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) 
-> Result<Case> {
         .into_iter()
         .map(|(when, then)| {
             let when_type = 
case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean);
-            let when = when.cast_to(when_type, &schema).map_err(|e| {
+            let when = when.cast_to(when_type, schema).map_err(|e| {
                 DataFusionError::Context(
                     format!(
                         "WHEN expressions in CASE couldn't be \
@@ -744,13 +744,13 @@ fn coerce_case_expression(case: Case, schema: 
&DFSchemaRef) -> Result<Case> {
                     Box::new(e),
                 )
             })?;
-            let then = then.cast_to(&then_else_coerce_type, &schema)?;
+            let then = then.cast_to(&then_else_coerce_type, schema)?;
             Ok((Box::new(when), Box::new(then)))
         })
         .collect::<Result<Vec<_>>>()?;
     let else_expr = case
         .else_expr
-        .map(|expr| expr.cast_to(&then_else_coerce_type, &schema))
+        .map(|expr| expr.cast_to(&then_else_coerce_type, schema))
         .transpose()?
         .map(Box::new);
 
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 6408af5cda..b8491aea2d 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -272,11 +272,15 @@ pub fn create_physical_expr(
                         execution_props,
                     )
                 }
-                ScalarFunctionDefinition::UDF(fun) => 
udf::create_physical_expr(
-                    fun.clone().as_ref(),
-                    &physical_args,
-                    input_schema,
-                ),
+                ScalarFunctionDefinition::UDF(fun) => {
+                    let return_type = fun.return_type_from_exprs(args, 
input_dfschema)?;
+
+                    udf::create_physical_expr(
+                        fun.clone().as_ref(),
+                        &physical_args,
+                        return_type,
+                    )
+                }
                 ScalarFunctionDefinition::Name(_) => {
                     internal_err!("Function `Expr` with name should be 
resolved.")
                 }
diff --git a/datafusion/physical-expr/src/udf.rs 
b/datafusion/physical-expr/src/udf.rs
index e0117fecb4..d9c7c9e5c2 100644
--- a/datafusion/physical-expr/src/udf.rs
+++ b/datafusion/physical-expr/src/udf.rs
@@ -17,28 +17,24 @@
 
 //! UDF support
 use crate::{PhysicalExpr, ScalarFunctionExpr};
-use arrow::datatypes::Schema;
+use arrow_schema::DataType;
 use datafusion_common::Result;
 pub use datafusion_expr::ScalarUDF;
 use std::sync::Arc;
 
 /// Create a physical expression of the UDF.
-/// This function errors when `args`' can't be coerced to a valid argument 
type of the UDF.
+///
+/// Arguments:
 pub fn create_physical_expr(
     fun: &ScalarUDF,
     input_phy_exprs: &[Arc<dyn PhysicalExpr>],
-    input_schema: &Schema,
+    return_type: DataType,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    let input_exprs_types = input_phy_exprs
-        .iter()
-        .map(|e| e.data_type(input_schema))
-        .collect::<Result<Vec<_>>>()?;
-
     Ok(Arc::new(ScalarFunctionExpr::new(
         fun.name(),
         fun.fun(),
         input_phy_exprs.to_vec(),
-        fun.return_type(&input_exprs_types)?,
+        return_type,
         fun.monotonicity()?,
         fun.signature().type_signature.supports_zero_argument(),
     )))
@@ -46,7 +42,6 @@ pub fn create_physical_expr(
 
 #[cfg(test)]
 mod tests {
-    use arrow::datatypes::Schema;
     use arrow_schema::DataType;
     use datafusion_common::Result;
     use datafusion_expr::{
@@ -102,7 +97,7 @@ mod tests {
         // create and register the udf
         let udf = ScalarUDF::from(TestScalarUDF::new());
 
-        let p_expr = create_physical_expr(&udf, &[], &Schema::empty())?;
+        let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?;
 
         assert_eq!(
             p_expr

Reply via email to