This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 742e3c5b3f Remove ScalarFunctionDefinition (#10325)
742e3c5b3f is described below
commit 742e3c5b3f5d4d961d50cc77033ab7774b90c56f
Author: 张林伟 <[email protected]>
AuthorDate: Wed May 8 00:49:57 2024 +0800
Remove ScalarFunctionDefinition (#10325)
* Remove ScalarFunctionDefinition
* Fix test
* rename func_def to func
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/datasource/listing/helpers.rs | 18 +++-----
.../src/physical_optimizer/projection_pushdown.rs | 19 ++------
datafusion/expr/src/expr.rs | 54 ++++------------------
datafusion/expr/src/expr_schema.rs | 16 +++----
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/tree_node.rs | 12 ++---
datafusion/functions-array/src/rewrite.rs | 8 ++--
datafusion/functions/src/math/log.rs | 19 ++------
datafusion/functions/src/math/power.rs | 17 ++-----
datafusion/functions/src/string/concat.rs | 4 +-
datafusion/functions/src/string/concat_ws.rs | 4 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 24 ++++------
.../optimizer/src/optimize_projections/mod.rs | 4 +-
datafusion/optimizer/src/push_down_filter.rs | 7 +--
.../src/simplify_expressions/expr_simplifier.rs | 32 ++++++-------
datafusion/physical-expr/src/planner.rs | 22 ++++-----
datafusion/physical-expr/src/scalar_function.rs | 36 ++++++---------
datafusion/proto/src/logical_plan/to_proto.rs | 34 ++++++--------
datafusion/proto/src/physical_plan/from_proto.rs | 3 +-
datafusion/proto/src/physical_plan/to_proto.rs | 7 +--
.../proto/tests/cases/roundtrip_physical_plan.rs | 9 ++--
datafusion/sql/src/unparser/expr.rs | 4 +-
22 files changed, 124 insertions(+), 231 deletions(-)
diff --git a/datafusion/core/src/datasource/listing/helpers.rs
b/datafusion/core/src/datasource/listing/helpers.rs
index 09d9aa8811..0cffa05131 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -38,7 +38,7 @@ use log::{debug, trace};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DFSchema, DataFusionError};
-use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
+use datafusion_expr::{Expr, Volatility};
use datafusion_physical_expr::create_physical_expr;
use object_store::path::Path;
use object_store::{ObjectMeta, ObjectStore};
@@ -89,16 +89,12 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr:
&Expr) -> bool {
| Expr::Case { .. } => Ok(TreeNodeRecursion::Continue),
Expr::ScalarFunction(scalar_function) => {
- match &scalar_function.func_def {
- ScalarFunctionDefinition::UDF(fun) => {
- match fun.signature().volatility {
- Volatility::Immutable =>
Ok(TreeNodeRecursion::Continue),
- // TODO: Stable functions could be `applicable`,
but that would require access to the context
- Volatility::Stable | Volatility::Volatile => {
- is_applicable = false;
- Ok(TreeNodeRecursion::Stop)
- }
- }
+ match scalar_function.func.signature().volatility {
+ Volatility::Immutable => Ok(TreeNodeRecursion::Continue),
+ // TODO: Stable functions could be `applicable`, but that
would require access to the context
+ Volatility::Stable | Volatility::Volatile => {
+ is_applicable = false;
+ Ok(TreeNodeRecursion::Stop)
}
}
}
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index 160dd3a1c4..0190f35cc9 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -1301,8 +1301,7 @@ mod tests {
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{
- ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF,
ScalarUDFImpl,
- Signature, Volatility,
+ ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_physical_expr::expressions::{
BinaryExpr, CaseExpr, CastExpr, NegativeExpr,
@@ -1363,9 +1362,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
- DummyUDF::new(),
- ))),
+ Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
@@ -1431,9 +1428,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
- DummyUDF::new(),
- ))),
+ Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
@@ -1502,9 +1497,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
- DummyUDF::new(),
- ))),
+ Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
@@ -1570,9 +1563,7 @@ mod tests {
Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))),
Arc::new(ScalarFunctionExpr::new(
"scalar_expr",
-
ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl(
- DummyUDF::new(),
- ))),
+ Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())),
vec![
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b_new", 1)),
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index c154cd999a..9789dd345f 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -26,11 +26,11 @@ use std::sync::Arc;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
use crate::utils::expr_to_columns;
-use crate::window_frame;
use crate::{
aggregate_function, built_in_window_function, udaf, ExprSchemable,
Operator,
Signature,
};
+use crate::{window_frame, Volatility};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
@@ -399,18 +399,11 @@ impl Between {
}
}
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
-/// Defines which implementation of a function for DataFusion to call.
-pub enum ScalarFunctionDefinition {
- /// Resolved to a user defined function
- UDF(Arc<crate::ScalarUDF>),
-}
-
/// ScalarFunction expression invokes a built-in scalar function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
/// The function
- pub func_def: ScalarFunctionDefinition,
+ pub func: Arc<crate::ScalarUDF>,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
}
@@ -418,41 +411,14 @@ pub struct ScalarFunction {
impl ScalarFunction {
// return the Function's name
pub fn name(&self) -> &str {
- self.func_def.name()
- }
-}
-
-impl ScalarFunctionDefinition {
- /// Function's name for display
- pub fn name(&self) -> &str {
- match self {
- ScalarFunctionDefinition::UDF(udf) => udf.name(),
- }
- }
-
- /// Whether this function is volatile, i.e. whether it can return
different results
- /// when evaluated multiple times with the same input.
- pub fn is_volatile(&self) -> Result<bool> {
- match self {
- ScalarFunctionDefinition::UDF(udf) => {
- Ok(udf.signature().volatility == crate::Volatility::Volatile)
- }
- }
+ self.func.name()
}
}
impl ScalarFunction {
/// Create a new ScalarFunction expression with a user-defined function
(UDF)
pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
- Self {
- func_def: ScalarFunctionDefinition::UDF(udf),
- args,
- }
- }
-
- /// Create a new ScalarFunction expression with a user-defined function
(UDF)
- pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec<Expr>)
-> Self {
- Self { func_def, args }
+ Self { func: udf, args }
}
}
@@ -1299,7 +1265,7 @@ impl Expr {
/// results when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
self.exists(|expr| {
- Ok(matches!(expr, Expr::ScalarFunction(func) if
func.func_def.is_volatile()?))
+ Ok(matches!(expr, Expr::ScalarFunction(func) if
func.func.signature().volatility == Volatility::Volatile ))
})
}
@@ -1334,9 +1300,7 @@ impl Expr {
/// and thus any side effects (like divide by zero) may not be encountered
pub fn short_circuits(&self) -> bool {
match self {
- Expr::ScalarFunction(ScalarFunction { func_def, .. }) => {
- matches!(func_def, ScalarFunctionDefinition::UDF(fun) if
fun.short_circuits())
- }
+ Expr::ScalarFunction(ScalarFunction { func, .. }) =>
func.short_circuits(),
Expr::BinaryExpr(BinaryExpr { op, .. }) => {
matches!(op, Operator::And | Operator::Or)
}
@@ -2071,7 +2035,7 @@ mod test {
}
#[test]
- fn test_is_volatile_scalar_func_definition() {
+ fn test_is_volatile_scalar_func() {
// UDF
#[derive(Debug)]
struct TestScalarUDF {
@@ -2100,7 +2064,7 @@ mod test {
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
}));
- assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+ assert_ne!(udf.signature().volatility, Volatility::Volatile);
let udf = Arc::new(ScalarUDF::from(TestScalarUDF {
signature: Signature::uniform(
@@ -2109,7 +2073,7 @@ mod test {
Volatility::Volatile,
),
}));
- assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
+ assert_eq!(udf.signature().volatility, Volatility::Volatile);
}
use super::*;
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index f93f085749..4aca52d67c 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -19,7 +19,7 @@ use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast,
GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder,
ScalarFunction,
- ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction,
+ Sort, TryCast, Unnest, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
@@ -133,20 +133,18 @@ impl ExprSchemable for Expr {
}
}
}
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ Expr::ScalarFunction(ScalarFunction { func, args }) => {
let arg_data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
- match func_def {
- ScalarFunctionDefinition::UDF(fun) => {
// verify that function is invoked with correct number
and type of arguments as defined in `TypeSignature`
- data_types(&arg_data_types,
fun.signature()).map_err(|_| {
+ data_types(&arg_data_types,
func.signature()).map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
- fun.name(),
- fun.signature().clone(),
+ func.name(),
+ func.signature().clone(),
&arg_data_types,
)
)
@@ -154,9 +152,7 @@ impl ExprSchemable for Expr {
// perform additional function arguments validation
(due to limited
// expressiveness of `TypeSignature`), then infer
return type
- Ok(fun.return_type_from_exprs(args, schema,
&arg_data_types)?)
- }
- }
+ Ok(func.return_type_from_exprs(args, schema,
&arg_data_types)?)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index de4f310292..e2b68388ab 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -63,7 +63,7 @@ pub use built_in_window_function::BuiltInWindowFunction;
pub use columnar_value::ColumnarValue;
pub use expr::{
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField,
GroupingSet,
- Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition,
+ Like, TryCast, WindowFunctionDefinition,
};
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index ae3ca9afc4..710164eca3 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -20,7 +20,7 @@
use crate::expr::{
AggregateFunction, AggregateFunctionDefinition, Alias, Between,
BinaryExpr, Case,
Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
- ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest,
WindowFunction,
+ ScalarFunction, Sort, TryCast, Unnest, WindowFunction,
};
use crate::{Expr, GetFieldAccess};
@@ -281,11 +281,11 @@ impl TreeNode for Expr {
nulls_first,
}) => transform_box(expr, &mut f)?
.update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))),
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
- transform_vec(args, &mut f)?.map_data(|new_args| match
func_def {
- ScalarFunctionDefinition::UDF(fun) => {
- Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun,
new_args)))
- }
+ Expr::ScalarFunction(ScalarFunction { func, args }) => {
+ transform_vec(args, &mut f)?.map_data(|new_args| {
+ Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
+ func, new_args,
+ )))
})?
}
Expr::WindowFunction(WindowFunction {
diff --git a/datafusion/functions-array/src/rewrite.rs
b/datafusion/functions-array/src/rewrite.rs
index 32d15b5563..416e79cbc0 100644
--- a/datafusion/functions-array/src/rewrite.rs
+++ b/datafusion/functions-array/src/rewrite.rs
@@ -182,20 +182,20 @@ impl FunctionRewrite for ArrayFunctionRewriter {
/// Returns true if expr is a function call to the specified named function.
/// Returns false otherwise.
fn is_func(expr: &Expr, func_name: &str) -> bool {
- let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else
{
+ let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
return false;
};
- func_def.name() == func_name
+ func.name() == func_name
}
/// Returns true if expr is a function call with one of the specified names
fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool {
- let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else
{
+ let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else {
return false;
};
- func_names.contains(&func_def.name())
+ func_names.contains(&func.name())
}
/// returns Some(col) if this is Expr::Column
diff --git a/datafusion/functions/src/math/log.rs
b/datafusion/functions/src/math/log.rs
index f451321ea1..e6c698ad1a 100644
--- a/datafusion/functions/src/math/log.rs
+++ b/datafusion/functions/src/math/log.rs
@@ -24,9 +24,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{
- lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition,
-};
+use datafusion_expr::{lit, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF};
use arrow::array::{ArrayRef, Float32Array, Float64Array};
use datafusion_expr::TypeSignature::*;
@@ -178,8 +176,8 @@ impl ScalarUDFImpl for LogFunc {
&info.get_data_type(&base)?,
)?)))
}
- Expr::ScalarFunction(ScalarFunction { func_def, mut args })
- if is_pow(&func_def) && args.len() == 2 && base == args[0] =>
+ Expr::ScalarFunction(ScalarFunction { func, mut args })
+ if is_pow(&func) && args.len() == 2 && base == args[0] =>
{
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
@@ -207,15 +205,8 @@ impl ScalarUDFImpl for LogFunc {
}
/// Returns true if the function is `PowerFunc`
-fn is_pow(func_def: &ScalarFunctionDefinition) -> bool {
- match func_def {
- ScalarFunctionDefinition::UDF(fun) => fun
- .as_ref()
- .inner()
- .as_any()
- .downcast_ref::<PowerFunc>()
- .is_some(),
- }
+fn is_pow(func: &ScalarUDF) -> bool {
+ func.inner().as_any().downcast_ref::<PowerFunc>().is_some()
}
#[cfg(test)]
diff --git a/datafusion/functions/src/math/power.rs
b/datafusion/functions/src/math/power.rs
index 8cc6b4c02a..7677e8b2af 100644
--- a/datafusion/functions/src/math/power.rs
+++ b/datafusion/functions/src/math/power.rs
@@ -23,7 +23,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition};
+use datafusion_expr::{ColumnarValue, Expr, ScalarUDF};
use arrow::array::{ArrayRef, Float64Array, Int64Array};
use datafusion_expr::TypeSignature::*;
@@ -140,8 +140,8 @@ impl ScalarUDFImpl for PowerFunc {
Expr::Literal(value) if value ==
ScalarValue::new_one(&exponent_type)? => {
Ok(ExprSimplifyResult::Simplified(base))
}
- Expr::ScalarFunction(ScalarFunction { func_def, mut args })
- if is_log(&func_def) && args.len() == 2 && base == args[0] =>
+ Expr::ScalarFunction(ScalarFunction { func, mut args })
+ if is_log(&func) && args.len() == 2 && base == args[0] =>
{
let b = args.pop().unwrap(); // length checked above
Ok(ExprSimplifyResult::Simplified(b))
@@ -152,15 +152,8 @@ impl ScalarUDFImpl for PowerFunc {
}
/// Return true if this function call is a call to `Log`
-fn is_log(func_def: &ScalarFunctionDefinition) -> bool {
- match func_def {
- ScalarFunctionDefinition::UDF(fun) => fun
- .as_ref()
- .inner()
- .as_any()
- .downcast_ref::<LogFunc>()
- .is_some(),
- }
+fn is_log(func: &ScalarUDF) -> bool {
+ func.inner().as_any().downcast_ref::<LogFunc>().is_some()
}
#[cfg(test)]
diff --git a/datafusion/functions/src/string/concat.rs
b/datafusion/functions/src/string/concat.rs
index 55b7c2f222..6d15e22067 100644
--- a/datafusion/functions/src/string/concat.rs
+++ b/datafusion/functions/src/string/concat.rs
@@ -25,7 +25,7 @@ use datafusion_common::cast::as_string_array;
use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition,
Volatility};
+use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
use crate::string::common::*;
@@ -182,7 +182,7 @@ pub fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
if !args.eq(&new_args) {
Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
ScalarFunction {
- func_def: ScalarFunctionDefinition::UDF(concat()),
+ func: concat(),
args: new_args,
},
)))
diff --git a/datafusion/functions/src/string/concat_ws.rs
b/datafusion/functions/src/string/concat_ws.rs
index 1d27712b2c..4d05f4e707 100644
--- a/datafusion/functions/src/string/concat_ws.rs
+++ b/datafusion/functions/src/string/concat_ws.rs
@@ -26,7 +26,7 @@ use datafusion_common::cast::as_string_array;
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
-use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition,
Volatility};
+use datafusion_expr::{lit, ColumnarValue, Expr, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};
use crate::string::common::*;
@@ -266,7 +266,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) ->
Result<ExprSimplifyRes
Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
ScalarFunction {
- func_def:
ScalarFunctionDefinition::UDF(concat_ws()),
+ func: concat_ws(),
args: new_args,
},
)))
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 10479f29a5..61b1d1d77b 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -46,8 +46,7 @@ use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
not,
type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan,
Operator,
- ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame,
WindowFrameBound,
- WindowFrameUnits,
+ ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits,
};
use crate::analyzer::AnalyzerRule;
@@ -303,19 +302,14 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
let case = coerce_case_expression(case, self.schema)?;
Ok(Transformed::yes(Expr::Case(case)))
}
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
- ScalarFunctionDefinition::UDF(fun) => {
- let new_expr = coerce_arguments_for_signature(
- args,
- self.schema,
- fun.signature(),
- )?;
- let new_expr = coerce_arguments_for_fun(new_expr,
self.schema, &fun)?;
- Ok(Transformed::yes(Expr::ScalarFunction(
- ScalarFunction::new_udf(fun, new_expr),
- )))
- }
- },
+ Expr::ScalarFunction(ScalarFunction { func, args }) => {
+ let new_expr =
+ coerce_arguments_for_signature(args, self.schema,
func.signature())?;
+ let new_expr = coerce_arguments_for_fun(new_expr, self.schema,
&func)?;
+ Ok(Transformed::yes(Expr::ScalarFunction(
+ ScalarFunction::new_udf(func, new_expr),
+ )))
+ }
Expr::AggregateFunction(expr::AggregateFunction {
func_def,
args,
diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs
b/datafusion/optimizer/src/optimize_projections/mod.rs
index 0f2aaa6cbc..aa2d005379 100644
--- a/datafusion/optimizer/src/optimize_projections/mod.rs
+++ b/datafusion/optimizer/src/optimize_projections/mod.rs
@@ -555,8 +555,8 @@ fn rewrite_expr(expr: &Expr, input: &Projection) ->
Result<Option<Expr>> {
.map(|expr| rewrite_expr(expr, input))
.collect::<Result<Option<_>>>()?
.map(|new_args| {
- Expr::ScalarFunction(ScalarFunction::new_func_def(
- scalar_fn.func_def.clone(),
+ Expr::ScalarFunction(ScalarFunction::new_udf(
+ scalar_fn.func.clone(),
new_args,
))
}));
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 2355ee604e..f58345237b 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -35,7 +35,7 @@ use datafusion_expr::logical_plan::{
use datafusion_expr::utils::{conjunction, split_conjunction,
split_conjunction_owned};
use datafusion_expr::{
and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder,
Operator,
- ScalarFunctionDefinition, TableProviderFilterPushDown,
+ TableProviderFilterPushDown,
};
use crate::optimizer::ApplyOrder;
@@ -228,10 +228,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) ->
Result<bool> {
| Expr::ScalarSubquery(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::Unnest(_)
- | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
- func_def: ScalarFunctionDefinition::UDF(_),
- ..
- }) => {
+ | Expr::ScalarFunction(_) => {
is_evaluate = false;
Ok(TreeNodeRecursion::Stop)
}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 0f711d6a2c..5122de4f09 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -35,8 +35,7 @@ use datafusion_common::{internal_err, DFSchema,
DataFusionError, Result, ScalarV
use datafusion_expr::expr::{InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
- and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator,
- ScalarFunctionDefinition, Volatility,
+ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator,
Volatility,
};
use datafusion_expr::{expr::ScalarFunction,
interval_arithmetic::NullableInterval};
use datafusion_physical_expr::{create_physical_expr,
execution_props::ExecutionProps};
@@ -595,11 +594,9 @@ impl<'a> ConstEvaluator<'a> {
| Expr::GroupingSet(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => false,
- Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match
func_def {
- ScalarFunctionDefinition::UDF(fun) => {
- Self::volatility_ok(fun.signature().volatility)
- }
- },
+ Expr::ScalarFunction(ScalarFunction { func, .. }) => {
+ Self::volatility_ok(func.signature().volatility)
+ }
Expr::Literal(_)
| Expr::Unnest(_)
| Expr::BinaryExpr { .. }
@@ -1373,18 +1370,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
// Do a first pass at simplification
out_expr.rewrite(self)?
}
- Expr::ScalarFunction(ScalarFunction {
- func_def: ScalarFunctionDefinition::UDF(udf),
- args,
- }) => match udf.simplify(args, info)? {
- ExprSimplifyResult::Original(args) => {
- Transformed::no(Expr::ScalarFunction(ScalarFunction {
- func_def: ScalarFunctionDefinition::UDF(udf),
- args,
- }))
+ Expr::ScalarFunction(ScalarFunction { func: udf, args }) => {
+ match udf.simplify(args, info)? {
+ ExprSimplifyResult::Original(args) => {
+ Transformed::no(Expr::ScalarFunction(ScalarFunction {
+ func: udf,
+ args,
+ }))
+ }
+ ExprSimplifyResult::Simplified(expr) =>
Transformed::yes(expr),
}
- ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr),
- },
+ }
//
// Rules for Between
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 2621b817b2..ab57a8e800 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -28,7 +28,7 @@ use datafusion_expr::var_provider::is_system_variables;
use datafusion_expr::var_provider::VarType;
use datafusion_expr::{
binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField,
Like,
- Operator, ScalarFunctionDefinition, TryCast,
+ Operator, TryCast,
};
use crate::scalar_function;
@@ -305,21 +305,17 @@ pub fn create_physical_expr(
}
},
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ Expr::ScalarFunction(ScalarFunction { func, args }) => {
let physical_args =
create_physical_exprs(args, input_dfschema, execution_props)?;
- match func_def {
- ScalarFunctionDefinition::UDF(fun) => {
- scalar_function::create_physical_expr(
- fun.clone().as_ref(),
- &physical_args,
- input_schema,
- args,
- input_dfschema,
- )
- }
- }
+ scalar_function::create_physical_expr(
+ func.clone().as_ref(),
+ &physical_args,
+ input_schema,
+ args,
+ input_dfschema,
+ )
}
Expr::Between(Between {
expr,
diff --git a/datafusion/physical-expr/src/scalar_function.rs
b/datafusion/physical-expr/src/scalar_function.rs
index 6b84b81e9f..180f2a7946 100644
--- a/datafusion/physical-expr/src/scalar_function.rs
+++ b/datafusion/physical-expr/src/scalar_function.rs
@@ -40,10 +40,7 @@ use arrow::record_batch::RecordBatch;
use datafusion_common::{internal_err, DFSchema, Result};
use datafusion_expr::type_coercion::functions::data_types;
-use datafusion_expr::{
- expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity,
ScalarFunctionDefinition,
- ScalarUDF,
-};
+use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity,
ScalarUDF};
use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal};
use crate::sort_properties::SortProperties;
@@ -51,7 +48,7 @@ use crate::PhysicalExpr;
/// Physical expression of a scalar function
pub struct ScalarFunctionExpr {
- fun: ScalarFunctionDefinition,
+ fun: Arc<ScalarUDF>,
name: String,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
@@ -78,7 +75,7 @@ impl ScalarFunctionExpr {
/// Create a new Scalar function
pub fn new(
name: &str,
- fun: ScalarFunctionDefinition,
+ fun: Arc<ScalarUDF>,
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
@@ -93,7 +90,7 @@ impl ScalarFunctionExpr {
}
/// Get the scalar function implementation
- pub fn fun(&self) -> &ScalarFunctionDefinition {
+ pub fn fun(&self) -> &ScalarUDF {
&self.fun
}
@@ -146,22 +143,18 @@ impl PhysicalExpr for ScalarFunctionExpr {
.collect::<Result<Vec<_>>>()?;
// evaluate the function
- match self.fun {
- ScalarFunctionDefinition::UDF(ref fun) => {
- let output = match self.args.is_empty() {
- true => fun.invoke_no_args(batch.num_rows()),
- false => fun.invoke(&inputs),
- }?;
-
- if let ColumnarValue::Array(array) = &output {
- if array.len() != batch.num_rows() {
- return internal_err!("UDF returned a different number
of rows than expected. Expected: {}, Got: {}",
+ let output = match self.args.is_empty() {
+ true => self.fun.invoke_no_args(batch.num_rows()),
+ false => self.fun.invoke(&inputs),
+ }?;
+
+ if let ColumnarValue::Array(array) = &output {
+ if array.len() != batch.num_rows() {
+ return internal_err!("UDF returned a different number of rows
than expected. Expected: {}, Got: {}",
batch.num_rows(), array.len());
- }
- }
- Ok(output)
}
}
+ Ok(output)
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
@@ -233,10 +226,9 @@ pub fn create_physical_expr(
let return_type =
fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;
- let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone()));
Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
- fun_def,
+ Arc::new(fun.clone()),
input_phy_exprs.to_vec(),
return_type,
fun.monotonicity()?,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index dcec2a3b85..80acd12e4e 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -36,8 +36,8 @@ use datafusion_common::{
};
use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast,
GetFieldAccess,
- GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction,
- ScalarFunctionDefinition, Sort, Unnest,
+ GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction,
Sort,
+ Unnest,
};
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -763,25 +763,19 @@ pub fn serialize_expr(
"Proto serialization error: Scalar Variable not
supported".to_string(),
))
}
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ Expr::ScalarFunction(ScalarFunction { func, args }) => {
let args = serialize_exprs(args, codec)?;
- match func_def {
- ScalarFunctionDefinition::UDF(fun) => {
- let mut buf = Vec::new();
- let _ = codec.try_encode_udf(fun.as_ref(), &mut buf);
-
- let fun_definition = if buf.is_empty() { None } else {
Some(buf) };
-
- protobuf::LogicalExprNode {
- expr_type: Some(ExprType::ScalarUdfExpr(
- protobuf::ScalarUdfExprNode {
- fun_name: fun.name().to_string(),
- fun_definition,
- args,
- },
- )),
- }
- }
+ let mut buf = Vec::new();
+ let _ = codec.try_encode_udf(func.as_ref(), &mut buf);
+
+ let fun_definition = if buf.is_empty() { None } else { Some(buf) };
+
+ protobuf::LogicalExprNode {
+ expr_type:
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
+ fun_name: func.name().to_string(),
+ fun_definition,
+ args,
+ })),
}
}
Expr::Not(expr) => {
diff --git a/datafusion/proto/src/physical_plan/from_proto.rs
b/datafusion/proto/src/physical_plan/from_proto.rs
index 33e632b0d9..4bd07fae49 100644
--- a/datafusion/proto/src/physical_plan/from_proto.rs
+++ b/datafusion/proto/src/physical_plan/from_proto.rs
@@ -53,7 +53,6 @@ use
datafusion_common::file_options::json_writer::JsonWriterOptions;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result,
ScalarValue};
-use datafusion_expr::ScalarFunctionDefinition;
use crate::common::proto_error;
use crate::convert_required;
@@ -342,7 +341,7 @@ pub fn parse_physical_expr(
Some(buf) => codec.try_decode_udf(&e.name, buf)?,
None => registry.udf(e.name.as_str())?,
};
- let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone());
+ let scalar_fun_def = udf.clone();
let args = parse_physical_exprs(&e.args, registry, input_schema,
codec)?;
diff --git a/datafusion/proto/src/physical_plan/to_proto.rs
b/datafusion/proto/src/physical_plan/to_proto.rs
index a0a0ee7205..3bc71f5f4c 100644
--- a/datafusion/proto/src/physical_plan/to_proto.rs
+++ b/datafusion/proto/src/physical_plan/to_proto.rs
@@ -56,7 +56,6 @@ use datafusion_common::{
stats::Precision,
DataFusionError, JoinSide, Result,
};
-use datafusion_expr::ScalarFunctionDefinition;
use crate::logical_plan::csv_writer_options_to_proto;
use crate::protobuf::{
@@ -540,11 +539,7 @@ pub fn serialize_physical_expr(
let args = serialize_physical_exprs(expr.args().to_vec(), codec)?;
let mut buf = Vec::new();
- match expr.fun() {
- ScalarFunctionDefinition::UDF(udf) => {
- codec.try_encode_udf(udf, &mut buf)?;
- }
- }
+ codec.try_encode_udf(expr.fun(), &mut buf)?;
let fun_definition = if buf.is_empty() { None } else { Some(buf) };
Ok(protobuf::PhysicalExprNode {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 5e446f93fe..c2018352c7 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -75,9 +75,8 @@ use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
use datafusion_expr::{
- Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
- ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature,
SimpleAggregateUDF,
- WindowFrame, WindowFrameBound,
+ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue,
ScalarUDF,
+ ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame,
WindowFrameBound,
};
use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
@@ -618,7 +617,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
scalar_fn.clone(),
);
- let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone()));
+ let fun_def = Arc::new(udf.clone());
let expr = ScalarFunctionExpr::new(
"dummy",
@@ -750,7 +749,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> {
let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string()));
let udf_expr = Arc::new(ScalarFunctionExpr::new(
udf.name(),
- ScalarFunctionDefinition::UDF(Arc::new(udf.clone())),
+ Arc::new(udf.clone()),
vec![col("text", &schema)?],
DataType::Int64,
None,
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index c619c62668..804fa6d306 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -108,8 +108,8 @@ impl Unparser<'_> {
negated: *negated,
})
}
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
- let func_name = func_def.name();
+ Expr::ScalarFunction(ScalarFunction { func, args }) => {
+ let func_name = func.name();
let args = args
.iter()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]