This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ef227f41cb fix: Correct results for grouping sets when columns contain 
nulls (#12571)
ef227f41cb is described below

commit ef227f41cba69718e16557e164415d20b83a4bd6
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Mon Oct 7 18:24:46 2024 +0200

    fix: Correct results for grouping sets when columns contain nulls (#12571)
    
    * Fix grouping sets behavior when data contains nulls
    
    * PR suggestion comment
    
    * Update new test case
    
    * Add grouping_id to the logical plan
    
    * Add doc comment next to INTERNAL_GROUPING_ID
    
    * Fix unparsing of Aggregate with grouping sets
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/dataframe/mod.rs               |  17 +
 datafusion/core/src/physical_planner.rs            |  14 +-
 datafusion/expr/src/logical_plan/plan.rs           |  56 +++-
 datafusion/expr/src/utils.rs                       |  12 +-
 .../optimizer/src/single_distinct_to_groupby.rs    |   6 +-
 datafusion/physical-plan/src/aggregates/mod.rs     | 370 +++++++++++++--------
 .../physical-plan/src/aggregates/row_hash.rs       |   6 +-
 datafusion/sql/src/unparser/utils.rs               |  17 +-
 datafusion/sqllogictest/test_files/aggregate.slt   |  32 +-
 datafusion/sqllogictest/test_files/group_by.slt    |  11 +-
 .../tests/cases/roundtrip_logical_plan.rs          |   5 +-
 11 files changed, 359 insertions(+), 187 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index f5867881da..67e2a4780d 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -535,9 +535,26 @@ impl DataFrame {
         group_expr: Vec<Expr>,
         aggr_expr: Vec<Expr>,
     ) -> Result<DataFrame> {
+        let is_grouping_set = matches!(group_expr.as_slice(), 
[Expr::GroupingSet(_)]);
+        let aggr_expr_len = aggr_expr.len();
         let plan = LogicalPlanBuilder::from(self.plan)
             .aggregate(group_expr, aggr_expr)?
             .build()?;
+        let plan = if is_grouping_set {
+            let grouping_id_pos = plan.schema().fields().len() - 1 - 
aggr_expr_len;
+            // For grouping sets we do a project to not expose the internal 
grouping id
+            let exprs = plan
+                .schema()
+                .columns()
+                .into_iter()
+                .enumerate()
+                .filter(|(idx, _)| *idx != grouping_id_pos)
+                .map(|(_, column)| Expr::Column(column))
+                .collect::<Vec<_>>();
+            LogicalPlanBuilder::from(plan).project(exprs)?.build()?
+        } else {
+            plan
+        };
         Ok(DataFrame {
             session_state: self.session_state,
             plan,
diff --git a/datafusion/core/src/physical_planner.rs 
b/datafusion/core/src/physical_planner.rs
index 78c70606bf..cf2a157b04 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner {
                     physical_input_schema.clone(),
                 )?);
 
-                // update group column indices based on partial aggregate plan 
evaluation
-                let final_group: Vec<Arc<dyn PhysicalExpr>> =
-                    initial_aggr.output_group_expr();
-
                 let can_repartition = !groups.is_empty()
                     && session_state.config().target_partitions() > 1
                     && session_state.config().repartition_aggregations();
@@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner {
                     AggregateMode::Final
                 };
 
-                let final_grouping_set = PhysicalGroupBy::new_single(
-                    final_group
-                        .iter()
-                        .enumerate()
-                        .map(|(i, expr)| (expr.clone(), 
groups.expr()[i].1.clone()))
-                        .collect(),
-                );
+                let final_grouping_set = initial_aggr.group_expr().as_final();
 
                 Arc::new(AggregateExec::try_new(
                     next_partition_mode,
@@ -2345,7 +2335,7 @@ mod tests {
             .expect("hash aggregate");
         assert_eq!(
             "sum(aggregate_test_100.c3)",
-            final_hash_agg.schema().field(2).name()
+            final_hash_agg.schema().field(3).name()
         );
         // we need access to the input to the partial aggregate so that other 
projects can
         // implement serde
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 19e73140b7..0292274e57 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -21,7 +21,7 @@ use std::cmp::Ordering;
 use std::collections::{HashMap, HashSet};
 use std::fmt::{self, Debug, Display, Formatter};
 use std::hash::{Hash, Hasher};
-use std::sync::Arc;
+use std::sync::{Arc, OnceLock};
 
 use super::dml::CopyTo;
 use super::DdlStatement;
@@ -2965,6 +2965,15 @@ impl Aggregate {
                 .into_iter()
                 .map(|(q, f)| (q, 
f.as_ref().clone().with_nullable(true).into()))
                 .collect::<Vec<_>>();
+            qualified_fields.push((
+                None,
+                Field::new(
+                    Self::INTERNAL_GROUPING_ID,
+                    Self::grouping_id_type(qualified_fields.len()),
+                    false,
+                )
+                .into(),
+            ));
         }
 
         qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), 
&input)?);
@@ -3016,9 +3025,19 @@ impl Aggregate {
         })
     }
 
+    fn is_grouping_set(&self) -> bool {
+        matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
+    }
+
     /// Get the output expressions.
     fn output_expressions(&self) -> Result<Vec<&Expr>> {
+        static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
         let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
+        if self.is_grouping_set() {
+            exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
+                Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
+            }));
+        }
         exprs.extend(self.aggr_expr.iter());
         debug_assert!(exprs.len() == self.schema.fields().len());
         Ok(exprs)
