This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ef227f41cb fix: Correct results for grouping sets when columns contain
nulls (#12571)
ef227f41cb is described below
commit ef227f41cba69718e16557e164415d20b83a4bd6
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Mon Oct 7 18:24:46 2024 +0200
fix: Correct results for grouping sets when columns contain nulls (#12571)
* Fix grouping sets behavior when data contains nulls
* PR suggestion comment
* Update new test case
* Add grouping_id to the logical plan
* Add doc comment next to INTERNAL_GROUPING_ID
* Fix unparsing of Aggregate with grouping sets
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/dataframe/mod.rs | 17 +
datafusion/core/src/physical_planner.rs | 14 +-
datafusion/expr/src/logical_plan/plan.rs | 56 +++-
datafusion/expr/src/utils.rs | 12 +-
.../optimizer/src/single_distinct_to_groupby.rs | 6 +-
datafusion/physical-plan/src/aggregates/mod.rs | 370 +++++++++++++--------
.../physical-plan/src/aggregates/row_hash.rs | 6 +-
datafusion/sql/src/unparser/utils.rs | 17 +-
datafusion/sqllogictest/test_files/aggregate.slt | 32 +-
datafusion/sqllogictest/test_files/group_by.slt | 11 +-
.../tests/cases/roundtrip_logical_plan.rs | 5 +-
11 files changed, 359 insertions(+), 187 deletions(-)
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index f5867881da..67e2a4780d 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -535,9 +535,26 @@ impl DataFrame {
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<DataFrame> {
+ let is_grouping_set = matches!(group_expr.as_slice(),
[Expr::GroupingSet(_)]);
+ let aggr_expr_len = aggr_expr.len();
let plan = LogicalPlanBuilder::from(self.plan)
.aggregate(group_expr, aggr_expr)?
.build()?;
+ let plan = if is_grouping_set {
+ let grouping_id_pos = plan.schema().fields().len() - 1 -
aggr_expr_len;
+ // For grouping sets we do a project to not expose the internal
grouping id
+ let exprs = plan
+ .schema()
+ .columns()
+ .into_iter()
+ .enumerate()
+ .filter(|(idx, _)| *idx != grouping_id_pos)
+ .map(|(_, column)| Expr::Column(column))
+ .collect::<Vec<_>>();
+ LogicalPlanBuilder::from(plan).project(exprs)?.build()?
+ } else {
+ plan
+ };
Ok(DataFrame {
session_state: self.session_state,
plan,
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index 78c70606bf..cf2a157b04 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner {
physical_input_schema.clone(),
)?);
- // update group column indices based on partial aggregate plan
evaluation
- let final_group: Vec<Arc<dyn PhysicalExpr>> =
- initial_aggr.output_group_expr();
-
let can_repartition = !groups.is_empty()
&& session_state.config().target_partitions() > 1
&& session_state.config().repartition_aggregations();
@@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner {
AggregateMode::Final
};
- let final_grouping_set = PhysicalGroupBy::new_single(
- final_group
- .iter()
- .enumerate()
- .map(|(i, expr)| (expr.clone(),
groups.expr()[i].1.clone()))
- .collect(),
- );
+ let final_grouping_set = initial_aggr.group_expr().as_final();
Arc::new(AggregateExec::try_new(
next_partition_mode,
@@ -2345,7 +2335,7 @@ mod tests {
.expect("hash aggregate");
assert_eq!(
"sum(aggregate_test_100.c3)",
- final_hash_agg.schema().field(2).name()
+ final_hash_agg.schema().field(3).name()
);
// we need access to the input to the partial aggregate so that other
projects can
// implement serde
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index 19e73140b7..0292274e57 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -21,7 +21,7 @@ use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
use std::hash::{Hash, Hasher};
-use std::sync::Arc;
+use std::sync::{Arc, OnceLock};
use super::dml::CopyTo;
use super::DdlStatement;
@@ -2965,6 +2965,15 @@ impl Aggregate {
.into_iter()
.map(|(q, f)| (q,
f.as_ref().clone().with_nullable(true).into()))
.collect::<Vec<_>>();
+ qualified_fields.push((
+ None,
+ Field::new(
+ Self::INTERNAL_GROUPING_ID,
+ Self::grouping_id_type(qualified_fields.len()),
+ false,
+ )
+ .into(),
+ ));
}
qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(),
&input)?);
@@ -3016,9 +3025,19 @@ impl Aggregate {
})
}
+ fn is_grouping_set(&self) -> bool {
+ matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
+ }
+
/// Get the output expressions.
fn output_expressions(&self) -> Result<Vec<&Expr>> {
+ static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
+ if self.is_grouping_set() {
+ exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
+ Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
+ }));
+ }
exprs.extend(self.aggr_expr.iter());
debug_assert!(exprs.len() == self.schema.fields().len());
Ok(exprs)
@@ -3030,6 +3049,41 @@ impl Aggregate {
pub fn group_expr_len(&self) -> Result<usize> {
grouping_set_expr_count(&self.group_expr)
}
+
+ /// Returns the data type of the grouping id.
+ /// The grouping ID value is a bitmask where each set bit
+ /// indicates that the corresponding grouping expression is
+ /// null
+ pub fn grouping_id_type(group_exprs: usize) -> DataType {
+ if group_exprs <= 8 {
+ DataType::UInt8
+ } else if group_exprs <= 16 {
+ DataType::UInt16
+ } else if group_exprs <= 32 {
+ DataType::UInt32
+ } else {
+ DataType::UInt64
+ }
+ }
+
+ /// Internal column used when the aggregation is a grouping set.
+ ///
+ /// This column contains a bitmask where each bit represents a grouping
+ /// expression. The least significant bit corresponds to the rightmost
+ /// grouping expression. A bit value of 0 indicates that the corresponding
+ /// column is included in the grouping set, while a value of 1 means it is
excluded.
+ ///
+ /// For example, for the grouping expressions CUBE(a, b), the grouping ID
+ /// column will have the following values:
+ /// 0b00: Both `a` and `b` are included
+ /// 0b01: `b` is excluded
+ /// 0b10: `a` is excluded
+ /// 0b11: Both `a` and `b` are excluded
+ ///
+ /// This internal column is necessary because excluded columns are replaced
+ /// with `NULL` values. To handle these cases correctly, we must
distinguish
+ /// between an actual `NULL` value in a column and a column being excluded
from the set.
+ pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
}
// Manual implementation needed because of `schema` field. Comparison excludes
this field.
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index fa92759504..02b36d0fea 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -61,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut
HashSet<Column>) -> Result
/// Count the number of distinct exprs in a list of group by expressions. If
the
/// first element is a `GroupingSet` expression then it must be the only expr.
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
- grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
+ if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
+ if group_expr.len() > 1 {
+ return plan_err!(
+ "Invalid group by expressions, GroupingSet must be the only
expression"
+ );
+ }
+ // Groupings sets have an additional interal column for the grouping id
+ Ok(grouping_set.distinct_expr().len() + 1)
+ } else {
+ grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
+ }
}
/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index 1c22c2a437..74251e5caa 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -355,7 +355,7 @@ mod tests {
.build()?;
// Should not be optimized
- let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a),
(test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N,
count(DISTINCT test.c):Int64]\
+ let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a),
(test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N,
__grouping_id:UInt8, count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
assert_optimized_plan_equal(plan, expected)
@@ -373,7 +373,7 @@ mod tests {
.build()?;
// Should not be optimized
- let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]],
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT
test.c):Int64]\
+ let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]],
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8,
count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
assert_optimized_plan_equal(plan, expected)
@@ -392,7 +392,7 @@ mod tests {
.build()?;
// Should not be optimized
- let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]],
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT
test.c):Int64]\
+ let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]],
aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8,
count(DISTINCT test.c):Int64]\
\n TableScan: test [a:UInt32, b:UInt32,
c:UInt32]";
assert_optimized_plan_equal(plan, expected)
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs
b/datafusion/physical-plan/src/aggregates/mod.rs
index 9466ff6dd4..f9dd973c81 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -36,10 +36,11 @@ use crate::{
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
+use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array};
use datafusion_common::stats::Precision;
use datafusion_common::{internal_err, not_impl_err, Result};
use datafusion_execution::TaskContext;
-use datafusion_expr::Accumulator;
+use datafusion_expr::{Accumulator, Aggregate};
use datafusion_physical_expr::{
equivalence::{collapse_lex_req, ProjectionMapping},
expressions::Column,
@@ -211,13 +212,99 @@ impl PhysicalGroupBy {
.collect()
}
+ /// The number of expressions in the output schema.
+ fn num_output_exprs(&self) -> usize {
+ let mut num_exprs = self.expr.len();
+ if !self.is_single() {
+ num_exprs += 1
+ }
+ num_exprs
+ }
+
/// Return grouping expressions as they occur in the output schema.
pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- self.expr
- .iter()
- .enumerate()
- .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _)
- .collect()
+ let num_output_exprs = self.num_output_exprs();
+ let mut output_exprs = Vec::with_capacity(num_output_exprs);
+ output_exprs.extend(
+ self.expr
+ .iter()
+ .enumerate()
+ .take(num_output_exprs)
+ .map(|(index, (_, name))| Arc::new(Column::new(name, index))
as _),
+ );
+ if !self.is_single() {
+ output_exprs.push(Arc::new(Column::new(
+ Aggregate::INTERNAL_GROUPING_ID,
+ self.expr.len(),
+ )) as _);
+ }
+ output_exprs
+ }
+
+ /// Returns the number expression as grouping keys.
+ fn num_group_exprs(&self) -> usize {
+ if self.is_single() {
+ self.expr.len()
+ } else {
+ self.expr.len() + 1
+ }
+ }
+
+ /// Returns the fields that are used as the grouping keys.
+ fn group_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
+ let mut fields = Vec::with_capacity(self.num_group_exprs());
+ for ((expr, name), group_expr_nullable) in
+ self.expr.iter().zip(self.exprs_nullable().into_iter())
+ {
+ fields.push(
+ Field::new(
+ name,
+ expr.data_type(input_schema)?,
+ group_expr_nullable || expr.nullable(input_schema)?,
+ )
+ .with_metadata(
+ get_field_metadata(expr, input_schema).unwrap_or_default(),
+ ),
+ );
+ }
+ if !self.is_single() {
+ fields.push(Field::new(
+ Aggregate::INTERNAL_GROUPING_ID,
+ Aggregate::grouping_id_type(self.expr.len()),
+ false,
+ ));
+ }
+ Ok(fields)
+ }
+
+ /// Returns the output fields of the group by.
+ ///
+ /// This might be different from the `group_fields` that might contain
internal expressions that
+ /// should not be part of the output schema.
+ fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
+ let mut fields = self.group_fields(input_schema)?;
+ fields.truncate(self.num_output_exprs());
+ Ok(fields)
+ }
+
+ /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is
used for a partial
+ /// aggregation.
+ pub fn as_final(&self) -> PhysicalGroupBy {
+ let expr: Vec<_> =
+ self.output_exprs()
+ .into_iter()
+ .zip(
+ self.expr.iter().map(|t|
t.1.clone()).chain(std::iter::once(
+ Aggregate::INTERNAL_GROUPING_ID.to_owned(),
+ )),
+ )
+ .collect();
+ let num_exprs = expr.len();
+ Self {
+ expr,
+ null_expr: vec![],
+ groups: vec![vec![false; num_exprs]],
+ }
}
}
@@ -321,13 +408,7 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
- let schema = create_schema(
- &input.schema(),
- &group_by.expr,
- &aggr_expr,
- group_by.exprs_nullable(),
- mode,
- )?;
+ let schema = create_schema(&input.schema(), &group_by, &aggr_expr,
mode)?;
let schema = Arc::new(schema);
AggregateExec::try_new_with_schema(
@@ -789,25 +870,12 @@ impl ExecutionPlan for AggregateExec {
fn create_schema(
input_schema: &Schema,
- group_expr: &[(Arc<dyn PhysicalExpr>, String)],
+ group_by: &PhysicalGroupBy,
aggr_expr: &[AggregateFunctionExpr],
- group_expr_nullable: Vec<bool>,
mode: AggregateMode,
) -> Result<Schema> {
- let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
- for (index, (expr, name)) in group_expr.iter().enumerate() {
- fields.push(
- Field::new(
- name,
- expr.data_type(input_schema)?,
- // In cases where we have multiple grouping sets, we will use
NULL expressions in
- // order to align the grouping sets. So the field must be
nullable even if the underlying
- // schema field is not.
- group_expr_nullable[index] || expr.nullable(input_schema)?,
- )
- .with_metadata(get_field_metadata(expr,
input_schema).unwrap_or_default()),
- )
- }
+ let mut fields = Vec::with_capacity(group_by.num_output_exprs() +
aggr_expr.len());
+ fields.extend(group_by.output_fields(input_schema)?);
match mode {
AggregateMode::Partial => {
@@ -833,9 +901,8 @@ fn create_schema(
))
}
-fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
- let group_fields = schema.fields()[0..group_count].to_vec();
- Arc::new(Schema::new(group_fields))
+fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) ->
Result<SchemaRef> {
+ Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?)))
}
/// Determines the lexical ordering requirement for an aggregate expression.
@@ -1142,6 +1209,27 @@ fn evaluate_optional(
.collect()
}
+fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
+ if group.len() > 64 {
+ return not_impl_err!(
+ "Grouping sets with more than 64 columns are not supported"
+ );
+ }
+ let group_id = group.iter().fold(0u64, |acc, &is_null| {
+ (acc << 1) | if is_null { 1 } else { 0 }
+ });
+ let num_rows = batch.num_rows();
+ if group.len() <= 8 {
+ Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows])))
+ } else if group.len() <= 16 {
+ Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows])))
+ } else if group.len() <= 32 {
+ Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows])))
+ } else {
+ Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows])))
+ }
+}
+
/// Evaluate a group by expression against a `RecordBatch`
///
/// Arguments:
@@ -1174,23 +1262,24 @@ pub(crate) fn evaluate_group_by(
})
.collect::<Result<Vec<_>>>()?;
- Ok(group_by
+ group_by
.groups
.iter()
.map(|group| {
- group
- .iter()
- .enumerate()
- .map(|(idx, is_null)| {
- if *is_null {
- Arc::clone(&null_exprs[idx])
- } else {
- Arc::clone(&exprs[idx])
- }
- })
- .collect()
+ let mut group_values =
Vec::with_capacity(group_by.num_group_exprs());
+ group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
+ if *is_null {
+ Arc::clone(&null_exprs[idx])
+ } else {
+ Arc::clone(&exprs[idx])
+ }
+ }));
+ if !group_by.is_single() {
+ group_values.push(group_id_array(group, batch)?);
+ }
+ Ok(group_values)
})
- .collect())
+ .collect()
}
#[cfg(test)]
@@ -1348,21 +1437,21 @@ mod tests {
) -> Result<()> {
let input_schema = input.schema();
- let grouping_set = PhysicalGroupBy {
- expr: vec![
+ let grouping_set = PhysicalGroupBy::new(
+ vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
- null_expr: vec![
+ vec![
(lit(ScalarValue::UInt32(None)), "a".to_string()),
(lit(ScalarValue::Float64(None)), "b".to_string()),
],
- groups: vec![
+ vec![
vec![false, true], // (a, NULL)
vec![true, false], // (NULL, b)
vec![false, false], // (a,b)
],
- };
+ );
let aggregates = vec![AggregateExprBuilder::new(count_udaf(),
vec![lit(1i8)])
.schema(Arc::clone(&input_schema))
@@ -1392,63 +1481,56 @@ mod tests {
// In spill mode, we test with the limited memory, if the mem
usage exceeds,
// we trigger the early emit rule, which turns out the partial
aggregate result.
vec![
- "+---+-----+-----------------+",
- "| a | b | COUNT(1)[count] |",
- "+---+-----+-----------------+",
- "| | 1.0 | 1 |",
- "| | 1.0 | 1 |",
- "| | 2.0 | 1 |",
- "| | 2.0 | 1 |",
- "| | 3.0 | 1 |",
- "| | 3.0 | 1 |",
- "| | 4.0 | 1 |",
- "| | 4.0 | 1 |",
- "| 2 | | 1 |",
- "| 2 | | 1 |",
- "| 2 | 1.0 | 1 |",
- "| 2 | 1.0 | 1 |",
- "| 3 | | 1 |",
- "| 3 | | 2 |",
- "| 3 | 2.0 | 2 |",
- "| 3 | 3.0 | 1 |",
- "| 4 | | 1 |",
- "| 4 | | 2 |",
- "| 4 | 3.0 | 1 |",
- "| 4 | 4.0 | 2 |",
- "+---+-----+-----------------+",
+ "+---+-----+---------------+-----------------+",
+ "| a | b | __grouping_id | COUNT(1)[count] |",
+ "+---+-----+---------------+-----------------+",
+ "| | 1.0 | 2 | 1 |",
+ "| | 1.0 | 2 | 1 |",
+ "| | 2.0 | 2 | 1 |",
+ "| | 2.0 | 2 | 1 |",
+ "| | 3.0 | 2 | 1 |",
+ "| | 3.0 | 2 | 1 |",
+ "| | 4.0 | 2 | 1 |",
+ "| | 4.0 | 2 | 1 |",
+ "| 2 | | 1 | 1 |",
+ "| 2 | | 1 | 1 |",
+ "| 2 | 1.0 | 0 | 1 |",
+ "| 2 | 1.0 | 0 | 1 |",
+ "| 3 | | 1 | 1 |",
+ "| 3 | | 1 | 2 |",
+ "| 3 | 2.0 | 0 | 2 |",
+ "| 3 | 3.0 | 0 | 1 |",
+ "| 4 | | 1 | 1 |",
+ "| 4 | | 1 | 2 |",
+ "| 4 | 3.0 | 0 | 1 |",
+ "| 4 | 4.0 | 0 | 2 |",
+ "+---+-----+---------------+-----------------+",
]
} else {
vec![
- "+---+-----+-----------------+",
- "| a | b | COUNT(1)[count] |",
- "+---+-----+-----------------+",
- "| | 1.0 | 2 |",
- "| | 2.0 | 2 |",
- "| | 3.0 | 2 |",
- "| | 4.0 | 2 |",
- "| 2 | | 2 |",
- "| 2 | 1.0 | 2 |",
- "| 3 | | 3 |",
- "| 3 | 2.0 | 2 |",
- "| 3 | 3.0 | 1 |",
- "| 4 | | 3 |",
- "| 4 | 3.0 | 1 |",
- "| 4 | 4.0 | 2 |",
- "+---+-----+-----------------+",
+ "+---+-----+---------------+-----------------+",
+ "| a | b | __grouping_id | COUNT(1)[count] |",
+ "+---+-----+---------------+-----------------+",
+ "| | 1.0 | 2 | 2 |",
+ "| | 2.0 | 2 | 2 |",
+ "| | 3.0 | 2 | 2 |",
+ "| | 4.0 | 2 | 2 |",
+ "| 2 | | 1 | 2 |",
+ "| 2 | 1.0 | 0 | 2 |",
+ "| 3 | | 1 | 3 |",
+ "| 3 | 2.0 | 0 | 2 |",
+ "| 3 | 3.0 | 0 | 1 |",
+ "| 4 | | 1 | 3 |",
+ "| 4 | 3.0 | 0 | 1 |",
+ "| 4 | 4.0 | 0 | 2 |",
+ "+---+-----+---------------+-----------------+",
]
};
assert_batches_sorted_eq!(expected, &result);
- let groups = partial_aggregate.group_expr().expr().to_vec();
-
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
- let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = groups
- .iter()
- .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
- .collect::<Result<_>>()?;
-
- let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+ let final_grouping_set = grouping_set.as_final();
let task_ctx = if spill {
new_spill_ctx(4, 3160)
@@ -1468,26 +1550,26 @@ mod tests {
let result =
common::collect(merged_aggregate.execute(0,
Arc::clone(&task_ctx))?).await?;
let batch = concat_batches(&result[0].schema(), &result)?;
- assert_eq!(batch.num_columns(), 3);
+ assert_eq!(batch.num_columns(), 4);
assert_eq!(batch.num_rows(), 12);
let expected = vec![
- "+---+-----+----------+",
- "| a | b | COUNT(1) |",
- "+---+-----+----------+",
- "| | 1.0 | 2 |",
- "| | 2.0 | 2 |",
- "| | 3.0 | 2 |",
- "| | 4.0 | 2 |",
- "| 2 | | 2 |",
- "| 2 | 1.0 | 2 |",
- "| 3 | | 3 |",
- "| 3 | 2.0 | 2 |",
- "| 3 | 3.0 | 1 |",
- "| 4 | | 3 |",
- "| 4 | 3.0 | 1 |",
- "| 4 | 4.0 | 2 |",
- "+---+-----+----------+",
+ "+---+-----+---------------+----------+",
+ "| a | b | __grouping_id | COUNT(1) |",
+ "+---+-----+---------------+----------+",
+ "| | 1.0 | 2 | 2 |",
+ "| | 2.0 | 2 | 2 |",
+ "| | 3.0 | 2 | 2 |",
+ "| | 4.0 | 2 | 2 |",
+ "| 2 | | 1 | 2 |",
+ "| 2 | 1.0 | 0 | 2 |",
+ "| 3 | | 1 | 3 |",
+ "| 3 | 2.0 | 0 | 2 |",
+ "| 3 | 3.0 | 0 | 1 |",
+ "| 4 | | 1 | 3 |",
+ "| 4 | 3.0 | 0 | 1 |",
+ "| 4 | 4.0 | 0 | 2 |",
+ "+---+-----+---------------+----------+",
];
assert_batches_sorted_eq!(&expected, &result);
@@ -1503,11 +1585,11 @@ mod tests {
async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) ->
Result<()> {
let input_schema = input.schema();
- let grouping_set = PhysicalGroupBy {
- expr: vec![(col("a", &input_schema)?, "a".to_string())],
- null_expr: vec![],
- groups: vec![vec![false]],
- };
+ let grouping_set = PhysicalGroupBy::new(
+ vec![(col("a", &input_schema)?, "a".to_string())],
+ vec![],
+ vec![vec![false]],
+ );
let aggregates: Vec<AggregateFunctionExpr> =
vec![
@@ -1563,13 +1645,7 @@ mod tests {
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
- let final_group: Vec<(Arc<dyn PhysicalExpr>, String)> = grouping_set
- .expr
- .iter()
- .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone())))
- .collect::<Result<_>>()?;
-
- let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+ let final_grouping_set = grouping_set.as_final();
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
@@ -1825,11 +1901,11 @@ mod tests {
let task_ctx = Arc::new(task_ctx);
let groups_none = PhysicalGroupBy::default();
- let groups_some = PhysicalGroupBy {
- expr: vec![(col("a", &input_schema)?, "a".to_string())],
- null_expr: vec![],
- groups: vec![vec![false]],
- };
+ let groups_some = PhysicalGroupBy::new(
+ vec![(col("a", &input_schema)?, "a".to_string())],
+ vec![],
+ vec![vec![false]],
+ );
// something that allocates within the aggregator
let aggregates_v0: Vec<AggregateFunctionExpr> =
@@ -2306,7 +2382,7 @@ mod tests {
)?);
let aggregate_exec = Arc::new(AggregateExec::try_new(
- AggregateMode::Partial,
+ AggregateMode::Single,
groups,
aggregates.clone(),
vec![None],
@@ -2318,13 +2394,13 @@ mod tests {
collect(aggregate_exec.execute(0,
Arc::new(TaskContext::default()))?).await?;
let expected = [
- "+-----+-----+-------+----------+",
- "| a | b | const | 1[count] |",
- "+-----+-----+-------+----------+",
- "| | 0.0 | | 32768 |",
- "| 0.0 | | | 32768 |",
- "| | | 1 | 32768 |",
- "+-----+-----+-------+----------+",
+ "+-----+-----+-------+---------------+-------+",
+ "| a | b | const | __grouping_id | 1 |",
+ "+-----+-----+-------+---------------+-------+",
+ "| | | 1 | 6 | 32768 |",
+ "| | 0.0 | | 5 | 32768 |",
+ "| 0.0 | | | 3 | 32768 |",
+ "+-----+-----+-------+---------------+-------+",
];
assert_batches_sorted_eq!(expected, &output);
@@ -2638,30 +2714,30 @@ mod tests {
.build()?,
];
- let grouping_set = PhysicalGroupBy {
- expr: vec![
+ let grouping_set = PhysicalGroupBy::new(
+ vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
- null_expr: vec![
+ vec![
(lit(ScalarValue::Float32(None)), "a".to_string()),
(lit(ScalarValue::Float32(None)), "b".to_string()),
],
- groups: vec![
+ vec![
vec![false, true], // (a, NULL)
vec![false, false], // (a,b)
],
- };
+ );
let aggr_schema = create_schema(
&input_schema,
- &grouping_set.expr,
+ &grouping_set,
&aggr_expr,
- grouping_set.exprs_nullable(),
AggregateMode::Final,
)?;
let expected_schema = Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, true),
+ Field::new("__grouping_id", DataType::UInt8, false),
Field::new("COUNT(a)", DataType::Int64, false),
]);
assert_eq!(aggr_schema, expected_schema);
diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs
b/datafusion/physical-plan/src/aggregates/row_hash.rs
index 9e4968f112..5121e6cc3b 100644
--- a/datafusion/physical-plan/src/aggregates/row_hash.rs
+++ b/datafusion/physical-plan/src/aggregates/row_hash.rs
@@ -449,13 +449,13 @@ impl GroupedHashAggregateStream {
let aggregate_arguments = aggregates::aggregate_expressions(
&agg.aggr_expr,
&agg.mode,
- agg_group_by.expr.len(),
+ agg_group_by.num_group_exprs(),
)?;
// arguments for aggregating spilled data is the same as the one for
final aggregation
let merging_aggregate_arguments = aggregates::aggregate_expressions(
&agg.aggr_expr,
&AggregateMode::Final,
- agg_group_by.expr.len(),
+ agg_group_by.num_group_exprs(),
)?;
let filter_expressions = match agg.mode {
@@ -473,7 +473,7 @@ impl GroupedHashAggregateStream {
.map(create_group_accumulator)
.collect::<Result<_>>()?;
- let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
+ let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;
let spill_expr = group_schema
.fields
.into_iter()
diff --git a/datafusion/sql/src/unparser/utils.rs
b/datafusion/sql/src/unparser/utils.rs
index 8b2530a749..e05df8ba77 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.
+use std::cmp::Ordering;
+
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
@@ -169,10 +171,17 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column)
-> Result<Option<&'a E
if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) {
// For grouping set expr, we must operate by expression list from
the grouping set
let grouping_expr =
grouping_set_to_exprlist(agg.group_expr.as_slice())?;
- Ok(grouping_expr
- .into_iter()
- .chain(agg.aggr_expr.iter())
- .nth(index))
+ match index.cmp(&grouping_expr.len()) {
+ Ordering::Less => Ok(grouping_expr.into_iter().nth(index)),
+ Ordering::Equal => {
+ internal_err!(
+ "Tried to unproject column refereing to internal
grouping id"
+ )
+ }
+ Ordering::Greater => {
+ Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1))
+ }
+ }
} else {
Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index))
}
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt
b/datafusion/sqllogictest/test_files/aggregate.slt
index a78ade81ee..250fa85cdd 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -3520,6 +3520,18 @@ SELECT MIN(value), MAX(value) FROM integers_with_nulls
----
1 5
+# grouping_sets with null values
+query II rowsort
+SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value)
+----
+1 1
+3 3
+4 4
+5 5
+NULL 1
+NULL NULL
+
+
statement ok
DROP TABLE integers_with_nulls;
@@ -4879,16 +4891,18 @@ query TT
EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
----
logical_plan
-01)Limit: skip=0, fetch=3
-02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2,
aggregate_test_100.c3)]], aggr=[[]]
-03)----TableScan: aggregate_test_100 projection=[c2, c3]
+01)Projection: aggregate_test_100.c2, aggregate_test_100.c3
+02)--Limit: skip=0, fetch=3
+03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2,
aggregate_test_100.c3)]], aggr=[[]]
+04)------TableScan: aggregate_test_100 projection=[c2, c3]
physical_plan
-01)GlobalLimitExec: skip=0, fetch=3
-02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3]
-03)----CoalescePartitionsExec
-04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as
c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[]
-05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
-06)----------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2,
c3], has_header=true
+01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3]
+02)--GlobalLimitExec: skip=0, fetch=3
+03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2
as __grouping_id], aggr=[], lim=[3]
+04)------CoalescePartitionsExec
+05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0
as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[]
+06)----------RepartitionExec: partitioning=RoundRobinBatch(4),
input_partitions=1
+07)------------CsvExec: file_groups={1 group:
[[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2,
c3], has_header=true
query II
SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
diff --git a/datafusion/sqllogictest/test_files/group_by.slt
b/datafusion/sqllogictest/test_files/group_by.slt
index f561fa9e9a..a80a0891e9 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -5152,8 +5152,6 @@ drop table test_case_expr
statement ok
drop table t;
-# TODO: Current grouping set result is not align with Postgres and DuckDB, we
might want to change the result
-# See https://github.com/apache/datafusion/issues/12570
# test multi group by for binary type with nulls
statement ok
create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null,
0xb), (null, 0xb);
@@ -5162,11 +5160,14 @@ query I?I
select a, b, count(*) from t group by grouping sets ((a, b), (a), (b));
----
1 0a 2
-2 NULL 2
-NULL 0b 4
+2 NULL 1
+NULL 0b 2
1 NULL 2
-NULL NULL 3
+2 NULL 1
+NULL NULL 2
NULL 0a 2
+NULL NULL 1
+NULL 0b 2
statement ok
drop table t;
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 3b7d0fd296..ce6d1825cd 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -294,8 +294,9 @@ async fn aggregate_grouping_sets() -> Result<()> {
async fn aggregate_grouping_rollup() -> Result<()> {
assert_expected_plan(
"SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)",
- "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e),
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\
- \n TableScan: data projection=[a, b, c, e]",
+ "Projection: data.a, data.c, data.e, avg(data.b)\
+ \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e),
(data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\
+ \n TableScan: data projection=[a, b, c, e]",
true
).await
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]