This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 30e8ceff8 Add logical plan support for aggregate expressions with 
filters (and upgrade to sqlparser 0.23) (#3405)
30e8ceff8 is described below

commit 30e8ceff80dcb5d95d8f399917ac6c846986bdf7
Author: Andy Grove <[email protected]>
AuthorDate: Mon Sep 12 12:19:29 2022 -0600

    Add logical plan support for aggregate expressions with filters (and 
upgrade to sqlparser 0.23) (#3405)
    
    * Use sqlparser-0.23
    
    * Add filter to aggregate expressions
    
    * clippy
    
    * implement protobuf serde
    
    * clippy
    
    * fix error message
    
    * Update datafusion/expr/src/expr.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_plan/planner.rs       |  7 +++-
 datafusion/expr/src/expr.rs                        | 46 ++++++++++++++++++----
 datafusion/expr/src/expr_fn.rs                     | 10 +++++
 datafusion/expr/src/expr_rewriter.rs               |  5 ++-
 datafusion/expr/src/udaf.rs                        |  1 +
 .../optimizer/src/single_distinct_to_groupby.rs    |  6 ++-
 datafusion/proto/proto/datafusion.proto            |  2 +
 datafusion/proto/src/from_proto.rs                 |  9 +++--
 datafusion/proto/src/lib.rs                        |  4 ++
 datafusion/proto/src/to_proto.rs                   | 35 ++++++++++------
 datafusion/sql/src/planner.rs                      | 29 ++++++++++++--
 datafusion/sql/src/utils.rs                        |  5 ++-
 12 files changed, 130 insertions(+), 29 deletions(-)

diff --git a/datafusion/core/src/physical_plan/planner.rs 
b/datafusion/core/src/physical_plan/planner.rs
index 34cadd5b4..9f1e488ef 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -194,7 +194,12 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> 
Result<String> {
             args,
             ..
         } => create_function_physical_name(&fun.to_string(), *distinct, args),
-        Expr::AggregateUDF { fun, args } => {
+        Expr::AggregateUDF { fun, args, filter } => {
+            if filter.is_some() {
+                return Err(DataFusionError::Execution(
+                    "aggregate expression with filter is not 
supported".to_string(),
+                ));
+            }
             let mut names = Vec::with_capacity(args.len());
             for e in args {
                 names.push(create_physical_name(e, false)?);
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index ab45dd67d..8b90fb9e4 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -231,6 +231,8 @@ pub enum Expr {
         args: Vec<Expr>,
         /// Whether this is a DISTINCT aggregation or not
         distinct: bool,
+        /// Optional filter
+        filter: Option<Box<Expr>>,
     },
     /// Represents the call of a window function with arguments.
     WindowFunction {
@@ -251,6 +253,8 @@ pub enum Expr {
         fun: Arc<AggregateUDF>,
         /// List of expressions to feed to the functions as arguments
         args: Vec<Expr>,
+        /// Optional filter applied prior to aggregating
+        filter: Option<Box<Expr>>,
     },
     /// Returns whether the list contains the expr value.
     InList {
@@ -668,10 +672,26 @@ impl fmt::Debug for Expr {
                 fun,
                 distinct,
                 ref args,
+                filter,
                 ..
-            } => fmt_function(f, &fun.to_string(), *distinct, args, true),
-            Expr::AggregateUDF { fun, ref args, .. } => {
-                fmt_function(f, &fun.name, false, args, false)
+            } => {
+                fmt_function(f, &fun.to_string(), *distinct, args, true)?;
+                if let Some(fe) = filter {
+                    write!(f, " FILTER (WHERE {})", fe)?;
+                }
+                Ok(())
+            }
+            Expr::AggregateUDF {
+                fun,
+                ref args,
+                filter,
+                ..
+            } => {
+                fmt_function(f, &fun.name, false, args, false)?;
+                if let Some(fe) = filter {
+                    write!(f, " FILTER (WHERE {})", fe)?;
+                }
+                Ok(())
             }
             Expr::Between {
                 expr,
@@ -1010,14 +1030,26 @@ fn create_name(e: &Expr) -> Result<String> {
             fun,
             distinct,
             args,
-            ..
-        } => create_function_name(&fun.to_string(), *distinct, args),
-        Expr::AggregateUDF { fun, args } => {
+            filter,
+        } => {
+            let name = create_function_name(&fun.to_string(), *distinct, 
args)?;
+            if let Some(fe) = filter {
+                Ok(format!("{} FILTER (WHERE {})", name, fe))
+            } else {
+                Ok(name)
+            }
+        }
+        Expr::AggregateUDF { fun, args, filter } => {
             let mut names = Vec::with_capacity(args.len());
             for e in args {
                 names.push(create_name(e)?);
             }
-            Ok(format!("{}({})", fun.name, names.join(",")))
+            let filter = if let Some(fe) = filter {
+                format!(" FILTER (WHERE {})", fe)
+            } else {
+                "".to_string()
+            };
+            Ok(format!("{}({}){}", fun.name, names.join(","), filter))
         }
         Expr::GroupingSet(grouping_set) => match grouping_set {
             GroupingSet::Rollup(exprs) => {
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index f7eaec39b..6c5cc0ecc 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -66,6 +66,7 @@ pub fn min(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::Min,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -75,6 +76,7 @@ pub fn max(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::Max,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -84,6 +86,7 @@ pub fn sum(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::Sum,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -93,6 +96,7 @@ pub fn avg(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::Avg,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -102,6 +106,7 @@ pub fn count(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::Count,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -111,6 +116,7 @@ pub fn count_distinct(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::Count,
         distinct: true,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -163,6 +169,7 @@ pub fn approx_distinct(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::ApproxDistinct,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -172,6 +179,7 @@ pub fn approx_median(expr: Expr) -> Expr {
         fun: aggregate_function::AggregateFunction::ApproxMedian,
         distinct: false,
         args: vec![expr],
+        filter: None,
     }
 }
 
@@ -181,6 +189,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) 
-> Expr {
         fun: aggregate_function::AggregateFunction::ApproxPercentileCont,
         distinct: false,
         args: vec![expr, percentile],
+        filter: None,
     }
 }
 
@@ -194,6 +203,7 @@ pub fn approx_percentile_cont_with_weight(
         fun: 
aggregate_function::AggregateFunction::ApproxPercentileContWithWeight,
         distinct: false,
         args: vec![expr, weight_expr, percentile],
+        filter: None,
     }
 }
 
diff --git a/datafusion/expr/src/expr_rewriter.rs 
b/datafusion/expr/src/expr_rewriter.rs
index b8b9fced9..533f31ce1 100644
--- a/datafusion/expr/src/expr_rewriter.rs
+++ b/datafusion/expr/src/expr_rewriter.rs
@@ -250,10 +250,12 @@ impl ExprRewritable for Expr {
                 args,
                 fun,
                 distinct,
+                filter,
             } => Expr::AggregateFunction {
                 args: rewrite_vec(args, rewriter)?,
                 fun,
                 distinct,
+                filter,
             },
             Expr::GroupingSet(grouping_set) => match grouping_set {
                 GroupingSet::Rollup(exprs) => {
@@ -271,9 +273,10 @@ impl ExprRewritable for Expr {
                     ))
                 }
             },
-            Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
+            Expr::AggregateUDF { args, fun, filter } => Expr::AggregateUDF {
                 args: rewrite_vec(args, rewriter)?,
                 fun,
+                filter,
             },
             Expr::InList {
                 expr,
diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs
index 00f48dda2..0ecb5280a 100644
--- a/datafusion/expr/src/udaf.rs
+++ b/datafusion/expr/src/udaf.rs
@@ -89,6 +89,7 @@ impl AggregateUDF {
         Expr::AggregateUDF {
             fun: Arc::new(self.clone()),
             args,
+            filter: None,
         }
     }
 }
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs 
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 656d3967e..f1982bcf1 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -87,7 +87,9 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                 let new_aggr_exprs = aggr_expr
                     .iter()
                     .map(|aggr_expr| match aggr_expr {
-                        Expr::AggregateFunction { fun, args, .. } => {
+                        Expr::AggregateFunction {
+                            fun, args, filter, ..
+                        } => {
                             // is_single_distinct_agg ensure args.len=1
                             if group_fields_set.insert(args[0].name()?) {
                                 inner_group_exprs
@@ -97,6 +99,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
                                 fun: fun.clone(),
                                 args: vec![col(SINGLE_DISTINCT_ALIAS)],
                                 distinct: false, // intentional to remove 
distinct here
+                                filter: filter.clone(),
                             })
                         }
                         _ => Ok(aggr_expr.clone()),
@@ -402,6 +405,7 @@ mod tests {
                         fun: AggregateFunction::Max,
                         distinct: true,
                         args: vec![col("b")],
+                        filter: None,
                     },
                 ],
             )?
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 8d4da0250..baabc04cf 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -504,11 +504,13 @@ message AggregateExprNode {
   AggregateFunction aggr_function = 1;
   repeated LogicalExprNode expr = 2;
   bool distinct = 3;
+  LogicalExprNode filter = 4;
 }
 
 message AggregateUDFExprNode {
   string fun_name = 1;
   repeated LogicalExprNode args = 2;
+  LogicalExprNode filter = 3;
 }
 
 message ScalarUDFExprNode {
diff --git a/datafusion/proto/src/from_proto.rs 
b/datafusion/proto/src/from_proto.rs
index c93d3a877..5402b03ce 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -891,6 +891,7 @@ pub fn parse_expr(
                     .map(|e| parse_expr(e, registry))
                     .collect::<Result<Vec<_>, _>>()?,
                 distinct: expr.distinct,
+                filter: parse_optional_expr(&expr.filter, 
registry)?.map(Box::new),
             })
         }
         ExprType::Alias(alias) => Ok(Expr::Alias(
@@ -1194,15 +1195,17 @@ pub fn parse_expr(
                     .collect::<Result<Vec<_>, Error>>()?,
             })
         }
-        ExprType::AggregateUdfExpr(protobuf::AggregateUdfExprNode { fun_name, 
args }) => {
-            let agg_fn = registry.udaf(fun_name.as_str())?;
+        ExprType::AggregateUdfExpr(pb) => {
+            let agg_fn = registry.udaf(pb.fun_name.as_str())?;
 
             Ok(Expr::AggregateUDF {
                 fun: agg_fn,
-                args: args
+                args: pb
+                    .args
                     .iter()
                     .map(|expr| parse_expr(expr, registry))
                     .collect::<Result<Vec<_>, Error>>()?,
+                filter: parse_optional_expr(&pb.filter, 
registry)?.map(Box::new),
             })
         }
 
diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs
index 8e9475329..cce778be2 100644
--- a/datafusion/proto/src/lib.rs
+++ b/datafusion/proto/src/lib.rs
@@ -1023,6 +1023,7 @@ mod roundtrip_tests {
             fun: AggregateFunction::Count,
             args: vec![col("bananas")],
             distinct: false,
+            filter: None,
         };
         let ctx = SessionContext::new();
         roundtrip_expr_test(test_expr, ctx);
@@ -1034,6 +1035,7 @@ mod roundtrip_tests {
             fun: AggregateFunction::Count,
             args: vec![col("bananas")],
             distinct: true,
+            filter: None,
         };
         let ctx = SessionContext::new();
         roundtrip_expr_test(test_expr, ctx);
@@ -1045,6 +1047,7 @@ mod roundtrip_tests {
             fun: AggregateFunction::ApproxPercentileCont,
             args: vec![col("bananas"), lit(0.42_f32)],
             distinct: false,
+            filter: None,
         };
 
         let ctx = SessionContext::new();
@@ -1097,6 +1100,7 @@ mod roundtrip_tests {
         let test_expr = Expr::AggregateUDF {
             fun: Arc::new(dummy_agg.clone()),
             args: vec![lit(1.0_f64)],
+            filter: Some(Box::new(lit(true))),
         };
 
         let mut ctx = SessionContext::new();
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 8c43f876e..43d649029 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -585,6 +585,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                 ref fun,
                 ref args,
                 ref distinct,
+                ref filter
             } => {
                 let aggr_function = match fun {
                     AggregateFunction::ApproxDistinct => {
@@ -633,9 +634,13 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                         .map(|v| v.try_into())
                         .collect::<Result<Vec<_>, _>>()?,
                     distinct: *distinct,
+                    filter: match filter {
+                        Some(e) => Some(Box::new(e.as_ref().try_into()?)),
+                        None => None,
+                    }
                 };
                 Self {
-                    expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
+                    expr_type: 
Some(ExprType::AggregateExpr(Box::new(aggregate_expr))),
                 }
             }
             Expr::ScalarVariable(_, _) => return Err(Error::General("Proto 
serialization error: Scalar Variable not supported".to_string())),
@@ -663,17 +668,23 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                         .collect::<Result<Vec<_>, Error>>()?,
                 })),
             },
-            Expr::AggregateUDF { fun, args } => Self {
-                expr_type: Some(ExprType::AggregateUdfExpr(
-                    protobuf::AggregateUdfExprNode {
-                        fun_name: fun.name.clone(),
-                        args: args.iter().map(|expr| 
expr.try_into()).collect::<Result<
-                            Vec<_>,
-                            Error,
-                        >>(
-                        )?,
-                    },
-                )),
+            Expr::AggregateUDF { fun, args, filter } => {
+                Self {
+                    expr_type: Some(ExprType::AggregateUdfExpr(
+                        Box::new(protobuf::AggregateUdfExprNode {
+                            fun_name: fun.name.clone(),
+                            args: args.iter().map(|expr| 
expr.try_into()).collect::<Result<
+                                Vec<_>,
+                                Error,
+                            >>(
+                            )?,
+                            filter: match filter {
+                                Some(e) => 
Some(Box::new(e.as_ref().try_into()?)),
+                                None => None,
+                            }
+                        },
+                    ))),
+                }
             },
             Expr::Not(expr) => {
                 let expr = Box::new(protobuf::Not {
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 881de1ebe..5d30b670f 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -2089,6 +2089,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 Ok(Expr::ScalarFunction { fun, args })
             }
 
+            SQLExpr::AggregateExpressionWithFilter { expr, filter } => {
+                match self.sql_expr_to_logical_expr(*expr, schema, ctes)? {
+                    Expr::AggregateFunction {
+                        fun, args, distinct, ..
+                    } =>  Ok(Expr::AggregateFunction { fun, args, distinct, 
filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) 
}),
+                    _ => 
Err(DataFusionError::Internal("AggregateExpressionWithFilter expression was not 
an AggregateFunction".to_string()))
+                }
+
+            }
+
             SQLExpr::Function(mut function) => {
                 let name = if function.name.0.len() > 1 {
                     // DF doesn't handle compound identifiers
@@ -2185,6 +2195,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                         fun,
                         distinct,
                         args,
+                        filter: None
                     });
                 };
 
@@ -2198,7 +2209,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     None => match 
self.schema_provider.get_aggregate_meta(&name) {
                         Some(fm) => {
                             let args = 
self.function_args_to_expr(function.args, schema)?;
-                            Ok(Expr::AggregateUDF { fun: fm, args })
+                            Ok(Expr::AggregateUDF { fun: fm, args, filter: 
None })
                         }
                         _ => Err(DataFusionError::Plan(format!(
                             "Invalid function '{}'",
@@ -2217,7 +2228,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             SQLExpr::Subquery(subquery) => 
self.parse_scalar_subquery(&subquery, schema, ctes),
 
             _ => Err(DataFusionError::NotImplemented(format!(
-                "Unsupported ast node {:?} in sqltorel",
+                "Unsupported ast node in sqltorel: {:?}",
                 sql
             ))),
         }
@@ -2731,7 +2742,7 @@ fn parse_sql_number(n: &str) -> Result<Expr> {
 mod tests {
     use super::*;
     use crate::assert_contains;
-    use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
+    use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, 
MySqlDialect};
     use std::any::Any;
 
     #[test]
@@ -4966,6 +4977,18 @@ mod tests {
         quick_test(sql, expected);
     }
 
+    #[test]
+    fn hive_aggregate_with_filter() -> Result<()> {
+        let dialect = &HiveDialect {};
+        let sql = "SELECT SUM(age) FILTER (WHERE age > 4) FROM person";
+        let plan = logical_plan_with_dialect(sql, dialect)?;
+        let expected = "Projection: #SUM(person.age) FILTER (WHERE #age > 
Int64(4))\
+        \n  Aggregate: groupBy=[[]], aggr=[[SUM(#person.age) FILTER (WHERE 
#age > Int64(4))]]\
+        \n    TableScan: person".to_string();
+        assert_eq!(expected, format!("{}", plan.display_indent()));
+        Ok(())
+    }
+
     #[test]
     fn order_by_unaliased_name() {
         // https://github.com/apache/arrow-datafusion/issues/3160
diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs
index 25f5c549a..eb58509d0 100644
--- a/datafusion/sql/src/utils.rs
+++ b/datafusion/sql/src/utils.rs
@@ -163,6 +163,7 @@ where
                 fun,
                 args,
                 distinct,
+                filter,
             } => Ok(Expr::AggregateFunction {
                 fun: fun.clone(),
                 args: args
@@ -170,6 +171,7 @@ where
                     .map(|e| clone_with_replacement(e, replacement_fn))
                     .collect::<Result<Vec<Expr>>>()?,
                 distinct: *distinct,
+                filter: filter.clone(),
             }),
             Expr::WindowFunction {
                 fun,
@@ -193,12 +195,13 @@ where
                     .collect::<Result<Vec<_>>>()?,
                 window_frame: *window_frame,
             }),
-            Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF {
+            Expr::AggregateUDF { fun, args, filter } => Ok(Expr::AggregateUDF {
                 fun: fun.clone(),
                 args: args
                     .iter()
                     .map(|e| clone_with_replacement(e, replacement_fn))
                     .collect::<Result<Vec<Expr>>>()?,
+                filter: filter.clone(),
             }),
             Expr::Alias(nested_expr, alias_name) => Ok(Expr::Alias(
                 Box::new(clone_with_replacement(nested_expr, replacement_fn)?),

Reply via email to