This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 742e3c5b3f Remove ScalarFunctionDefinition (#10325)
742e3c5b3f is described below

commit 742e3c5b3f5d4d961d50cc77033ab7774b90c56f
Author: 张林伟 <[email protected]>
AuthorDate: Wed May 8 00:49:57 2024 +0800

    Remove ScalarFunctionDefinition (#10325)
    
    * Remove ScalarFunctionDefinition
    
    * Fix test
    
    * rename func_def to func
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/datasource/listing/helpers.rs  | 18 +++-----
 .../src/physical_optimizer/projection_pushdown.rs  | 19 ++------
 datafusion/expr/src/expr.rs                        | 54 ++++------------------
 datafusion/expr/src/expr_schema.rs                 | 16 +++----
 datafusion/expr/src/lib.rs                         |  2 +-
 datafusion/expr/src/tree_node.rs                   | 12 ++---
 datafusion/functions-array/src/rewrite.rs          |  8 ++--
 datafusion/functions/src/math/log.rs               | 19 ++------
 datafusion/functions/src/math/power.rs             | 17 ++-----
 datafusion/functions/src/string/concat.rs          |  4 +-
 datafusion/functions/src/string/concat_ws.rs       |  4 +-
 datafusion/optimizer/src/analyzer/type_coercion.rs | 24 ++++------
 .../optimizer/src/optimize_projections/mod.rs      |  4 +-
 datafusion/optimizer/src/push_down_filter.rs       |  7 +--
 .../src/simplify_expressions/expr_simplifier.rs    | 32 ++++++-------
 datafusion/physical-expr/src/planner.rs            | 22 ++++-----
 datafusion/physical-expr/src/scalar_function.rs    | 36 ++++++---------
 datafusion/proto/src/logical_plan/to_proto.rs      | 34 ++++++--------
 datafusion/proto/src/physical_plan/from_proto.rs   |  3 +-
 datafusion/proto/src/physical_plan/to_proto.rs     |  7 +--
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  9 ++--
 datafusion/sql/src/unparser/expr.rs                |  4 +-
 22 files changed, 124 insertions(+), 231 deletions(-)

diff --git a/datafusion/core/src/datasource/listing/helpers.rs 
b/datafusion/core/src/datasource/listing/helpers.rs
index 09d9aa8811..0cffa05131 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -38,7 +38,7 @@ use log::{debug, trace};
 
 use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
 use datafusion_common::{Column, DFSchema, DataFusionError};
-use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
+use datafusion_expr::{Expr, Volatility};
 use datafusion_physical_expr::create_physical_expr;
 use object_store::path::Path;
 use object_store::{ObjectMeta, ObjectStore};
@@ -89,16 +89,12 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: 
&Expr) -> bool {
             | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),
 
             Expr::ScalarFunction(scalar_function) => {
-                match &scalar_function.func_def {
-                    ScalarFunctionDefinition::UDF(fun) => {
-                        match fun.signature().volatility {
-                            Volatility::Immutable => 
Ok(TreeNodeRecursion::Continue),
-                            // TODO: Stable functions could be `applicable`, 
but that would require access to the context
-                            Volatility::Stable | Volatility::Volatile => {
-                                is_applicable = false;
-                                Ok(TreeNodeRecursion::Stop)
-                            }
-                        }
+                match scalar_function.func.signature().volatility {
+                    Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
+                    // TODO: Stable functions could be `applicable`, but that 
would require access to the context
+                    Volatility::Stable | Volatility::Volatile => {
+                        is_applicable = false;
+                        Ok(TreeNodeRecursion::Stop)
                     }
                 }
             }
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs 
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index 160dd3a1c4..0190f35cc9 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -1301,8 +1301,7 @@ mod tests {
     use datafusion_execution::object_store::ObjectStoreUrl;
     use datafusion_execution::{SendableRecordBatchStream, TaskContext};
     use datafusion_expr::{
-        ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, 
ScalarUDFImpl,
-        Signature, Volatility,
+        ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, 
Volatility,
     };
     use datafusion_physical_expr::expressions::{
         BinaryExpr, CaseExpr, CastExpr, NegativeExpr,
@@ -1363,9 +1362,7 @@ mod tests {
             Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
             Arc::new(ScalarFunctionExpr::new(
                 "scalar_expr",
-                
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
-                    DummyUDF::new(),
-                ))),
+                Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
                 vec![
                     Arc::new(BinaryExpr::new(
                         Arc::new(Column::new("b", 1)),
@@ -1431,9 +1428,7 @@ mod tests {
             Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
             Arc::new(ScalarFunctionExpr::new(
                 "scalar_expr",
-                
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
-                    DummyUDF::new(),
-                ))),
+                Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
                 vec![
                     Arc::new(BinaryExpr::new(
                         Arc::new(Column::new("b", 1)),
@@ -1502,9 +1497,7 @@ mod tests {
             Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
             Arc::new(ScalarFunctionExpr::new(
                 "scalar_expr",
-                
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
-                    DummyUDF::new(),
-                ))),
+                Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
                 vec![
                     Arc::new(BinaryExpr::new(
                         Arc::new(Column::new("b", 1)),
@@ -1570,9 +1563,7 @@ mod tests {
             Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
             Arc::new(ScalarFunctionExpr::new(
                 "scalar_expr",
-                
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
-                    DummyUDF::new(),
-                ))),
+                Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
                 vec![
                     Arc::new(BinaryExpr::new(
                         Arc::new(Column::new("b_new", 1)),
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index c154cd999a..9789dd345f 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -26,11 +26,11 @@ use std::sync::Arc;
 use crate::expr_fn::binary_expr;
 use crate::logical_plan::Subquery;
 use crate::utils::expr_to_columns;
-use crate::window_frame;
 use crate::{
     aggregate_function, built_in_window_function, udaf, ExprSchemable, 
Operator,
     Signature,
 };
+use crate::{window_frame, Volatility};
 
 use arrow::datatypes::{DataType, FieldRef};
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -399,18 +399,11 @@ impl Between {
     }
 }
 
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-/// Defines which implementation of a function for DataFusion to call.
-pub enum ScalarFunctionDefinition {
-    /// Resolved to a user defined function
-    UDF(Arc<crate::ScalarUDF>),
-}
-
 /// ScalarFunction expression invokes a built-in scalar function
 #[derive(Clone, PartialEq, Eq, Hash, Debug)]
 pub struct ScalarFunction {
     /// The function
-    pub func_def: ScalarFunctionDefinition,
+    pub func: Arc<crate::ScalarUDF>,
     /// List of expressions to feed to the functions as arguments
     pub args: Vec<Expr>,
 }
@@ -418,41 +411,14 @@ pub struct ScalarFunction {
 impl ScalarFunction {
     // return the Function's name
     pub fn name(&self) -> &str {
-        self.func_def.name()
-    }
-}
-
-impl ScalarFunctionDefinition {
-    /// Function's name for display
-    pub fn name(&self) -> &str {
-        match self {
-            ScalarFunctionDefinition::UDF(udf) => udf.name(),
-        }
-    }
-
-    /// Whether this function is volatile, i.e. whether it can return 
different results
-    /// when evaluated multiple times with the same input.
-    pub fn is_volatile(&self) -> Result<bool> {
-        match self {
-            ScalarFunctionDefinition::UDF(udf) => {
-                Ok(udf.signature().volatility == crate::Volatility::Volatile)
-            }
-        }
+        self.func.name()
     }
 }
 
 impl ScalarFunction {
     /// Create a new ScalarFunction expression with a user-defined function 
(UDF)
     pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
-        Self {
-            func_def: ScalarFunctionDefinition::UDF(udf),
-            args,
-        }
-    }
-
-    /// Create a new ScalarFunction expression with a user-defined function 
(UDF)
-    pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec<Expr>) 
-> Self {
-        Self { func_def, args }
+        Self { func: udf, args }
     }
 }
 
@@ -1299,7 +1265,7 @@ impl Expr {
     /// results when evaluated multiple times with the same input.
     pub fn is_volatile(&self) -> Result<bool> {
         self.exists(|expr| {
-            Ok(matches!(expr, Expr::ScalarFunction(func) if 
func.func_def.is_volatile()?))
+            Ok(matches!(expr, Expr::ScalarFunction(func) if 
func.func.signature().volatility == Volatility::Volatile ))
         })
     }
 
@@ -1334,9 +1300,7 @@ impl Expr {
     /// and thus any side effects (like divide by zero) may not be encountered
     pub fn short_circuits(&self) -> bool {
         match self {
-            Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
-                matches!(func_def, ScalarFunctionDefinition::UDF(fun) if 
fun.short_circuits())
-            }
+            Expr::ScalarFunction(ScalarFunction { func, .. }) => 
func.short_circuits(),
             Expr::BinaryExpr(BinaryExpr { op, .. }) => {
                 matches!(op, Operator::And | Operator::Or)
             }
@@ -2071,7 +2035,7 @@ mod test {
     }
 
     #[test]
-    fn test_is_volatile_scalar_func_definition() {
+    fn test_is_volatile_scalar_func() {
         // UDF
         #[derive(Debug)]
         struct TestScalarUDF {
@@ -2100,7 +2064,7 @@ mod test {
         let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
             signature: Signature::uniform(1, vec![DataType::Float32], 
Volatility::Stable),
         }));
-        assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+        assert_ne!(udf.signature().volatility, Volatility::Volatile);
 
         let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
             signature: Signature::uniform(
@@ -2109,7 +2073,7 @@ mod test {
                 Volatility::Volatile,
             ),
         }));
-        assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+        assert_eq!(udf.signature().volatility, Volatility::Volatile);
     }
 
     use super::*;
diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index f93f085749..4aca52d67c 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -19,7 +19,7 @@ use super::{Between, Expr, Like};
 use crate::expr::{
     AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast,
     GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, 
ScalarFunction,
-    ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
+    Sort, TryCast, Unnest, WindowFunction,
 };
 use crate::field_util::GetFieldAccessSchema;
 use crate::type_coercion::binary::get_result_type;
@@ -133,20 +133,18 @@ impl ExprSchemable for Expr {
                     }
                 }
             }
-            Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+            Expr::ScalarFunction(ScalarFunction { func, args }) => {
                 let arg_data_types = args
                     .iter()
                     .map(|e| e.get_type(schema))
                     .collect::<Result<Vec<_>>>()?;
-                match func_def {
-                    ScalarFunctionDefinition::UDF(fun) => {
                         // verify that function is invoked with correct number 
and type of arguments as defined in `TypeSignature`
-                        data_types(&arg_data_types, 
fun.signature()).map_err(|_| {
+                        data_types(&arg_data_types, 
func.signature()).map_err(|_| {
                             plan_datafusion_err!(
                                 "{}",
                                 utils::generate_signature_error_msg(
-                                    fun.name(),
-                                    fun.signature().clone(),
+                                    func.name(),
+                                    func.signature().clone(),
                                     &arg_data_types,
                                 )
                             )
@@ -154,9 +152,7 @@ impl ExprSchemable for Expr {
 
                         // perform additional function arguments validation 
(due to limited
                         // expressiveness of `TypeSignature`), then infer 
return type
-                        Ok(fun.return_type_from_exprs(args, schema, 
&arg_data_types)?)
-                    }
-                }
+                        Ok(func.return_type_from_exprs(args, schema, 
&arg_data_types)?)
             }
             Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
                 let data_types = args
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index de4f310292..e2b68388ab 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -63,7 +63,7 @@ pub use built_in_window_function::BuiltInWindowFunction;
 pub use columnar_value::ColumnarValue;
 pub use expr::{
     Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, 
GroupingSet,
-    Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition,
+    Like, TryCast, WindowFunctionDefinition,
 };
 pub use expr_fn::*;
 pub use expr_schema::ExprSchemable;
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index ae3ca9afc4..710164eca3 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -20,7 +20,7 @@
 use crate::expr::{
     AggregateFunction, AggregateFunctionDefinition, Alias, Between, 
BinaryExpr, Case,
     Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
-    ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest, 
WindowFunction,
+    ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
 };
 use crate::{Expr, GetFieldAccess};
 
@@ -281,11 +281,11 @@ impl TreeNode for Expr {
                 nulls_first,
             }) => transform_box(expr, &mut f)?
                 .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))),
-            Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
-                transform_vec(args, &mut f)?.map_data(|new_args| match 
func_def {
-                    ScalarFunctionDefinition::UDF(fun) => {
-                        Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, 
new_args)))
-                    }
+            Expr::ScalarFunction(ScalarFunction { func, args }) => {
+                transform_vec(args, &mut f)?.map_data(|new_args| {
+                    Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+                        func, new_args,
+                    )))
                 })?
             }
             Expr::WindowFunction(WindowFunction {
diff --git a/datafusion/functions-array/src/rewrite.rs 
b/datafusion/functions-array/src/rewrite.rs
index 32d15b5563..416e79cbc0 100644
--- a/datafusion/functions-array/src/rewrite.rs
+++ b/datafusion/functions-array/src/rewrite.rs
@@ -182,20 +182,20 @@ impl FunctionRewrite for ArrayFunctionRewriter {
 /// Returns true if expr is a function call to the specified named function.
 /// Returns false otherwise.
 fn is_func(expr: &Expr, func_name: &str) -> bool {
-    let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else 
{
+    let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
         return false;
     };
 
-    func_def.name() == func_name
+    func.name() == func_name
 }
 
 /// Returns true if expr is a function call with one of the specified names
 fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
-    let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else 
{
+    let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
         return false;
     };
 
-    func_names.contains(&func_def.name())
+    func_names.contains(&func.name())
 }
 
 /// returns Some(col) if this is Expr::Column
diff --git a/datafusion/functions/src/math/log.rs 
b/datafusion/functions/src/math/log.rs
index f451321ea1..e6c698ad1a 100644
--- a/datafusion/functions/src/math/log.rs
+++ b/datafusion/functions/src/math/log.rs
@@ -24,9 +24,7 @@ use datafusion_common::{
 };
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{
-    lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition,
-};
+use datafusion_expr::{lit, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF};
 
 use arrow::array::{ArrayRef, Float32Array, Float64Array};
 use datafusion_expr::TypeSignature::*;
@@ -178,8 +176,8 @@ impl ScalarUDFImpl for LogFunc {
                     &info.get_data_type(&base)?,
                 )?)))
             }
-            Expr::ScalarFunction(ScalarFunction { func_def, mut args })
-                if is_pow(&func_def) && args.len() == 2 && base == args[0] =>
+            Expr::ScalarFunction(ScalarFunction { func, mut args })
+                if is_pow(&func) && args.len() == 2 && base == args[0] =>
             {
                 let b = args.pop().unwrap(); // length checked above
                 Ok(ExprSimplifyResult::Simplified(b))
@@ -207,15 +205,8 @@ impl ScalarUDFImpl for LogFunc {
 }
 
 /// Returns true if the function is `PowerFunc`
-fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {
-    match func_def {
-        ScalarFunctionDefinition::UDF(fun) => fun
-            .as_ref()
-            .inner()
-            .as_any()
-            .downcast_ref::<PowerFunc>()
-            .is_some(),
-    }
+fn is_pow(func: &ScalarUDF) -> bool {
+    func.inner().as_any().downcast_ref::<PowerFunc>().is_some()
 }
 
 #[cfg(test)]
diff --git a/datafusion/functions/src/math/power.rs 
b/datafusion/functions/src/math/power.rs
index 8cc6b4c02a..7677e8b2af 100644
--- a/datafusion/functions/src/math/power.rs
+++ b/datafusion/functions/src/math/power.rs
@@ -23,7 +23,7 @@ use datafusion_common::{
 };
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition};
+use datafusion_expr::{ColumnarValue, Expr, ScalarUDF};
 
 use arrow::array::{ArrayRef, Float64Array, Int64Array};
 use datafusion_expr::TypeSignature::*;
@@ -140,8 +140,8 @@ impl ScalarUDFImpl for PowerFunc {
             Expr::Literal(value) if value == 
ScalarValue::new_one(&exponent_type)? => {
                 Ok(ExprSimplifyResult::Simplified(base))
             }
-            Expr::ScalarFunction(ScalarFunction { func_def, mut args })
-                if is_log(&func_def) && args.len() == 2 && base == args[0] =>
+            Expr::ScalarFunction(ScalarFunction { func, mut args })
+                if is_log(&func) && args.len() == 2 && base == args[0] =>
             {
                 let b = args.pop().unwrap(); // length checked above
                 Ok(ExprSimplifyResult::Simplified(b))
@@ -152,15 +152,8 @@ impl ScalarUDFImpl for PowerFunc {
 }
 
 /// Return true if this function call is a call to `Log`
-fn is_log(func_def: &ScalarFunctionDefinition) -> bool {
-    match func_def {
-        ScalarFunctionDefinition::UDF(fun) => fun
-            .as_ref()
-            .inner()
-            .as_any()
-            .downcast_ref::<LogFunc>()
-            .is_some(),
-    }
+fn is_log(func: &ScalarUDF) -> bool {
+    func.inner().as_any().downcast_ref::<LogFunc>().is_some()
 }
 
 #[cfg(test)]
diff --git a/datafusion/functions/src/string/concat.rs 
b/datafusion/functions/src/string/concat.rs
index 55b7c2f222..6d15e22067 100644
--- a/datafusion/functions/src/string/concat.rs
+++ b/datafusion/functions/src/string/concat.rs
@@ -25,7 +25,7 @@ use datafusion_common::cast::as_string_array;
 use datafusion_common::{internal_err, Result, ScalarValue};
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, 
Volatility};
+use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
 use datafusion_expr::{ScalarUDFImpl, Signature};
 
 use crate::string::common::*;
@@ -182,7 +182,7 @@ pub fn simplify_concat(args: Vec<Expr>) -> 
Result<ExprSimplifyResult> {
     if !args.eq(&new_args) {
         Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
             ScalarFunction {
-                func_def: ScalarFunctionDefinition::UDF(concat()),
+                func: concat(),
                 args: new_args,
             },
         )))
diff --git a/datafusion/functions/src/string/concat_ws.rs 
b/datafusion/functions/src/string/concat_ws.rs
index 1d27712b2c..4d05f4e707 100644
--- a/datafusion/functions/src/string/concat_ws.rs
+++ b/datafusion/functions/src/string/concat_ws.rs
@@ -26,7 +26,7 @@ use datafusion_common::cast::as_string_array;
 use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
 use datafusion_expr::expr::ScalarFunction;
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, 
Volatility};
+use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
 use datafusion_expr::{ScalarUDFImpl, Signature};
 
 use crate::string::common::*;
@@ -266,7 +266,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> 
Result<ExprSimplifyRes
 
                     Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
                         ScalarFunction {
-                            func_def: 
ScalarFunctionDefinition::UDF(concat_ws()),
+                            func: concat_ws(),
                             args: new_args,
                         },
                     )))
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 10479f29a5..61b1d1d77b 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -46,8 +46,7 @@ use datafusion_expr::utils::merge_schema;
 use datafusion_expr::{
     is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, 
not,
     type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan, 
Operator,
-    ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, 
WindowFrameBound,
-    WindowFrameUnits,
+    ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits,
 };
 
 use crate::analyzer::AnalyzerRule;
@@ -303,19 +302,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
                 let case = coerce_case_expression(case, self.schema)?;
                 Ok(Transformed::yes(Expr::Case(case)))
             }
-            Expr::ScalarFunction(ScalarFunction { func_def, args }) => match 
func_def {
-                ScalarFunctionDefinition::UDF(fun) => {
-                    let new_expr = coerce_arguments_for_signature(
-                        args,
-                        self.schema,
-                        fun.signature(),
-                    )?;
-                    let new_expr = coerce_arguments_for_fun(new_expr, 
self.schema, &fun)?;
-                    Ok(Transformed::yes(Expr::ScalarFunction(
-                        ScalarFunction::new_udf(fun, new_expr),
-                    )))
-                }
-            },
+            Expr::ScalarFunction(ScalarFunction { func, args }) => {
+                let new_expr =
+                    coerce_arguments_for_signature(args, self.schema, 
func.signature())?;
+                let new_expr = coerce_arguments_for_fun(new_expr, self.schema, 
&func)?;
+                Ok(Transformed::yes(Expr::ScalarFunction(
+                    ScalarFunction::new_udf(func, new_expr),
+                )))
+            }
             Expr::AggregateFunction(expr::AggregateFunction {
                 func_def,
                 args,
diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs 
b/datafusion/optimizer/src/optimize_projections/mod.rs
index 0f2aaa6cbc..aa2d005379 100644
--- a/datafusion/optimizer/src/optimize_projections/mod.rs
+++ b/datafusion/optimizer/src/optimize_projections/mod.rs
@@ -555,8 +555,8 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> 
Result<Option<Expr>> {
                 .map(|expr| rewrite_expr(expr, input))
                 .collect::<Result<Option<_>>>()?
                 .map(|new_args| {
-                    Expr::ScalarFunction(ScalarFunction::new_func_def(
-                        scalar_fn.func_def.clone(),
+                    Expr::ScalarFunction(ScalarFunction::new_udf(
+                        scalar_fn.func.clone(),
                         new_args,
                     ))
                 }));
diff --git a/datafusion/optimizer/src/push_down_filter.rs 
b/datafusion/optimizer/src/push_down_filter.rs
index 2355ee604e..f58345237b 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -35,7 +35,7 @@ use datafusion_expr::logical_plan::{
 use datafusion_expr::utils::{conjunction, split_conjunction, 
split_conjunction_owned};
 use datafusion_expr::{
     and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, 
Operator,
-    ScalarFunctionDefinition, TableProviderFilterPushDown,
+    TableProviderFilterPushDown,
 };
 
 use crate::optimizer::ApplyOrder;
@@ -228,10 +228,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> 
Result<bool> {
         | Expr::ScalarSubquery(_)
         | Expr::OuterReferenceColumn(_, _)
         | Expr::Unnest(_)
-        | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
-            func_def: ScalarFunctionDefinition::UDF(_),
-            ..
-        }) => {
+        | Expr::ScalarFunction(_) => {
             is_evaluate = false;
             Ok(TreeNodeRecursion::Stop)
         }
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 0f711d6a2c..5122de4f09 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -35,8 +35,7 @@ use datafusion_common::{internal_err, DFSchema, 
DataFusionError, Result, ScalarV
 use datafusion_expr::expr::{InList, InSubquery};
 use datafusion_expr::simplify::ExprSimplifyResult;
 use datafusion_expr::{
-    and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator,
-    ScalarFunctionDefinition, Volatility,
+    and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, 
Volatility,
 };
 use datafusion_expr::{expr::ScalarFunction, 
interval_arithmetic::NullableInterval};
 use datafusion_physical_expr::{create_physical_expr, 
execution_props::ExecutionProps};
@@ -595,11 +594,9 @@ impl<'a> ConstEvaluator<'a> {
             | Expr::GroupingSet(_)
             | Expr::Wildcard { .. }
             | Expr::Placeholder(_) => false,
-            Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match 
func_def {
-                ScalarFunctionDefinition::UDF(fun) => {
-                    Self::volatility_ok(fun.signature().volatility)
-                }
-            },
+            Expr::ScalarFunction(ScalarFunction { func, .. }) => {
+                Self::volatility_ok(func.signature().volatility)
+            }
             Expr::Literal(_)
             | Expr::Unnest(_)
             | Expr::BinaryExpr { .. }
@@ -1373,18 +1370,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for 
Simplifier<'a, S> {
                 // Do a first pass at simplification
                 out_expr.rewrite(self)?
             }
-            Expr::ScalarFunction(ScalarFunction {
-                func_def: ScalarFunctionDefinition::UDF(udf),
-                args,
-            }) => match udf.simplify(args, info)? {
-                ExprSimplifyResult::Original(args) => {
-                    Transformed::no(Expr::ScalarFunction(ScalarFunction {
-                        func_def: ScalarFunctionDefinition::UDF(udf),
-                        args,
-                    }))
+            Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
+                match udf.simplify(args, info)? {
+                    ExprSimplifyResult::Original(args) => {
+                        Transformed::no(Expr::ScalarFunction(ScalarFunction {
+                            func: udf,
+                            args,
+                        }))
+                    }
+                    ExprSimplifyResult::Simplified(expr) => 
Transformed::yes(expr),
                 }
-                ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
-            },
+            }
 
             //
             // Rules for Between
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 2621b817b2..ab57a8e800 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -28,7 +28,7 @@ use datafusion_expr::var_provider::is_system_variables;
 use datafusion_expr::var_provider::VarType;
 use datafusion_expr::{
     binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, 
Like,
-    Operator, ScalarFunctionDefinition, TryCast,
+    Operator, TryCast,
 };
 
 use crate::scalar_function;
@@ -305,21 +305,17 @@ pub fn create_physical_expr(
             }
         },
 
-        Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+        Expr::ScalarFunction(ScalarFunction { func, args }) => {
             let physical_args =
                 create_physical_exprs(args, input_dfschema, execution_props)?;
 
-            match func_def {
-                ScalarFunctionDefinition::UDF(fun) => {
-                    scalar_function::create_physical_expr(
-                        fun.clone().as_ref(),
-                        &physical_args,
-                        input_schema,
-                        args,
-                        input_dfschema,
-                    )
-                }
-            }
+            scalar_function::create_physical_expr(
+                func.clone().as_ref(),
+                &physical_args,
+                input_schema,
+                args,
+                input_dfschema,
+            )
         }
         Expr::Between(Between {
             expr,
diff --git a/datafusion/physical-expr/src/scalar_function.rs 
b/datafusion/physical-expr/src/scalar_function.rs
index 6b84b81e9f..180f2a7946 100644
--- a/datafusion/physical-expr/src/scalar_function.rs
+++ b/datafusion/physical-expr/src/scalar_function.rs
@@ -40,10 +40,7 @@ use arrow::record_batch::RecordBatch;
 
 use datafusion_common::{internal_err, DFSchema, Result};
 use datafusion_expr::type_coercion::functions::data_types;
-use datafusion_expr::{
-    expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity, 
ScalarFunctionDefinition,
-    ScalarUDF,
-};
+use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity, 
ScalarUDF};
 
 use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal};
 use crate::sort_properties::SortProperties;
@@ -51,7 +48,7 @@ use crate::PhysicalExpr;
 
 /// Physical expression of a scalar function
 pub struct ScalarFunctionExpr {
-    fun: ScalarFunctionDefinition,
+    fun: Arc<ScalarUDF>,
     name: String,
     args: Vec<Arc<dyn PhysicalExpr>>,
     return_type: DataType,
@@ -78,7 +75,7 @@ impl ScalarFunctionExpr {
     /// Create a new Scalar function
     pub fn new(
         name: &str,
-        fun: ScalarFunctionDefinition,
+        fun: Arc<ScalarUDF>,
         args: Vec<Arc<dyn PhysicalExpr>>,
         return_type: DataType,
         monotonicity: Option<FuncMonotonicity>,
@@ -93,7 +90,7 @@ impl ScalarFunctionExpr {
     }
 
     /// Get the scalar function implementation
-    pub fn fun(&self) -> &ScalarFunctionDefinition {
+    pub fn fun(&self) -> &ScalarUDF {
         &self.fun
     }
 
@@ -146,22 +143,18 @@ impl PhysicalExpr for ScalarFunctionExpr {
             .collect::<Result<Vec<_>>>()?;
 
         // evaluate the function
-        match self.fun {
-            ScalarFunctionDefinition::UDF(ref fun) => {
-                let output = match self.args.is_empty() {
-                    true => fun.invoke_no_args(batch.num_rows()),
-                    false => fun.invoke(&inputs),
-                }?;
-
-                if let ColumnarValue::Array(array) = &output {
-                    if array.len() != batch.num_rows() {
-                        return internal_err!("UDF returned a different number 
of rows than expected. Expected: {}, Got: {}",
+        let output = match self.args.is_empty() {
+            true => self.fun.invoke_no_args(batch.num_rows()),
+            false => self.fun.invoke(&inputs),
+        }?;
+
+        if let ColumnarValue::Array(array) = &output {
+            if array.len() != batch.num_rows() {
+                return internal_err!("UDF returned a different number of rows 
than expected. Expected: {}, Got: {}",
                         batch.num_rows(), array.len());
-                    }
-                }
-                Ok(output)
             }
         }
+        Ok(output)
     }
 
     fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
@@ -233,10 +226,9 @@ pub fn create_physical_expr(
     let return_type =
         fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;
 
-    let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone()));
     Ok(Arc::new(ScalarFunctionExpr::new(
         fun.name(),
-        fun_def,
+        Arc::new(fun.clone()),
         input_phy_exprs.to_vec(),
         return_type,
         fun.monotonicity()?,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index dcec2a3b85..80acd12e4e 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -36,8 +36,8 @@ use datafusion_common::{
 };
 use datafusion_expr::expr::{
     self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, 
GetFieldAccess,
-    GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction,
-    ScalarFunctionDefinition, Sort, Unnest,
+    GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, 
Sort,
+    Unnest,
 };
 use datafusion_expr::{
     logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -763,25 +763,19 @@ pub fn serialize_expr(
                 "Proto serialization error: Scalar Variable not 
supported".to_string(),
             ))
         }
-        Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+        Expr::ScalarFunction(ScalarFunction { func, args }) => {
             let args = serialize_exprs(args, codec)?;
-            match func_def {
-                ScalarFunctionDefinition::UDF(fun) => {
-                    let mut buf = Vec::new();
-                    let _ = codec.try_encode_udf(fun.as_ref(), &mut buf);
-
-                    let fun_definition = if buf.is_empty() { None } else { 
Some(buf) };
-
-                    protobuf::LogicalExprNode {
-                        expr_type: Some(ExprType::ScalarUdfExpr(
-                            protobuf::ScalarUdfExprNode {
-                                fun_name: fun.name().to_string(),
-                                fun_definition,
-                                args,
-                            },
-                        )),
-                    }
-                }
+            let mut buf = Vec::new();
+            let _ = codec.try_encode_udf(func.as_ref(), &mut buf);
+
+            let fun_definition = if buf.is_empty() { None } else { Some(buf) };
+
+            protobuf::LogicalExprNode {
+                expr_type: 
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
+                    fun_name: func.name().to_string(),
+                    fun_definition,
+                    args,
+                })),
             }
         }
         Expr::Not(expr) => {
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs 
b/datafusion/proto/src/physical_plan/from_proto.rs
index 33e632b0d9..4bd07fae49 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -53,7 +53,6 @@ use 
datafusion_common::file_options::json_writer::JsonWriterOptions;
 use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::stats::Precision;
 use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, 
ScalarValue};
-use datafusion_expr::ScalarFunctionDefinition;
 
 use crate::common::proto_error;
 use crate::convert_required;
@@ -342,7 +341,7 @@ pub fn parse_physical_expr(
                 Some(buf) => codec.try_decode_udf(&e.name, buf)?,
                 None => registry.udf(e.name.as_str())?,
             };
-            let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone());
+            let scalar_fun_def = udf.clone();
 
             let args = parse_physical_exprs(&e.args, registry, input_schema, 
codec)?;
 
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs 
b/datafusion/proto/src/physical_plan/to_proto.rs
index a0a0ee7205..3bc71f5f4c 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -56,7 +56,6 @@ use datafusion_common::{
     stats::Precision,
     DataFusionError, JoinSide, Result,
 };
-use datafusion_expr::ScalarFunctionDefinition;
 
 use crate::logical_plan::csv_writer_options_to_proto;
 use crate::protobuf::{
@@ -540,11 +539,7 @@ pub fn serialize_physical_expr(
         let args = serialize_physical_exprs(expr.args().to_vec(), codec)?;
 
         let mut buf = Vec::new();
-        match expr.fun() {
-            ScalarFunctionDefinition::UDF(udf) => {
-                codec.try_encode_udf(udf, &mut buf)?;
-            }
-        }
+        codec.try_encode_udf(expr.fun(), &mut buf)?;
 
         let fun_definition = if buf.is_empty() { None } else { Some(buf) };
         Ok(protobuf::PhysicalExprNode {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 5e446f93fe..c2018352c7 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -75,9 +75,8 @@ use datafusion_common::parsers::CompressionTypeVariant;
 use datafusion_common::stats::Precision;
 use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
 use datafusion_expr::{
-    Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
-    ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, 
SimpleAggregateUDF,
-    WindowFrame, WindowFrameBound,
+    Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, 
ScalarUDF,
+    ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, 
WindowFrameBound,
 };
 use datafusion_proto::physical_plan::{
     AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
@@ -618,7 +617,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
         scalar_fn.clone(),
     );
 
-    let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone()));
+    let fun_def = Arc::new(udf.clone());
 
     let expr = ScalarFunctionExpr::new(
         "dummy",
@@ -750,7 +749,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
     let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
     let udf_expr = Arc::new(ScalarFunctionExpr::new(
         udf.name(),
-        ScalarFunctionDefinition::UDF(Arc::new(udf.clone())),
+        Arc::new(udf.clone()),
         vec![col("text", &schema)?],
         DataType::Int64,
         None,
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index c619c62668..804fa6d306 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -108,8 +108,8 @@ impl Unparser<'_> {
                     negated: *negated,
                 })
             }
-            Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
-                let func_name = func_def.name();
+            Expr::ScalarFunction(ScalarFunction { func, args }) => {
+                let func_name = func.name();
 
                 let args = args
                     .iter()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to