This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8802f6314f feat: Implement grouping function using grouping id (#12704)
8802f6314f is described below
commit 8802f6314f94ee99e2235350c148cdeecaa0ab1b
Author: Emil Ejbyfeldt <[email protected]>
AuthorDate: Wed Oct 16 16:14:19 2024 +0200
feat: Implement grouping function using grouping id (#12704)
* Implement grouping function using grouping id
This patch adds a Analyzer rule to transform the grouping aggreation
function into computation ontop of the grouping id that is used
internally for grouping sets.
* PR comments
---
datafusion/functions-aggregate/src/grouping.rs | 2 +-
datafusion/optimizer/src/analyzer/mod.rs | 3 +
.../src/analyzer/resolve_grouping_function.rs | 247 +++++++++++++++++++++
datafusion/sqllogictest/test_files/explain.slt | 1 +
datafusion/sqllogictest/test_files/grouping.slt | 214 ++++++++++++++++++
5 files changed, 466 insertions(+), 1 deletion(-)
diff --git a/datafusion/functions-aggregate/src/grouping.rs
b/datafusion/functions-aggregate/src/grouping.rs
index 09e9b90b2e..558d3055f1 100644
--- a/datafusion/functions-aggregate/src/grouping.rs
+++ b/datafusion/functions-aggregate/src/grouping.rs
@@ -63,7 +63,7 @@ impl Grouping {
/// Create a new GROUPING aggregate function.
pub fn new() -> Self {
Self {
- signature: Signature::any(1, Volatility::Immutable),
+ signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
diff --git a/datafusion/optimizer/src/analyzer/mod.rs
b/datafusion/optimizer/src/analyzer/mod.rs
index 4cd891664e..a9fd4900b2 100644
--- a/datafusion/optimizer/src/analyzer/mod.rs
+++ b/datafusion/optimizer/src/analyzer/mod.rs
@@ -34,6 +34,7 @@ use datafusion_expr::{Expr, LogicalPlan};
use crate::analyzer::count_wildcard_rule::CountWildcardRule;
use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule;
use crate::analyzer::inline_table_scan::InlineTableScan;
+use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction;
use crate::analyzer::subquery::check_subquery_expr;
use crate::analyzer::type_coercion::TypeCoercion;
use crate::utils::log_plan;
@@ -44,6 +45,7 @@ pub mod count_wildcard_rule;
pub mod expand_wildcard_rule;
pub mod function_rewrite;
pub mod inline_table_scan;
+pub mod resolve_grouping_function;
pub mod subquery;
pub mod type_coercion;
@@ -96,6 +98,7 @@ impl Analyzer {
// Every rule that will generate [Expr::Wildcard] should be placed
in front of [ExpandWildcardRule].
Arc::new(ExpandWildcardRule::new()),
// [Expr::Wildcard] should be expanded before [TypeCoercion]
+ Arc::new(ResolveGroupingFunction::new()),
Arc::new(TypeCoercion::new()),
Arc::new(CountWildcardRule::new()),
];
diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
new file mode 100644
index 0000000000..16ebb8cd39
--- /dev/null
+++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs
@@ -0,0 +1,247 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Analyzed rule to replace TableScan references
+//! such as DataFrames and Views and inlines the LogicalPlan.
+
+use std::cmp::Ordering;
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use crate::analyzer::AnalyzerRule;
+
+use arrow::datatypes::DataType;
+use datafusion_common::config::ConfigOptions;
+use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
+use datafusion_common::{
+ internal_datafusion_err, plan_err, Column, DFSchemaRef, Result,
ScalarValue,
+};
+use datafusion_expr::expr::{AggregateFunction, Alias};
+use datafusion_expr::logical_plan::LogicalPlan;
+use datafusion_expr::utils::grouping_set_to_exprlist;
+use datafusion_expr::{
+ bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast,
Aggregate,
+ Expr, Projection,
+};
+use itertools::Itertools;
+
+/// Replaces grouping aggregation function with value derived from internal
grouping id
+#[derive(Default, Debug)]
+pub struct ResolveGroupingFunction;
+
+impl ResolveGroupingFunction {
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl AnalyzerRule for ResolveGroupingFunction {
+ fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) ->
Result<LogicalPlan> {
+ plan.transform_up(analyze_internal).data()
+ }
+
+ fn name(&self) -> &str {
+ "resolve_grouping_function"
+ }
+}
+
+/// Create a map from grouping expr to index in the internal grouping id.
+///
+/// For more details on how the grouping id bitmap works the documentation for
+/// [[Aggregate::INTERNAL_GROUPING_ID]]
+fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result<HashMap<&Expr,
usize>> {
+ Ok(grouping_set_to_exprlist(group_expr)?
+ .into_iter()
+ .rev()
+ .enumerate()
+ .map(|(idx, v)| (v, idx))
+ .collect::<HashMap<_, _>>())
+}
+
+fn replace_grouping_exprs(
+ input: Arc<LogicalPlan>,
+ schema: DFSchemaRef,
+ group_expr: Vec<Expr>,
+ aggr_expr: Vec<Expr>,
+) -> Result<LogicalPlan> {
+ // Create HashMap from Expr to index in the grouping_id bitmap
+ let is_grouping_set = matches!(group_expr.as_slice(),
[Expr::GroupingSet(_)]);
+ let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?;
+ let columns = schema.columns();
+ let mut new_agg_expr = Vec::new();
+ let mut projection_exprs = Vec::new();
+ let grouping_id_len = if is_grouping_set { 1 } else { 0 };
+ let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len;
+ projection_exprs.extend(
+ columns
+ .iter()
+ .take(group_expr_len)
+ .map(|column| Expr::Column(column.clone())),
+ );
+ for (expr, column) in aggr_expr
+ .into_iter()
+ .zip(columns.into_iter().skip(group_expr_len + grouping_id_len))
+ {
+ match expr {
+ Expr::AggregateFunction(ref function) if
is_grouping_function(&expr) => {
+ let grouping_expr = grouping_function_on_id(
+ function,
+ &group_expr_to_bitmap_index,
+ is_grouping_set,
+ )?;
+ projection_exprs.push(Expr::Alias(Alias::new(
+ grouping_expr,
+ column.relation,
+ column.name,
+ )));
+ }
+ _ => {
+ projection_exprs.push(Expr::Column(column));
+ new_agg_expr.push(expr);
+ }
+ }
+ }
+ // Recreate aggregate without grouping functions
+ let new_aggregate =
+ LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr,
new_agg_expr)?);
+ // Create projection with grouping functions calculations
+ let projection = LogicalPlan::Projection(Projection::try_new(
+ projection_exprs,
+ new_aggregate.into(),
+ )?);
+ Ok(projection)
+}
+
+fn analyze_internal(plan: LogicalPlan) -> Result<Transformed<LogicalPlan>> {
+ // rewrite any subqueries in the plan first
+ let transformed_plan =
+ plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?;
+
+ let transformed_plan = transformed_plan.transform_data(|plan| match plan {
+ LogicalPlan::Aggregate(Aggregate {
+ input,
+ group_expr,
+ aggr_expr,
+ schema,
+ ..
+ }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes(
+ replace_grouping_exprs(input, schema, group_expr, aggr_expr)?,
+ )),
+ _ => Ok(Transformed::no(plan)),
+ })?;
+
+ Ok(transformed_plan)
+}
+
+fn is_grouping_function(expr: &Expr) -> bool {
+ // TODO: Do something better than name here should grouping be a built
+ // in expression?
+ matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. })
if func.name() == "grouping")
+}
+
+fn contains_grouping_function(exprs: &[Expr]) -> bool {
+ exprs.iter().any(is_grouping_function)
+}
+
+/// Validate that the arguments to the grouping function are in the group by
clause.
+fn validate_args(
+ function: &AggregateFunction,
+ group_by_expr: &HashMap<&Expr, usize>,
+) -> Result<()> {
+ let expr_not_in_group_by = function
+ .args
+ .iter()
+ .find(|expr| !group_by_expr.contains_key(expr));
+ if let Some(expr) = expr_not_in_group_by {
+ plan_err!(
+ "Argument {} to grouping function is not in grouping columns {}",
+ expr,
+ group_by_expr.keys().map(|e| e.to_string()).join(", ")
+ )
+ } else {
+ Ok(())
+ }
+}
+
+fn grouping_function_on_id(
+ function: &AggregateFunction,
+ group_by_expr: &HashMap<&Expr, usize>,
+ is_grouping_set: bool,
+) -> Result<Expr> {
+ validate_args(function, group_by_expr)?;
+ let args = &function.args;
+
+ // Postgres allows grouping function for group by without grouping sets,
the result is then
+ // always 0
+ if !is_grouping_set {
+ return Ok(Expr::Literal(ScalarValue::from(0i32)));
+ }
+
+ let group_by_expr_count = group_by_expr.len();
+ let literal = |value: usize| {
+ if group_by_expr_count < 8 {
+ Expr::Literal(ScalarValue::from(value as u8))
+ } else if group_by_expr_count < 16 {
+ Expr::Literal(ScalarValue::from(value as u16))
+ } else if group_by_expr_count < 32 {
+ Expr::Literal(ScalarValue::from(value as u32))
+ } else {
+ Expr::Literal(ScalarValue::from(value as u64))
+ }
+ };
+
+ let grouping_id_column =
Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID));
+ // The grouping call is exactly our internal grouping id
+ if args.len() == group_by_expr_count
+ && args
+ .iter()
+ .rev()
+ .enumerate()
+ .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx))
+ {
+ return Ok(cast(grouping_id_column, DataType::Int32));
+ }
+
+ args.iter()
+ .rev()
+ .enumerate()
+ .map(|(arg_idx, expr)| {
+ group_by_expr.get(expr).map(|group_by_idx| {
+ let group_by_bit =
+ bitwise_and(grouping_id_column.clone(), literal(1 <<
group_by_idx));
+ match group_by_idx.cmp(&arg_idx) {
+ Ordering::Less => {
+ bitwise_shift_left(group_by_bit, literal(arg_idx -
group_by_idx))
+ }
+ Ordering::Greater => {
+ bitwise_shift_right(group_by_bit, literal(group_by_idx
- arg_idx))
+ }
+ Ordering::Equal => group_by_bit,
+ }
+ })
+ })
+ .collect::<Option<Vec<_>>>()
+ .and_then(|bit_exprs| {
+ bit_exprs
+ .into_iter()
+ .reduce(bitwise_or)
+ .map(|expr| cast(expr, DataType::Int32))
+ })
+ .ok_or_else(|| {
+ internal_datafusion_err!("Grouping sets should contains at least
one element")
+ })
+}
diff --git a/datafusion/sqllogictest/test_files/explain.slt
b/datafusion/sqllogictest/test_files/explain.slt
index 6dc92bae82..b1962ffcc1 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -176,6 +176,7 @@ initial_logical_plan
02)--TableScan: simple_explain_test
logical_plan after inline_table_scan SAME TEXT AS ABOVE
logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE
+logical_plan after resolve_grouping_function SAME TEXT AS ABOVE
logical_plan after type_coercion SAME TEXT AS ABOVE
logical_plan after count_wildcard_rule SAME TEXT AS ABOVE
analyzed_logical_plan SAME TEXT AS ABOVE
diff --git a/datafusion/sqllogictest/test_files/grouping.slt
b/datafusion/sqllogictest/test_files/grouping.slt
new file mode 100644
index 0000000000..64d040d012
--- /dev/null
+++ b/datafusion/sqllogictest/test_files/grouping.slt
@@ -0,0 +1,214 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+statement ok
+CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values
+('a','A',1), ('b','B',2)
+
+# grouping_with_grouping_sets
+query TTIIII
+select
+ c1,
+ c2,
+ grouping(c1) as g0,
+ grouping(c2) as g1,
+ grouping(c1, c2) as g2,
+ grouping(c2, c1) as g3
+from
+ test
+group by
+ grouping sets (
+ (c1, c2),
+ (c1),
+ (c2),
+ ()
+ )
+order by
+ c1, c2, g0, g1, g2, g3;
+----
+a A 0 0 0 0
+a NULL 0 1 1 2
+b B 0 0 0 0
+b NULL 0 1 1 2
+NULL A 1 0 2 1
+NULL B 1 0 2 1
+NULL NULL 1 1 3 3
+
+# grouping_with_cube
+query TTIIII
+select
+ c1,
+ c2,
+ grouping(c1) as g0,
+ grouping(c2) as g1,
+ grouping(c1, c2) as g2,
+ grouping(c2, c1) as g3
+from
+ test
+group by
+ cube(c1, c2)
+order by
+ c1, c2, g0, g1, g2, g3;
+----
+a A 0 0 0 0
+a NULL 0 1 1 2
+b B 0 0 0 0
+b NULL 0 1 1 2
+NULL A 1 0 2 1
+NULL B 1 0 2 1
+NULL NULL 1 1 3 3
+
+# grouping_with_rollup
+query TTIIII
+select
+ c1,
+ c2,
+ grouping(c1) as g0,
+ grouping(c2) as g1,
+ grouping(c1, c2) as g2,
+ grouping(c2, c1) as g3
+from
+ test
+group by
+ rollup(c1, c2)
+order by
+ c1, c2, g0, g1, g2, g3;
+----
+a A 0 0 0 0
+a NULL 0 1 1 2
+b B 0 0 0 0
+b NULL 0 1 1 2
+NULL NULL 1 1 3 3
+
+query TTIIIIIIII
+select
+ c1,
+ c2,
+ c3,
+ grouping(c1) as g0,
+ grouping(c2) as g1,
+ grouping(c1, c2) as g2,
+ grouping(c2, c1) as g3,
+ grouping(c1, c2, c3) as g4,
+ grouping(c2, c3, c1) as g5,
+ grouping(c3, c2, c1) as g6
+from
+ test
+group by
+ rollup(c1, c2, c3)
+order by
+ c1, c2, g0, g1, g2, g3, g4, g5, g6;
+----
+a A 1 0 0 0 0 0 0 0
+a A NULL 0 0 0 0 1 2 4
+a NULL NULL 0 1 1 2 3 6 6
+b B 2 0 0 0 0 0 0 0
+b B NULL 0 0 0 0 1 2 4
+b NULL NULL 0 1 1 2 3 6 6
+NULL NULL NULL 1 1 3 3 7 7 7
+
+# grouping_with_add
+query TTI
+select
+ c1,
+ c2,
+ grouping(c1)+grouping(c2) as g0
+from
+ test
+group by
+ rollup(c1, c2)
+order by
+ c1, c2, g0;
+----
+a A 0
+a NULL 1
+b B 0
+b NULL 1
+NULL NULL 2
+
+#grouping_with_windown_function
+query TTIII
+select
+ c1,
+ c2,
+ count(c1) as cnt,
+ grouping(c1)+ grouping(c2) as g0,
+ rank() over (
+ partition by grouping(c1)+grouping(c2),
+ case when grouping(c2) = 0 then c1 end
+ order by
+ count(c1) desc
+ ) as rank_within_parent
+from
+ test
+group by
+ rollup(c1, c2)
+order by
+ c1,
+ c2,
+ cnt,
+ g0 desc,
+ rank_within_parent;
+----
+a A 1 0 1
+a NULL 1 1 1
+b B 1 0 1
+b NULL 1 1 1
+NULL NULL 2 2 1
+
+# grouping_with_non_columns
+query TIIIII
+select
+ c1,
+ c3 + 1 as c3_add_one,
+ grouping(c1) as g0,
+ grouping(c3 + 1) as g1,
+ grouping(c1, c3 + 1) as g2,
+ grouping(c3 + 1, c1) as g3
+from
+ test
+group by
+ grouping sets (
+ (c1, c3 + 1),
+ (c3 + 1),
+ (c1)
+ )
+order by
+ c1, c3_add_one, g0, g1, g2, g3;
+----
+a 2 0 0 0 0
+a NULL 0 1 1 2
+b 3 0 0 0 0
+b NULL 0 1 1 2
+NULL 2 1 0 2 1
+NULL 3 1 0 2 1
+
+# postgres allows grouping function for GROUP BY without GROUPING
SETS/ROLLUP/CUBE
+query TI
+select c1, grouping(c1) from test group by c1 order by c1;
+----
+a 0
+b 0
+
+statement error c2.*not in grouping columns
+select c1, grouping(c2) from test group by c1;
+
+statement error c2.*not in grouping columns
+select c1, grouping(c1, c2) from test group by CUBE(c1);
+
+statement error zero arguments
+select c1, grouping() from test group by CUBE(c1);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]