This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 16346022ed Support UDAF to align Builtin aggregate function (#10493)
16346022ed is described below

commit 16346022ed4564211c0fb4bf20ac165f2c481a90
Author: Jay Zhan <[email protected]>
AuthorDate: Wed May 15 11:40:25 2024 +0800

    Support UDAF to align Builtin aggregate function (#10493)
    
    * align udaf and builtin
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * add more
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 .../src/physical_optimizer/aggregate_statistics.rs | 76 +++++++++++++---------
 datafusion/core/src/physical_planner.rs            | 31 +++------
 datafusion/expr/src/expr.rs                        | 12 +---
 .../optimizer/src/analyzer/count_wildcard_rule.rs  | 44 +++++++++----
 .../optimizer/src/single_distinct_to_groupby.rs    | 25 +++++++
 .../physical-expr-common/src/aggregate/mod.rs      |  9 +++
 datafusion/physical-plan/src/windows/mod.rs        |  1 +
 datafusion/proto/src/physical_plan/mod.rs          |  2 +-
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  1 +
 datafusion/sql/src/expr/function.rs                |  9 ++-
 datafusion/substrait/src/logical_plan/producer.rs  |  5 +-
 11 files changed, 131 insertions(+), 84 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs 
b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
index 5057488603..1a82dac465 100644
--- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
+++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs
@@ -30,6 +30,7 @@ use datafusion_common::stats::Precision;
 use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
 use datafusion_expr::utils::COUNT_STAR_EXPANSION;
 use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
+use datafusion_physical_plan::udaf::AggregateFunctionExpr;
 
 /// Optimizer that uses available statistics for aggregate functions
 #[derive(Default)]
@@ -57,13 +58,9 @@ impl PhysicalOptimizerRule for AggregateStatistics {
             let mut projections = vec![];
             for expr in partial_agg_exec.aggr_expr() {
                 if let Some((non_null_rows, name)) =
-                    take_optimizable_column_count(&**expr, &stats)
+                    take_optimizable_column_and_table_count(&**expr, &stats)
                 {
                     projections.push((expressions::lit(non_null_rows), 
name.to_owned()));
-                } else if let Some((num_rows, name)) =
-                    take_optimizable_table_count(&**expr, &stats)
-                {
-                    projections.push((expressions::lit(num_rows), 
name.to_owned()));
                 } else if let Some((min, name)) = 
take_optimizable_min(&**expr, &stats) {
                     projections.push((expressions::lit(min), name.to_owned()));
                 } else if let Some((max, name)) = 
take_optimizable_max(&**expr, &stats) {
@@ -137,43 +134,48 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> 
Option<Arc<dyn ExecutionPlan>>
     None
 }
 
-/// If this agg_expr is a count that is exactly defined in the statistics, 
return it.
-fn take_optimizable_table_count(
+/// If this agg_expr is a count that can be exactly derived from the 
statistics, return it.
+fn take_optimizable_column_and_table_count(
     agg_expr: &dyn AggregateExpr,
     stats: &Statistics,
 ) -> Option<(ScalarValue, String)> {
-    if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
-        &stats.num_rows,
-        agg_expr.as_any().downcast_ref::<expressions::Count>(),
-    ) {
-        // TODO implementing Eq on PhysicalExpr would help a lot here
-        if casted_expr.expressions().len() == 1 {
-            if let Some(lit_expr) = casted_expr.expressions()[0]
-                .as_any()
-                .downcast_ref::<expressions::Literal>()
-            {
-                if lit_expr.value() == &COUNT_STAR_EXPANSION {
-                    return Some((
-                        ScalarValue::Int64(Some(num_rows as i64)),
-                        casted_expr.name().to_owned(),
-                    ));
+    let col_stats = &stats.column_statistics;
+    if let Some(agg_expr) = 
agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
+        if agg_expr.fun().name() == "COUNT" && !agg_expr.is_distinct() {
+            if let Precision::Exact(num_rows) = stats.num_rows {
+                let exprs = agg_expr.expressions();
+                if exprs.len() == 1 {
+                    // TODO optimize with exprs other than Column
+                    if let Some(col_expr) =
+                        exprs[0].as_any().downcast_ref::<expressions::Column>()
+                    {
+                        let current_val = 
&col_stats[col_expr.index()].null_count;
+                        if let &Precision::Exact(val) = current_val {
+                            return Some((
+                                ScalarValue::Int64(Some((num_rows - val) as 
i64)),
+                                agg_expr.name().to_string(),
+                            ));
+                        }
+                    } else if let Some(lit_expr) =
+                        
exprs[0].as_any().downcast_ref::<expressions::Literal>()
+                    {
+                        if lit_expr.value() == &COUNT_STAR_EXPANSION {
+                            return Some((
+                                ScalarValue::Int64(Some(num_rows as i64)),
+                                agg_expr.name().to_string(),
+                            ));
+                        }
+                    }
                 }
             }
         }
     }
-    None
-}
-
-/// If this agg_expr is a count that can be exactly derived from the 
statistics, return it.
-fn take_optimizable_column_count(
-    agg_expr: &dyn AggregateExpr,
-    stats: &Statistics,
-) -> Option<(ScalarValue, String)> {
-    let col_stats = &stats.column_statistics;
-    if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
+    // TODO: Remove this after revmoing Builtin Count
+    else if let (&Precision::Exact(num_rows), Some(casted_expr)) = (
         &stats.num_rows,
         agg_expr.as_any().downcast_ref::<expressions::Count>(),
     ) {
+        // TODO implementing Eq on PhysicalExpr would help a lot here
         if casted_expr.expressions().len() == 1 {
             // TODO optimize with exprs other than Column
             if let Some(col_expr) = casted_expr.expressions()[0]
@@ -187,6 +189,16 @@ fn take_optimizable_column_count(
                         casted_expr.name().to_string(),
                     ));
                 }
+            } else if let Some(lit_expr) = casted_expr.expressions()[0]
+                .as_any()
+                .downcast_ref::<expressions::Literal>()
+            {
+                if lit_expr.value() == &COUNT_STAR_EXPANSION {
+                    return Some((
+                        ScalarValue::Int64(Some(num_rows as i64)),
+                        casted_expr.name().to_owned(),
+                    ));
+                }
             }
         }
     }
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index d4a9a949fc..406196a591 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -252,31 +252,15 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> 
Result<String> {
             func_def,
             distinct,
             args,
-            filter,
+            filter: _,
             order_by,
             null_treatment: _,
-        }) => match func_def {
-            AggregateFunctionDefinition::BuiltIn(..) => 
create_function_physical_name(
-                func_def.name(),
-                *distinct,
-                args,
-                order_by.as_ref(),
-            ),
-            AggregateFunctionDefinition::UDF(fun) => {
-                // TODO: Add support for filter by in AggregateUDF
-                if filter.is_some() {
-                    return exec_err!(
-                        "aggregate expression with filter is not supported"
-                    );
-                }
-
-                let names = args
-                    .iter()
-                    .map(|e| create_physical_name(e, false))
-                    .collect::<Result<Vec<_>>>()?;
-                Ok(format!("{}({})", fun.name(), names.join(",")))
-            }
-        },
+        }) => create_function_physical_name(
+            func_def.name(),
+            *distinct,
+            args,
+            order_by.as_ref(),
+        ),
         Expr::GroupingSet(grouping_set) => match grouping_set {
             GroupingSet::Rollup(exprs) => Ok(format!(
                 "ROLLUP ({})",
@@ -1941,6 +1925,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
                         physical_input_schema,
                         name,
                         ignore_nulls,
+                        *distinct,
                     )?;
                     (agg_expr, filter, physical_sort_exprs)
                 }
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 36953742c1..a0bd0086aa 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -1892,16 +1892,8 @@ fn write_name<W: Write>(w: &mut W, e: &Expr) -> 
Result<()> {
             order_by,
             null_treatment,
         }) => {
-            match func_def {
-                AggregateFunctionDefinition::BuiltIn(..) => {
-                    write_function_name(w, func_def.name(), *distinct, args)?;
-                }
-                AggregateFunctionDefinition::UDF(fun) => {
-                    write!(w, "{}(", fun.name())?;
-                    write_names_join(w, args, ",")?;
-                    write!(w, ")")?;
-                }
-            };
+            write_function_name(w, func_def.name(), *distinct, args)?;
+
             if let Some(fe) = filter {
                 write!(w, " FILTER (WHERE {fe})")?;
             };
diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs 
b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
index a607d49ef9..dfbd5f5632 100644
--- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
+++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
@@ -25,7 +25,9 @@ use datafusion_expr::expr::{
     AggregateFunction, AggregateFunctionDefinition, WindowFunction,
 };
 use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition};
+use datafusion_expr::{
+    aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition,
+};
 
 /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
 ///
@@ -54,23 +56,37 @@ fn is_wildcard(expr: &Expr) -> bool {
 }
 
 fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool {
-    matches!(
-        &aggregate_function.func_def,
-        AggregateFunctionDefinition::BuiltIn(
-            datafusion_expr::aggregate_function::AggregateFunction::Count,
-        )
-    ) && aggregate_function.args.len() == 1
-        && is_wildcard(&aggregate_function.args[0])
+    match aggregate_function {
+        AggregateFunction {
+            func_def: AggregateFunctionDefinition::UDF(udf),
+            args,
+            ..
+        } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) 
=> true,
+        AggregateFunction {
+            func_def:
+                AggregateFunctionDefinition::BuiltIn(
+                    
datafusion_expr::aggregate_function::AggregateFunction::Count,
+                ),
+            args,
+            ..
+        } if args.len() == 1 && is_wildcard(&args[0]) => true,
+        _ => false,
+    }
 }
 
 fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool {
-    matches!(
-        &window_function.fun,
+    let args = &window_function.args;
+    match window_function.fun {
         WindowFunctionDefinition::AggregateFunction(
-            datafusion_expr::aggregate_function::AggregateFunction::Count,
-        )
-    ) && window_function.args.len() == 1
-        && is_wildcard(&window_function.args[0])
+            aggregate_function::AggregateFunction::Count,
+        ) if args.len() == 1 && is_wildcard(&args[0]) => true,
+        WindowFunctionDefinition::AggregateUDF(ref udaf)
+            if udaf.name() == "COUNT" && args.len() == 1 && 
is_wildcard(&args[0]) =>
+        {
+            true
+        }
+        _ => false,
+    }
 }
 
 fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs 
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 5c82cf93cb..4f9c1ad645 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -90,6 +90,31 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> 
Result<bool> {
                     } else if !matches!(fun, Sum | Min | Max) {
                         return Ok(false);
                     }
