This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 85ec314336 Add ROLLUP and GROUPING SETS substrait support (#7382)
85ec314336 is described below

commit 85ec31433615735a05d332f87cd0bdfc11aac663
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Tue Aug 29 04:39:38 2023 -0700

    Add ROLLUP and GROUPING SETS substrait support (#7382)
    
    * Add ROLLUP and GROUPING SETS support
    
    * fix: fmt
    
    * clippy
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 43 ++++++++----
 datafusion/substrait/src/logical_plan/producer.rs  | 80 +++++++++++++++++++---
 .../tests/cases/roundtrip_logical_plan.rs          | 17 +++++
 3 files changed, 117 insertions(+), 23 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 32cb1db4c3..32b8f8ea54 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -22,8 +22,10 @@ use datafusion::logical_expr::{
     aggregate_function, window_function::find_df_window_func, BinaryExpr,
     BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
 };
-use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
-use datafusion::logical_expr::{Extension, Like, LogicalPlanBuilder};
+use datafusion::logical_expr::{
+    expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, 
WindowFrameBound,
+    WindowFrameUnits,
+};
 use datafusion::prelude::JoinType;
 use datafusion::sql::TableReference;
 use datafusion::{
@@ -251,18 +253,35 @@ pub async fn from_substrait_rel(
                 let mut group_expr = vec![];
                 let mut aggr_expr = vec![];
 
-                let groupings = match agg.groupings.len() {
-                    1 => Ok(&agg.groupings[0]),
-                    _ => not_impl_err!(
-                        "Aggregate with multiple grouping sets is not 
supported"
-                    ),
+                match agg.groupings.len() {
+                    1 => {
+                        for e in &agg.groupings[0].grouping_expressions {
+                            let x =
+                                from_substrait_rex(e, input.schema(), 
extensions).await?;
+                            group_expr.push(x.as_ref().clone());
+                        }
+                    }
+                    _ => {
+                        let mut grouping_sets = vec![];
+                        for grouping in &agg.groupings {
+                            let mut grouping_set = vec![];
+                            for e in &grouping.grouping_expressions {
+                                let x = from_substrait_rex(e, input.schema(), 
extensions)
+                                    .await?;
+                                grouping_set.push(x.as_ref().clone());
+                            }
+                            grouping_sets.push(grouping_set);
+                        }
+                        // Single-element grouping expression of type 
Expr::GroupingSet.
+                        // Note that GroupingSet::Rollup would become 
GroupingSet::GroupingSets, when
+                        // parsed by the producer and consumer, since 
Substrait does not have a type dedicated
+                        // to ROLLUP. Only vector of Groupings (grouping sets) 
is available.
+                        
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
+                            grouping_sets,
+                        )));
+                    }
                 };
 
-                for e in &groupings?.grouping_expressions {
-                    let x = from_substrait_rex(e, input.schema(), 
extensions).await?;
-                    group_expr.push(x.as_ref().clone());
-                }
-
                 for m in &agg.measures {
                     let filter = match &m.filter {
                         Some(fil) => Some(Box::new(
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 74a0ba63df..138825d061 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err, 
not_impl_err};
 #[allow(unused_imports)]
 use datafusion::logical_expr::aggregate_function;
 use datafusion::logical_expr::expr::{
-    Alias, BinaryExpr, Case, Cast, InList, ScalarFunction as DFScalarFunction, 
Sort,
-    WindowFunction,
+    Alias, BinaryExpr, Case, Cast, GroupingSet, InList,
+    ScalarFunction as DFScalarFunction, Sort, WindowFunction,
 };
 use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, 
Operator};
 use datafusion::prelude::Expr;
@@ -221,12 +221,11 @@ pub fn to_substrait_rel(
         }
         LogicalPlan::Aggregate(agg) => {
             let input = to_substrait_rel(agg.input.as_ref(), ctx, 
extension_info)?;
-            // Translate aggregate expression to Substrait's groupings 
(repeated repeated Expression)
-            let grouping = agg
-                .group_expr
-                .iter()
-                .map(|e| to_substrait_rex(e, agg.input.schema(), 0, 
extension_info))
-                .collect::<Result<Vec<_>>>()?;
+            let groupings = to_substrait_groupings(
+                &agg.group_expr,
+                agg.input.schema(),
+                extension_info,
+            )?;
             let measures = agg
                 .aggr_expr
                 .iter()
@@ -237,9 +236,7 @@ pub fn to_substrait_rel(
                 rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
                     common: None,
                     input: Some(input),
-                    groupings: vec![Grouping {
-                        grouping_expressions: grouping,
-                    }], //groupings,
+                    groupings,
                     measures,
                     advanced_extension: None,
                 }))),
@@ -491,6 +488,67 @@ pub fn operator_to_name(op: Operator) -> &'static str {
     }
 }
 
