This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a6e6d3fab0 Refactor function argument handling in (#8387)
a6e6d3fab0 is described below
commit a6e6d3fab083839239ef81cf3a3546dd8929a541
Author: Alex Huang <[email protected]>
AuthorDate: Fri Dec 1 21:14:19 2023 +0100
Refactor function argument handling in (#8387)
ScalarFunctionDefinition
---
datafusion/expr/src/expr_schema.rs | 15 ++----
datafusion/physical-expr/src/planner.rs | 66 +++++++++++----------------
datafusion/proto/src/logical_plan/to_proto.rs | 53 +++++++++++----------
3 files changed, 58 insertions(+), 76 deletions(-)
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 2795ac5f09..e5b0185d90 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -83,13 +83,12 @@ impl ExprSchemable for Expr {
Expr::Cast(Cast { data_type, .. })
| Expr::TryCast(TryCast { data_type, .. }) =>
Ok(data_type.clone()),
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ let arg_data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
- let arg_data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
-
// verify that input data types is consistent with
function's `TypeSignature`
data_types(&arg_data_types,
&fun.signature()).map_err(|_| {
plan_datafusion_err!(
@@ -105,11 +104,7 @@ impl ExprSchemable for Expr {
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- Ok(fun.return_type(&data_types)?)
+ Ok(fun.return_type(&arg_data_types)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be
resolved.")
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 5501647da2..9c212cb81f 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -348,50 +348,38 @@ pub fn create_physical_expr(
)))
}
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
- ScalarFunctionDefinition::BuiltIn(fun) => {
- let physical_args = args
- .iter()
- .map(|e| {
- create_physical_expr(
- e,
- input_dfschema,
- input_schema,
- execution_props,
- )
- })
- .collect::<Result<Vec<_>>>()?;
- functions::create_physical_expr(
- fun,
- &physical_args,
- input_schema,
- execution_props,
- )
- }
- ScalarFunctionDefinition::UDF(fun) => {
- let mut physical_args = vec![];
- for e in args {
- physical_args.push(create_physical_expr(
- e,
- input_dfschema,
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ let mut physical_args = args
+ .iter()
+ .map(|e| {
+ create_physical_expr(e, input_dfschema, input_schema,
execution_props)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ match func_def {
+ ScalarFunctionDefinition::BuiltIn(fun) => {
+ functions::create_physical_expr(
+ fun,
+ &physical_args,
input_schema,
execution_props,
- )?);
+ )
+ }
+ ScalarFunctionDefinition::UDF(fun) => {
+ // udfs with zero params expect null array as input
+ if args.is_empty() {
+
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
+ }
+ udf::create_physical_expr(
+ fun.clone().as_ref(),
+ &physical_args,
+ input_schema,
+ )
}
- // udfs with zero params expect null array as input
- if args.is_empty() {
-
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
+ ScalarFunctionDefinition::Name(_) => {
+ internal_err!("Function `Expr` with name should be
resolved.")
}
- udf::create_physical_expr(
- fun.clone().as_ref(),
- &physical_args,
- input_schema,
- )
}
- ScalarFunctionDefinition::Name(_) => {
- internal_err!("Function `Expr` with name should be resolved.")
- }
- },
+ }
Expr::Between(Between {
expr,
negated,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index ab8e850014..ecbfaca5db 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -792,40 +792,39 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.to_string(),
))
}
- Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
- ScalarFunctionDefinition::BuiltIn(fun) => {
- let fun: protobuf::ScalarFunction = fun.try_into()?;
- let args: Vec<Self> = args
- .iter()
- .map(|e| e.try_into())
- .collect::<Result<Vec<Self>, Error>>()?;
- Self {
- expr_type: Some(ExprType::ScalarFunction(
- protobuf::ScalarFunctionNode {
- fun: fun.into(),
+ Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
+ let args = args
+ .iter()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, Error>>()?;
+ match func_def {
+ ScalarFunctionDefinition::BuiltIn(fun) => {
+ let fun: protobuf::ScalarFunction = fun.try_into()?;
+ Self {
+ expr_type: Some(ExprType::ScalarFunction(
+ protobuf::ScalarFunctionNode {
+ fun: fun.into(),
+ args,
+ },
+ )),
+ }
+ }
+ ScalarFunctionDefinition::UDF(fun) => Self {
+ expr_type: Some(ExprType::ScalarUdfExpr(
+ protobuf::ScalarUdfExprNode {
+ fun_name: fun.name().to_string(),
args,
},
)),
- }
- }
- ScalarFunctionDefinition::UDF(fun) => Self {
- expr_type: Some(ExprType::ScalarUdfExpr(
- protobuf::ScalarUdfExprNode {
- fun_name: fun.name().to_string(),
- args: args
- .iter()
- .map(|expr| expr.try_into())
- .collect::<Result<Vec<_>, Error>>()?,
- },
- )),
- },
- ScalarFunctionDefinition::Name(_) => {
- return Err(Error::NotImplemented(
+ },
+ ScalarFunctionDefinition::Name(_) => {
+ return Err(Error::NotImplemented(
"Proto serialization error: Trying to serialize a
unresolved function"
.to_string(),
));
+ }
}
- },
+ }
Expr::Not(expr) => {
let expr = Box::new(protobuf::Not {
expr: Some(Box::new(expr.as_ref().try_into()?)),