This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8802f6314f feat: Implement grouping function using grouping id (#12704)
8802f6314f is described below

commit 8802f6314f94ee99e2235350c148cdeecaa0ab1b
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Wed Oct 16 16:14:19 2024 +0200

    feat: Implement grouping function using grouping id (#12704)
    
    * Implement grouping function using grouping id
    
    This patch adds a Analyzer rule to transform the grouping aggreation
    function into computation ontop of the grouping id that is used
    internally for grouping sets.
    
    * PR comments
---
 datafusion/functions-aggregate/src/grouping.rs     |   2 +-
 datafusion/optimizer/src/analyzer/mod.rs           |   3 +
 .../src/analyzer/resolve_grouping_function.rs      | 247 +++++++++++++++++++++
 datafusion/sqllogictest/test_files/explain.slt     |   1 +
 datafusion/sqllogictest/test_files/grouping.slt    | 214 ++++++++++++++++++
 5 files changed, 466 insertions(+), 1 deletion(-)

diff --git a/datafusion/functions-aggregate/src/grouping.rs 
b/datafusion/functions-aggregate/src/grouping.rs
index 09e9b90b2e..558d3055f1 100644
--- a/datafusion/functions-aggregate/src/grouping.rs
+++ b/datafusion/functions-aggregate/src/grouping.rs
@@ -63,7 +63,7 @@ impl Grouping {
     /// Create a new GROUPING aggregate function.
     pub fn new() -> Self {
         Self {
-            signature: Signature::any(1, Volatility::Immutable),
+            signature: Signature::variadic_any(Volatility::Immutable),
         }
     }
 }
diff --git a/datafusion/optimizer/src/analyzer/mod.rs 
b/datafusion/optimizer/src/analyzer/mod.rs
index 4cd891664e..a9fd4900b2 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -34,6 +34,7 @@ use datafusion_expr::{Expr, LogicalPlan};
 use crate::analyzer::count_wildcard_rule::CountWildcardRule;
 use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule;
 use crate::analyzer::inline_table_scan::InlineTableScan;
+use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction;
 use crate::analyzer::subquery::check_subquery_expr;
 use crate::analyzer::type_coercion::TypeCoercion;
 use crate::utils::log_plan;
@@ -44,6 +45,7 @@ pub mod count_wildcard_rule;
 pub mod expand_wildcard_rule;
 pub mod function_rewrite;
 pub mod inline_table_scan;
+pub mod resolve_grouping_function;
 pub mod subquery;
 pub mod type_coercion;
 
@@ -96,6 +98,7 @@ impl Analyzer {
             // Every rule that will generate [Expr::Wildcard] should be placed 
in front of [ExpandWildcardRule].
             Arc::new(ExpandWildcardRule::new()),
             // [Expr::Wildcard] should be expanded before [TypeCoercion]
+            Arc::new(ResolveGroupingFunction::new()),
             Arc::new(TypeCoercion::new()),
             Arc::new(CountWildcardRule::new()),
         ];
diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs 
b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
new file mode 100644
index 0000000000..16ebb8cd39
--- /dev/null
+++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
@@ -0,0 +1,247 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Analyzed rule to replace TableScan references
+//! such as DataFrames and Views and inlines the LogicalPlan.
+
+use std::cmp::Ordering;
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use crate::analyzer::AnalyzerRule;
+
+use arrow::datatypes::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
+use datafusion_common::{
+    internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, 
ScalarValue,
+};
+use datafusion_expr::expr::{AggregateFunction, Alias};
+use datafusion_expr::logical_plan::LogicalPlan;
+use datafusion_expr::utils::grouping_set_to_exprlist;
+use datafusion_expr::{
+    bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, 
Aggregate,
+    Expr, Projection,
+};
+use itertools::Itertools;
+
+/// Replaces grouping aggregation function with value derived from internal 
grouping id
+#[derive(Default, Debug)]
+pub struct ResolveGroupingFunction;
+
+impl ResolveGroupingFunction {
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl AnalyzerRule for ResolveGroupingFunction {
+    fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> 
Result<LogicalPlan> {
+        plan.transform_up(analyze_internal).data()
+    }
+
+    fn name(&self) -> &str {
+        "resolve_grouping_function"
+    }
+}
+
+/// Create a map from grouping expr to index in the internal grouping id.
+///
+/// For more details on how the grouping id bitmap works the documentation for
+/// [[Aggregate::INTERNAL_GROUPING_ID]]
+fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr, 
usize>> {
+    Ok(grouping_set_to_exprlist(group_expr)?
+        .into_iter()
+        .rev()
+        .enumerate()
+        .map(|(idx, v)| (v, idx))
+        .collect::<HashMap<_, _>>())
+}
+
+fn replace_grouping_exprs(
+    input: Arc<LogicalPlan>,
+    schema: DFSchemaRef,
+    group_expr: Vec<Expr>,
+    aggr_expr: Vec<Expr>,
+) -> Result<LogicalPlan> {
+    // Create HashMap from Expr to index in the grouping_id bitmap
+    let is_grouping_set = matches!(group_expr.as_slice(), 
[Expr::GroupingSet(_)]);
+    let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
+    let columns = schema.columns();
+    let mut new_agg_expr = Vec::new();
+    let mut projection_exprs = Vec::new();
+    let grouping_id_len = if is_grouping_set { 1 } else { 0 };
+    let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
+    projection_exprs.extend(
+        columns
+            .iter()
+            .take(group_expr_len)
+            .map(|column| Expr::Column(column.clone())),
+    );
+    for (expr, column) in aggr_expr
+        .into_iter()
+        .zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
+    {
+        match expr {
+            Expr::AggregateFunction(ref function) if 
is_grouping_function(&expr) => {
+                let grouping_expr = grouping_function_on_id(
+                    function,
+                    &group_expr_to_bitmap_index,
+                    is_grouping_set,
+                )?;
+                projection_exprs.push(Expr::Alias(Alias::new(
+                    grouping_expr,
+                    column.relation,
+                    column.name,
+                )));
+            }
+            _ => {
+                projection_exprs.push(Expr::Column(column));
+                new_agg_expr.push(expr);
+            }
+        }
+    }
+    // Recreate aggregate without grouping functions
+    let new_aggregate =
+        LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, 
new_agg_expr)?);
+    // Create projection with grouping functions calculations
+    let projection = LogicalPlan::Projection(Projection::try_new(
+        projection_exprs,
+        new_aggregate.into(),
+    )?);
+    Ok(projection)
+}
+
+fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+    // rewrite any subqueries in the plan first
+    let transformed_plan =
+        plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
+
+    let transformed_plan = transformed_plan.transform_data(|plan| match plan {
+        LogicalPlan::Aggregate(Aggregate {
+            input,
+            group_expr,
+            aggr_expr,
+            schema,
+            ..
+        }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
+            replace_grouping_exprs(input, schema, group_expr, aggr_expr)?,
+        )),
+        _ => Ok(Transformed::no(plan)),
+    })?;
+
+    Ok(transformed_plan)
+}
+
+fn is_grouping_function(expr: &Expr) -> bool {
+    // TODO: Do something better than name here should grouping be a built
+    // in expression?
+    matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) 
if func.name() == "grouping")
+}
+
+fn contains_grouping_function(exprs: &[Expr]) -> bool {
+    exprs.iter().any(is_grouping_function)
+}
+
+/// Validate that the arguments to the grouping function are in the group by 
clause.
+fn validate_args(
+    function: &AggregateFunction,
+    group_by_expr: &HashMap<&Expr, usize>,
+) -> Result<()> {
+    let expr_not_in_group_by = function
+        .args
+        .iter()
+        .find(|expr| !group_by_expr.contains_key(expr));
+    if let Some(expr) = expr_not_in_group_by {
+        plan_err!(
+            "Argument {} to grouping function is not in grouping columns {}",
+            expr,
+            group_by_expr.keys().map(|e| e.to_string()).join(", ")
+        )
+    } else {
+        Ok(())
+    }
+}
+
+fn grouping_function_on_id(
+    function: &AggregateFunction,
+    group_by_expr: &HashMap<&Expr, usize>,
+    is_grouping_set: bool,
+) -> Result<Expr> {
+    validate_args(function, group_by_expr)?;
+    let args = &function.args;
+
+    // Postgres allows grouping function for group by without grouping sets, 
the result is then
+    // always 0
+    if !is_grouping_set {
+        return Ok(Expr::Literal(ScalarValue::from(0i32)));
+    }
+
+    let group_by_expr_count = group_by_expr.len();
+    let literal = |value: usize| {
+        if group_by_expr_count < 8 {
+            Expr::Literal(ScalarValue::from(value as u8))
+        } else if group_by_expr_count < 16 {
+            Expr::Literal(ScalarValue::from(value as u16))
+        } else if group_by_expr_count < 32 {
+            Expr::Literal(ScalarValue::from(value as u32))
+        } else {
+            Expr::Literal(ScalarValue::from(value as u64))
+        }
+    };
+
+    let grouping_id_column = 
Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
+    // The grouping call is exactly our internal grouping id
+    if args.len() == group_by_expr_count
+        && args
+            .iter()
+            .rev()
+            .enumerate()
+            .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
+    {
+        return Ok(cast(grouping_id_column, DataType::Int32));
+    }
+
+    args.iter()
+        .rev()
+        .enumerate()
+        .map(|(arg_idx, expr)| {
+            group_by_expr.get(expr).map(|group_by_idx| {
+                let group_by_bit =
+                    bitwise_and(grouping_id_column.clone(), literal(1 << 
group_by_idx));
+                match group_by_idx.cmp(&arg_idx) {
+                    Ordering::Less => {
+                        bitwise_shift_left(group_by_bit, literal(arg_idx - 
group_by_idx))
+                    }
+                    Ordering::Greater => {
+                        bitwise_shift_right(group_by_bit, literal(group_by_idx 
- arg_idx))
+                    }
+                    Ordering::Equal => group_by_bit,
+                }
+            })
+        })
+        .collect::<Option<Vec<_>>>()
+        .and_then(|bit_exprs| {
+            bit_exprs
+                .into_iter()
+                .reduce(bitwise_or)
+                .map(|expr| cast(expr, DataType::Int32))
+        })
+        .ok_or_else(|| {
+            internal_datafusion_err!("Grouping sets should contains at least 
one element")
+        })
+}
diff --git a/datafusion/sqllogictest/test_files/explain.slt 
b/datafusion/sqllogictest/test_files/explain.slt
index 6dc92bae82..b1962ffcc1 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -176,6 +176,7 @@ initial_logical_plan
 02)--TableScan: simple_explain_test
 logical_plan after inline_table_scan SAME TEXT AS ABOVE
 logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE
+logical_plan after resolve_grouping_function SAME TEXT AS ABOVE
 logical_plan after type_coercion SAME TEXT AS ABOVE
 logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
 analyzed_logical_plan SAME TEXT AS ABOVE
diff --git a/datafusion/sqllogictest/test_files/grouping.slt 
b/datafusion/sqllogictest/test_files/grouping.slt
new file mode 100644
index 0000000000..64d040d012
--- /dev/null
+++ b/datafusion/sqllogictest/test_files/grouping.slt
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+
+#   http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+statement ok
+CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values
+('a','A',1), ('b','B',2)
+
+# grouping_with_grouping_sets
+query TTIIII
+select
+  c1,
+  c2,
+  grouping(c1) as g0,
+  grouping(c2) as g1,
+  grouping(c1, c2) as g2,
+  grouping(c2, c1) as g3
+from
+  test
+group by
+  grouping sets (
+    (c1, c2),
+    (c1),
+    (c2),
+    ()
+  )
+order by
+  c1, c2, g0, g1, g2, g3;
+----
+a A 0 0 0 0
+a NULL 0 1 1 2
+b B 0 0 0 0
+b NULL 0 1 1 2
+NULL A 1 0 2 1
+NULL B 1 0 2 1
+NULL NULL 1 1 3 3
+
+# grouping_with_cube
+query TTIIII
+select
+  c1,
+  c2,
+  grouping(c1) as g0,
+  grouping(c2) as g1,
+  grouping(c1, c2) as g2,
+  grouping(c2, c1) as g3
+from
+  test
+group by
+  cube(c1, c2)
+order by
+  c1, c2, g0, g1, g2, g3;
+----
+a A 0 0 0 0
+a NULL 0 1 1 2
+b B 0 0 0 0
+b NULL 0 1 1 2
+NULL A 1 0 2 1
+NULL B 1 0 2 1
+NULL NULL 1 1 3 3
+
+# grouping_with_rollup
+query TTIIII
+select
+  c1,
+  c2,
+  grouping(c1) as g0,
+  grouping(c2) as g1,
+  grouping(c1, c2) as g2,
+  grouping(c2, c1) as g3
+from
+  test
+group by
+  rollup(c1, c2)
+order by
+  c1, c2, g0, g1, g2, g3;
+----
+a A 0 0 0 0
+a NULL 0 1 1 2
+b B 0 0 0 0
+b NULL 0 1 1 2
+NULL NULL 1 1 3 3
+
+query TTIIIIIIII
+select
+  c1,
+  c2,
+  c3,
+  grouping(c1) as g0,
+  grouping(c2) as g1,
+  grouping(c1, c2) as g2,
+  grouping(c2, c1) as g3,
+  grouping(c1, c2, c3) as g4,
+  grouping(c2, c3, c1) as g5,
+  grouping(c3, c2, c1) as g6
+from
+  test
+group by
+  rollup(c1, c2, c3)
+order by
+  c1, c2, g0, g1, g2, g3, g4, g5, g6;
+----
+a A 1 0 0 0 0 0 0 0
+a A NULL 0 0 0 0 1 2 4
+a NULL NULL 0 1 1 2 3 6 6
+b B 2 0 0 0 0 0 0 0
+b B NULL 0 0 0 0 1 2 4
+b NULL NULL 0 1 1 2 3 6 6
+NULL NULL NULL 1 1 3 3 7 7 7
+
+# grouping_with_add
+query TTI
+select
+  c1,
+  c2,
+  grouping(c1)+grouping(c2) as g0
+from
+  test
+group by
+  rollup(c1, c2)
+order by
+  c1, c2, g0;
+----
+a A 0
+a NULL 1
+b B 0
+b NULL 1
+NULL NULL 2
+
+#grouping_with_windown_function
+query TTIII
+select
+  c1,
+  c2,
+  count(c1) as cnt,
+  grouping(c1)+ grouping(c2) as g0,
+  rank() over (
+    partition by grouping(c1)+grouping(c2),
+    case when grouping(c2) = 0 then c1 end
+    order by
+      count(c1) desc
+  ) as rank_within_parent
+from
+  test
+group by
+  rollup(c1, c2)
+order by
+  c1,
+  c2,
+  cnt,
+  g0 desc,
+  rank_within_parent;
+----
+a A 1 0 1
+a NULL 1 1 1
+b B 1 0 1
+b NULL 1 1 1
+NULL NULL 2 2 1
+
+# grouping_with_non_columns
+query TIIIII
+select
+  c1,
+  c3 + 1 as c3_add_one,
+  grouping(c1) as g0,
+  grouping(c3 + 1) as g1,
+  grouping(c1, c3 + 1) as g2,
+  grouping(c3 + 1, c1) as g3
+from
+  test
+group by
+  grouping sets (
+    (c1, c3 + 1),
+    (c3 + 1),
+    (c1)
+  )
+order by
+  c1, c3_add_one, g0, g1, g2, g3;
+----
+a 2 0 0 0 0
+a NULL 0 1 1 2
+b 3 0 0 0 0
+b NULL 0 1 1 2
+NULL 2 1 0 2 1
+NULL 3 1 0 2 1
+
+# postgres allows grouping function for GROUP BY without GROUPING 
SETS/ROLLUP/CUBE
+query TI
+select c1, grouping(c1) from test group by c1 order by c1;
+----
+a 0
+b 0
+
+statement error c2.*not in grouping columns
+select c1, grouping(c2) from test group by c1;
+
+statement error c2.*not in grouping columns
+select c1, grouping(c1, c2) from test group by CUBE(c1);
+
+statement error zero arguments
+select c1, grouping() from test group by CUBE(c1);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to