This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ece45554af `AggregateUDFImpl::schema_name` and
`AggregateUDFImpl::display_name` for customizable name (#14695)
ece45554af is described below
commit ece45554af6e022d2e569c946cbdfd0f19c02aee
Author: Jay Zhan <[email protected]>
AuthorDate: Mon Feb 17 20:12:44 2025 +0800
`AggregateUDFImpl::schema_name` and `AggregateUDFImpl::display_name` for
customizable name (#14695)
* udaf schema_name
* doc
* fix proto
* fmt
* fix
* fmt
* doc
* add displayname
* doc
---
datafusion-examples/examples/advanced_udaf.rs | 10 +-
datafusion/core/src/physical_planner.rs | 16 +--
datafusion/core/tests/execution/logical_plan.rs | 14 +--
datafusion/expr/src/expr.rs | 120 +++++++++------------
datafusion/expr/src/expr_fn.rs | 8 +-
datafusion/expr/src/expr_schema.rs | 9 +-
datafusion/expr/src/tree_node.rs | 20 ++--
datafusion/expr/src/udaf.rs | 104 +++++++++++++++++-
datafusion/functions-nested/src/planner.rs | 45 ++++----
.../optimizer/src/analyzer/count_wildcard_rule.rs | 7 +-
.../src/analyzer/resolve_grouping_function.rs | 3 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 17 +--
.../optimizer/src/single_distinct_to_groupby.rs | 18 ++--
datafusion/proto/src/logical_plan/to_proto.rs | 17 +--
datafusion/sql/src/unparser/expr.rs | 15 ++-
datafusion/substrait/src/logical_plan/producer.rs | 16 +--
16 files changed, 280 insertions(+), 159 deletions(-)
diff --git a/datafusion-examples/examples/advanced_udaf.rs
b/datafusion-examples/examples/advanced_udaf.rs
index fd65c3352b..9cda726db7 100644
--- a/datafusion-examples/examples/advanced_udaf.rs
+++ b/datafusion-examples/examples/advanced_udaf.rs
@@ -423,11 +423,11 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
// In real-world scenarios, you might create UDFs from built-in
expressions.
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
- aggregate_function.args,
- aggregate_function.distinct,
- aggregate_function.filter,
- aggregate_function.order_by,
- aggregate_function.null_treatment,
+ aggregate_function.params.args,
+ aggregate_function.params.distinct,
+ aggregate_function.params.filter,
+ aggregate_function.params.order_by,
+ aggregate_function.params.null_treatment,
)))
};
Some(Box::new(simplify))
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 2303574e88..bce1aab16e 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -70,7 +70,8 @@ use datafusion_common::{
};
use datafusion_expr::dml::{CopyTo, InsertOp};
use datafusion_expr::expr::{
- physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction,
+ physical_name, AggregateFunction, AggregateFunctionParams, Alias,
GroupingSet,
+ WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use
datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
@@ -1579,11 +1580,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
match e {
Expr::AggregateFunction(AggregateFunction {
func,
- distinct,
- args,
- filter,
- order_by,
- null_treatment,
+ params:
+ AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ },
}) => {
let name = if let Some(name) = name {
name
diff --git a/datafusion/core/tests/execution/logical_plan.rs
b/datafusion/core/tests/execution/logical_plan.rs
index a521902389..a17bb5eec8 100644
--- a/datafusion/core/tests/execution/logical_plan.rs
+++ b/datafusion/core/tests/execution/logical_plan.rs
@@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans};
use datafusion_execution::TaskContext;
-use datafusion_expr::expr::AggregateFunction;
+use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::logical_plan::{LogicalPlan, Values};
use datafusion_expr::{Aggregate, AggregateUDF, Expr};
use datafusion_functions_aggregate::count::Count;
@@ -60,11 +60,13 @@ async fn count_only_nulls() -> Result<()> {
vec![],
vec![Expr::AggregateFunction(AggregateFunction {
func: Arc::new(AggregateUDF::new_from_impl(Count::new())),
- args: vec![input_col_ref],
- distinct: false,
- filter: None,
- order_by: None,
- null_treatment: None,
+ params: AggregateFunctionParams {
+ args: vec![input_col_ref],
+ distinct: false,
+ filter: None,
+ order_by: None,
+ null_treatment: None,
+ },
})],
)?);
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 305519a1f4..84ff36a931 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -696,7 +696,11 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort {
pub struct AggregateFunction {
/// Name of the function
pub func: Arc<crate::AggregateUDF>,
- /// List of expressions to feed to the functions as arguments
+ pub params: AggregateFunctionParams,
+}
+
+#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
+pub struct AggregateFunctionParams {
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
pub distinct: bool,
@@ -719,11 +723,13 @@ impl AggregateFunction {
) -> Self {
Self {
func,
- args,
- distinct,
- filter,
- order_by,
- null_treatment,
+ params: AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ },
}
}
}
@@ -1864,19 +1870,25 @@ impl NormalizeEq for Expr {
(
Expr::AggregateFunction(AggregateFunction {
func: self_func,
- args: self_args,
- distinct: self_distinct,
- filter: self_filter,
- order_by: self_order_by,
- null_treatment: self_null_treatment,
+ params:
+ AggregateFunctionParams {
+ args: self_args,
+ distinct: self_distinct,
+ filter: self_filter,
+ order_by: self_order_by,
+ null_treatment: self_null_treatment,
+ },
}),
Expr::AggregateFunction(AggregateFunction {
func: other_func,
- args: other_args,
- distinct: other_distinct,
- filter: other_filter,
- order_by: other_order_by,
- null_treatment: other_null_treatment,
+ params:
+ AggregateFunctionParams {
+ args: other_args,
+ distinct: other_distinct,
+ filter: other_filter,
+ order_by: other_order_by,
+ null_treatment: other_null_treatment,
+ },
}),
) => {
self_func.name() == other_func.name()
@@ -2154,11 +2166,14 @@ impl HashNode for Expr {
}
Expr::AggregateFunction(AggregateFunction {
func,
- args: _args,
- distinct,
- filter: _filter,
- order_by: _order_by,
- null_treatment,
+ params:
+ AggregateFunctionParams {
+ args: _args,
+ distinct,
+ filter: _,
+ order_by: _,
+ null_treatment,
+ },
}) => {
func.hash(state);
distinct.hash(state);
@@ -2264,35 +2279,15 @@ impl Display for SchemaDisplay<'_> {
| Expr::Placeholder(_)
| Expr::Wildcard { .. } => write!(f, "{}", self.0),
- Expr::AggregateFunction(AggregateFunction {
- func,
- args,
- distinct,
- filter,
- order_by,
- null_treatment,
- }) => {
- write!(
- f,
- "{}({}{})",
- func.name(),
- if *distinct { "DISTINCT " } else { "" },
- schema_name_from_exprs_comma_separated_without_space(args)?
- )?;
-
- if let Some(null_treatment) = null_treatment {
- write!(f, " {}", null_treatment)?;
+ Expr::AggregateFunction(AggregateFunction { func, params }) => {
+ match func.schema_name(params) {
+ Ok(name) => {
+ write!(f, "{name}")
+ }
+ Err(e) => {
+ write!(f, "got error from schema_name {}", e)
+ }
}
-
- if let Some(filter) = filter {
- write!(f, " FILTER (WHERE {filter})")?;
- };
-
- if let Some(order_by) = order_by {
- write!(f, " ORDER BY [{}]",
schema_name_from_sorts(order_by)?)?;
- };
-
- Ok(())
}
// Expr is not shown since it is aliased
Expr::Alias(Alias {
@@ -2653,26 +2648,15 @@ impl Display for Expr {
)?;
Ok(())
}
- Expr::AggregateFunction(AggregateFunction {
- func,
- distinct,
- ref args,
- filter,
- order_by,
- null_treatment,
- ..
- }) => {
- fmt_function(f, func.name(), *distinct, args, true)?;
- if let Some(nt) = null_treatment {
- write!(f, " {}", nt)?;
- }
- if let Some(fe) = filter {
- write!(f, " FILTER (WHERE {fe})")?;
- }
- if let Some(ob) = order_by {
- write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?;
+ Expr::AggregateFunction(AggregateFunction { func, params }) => {
+ match func.display_name(params) {
+ Ok(name) => {
+ write!(f, "{}", name)
+ }
+ Err(e) => {
+ write!(f, "got error from display_name {}", e)
+ }
}
- Ok(())
}
Expr::Between(Between {
expr,
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index eb5f98930a..a0425bf847 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -826,10 +826,10 @@ impl ExprFuncBuilder {
let fun_expr = match fun {
ExprFuncKind::Aggregate(mut udaf) => {
- udaf.order_by = order_by;
- udaf.filter = filter.map(Box::new);
- udaf.distinct = distinct;
- udaf.null_treatment = null_treatment;
+ udaf.params.order_by = order_by;
+ udaf.params.filter = filter.map(Box::new);
+ udaf.params.distinct = distinct;
+ udaf.params.null_treatment = null_treatment;
Expr::AggregateFunction(udaf)
}
ExprFuncKind::Window(mut udwf) => {
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 4979142713..becb7c1439 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -17,8 +17,8 @@
use super::{Between, Expr, Like};
use crate::expr::{
- AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery,
Placeholder,
- ScalarFunction, TryCast, Unnest, WindowFunction,
+ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast,
InList,
+ InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
};
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf,
data_types_with_window_udf,
@@ -153,7 +153,10 @@ impl ExprSchemable for Expr {
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema,
window_function)
.map(|(return_type, _)| return_type),
- Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
+ Expr::AggregateFunction(AggregateFunction {
+ func,
+ params: AggregateFunctionParams { args, .. },
+ }) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index eacace5ed0..7801d56413 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -18,8 +18,9 @@
//! Tree node implementation for Logical Expressions
use crate::expr::{
- AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet,
InList,
- InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
WindowFunction,
+ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr,
Case, Cast,
+ GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction,
TryCast, Unnest,
+ WindowFunction,
};
use crate::{Expr, ExprFunctionExt};
@@ -87,7 +88,7 @@ impl TreeNode for Expr {
}) => (expr, low, high).apply_ref_elements(f),
Expr::Case(Case { expr, when_then_expr, else_expr }) =>
(expr, when_then_expr, else_expr).apply_ref_elements(f),
- Expr::AggregateFunction(AggregateFunction { args, filter,
order_by, .. }) =>
+ Expr::AggregateFunction(AggregateFunction { params:
AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
(args, filter, order_by).apply_ref_elements(f),
Expr::WindowFunction(WindowFunction {
args,
@@ -241,12 +242,15 @@ impl TreeNode for Expr {
},
),
Expr::AggregateFunction(AggregateFunction {
- args,
func,
- distinct,
- filter,
- order_by,
- null_treatment,
+ params:
+ AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ },
}) => (args, filter, order_by).map_elements(f)?.map_data(
|(new_args, new_filter, new_order_by)| {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 7ffc6623ea..bf8f34f949 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -19,7 +19,7 @@
use std::any::Any;
use std::cmp::Ordering;
-use std::fmt::{self, Debug, Formatter};
+use std::fmt::{self, Debug, Formatter, Write};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::vec;
@@ -29,7 +29,10 @@ use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue,
Statistics};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
-use crate::expr::AggregateFunction;
+use crate::expr::{
+ schema_name_from_exprs_comma_separated_without_space,
schema_name_from_sorts,
+ AggregateFunction, AggregateFunctionParams,
+};
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
@@ -165,6 +168,16 @@ impl AggregateUDF {
self.inner.name()
}
+ /// See [`AggregateUDFImpl::schema_name`] for more details.
+ pub fn schema_name(&self, params: &AggregateFunctionParams) ->
Result<String> {
+ self.inner.schema_name(params)
+ }
+
+ /// See [`AggregateUDFImpl::display_name`] for more details.
+ pub fn display_name(&self, params: &AggregateFunctionParams) ->
Result<String> {
+ self.inner.display_name(params)
+ }
+
pub fn is_nullable(&self) -> bool {
self.inner.is_nullable()
}
@@ -382,6 +395,93 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
/// Returns this function's name
fn name(&self) -> &str;
+ /// Returns the name of the column this expression would create
+ ///
+ /// See [`Expr::schema_name`] for details
+ ///
+ /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2
> 10) ORDER BY [..]
+ fn schema_name(&self, params: &AggregateFunctionParams) -> Result<String> {
+ let AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ } = params;
+
+ let mut schema_name = String::new();
+
+ schema_name.write_fmt(format_args!(
+ "{}({}{})",
+ self.name(),
+ if *distinct { "DISTINCT " } else { "" },
+ schema_name_from_exprs_comma_separated_without_space(args)?
+ ))?;
+
+ if let Some(null_treatment) = null_treatment {
+ schema_name.write_fmt(format_args!(" {}", null_treatment))?;
+ }
+
+ if let Some(filter) = filter {
+ schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?;
+ };
+
+ if let Some(order_by) = order_by {
+ schema_name.write_fmt(format_args!(
+ " ORDER BY [{}]",
+ schema_name_from_sorts(order_by)?
+ ))?;
+ };
+
+ Ok(schema_name)
+ }
+
+ /// Returns the user-defined display name of function, given the arguments
+ ///
+ /// This can be used to customize the output column name generated by this
+ /// function.
+ ///
+ /// Defaults to `function_name([DISTINCT] column1, column2, ..)
[null_treatment] [filter] [order_by [..]]`
+ fn display_name(&self, params: &AggregateFunctionParams) -> Result<String>
{
+ let AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ } = params;
+
+ let mut schema_name = String::new();
+
+ schema_name.write_fmt(format_args!(
+ "{}({}{})",
+ self.name(),
+ if *distinct { "DISTINCT " } else { "" },
+ args.iter()
+ .map(|arg| format!("{arg}"))
+ .collect::<Vec<String>>()
+ .join(", ")
+ ))?;
+
+ if let Some(nt) = null_treatment {
+ schema_name.write_fmt(format_args!(" {}", nt))?;
+ }
+ if let Some(fe) = filter {
+ schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
+ }
+ if let Some(ob) = order_by {
+ schema_name.write_fmt(format_args!(
+ " ORDER BY [{}]",
+ ob.iter()
+ .map(|o| format!("{o}"))
+ .collect::<Vec<String>>()
+ .join(", ")
+ ))?;
+ }
+
+ Ok(schema_name)
+ }
+
/// Returns the function's [`Signature`] for information about what input
/// types are accepted and the function's Volatility.
fn signature(&self) -> &Signature;
diff --git a/datafusion/functions-nested/src/planner.rs
b/datafusion/functions-nested/src/planner.rs
index 5ca51ac20f..d55176a42c 100644
--- a/datafusion/functions-nested/src/planner.rs
+++ b/datafusion/functions-nested/src/planner.rs
@@ -17,8 +17,11 @@
//! SQL planning extensions like [`NestedFunctionPlanner`] and
[`FieldAccessPlanner`]
+use std::sync::Arc;
+
use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
-use datafusion_expr::expr::ScalarFunction;
+use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams,
ScalarFunction};
+use datafusion_expr::AggregateUDF;
use datafusion_expr::{
planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
sqlparser, Expr, ExprSchemable, GetFieldAccess,
@@ -150,22 +153,26 @@ impl ExprPlanner for FieldAccessPlanner {
GetFieldAccess::ListIndex { key: index } => {
match expr {
// Special case for array_agg(expr)[index] to
NTH_VALUE(expr, index)
- Expr::AggregateFunction(agg_func) if
is_array_agg(&agg_func) => {
- Ok(PlannerResult::Planned(Expr::AggregateFunction(
- datafusion_expr::expr::AggregateFunction::new_udf(
- nth_value_udaf(),
- agg_func
- .args
- .into_iter()
- .chain(std::iter::once(*index))
- .collect(),
- agg_func.distinct,
- agg_func.filter,
- agg_func.order_by,
- agg_func.null_treatment,
- ),
- )))
- }
+ Expr::AggregateFunction(AggregateFunction {
+ func,
+ params:
+ AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ },
+ }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
+ Expr::AggregateFunction(AggregateFunction::new_udf(
+ nth_value_udaf(),
+
args.into_iter().chain(std::iter::once(*index)).collect(),
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ )),
+ )),
_ => Ok(PlannerResult::Planned(array_element(expr,
*index))),
}
}
@@ -184,6 +191,6 @@ impl ExprPlanner for FieldAccessPlanner {
}
}
-fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
- agg_func.func.name() == "array_agg"
+fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
+ func.name() == "array_agg"
}
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index 95b6f9dc76..7e73474cf6 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -21,7 +21,7 @@ use crate::utils::NamePreserver;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::Result;
-use datafusion_expr::expr::{AggregateFunction, WindowFunction};
+use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams,
WindowFunction};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
@@ -55,8 +55,7 @@ fn is_count_star_aggregate(aggregate_function:
&AggregateFunction) -> bool {
matches!(aggregate_function,
AggregateFunction {
func,
- args,
- ..
+ params: AggregateFunctionParams { args, .. },
} if func.name() == "count" && (args.len() == 1 &&
is_wildcard(&args[0]) || args.is_empty()))
}
@@ -81,7 +80,7 @@ fn analyze_internal(plan: LogicalPlan) ->
Result<Transformed<LogicalPlan>> {
Expr::AggregateFunction(mut aggregate_function)
if is_count_star_aggregate(&aggregate_function) =>
{
- aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)];
+ aggregate_function.params.args =
vec![lit(COUNT_STAR_EXPANSION)];
Ok(Transformed::yes(Expr::AggregateFunction(
aggregate_function,
)))
diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
index 16ebb8cd39..f8a8185636 100644
--- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
+++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
@@ -163,6 +163,7 @@ fn validate_args(
group_by_expr: &HashMap<&Expr, usize>,
) -> Result<()> {
let expr_not_in_group_by = function
+ .params
.args
.iter()
.find(|expr| !group_by_expr.contains_key(expr));
@@ -183,7 +184,7 @@ fn grouping_function_on_id(
is_grouping_set: bool,
) -> Result<Expr> {
validate_args(function, group_by_expr)?;
- let args = &function.args;
+ let args = &function.params.args;
// Postgres allows grouping function for group by without grouping sets,
the result is then
// always 0
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index f7dc4befb1..c7c84dc3d8 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -33,8 +33,8 @@ use datafusion_common::{
DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
TableReference,
};
use datafusion_expr::expr::{
- self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like,
- ScalarFunction, Sort, WindowFunction,
+ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists,
InList,
+ InSubquery, Like, ScalarFunction, Sort, WindowFunction,
};
use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema;
use datafusion_expr::expr_schema::cast_subquery;
@@ -506,11 +506,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
}
Expr::AggregateFunction(expr::AggregateFunction {
func,
- args,
- distinct,
- filter,
- order_by,
- null_treatment,
+ params:
+ AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment,
+ },
}) => {
let new_expr =
coerce_arguments_for_signature_with_aggregate_udf(
args,
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index c8f3a4bc78..191377fc27 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -26,6 +26,7 @@ use datafusion_common::{
internal_err, tree_node::Transformed, DataFusionError, HashSet, Result,
};
use datafusion_expr::builder::project;
+use datafusion_expr::expr::AggregateFunctionParams;
use datafusion_expr::{
col,
expr::AggregateFunction,
@@ -68,11 +69,14 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) ->
Result<bool> {
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
func,
- distinct,
- args,
- filter,
- order_by,
- null_treatment: _,
+ params:
+ AggregateFunctionParams {
+ distinct,
+ args,
+ filter,
+ order_by,
+ null_treatment: _,
+ },
}) = expr
{
if filter.is_some() || order_by.is_some() {
@@ -179,9 +183,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
func,
- mut args,
- distinct,
- ..
+ params: AggregateFunctionParams { mut args,
distinct, .. }
}) => {
if distinct {
if args.len() != 1 {
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs
b/datafusion/proto/src/logical_plan/to_proto.rs
index 6d1d4f3061..2284372716 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -22,8 +22,8 @@
use datafusion_common::{TableReference, UnnestOptions};
use datafusion_expr::dml::InsertOp;
use datafusion_expr::expr::{
- self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like,
Placeholder,
- ScalarFunction, Unnest,
+ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast,
GroupingSet, InList,
+ Like, Placeholder, ScalarFunction, Unnest,
};
use datafusion_expr::WriteOp;
use datafusion_expr::{
@@ -348,11 +348,14 @@ pub fn serialize_expr(
}
Expr::AggregateFunction(expr::AggregateFunction {
ref func,
- ref args,
- ref distinct,
- ref filter,
- ref order_by,
- null_treatment: _,
+ params:
+ AggregateFunctionParams {
+ ref args,
+ ref distinct,
+ ref filter,
+ ref order_by,
+ null_treatment: _,
+ },
}) => {
let mut buf = Vec::new();
let _ = codec.try_encode_udaf(func, &mut buf);
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 7c1bcbd5ac..90630cf82f 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion_expr::expr::Unnest;
+use datafusion_expr::expr::{AggregateFunctionParams, Unnest};
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval,
ObjectName,
@@ -284,9 +284,15 @@ impl Unparser<'_> {
}),
Expr::AggregateFunction(agg) => {
let func_name = agg.func.name();
+ let AggregateFunctionParams {
+ distinct,
+ args,
+ filter,
+ ..
+ } = &agg.params;
- let args = self.function_args_to_sql(&agg.args)?;
- let filter = match &agg.filter {
+ let args = self.function_args_to_sql(args)?;
+ let filter = match filter {
Some(filter) =>
Some(Box::new(self.expr_to_sql_inner(filter)?)),
None => None,
};
@@ -297,8 +303,7 @@ impl Unparser<'_> {
span: Span::empty(),
}]),
args:
ast::FunctionArguments::List(ast::FunctionArgumentList {
- duplicate_treatment: agg
- .distinct
+ duplicate_treatment: distinct
.then_some(ast::DuplicateTreatment::Distinct),
args,
clauses: vec![],
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 42c2261749..d795a86956 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -52,7 +52,8 @@ use datafusion::common::{
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::SessionState;
use datafusion::logical_expr::expr::{
- Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery,
WindowFunction,
+ AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet,
InList,
+ InSubquery, WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan,
Operator};
use datafusion::prelude::Expr;
@@ -1208,11 +1209,14 @@ pub fn from_aggregate_function(
) -> Result<Measure> {
let expr::AggregateFunction {
func,
- args,
- distinct,
- filter,
- order_by,
- null_treatment: _null_treatment,
+ params:
+ AggregateFunctionParams {
+ args,
+ distinct,
+ filter,
+ order_by,
+ null_treatment: _null_treatment,
+ },
} = agg_fn;
let sorts = if let Some(order_by) = order_by {
order_by
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]