+pub fn parse_flat_grouping_exprs(
+    exprs: &[Expr],
+    schema: &DFSchemaRef,
+    extension_info: &mut (
+        Vec<extensions::SimpleExtensionDeclaration>,
+        HashMap<String, u32>,
+    ),
+) -> Result<Grouping> {
+    let grouping_expressions = exprs
+        .iter()
+        .map(|e| to_substrait_rex(e, schema, 0, extension_info))
+        .collect::<Result<Vec<_>>>()?;
+    Ok(Grouping {
+        grouping_expressions,
+    })
+}
+
+pub fn to_substrait_groupings(
+    exprs: &Vec<Expr>,
+    schema: &DFSchemaRef,
+    extension_info: &mut (
+        Vec<extensions::SimpleExtensionDeclaration>,
+        HashMap<String, u32>,
+    ),
+) -> Result<Vec<Grouping>> {
+    match exprs.len() {
+        1 => match &exprs[0] {
+            Expr::GroupingSet(gs) => match gs {
+                GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
+                    "GroupingSet CUBE is not yet supported".to_string(),
+                )),
+                GroupingSet::GroupingSets(sets) => Ok(sets
+                    .iter()
+                    .map(|set| parse_flat_grouping_exprs(set, schema, 
extension_info))
+                    .collect::<Result<Vec<_>>>()?),
+                GroupingSet::Rollup(set) => {
+                    let mut sets: Vec<Vec<Expr>> = vec![vec![]];
+                    for i in 0..set.len() {
+                        sets.push(set[..=i].to_vec());
+                    }
+                    Ok(sets
+                        .iter()
+                        .rev()
+                        .map(|set| parse_flat_grouping_exprs(set, schema, 
extension_info))
+                        .collect::<Result<Vec<_>>>()?)
+                }
+            },
+            _ => Ok(vec![parse_flat_grouping_exprs(
+                exprs,
+                schema,
+                extension_info,
+            )?]),
+        },
+        _ => Ok(vec![parse_flat_grouping_exprs(
+            exprs,
+            schema,
+            extension_info,
+        )?]),
+    }
+}
+
 #[allow(deprecated)]
 pub fn to_substrait_agg_measure(
     expr: &Expr,
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 90c3d199b7..f4d74ae426 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -209,6 +209,23 @@ async fn aggregate_multiple_keys() -> Result<()> {
     roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
 }
 
+#[tokio::test]
+async fn aggregate_grouping_sets() -> Result<()> {
+    roundtrip(
+        "SELECT a, c, d, avg(b) FROM data GROUP BY GROUPING SETS ((a, c), (a), 
(d), ())",
+    )
+    .await
+}
+
+#[tokio::test]
+async fn aggregate_grouping_rollup() -> Result<()> {
+    assert_expected_plan(
+        "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
+        "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), 
(data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\
+        \n  TableScan: data projection=[a, b, c, e]"
+    ).await
+}
+
 #[tokio::test]
 async fn decimal_literal() -> Result<()> {
     roundtrip("SELECT * FROM data WHERE b > 2.5").await

Reply via email to