This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a9d4d525df feat(substrait): AggregateRel grouping_expression support 
(#13173)
a9d4d525df is described below

commit a9d4d525df07dd2fc5eb6adc622a821cf54d44ba
Author: Andrey Koshchiy <[email protected]>
AuthorDate: Sun Nov 3 14:48:56 2024 +0300

    feat(substrait): AggregateRel grouping_expression support (#13173)
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 77 ++++++++++++-----
 datafusion/substrait/src/logical_plan/producer.rs  | 58 +++++++++----
 .../tests/cases/roundtrip_logical_plan.rs          | 13 +++
 ..._no_project_group_expression_ref.substrait.json | 98 ++++++++++++++++++++++
 4 files changed, 210 insertions(+), 36 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 7ccca8616b..890da7361d 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -33,6 +33,7 @@ use datafusion::logical_expr::{
     expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, 
Expr,
     ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
 };
+use substrait::proto::aggregate_rel::Grouping;
 use substrait::proto::expression::subquery::set_predicate::PredicateOp;
 use substrait::proto::expression_reference::ExprType;
 use url::Url;
@@ -665,39 +666,48 @@ pub async fn from_substrait_rel(
                 let input = LogicalPlanBuilder::from(
                     from_substrait_rel(ctx, input, extensions).await?,
                 );
-                let mut group_expr = vec![];
-                let mut aggr_expr = vec![];
+                let mut ref_group_exprs = vec![];
+
+                for e in &agg.grouping_expressions {
+                    let x =
+                        from_substrait_rex(ctx, e, input.schema(), 
extensions).await?;
+                    ref_group_exprs.push(x);
+                }
+
+                let mut group_exprs = vec![];
+                let mut aggr_exprs = vec![];
 
                 match agg.groupings.len() {
                     1 => {
-                        for e in &agg.groupings[0].grouping_expressions {
-                            let x =
-                                from_substrait_rex(ctx, e, input.schema(), 
extensions)
-                                    .await?;
-                            group_expr.push(x);
-                        }
+                        group_exprs.extend_from_slice(
+                            &from_substrait_grouping(
+                                ctx,
+                                &agg.groupings[0],
+                                &ref_group_exprs,
+                                input.schema(),
+                                extensions,
+                            )
+                            .await?,
+                        );
                     }
                     _ => {
                         let mut grouping_sets = vec![];
                         for grouping in &agg.groupings {
-                            let mut grouping_set = vec![];
-                            for e in &grouping.grouping_expressions {
-                                let x = from_substrait_rex(
-                                    ctx,
-                                    e,
-                                    input.schema(),
-                                    extensions,
-                                )
-                                .await?;
-                                grouping_set.push(x);
-                            }
+                            let grouping_set = from_substrait_grouping(
+                                ctx,
+                                grouping,
+                                &ref_group_exprs,
+                                input.schema(),
+                                extensions,
+                            )
+                            .await?;
                             grouping_sets.push(grouping_set);
                         }
                         // Single-element grouping expression of type 
Expr::GroupingSet.
                         // Note that GroupingSet::Rollup would become 
GroupingSet::GroupingSets, when
                         // parsed by the producer and consumer, since 
Substrait does not have a type dedicated
                         // to ROLLUP. Only vector of Groupings (grouping sets) 
is available.
-                        
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
+                        
group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets(
                             grouping_sets,
                         )));
                     }
@@ -755,9 +765,9 @@ pub async fn from_substrait_rel(
                             "Aggregate without aggregate function is not 
supported"
                         ),
                     };
-                    aggr_expr.push(agg_func?.as_ref().clone());
+                    aggr_exprs.push(agg_func?.as_ref().clone());
                 }
-                input.aggregate(group_expr, aggr_expr)?.build()
+                input.aggregate(group_exprs, aggr_exprs)?.build()
             } else {
                 not_impl_err!("Aggregate without an input is not valid")
             }
@@ -2762,6 +2772,29 @@ fn from_substrait_null(
     }
 }
 