+                } else if let Expr::AggregateFunction(AggregateFunction {
+                    func_def: AggregateFunctionDefinition::UDF(fun),
+                    distinct,
+                    args,
+                    filter,
+                    order_by,
+                    null_treatment: _,
+                }) = expr
+                {
+                    if filter.is_some() || order_by.is_some() {
+                        return Ok(false);
+                    }
+                    aggregate_count += 1;
+                    if *distinct {
+                        for e in args {
+                            fields_set.insert(e.canonical_name());
+                        }
+                    } else if fun.name() != "SUM"
+                        && fun.name() != "MIN"
+                        && fun.name() != "MAX"
+                    {
+                        return Ok(false);
+                    }
+                } else {
+                    return Ok(false);
                 }
             }
             Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs 
b/datafusion/physical-expr-common/src/aggregate/mod.rs
index d2e3414fbf..05641b373b 100644
--- a/datafusion/physical-expr-common/src/aggregate/mod.rs
+++ b/datafusion/physical-expr-common/src/aggregate/mod.rs
@@ -34,6 +34,7 @@ use self::utils::{down_cast_any_ref, ordering_fields};
 
 /// Creates a physical expression of the UDAF, that includes all necessary 
type coercion.
 /// This function errors when `args`' can't be coerced to a valid argument 
