This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a9d4d525df feat(substrait): AggregateRel grouping_expression support
(#13173)
a9d4d525df is described below
commit a9d4d525df07dd2fc5eb6adc622a821cf54d44ba
Author: Andrey Koshchiy <[email protected]>
AuthorDate: Sun Nov 3 14:48:56 2024 +0300
feat(substrait): AggregateRel grouping_expression support (#13173)
---
datafusion/substrait/src/logical_plan/consumer.rs | 77 ++++++++++++-----
datafusion/substrait/src/logical_plan/producer.rs | 58 +++++++++----
.../tests/cases/roundtrip_logical_plan.rs | 13 +++
..._no_project_group_expression_ref.substrait.json | 98 ++++++++++++++++++++++
4 files changed, 210 insertions(+), 36 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 7ccca8616b..890da7361d 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -33,6 +33,7 @@ use datafusion::logical_expr::{
expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation,
Expr,
ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values,
};
+use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use substrait::proto::expression_reference::ExprType;
use url::Url;
@@ -665,39 +666,48 @@ pub async fn from_substrait_rel(
let input = LogicalPlanBuilder::from(
from_substrait_rel(ctx, input, extensions).await?,
);
- let mut group_expr = vec![];
- let mut aggr_expr = vec![];
+ let mut ref_group_exprs = vec![];
+
+ for e in &agg.grouping_expressions {
+ let x =
+ from_substrait_rex(ctx, e, input.schema(),
extensions).await?;
+ ref_group_exprs.push(x);
+ }
+
+ let mut group_exprs = vec![];
+ let mut aggr_exprs = vec![];
match agg.groupings.len() {
1 => {
- for e in &agg.groupings[0].grouping_expressions {
- let x =
- from_substrait_rex(ctx, e, input.schema(),
extensions)
- .await?;
- group_expr.push(x);
- }
+ group_exprs.extend_from_slice(
+ &from_substrait_grouping(
+ ctx,
+ &agg.groupings[0],
+ &ref_group_exprs,
+ input.schema(),
+ extensions,
+ )
+ .await?,
+ );
}
_ => {
let mut grouping_sets = vec![];
for grouping in &agg.groupings {
- let mut grouping_set = vec![];
- for e in &grouping.grouping_expressions {
- let x = from_substrait_rex(
- ctx,
- e,
- input.schema(),
- extensions,
- )
- .await?;
- grouping_set.push(x);
- }
+ let grouping_set = from_substrait_grouping(
+ ctx,
+ grouping,
+ &ref_group_exprs,
+ input.schema(),
+ extensions,
+ )
+ .await?;
grouping_sets.push(grouping_set);
}
// Single-element grouping expression of type
Expr::GroupingSet.
// Note that GroupingSet::Rollup would become
GroupingSet::GroupingSets, when
// parsed by the producer and consumer, since
Substrait does not have a type dedicated
// to ROLLUP. Only vector of Groupings (grouping sets)
is available.
-
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
+
group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets(
grouping_sets,
)));
}
@@ -755,9 +765,9 @@ pub async fn from_substrait_rel(
"Aggregate without aggregate function is not
supported"
),
};
- aggr_expr.push(agg_func?.as_ref().clone());
+ aggr_exprs.push(agg_func?.as_ref().clone());
}
- input.aggregate(group_expr, aggr_expr)?.build()
+ input.aggregate(group_exprs, aggr_exprs)?.build()
} else {
not_impl_err!("Aggregate without an input is not valid")
}
@@ -2762,6 +2772,29 @@ fn from_substrait_null(
}
}
+#[allow(deprecated)]
+async fn from_substrait_grouping(
+ ctx: &SessionContext,
+ grouping: &Grouping,
+ expressions: &[Expr],
+ input_schema: &DFSchemaRef,
+ extensions: &Extensions,
+) -> Result<Vec<Expr>> {
+ let mut group_exprs = vec![];
+ if !grouping.grouping_expressions.is_empty() {
+ for e in &grouping.grouping_expressions {
+ let expr = from_substrait_rex(ctx, e, input_schema,
extensions).await?;
+ group_exprs.push(expr);
+ }
+ return Ok(group_exprs);
+ }
+ for idx in &grouping.expression_references {
+ let e = &expressions[*idx as usize];
+ group_exprs.push(e.clone());
+ }
+ Ok(group_exprs)
+}
+
fn from_substrait_field_reference(
field_ref: &FieldReference,
input_schema: &DFSchema,
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index c73029f130..4d864e4334 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -361,7 +361,7 @@ pub fn to_substrait_rel(
}
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
- let groupings = to_substrait_groupings(
+ let (grouping_expressions, groupings) = to_substrait_groupings(
ctx,
&agg.group_expr,
agg.input.schema(),
@@ -377,7 +377,7 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
- grouping_expressions: vec![],
+ grouping_expressions,
groupings,
measures,
advanced_extension: None,
@@ -774,14 +774,20 @@ pub fn parse_flat_grouping_exprs(
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
+ ref_group_exprs: &mut Vec<Expression>,
) -> Result<Grouping> {
- let grouping_expressions = exprs
- .iter()
- .map(|e| to_substrait_rex(ctx, e, schema, 0, extensions))
- .collect::<Result<Vec<_>>>()?;
+ let mut expression_references = vec![];
+ let mut grouping_expressions = vec![];
+
+ for e in exprs {
+ let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
+ grouping_expressions.push(rex.clone());
+ ref_group_exprs.push(rex);
+ expression_references.push((ref_group_exprs.len() - 1) as u32);
+ }
Ok(Grouping {
grouping_expressions,
- expression_references: vec![],
+ expression_references,
})
}
@@ -790,8 +796,9 @@ pub fn to_substrait_groupings(
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
-) -> Result<Vec<Grouping>> {
- match exprs.len() {
+) -> Result<(Vec<Expression>, Vec<Grouping>)> {
+ let mut ref_group_exprs = vec![];
+ let groupings = match exprs.len() {
1 => match &exprs[0] {
Expr::GroupingSet(gs) => match gs {
GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
@@ -799,7 +806,15 @@ pub fn to_substrait_groupings(
)),
GroupingSet::GroupingSets(sets) => Ok(sets
.iter()
- .map(|set| parse_flat_grouping_exprs(ctx, set, schema,
extensions))
+ .map(|set| {
+ parse_flat_grouping_exprs(
+ ctx,
+ set,
+ schema,
+ extensions,
+ &mut ref_group_exprs,
+ )
+ })
.collect::<Result<Vec<_>>>()?),
GroupingSet::Rollup(set) => {
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
@@ -810,19 +825,34 @@ pub fn to_substrait_groupings(
.iter()
.rev()
.map(|set| {
- parse_flat_grouping_exprs(ctx, set, schema,
extensions)
+ parse_flat_grouping_exprs(
+ ctx,
+ set,
+ schema,
+ extensions,
+ &mut ref_group_exprs,
+ )
})
.collect::<Result<Vec<_>>>()?)
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
- ctx, exprs, schema, extensions,
+ ctx,
+ exprs,
+ schema,
+ extensions,
+ &mut ref_group_exprs,
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
- ctx, exprs, schema, extensions,
+ ctx,
+ exprs,
+ schema,
+ extensions,
+ &mut ref_group_exprs,
)?]),
- }
+ }?;
+ Ok((ref_group_exprs, groupings))
}
#[allow(deprecated)]
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 8fbdefe285..5687c9af54 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -665,6 +665,19 @@ async fn aggregate_wo_projection_consume() -> Result<()> {
.await
}
+#[tokio::test]
+async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> {
+ let proto_plan =
+
read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json");
+
+ assert_expected_plan_substrait(
+ proto_plan,
+ "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\
+ \n TableScan: data projection=[a]",
+ )
+ .await
+}
+
#[tokio::test]
async fn aggregate_wo_projection_sorted_consume() -> Result<()> {
let proto_plan =
diff --git
a/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json
b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json
new file mode 100644
index 0000000000..b6f14afd6f
--- /dev/null
+++
b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json
@@ -0,0 +1,98 @@
+{
+ "extensionUris": [
+ {
+ "uri":
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml"
+ }
+ ],
+ "extensions": [
+ {
+ "extensionFunction": {
+ "functionAnchor": 185,
+ "name": "count:any"
+ }
+ }
+ ],
+ "relations": [
+ {
+ "root": {
+ "input": {
+ "aggregate": {
+ "input": {
+ "read": {
+ "common": {
+ "direct": {}
+ },
+ "baseSchema": {
+ "names": [
+ "a"
+ ],
+ "struct": {
+ "types": [
+ {
+ "i64": {
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ "namedTable": {
+ "names": [
+ "data"
+ ]
+ }
+ }
+ },
+ "grouping_expressions": [
+ {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ ],
+ "groupings": [
+ {
+ "expression_references": [0]
+ }
+ ],
+ "measures": [
+ {
+ "measure": {
+ "functionReference": 185,
+ "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT",
+ "outputType": {
+ "i64": {}
+ },
+ "arguments": [
+ {
+ "value": {
+ "selection": {
+ "directReference": {
+ "structField": {}
+ },
+ "rootReference": {}
+ }
+ }
+ }
+ ]
+ }
+ }
+ ]
+ }
+ },
+ "names": [
+ "a",
+ "countA"
+ ]
+ }
+ }
+ ],
+ "version": {
+ "minorNumber": 54,
+ "producer": "subframe"
+ }
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]