+#[allow(deprecated)]
+async fn from_substrait_grouping(
+    ctx: &SessionContext,
+    grouping: &Grouping,
+    expressions: &[Expr],
+    input_schema: &DFSchemaRef,
+    extensions: &Extensions,
+) -> Result<Vec<Expr>> {
+    let mut group_exprs = vec![];
+    if !grouping.grouping_expressions.is_empty() {
+        for e in &grouping.grouping_expressions {
+            let expr = from_substrait_rex(ctx, e, input_schema, 
extensions).await?;
+            group_exprs.push(expr);
+        }
+        return Ok(group_exprs);
+    }
+    for idx in &grouping.expression_references {
+        let e = &expressions[*idx as usize];
+        group_exprs.push(e.clone());
+    }
+    Ok(group_exprs)
+}
+
 fn from_substrait_field_reference(
     field_ref: &FieldReference,
     input_schema: &DFSchema,
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index c73029f130..4d864e4334 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -361,7 +361,7 @@ pub fn to_substrait_rel(
         }
         LogicalPlan::Aggregate(agg) => {
             let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
-            let groupings = to_substrait_groupings(
+            let (grouping_expressions, groupings) = to_substrait_groupings(
                 ctx,
                 &agg.group_expr,
                 agg.input.schema(),
@@ -377,7 +377,7 @@ pub fn to_substrait_rel(
                 rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
                     common: None,
                     input: Some(input),
-                    grouping_expressions: vec![],
+                    grouping_expressions,
                     groupings,
                     measures,
                     advanced_extension: None,
@@ -774,14 +774,20 @@ pub fn parse_flat_grouping_exprs(
     exprs: &[Expr],
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
+    ref_group_exprs: &mut Vec<Expression>,
 ) -> Result<Grouping> {
-    let grouping_expressions = exprs
-        .iter()
-        .map(|e| to_substrait_rex(ctx, e, schema, 0, extensions))
-        .collect::<Result<Vec<_>>>()?;
+    let mut expression_references = vec![];
+    let mut grouping_expressions = vec![];
+
+    for e in exprs {
+        let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
+        grouping_expressions.push(rex.clone());
+        ref_group_exprs.push(rex);
+        expression_references.push((ref_group_exprs.len() - 1) as u32);
+    }
     Ok(Grouping {
         grouping_expressions,
-        expression_references: vec![],
+        expression_references,
     })
 }
 
@@ -790,8 +796,9 @@ pub fn to_substrait_groupings(
     exprs: &[Expr],
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
-) -> Result<Vec<Grouping>> {
-    match exprs.len() {
+) -> Result<(Vec<Expression>, Vec<Grouping>)> {
+    let mut ref_group_exprs = vec![];
+    let groupings = match exprs.len() {
         1 => match &exprs[0] {
             Expr::GroupingSet(gs) => match gs {
                 GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
@@ -799,7 +806,15 @@ pub fn to_substrait_groupings(
                 )),
                 GroupingSet::GroupingSets(sets) => Ok(sets
                     .iter()
-                    .map(|set| parse_flat_grouping_exprs(ctx, set, schema, 
extensions))
+                    .map(|set| {
+                        parse_flat_grouping_exprs(
+                            ctx,
+                            set,
+                            schema,
+                            extensions,
+                            &mut ref_group_exprs,
+                        )
+                    })
                     .collect::<Result<Vec<_>>>()?),
                 GroupingSet::Rollup(set) => {
                     let mut sets: Vec<Vec<Expr>> = vec![vec![]];
@@ -810,19 +825,34 @@ pub fn to_substrait_groupings(
                         .iter()
                         .rev()
                         .map(|set| {
-                            parse_flat_grouping_exprs(ctx, set, schema, 
extensions)
+                            parse_flat_grouping_exprs(
+                                ctx,
+                                set,
+                                schema,
+                                extensions,
+                                &mut ref_group_exprs,
+                            )
                         })
                         .collect::<Result<Vec<_>>>()?)
                 }
             },
             _ => Ok(vec![parse_flat_grouping_exprs(
-                ctx, exprs, schema, extensions,
+                ctx,
+                exprs,
+                schema,
+                extensions,
+                &mut ref_group_exprs,
             )?]),
         },
         _ => Ok(vec![parse_flat_grouping_exprs(
-            ctx, exprs, schema, extensions,
+            ctx,
+            exprs,
+            schema,
+            extensions,
+            &mut ref_group_exprs,
         )?]),
-    }
+    }?;
+    Ok((ref_group_exprs, groupings))
 }
 
 #[allow(deprecated)]
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 8fbdefe285..5687c9af54 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -665,6 +665,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
     .await
 }
 
+#[tokio::test]
+async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
+    let proto_plan =
+        
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");
+
+    assert_expected_plan_substrait(
+        proto_plan,
+        "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
+        \n  TableScan: data projection=[a]",
+    )
+    .await
+}
+
 #[tokio::test]
 async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
     let proto_plan =
diff --git 
a/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json
 
b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json
new file mode 100644
index 0000000000..b6f14afd6f
--- /dev/null
+++ 
b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json
@@ -0,0 +1,98 @@
+{
+  "extensionUris": [
+    {
+      "uri": 
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml";
+    }
+  ],
+  "extensions": [
+    {
+      "extensionFunction": {
+        "functionAnchor": 185,
+        "name": "count:any"
+      }
+    }
+  ],
+  "relations": [
+    {
+      "root": {
+        "input": {
+          "aggregate": {
+            "input": {
+              "read": {
+                "common": {
+                  "direct": {}
+                },
+                "baseSchema": {
+                  "names": [
+                    "a"
+                  ],
+                  "struct": {
+                    "types": [
+                      {
+                        "i64": {
+                          "nullability": "NULLABILITY_NULLABLE"
+                        }
+                      }
+                    ],
+                    "nullability": "NULLABILITY_NULLABLE"
+                  }
+                },
+                "namedTable": {
+                  "names": [
+                    "data"
+                  ]
+                }
+              }
+            },
+            "grouping_expressions": [
+              {
+                "selection": {
+                  "directReference": {
+                    "structField": {}
+                  },
+                  "rootReference": {}
+                }
+              }
+            ],
+            "groupings": [
+              {
+                "expression_references": [0]
+              }
+            ],
+            "measures": [
+              {
+                "measure": {
+                  "functionReference": 185,
+                  "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+                  "outputType": {
+                    "i64": {}
+                  },
+                  "arguments": [
+                    {
+                      "value": {
+                        "selection": {
+                          "directReference": {
+                            "structField": {}
+                          },
+                          "rootReference": {}
+                        }
+                      }
+                    }
+                  ]
+                }
+              }
+            ]
+          }
+        },
+        "names": [
+          "a",
+          "countA"
+        ]
+      }
+    }
+  ],
+  "version": {
+    "minorNumber": 54,
+    "producer": "subframe"
+  }
+}
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to