This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 85ec314336 Add ROLLUP and GROUPING SETS substrait support (#7382)
85ec314336 is described below
commit 85ec31433615735a05d332f87cd0bdfc11aac663
Author: Nuttiiya Seekhao <[email protected]>
AuthorDate: Tue Aug 29 04:39:38 2023 -0700
Add ROLLUP and GROUPING SETS substrait support (#7382)
* Add ROLLUP and GROUPING SETS support
* fix: fmt
* clippy
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/substrait/src/logical_plan/consumer.rs | 43 ++++++++----
datafusion/substrait/src/logical_plan/producer.rs | 80 +++++++++++++++++++---
.../tests/cases/roundtrip_logical_plan.rs | 17 +++++
3 files changed, 117 insertions(+), 23 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 32cb1db4c3..32b8f8ea54 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -22,8 +22,10 @@ use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
};
-use datafusion::logical_expr::{expr, Cast, WindowFrameBound, WindowFrameUnits};
-use datafusion::logical_expr::{Extension, Like, LogicalPlanBuilder};
+use datafusion::logical_expr::{
+ expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder,
WindowFrameBound,
+ WindowFrameUnits,
+};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
@@ -251,18 +253,35 @@ pub async fn from_substrait_rel(
let mut group_expr = vec![];
let mut aggr_expr = vec![];
- let groupings = match agg.groupings.len() {
- 1 => Ok(&agg.groupings[0]),
- _ => not_impl_err!(
- "Aggregate with multiple grouping sets is not
supported"
- ),
+ match agg.groupings.len() {
+ 1 => {
+ for e in &agg.groupings[0].grouping_expressions {
+ let x =
+ from_substrait_rex(e, input.schema(),
extensions).await?;
+ group_expr.push(x.as_ref().clone());
+ }
+ }
+ _ => {
+ let mut grouping_sets = vec![];
+ for grouping in &agg.groupings {
+ let mut grouping_set = vec![];
+ for e in &grouping.grouping_expressions {
+ let x = from_substrait_rex(e, input.schema(),
extensions)
+ .await?;
+ grouping_set.push(x.as_ref().clone());
+ }
+ grouping_sets.push(grouping_set);
+ }
+ // Single-element grouping expression of type
Expr::GroupingSet.
+ // Note that GroupingSet::Rollup would become
GroupingSet::GroupingSets, when
+ // parsed by the producer and consumer, since
Substrait does not have a type dedicated
+ // to ROLLUP. Only vector of Groupings (grouping sets)
is available.
+
group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets(
+ grouping_sets,
+ )));
+ }
};
- for e in &groupings?.grouping_expressions {
- let x = from_substrait_rex(e, input.schema(),
extensions).await?;
- group_expr.push(x.as_ref().clone());
- }
-
for m in &agg.measures {
let filter = match &m.filter {
Some(fil) => Some(Box::new(
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 74a0ba63df..138825d061 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err,
not_impl_err};
#[allow(unused_imports)]
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{
- Alias, BinaryExpr, Case, Cast, InList, ScalarFunction as DFScalarFunction,
Sort,
- WindowFunction,
+ Alias, BinaryExpr, Case, Cast, GroupingSet, InList,
+ ScalarFunction as DFScalarFunction, Sort, WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan,
Operator};
use datafusion::prelude::Expr;
@@ -221,12 +221,11 @@ pub fn to_substrait_rel(
}
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), ctx,
extension_info)?;
- // Translate aggregate expression to Substrait's groupings
(repeated repeated Expression)
- let grouping = agg
- .group_expr
- .iter()
- .map(|e| to_substrait_rex(e, agg.input.schema(), 0,
extension_info))
- .collect::<Result<Vec<_>>>()?;
+ let groupings = to_substrait_groupings(
+ &agg.group_expr,
+ agg.input.schema(),
+ extension_info,
+ )?;
let measures = agg
.aggr_expr
.iter()
@@ -237,9 +236,7 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
- groupings: vec![Grouping {
- grouping_expressions: grouping,
- }], //groupings,
+ groupings,
measures,
advanced_extension: None,
}))),
@@ -491,6 +488,67 @@ pub fn operator_to_name(op: Operator) -> &'static str {
}
}
+pub fn parse_flat_grouping_exprs(
+ exprs: &[Expr],
+ schema: &DFSchemaRef,
+ extension_info: &mut (
+ Vec<extensions::SimpleExtensionDeclaration>,
+ HashMap<String, u32>,
+ ),
+) -> Result<Grouping> {
+ let grouping_expressions = exprs
+ .iter()
+ .map(|e| to_substrait_rex(e, schema, 0, extension_info))
+ .collect::<Result<Vec<_>>>()?;
+ Ok(Grouping {
+ grouping_expressions,
+ })
+}
+
+pub fn to_substrait_groupings(
+ exprs: &Vec<Expr>,
+ schema: &DFSchemaRef,
+ extension_info: &mut (
+ Vec<extensions::SimpleExtensionDeclaration>,
+ HashMap<String, u32>,
+ ),
+) -> Result<Vec<Grouping>> {
+ match exprs.len() {
+ 1 => match &exprs[0] {
+ Expr::GroupingSet(gs) => match gs {
+ GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented(
+ "GroupingSet CUBE is not yet supported".to_string(),
+ )),
+ GroupingSet::GroupingSets(sets) => Ok(sets
+ .iter()
+ .map(|set| parse_flat_grouping_exprs(set, schema,
extension_info))
+ .collect::<Result<Vec<_>>>()?),
+ GroupingSet::Rollup(set) => {
+ let mut sets: Vec<Vec<Expr>> = vec![vec![]];
+ for i in 0..set.len() {
+ sets.push(set[..=i].to_vec());
+ }
+ Ok(sets
+ .iter()
+ .rev()
+ .map(|set| parse_flat_grouping_exprs(set, schema,
extension_info))
+ .collect::<Result<Vec<_>>>()?)
+ }
+ },
+ _ => Ok(vec![parse_flat_grouping_exprs(
+ exprs,
+ schema,
+ extension_info,
+ )?]),
+ },
+ _ => Ok(vec![parse_flat_grouping_exprs(
+ exprs,
+ schema,
+ extension_info,
+ )?]),
+ }
+}
+
#[allow(deprecated)]
pub fn to_substrait_agg_measure(
expr: &Expr,
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 90c3d199b7..f4d74ae426 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -209,6 +209,23 @@ async fn aggregate_multiple_keys() -> Result<()> {
roundtrip("SELECT a, c, avg(b) FROM data GROUP BY a, c").await
}
+#[tokio::test]
+async fn aggregate_grouping_sets() -> Result<()> {
+ roundtrip(
+ "SELECT a, c, d, avg(b) FROM data GROUP BY GROUPING SETS ((a, c), (a),
(d), ())",
+ )
+ .await
+}
+
+#[tokio::test]
+async fn aggregate_grouping_rollup() -> Result<()> {
+ assert_expected_plan(
+ "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
+ "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e),
(data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\
+ \n TableScan: data projection=[a, b, c, e]"
+ ).await
+}
+
#[tokio::test]
async fn decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b > 2.5").await