This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 310d687fcbd6 [SPARK-55137][SQL] Refactor
`GroupingAnalyticsTransformer` and `Analyzer` code
310d687fcbd6 is described below
commit 310d687fcbd6bbe70644d491b9f8e8bca09b9222
Author: mihailoale-db <[email protected]>
AuthorDate: Fri Jan 23 20:27:03 2026 +0800
[SPARK-55137][SQL] Refactor `GroupingAnalyticsTransformer` and `Analyzer`
code
### What changes were proposed in this pull request?
In this PR I propose to Refactor `GroupingAnalyticsTransformer` and
`Analyzer` code.
### Why are the changes needed?
In order to reuse it during single-pass resolver implementation of it.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53921 from mihailoale-db/refactorgroupanalytics.
Authored-by: mihailoale-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 50 ++++------
.../analysis/GroupingAnalyticsTransformer.scala | 102 ++++++++++++---------
2 files changed, 75 insertions(+), 77 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6e899e958f15..74ee4622f9d9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -655,29 +655,6 @@ class Analyzer(
e.exists (g => g.isInstanceOf[Grouping] || g.isInstanceOf[GroupingID])
}
- private def replaceGroupingFunc(
- expr: Expression,
- groupByExprs: Seq[Expression],
- gid: Expression): Expression = {
- expr transform {
- case e: GroupingID =>
- if (e.groupByExprs.isEmpty ||
- e.groupByExprs.map(_.canonicalized) ==
groupByExprs.map(_.canonicalized)) {
- Alias(gid, toPrettySQL(e))()
- } else {
- throw QueryCompilationErrors.groupingIDMismatchError(e,
groupByExprs)
- }
- case e @ Grouping(col: Expression) =>
- val idx = groupByExprs.indexWhere(_.semanticEquals(col))
- if (idx >= 0) {
- Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length
- 1 - idx)),
- Literal(1L)), ByteType), toPrettySQL(e))()
- } else {
- throw QueryCompilationErrors.groupingColInvalidError(col,
groupByExprs)
- }
- }
- }
-
/*
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
*/
@@ -710,14 +687,7 @@ class Analyzer(
private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
plan.collectFirst {
- case a: Aggregate =>
- // this Aggregate should have grouping id as the last grouping key.
- val gid = a.groupingExpressions.last
- if (!gid.isInstanceOf[AttributeReference]
- || gid.asInstanceOf[AttributeReference].name !=
VirtualColumn.groupingIdName) {
- throw
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
- }
- a.groupingExpressions.take(a.groupingExpressions.length - 1)
+ case a: Aggregate =>
GroupingAnalyticsTransformer.collectGroupingExpressions(a)
}.getOrElse {
throw
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
}
@@ -800,7 +770,13 @@ class Analyzer(
case f @ Filter(cond, child) if hasGroupingFunction(cond) &&
cond.resolved =>
val groupingExprs = findGroupingExprs(child)
// The unresolved grouping id will be resolved by ResolveReferences
- val newCond = replaceGroupingFunc(cond, groupingExprs,
VirtualColumn.groupingIdAttribute)
+ val newCond = GroupingAnalyticsTransformer.replaceGroupingFunction(
+ expression = cond,
+ groupByExpressions = groupingExprs,
+ gid = VirtualColumn.groupingIdAttribute,
+ newAlias = (child, name, qualifier) =>
+ Alias(child, name.get)(qualifier = qualifier)
+ )
f.copy(condition = newCond)
// We should make sure all [[SortOrder]]s have been resolved.
@@ -809,7 +785,15 @@ class Analyzer(
val groupingExprs = findGroupingExprs(child)
val gid = VirtualColumn.groupingIdAttribute
// The unresolved grouping id will be resolved by ResolveReferences
- val newOrder = order.map(replaceGroupingFunc(_, groupingExprs,
gid).asInstanceOf[SortOrder])
+ val newOrder = order.map { expression =>
+ GroupingAnalyticsTransformer.replaceGroupingFunction(
+ expression = expression,
+ groupByExpressions = groupingExprs,
+ gid = gid,
+ newAlias = (child, name, qualifier) =>
+ Alias(child, name.get)(qualifier = qualifier)
+ ).asInstanceOf[SortOrder]
+ }
s.copy(order = newOrder)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
index 2de2a1aaf3b1..ade9e1d2faf9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
@@ -20,11 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{
- Aggregate,
- Expand,
- LogicalPlan
-}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand,
LogicalPlan}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.ByteType
@@ -115,6 +111,63 @@ object GroupingAnalyticsTransformer extends SQLConfHelper
with AliasHelper {
aggregate
}
+ /**
+ * Replace [[GROUPING]] and [[GROUPING_ID]] functions with expressions that
extract bits from
+ * the grouping ID attribute to determine which grouping set is active.
+ */
+ def replaceGroupingFunction(
+ expression: Expression,
+ groupByExpressions: Seq[Expression],
+ gid: Expression,
+ newAlias: (Expression, Option[String], Seq[String]) => Alias):
Expression = {
+ val canonicalizedGroupByExpressions =
groupByExpressions.map(_.canonicalized)
+
+ expression transform {
+ case groupingId: GroupingID =>
+ if (groupingId.groupByExprs.isEmpty ||
+ groupingId.groupByExprs.map(_.canonicalized) ==
canonicalizedGroupByExpressions) {
+ newAlias(gid, Some(toPrettySQL(groupingId)), Seq.empty)
+ } else {
+ throw QueryCompilationErrors.groupingIDMismatchError(groupingId,
groupByExpressions)
+ }
+ case grouping @ Grouping(column: Expression) =>
+ val index = groupByExpressions.indexWhere(_.semanticEquals(column))
+ if (index >= 0) {
+ newAlias(
+ Cast(
+ BitwiseAnd(
+ ShiftRight(gid, Literal(groupByExpressions.length - 1 -
index)),
+ Literal(1L)
+ ),
+ ByteType
+ ).withTimeZone(conf.sessionLocalTimeZone),
+ Some(toPrettySQL(grouping)),
+ Seq.empty
+ )
+ } else {
+ throw QueryCompilationErrors.groupingColInvalidError(column,
groupByExpressions)
+ }
+ }
+ }
+
+ /**
+ * Collect the last grouping expression since the provided [[Aggregate]]
should have grouping id
+ * as the last grouping key.
+ */
+ def collectGroupingExpressions(aggregate: Aggregate): Seq[Expression] = {
+ val gid = aggregate.groupingExpressions.last
+ gid match {
+ case attributeReference: AttributeReference =>
+ if (attributeReference.name != VirtualColumn.groupingIdName) {
+ throw
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
+ }
+ case _ =>
+ throw
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
+ }
+
+ aggregate.groupingExpressions.take(aggregate.groupingExpressions.length -
1)
+ }
+
/**
* Create new aliases for all group by expressions to prevent null values
set by [[Expand]]
* from being used in aggregates instead of original values.
@@ -205,45 +258,6 @@ object GroupingAnalyticsTransformer extends SQLConfHelper
with AliasHelper {
}
}
- /**
- * Replace [[GROUPING]] and [[GROUPING_ID]] functions with expressions that
extract bits from
- * the grouping ID attribute to determine which grouping set is active.
- */
- private def replaceGroupingFunction(
- expression: Expression,
- groupByExpressions: Seq[Expression],
- gid: Expression,
- newAlias: (Expression, Option[String], Seq[String]) => Alias):
Expression = {
- val canonicalizedGroupByExpressions =
groupByExpressions.map(_.canonicalized)
-
- expression transform {
- case groupingId: GroupingID =>
- if (groupingId.groupByExprs.isEmpty ||
- groupingId.groupByExprs.map(_.canonicalized) ==
canonicalizedGroupByExpressions) {
- newAlias(gid, Some(toPrettySQL(groupingId)), Seq.empty)
- } else {
- throw QueryCompilationErrors.groupingIDMismatchError(groupingId,
groupByExpressions)
- }
- case grouping @ Grouping(column: Expression) =>
- val index = groupByExpressions.indexWhere(_.semanticEquals(column))
- if (index >= 0) {
- newAlias(
- Cast(
- BitwiseAnd(
- ShiftRight(gid, Literal(groupByExpressions.length - 1 -
index)),
- Literal(1L)
- ),
- ByteType
- ).withTimeZone(conf.sessionLocalTimeZone),
- Some(toPrettySQL(grouping)),
- Seq.empty
- )
- } else {
- throw QueryCompilationErrors.groupingColInvalidError(column,
groupByExpressions)
- }
- }
- }
-
/**
* Replace group by expressions with their corresponding expanded attributes
from the
* [[Expand]] operator output. Leaves aggregate expressions unchanged.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]