@@ -3030,6 +3049,41 @@ impl Aggregate {
     pub fn group_expr_len(&self) -> Result<usize> {
         grouping_set_expr_count(&self.group_expr)
     }
+
+    /// Returns the data type of the grouping id.
+    /// The grouping ID value is a bitmask where each set bit
+    /// indicates that the corresponding grouping expression is
+    /// null
+    pub fn grouping_id_type(group_exprs: usize) -> DataType {
+        if group_exprs <= 8 {
+            DataType::UInt8
+        } else if group_exprs <= 16 {
+            DataType::UInt16
+        } else if group_exprs <= 32 {
+            DataType::UInt32
+        } else {
+            DataType::UInt64
+        }
+    }
+
+    /// Internal column used when the aggregation is a grouping set.
+    ///
+    /// This column contains a bitmask where each bit represents a grouping
+    /// expression. The least significant bit corresponds to the rightmost
+    /// grouping expression. A bit value of 0 indicates that the corresponding
+    /// column is included in the grouping set, while a value of 1 means it is 
excluded.
+    ///
+    /// For example, for the grouping expressions CUBE(a, b), the grouping ID
+    /// column will have the following values:
+    ///     0b00: Both `a` and `b` are included
+    ///     0b01: `b` is excluded
+    ///     0b10: `a` is excluded
+    ///     0b11: Both `a` and `b` are excluded
+    ///
+    /// This internal column is necessary because excluded columns are replaced
+    /// with `NULL` values. To handle these cases correctly, we must 
distinguish
+    /// between an actual `NULL` value in a column and a column being excluded 
from the set.
+    pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
 }
 
 // Manual implementation needed because of `schema` field. Comparison excludes 
this field.
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index fa92759504..02b36d0fea 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -61,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut 
HashSet<Column>) -> Result
 /// Count the number of distinct exprs in a list of group by expressions. If 
the
 /// first element is a `GroupingSet` expression then it must be the only expr.
 pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
-    grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
+    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
+        if group_expr.len() > 1 {
+            return plan_err!(
+                "Invalid group by expressions, GroupingSet must be the only 
expression"
+            );
+        }
+        // Groupings sets have an additional interal column for the grouping id
+        Ok(grouping_set.distinct_expr().len() + 1)
+    } else {
+        grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
+    }
 }
 
 /// The [power set] (or powerset) of a set S is the set of all subsets of S, \
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs 
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 1c22c2a437..74251e5caa 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -355,7 +355,7 @@ mod tests {
             .build()?;
 
         // Should not be optimized
-        let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), 
(test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, 
count(DISTINCT test.c):Int64]\
+        let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), 
(test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, 
__grouping_id:UInt8, count(DISTINCT test.c):Int64]\
                             \n  TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]";
 
         assert_optimized_plan_equal(plan, expected)
