This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f48053fc5 Minor: Make schema of grouping set columns nullable (#8248)
8f48053fc5 is described below

commit 8f48053fc5f6fc3de27b69cd6f229558d8fc8990
Author: Markus Appel <[email protected]>
AuthorDate: Sat Nov 18 14:41:29 2023 +0100

    Minor: Make schema of grouping set columns nullable (#8248)
    
    * Make output schema of aggregation grouping sets nullable
    
    * Improve
    
    * Fix tests
---
 datafusion/expr/src/logical_plan/plan.rs           | 56 +++++++++++++++++++---
 .../optimizer/src/single_distinct_to_groupby.rs    |  6 +--
 datafusion/sqllogictest/test_files/aggregate.slt   |  7 +--
 3 files changed, 57 insertions(+), 12 deletions(-)

diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index b7537dc02e..a024824c7a 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -2294,13 +2294,25 @@ impl Aggregate {
         aggr_expr: Vec<Expr>,
     ) -> Result<Self> {
         let group_expr = enumerate_grouping_sets(group_expr)?;
+
+        let is_grouping_set = matches!(group_expr.as_slice(), 
[Expr::GroupingSet(_)]);
+
         let grouping_expr: Vec<Expr> = 
grouping_set_to_exprlist(group_expr.as_slice())?;
-        let all_expr = grouping_expr.iter().chain(aggr_expr.iter());
 
-        let schema = DFSchema::new_with_metadata(
-            exprlist_to_fields(all_expr, &input)?,
-            input.schema().metadata().clone(),
-        )?;
+        let mut fields = exprlist_to_fields(grouping_expr.iter(), &input)?;
+
+        // Even columns that cannot be null will become nullable when used in 
a grouping set.
+        if is_grouping_set {
+            fields = fields
+                .into_iter()
+                .map(|field| field.with_nullable(true))
+                .collect::<Vec<_>>();
+        }
+
+        fields.extend(exprlist_to_fields(aggr_expr.iter(), &input)?);
+
+        let schema =
+            DFSchema::new_with_metadata(fields, 
input.schema().metadata().clone())?;
 
         Self::try_new_with_schema(input, group_expr, aggr_expr, 
Arc::new(schema))
     }
@@ -2539,7 +2551,7 @@ pub struct Unnest {
 mod tests {
     use super::*;
     use crate::logical_plan::table_scan;
-    use crate::{col, exists, in_subquery, lit, placeholder};
+    use crate::{col, count, exists, in_subquery, lit, placeholder, 
GroupingSet};
     use arrow::datatypes::{DataType, Field, Schema};
     use datafusion_common::tree_node::TreeNodeVisitor;
     use datafusion_common::{not_impl_err, DFSchema, TableReference};
@@ -3006,4 +3018,36 @@ digraph {
         plan.replace_params_with_values(&[42i32.into()])
             .expect_err("unexpectedly succeeded to replace an invalid 
placeholder");
     }
+
+    #[test]
+    fn test_nullable_schema_after_grouping_set() {
+        let schema = Schema::new(vec![
+            Field::new("foo", DataType::Int32, false),
+            Field::new("bar", DataType::Int32, false),
+        ]);
+
+        let plan = table_scan(TableReference::none(), &schema, None)
+            .unwrap()
+            .aggregate(
+                vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![
+                    vec![col("foo")],
+                    vec![col("bar")],
+                ]))],
+                vec![count(lit(true))],
+            )
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let output_schema = plan.schema();
+
+        assert!(output_schema
+            .field_with_name(None, "foo")
+            .unwrap()
+            .is_nullable(),);
+        assert!(output_schema
+            .field_with_name(None, "bar")
+            .unwrap()
+            .is_nullable());
+    }
 }
diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs 
b/datafusion/optimizer/src/single_distinct_to_groupby.rs
index be76c069f0..ac18e596b7 100644
--- a/datafusion/optimizer/src/single_distinct_to_groupby.rs
+++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs
@@ -322,7 +322,7 @@ mod tests {
             .build()?;
 
         // Should not be optimized
-        let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), 
(test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, 
COUNT(DISTINCT test.c):Int64;N]\
+        let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), 
(test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, 
COUNT(DISTINCT test.c):Int64;N]\
                             \n  TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]";
 
         assert_optimized_plan_equal(&plan, expected)
@@ -340,7 +340,7 @@ mod tests {
             .build()?;
 
         // Should not be optimized
-        let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], 
aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT 
test.c):Int64;N]\
+        let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], 
aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT 
test.c):Int64;N]\
                             \n  TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]";
 
         assert_optimized_plan_equal(&plan, expected)
@@ -359,7 +359,7 @@ mod tests {
             .build()?;
 
         // Should not be optimized
-        let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], 
aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT 
test.c):Int64;N]\
+        let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], 
aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT 
test.c):Int64;N]\
                             \n  TableScan: test [a:UInt32, b:UInt32, 
c:UInt32]";
 
         assert_optimized_plan_equal(&plan, expected)
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 0a495dd2b0..faad6feb3f 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -2672,9 +2672,10 @@ query TT
 EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
 ----
 logical_plan
-Limit: skip=0, fetch=3
---Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, 
aggregate_test_100.c3)]], aggr=[[]]
-----TableScan: aggregate_test_100 projection=[c2, c3]
+Projection: aggregate_test_100.c2, aggregate_test_100.c3
+--Limit: skip=0, fetch=3
+----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, 
aggregate_test_100.c3)]], aggr=[[]]
+------TableScan: aggregate_test_100 projection=[c2, c3]
 physical_plan
 GlobalLimitExec: skip=0, fetch=3
 --AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3]

Reply via email to