This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 153356e1a1 refactor: Expr::ScalarFunction to use a struct (#6281)
153356e1a1 is described below
commit 153356e1a1b2bb479a564aed96be242e267ced50
Author: jakevin <[email protected]>
AuthorDate: Mon May 8 23:25:08 2023 +0800
refactor: Expr::ScalarFunction to use a struct (#6281)
---
benchmarks/src/tpch.rs | 62 +++++++++++-----------
datafusion/core/src/datasource/listing/helpers.rs | 4 +-
datafusion/core/src/physical_plan/planner.rs | 4 +-
datafusion/expr/src/expr.rs | 33 ++++++++----
datafusion/expr/src/expr_fn.rs | 60 ++++++++++-----------
datafusion/expr/src/expr_schema.rs | 7 +--
datafusion/expr/src/tree_node/expr.rs | 11 ++--
datafusion/expr/src/utils.rs | 2 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 19 +++----
datafusion/optimizer/src/push_down_filter.rs | 2 +-
.../src/simplify_expressions/expr_simplifier.rs | 31 ++++++-----
.../optimizer/src/simplify_expressions/utils.rs | 41 +++++++-------
datafusion/physical-expr/src/planner.rs | 4 +-
datafusion/proto/src/logical_plan/mod.rs | 23 ++++----
datafusion/proto/src/logical_plan/to_proto.rs | 5 +-
datafusion/sql/src/expr/function.rs | 5 +-
datafusion/sql/src/expr/mod.rs | 11 ++--
datafusion/sql/src/expr/substring.rs | 7 +--
datafusion/sql/src/utils.rs | 17 +++---
19 files changed, 181 insertions(+), 167 deletions(-)
diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs
index 4e144edcc8..72f5907b07 100644
--- a/benchmarks/src/tpch.rs
+++ b/benchmarks/src/tpch.rs
@@ -29,6 +29,7 @@ use datafusion::common::cast::{
as_int64_array, as_string_array,
};
use datafusion::common::ScalarValue;
+use datafusion::logical_expr::expr::ScalarFunction;
use datafusion::logical_expr::Cast;
use datafusion::prelude::*;
use datafusion::{
@@ -495,38 +496,37 @@ pub async fn transform_actual_result(
})
.collect::<Result<Vec<_>>>()?;
let table = Arc::new(MemTable::try_new(result_schema.clone(),
vec![result])?);
- let mut df = ctx.read_table(table)?
- .select(
- result_schema
- .fields
- .iter()
- .map(|field| {
- match field.data_type() {
- DataType::Decimal128(_, _) => {
- // if decimal, then round it to 2 decimal places
like the answers
- // round() doesn't support the second argument for
decimal places to round to
- // this can be simplified to remove the mul and
div when
- //
https://github.com/apache/arrow-datafusion/issues/2420 is completed
- // cast it back to an over-sized Decimal with 2
precision when done rounding
- let round = Box::new(Expr::ScalarFunction {
- fun:
datafusion::logical_expr::BuiltinScalarFunction::Round,
- args:
vec![col(Field::name(field)).mul(lit(100))],
- }.div(lit(100)));
- Expr::Cast(Cast::new(
- round,
- DataType::Decimal128(15, 2),
- )).alias(field.name())
- }
- DataType::Utf8 => {
- // if string, then trim it like the answers got
trimmed
- trim(col(Field::name(field))).alias(field.name())
- }
- _ => {
- col(field.name())
- }
+ let mut df = ctx.read_table(table)?.select(
+ result_schema
+ .fields
+ .iter()
+ .map(|field| {
+ match field.data_type() {
+ DataType::Decimal128(_, _) => {
+ // if decimal, then round it to 2 decimal places like
the answers
+ // round() doesn't support the second argument for
decimal places to round to
+ // this can be simplified to remove the mul and div
when
+ //
https://github.com/apache/arrow-datafusion/issues/2420 is completed
+ // cast it back to an over-sized Decimal with 2
precision when done rounding
+ let round = Box::new(
+ Expr::ScalarFunction(ScalarFunction::new(
+
datafusion::logical_expr::BuiltinScalarFunction::Round,
+ vec![col(Field::name(field)).mul(lit(100))],
+ ))
+ .div(lit(100)),
+ );
+ Expr::Cast(Cast::new(round, DataType::Decimal128(15,
2)))
+ .alias(field.name())
}
- }).collect()
- )?;
+ DataType::Utf8 => {
+ // if string, then trim it like the answers got trimmed
+ trim(col(Field::name(field))).alias(field.name())
+ }
+ _ => col(field.name()),
+ }
+ })
+ .collect(),
+ )?;
if let Some(x) = QUERY_LIMIT[n - 1] {
df = df.limit(0, Some(x))?;
}
diff --git a/datafusion/core/src/datasource/listing/helpers.rs
b/datafusion/core/src/datasource/listing/helpers.rs
index de7e243e3b..2a09445665 100644
--- a/datafusion/core/src/datasource/listing/helpers.rs
+++ b/datafusion/core/src/datasource/listing/helpers.rs
@@ -95,8 +95,8 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr:
&Expr) -> bool {
| Expr::GroupingSet(_)
| Expr::Case { .. } => VisitRecursion::Continue,
- Expr::ScalarFunction { fun, .. } => {
- match fun.volatility() {
+ Expr::ScalarFunction(scalar_function) => {
+ match scalar_function.fun.volatility() {
Volatility::Immutable => VisitRecursion::Continue,
// TODO: Stable functions could be `applicable`, but that
would require access to the context
Volatility::Stable | Volatility::Volatile => {
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index f685d5a5f4..88e0d6e3d0 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -184,8 +184,8 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
let expr = create_physical_name(expr, false)?;
Ok(format!("{expr}[{key}]"))
}
- Expr::ScalarFunction { fun, args, .. } => {
- create_function_physical_name(&fun.to_string(), false, args)
+ Expr::ScalarFunction(func) => {
+ create_function_physical_name(&func.fun.to_string(), false,
&func.args)
}
Expr::ScalarUDF { fun, args, .. } => {
create_function_physical_name(&fun.name, false, args)
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index d89589a30a..f508efb2ab 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -152,12 +152,7 @@ pub enum Expr {
/// A sort expression, that can be used to sort values.
Sort(Sort),
/// Represents the call of a built-in scalar function with a set of
arguments.
- ScalarFunction {
- /// The function
- fun: built_in_function::BuiltinScalarFunction,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- },
+ ScalarFunction(ScalarFunction),
/// Represents the call of a user-defined scalar function with arguments.
ScalarUDF {
/// The function
@@ -353,6 +348,22 @@ impl Between {
}
}
+/// ScalarFunction expression
+#[derive(Clone, PartialEq, Eq, Hash)]
+pub struct ScalarFunction {
+ /// The function
+ pub fun: built_in_function::BuiltinScalarFunction,
+ /// List of expressions to feed to the functions as arguments
+ pub args: Vec<Expr>,
+}
+
+impl ScalarFunction {
+ /// Create a new ScalarFunction expression
+ pub fn new(fun: built_in_function::BuiltinScalarFunction, args: Vec<Expr>)
-> Self {
+ Self { fun, args }
+ }
+}
+
/// Returns the field of a [`arrow::array::ListArray`] or
[`arrow::array::StructArray`] by key
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct GetIndexedField {
@@ -592,7 +603,7 @@ impl Expr {
Expr::Not(..) => "Not",
Expr::Placeholder { .. } => "Placeholder",
Expr::QualifiedWildcard { .. } => "QualifiedWildcard",
- Expr::ScalarFunction { .. } => "ScalarFunction",
+ Expr::ScalarFunction(..) => "ScalarFunction",
Expr::ScalarSubquery { .. } => "ScalarSubquery",
Expr::ScalarUDF { .. } => "ScalarUDF",
Expr::ScalarVariable(..) => "ScalarVariable",
@@ -927,8 +938,8 @@ impl fmt::Debug for Expr {
write!(f, " NULLS LAST")
}
}
- Expr::ScalarFunction { fun, args, .. } => {
- fmt_function(f, &fun.to_string(), false, args, false)
+ Expr::ScalarFunction(func) => {
+ fmt_function(f, &func.fun.to_string(), false, &func.args,
false)
}
Expr::ScalarUDF { fun, ref args, .. } => {
fmt_function(f, &fun.name, false, args, false)
@@ -1286,8 +1297,8 @@ fn create_name(e: &Expr) -> Result<String> {
let expr = create_name(expr)?;
Ok(format!("{expr}[{key}]"))
}
- Expr::ScalarFunction { fun, args, .. } => {
- create_function_name(&fun.to_string(), false, args)
+ Expr::ScalarFunction(func) => {
+ create_function_name(&func.fun.to_string(), false, &func.args)
}
Expr::ScalarUDF { fun, args, .. } => create_function_name(&fun.name,
false, args),
Expr::WindowFunction(WindowFunction {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index a233ce8260..a9c0b1fcaf 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -17,7 +17,9 @@
//! Functions for creating logical expressions
-use crate::expr::{AggregateFunction, BinaryExpr, Cast, GroupingSet, TryCast};
+use crate::expr::{
+ AggregateFunction, BinaryExpr, Cast, GroupingSet, ScalarFunction, TryCast,
+};
use crate::{
aggregate_function, built_in_function,
conditional_expressions::CaseBuilder,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
@@ -214,10 +216,10 @@ pub fn in_list(expr: Expr, list: Vec<Expr>, negated:
bool) -> Expr {
/// Concatenates the text representations of all the arguments. NULL arguments
are ignored.
pub fn concat(args: &[Expr]) -> Expr {
- Expr::ScalarFunction {
- fun: built_in_function::BuiltinScalarFunction::Concat,
- args: args.to_vec(),
- }
+ Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::Concat,
+ args.to_vec(),
+ ))
}
/// Concatenates all but the first argument, with separators.
@@ -226,26 +228,20 @@ pub fn concat(args: &[Expr]) -> Expr {
pub fn concat_ws(sep: Expr, values: Vec<Expr>) -> Expr {
let mut args = values;
args.insert(0, sep);
- Expr::ScalarFunction {
- fun: built_in_function::BuiltinScalarFunction::ConcatWithSeparator,
+ Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::ConcatWithSeparator,
args,
- }
+ ))
}
/// Returns an approximate value of π
pub fn pi() -> Expr {
- Expr::ScalarFunction {
- fun: built_in_function::BuiltinScalarFunction::Pi,
- args: vec![],
- }
+ Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Pi,
vec![]))
}
/// Returns a random value in the range 0.0 <= x < 1.0
pub fn random() -> Expr {
- Expr::ScalarFunction {
- fun: built_in_function::BuiltinScalarFunction::Random,
- args: vec![],
- }
+ Expr::ScalarFunction(ScalarFunction::new(BuiltinScalarFunction::Random,
vec![]))
}
/// Returns the approximate number of distinct input values.
@@ -441,10 +437,10 @@ macro_rules! scalar_expr {
($ENUM:ident, $FUNC:ident, $($arg:ident)*, $DOC:expr) => {
#[doc = $DOC ]
pub fn $FUNC($($arg: Expr),*) -> Expr {
- Expr::ScalarFunction {
- fun: built_in_function::BuiltinScalarFunction::$ENUM,
- args: vec![$($arg),*],
- }
+ Expr::ScalarFunction(ScalarFunction::new(
+ built_in_function::BuiltinScalarFunction::$ENUM,
+ vec![$($arg),*],
+ ))
}
};
}
@@ -453,10 +449,10 @@ macro_rules! nary_scalar_expr {
($ENUM:ident, $FUNC:ident, $DOC:expr) => {
#[doc = $DOC ]
pub fn $FUNC(args: Vec<Expr>) -> Expr {
- Expr::ScalarFunction {
- fun: built_in_function::BuiltinScalarFunction::$ENUM,
+ Expr::ScalarFunction(ScalarFunction::new(
+ built_in_function::BuiltinScalarFunction::$ENUM,
args,
- }
+ ))
}
};
}
@@ -712,7 +708,7 @@ pub fn create_udaf(
/// ```
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
- Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
+ Ok(fun) => Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))),
Err(e) => Err(e),
}
}
@@ -735,7 +731,9 @@ mod test {
macro_rules! test_unary_scalar_expr {
($ENUM:ident, $FUNC:ident) => {{
- if let Expr::ScalarFunction { fun, args } = $FUNC(col("tableA.a"))
{
+ if let Expr::ScalarFunction(ScalarFunction { fun, args }) =
+ $FUNC(col("tableA.a"))
+ {
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(1, args.len());
@@ -753,7 +751,7 @@ mod test {
col(stringify!($arg.to_string()))
),*
);
- if let Expr::ScalarFunction { fun, args } = result {
+ if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result
{
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
@@ -773,7 +771,7 @@ mod test {
),*
]
);
- if let Expr::ScalarFunction { fun, args } = result {
+ if let Expr::ScalarFunction(ScalarFunction { fun, args }) = result
{
let name = built_in_function::BuiltinScalarFunction::$ENUM;
assert_eq!(name, fun);
assert_eq!(expected.len(), args.len());
@@ -877,8 +875,8 @@ mod test {
#[test]
fn uuid_function_definitions() {
- if let Expr::ScalarFunction { fun, args } = uuid() {
- let name = built_in_function::BuiltinScalarFunction::Uuid;
+ if let Expr::ScalarFunction(ScalarFunction { fun, args }) = uuid() {
+ let name = BuiltinScalarFunction::Uuid;
assert_eq!(name, fun);
assert_eq!(0, args.len());
} else {
@@ -888,7 +886,9 @@ mod test {
#[test]
fn digest_function_definitions() {
- if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"),
lit("md5")) {
+ if let Expr::ScalarFunction(ScalarFunction { fun, args }) =
+ digest(col("tableA.a"), lit("md5"))
+ {
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
assert_eq!(2, args.len());
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index ba37cf6d45..4cdf8debca 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -17,7 +17,8 @@
use super::{Between, Expr, Like};
use crate::expr::{
- AggregateFunction, BinaryExpr, Cast, GetIndexedField, Sort, TryCast,
WindowFunction,
+ AggregateFunction, BinaryExpr, Cast, GetIndexedField, ScalarFunction,
Sort, TryCast,
+ WindowFunction,
};
use crate::field_util::get_indexed_field;
use crate::type_coercion::binary::get_result_type;
@@ -101,7 +102,7 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
- Expr::ScalarFunction { fun, args } => {
+ Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
@@ -216,7 +217,7 @@ impl ExprSchemable for Expr {
Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
Expr::ScalarVariable(_, _)
| Expr::TryCast { .. }
- | Expr::ScalarFunction { .. }
+ | Expr::ScalarFunction(..)
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
diff --git a/datafusion/expr/src/tree_node/expr.rs
b/datafusion/expr/src/tree_node/expr.rs
index b0a5e31da0..37df3ce201 100644
--- a/datafusion/expr/src/tree_node/expr.rs
+++ b/datafusion/expr/src/tree_node/expr.rs
@@ -19,7 +19,7 @@
use crate::expr::{
AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField,
GroupingSet,
- Like, Sort, TryCast, WindowFunction,
+ Like, ScalarFunction, Sort, TryCast, WindowFunction,
};
use crate::Expr;
use datafusion_common::tree_node::VisitRecursion;
@@ -51,7 +51,7 @@ impl TreeNode for Expr {
}
Expr::GroupingSet(GroupingSet::Rollup(exprs))
| Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.clone(),
- Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. }
=> {
+ Expr::ScalarFunction (ScalarFunction{ args, .. } )|
Expr::ScalarUDF { args, .. } => {
args.clone()
}
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
@@ -267,10 +267,9 @@ impl TreeNode for Expr {
asc,
nulls_first,
)),
- Expr::ScalarFunction { args, fun } => Expr::ScalarFunction {
- args: transform_vec(args, &mut transform)?,
- fun,
- },
+ Expr::ScalarFunction(ScalarFunction { args, fun }) =>
Expr::ScalarFunction(
+ ScalarFunction::new(fun, transform_vec(args, &mut transform)?),
+ ),
Expr::ScalarUDF { args, fun } => Expr::ScalarUDF {
args: transform_vec(args, &mut transform)?,
fun,
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 00e1d07693..41c40352d9 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -295,7 +295,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut
HashSet<Column>) -> Result<()> {
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::Sort { .. }
- | Expr::ScalarFunction { .. }
+ | Expr::ScalarFunction(..)
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index 676cb82509..a7e45c9fd2 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -24,7 +24,9 @@ use arrow::datatypes::{DataType, IntervalUnit};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue};
-use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like,
WindowFunction};
+use datafusion_expr::expr::{
+ self, Between, BinaryExpr, Case, Like, ScalarFunction, WindowFunction,
+};
use datafusion_expr::expr_schema::cast_subquery;
use datafusion_expr::logical_plan::Subquery;
use datafusion_expr::type_coercion::binary::{
@@ -379,16 +381,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
- Expr::ScalarFunction { fun, args } => {
+ Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let nex_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&function::signature(&fun),
)?;
- let expr = Expr::ScalarFunction {
- fun,
- args: nex_expr,
- };
+ let expr = Expr::ScalarFunction(ScalarFunction::new(fun,
nex_expr));
Ok(expr)
}
Expr::AggregateFunction(expr::AggregateFunction {
@@ -746,7 +745,7 @@ mod test {
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result,
ScalarValue};
- use datafusion_expr::expr::{self, Like};
+ use datafusion_expr::expr::{self, Like, ScalarFunction};
use datafusion_expr::{
cast, col, concat, concat_ws, create_udaf, is_true,
AccumulatorFunctionImplementation, AggregateFunction, AggregateUDF,
BinaryExpr,
@@ -862,10 +861,8 @@ mod test {
let empty = empty();
let lit_expr = lit(10i64);
let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs;
- let scalar_function_expr = Expr::ScalarFunction {
- fun,
- args: vec![lit_expr],
- };
+ let scalar_function_expr =
+ Expr::ScalarFunction(ScalarFunction::new(fun, vec![lit_expr]));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![scalar_function_expr],
empty,
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 1c0b4b07e5..33ea7d1867 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -184,7 +184,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) ->
Result<bool> {
| Expr::Case(_)
| Expr::Cast(_)
| Expr::TryCast(_)
- | Expr::ScalarFunction { .. }
+ | Expr::ScalarFunction(..)
| Expr::InList { .. } => Ok(VisitRecursion::Continue),
Expr::Sort(_)
| Expr::AggregateFunction(_)
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 20bee905c4..5e8248c2b8 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -28,6 +28,7 @@ use arrow::{
};
use datafusion_common::tree_node::{RewriteRecursion, TreeNode,
TreeNodeRewriter};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
and, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Like,
Volatility,
@@ -265,7 +266,9 @@ impl<'a> ConstEvaluator<'a> {
| Expr::Wildcard
| Expr::QualifiedWildcard { .. }
| Expr::Placeholder { .. } => false,
- Expr::ScalarFunction { fun, .. } =>
Self::volatility_ok(fun.volatility()),
+ Expr::ScalarFunction(ScalarFunction { fun, .. }) => {
+ Self::volatility_ok(fun.volatility())
+ }
Expr::ScalarUDF { fun, .. } =>
Self::volatility_ok(fun.signature.volatility),
Expr::Literal(_)
| Expr::BinaryExpr { .. }
@@ -1073,33 +1076,33 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for
Simplifier<'a, S> {
}
// log
- Expr::ScalarFunction {
+ Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::Log,
args,
- } => simpl_log(args, <&S>::clone(&info))?,
+ }) => simpl_log(args, <&S>::clone(&info))?,
// power
- Expr::ScalarFunction {
+ Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::Power,
args,
- } => simpl_power(args, <&S>::clone(&info))?,
+ }) => simpl_power(args, <&S>::clone(&info))?,
// concat
- Expr::ScalarFunction {
+ Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::Concat,
args,
- } => simpl_concat(args)?,
+ }) => simpl_concat(args)?,
// concat_ws
- Expr::ScalarFunction {
+ Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::ConcatWithSeparator,
args,
- } => match &args[..] {
+ }) => match &args[..] {
[delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?,
- _ => Expr::ScalarFunction {
- fun: BuiltinScalarFunction::ConcatWithSeparator,
+ _ => Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::ConcatWithSeparator,
args,
- },
+ )),
},
//
@@ -1385,7 +1388,7 @@ mod tests {
// rand() + (1 + 2) --> rand() + 3
let fun = BuiltinScalarFunction::Random;
assert_eq!(fun.volatility(), Volatility::Volatile);
- let rand = Expr::ScalarFunction { args: vec![], fun };
+ let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![]));
let expr = rand.clone() + (lit(1) + lit(2));
let expected = rand + lit(3);
test_evaluate(expr, expected);
@@ -1393,7 +1396,7 @@ mod tests {
// parenthesization matters: can't rewrite
// (rand() + 1) + 2 --> (rand() + 1) + 2)
let fun = BuiltinScalarFunction::Random;
- let rand = Expr::ScalarFunction { args: vec![], fun };
+ let rand = Expr::ScalarFunction(ScalarFunction::new(fun, vec![]));
let expr = (rand + lit(1)) + lit(2);
test_evaluate(expr.clone(), expr);
}
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index a69b48a0d0..094647ca5c 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -19,6 +19,7 @@
use crate::simplify_expressions::SimplifyInfo;
use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
expr::{Between, BinaryExpr},
expr_fn::{and, bitwise_and, bitwise_or, concat_ws, or},
@@ -371,20 +372,20 @@ pub fn simpl_log(current_args: Vec<Expr>, info: &dyn
SimplifyInfo) -> Result<Exp
&info.get_data_type(base)?,
)?))
}
- Expr::ScalarFunction {
+ Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::Power,
args,
- } if base == &args[0] => Ok(args[1].clone()),
+ }) if base == &args[0] => Ok(args[1].clone()),
_ => {
if number == base {
Ok(Expr::Literal(ScalarValue::new_one(
&info.get_data_type(number)?,
)?))
} else {
- Ok(Expr::ScalarFunction {
- fun: BuiltinScalarFunction::Log,
- args: vec![base.clone(), number.clone()],
- })
+ Ok(Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::Log,
+ vec![base.clone(), number.clone()],
+ )))
}
}
}
@@ -411,14 +412,14 @@ pub fn simpl_power(current_args: Vec<Expr>, info: &dyn
SimplifyInfo) -> Result<E
{
Ok(base.clone())
}
- Expr::ScalarFunction {
+ Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::Log,
args,
- } if base == &args[0] => Ok(args[1].clone()),
- _ => Ok(Expr::ScalarFunction {
- fun: BuiltinScalarFunction::Power,
- args: current_args,
- }),
+ }) if base == &args[0] => Ok(args[1].clone()),
+ _ => Ok(Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::Power,
+ current_args,
+ ))),
}
}
@@ -463,10 +464,10 @@ pub fn simpl_concat(args: Vec<Expr>) -> Result<Expr> {
new_args.push(lit(contiguous_scalar));
}
- Ok(Expr::ScalarFunction {
- fun: BuiltinScalarFunction::Concat,
- args: new_args,
- })
+ Ok(Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::Concat,
+ new_args,
+ )))
}
/// Simply the `concat_ws` function by
@@ -517,10 +518,10 @@ pub fn simpl_concat_ws(delimiter: &Expr, args: &[Expr])
-> Result<Expr> {
if let Some(val) = contiguous_scalar {
new_args.push(lit(val));
}
- Ok(Expr::ScalarFunction {
- fun: BuiltinScalarFunction::ConcatWithSeparator,
- args: new_args,
- })
+ Ok(Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::ConcatWithSeparator,
+ new_args,
+ )))
}
// if the delimiter is null, then the value of the whole
expression is null.
None => Ok(Expr::Literal(ScalarValue::Utf8(None))),
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index f1bb35c2e2..a1fdce4f3a 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -27,7 +27,7 @@ use crate::{
};
use arrow::datatypes::{DataType, Schema};
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
-use datafusion_expr::expr::Cast;
+use datafusion_expr::expr::{Cast, ScalarFunction};
use datafusion_expr::{
binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator,
TryCast,
};
@@ -383,7 +383,7 @@ pub fn create_physical_expr(
)))
}
- Expr::ScalarFunction { fun, args } => {
+ Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let physical_args = args
.iter()
.map(|e| {
diff --git a/datafusion/proto/src/logical_plan/mod.rs
b/datafusion/proto/src/logical_plan/mod.rs
index 8789b7302e..04e30ed8d8 100644
--- a/datafusion/proto/src/logical_plan/mod.rs
+++ b/datafusion/proto/src/logical_plan/mod.rs
@@ -1416,7 +1416,7 @@ mod roundtrip_tests {
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
- self, Between, BinaryExpr, Case, Cast, GroupingSet, Like, Sort,
+ self, Between, BinaryExpr, Case, Cast, GroupingSet, Like,
ScalarFunction, Sort,
};
use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
use datafusion_expr::{
@@ -2458,10 +2458,7 @@ mod roundtrip_tests {
#[test]
fn roundtrip_sqrt() {
- let test_expr = Expr::ScalarFunction {
- fun: Sqrt,
- args: vec![col("col")],
- };
+ let test_expr = Expr::ScalarFunction(ScalarFunction::new(Sqrt,
vec![col("col")]));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx);
}
@@ -2672,16 +2669,16 @@ mod roundtrip_tests {
#[test]
fn roundtrip_substr() {
// substr(string, position)
- let test_expr = Expr::ScalarFunction {
- fun: Substr,
- args: vec![col("col"), lit(1_i64)],
- };
+ let test_expr = Expr::ScalarFunction(ScalarFunction::new(
+ Substr,
+ vec![col("col"), lit(1_i64)],
+ ));
// substr(string, position, count)
- let test_expr_with_count = Expr::ScalarFunction {
- fun: Substr,
- args: vec![col("col"), lit(1_i64), lit(1_i64)],
- };
+ let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new(
+ Substr,
+ vec![col("col"), lit(1_i64), lit(1_i64)],
+ ));
let ctx = SessionContext::new();
roundtrip_expr_test(test_expr, ctx.clone());
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index e757f7830b..1897be4dba 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -36,7 +36,8 @@ use arrow::datatypes::{
};
use datafusion_common::{Column, DFField, DFSchemaRef, OwnedTableReference,
ScalarValue};
use datafusion_expr::expr::{
- self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, Sort,
+ self, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like,
ScalarFunction,
+ Sort,
};
use datafusion_expr::{
logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction,
@@ -679,7 +680,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
.to_string(),
))
}
- Expr::ScalarFunction { ref fun, ref args } => {
+ Expr::ScalarFunction(ScalarFunction { fun, args }) => {
let fun: protobuf::ScalarFunction = fun.try_into()?;
let args: Vec<Self> = args
.iter()
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index bf076a2f30..996221a61e 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -17,6 +17,7 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{DFSchema, DataFusionError, Result};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
@@ -49,7 +50,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if let Ok(fun) = BuiltinScalarFunction::from_str(&name) {
let args =
self.function_args_to_expr(function.args, schema,
planner_context)?;
- return Ok(Expr::ScalarFunction { fun, args });
+ return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)));
};
// then, window function
@@ -153,7 +154,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
planner_context: &mut PlannerContext,
) -> Result<Expr> {
let args = vec![self.sql_expr_to_logical_expr(expr, schema,
planner_context)?];
- Ok(Expr::ScalarFunction { fun, args })
+ Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}
pub(super) fn find_window_func(&self, name: &str) ->
Result<WindowFunction> {
diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs
index 67fe8b1e99..eed4487216 100644
--- a/datafusion/sql/src/expr/mod.rs
+++ b/datafusion/sql/src/expr/mod.rs
@@ -30,6 +30,7 @@ use crate::planner::{ContextProvider, PlannerContext,
SqlToRel};
use arrow_schema::DataType;
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{Column, DFSchema, DataFusionError, Result,
ScalarValue};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr,
BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetIndexedField, Like, Operator, TryCast,
@@ -121,13 +122,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(value) => {
self.parse_value(value,
planner_context.prepare_param_data_types())
}
- SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
- fun: BuiltinScalarFunction::DatePart,
- args: vec![
+ SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction
(ScalarFunction::new(
+ BuiltinScalarFunction::DatePart,
+ vec![
Expr::Literal(ScalarValue::Utf8(Some(format!("{field}")))),
self.sql_expr_to_logical_expr(*expr, schema,
planner_context)?,
],
- }),
+ ))),
SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema),
SQLExpr::Interval {
@@ -464,7 +465,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
None => vec![arg],
};
- Ok(Expr::ScalarFunction { fun, args })
+ Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args)))
}
fn sql_agg_with_filter_to_expr(
diff --git a/datafusion/sql/src/expr/substring.rs
b/datafusion/sql/src/expr/substring.rs
index 991f82a67b..1a95266615 100644
--- a/datafusion/sql/src/expr/substring.rs
+++ b/datafusion/sql/src/expr/substring.rs
@@ -17,6 +17,7 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
+use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{BuiltinScalarFunction, Expr};
use sqlparser::ast::Expr as SQLExpr;
@@ -67,9 +68,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
};
- Ok(Expr::ScalarFunction {
- fun: BuiltinScalarFunction::Substr,
+ Ok(Expr::ScalarFunction(ScalarFunction::new(
+ BuiltinScalarFunction::Substr,
args,
- })
+ )))
}
}
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index 91cef6d471..11fc0a58cf 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -23,7 +23,7 @@ use sqlparser::ast::Ident;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{
AggregateFunction, Between, BinaryExpr, Case, GetIndexedField,
GroupingSet, Like,
- WindowFunction,
+ ScalarFunction, WindowFunction,
};
use datafusion_expr::expr::{Cast, Sort};
use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
@@ -295,13 +295,14 @@ where
None => None,
},
))),
- Expr::ScalarFunction { fun, args } => Ok(Expr::ScalarFunction {
- fun: fun.clone(),
- args: args
- .iter()
- .map(|e| clone_with_replacement(e, replacement_fn))
- .collect::<Result<Vec<Expr>>>()?,
- }),
+ Expr::ScalarFunction(ScalarFunction { fun, args }) => {
+ Ok(Expr::ScalarFunction(ScalarFunction::new(
+ fun.clone(),
+ args.iter()
+ .map(|e| clone_with_replacement(e, replacement_fn))
+ .collect::<Result<Vec<Expr>>>()?,
+ )))
+ }
Expr::ScalarUDF { fun, args } => Ok(Expr::ScalarUDF {
fun: fun.clone(),
args: args