@@ -373,7 +373,7 @@ mod tests {
             .build()?;
 
         // Should not be optimized
-        let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], 
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT 
test.c):Int64]\
+        let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], 
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, 
count(DISTINCT test.c):Int64]\
                             \n  TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]";
 
         assert_optimized_plan_equal(plan, expected)
@@ -392,7 +392,7 @@ mod tests {
             .build()?;
 
         // Should not be optimized
-        let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], 
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT 
test.c):Int64]\
+        let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], 
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, 
count(DISTINCT test.c):Int64]\
                             \n  TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]";
 
         assert_optimized_plan_equal(plan, expected)
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs 
b/datafusion/physical-plan/src/aggregates/mod.rs
index 9466ff6dd4..f9dd973c81 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -36,10 +36,11 @@ use crate::{
 use arrow::array::ArrayRef;
 use arrow::datatypes::{Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
+use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array};
 use datafusion_common::stats::Precision;
 use datafusion_common::{internal_err, not_impl_err, Result};
 use datafusion_execution::TaskContext;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, Aggregate};
 use datafusion_physical_expr::{
     equivalence::{collapse_lex_req, ProjectionMapping},
     expressions::Column,
@@ -211,13 +212,99 @@ impl PhysicalGroupBy {
             .collect()
     }
 
+    /// The number of expressions in the output schema.
+    fn num_output_exprs(&self) -> usize {
+        let mut num_exprs = self.expr.len();
+        if !self.is_single() {
+            num_exprs += 1
+        }
+        num_exprs
+    }
+
     /// Return grouping expressions as they occur in the output schema.
     pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        self.expr
-            .iter()
-            .enumerate()
-            .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _)
-            .collect()
+        let num_output_exprs = self.num_output_exprs();
+        let mut output_exprs = Vec::with_capacity(num_output_exprs);
+        output_exprs.extend(
+            self.expr
+                .iter()
+                .enumerate()
+                .take(num_output_exprs)
+                .map(|(index, (_, name))| Arc::new(Column::new(name, index)) 
as _),
+        );
+        if !self.is_single() {
+            output_exprs.push(Arc::new(Column::new(
+                Aggregate::INTERNAL_GROUPING_ID,
+                self.expr.len(),
+            )) as _);
+        }
+        output_exprs
+    }
+
+    /// Returns the number expression as grouping keys.
+    fn num_group_exprs(&self) -> usize {
+        if self.is_single() {
+            self.expr.len()
+        } else {
+            self.expr.len() + 1
+        }
+    }
+
+    /// Returns the fields that are used as the grouping keys.
+    fn group_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
+        let mut fields = Vec::with_capacity(self.num_group_exprs());
+        for ((expr, name), group_expr_nullable) in
+            self.expr.iter().zip(self.exprs_nullable().into_iter())
+        {
+            fields.push(
+                Field::new(
+                    name,
+                    expr.data_type(input_schema)?,
+                    group_expr_nullable || expr.nullable(input_schema)?,
+                )
+                .with_metadata(
+                    get_field_metadata(expr, input_schema).unwrap_or_default(),
+                ),
+            );
+        }
+        if !self.is_single() {
+            fields.push(Field::new(
+                Aggregate::INTERNAL_GROUPING_ID,
+                Aggregate::grouping_id_type(self.expr.len()),
+                false,
+            ));
+        }
+        Ok(fields)
+    }
+
+    /// Returns the output fields of the group by.
+    ///
+    /// This might be different from the `group_fields` that might contain 
internal expressions that
+    /// should not be part of the output schema.
+    fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
+        let mut fields = self.group_fields(input_schema)?;
+        fields.truncate(self.num_output_exprs());
+        Ok(fields)
+    }
+
+    /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is 
used for a partial
+    /// aggregation.
+    pub fn as_final(&self) -> PhysicalGroupBy {
+        let expr: Vec<_> =
+            self.output_exprs()
+                .into_iter()
+                .zip(
+                    self.expr.iter().map(|t| 
t.1.clone()).chain(std::iter::once(
+                        Aggregate::INTERNAL_GROUPING_ID.to_owned(),
+                    )),
+                )
+                .collect();
+        let num_exprs = expr.len();
+        Self {
+            expr,
+            null_expr: vec![],
+            groups: vec![vec![false; num_exprs]],
+        }
     }
 }
 
