alamb commented on code in PR #2716:
URL: https://github.com/apache/arrow-datafusion/pull/2716#discussion_r895151112
##########
datafusion/expr/src/expr_fn.rs:
##########
@@ -226,6 +227,21 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr
{
Expr::ScalarSubquery(Subquery { subquery })
}
+/// Create a grouping set
+pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
+ Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
+}
+
+/// Create a grouping set
Review Comment:
```suggestion
/// Create a grouping set for all combination of `exprs`
```
##########
datafusion/core/tests/sql/aggregates.rs:
##########
@@ -476,6 +476,205 @@ async fn csv_query_approx_percentile_cont() -> Result<()>
{
Ok(())
}
+#[tokio::test]
+async fn csv_query_cube_avg() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+
+ let sql = "SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE
(c1, c2) ORDER BY c1, c2";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+----------------------------+",
+ "| c1 | c2 | AVG(aggregate_test_100.c3) |",
+ "+----+----+----------------------------+",
+ "| a | 1 | -17.6 |",
+ "| a | 2 | -15.333333333333334 |",
+ "| a | 3 | -4.5 |",
+ "| a | 4 | -32 |",
+ "| a | 5 | -32 |",
+ "| a | | -18.333333333333332 |",
+ "| b | 1 | 31.666666666666668 |",
+ "| b | 2 | 25.5 |",
+ "| b | 3 | -42 |",
+ "| b | 4 | -44.6 |",
+ "| b | 5 | -0.2 |",
+ "| b | | -5.842105263157895 |",
+ "| c | 1 | 47.5 |",
+ "| c | 2 | -55.57142857142857 |",
+ "| c | 3 | 47.5 |",
+ "| c | 4 | -10.75 |",
+ "| c | 5 | 12 |",
+ "| c | | -1.3333333333333333 |",
+ "| d | 1 | -8.142857142857142 |",
+ "| d | 2 | 109.33333333333333 |",
+ "| d | 3 | 41.333333333333336 |",
+ "| d | 4 | 54 |",
+ "| d | 5 | -49.5 |",
+ "| d | | 25.444444444444443 |",
+ "| e | 1 | 75.66666666666667 |",
+ "| e | 2 | 37.8 |",
+ "| e | 3 | 48 |",
+ "| e | 4 | 37.285714285714285 |",
+ "| e | 5 | -11 |",
+ "| e | | 40.333333333333336 |",
+ "| | 1 | 16.681818181818183 |",
+ "| | 2 | 8.363636363636363 |",
+ "| | 3 | 20.789473684210527 |",
+ "| | 4 | 1.2608695652173914 |",
+ "| | 5 | -13.857142857142858 |",
+ "| | | 7.81 |",
+ "+----+----+----------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn csv_query_rollup_avg() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_aggregate_csv_by_sql(&ctx).await;
+
+ let sql = "SELECT c1, c2, c3, AVG(c4) FROM aggregate_test_100 GROUP BY
ROLLUP (c1, c2, c3) ORDER BY c1, c2, c3";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+----+----+------+----------------------------+",
+ "| c1 | c2 | c3 | AVG(aggregate_test_100.c4) |",
+ "+----+----+------+----------------------------+",
+ "| a | 1 | -85 | -15154 |",
+ "| a | 1 | -56 | 8692 |",
+ "| a | 1 | -25 | 15295 |",
+ "| a | 1 | -5 | 12636 |",
+ "| a | 1 | 83 | -14704 |",
+ "| a | 1 | | 1353 |",
+ "| a | 2 | -48 | -18025 |",
+ "| a | 2 | -43 | 13080 |",
+ "| a | 2 | 45 | 15673 |",
+ "| a | 2 | | 3576 |",
+ "| a | 3 | -72 | -11122 |",
+ "| a | 3 | -12 | -9168 |",
+ "| a | 3 | 13 | 22338.5 |",
+ "| a | 3 | 14 | 28162 |",
+ "| a | 3 | 17 | -22796 |",
+ "| a | 3 | | 4958.833333333333 |",
+ "| a | 4 | -101 | 11640 |",
+ "| a | 4 | -54 | -2376 |",
+ "| a | 4 | -38 | 20744 |",
+ "| a | 4 | 65 | -28462 |",
+ "| a | 4 | | 386.5 |",
+ "| a | 5 | -101 | -12484 |",
+ "| a | 5 | -31 | -12907 |",
+ "| a | 5 | 36 | -16974 |",
+ "| a | 5 | | -14121.666666666666 |",
+ "| a | | | 306.04761904761904 |",
+ "| b | 1 | 12 | 7652 |",
+ "| b | 1 | 29 | -18218 |",
+ "| b | 1 | 54 | -18410 |",
+ "| b | 1 | | -9658.666666666666 |",
+ "| b | 2 | -60 | -21739 |",
+ "| b | 2 | 31 | 23127 |",
+ "| b | 2 | 63 | 21456 |",
+ "| b | 2 | 68 | 15874 |",
+ "| b | 2 | | 9679.5 |",
+ "| b | 3 | -101 | -13217 |",
+ "| b | 3 | 17 | 14457 |",
+ "| b | 3 | | 620 |",
+ "| b | 4 | -117 | 19316 |",
+ "| b | 4 | -111 | -1967 |",
+ "| b | 4 | -59 | 25286 |",
+ "| b | 4 | 17 | -28070 |",
+ "| b | 4 | 47 | 20690 |",
+ "| b | 4 | | 7051 |",
+ "| b | 5 | -82 | 22080 |",
+ "| b | 5 | -44 | 15788 |",
+ "| b | 5 | -5 | 24896 |",
+ "| b | 5 | 62 | 16337 |",
+ "| b | 5 | 68 | 21576 |",
+ "| b | 5 | | 20135.4 |",
+ "| b | | | 7732.315789473684 |",
+ "| c | 1 | -24 | -24085 |",
+ "| c | 1 | 41 | -4667 |",
+ "| c | 1 | 70 | 27752 |",
+ "| c | 1 | 103 | -22186 |",
+ "| c | 1 | | -5796.5 |",
+ "| c | 2 | -117 | -30187 |",
+ "| c | 2 | -107 | -2904 |",
+ "| c | 2 | -106 | -1114 |",
+ "| c | 2 | -60 | -16312 |",
+ "| c | 2 | -29 | 25305 |",
+ "| c | 2 | 1 | 18109 |",
+ "| c | 2 | 29 | -3855 |",
+ "| c | 2 | | -1565.4285714285713 |",
+ "| c | 3 | -2 | -18655 |",
+ "| c | 3 | 22 | 13741 |",
+ "| c | 3 | 73 | -9565 |",
+ "| c | 3 | 97 | 29106 |",
+ "| c | 3 | | 3656.75 |",
+ "| c | 4 | -90 | -2935 |",
+ "| c | 4 | -79 | 5281 |",
+ "| c | 4 | 3 | -30508 |",
+ "| c | 4 | 123 | 16620 |",
+ "| c | 4 | | -2885.5 |",
+ "| c | 5 | -94 | -15880 |",
+ "| c | 5 | 118 | 19208 |",
+ "| c | 5 | | 1664 |",
+ "| c | | | -1320.5238095238096 |",
+ "| d | 1 | -99 | 5613 |",
+ "| d | 1 | -98 | 13630 |",
+ "| d | 1 | -72 | 25590 |",
+ "| d | 1 | -8 | 27138 |",
+ "| d | 1 | 38 | 18384 |",
+ "| d | 1 | 57 | 28781 |",
+ "| d | 1 | 125 | 31106 |",
+ "| d | 1 | | 21463.14285714286 |",
+ "| d | 2 | 93 | -12642 |",
+ "| d | 2 | 113 | 3917 |",
+ "| d | 2 | 122 | 10130 |",
+ "| d | 2 | | 468.3333333333333 |",
+ "| d | 3 | -76 | 8809 |",
+ "| d | 3 | 77 | 15091 |",
+ "| d | 3 | 123 | 29533 |",
+ "| d | 3 | | 17811 |",
+ "| d | 4 | 5 | -7688 |",
+ "| d | 4 | 55 | -1471 |",
+ "| d | 4 | 102 | -24558 |",
+ "| d | 4 | | -11239 |",
+ "| d | 5 | -59 | 2045 |",
+ "| d | 5 | -40 | 22614 |",
+ "| d | 5 | | 12329.5 |",
+ "| d | | | 10890.111111111111 |",
+ "| e | 1 | 36 | -21481 |",
+ "| e | 1 | 71 | -5479 |",
+ "| e | 1 | 120 | 10837 |",
+ "| e | 1 | | -5374.333333333333 |",
+ "| e | 2 | -61 | -2888 |",
+ "| e | 2 | 49 | 24495 |",
+ "| e | 2 | 52 | 5666 |",
+ "| e | 2 | 97 | 18167 |",
+ "| e | 2 | | 10221.2 |",
+ "| e | 3 | -95 | 13611 |",
+ "| e | 3 | 71 | 194 |",
+ "| e | 3 | 104 | -25136 |",
+ "| e | 3 | 112 | -6823 |",
+ "| e | 3 | | -4538.5 |",
+ "| e | 4 | -56 | -31500 |",
+ "| e | 4 | -53 | 13788 |",
+ "| e | 4 | 30 | -16110 |",
+ "| e | 4 | 73 | -22501 |",
+ "| e | 4 | 74 | -12612 |",
+ "| e | 4 | 96 | -30336 |",
+ "| e | 4 | 97 | -13181 |",
+ "| e | 4 | | -16064.57142857143 |",
+ "| e | 5 | -86 | 32514 |",
+ "| e | 5 | 64 | -26526 |",
+ "| e | 5 | | 2994 |",
Review Comment:
👍
##########
datafusion/core/src/lib.rs:
##########
@@ -204,6 +204,7 @@
/// DataFusion crate version
pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION");
+extern crate core;
Review Comment:
Why is this necessary?
##########
datafusion/expr/src/expr_fn.rs:
##########
@@ -226,6 +227,21 @@ pub fn scalar_subquery(subquery: Arc<LogicalPlan>) -> Expr
{
Expr::ScalarSubquery(Subquery { subquery })
}
+/// Create a grouping set
+pub fn grouping_set(exprs: Vec<Vec<Expr>>) -> Expr {
+ Expr::GroupingSet(GroupingSet::GroupingSets(exprs))
+}
+
+/// Create a grouping set
+pub fn cube(exprs: Vec<Expr>) -> Expr {
+ Expr::GroupingSet(GroupingSet::Cube(exprs))
+}
+
+/// Create a grouping set
Review Comment:
```suggestion
/// Create a grouping set for rollup
```
##########
datafusion/expr/src/expr.rs:
##########
@@ -270,6 +270,27 @@ pub enum GroupingSet {
GroupingSets(Vec<Vec<Expr>>),
}
+impl GroupingSet {
+ /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP`
this
+ /// is just the underlying list of exprs. For `GROUPING SET` we need to
deduplicate
+ /// the exprs in the underlying sets.
+ pub fn distinct_expr(&self) -> Vec<Expr> {
+ match self {
+ GroupingSet::Rollup(exprs) => exprs.clone(),
+ GroupingSet::Cube(exprs) => exprs.clone(),
+ GroupingSet::GroupingSets(groups) => {
+ let mut exprs: Vec<Expr> = vec![];
+ for exp in groups.iter().flatten() {
+ if !exprs.contains(exp) {
Review Comment:
This is N^2 in the number of grouping sets -- probably not an issue, I just
figured I would point it out
##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -90,18 +121,24 @@ impl AggregateExec {
/// Create a new hash aggregate execution plan
pub fn try_new(
mode: AggregateMode,
- group_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
+ grouping_set: PhysicalGroupingSet,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
- let schema = create_schema(&input.schema(), &group_expr, &aggr_expr,
mode)?;
+ let schema = create_schema(
+ &input.schema(),
+ &grouping_set.expr,
+ &aggr_expr,
+ grouping_set.groups.iter().flatten().any(|is_null| *is_null),
Review Comment:
I wonder if extracting this code to a function such as
`GroupingSets::contains_null()` might make the code easier to read. The same
comment applies to other places where `GroupingSets::groups` is referenced as
well.
Given the size of this PR already, definitely could be done as a follow on
##########
datafusion/expr/src/utils.rs:
##########
@@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut
HashSet<Column>) -> Result
Ok(())
}
+/// Find all distinct exprs in a list of group by expressions. We assume that
if the
+/// first element is a `GroupingSet` expression then that is the only expr.
Review Comment:
```suggestion
/// Find all distinct exprs in a list of group by expressions. If the
/// first element is a `GroupingSet` expression then it must be the only
expr.
```
##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -110,12 +111,15 @@ impl GroupedHashAggregateStreamV2 {
// The expressions to evaluate the batch, one vec of expressions per
aggregation.
// Assume create_schema() always put group columns in front of aggr
columns, we set
// col_idx_base to group expression count.
- let aggregate_expressions =
- aggregates::aggregate_expressions(&aggr_expr, &mode,
group_expr.len())?;
+ let aggregate_expressions = aggregates::aggregate_expressions(
Review Comment:
FYI @yjshen -- it would be really nice to try and consolidate `row_hash`
and `hash` -- filed https://github.com/apache/arrow-datafusion/issues/2723 to
track 👍
##########
datafusion/optimizer/src/single_distinct_to_groupby.rs:
##########
@@ -250,6 +280,29 @@ mod tests {
Ok(())
}
+ #[test]
+ fn single_distinct_and_grouping_set() -> Result<()> {
Review Comment:
Given there is special handling for CUBE and ROLLUP in this pass, I suggest
test coverage for those cases too
##########
datafusion/core/src/physical_plan/aggregates/mod.rs:
##########
@@ -65,13 +66,43 @@ pub enum AggregateMode {
FinalPartitioned,
}
+/// Represents a GROUPING SET in the physical plan
Review Comment:
👍
Thank you -- this is well commented and clear to understand
```suggestion
/// Represents `GROUP BY` clause in the plan plan (including the more
general GROUPING SET)
```
I was thinking that many people encountering SQL will not be familiar with
GROUPING SET but would be very familiar with GROUP BY -- what would you think
about calling this structure `PhysicalGroupBy` or something to make the
connection to `GROUPING` clearer?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]