This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a6e6d3fab0 Refactor function argument handling in (#8387)
a6e6d3fab0 is described below

commit a6e6d3fab083839239ef81cf3a3546dd8929a541
Author: Alex Huang <[email protected]>
AuthorDate: Fri Dec 1 21:14:19 2023 +0100

    Refactor function argument handling in (#8387)
    
    ScalarFunctionDefinition
---
 datafusion/expr/src/expr_schema.rs            | 15 ++----
 datafusion/physical-expr/src/planner.rs       | 66 +++++++++++----------------
 datafusion/proto/src/logical_plan/to_proto.rs | 53 +++++++++++----------
 3 files changed, 58 insertions(+), 76 deletions(-)

diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index 2795ac5f09..e5b0185d90 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -83,13 +83,12 @@ impl ExprSchemable for Expr {
             Expr::Cast(Cast { data_type, .. })
             | Expr::TryCast(TryCast { data_type, .. }) => 
Ok(data_type.clone()),
             Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+                let arg_data_types = args
+                    .iter()
+                    .map(|e| e.get_type(schema))
+                    .collect::<Result<Vec<_>>>()?;
                 match func_def {
                     ScalarFunctionDefinition::BuiltIn(fun) => {
-                        let arg_data_types = args
-                            .iter()
-                            .map(|e| e.get_type(schema))
-                            .collect::<Result<Vec<_>>>()?;
-
                         // verify that input data types is consistent with 
function's `TypeSignature`
                         data_types(&arg_data_types, 
&fun.signature()).map_err(|_| {
                             plan_datafusion_err!(
@@ -105,11 +104,7 @@ impl ExprSchemable for Expr {
                         fun.return_type(&arg_data_types)
                     }
                     ScalarFunctionDefinition::UDF(fun) => {
-                        let data_types = args
-                            .iter()
-                            .map(|e| e.get_type(schema))
-                            .collect::<Result<Vec<_>>>()?;
-                        Ok(fun.return_type(&data_types)?)
+                        Ok(fun.return_type(&arg_data_types)?)
                     }
                     ScalarFunctionDefinition::Name(_) => {
                         internal_err!("Function `Expr` with name should be 
resolved.")
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 5501647da2..9c212cb81f 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -348,50 +348,38 @@ pub fn create_physical_expr(
             )))
         }
 
-        Expr::ScalarFunction(ScalarFunction { func_def, args }) => match 
func_def {
-            ScalarFunctionDefinition::BuiltIn(fun) => {
-                let physical_args = args
-                    .iter()
-                    .map(|e| {
-                        create_physical_expr(
-                            e,
-                            input_dfschema,
-                            input_schema,
-                            execution_props,
-                        )
-                    })
-                    .collect::<Result<Vec<_>>>()?;
-                functions::create_physical_expr(
-                    fun,
-                    &physical_args,
-                    input_schema,
-                    execution_props,
-                )
-            }
-            ScalarFunctionDefinition::UDF(fun) => {
-                let mut physical_args = vec![];
-                for e in args {
-                    physical_args.push(create_physical_expr(
-                        e,
-                        input_dfschema,
+        Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+            let mut physical_args = args
+                .iter()
+                .map(|e| {
+                    create_physical_expr(e, input_dfschema, input_schema, 
execution_props)
+                })
+                .collect::<Result<Vec<_>>>()?;
+            match func_def {
+                ScalarFunctionDefinition::BuiltIn(fun) => {
+                    functions::create_physical_expr(
+                        fun,
+                        &physical_args,
                         input_schema,
                         execution_props,
-                    )?);
+                    )
+                }
+                ScalarFunctionDefinition::UDF(fun) => {
+                    // udfs with zero params expect null array as input
+                    if args.is_empty() {
+                        
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
+                    }
+                    udf::create_physical_expr(
+                        fun.clone().as_ref(),
+                        &physical_args,
+                        input_schema,
+                    )
                 }
-                // udfs with zero params expect null array as input
-                if args.is_empty() {
-                    
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
+                ScalarFunctionDefinition::Name(_) => {
+                    internal_err!("Function `Expr` with name should be 
resolved.")
                 }
-                udf::create_physical_expr(
-                    fun.clone().as_ref(),
-                    &physical_args,
-                    input_schema,
-                )
             }
-            ScalarFunctionDefinition::Name(_) => {
-                internal_err!("Function `Expr` with name should be resolved.")
-            }
-        },
+        }
         Expr::Between(Between {
             expr,
             negated,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index ab8e850014..ecbfaca5db 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -792,40 +792,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                         .to_string(),
                 ))
             }
-            Expr::ScalarFunction(ScalarFunction { func_def, args }) => match 
func_def {
-                ScalarFunctionDefinition::BuiltIn(fun) => {
-                    let fun: protobuf::ScalarFunction = fun.try_into()?;
-                    let args: Vec<Self> = args
-                        .iter()
-                        .map(|e| e.try_into())
-                        .collect::<Result<Vec<Self>, Error>>()?;
-                    Self {
-                        expr_type: Some(ExprType::ScalarFunction(
-                            protobuf::ScalarFunctionNode {
-                                fun: fun.into(),
+            Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+                let args = args
+                    .iter()
+                    .map(|expr| expr.try_into())
+                    .collect::<Result<Vec<_>, Error>>()?;
+                match func_def {
+                    ScalarFunctionDefinition::BuiltIn(fun) => {
+                        let fun: protobuf::ScalarFunction = fun.try_into()?;
+                        Self {
+                            expr_type: Some(ExprType::ScalarFunction(
+                                protobuf::ScalarFunctionNode {
+                                    fun: fun.into(),
+                                    args,
+                                },
+                            )),
+                        }
+                    }
+                    ScalarFunctionDefinition::UDF(fun) => Self {
+                        expr_type: Some(ExprType::ScalarUdfExpr(
+                            protobuf::ScalarUdfExprNode {
+                                fun_name: fun.name().to_string(),
                                 args,
                             },
                         )),
-                    }
-                }
-                ScalarFunctionDefinition::UDF(fun) => Self {
-                    expr_type: Some(ExprType::ScalarUdfExpr(
-                        protobuf::ScalarUdfExprNode {
-                            fun_name: fun.name().to_string(),
-                            args: args
-                                .iter()
-                                .map(|expr| expr.try_into())
-                                .collect::<Result<Vec<_>, Error>>()?,
-                        },
-                    )),
-                },
-                ScalarFunctionDefinition::Name(_) => {
-                    return Err(Error::NotImplemented(
+                    },
+                    ScalarFunctionDefinition::Name(_) => {
+                        return Err(Error::NotImplemented(
                     "Proto serialization error: Trying to serialize a 
unresolved function"
                         .to_string(),
                 ));
+                    }
                 }
-            },
+            }
             Expr::Not(expr) => {
                 let expr = Box::new(protobuf::Not {
                     expr: Some(Box::new(expr.as_ref().try_into()?)),

Reply via email to