@@ -321,13 +408,7 @@ impl AggregateExec {
         input: Arc<dyn ExecutionPlan>,
         input_schema: SchemaRef,
     ) -> Result<Self> {
-        let schema = create_schema(
-            &input.schema(),
-            &group_by.expr,
-            &aggr_expr,
-            group_by.exprs_nullable(),
-            mode,
-        )?;
+        let schema = create_schema(&input.schema(), &group_by, &aggr_expr, 
mode)?;
 
         let schema = Arc::new(schema);
         AggregateExec::try_new_with_schema(
@@ -789,25 +870,12 @@ impl ExecutionPlan for AggregateExec {
 
 fn create_schema(
     input_schema: &Schema,
-    group_expr: &[(Arc<dyn PhysicalExpr>, String)],
+    group_by: &PhysicalGroupBy,
     aggr_expr: &[AggregateFunctionExpr],
-    group_expr_nullable: Vec<bool>,
     mode: AggregateMode,
 ) -> Result<Schema> {
-    let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
-    for (index, (expr, name)) in group_expr.iter().enumerate() {
-        fields.push(
-            Field::new(
-                name,
-                expr.data_type(input_schema)?,
-                // In cases where we have multiple grouping sets, we will use 
NULL expressions in
-                // order to align the grouping sets. So the field must be 
nullable even if the underlying
-                // schema field is not.
-                group_expr_nullable[index] || expr.nullable(input_schema)?,
-            )
-            .with_metadata(get_field_metadata(expr, 
input_schema).unwrap_or_default()),
-        )
-    }
+    let mut fields = Vec::with_capacity(group_by.num_output_exprs() + 
aggr_expr.len());
+    fields.extend(group_by.output_fields(input_schema)?);
 
     match mode {
         AggregateMode::Partial => {
@@ -833,9 +901,8 @@ fn create_schema(
     ))
 }
 
-fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
-    let group_fields = schema.fields()[0..group_count].to_vec();
-    Arc::new(Schema::new(group_fields))
+fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> 
Result<SchemaRef> {
+    Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?)))
 }
 
 /// Determines the lexical ordering requirement for an aggregate expression.
@@ -1142,6 +1209,27 @@ fn evaluate_optional(
         .collect()
 }
 
+fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
+    if group.len() > 64 {
+        return not_impl_err!(
+            "Grouping sets with more than 64 columns are not supported"
+        );
+    }
+    let group_id = group.iter().fold(0u64, |acc, &is_null| {
+        (acc << 1) | if is_null { 1 } else { 0 }
+    });
+    let num_rows = batch.num_rows();
+    if group.len() <= 8 {
+        Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
+    } else if group.len() <= 16 {
+        Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
+    } else if group.len() <= 32 {
+        Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
+    } else {
+        Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
+    }
+}
+
 /// Evaluate a group by expression against a `RecordBatch`
 ///
 /// Arguments:
@@ -1174,23 +1262,24 @@ pub(crate) fn evaluate_group_by(
         })
         .collect::<Result<Vec<_>>>()?;
 
-    Ok(group_by
+    group_by
         .groups
         .iter()
         .map(|group| {
-            group
-                .iter()
-                .enumerate()
-                .map(|(idx, is_null)| {
-                    if *is_null {
-                        Arc::clone(&null_exprs[idx])
-                    } else {
-                        Arc::clone(&exprs[idx])
-                    }
-                })
-                .collect()
+            let mut group_values = 
Vec::with_capacity(group_by.num_group_exprs());
+            group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
+                if *is_null {
+                    Arc::clone(&null_exprs[idx])
+                } else {
+                    Arc::clone(&exprs[idx])
+                }
+            }));
+            if !group_by.is_single() {
+                group_values.push(group_id_array(group, batch)?);
+            }
+            Ok(group_values)
         })
-        .collect())
+        .collect()
 }
 
 #[cfg(test)]
