This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 310d687fcbd6 [SPARK-55137][SQL] Refactor 
`GroupingAnalyticsTransformer` and `Analyzer` code
310d687fcbd6 is described below

commit 310d687fcbd6bbe70644d491b9f8e8bca09b9222
Author: mihailoale-db <[email protected]>
AuthorDate: Fri Jan 23 20:27:03 2026 +0800

    [SPARK-55137][SQL] Refactor `GroupingAnalyticsTransformer` and `Analyzer` 
code
    
    ### What changes were proposed in this pull request?
    In this PR I propose to Refactor `GroupingAnalyticsTransformer` and 
`Analyzer` code.
    
    ### Why are the changes needed?
    In order to reuse it during single-pass resolver implementation of it.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #53921 from mihailoale-db/refactorgroupanalytics.
    
    Authored-by: mihailoale-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  50 ++++------
 .../analysis/GroupingAnalyticsTransformer.scala    | 102 ++++++++++++---------
 2 files changed, 75 insertions(+), 77 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6e899e958f15..74ee4622f9d9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -655,29 +655,6 @@ class Analyzer(
       e.exists (g => g.isInstanceOf[Grouping] || g.isInstanceOf[GroupingID])
     }
 
-    private def replaceGroupingFunc(
-        expr: Expression,
-        groupByExprs: Seq[Expression],
-        gid: Expression): Expression = {
-      expr transform {
-        case e: GroupingID =>
-          if (e.groupByExprs.isEmpty ||
-              e.groupByExprs.map(_.canonicalized) == 
groupByExprs.map(_.canonicalized)) {
-            Alias(gid, toPrettySQL(e))()
-          } else {
-            throw QueryCompilationErrors.groupingIDMismatchError(e, 
groupByExprs)
-          }
-        case e @ Grouping(col: Expression) =>
-          val idx = groupByExprs.indexWhere(_.semanticEquals(col))
-          if (idx >= 0) {
-            Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length 
- 1 - idx)),
-              Literal(1L)), ByteType), toPrettySQL(e))()
-          } else {
-            throw QueryCompilationErrors.groupingColInvalidError(col, 
groupByExprs)
-          }
-      }
-    }
-
     /*
      * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
      */
@@ -710,14 +687,7 @@ class Analyzer(
 
     private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
       plan.collectFirst {
-        case a: Aggregate =>
-          // this Aggregate should have grouping id as the last grouping key.
-          val gid = a.groupingExpressions.last
-          if (!gid.isInstanceOf[AttributeReference]
-            || gid.asInstanceOf[AttributeReference].name != 
VirtualColumn.groupingIdName) {
-            throw 
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
-          }
-          a.groupingExpressions.take(a.groupingExpressions.length - 1)
+        case a: Aggregate => 
GroupingAnalyticsTransformer.collectGroupingExpressions(a)
       }.getOrElse {
         throw 
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
       }
@@ -800,7 +770,13 @@ class Analyzer(
       case f @ Filter(cond, child) if hasGroupingFunction(cond) && 
cond.resolved =>
         val groupingExprs = findGroupingExprs(child)
         // The unresolved grouping id will be resolved by ResolveReferences
-        val newCond = replaceGroupingFunc(cond, groupingExprs, 
VirtualColumn.groupingIdAttribute)
+        val newCond = GroupingAnalyticsTransformer.replaceGroupingFunction(
+          expression = cond,
+          groupByExpressions = groupingExprs,
+          gid = VirtualColumn.groupingIdAttribute,
+          newAlias = (child, name, qualifier) =>
+            Alias(child, name.get)(qualifier = qualifier)
+        )
         f.copy(condition = newCond)
 
       // We should make sure all [[SortOrder]]s have been resolved.
@@ -809,7 +785,15 @@ class Analyzer(
         val groupingExprs = findGroupingExprs(child)
         val gid = VirtualColumn.groupingIdAttribute
         // The unresolved grouping id will be resolved by ResolveReferences
-        val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, 
gid).asInstanceOf[SortOrder])
+        val newOrder = order.map { expression =>
+          GroupingAnalyticsTransformer.replaceGroupingFunction(
+            expression = expression,
+            groupByExpressions = groupingExprs,
+            gid = gid,
+            newAlias = (child, name, qualifier) =>
+              Alias(child, name.get)(qualifier = qualifier)
+          ).asInstanceOf[SortOrder]
+        }
         s.copy(order = newOrder)
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
index 2de2a1aaf3b1..ade9e1d2faf9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/GroupingAnalyticsTransformer.scala
@@ -20,11 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{
-  Aggregate,
-  Expand,
-  LogicalPlan
-}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, 
LogicalPlan}
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types.ByteType
@@ -115,6 +111,63 @@ object GroupingAnalyticsTransformer extends SQLConfHelper 
with AliasHelper {
     aggregate
   }
 
