This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f8dcc64ca3 Refactor: Unify `Expr::ScalarFunction` and
`Expr::ScalarUDF`, introduce unresolved functions by name (#8258)
f8dcc64ca3 is described below
commit f8dcc64ca3be4db315aa2e4d4da953ec8a3c87bb
Author: Yongting You <[email protected]>
AuthorDate: Sun Nov 26 03:49:18 2023 -0800
Refactor: Unify `Expr::ScalarFunction` and `Expr::ScalarUDF`, introduce
unresolved functions by name (#8258)
* Refactor Expr::ScalarFunction
* Remove Expr::ScalarUDF
* review comments
* make name() return &str
* fix fmt
* fix after merge
---
datafusion/core/src/datasource/listing/helpers.rs | 54 ++++++------
datafusion/core/src/physical_planner.rs | 18 ++--
datafusion/expr/src/expr.rs | 79 ++++++++++-------
datafusion/expr/src/expr_fn.rs | 98 ++++++++++++----------
datafusion/expr/src/expr_schema.rs | 64 +++++++-------
datafusion/expr/src/lib.rs | 2 +-
datafusion/expr/src/tree_node/expr.rs | 25 ++++--
datafusion/expr/src/udf.rs | 5 +-
datafusion/expr/src/utils.rs | 1 -
datafusion/optimizer/src/analyzer/type_coercion.rs | 75 +++++++++--------
datafusion/optimizer/src/push_down_filter.rs | 32 +++++--
.../src/simplify_expressions/expr_simplifier.rs | 59 +++++++++----
.../optimizer/src/simplify_expressions/utils.rs | 14 +++-
datafusion/physical-expr/src/planner.rs | 70 +++++++++-------
datafusion/proto/src/logical_plan/from_proto.rs | 2 +-
datafusion/proto/src/logical_plan/to_proto.rs | 52 +++++++-----
.../proto/tests/cases/roundtrip_logical_plan.rs | 7 +-
datafusion/sql/src/expr/function.rs | 4 +-
datafusion/sql/src/expr/value.rs | 9 +-
datafusion/substrait/src/logical_plan/consumer.rs | 7 +-
datafusion/substrait/src/logical_plan/producer.rs | 13 ++-
21 files changed, 419 insertions(+), 271 deletions(-)
diff --git a/datafusion/core/src/datasource/listing/helpers.rs
b/datafusion/core/src/datasource/listing/helpers.rs
index 322d65d564..f9b02f4d0c 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -38,9 +38,8 @@ use super::PartitionedFile;
use crate::datasource::listing::ListingTableUrl;
use crate::execution::context::SessionState;
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
-use datafusion_common::{Column, DFField, DFSchema, DataFusionError};
-use datafusion_expr::expr::ScalarUDF;
-use datafusion_expr::{Expr, Volatility};
+use datafusion_common::{internal_err, Column, DFField, DFSchema,
DataFusionError};
+use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use object_store::path::Path;
@@ -54,13 +53,13 @@ use object_store::{ObjectMeta, ObjectStore};
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
- Ok(match expr {
+ match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
if is_applicable {
- VisitRecursion::Skip
+ Ok(VisitRecursion::Skip)
} else {
- VisitRecursion::Stop
+ Ok(VisitRecursion::Stop)
}
}
Expr::Literal(_)
@@ -89,25 +88,32 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr:
&Expr) -> bool {
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
- | Expr::Case { .. } => VisitRecursion::Continue,
+ | Expr::Case { .. } => Ok(VisitRecursion::Continue),
Expr::ScalarFunction(scalar_function) => {
- match scalar_function.fun.volatility() {
- Volatility::Immutable => VisitRecursion::Continue,
- // TODO: Stable functions could be `applicable`, but that
would require access to the context
- Volatility::Stable | Volatility::Volatile => {
- is_applicable = false;
- VisitRecursion::Stop
+ match &scalar_function.func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } => {
+ match fun.volatility() {
+ Volatility::Immutable =>
Ok(VisitRecursion::Continue),
+ // TODO: Stable functions could be `applicable`,
but that would require access to the context
+ Volatility::Stable | Volatility::Volatile => {
+ is_applicable = false;
+ Ok(VisitRecursion::Stop)
+ }
+ }
}
- }
- }
- Expr::ScalarUDF(ScalarUDF { fun, .. }) => {
- match fun.signature().volatility {
- Volatility::Immutable => VisitRecursion::Continue,
- // TODO: Stable functions could be `applicable`, but that
would require access to the context
- Volatility::Stable | Volatility::Volatile => {
- is_applicable = false;
- VisitRecursion::Stop
+ ScalarFunctionDefinition::UDF(fun) => {
+ match fun.signature().volatility {
+ Volatility::Immutable =>
Ok(VisitRecursion::Continue),
+ // TODO: Stable functions could be `applicable`,
but that would require access to the context
+ Volatility::Stable | Volatility::Volatile => {
+ is_applicable = false;
+ Ok(VisitRecursion::Stop)
+ }
+ }
+ }
+ ScalarFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be
resolved.")
}
}
}
@@ -123,9 +129,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr:
&Expr) -> bool {
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => {
is_applicable = false;
- VisitRecursion::Stop
+ Ok(VisitRecursion::Stop)
}
- })
+ }
})
.unwrap();
is_applicable
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 82d96c98e6..09f0e11dc2 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -83,13 +83,13 @@ use datafusion_common::{
use datafusion_expr::dml::{CopyOptions, CopyTo};
use datafusion_expr::expr::{
self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast,
- GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, ScalarUDF,
TryCast,
- WindowFunction,
+ GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast,
WindowFunction,
};
use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols};
use
datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
- DescribeTable, DmlStatement, StringifiedPlan, WindowFrame,
WindowFrameBound, WriteOp,
+ DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan,
WindowFrame,
+ WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::expressions::Literal;
use datafusion_sql::utils::window_expr_common_partition_keys;
@@ -217,11 +217,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
Ok(name)
}
- Expr::ScalarFunction(func) => {
- create_function_physical_name(&func.fun.to_string(), false,
&func.args)
- }
- Expr::ScalarUDF(ScalarUDF { fun, args }) => {
- create_function_physical_name(fun.name(), false, args)
+ Expr::ScalarFunction(expr::ScalarFunction { func_def, args }) => {
+ // function should be resolved during `AnalyzerRule`s
+ if let ScalarFunctionDefinition::Name(_) = func_def {
+ return internal_err!("Function `Expr` with name should be
resolved.");
+ }
+
+ create_function_physical_name(func_def.name(), false, args)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
create_function_physical_name(&fun.to_string(), false, args)
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 2b2d30af3b..13e488dac0 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -148,10 +148,8 @@ pub enum Expr {
TryCast(TryCast),
/// A sort expression, that can be used to sort values.
Sort(Sort),
- /// Represents the call of a built-in scalar function with a set of
arguments.
+ /// Represents the call of a scalar function with a set of arguments.
ScalarFunction(ScalarFunction),
- /// Represents the call of a user-defined scalar function with arguments.
- ScalarUDF(ScalarUDF),
/// Represents the call of an aggregate built-in function with arguments.
AggregateFunction(AggregateFunction),
/// Represents the call of a window function with arguments.
@@ -338,37 +336,61 @@ impl Between {
}
}
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+/// Defines which implementation of a function for DataFusion to call.
+pub enum ScalarFunctionDefinition {
+ /// Resolved to a `BuiltinScalarFunction`
+ /// There is plan to migrate `BuiltinScalarFunction` to UDF-based
implementation (issue#8045)
+ /// This variant is planned to be removed in long term
+ BuiltIn {
+ fun: built_in_function::BuiltinScalarFunction,
+ name: Arc<str>,
+ },
+ /// Resolved to a user defined function
+ UDF(Arc<crate::ScalarUDF>),
+ /// A scalar function constructed with name. This variant can not be
executed directly
+ /// and instead must be resolved to one of the other variants prior to
physical planning.
+ Name(Arc<str>),
+}
+
/// ScalarFunction expression invokes a built-in scalar function
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ScalarFunction {
/// The function
- pub fun: built_in_function::BuiltinScalarFunction,
+ pub func_def: ScalarFunctionDefinition,
/// List of expressions to feed to the functions as arguments
pub args: Vec<Expr>,
}
+impl ScalarFunctionDefinition {
+ /// Function's name for display
+ pub fn name(&self) -> &str {
+ match self {
+ ScalarFunctionDefinition::BuiltIn { name, .. } => name.as_ref(),
+ ScalarFunctionDefinition::UDF(udf) => udf.name(),
+ ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
+ }
+ }
+}
+
impl ScalarFunction {
/// Create a new ScalarFunction expression
pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec<Expr>)
-> Self {
- Self { fun, args }
+ Self {
+ func_def: ScalarFunctionDefinition::BuiltIn {
+ fun,
+ name: Arc::from(fun.to_string()),
+ },
+ args,
+ }
}
-}
-/// ScalarUDF expression invokes a user-defined scalar function [`ScalarUDF`]
-///
-/// [`ScalarUDF`]: crate::ScalarUDF
-#[derive(Clone, PartialEq, Eq, Hash, Debug)]
-pub struct ScalarUDF {
- /// The function
- pub fun: Arc<crate::ScalarUDF>,
- /// List of expressions to feed to the functions as arguments
- pub args: Vec<Expr>,
-}
-
-impl ScalarUDF {
- /// Create a new ScalarUDF expression
- pub fn new(fun: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
- Self { fun, args }
+ /// Create a new ScalarFunction expression with a user-defined function
(UDF)
+ pub fn new_udf(udf: Arc<crate::ScalarUDF>, args: Vec<Expr>) -> Self {
+ Self {
+ func_def: ScalarFunctionDefinition::UDF(udf),
+ args,
+ }
}
}
@@ -736,7 +758,6 @@ impl Expr {
Expr::Placeholder(_) => "Placeholder",
Expr::ScalarFunction(..) => "ScalarFunction",
Expr::ScalarSubquery { .. } => "ScalarSubquery",
- Expr::ScalarUDF(..) => "ScalarUDF",
Expr::ScalarVariable(..) => "ScalarVariable",
Expr::Sort { .. } => "Sort",
Expr::TryCast { .. } => "TryCast",
@@ -1198,11 +1219,8 @@ impl fmt::Display for Expr {
write!(f, " NULLS LAST")
}
}
- Expr::ScalarFunction(func) => {
- fmt_function(f, &func.fun.to_string(), false, &func.args, true)
- }
- Expr::ScalarUDF(ScalarUDF { fun, args }) => {
- fmt_function(f, fun.name(), false, args, true)
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ fmt_function(f, func_def.name(), false, args, true)
}
Expr::WindowFunction(WindowFunction {
fun,
@@ -1534,11 +1552,8 @@ fn create_name(e: &Expr) -> Result<String> {
}
}
}
- Expr::ScalarFunction(func) => {
- create_function_name(&func.fun.to_string(), false, &func.args)
- }
- Expr::ScalarUDF(ScalarUDF { fun, args }) => {
- create_function_name(fun.name(), false, args)
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ create_function_name(func_def.name(), false, args)
}
Expr::WindowFunction(WindowFunction {
fun,
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index 674d2a34df..4da6857594 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -1014,7 +1014,7 @@ pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) ->
Result<Expr> {
#[cfg(test)]
mod test {
use super::*;
- use crate::lit;
+ use crate::{lit, ScalarFunctionDefinition};
#[test]
fn filter_is_null_and_is_not_null() {
@@ -1029,8 +1029,10 @@ mod test {
macro_rules! test_unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {{
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) =
- $FUNC(col("tableA.a"))
+ if let Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn { fun, .. },
+ args,
+ }) = $FUNC(col("tableA.a"))
{
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
@@ -1042,42 +1044,42 @@ mod test {
}
macro_rules! test_scalar_expr {
- ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
- let expected = [$(stringify!($arg)),*];
- let result = $FUNC(
+ ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
+ let expected = [$(stringify!($arg)),*];
+ let result = $FUNC(
+ $(
+ col(stringify!($arg.to_string()))
+ ),*
+ );
+ if let Expr::ScalarFunction(ScalarFunction { func_def:
ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result {
+ let name = built_in_function::BuiltinScalarFunction::$ENUM;
+ assert_eq!(name, fun);
+ assert_eq!(expected.len(), args.len());
+ } else {
+ assert!(false, "unexpected: {:?}", result);
+ }
+ };
+}
+
+ macro_rules! test_nary_scalar_expr {
+ ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
+ let expected = [$(stringify!($arg)),*];
+ let result = $FUNC(
+ vec![
$(
col(stringify!($arg.to_string()))
),*
- );
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result
{
- let name = built_in_function::BuiltinScalarFunction::$ENUM;
- assert_eq!(name, fun);
- assert_eq!(expected.len(), args.len());
- } else {
- assert!(false, "unexpected: {:?}", result);
- }
- };
- }
-
- macro_rules! test_nary_scalar_expr {
- ($ENUM:ident, $FUNC:ident, $($arg:ident),*) => {
- let expected = [$(stringify!($arg)),*];
- let result = $FUNC(
- vec![
- $(
- col(stringify!($arg.to_string()))
- ),*
- ]
- );
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result
{
- let name = built_in_function::BuiltinScalarFunction::$ENUM;
- assert_eq!(name, fun);
- assert_eq!(expected.len(), args.len());
- } else {
- assert!(false, "unexpected: {:?}", result);
- }
- };
- }
+ ]
+ );
+ if let Expr::ScalarFunction(ScalarFunction { func_def:
ScalarFunctionDefinition::BuiltIn{fun, ..}, args }) = result {
+ let name = built_in_function::BuiltinScalarFunction::$ENUM;
+ assert_eq!(name, fun);
+ assert_eq!(expected.len(), args.len());
+ } else {
+ assert!(false, "unexpected: {:?}", result);
+ }
+ };
+}
#[test]
fn scalar_function_definitions() {
@@ -1207,7 +1209,11 @@ mod test {
#[test]
fn uuid_function_definitions() {
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() {
+ if let Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn { fun, .. },
+ args,
+ }) = uuid()
+ {
let name = BuiltinScalarFunction::Uuid;
assert_eq!(name, fun);
assert_eq!(0, args.len());
@@ -1218,8 +1224,10 @@ mod test {
#[test]
fn digest_function_definitions() {
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) =
- digest(col("tableA.a"), lit("md5"))
+ if let Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn { fun, .. },
+ args,
+ }) = digest(col("tableA.a"), lit("md5"))
{
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
@@ -1231,8 +1239,10 @@ mod test {
#[test]
fn encode_function_definitions() {
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) =
- encode(col("tableA.a"), lit("base64"))
+ if let Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn { fun, .. },
+ args,
+ }) = encode(col("tableA.a"), lit("base64"))
{
let name = BuiltinScalarFunction::Encode;
assert_eq!(name, fun);
@@ -1244,8 +1254,10 @@ mod test {
#[test]
fn decode_function_definitions() {
- if let Expr::ScalarFunction(ScalarFunction { fun, args }) =
- decode(col("tableA.a"), lit("hex"))
+ if let Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn { fun, .. },
+ args,
+ }) = decode(col("tableA.a"), lit("hex"))
{
let name = BuiltinScalarFunction::Decode;
assert_eq!(name, fun);
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 0d06a12951..d5d9c848b2 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -18,8 +18,8 @@
use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess,
- GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
ScalarUDF, Sort,
- TryCast, WindowFunction,
+ GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction,
+ ScalarFunctionDefinition, Sort, TryCast, WindowFunction,
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
@@ -82,32 +82,39 @@ impl ExprSchemable for Expr {
Expr::Case(case) => case.when_then_expr[0].1.get_type(schema),
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) =>
Ok(data_type.clone()),
- Expr::ScalarUDF(ScalarUDF { fun, args }) => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- Ok(fun.return_type(&data_types)?)
- }
- Expr::ScalarFunction(ScalarFunction { fun, args }) => {
- let arg_data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
-
- // verify that input data types is consistent with function's
`TypeSignature`
- data_types(&arg_data_types, &fun.signature()).map_err(|_| {
- plan_datafusion_err!(
- "{}",
- utils::generate_signature_error_msg(
- &format!("{fun}"),
- fun.signature(),
- &arg_data_types,
- )
- )
- })?;
-
- fun.return_type(&arg_data_types)
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ match func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } => {
+ let arg_data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ // verify that input data types is consistent with
function's `TypeSignature`
+ data_types(&arg_data_types,
&fun.signature()).map_err(|_| {
+ plan_datafusion_err!(
+ "{}",
+ utils::generate_signature_error_msg(
+ &format!("{fun}"),
+ fun.signature(),
+ &arg_data_types,
+ )
+ )
+ })?;
+
+ fun.return_type(&arg_data_types)
+ }
+ ScalarFunctionDefinition::UDF(fun) => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(fun.return_type(&data_types)?)
+ }
+ ScalarFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be
resolved.")
+ }
+ }
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
@@ -243,7 +250,6 @@ impl ExprSchemable for Expr {
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
| Expr::ScalarFunction(..)
- | Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::AggregateUDF { .. }
diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs
index b9976f90c5..6172d17365 100644
--- a/datafusion/expr/src/lib.rs
+++ b/datafusion/expr/src/lib.rs
@@ -63,7 +63,7 @@ pub use built_in_function::BuiltinScalarFunction;
pub use columnar_value::ColumnarValue;
pub use expr::{
Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField,
GroupingSet,
- Like, TryCast,
+ Like, ScalarFunctionDefinition, TryCast,
};
pub use expr_fn::*;
pub use expr_schema::ExprSchemable;
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index 6b86de37ba..474b5f7689 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -20,12 +20,12 @@
use crate::expr::{
AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast,
GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder,
ScalarFunction,
- ScalarUDF, Sort, TryCast, WindowFunction,
+ ScalarFunctionDefinition, Sort, TryCast, WindowFunction,
};
use crate::{Expr, GetFieldAccess};
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
-use datafusion_common::Result;
+use datafusion_common::{internal_err, DataFusionError, Result};
impl TreeNode for Expr {
fn apply_children<F>(&self, op: &mut F) -> Result<VisitRecursion>
@@ -64,7 +64,7 @@ impl TreeNode for Expr {
}
Expr::GroupingSet(GroupingSet::Rollup(exprs))
| Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(),
- Expr::ScalarFunction (ScalarFunction{ args, .. } )|
Expr::ScalarUDF(ScalarUDF { args, .. }) => {
+ Expr::ScalarFunction (ScalarFunction{ args, .. } ) => {
args.clone()
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
@@ -276,12 +276,19 @@ impl TreeNode for Expr {
asc,
nulls_first,
)),
- Expr::ScalarFunction(ScalarFunction { args, fun }) =>
Expr::ScalarFunction(
- ScalarFunction::new(fun, transform_vec(args, &mut transform)?),
- ),
- Expr::ScalarUDF(ScalarUDF { args, fun }) => {
- Expr::ScalarUDF(ScalarUDF::new(fun, transform_vec(args, &mut
transform)?))
- }
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } =>
Expr::ScalarFunction(
+ ScalarFunction::new(fun, transform_vec(args, &mut
transform)?),
+ ),
+ ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction(
+ ScalarFunction::new_udf(fun, transform_vec(args, &mut
transform)?),
+ ),
+ ScalarFunctionDefinition::Name(_) => {
+ return internal_err!(
+ "Function `Expr` with name should be resolved."
+ );
+ }
+ },
Expr::WindowFunction(WindowFunction {
args,
fun,
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 22e56caaaf..bc910b928a 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -95,7 +95,10 @@ impl ScalarUDF {
/// creates a logical expression with a call of the UDF
/// This utility allows using the UDF without requiring access to the
registry.
pub fn call(&self, args: Vec<Expr>) -> Expr {
- Expr::ScalarUDF(crate::expr::ScalarUDF::new(Arc::new(self.clone()),
args))
+ Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf(
+ Arc::new(self.clone()),
+ args,
+ ))
}
/// Returns this function's name
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index ff95ff10e7..d8668fba8e 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -283,7 +283,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut
HashSet<Column>) -> Result<()> {
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction(..)
- | Expr::ScalarUDF(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 2c5e8c8b1c..6628e8961e 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -29,7 +29,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::{
self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like,
ScalarFunction,
- ScalarUDF, WindowFunction,
+ WindowFunction,
};
use datafusion_expr::expr_rewriter::rewrite_preserving_name;
use datafusion_expr::expr_schema::cast_subquery;
@@ -45,7 +45,8 @@ use datafusion_expr::type_coercion::{is_datetime,
is_utf8_or_large_utf8};
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
type_coercion, window_function, AggregateFunction, BuiltinScalarFunction,
Expr,
- LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound,
WindowFrameUnits,
+ LogicalPlan, Operator, Projection, ScalarFunctionDefinition, WindowFrame,
+ WindowFrameBound, WindowFrameUnits,
};
use datafusion_expr::{ExprSchemable, Signature};
@@ -319,24 +320,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let case = coerce_case_expression(case, &self.schema)?;
Ok(Expr::Case(case))
}
- Expr::ScalarUDF(ScalarUDF { fun, args }) => {
- let new_expr = coerce_arguments_for_signature(
- args.as_slice(),
- &self.schema,
- fun.signature(),
- )?;
- Ok(Expr::ScalarUDF(ScalarUDF::new(fun, new_expr)))
- }
- Expr::ScalarFunction(ScalarFunction { fun, args }) => {
- let new_args = coerce_arguments_for_signature(
- args.as_slice(),
- &self.schema,
- &fun.signature(),
- )?;
- let new_args =
- coerce_arguments_for_fun(new_args.as_slice(),
&self.schema, &fun)?;
- Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args)))
- }
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } => {
+ let new_args = coerce_arguments_for_signature(
+ args.as_slice(),
+ &self.schema,
+ &fun.signature(),
+ )?;
+ let new_args = coerce_arguments_for_fun(
+ new_args.as_slice(),
+ &self.schema,
+ &fun,
+ )?;
+ Ok(Expr::ScalarFunction(ScalarFunction::new(fun,
new_args)))
+ }
+ ScalarFunctionDefinition::UDF(fun) => {
+ let new_expr = coerce_arguments_for_signature(
+ args.as_slice(),
+ &self.schema,
+ fun.signature(),
+ )?;
+ Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun,
new_expr)))
+ }
+ ScalarFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be
resolved.")
+ }
+ },
Expr::AggregateFunction(expr::AggregateFunction {
fun,
args,
@@ -838,7 +847,7 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation =
Arc::new(move |_|
Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
- let udf = Expr::ScalarUDF(expr::ScalarUDF::new(
+ let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32],
Volatility::Stable),
@@ -859,7 +868,7 @@ mod test {
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_|
unimplemented!());
- let udf = Expr::ScalarUDF(expr::ScalarUDF::new(
+ let udf = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Int32],
Volatility::Stable),
@@ -873,9 +882,9 @@ mod test {
.err()
.unwrap();
assert_eq!(
- "type_coercion\ncaused by\nError during planning: Coercion from
[Utf8] to the signature Uniform(1, [Int32]) failed.",
- err.strip_backtrace()
- );
+ "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to
the signature Uniform(1, [Int32]) failed.",
+ err.strip_backtrace()
+ );
Ok(())
}
@@ -1246,10 +1255,10 @@ mod test {
None,
),
)));
- let expr = Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::MakeArray,
- args: vec![val.clone()],
- });
+ let expr = Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::MakeArray,
+ vec![val.clone()],
+ ));
let schema = Arc::new(DFSchema::new_with_metadata(
vec![DFField::new_unqualified(
"item",
@@ -1278,10 +1287,10 @@ mod test {
&schema,
)?;
- let expected = Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::MakeArray,
- args: vec![expected_casted_expr],
- });
+ let expected = Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::MakeArray,
+ vec![expected_casted_expr],
+ ));
assert_eq!(result, expected);
Ok(())
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 05f4072e38..7a2c6a8d8c 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -28,7 +28,8 @@ use datafusion_expr::{
and,
expr_rewriter::replace_col,
logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
- or, BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown,
+ or, BinaryExpr, Expr, Filter, Operator, ScalarFunctionDefinition,
+ TableProviderFilterPushDown,
};
use itertools::Itertools;
use std::collections::{HashMap, HashSet};
@@ -221,7 +222,10 @@ fn can_evaluate_as_join_condition(predicate: &Expr) ->
Result<bool> {
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::OuterReferenceColumn(_, _)
- | Expr::ScalarUDF(..) => {
+ | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction {
+ func_def: ScalarFunctionDefinition::UDF(_),
+ ..
+ }) => {
is_evaluate = false;
Ok(VisitRecursion::Stop)
}
@@ -977,10 +981,26 @@ fn is_volatile_expression(e: &Expr) -> bool {
let mut is_volatile = false;
e.apply(&mut |expr| {
Ok(match expr {
- Expr::ScalarFunction(f) if f.fun.volatility() ==
Volatility::Volatile => {
- is_volatile = true;
- VisitRecursion::Stop
- }
+ Expr::ScalarFunction(f) => match &f.func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. }
+ if fun.volatility() == Volatility::Volatile =>
+ {
+ is_volatile = true;
+ VisitRecursion::Stop
+ }
+ ScalarFunctionDefinition::UDF(fun)
+ if fun.signature().volatility == Volatility::Volatile =>
+ {
+ is_volatile = true;
+ VisitRecursion::Stop
+ }
+ ScalarFunctionDefinition::Name(_) => {
+ return internal_err!(
+ "Function `Expr` with name should be resolved."
+ );
+ }
+ _ => VisitRecursion::Continue,
+ },
_ => VisitRecursion::Continue,
})
})
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index ad64625f7f..3310bfed75 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -40,8 +40,8 @@ use datafusion_common::{
exec_err, internal_err, DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::{
- and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case,
ColumnarValue, Expr,
- Like, Volatility,
+ and, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue,
Expr, Like,
+ ScalarFunctionDefinition, Volatility,
};
use datafusion_expr::{
expr::{InList, InSubquery, ScalarFunction},
@@ -344,12 +344,15 @@ impl<'a> ConstEvaluator<'a> {
| Expr::GroupingSet(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => false,
- Expr::ScalarFunction(ScalarFunction { fun, .. }) => {
- Self::volatility_ok(fun.volatility())
- }
- Expr::ScalarUDF(expr::ScalarUDF { fun, .. }) => {
- Self::volatility_ok(fun.signature().volatility)
- }
+ Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match
func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } => {
+ Self::volatility_ok(fun.volatility())
+ }
+ ScalarFunctionDefinition::UDF(fun) => {
+ Self::volatility_ok(fun.signature().volatility)
+ }
+ ScalarFunctionDefinition::Name(_) => false,
+ },
Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Not(_)
@@ -1200,25 +1203,41 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
// log
Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::Log,
+ func_def:
+ ScalarFunctionDefinition::BuiltIn {
+ fun: BuiltinScalarFunction::Log,
+ ..
+ },
args,
}) => simpl_log(args, <&S>::clone(&info))?,
// power
Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::Power,
+ func_def:
+ ScalarFunctionDefinition::BuiltIn {
+ fun: BuiltinScalarFunction::Power,
+ ..
+ },
args,
}) => simpl_power(args, <&S>::clone(&info))?,
// concat
Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::Concat,
+ func_def:
+ ScalarFunctionDefinition::BuiltIn {
+ fun: BuiltinScalarFunction::Concat,
+ ..
+ },
args,
}) => simpl_concat(args)?,
// concat_ws
Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::ConcatWithSeparator,
+ func_def:
+ ScalarFunctionDefinition::BuiltIn {
+ fun: BuiltinScalarFunction::ConcatWithSeparator,
+ ..
+ },
args,
}) => match &args[..] {
[delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?,
@@ -1550,7 +1569,7 @@ mod tests {
// immutable UDF should get folded
// udf_add(1+2, 30+40) --> 73
- let expr = Expr::ScalarUDF(expr::ScalarUDF::new(
+ let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
make_udf_add(Volatility::Immutable),
args.clone(),
));
@@ -1559,15 +1578,21 @@ mod tests {
// stable UDF should be entirely folded
// udf_add(1+2, 30+40) --> 73
let fun = make_udf_add(Volatility::Stable);
- let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun),
args.clone()));
+ let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
+ Arc::clone(&fun),
+ args.clone(),
+ ));
test_evaluate(expr, lit(73));
// volatile UDF should have args folded
// udf_add(1+2, 30+40) --> udf_add(3, 70)
let fun = make_udf_add(Volatility::Volatile);
- let expr = Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun),
args));
- let expected_expr =
- Expr::ScalarUDF(expr::ScalarUDF::new(Arc::clone(&fun),
folded_args));
+ let expr =
+
Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args));
+ let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf(
+ Arc::clone(&fun),
+ folded_args,
+ ));
test_evaluate(expr, expected_expr);
}
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 17e5d97c30..e69207b688 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -23,7 +23,7 @@ use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
expr::{Between, BinaryExpr, InList},
expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or},
- lit, BuiltinScalarFunction, Expr, Like, Operator,
+ lit, BuiltinScalarFunction, Expr, Like, Operator, ScalarFunctionDefinition,
};
pub static POWS_OF_TEN: [i128; 38] = [
@@ -365,7 +365,11 @@ pub fn simpl_log(current_args: Vec<Expr>, info: &dyn
SimplifyInfo) -> Result<Exp
)?))
}
Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::Power,
+ func_def:
+ ScalarFunctionDefinition::BuiltIn {
+ fun: BuiltinScalarFunction::Power,
+ ..
+ },
args,
}) if base == &args[0] => Ok(args[1].clone()),
_ => {
@@ -405,7 +409,11 @@ pub fn simpl_power(current_args: Vec<Expr>, info: &dyn
SimplifyInfo) -> Result<E
Ok(base.clone())
}
Expr::ScalarFunction(ScalarFunction {
- fun: BuiltinScalarFunction::Log,
+ func_def:
+ ScalarFunctionDefinition::BuiltIn {
+ fun: BuiltinScalarFunction::Log,
+ ..
+ },
args,
}) if base == &args[0] => Ok(args[1].clone()),
_ => Ok(Expr::ScalarFunction(ScalarFunction::new(
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index f318cd3b0f..5c5cc8e36f 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -29,10 +29,10 @@ use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_err, DFSchema, DataFusionError,
Result,
ScalarValue,
};
-use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction, ScalarUDF};
+use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction};
use datafusion_expr::{
binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField,
Like,
- Operator, TryCast,
+ Operator, ScalarFunctionDefinition, TryCast,
};
use std::sync::Arc;
@@ -348,36 +348,50 @@ pub fn create_physical_expr(
)))
}
- Expr::ScalarFunction(ScalarFunction { fun, args }) => {
- let physical_args = args
- .iter()
- .map(|e| {
- create_physical_expr(e, input_dfschema, input_schema,
execution_props)
- })
- .collect::<Result<Vec<_>>>()?;
- functions::create_physical_expr(
- fun,
- &physical_args,
- input_schema,
- execution_props,
- )
- }
- Expr::ScalarUDF(ScalarUDF { fun, args }) => {
- let mut physical_args = vec![];
- for e in args {
- physical_args.push(create_physical_expr(
- e,
- input_dfschema,
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } => {
+ let physical_args = args
+ .iter()
+ .map(|e| {
+ create_physical_expr(
+ e,
+ input_dfschema,
+ input_schema,
+ execution_props,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+ functions::create_physical_expr(
+ fun,
+ &physical_args,
input_schema,
execution_props,
- )?);
+ )
+ }
+ ScalarFunctionDefinition::UDF(fun) => {
+ let mut physical_args = vec![];
+ for e in args {
+ physical_args.push(create_physical_expr(
+ e,
+ input_dfschema,
+ input_schema,
+ execution_props,
+ )?);
+ }
+ // udfs with zero params expect null array as input
+ if args.is_empty() {
+
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
+ }
+ udf::create_physical_expr(
+ fun.clone().as_ref(),
+ &physical_args,
+ input_schema,
+ )
}
- // udfs with zero params expect null array as input
- if args.is_empty() {
- physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
+ ScalarFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be resolved.")
}
- udf::create_physical_expr(fun.clone().as_ref(), &physical_args,
input_schema)
- }
+ },
Expr::Between(Between {
expr,
negated,
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 8069e017f7..d4a64287b0 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1723,7 +1723,7 @@ pub fn parse_expr(
}
ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args
}) => {
let scalar_fn = registry.udf(fun_name.as_str())?;
- Ok(Expr::ScalarUDF(expr::ScalarUDF::new(
+ Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
scalar_fn,
args.iter()
.map(|expr| parse_expr(expr, registry))
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 750eb03e83..508cde98ae 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -45,7 +45,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::{
self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField,
GroupingSet,
- InList, Like, Placeholder, ScalarFunction, ScalarUDF, Sort,
+ InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort,
};
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -756,29 +756,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.to_string(),
))
}
- Expr::ScalarFunction(ScalarFunction { fun, args }) => {
- let fun: protobuf::ScalarFunction = fun.try_into()?;
- let args: Vec<Self> = args
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<Self>, Error>>()?;
- Self {
- expr_type: Some(ExprType::ScalarFunction(
- protobuf::ScalarFunctionNode {
- fun: fun.into(),
- args,
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
+ ScalarFunctionDefinition::BuiltIn { fun, .. } => {
+ let fun: protobuf::ScalarFunction = fun.try_into()?;
+ let args: Vec<Self> = args
+ .iter()
+ .map(|e| e.try_into())
+ .collect::<Result<Vec<Self>, Error>>()?;
+ Self {
+ expr_type: Some(ExprType::ScalarFunction(
+ protobuf::ScalarFunctionNode {
+ fun: fun.into(),
+ args,
+ },
+ )),
+ }
+ }
+ ScalarFunctionDefinition::UDF(fun) => Self {
+ expr_type: Some(ExprType::ScalarUdfExpr(
+ protobuf::ScalarUdfExprNode {
+ fun_name: fun.name().to_string(),
+ args: args
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, Error>>()?,
},
)),
+ },
+ ScalarFunctionDefinition::Name(_) => {
+ return Err(Error::NotImplemented(
+ "Proto serialization error: Trying to serialize a
unresolved function"
+ .to_string(),
+ ));
}
- }
- Expr::ScalarUDF(ScalarUDF { fun, args }) => Self {
- expr_type:
Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode {
- fun_name: fun.name().to_string(),
- args: args
- .iter()
- .map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, Error>>()?,
- })),
},
Expr::AggregateUDF(expr::AggregateUDF {
fun,
diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
index acc7f07bfa..3ab001298e 100644
--- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs
@@ -39,7 +39,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err};
use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError,
ScalarValue};
use datafusion_expr::expr::{
self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like,
ScalarFunction,
- ScalarUDF, Sort,
+ Sort,
};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
use datafusion_expr::{
@@ -1402,7 +1402,10 @@ fn roundtrip_scalar_udf() {
scalar_fn,
);
- let test_expr = Expr::ScalarUDF(ScalarUDF::new(Arc::new(udf.clone()),
vec![lit("")]));
+ let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf(
+ Arc::new(udf.clone()),
+ vec![lit("")],
+ ));
let ctx = SessionContext::new();
ctx.register_udf(udf);
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index c77ef64718..24ba4d1b50 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -19,7 +19,7 @@ use crate::planner::{ContextProvider, PlannerContext,
SqlToRel};
use datafusion_common::{
not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError,
Result,
};
-use datafusion_expr::expr::{ScalarFunction, ScalarUDF};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::function::suggest_valid_function;
use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
@@ -66,7 +66,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// user-defined function (UDF) should have precedence in case it has
the same name as a scalar built-in function
if let Some(fm) = self.context_provider.get_function_meta(&name) {
let args = self.function_args_to_expr(args, schema,
planner_context)?;
- return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args)));
+ return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args)));
}
// next, scalar built-in
diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs
index 0f086bca68..f33e9e8ddf 100644
--- a/datafusion/sql/src/expr/value.rs
+++ b/datafusion/sql/src/expr/value.rs
@@ -24,8 +24,8 @@ use datafusion_common::{
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{BinaryExpr, Placeholder};
-use datafusion_expr::BuiltinScalarFunction;
use datafusion_expr::{lit, Expr, Operator};
+use datafusion_expr::{BuiltinScalarFunction, ScalarFunctionDefinition};
use log::debug;
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
use sqlparser::parser::ParserError::ParserError;
@@ -143,8 +143,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Expr::Literal(_) => {
values.push(value);
}
- Expr::ScalarFunction(ref scalar_function) => {
- if scalar_function.fun == BuiltinScalarFunction::MakeArray
{
+ Expr::ScalarFunction(ScalarFunction {
+ func_def: ScalarFunctionDefinition::BuiltIn { fun, .. },
+ ..
+ }) => {
+ if fun == BuiltinScalarFunction::MakeArray {
values.push(value);
} else {
return not_impl_err!(
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index f4c36557da..5cb72adaca 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -843,10 +843,9 @@ pub async fn from_substrait_rex(
};
args.push(arg_expr?.as_ref().clone());
}
- Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction {
- fun,
- args,
- })))
+ Ok(Arc::new(Expr::ScalarFunction(expr::ScalarFunction::new(
+ fun, args,
+ ))))
}
ScalarFunctionType::Op(op) => {
if f.arguments.len() != 2 {
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 4b6aded78b..95604e6d2d 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -34,7 +34,7 @@ use datafusion::common::{exec_err, internal_err,
not_impl_err};
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{
Alias, BinaryExpr, Case, Cast, GroupingSet, InList,
- ScalarFunction as DFScalarFunction, Sort, WindowFunction,
+ ScalarFunction as DFScalarFunction, ScalarFunctionDefinition, Sort,
WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan,
Operator};
use datafusion::prelude::Expr;
@@ -822,7 +822,7 @@ pub fn to_substrait_rex(
Ok(substrait_or_list)
}
}
- Expr::ScalarFunction(DFScalarFunction { fun, args }) => {
+ Expr::ScalarFunction(DFScalarFunction { func_def, args }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument {
@@ -834,7 +834,14 @@ pub fn to_substrait_rex(
)?)),
});
}
- let function_anchor = _register_function(fun.to_string(),
extension_info);
+
+ // function should be resolved during `AnalyzerRule`
+ if let ScalarFunctionDefinition::Name(_) = func_def {
+ return internal_err!("Function `Expr` with name should be
resolved.");
+ }
+
+ let function_anchor =
+ _register_function(func_def.name().to_string(),
extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,