@@ -1348,21 +1437,21 @@ mod tests {
     ) -> Result<()> {
         let input_schema = input.schema();
 
-        let grouping_set = PhysicalGroupBy {
-            expr: vec![
+        let grouping_set = PhysicalGroupBy::new(
+            vec![
                 (col("a", &input_schema)?, "a".to_string()),
                 (col("b", &input_schema)?, "b".to_string()),
             ],
-            null_expr: vec![
+            vec![
                 (lit(ScalarValue::UInt32(None)), "a".to_string()),
                 (lit(ScalarValue::Float64(None)), "b".to_string()),
             ],
-            groups: vec![
+            vec![
                 vec![false, true],  // (a, NULL)
                 vec![true, false],  // (NULL, b)
                 vec![false, false], // (a,b)
             ],
-        };
+        );
 
         let aggregates = vec![AggregateExprBuilder::new(count_udaf(), 
vec![lit(1i8)])
             .schema(Arc::clone(&input_schema))
@@ -1392,63 +1481,56 @@ mod tests {
             // In spill mode, we test with the limited memory, if the mem 
usage exceeds,
             // we trigger the early emit rule, which turns out the partial 
aggregate result.
             vec![
-                "+---+-----+-----------------+",
-                "| a | b   | COUNT(1)[count] |",
-                "+---+-----+-----------------+",
-                "|   | 1.0 | 1               |",
-                "|   | 1.0 | 1               |",
-                "|   | 2.0 | 1               |",
-                "|   | 2.0 | 1               |",
-                "|   | 3.0 | 1               |",
-                "|   | 3.0 | 1               |",
-                "|   | 4.0 | 1               |",
-                "|   | 4.0 | 1               |",
-                "| 2 |     | 1               |",
-                "| 2 |     | 1               |",
-                "| 2 | 1.0 | 1               |",
-                "| 2 | 1.0 | 1               |",
-                "| 3 |     | 1               |",
-                "| 3 |     | 2               |",
-                "| 3 | 2.0 | 2               |",
-                "| 3 | 3.0 | 1               |",
-                "| 4 |     | 1               |",
-                "| 4 |     | 2               |",
-                "| 4 | 3.0 | 1               |",
-                "| 4 | 4.0 | 2               |",
-                "+---+-----+-----------------+",
+                "+---+-----+---------------+-----------------+",
+                "| a | b   | __grouping_id | COUNT(1)[count] |",
+                "+---+-----+---------------+-----------------+",
+                "|   | 1.0 | 2             | 1               |",
+                "|   | 1.0 | 2             | 1               |",
+                "|   | 2.0 | 2             | 1               |",
+                "|   | 2.0 | 2             | 1               |",
+                "|   | 3.0 | 2             | 1               |",
+                "|   | 3.0 | 2             | 1               |",
+                "|   | 4.0 | 2             | 1               |",
+                "|   | 4.0 | 2             | 1               |",
+                "| 2 |     | 1             | 1               |",
+                "| 2 |     | 1             | 1               |",
+                "| 2 | 1.0 | 0             | 1               |",
+                "| 2 | 1.0 | 0             | 1               |",
+                "| 3 |     | 1             | 1               |",
+                "| 3 |     | 1             | 2               |",
+                "| 3 | 2.0 | 0             | 2               |",
+                "| 3 | 3.0 | 0             | 1               |",
+                "| 4 |     | 1             | 1               |",
+                "| 4 |     | 1             | 2               |",
+                "| 4 | 3.0 | 0             | 1               |",
+                "| 4 | 4.0 | 0             | 2               |",
+                "+---+-----+---------------+-----------------+",
             ]
         } else {
             vec![
-                "+---+-----+-----------------+",
-                "| a | b   | COUNT(1)[count] |",
-                "+---+-----+-----------------+",
-                "|   | 1.0 | 2               |",
-                "|   | 2.0 | 2               |",
-                "|   | 3.0 | 2               |",
-                "|   | 4.0 | 2               |",
-                "| 2 |     | 2               |",
-                "| 2 | 1.0 | 2               |",
-                "| 3 |     | 3               |",
-                "| 3 | 2.0 | 2               |",
-                "| 3 | 3.0 | 1               |",
-                "| 4 |     | 3               |",
-                "| 4 | 3.0 | 1               |",
-                "| 4 | 4.0 | 2               |",
-                "+---+-----+-----------------+",
+                "+---+-----+---------------+-----------------+",
+                "| a | b   | __grouping_id | COUNT(1)[count] |",
+                "+---+-----+---------------+-----------------+",
+                "|   | 1.0 | 2             | 2               |",
+                "|   | 2.0 | 2             | 2               |",
+                "|   | 3.0 | 2             | 2               |",
+                "|   | 4.0 | 2             | 2               |",
+                "| 2 |     | 1             | 2               |",
+                "| 2 | 1.0 | 0             | 2               |",
+                "| 3 |     | 1             | 3               |",
+                "| 3 | 2.0 | 0             | 2               |",
+                "| 3 | 3.0 | 0             | 1               |",
+                "| 4 |     | 1             | 3               |",
+                "| 4 | 3.0 | 0             | 1               |",
+                "| 4 | 4.0 | 0             | 2               |",
+                "+---+-----+---------------+-----------------+",
             ]
         };
         assert_batches_sorted_eq!(expected, &result);
 
-        let groups = partial_aggregate.group_expr().expr().to_vec();
-
         let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
 
-        let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups
-            .iter()
-            .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
-            .collect::<Result<_>>()?;
-
-        let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+        let final_grouping_set = grouping_set.as_final();
 
         let task_ctx = if spill {
             new_spill_ctx(4, 3160)
@@ -1468,26 +1550,26 @@ mod tests {
         let result =
             common::collect(merged_aggregate.execute(0, 
Arc::clone(&task_ctx))?).await?;
         let batch = concat_batches(&result[0].schema(), &result)?;
-        assert_eq!(batch.num_columns(), 3);
+        assert_eq!(batch.num_columns(), 4);
         assert_eq!(batch.num_rows(), 12);
 
         let expected = vec![
-            "+---+-----+----------+",
-            "| a | b   | COUNT(1) |",
-            "+---+-----+----------+",
-            "|   | 1.0 | 2        |",
-            "|   | 2.0 | 2        |",
-            "|   | 3.0 | 2        |",
-            "|   | 4.0 | 2        |",
-            "| 2 |     | 2        |",
-            "| 2 | 1.0 | 2        |",
-            "| 3 |     | 3        |",
-            "| 3 | 2.0 | 2        |",
-            "| 3 | 3.0 | 1        |",
-            "| 4 |     | 3        |",
-            "| 4 | 3.0 | 1        |",
-            "| 4 | 4.0 | 2        |",
-            "+---+-----+----------+",
+            "+---+-----+---------------+----------+",
+            "| a | b   | __grouping_id | COUNT(1) |",
+            "+---+-----+---------------+----------+",
+            "|   | 1.0 | 2             | 2        |",
+            "|   | 2.0 | 2             | 2        |",
+            "|   | 3.0 | 2             | 2        |",
+            "|   | 4.0 | 2             | 2        |",
+            "| 2 |     | 1             | 2        |",
+            "| 2 | 1.0 | 0             | 2        |",
+            "| 3 |     | 1             | 3        |",
+            "| 3 | 2.0 | 0             | 2        |",
+            "| 3 | 3.0 | 0             | 1        |",
+            "| 4 |     | 1             | 3        |",
+            "| 4 | 3.0 | 0             | 1        |",
+            "| 4 | 4.0 | 0             | 2        |",
+            "+---+-----+---------------+----------+",
         ];
 
         assert_batches_sorted_eq!(&expected, &result);
@@ -1503,11 +1585,11 @@ mod tests {
     async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> 
Result<()> {
         let input_schema = input.schema();
 
-        let grouping_set = PhysicalGroupBy {
-            expr: vec![(col("a", &input_schema)?, "a".to_string())],
-            null_expr: vec![],
-            groups: vec![vec![false]],
-        };
+        let grouping_set = PhysicalGroupBy::new(
+            vec![(col("a", &input_schema)?, "a".to_string())],
+            vec![],
+            vec![vec![false]],
+        );
 
         let aggregates: Vec<AggregateFunctionExpr> =
             vec![
@@ -1563,13 +1645,7 @@ mod tests {
 
         let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
 
-        let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set
-            .expr
-            .iter()
-            .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
-            .collect::<Result<_>>()?;
-
-        let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+        let final_grouping_set = grouping_set.as_final();
 
         let merged_aggregate = Arc::new(AggregateExec::try_new(
             AggregateMode::Final,
@@ -1825,11 +1901,11 @@ mod tests {
         let task_ctx = Arc::new(task_ctx);
 
         let groups_none = PhysicalGroupBy::default();
-        let groups_some = PhysicalGroupBy {
-            expr: vec![(col("a", &input_schema)?, "a".to_string())],
-            null_expr: vec![],
-            groups: vec![vec![false]],
-        };
+        let groups_some = PhysicalGroupBy::new(
+            vec![(col("a", &input_schema)?, "a".to_string())],
+            vec![],
+            vec![vec![false]],
+        );
 
         // something that allocates within the aggregator
         let aggregates_v0: Vec<AggregateFunctionExpr> =
@@ -2306,7 +2382,7 @@ mod tests {
         )?);
 
         let aggregate_exec = Arc::new(AggregateExec::try_new(
-            AggregateMode::Partial,
+            AggregateMode::Single,
             groups,
             aggregates.clone(),
             vec![None],
@@ -2318,13 +2394,13 @@ mod tests {
             collect(aggregate_exec.execute(0, 
Arc::new(TaskContext::default()))?).await?;
 
         let expected = [
-            "+-----+-----+-------+----------+",
-            "| a   | b   | const | 1[count] |",
-            "+-----+-----+-------+----------+",
-            "|     | 0.0 |       | 32768    |",
-            "| 0.0 |     |       | 32768    |",
-            "|     |     | 1     | 32768    |",
-            "+-----+-----+-------+----------+",
+            "+-----+-----+-------+---------------+-------+",
+            "| a   | b   | const | __grouping_id | 1     |",
+            "+-----+-----+-------+---------------+-------+",
+            "|     |     | 1     | 6             | 32768 |",
+            "|     | 0.0 |       | 5             | 32768 |",
+            "| 0.0 |     |       | 3             | 32768 |",
+            "+-----+-----+-------+---------------+-------+",
         ];
         assert_batches_sorted_eq!(expected, &output);
 
@@ -2638,30 +2714,30 @@ mod tests {
                     .build()?,
             ];
 
-        let grouping_set = PhysicalGroupBy {
-            expr: vec![
+        let grouping_set = PhysicalGroupBy::new(
+            vec![
                 (col("a", &input_schema)?, "a".to_string()),
                 (col("b", &input_schema)?, "b".to_string()),
             ],
-            null_expr: vec![
+            vec![
                 (lit(ScalarValue::Float32(None)), "a".to_string()),
                 (lit(ScalarValue::Float32(None)), "b".to_string()),
             ],
-            groups: vec![
+            vec![
                 vec![false, true],  // (a, NULL)
                 vec![false, false], // (a,b)
             ],
-        };
+        );
         let aggr_schema = create_schema(
             &input_schema,
-            &grouping_set.expr,
+            &grouping_set,
             &aggr_expr,
-            grouping_set.exprs_nullable(),
             AggregateMode::Final,
         )?;
         let expected_schema = Schema::new(vec![
             Field::new("a", DataType::Float32, false),
             Field::new("b", DataType::Float32, true),
+            Field::new("__grouping_id", DataType::UInt8, false),
             Field::new("COUNT(a)", DataType::Int64, false),
         ]);
         assert_eq!(aggr_schema, expected_schema);
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs 
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 9e4968f112..5121e6cc3b 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -449,13 +449,13 @@ impl GroupedHashAggregateStream {
         let aggregate_arguments = aggregates::aggregate_expressions(
             &agg.aggr_expr,
             &agg.mode,
-            agg_group_by.expr.len(),
+            agg_group_by.num_group_exprs(),
         )?;
         // arguments for aggregating spilled data is the same as the one for 
final aggregation
         let merging_aggregate_arguments = aggregates::aggregate_expressions(
             &agg.aggr_expr,
             &AggregateMode::Final,
-            agg_group_by.expr.len(),
+            agg_group_by.num_group_exprs(),
         )?;
 
         let filter_expressions = match agg.mode {
@@ -473,7 +473,7 @@ impl GroupedHashAggregateStream {
             .map(create_group_accumulator)
             .collect::<Result<_>>()?;
 
-        let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
+        let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;
         let spill_expr = group_schema
             .fields
             .into_iter()
diff --git a/datafusion/sql/src/unparser/utils.rs 
b/datafusion/sql/src/unparser/utils.rs
index 8b2530a749..e05df8ba77 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -15,6 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::cmp::Ordering;
+
 use datafusion_common::{
     internal_err,
     tree_node::{Transformed, TreeNode},
@@ -169,10 +171,17 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) 
-> Result<Option<&'a E
         if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
             // For grouping set expr, we must operate by expression list from 
the grouping set
             let grouping_expr = 
grouping_set_to_exprlist(agg.group_expr.as_slice())?;
-            Ok(grouping_expr
-                .into_iter()
-                .chain(agg.aggr_expr.iter())
-                .nth(index))
+            match index.cmp(&grouping_expr.len()) {
+                Ordering::Less => Ok(grouping_expr.into_iter().nth(index)),
+                Ordering::Equal => {
+                    internal_err!(
+                        "Tried to unproject column refereing to internal 
grouping id"
+                    )
+                }
+                Ordering::Greater => {
+                    Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1))
+                }
+            }
         } else {
             Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
         }
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index a78ade81ee..250fa85cdd 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3520,6 +3520,18 @@ SELECT MIN(value), MAX(value) FROM integers_with_nulls
 ----
 1 5
 
+# grouping_sets with null values
+query II rowsort
+SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value)
+----
+1 1
+3 3
+4 4
+5 5
+NULL 1
+NULL NULL
+
+
 statement ok
 DROP TABLE integers_with_nulls;
 
@@ -4879,16 +4891,18 @@ query TT
 EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
 ----
 logical_plan
-01)Limit: skip=0, fetch=3
-02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, 
aggregate_test_100.c3)]], aggr=[[]]
-03)----TableScan: aggregate_test_100 projection=[c2, c3]
+01)Projection: aggregate_test_100.c2, aggregate_test_100.c3
+02)--Limit: skip=0, fetch=3
+03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, 
aggregate_test_100.c3)]], aggr=[[]]
+04)------TableScan: aggregate_test_100 projection=[c2, c3]
 physical_plan
-01)GlobalLimitExec: skip=0, fetch=3
-02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3]
-03)----CoalescePartitionsExec
-04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as 
c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[]
-05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
-06)----------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, 
c3], has_header=true
+01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3]
+02)--GlobalLimitExec: skip=0, fetch=3
+03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 
as __grouping_id], aggr=[], lim=[3]
+04)------CoalescePartitionsExec
+05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 
as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[]
+06)----------RepartitionExec: partitioning=RoundRobinBatch(4), 
input_partitions=1
+07)------------CsvExec: file_groups={1 group: 
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, 
c3], has_header=true
 
 query II
 SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