+  /**
+   * Replace [[GROUPING]] and [[GROUPING_ID]] functions with expressions that 
extract bits from
+   * the grouping ID attribute to determine which grouping set is active.
+   */
+  def replaceGroupingFunction(
+      expression: Expression,
+      groupByExpressions: Seq[Expression],
+      gid: Expression,
+      newAlias: (Expression, Option[String], Seq[String]) => Alias): 
Expression = {
+    val canonicalizedGroupByExpressions = 
groupByExpressions.map(_.canonicalized)
+
+    expression transform {
+      case groupingId: GroupingID =>
+        if (groupingId.groupByExprs.isEmpty ||
+          groupingId.groupByExprs.map(_.canonicalized) == 
canonicalizedGroupByExpressions) {
+          newAlias(gid, Some(toPrettySQL(groupingId)), Seq.empty)
+        } else {
+          throw QueryCompilationErrors.groupingIDMismatchError(groupingId, 
groupByExpressions)
+        }
+      case grouping @ Grouping(column: Expression) =>
+        val index = groupByExpressions.indexWhere(_.semanticEquals(column))
+        if (index >= 0) {
+          newAlias(
+            Cast(
+              BitwiseAnd(
+                ShiftRight(gid, Literal(groupByExpressions.length - 1 - 
index)),
+                Literal(1L)
+              ),
+              ByteType
+            ).withTimeZone(conf.sessionLocalTimeZone),
+            Some(toPrettySQL(grouping)),
+            Seq.empty
+          )
+        } else {
+          throw QueryCompilationErrors.groupingColInvalidError(column, 
groupByExpressions)
+        }
+    }
+  }
+
+  /**
+   * Collect the last grouping expression since the provided [[Aggregate]] 
should have grouping id
+   * as the last grouping key.
+   */
+  def collectGroupingExpressions(aggregate: Aggregate): Seq[Expression] = {
+    val gid = aggregate.groupingExpressions.last
+    gid match {
+      case attributeReference: AttributeReference =>
+        if (attributeReference.name != VirtualColumn.groupingIdName) {
+          throw 
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
+        }
+      case _ =>
+        throw 
QueryCompilationErrors.groupingMustWithGroupingSetsOrCubeOrRollupError()
+    }
+
+    aggregate.groupingExpressions.take(aggregate.groupingExpressions.length - 
1)
+  }
+
   /**
    * Create new aliases for all group by expressions to prevent null values 
set by [[Expand]]
    * from being used in aggregates instead of original values.
@@ -205,45 +258,6 @@ object GroupingAnalyticsTransformer extends SQLConfHelper 
with AliasHelper {
     }
   }
 
-  /**
-   * Replace [[GROUPING]] and [[GROUPING_ID]] functions with expressions that 
extract bits from
-   * the grouping ID attribute to determine which grouping set is active.
-   */
-  private def replaceGroupingFunction(
-      expression: Expression,
-      groupByExpressions: Seq[Expression],
-      gid: Expression,
-      newAlias: (Expression, Option[String], Seq[String]) => Alias): 
Expression = {
-    val canonicalizedGroupByExpressions = 
groupByExpressions.map(_.canonicalized)
-
-    expression transform {
-      case groupingId: GroupingID =>
-        if (groupingId.groupByExprs.isEmpty ||
-          groupingId.groupByExprs.map(_.canonicalized) == 
canonicalizedGroupByExpressions) {
-          newAlias(gid, Some(toPrettySQL(groupingId)), Seq.empty)
-        } else {
-          throw QueryCompilationErrors.groupingIDMismatchError(groupingId, 
groupByExpressions)
-        }
-      case grouping @ Grouping(column: Expression) =>
-        val index = groupByExpressions.indexWhere(_.semanticEquals(column))
-        if (index >= 0) {
-          newAlias(
-            Cast(
-              BitwiseAnd(
-                ShiftRight(gid, Literal(groupByExpressions.length - 1 - 
index)),
-                Literal(1L)
-              ),
-              ByteType
-            ).withTimeZone(conf.sessionLocalTimeZone),
-            Some(toPrettySQL(grouping)),
-            Seq.empty
-          )
-        } else {
-          throw QueryCompilationErrors.groupingColInvalidError(column, 
groupByExpressions)
-        }
-    }
-  }
-
   /**
    * Replace group by expressions with their corresponding expanded attributes 
from the
    * [[Expand]] operator output. Leaves aggregate expressions unchanged.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to