type of the UDAF.
+#[allow(clippy::too_many_arguments)]
 pub fn create_aggregate_expr(
     fun: &AggregateUDF,
     input_phy_exprs: &[Arc<dyn PhysicalExpr>],
@@ -42,6 +43,7 @@ pub fn create_aggregate_expr(
     schema: &Schema,
     name: impl Into<String>,
     ignore_nulls: bool,
+    is_distinct: bool,
 ) -> Result<Arc<dyn AggregateExpr>> {
     let input_exprs_types = input_phy_exprs
         .iter()
@@ -71,6 +73,7 @@ pub fn create_aggregate_expr(
         ordering_req: ordering_req.to_vec(),
         ignore_nulls,
         ordering_fields,
+        is_distinct,
     }))
 }
 
@@ -162,6 +165,7 @@ pub struct AggregateFunctionExpr {
     ordering_req: LexOrdering,
     ignore_nulls: bool,
     ordering_fields: Vec<Field>,
+    is_distinct: bool,
 }
 
 impl AggregateFunctionExpr {
@@ -169,6 +173,11 @@ impl AggregateFunctionExpr {
     pub fn fun(&self) -> &AggregateUDF {
         &self.fun
     }
+
+    /// Return if the aggregation is distinct
+    pub fn is_distinct(&self) -> bool {
+        self.is_distinct
+    }
 }
 
 impl AggregateExpr for AggregateFunctionExpr {
diff --git a/datafusion/physical-plan/src/windows/mod.rs 
b/datafusion/physical-plan/src/windows/mod.rs
index ff60329ce1..d1223f7880 100644
--- a/datafusion/physical-plan/src/windows/mod.rs
+++ b/datafusion/physical-plan/src/windows/mod.rs
@@ -103,6 +103,7 @@ pub fn create_window_expr(
                 input_schema,
                 name,
                 ignore_nulls,
+                false,
             )?;
             window_expr_from_aggregate_expr(
                 partition_by,
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 1c5ba861d2..4de0b7c06d 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -525,7 +525,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
                                             let sort_exprs = &[];
                                             let ordering_req = &[];
                                             let ignore_nulls = false;
-                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, 
ordering_req, &physical_schema, name, ignore_nulls)
+                                            
udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, 
ordering_req, &physical_schema, name, ignore_nulls, false)
                                         }
                                     }
                                 }).transpose()?.ok_or_else(|| {
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index c2018352c7..30a28081ed 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -426,6 +426,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> {
         &schema,
         "example_agg",
         false,
+        false,
     )?];
 
     roundtrip_test_with_context(
diff --git a/datafusion/sql/src/expr/function.rs 
b/datafusion/sql/src/expr/function.rs
index 3adf296078..dc0ddd4714 100644
--- a/datafusion/sql/src/expr/function.rs
+++ b/datafusion/sql/src/expr/function.rs
@@ -229,12 +229,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 )?;
                 let order_by = (!order_by.is_empty()).then_some(order_by);
                 let args = self.function_args_to_expr(args, schema, 
planner_context)?;
-                // TODO: Support filter and distinct for UDAFs
+                let filter: Option<Box<Expr>> = filter
+                    .map(|e| self.sql_expr_to_logical_expr(*e, schema, 
planner_context))
+                    .transpose()?
+                    .map(Box::new);
                 return 
Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf(
                     fm,
                     args,
-                    false,
-                    None,
+                    distinct,
+                    filter,
                     order_by,
                     null_treatment,
                 )));
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index db5d341bc2..6f0738c38d 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -722,7 +722,10 @@ pub fn to_substrait_agg_measure(
                             arguments,
                             sorts,
                             output_type: None,
-                            invocation: AggregationInvocation::All as i32,
+                            invocation: match distinct {
+                                true => AggregationInvocation::Distinct as i32,
+                                false => AggregationInvocation::All as i32,
+                            },
                             phase: AggregationPhase::Unspecified as i32,
                             args: vec![],
                             options: vec![],


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to