diff --git a/datafusion/sqllogictest/test_files/group_by.slt 
b/datafusion/sqllogictest/test_files/group_by.slt
index f561fa9e9a..a80a0891e9 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -5152,8 +5152,6 @@ drop table test_case_expr
 statement ok
 drop table t;
 
-# TODO: Current grouping set result is not align with Postgres and DuckDB, we 
might want to change the result
-# See https://github.com/apache/datafusion/issues/12570
 # test multi group by for binary type with nulls
 statement ok
 create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 
0xb), (null, 0xb);
@@ -5162,11 +5160,14 @@ query I?I
 select a, b, count(*) from t group by grouping sets ((a, b), (a), (b));
 ----
 1 0a 2
-2 NULL 2
-NULL 0b 4
+2 NULL 1
+NULL 0b 2
 1 NULL 2
-NULL NULL 3
+2 NULL 1
+NULL NULL 2
 NULL 0a 2
+NULL NULL 1
+NULL 0b 2
 
 statement ok
 drop table t;
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 3b7d0fd296..ce6d1825cd 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -294,8 +294,9 @@ async fn aggregate_grouping_sets() -> Result<()> {
 async fn aggregate_grouping_rollup() -> Result<()> {
     assert_expected_plan(
         "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
-        "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), 
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\
-        \n  TableScan: data projection=[a, b, c, e]",
+        "Projection: data.a, data.c, data.e, avg(data.b)\
+        \n  Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), 
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\
+        \n    TableScan: data projection=[a, b, c, e]",
         true
     ).await
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to