This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 16346022ed Support UDAF to align Builtin aggregate function (#10493)
16346022ed is described below
commit 16346022ed4564211c0fb4bf20ac165f2c481a90
Author: Jay Zhan <[email protected]>
AuthorDate: Wed May 15 11:40:25 2024 +0800
Support UDAF to align Builtin aggregate function (#10493)
* align udaf and builtin
Signed-off-by: jayzhan211 <[email protected]>
* add more
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
.../src/physical_optimizer/aggregate_statistics.rs | 76 +++++++++++++---------
datafusion/core/src/physical_planner.rs | 31 +++------
datafusion/expr/src/expr.rs | 12 +---
.../optimizer/src/analyzer/count_wildcard_rule.rs | 44 +++++++++----
.../optimizer/src/single_distinct_to_groupby.rs | 25 +++++++
.../physical-expr-common/src/aggregate/mod.rs | 9 +++
datafusion/physical-plan/src/windows/mod.rs | 1 +
datafusion/proto/src/physical_plan/mod.rs | 2 +-
.../proto/tests/cases/roundtrip_physical_plan.rs | 1 +
datafusion/sql/src/expr/function.rs | 9 ++-
datafusion/substrait/src/logical_plan/producer.rs | 5 +-
11 files changed, 131 insertions(+), 84 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 5057488603..1a82dac465 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -30,6 +30,7 @@ use datafusion_common::stats::Precision;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
+use datafusion_physical_plan::udaf::AggregateFunctionExpr;
/// Optimizer that uses available statistics for aggregate functions
#[derive(Default)]
@@ -57,13 +58,9 @@ impl PhysicalOptimizerRule for AggregateStatistics {
let mut projections = vec![];
for expr in partial_agg_exec.aggr_expr() {
if let Some((non_null_rows, name)) =
- take_optimizable_column_count(&**expr, &stats)
+ take_optimizable_column_and_table_count(&**expr, &stats)
{
projections.push((expressions::lit(non_null_rows),
name.to_owned()));
- } else if let Some((num_rows, name)) =
- take_optimizable_table_count(&**expr, &stats)
- {
- projections.push((expressions::lit(num_rows),
name.to_owned()));
} else if let Some((min, name)) =
take_optimizable_min(&**expr, &stats) {
projections.push((expressions::lit(min), name.to_owned()));
} else if let Some((max, name)) =
take_optimizable_max(&**expr, &stats) {
@@ -137,43 +134,48 @@ fn take_optimizable(node: &dyn ExecutionPlan) ->
Option<Arc<dyn ExecutionPlan>>
None
}
-/// If this agg_expr is a count that is exactly defined in the statistics,
return it.
-fn take_optimizable_table_count(
+/// If this agg_expr is a count that can be exactly derived from the
statistics, return it.
+fn take_optimizable_column_and_table_count(
agg_expr: &dyn AggregateExpr,
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
- if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
- &stats.num_rows,
- agg_expr.as_any().downcast_ref::<expressions::Count>(),
- ) {
- // TODO implementing Eq on PhysicalExpr would help a lot here
- if casted_expr.expressions().len() == 1 {
- if let Some(lit_expr) = casted_expr.expressions()[0]
- .as_any()
- .downcast_ref::<expressions::Literal>()
- {
- if lit_expr.value() == &COUNT_STAR_EXPANSION {
- return Some((
- ScalarValue::Int64(Some(num_rows as i64)),
- casted_expr.name().to_owned(),
- ));
+ let col_stats = &stats.column_statistics;
+ if let Some(agg_expr) =
agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+ if agg_expr.fun().name() == "COUNT" && !agg_expr.is_distinct() {
+ if let Precision::Exact(num_rows) = stats.num_rows {
+ let exprs = agg_expr.expressions();
+ if exprs.len() == 1 {
+ // TODO optimize with exprs other than Column
+ if let Some(col_expr) =
+ exprs[0].as_any().downcast_ref::<expressions::Column>()
+ {
+ let current_val =
&col_stats[col_expr.index()].null_count;
+ if let &Precision::Exact(val) = current_val {
+ return Some((
+ ScalarValue::Int64(Some((num_rows - val) as
i64)),
+ agg_expr.name().to_string(),
+ ));
+ }
+ } else if let Some(lit_expr) =
+
exprs[0].as_any().downcast_ref::<expressions::Literal>()
+ {
+ if lit_expr.value() == &COUNT_STAR_EXPANSION {
+ return Some((
+ ScalarValue::Int64(Some(num_rows as i64)),
+ agg_expr.name().to_string(),
+ ));
+ }
+ }
}
}
}
}
- None
-}
-
-/// If this agg_expr is a count that can be exactly derived from the
statistics, return it.
-fn take_optimizable_column_count(
- agg_expr: &dyn AggregateExpr,
- stats: &Statistics,
-) -> Option<(ScalarValue, String)> {
- let col_stats = &stats.column_statistics;
- if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
+ // TODO: Remove this after revmoing Builtin Count
+ else if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
&stats.num_rows,
agg_expr.as_any().downcast_ref::<expressions::Count>(),
) {
+ // TODO implementing Eq on PhysicalExpr would help a lot here
if casted_expr.expressions().len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = casted_expr.expressions()[0]
@@ -187,6 +189,16 @@ fn take_optimizable_column_count(
casted_expr.name().to_string(),
));
}
+ } else if let Some(lit_expr) = casted_expr.expressions()[0]
+ .as_any()
+ .downcast_ref::<expressions::Literal>()
+ {
+ if lit_expr.value() == &COUNT_STAR_EXPANSION {
+ return Some((
+ ScalarValue::Int64(Some(num_rows as i64)),
+ casted_expr.name().to_owned(),
+ ));
+ }
}
}
}
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index d4a9a949fc..406196a591 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -252,31 +252,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) ->
Result<String> {
func_def,
distinct,
args,
- filter,
+ filter: _,
order_by,
null_treatment: _,
- }) => match func_def {
- AggregateFunctionDefinition::BuiltIn(..) =>
create_function_physical_name(
- func_def.name(),
- *distinct,
- args,
- order_by.as_ref(),
- ),
- AggregateFunctionDefinition::UDF(fun) => {
- // TODO: Add support for filter by in AggregateUDF
- if filter.is_some() {
- return exec_err!(
- "aggregate expression with filter is not supported"
- );
- }
-
- let names = args
- .iter()
- .map(|e| create_physical_name(e, false))
- .collect::<Result<Vec<_>>>()?;
- Ok(format!("{}({})", fun.name(), names.join(",")))
- }
- },
+ }) => create_function_physical_name(
+ func_def.name(),
+ *distinct,
+ args,
+ order_by.as_ref(),
+ ),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
@@ -1941,6 +1925,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
physical_input_schema,
name,
ignore_nulls,
+ *distinct,
)?;
(agg_expr, filter, physical_sort_exprs)
}
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 36953742c1..a0bd0086aa 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1892,16 +1892,8 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) ->
Result<()> {
order_by,
null_treatment,
}) => {
- match func_def {
- AggregateFunctionDefinition::BuiltIn(..) => {
- write_function_name(w, func_def.name(), *distinct, args)?;
- }
- AggregateFunctionDefinition::UDF(fun) => {
- write!(w, "{}(", fun.name())?;
- write_names_join(w, args, ",")?;
- write!(w, ")")?;
- }
- };
+ write_function_name(w, func_def.name(), *distinct, args)?;
+
if let Some(fe) = filter {
write!(w, " FILTER (WHERE {fe})")?;
};
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index a607d49ef9..dfbd5f5632 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -25,7 +25,9 @@ use datafusion_expr::expr::{
AggregateFunction, AggregateFunctionDefinition, WindowFunction,
};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
+use datafusion_expr::{
+ aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition,
+};
/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
///
@@ -54,23 +56,37 @@ fn is_wildcard(expr: &Expr) -> bool {
}
fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
- matches!(
- &aggregate_function.func_def,
- AggregateFunctionDefinition::BuiltIn(
- datafusion_expr::aggregate_function::AggregateFunction::Count,
- )
- ) && aggregate_function.args.len() == 1
- && is_wildcard(&aggregate_function.args[0])
+ match aggregate_function {
+ AggregateFunction {
+ func_def: AggregateFunctionDefinition::UDF(udf),
+ args,
+ ..
+ } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])
=> true,
+ AggregateFunction {
+ func_def:
+ AggregateFunctionDefinition::BuiltIn(
+
datafusion_expr::aggregate_function::AggregateFunction::Count,
+ ),
+ args,
+ ..
+ } if args.len() == 1 && is_wildcard(&args[0]) => true,
+ _ => false,
+ }
}
fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
- matches!(
- &window_function.fun,
+ let args = &window_function.args;
+ match window_function.fun {
WindowFunctionDefinition::AggregateFunction(
- datafusion_expr::aggregate_function::AggregateFunction::Count,
- )
- ) && window_function.args.len() == 1
- && is_wildcard(&window_function.args[0])
+ aggregate_function::AggregateFunction::Count,
+ ) if args.len() == 1 && is_wildcard(&args[0]) => true,
+ WindowFunctionDefinition::AggregateUDF(ref udaf)
+ if udaf.name() == "COUNT" && args.len() == 1 &&
is_wildcard(&args[0]) =>
+ {
+ true
+ }
+ _ => false,
+ }
}
fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 5c82cf93cb..4f9c1ad645 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -90,6 +90,31 @@ fn is_single_distinct_agg(plan: &LogicalPlan) ->
Result<bool> {
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
+ } else if let Expr::AggregateFunction(AggregateFunction {
+ func_def: AggregateFunctionDefinition::UDF(fun),
+ distinct,
+ args,
+ filter,
+ order_by,
+ null_treatment: _,
+ }) = expr
+ {
+ if filter.is_some() || order_by.is_some() {
+ return Ok(false);
+ }
+ aggregate_count += 1;
+ if *distinct {
+ for e in args {
+ fields_set.insert(e.canonical_name());
+ }
+ } else if fun.name() != "SUM"
+ && fun.name() != "MIN"
+ && fun.name() != "MAX"
+ {
+ return Ok(false);
+ }
+ } else {
+ return Ok(false);
}
}
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index d2e3414fbf..05641b373b 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -34,6 +34,7 @@ use self::utils::{down_cast_any_ref, ordering_fields};
/// Creates a physical expression of the UDAF, that includes all necessary
type coercion.
/// This function errors when `args`' can't be coerced to a valid argument
type of the UDAF.
+#[allow(clippy::too_many_arguments)]
pub fn create_aggregate_expr(
fun: &AggregateUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
@@ -42,6 +43,7 @@ pub fn create_aggregate_expr(
schema: &Schema,
name: impl Into<String>,
ignore_nulls: bool,
+ is_distinct: bool,
) -> Result<Arc<dyn AggregateExpr>> {
let input_exprs_types = input_phy_exprs
.iter()
@@ -71,6 +73,7 @@ pub fn create_aggregate_expr(
ordering_req: ordering_req.to_vec(),
ignore_nulls,
ordering_fields,
+ is_distinct,
}))
}
@@ -162,6 +165,7 @@ pub struct AggregateFunctionExpr {
ordering_req: LexOrdering,
ignore_nulls: bool,
ordering_fields: Vec<Field>,
+ is_distinct: bool,
}
impl AggregateFunctionExpr {
@@ -169,6 +173,11 @@ impl AggregateFunctionExpr {
pub fn fun(&self) -> &AggregateUDF {
&self.fun
}
+
+ /// Return if the aggregation is distinct
+ pub fn is_distinct(&self) -> bool {
+ self.is_distinct
+ }
}
impl AggregateExpr for AggregateFunctionExpr {
diff --git a/datafusion/physical-plan/src/windows/mod.rs
b/datafusion/physical-plan/src/windows/mod.rs
index ff60329ce1..d1223f7880 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -103,6 +103,7 @@ pub fn create_window_expr(
input_schema,
name,
ignore_nulls,
+ false,
)?;
window_expr_from_aggregate_expr(
partition_by,
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 1c5ba861d2..4de0b7c06d 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -525,7 +525,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
let sort_exprs = &[];
let ordering_req = &[];
let ignore_nulls = false;
-
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs,
ordering_req, &physical_schema, name, ignore_nulls)
+
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs,
ordering_req, &physical_schema, name, ignore_nulls, false)
}
}
}).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index c2018352c7..30a28081ed 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -426,6 +426,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
&schema,
"example_agg",
false,
+ false,
)?];
roundtrip_test_with_context(
diff --git a/datafusion/sql/src/expr/function.rs
b/datafusion/sql/src/expr/function.rs
index 3adf296078..dc0ddd4714 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -229,12 +229,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)?;
let order_by = (!order_by.is_empty()).then_some(order_by);
let args = self.function_args_to_expr(args, schema,
planner_context)?;
- // TODO: Support filter and distinct for UDAFs
+ let filter: Option<Box<Expr>> = filter
+ .map(|e| self.sql_expr_to_logical_expr(*e, schema,
planner_context))
+ .transpose()?
+ .map(Box::new);
return
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
fm,
args,
- false,
- None,
+ distinct,
+ filter,
order_by,
null_treatment,
)));
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index db5d341bc2..6f0738c38d 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -722,7 +722,10 @@ pub fn to_substrait_agg_measure(
arguments,
sorts,
output_type: None,
- invocation: AggregationInvocation::All as i32,
+ invocation: match distinct {
+ true => AggregationInvocation::Distinct as i32,
+ false => AggregationInvocation::All as i32,
+ },
phase: AggregationPhase::Unspecified as i32,
args: vec![],
options: vec![],
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]