This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d6b733492d3e [SPARK-50033][SQL] Add a hint to logical.Aggregate() node
d6b733492d3e is described below

commit d6b733492d3e1643babbcd2b1860e35af7928906
Author: Andrey Gubichev <[email protected]>
AuthorDate: Fri Nov 1 09:59:58 2024 -0700

    [SPARK-50033][SQL] Add a hint to logical.Aggregate() node
    
    ### What changes were proposed in this pull request?
    
    Adds an abstract hint to the case class of Aggregate(). And adds an 
extension API to allow custom hint resolution rules.
    
    ### Why are the changes needed?
    
    Allow users to define and resolve their own aggregate hints, and use the 
hints to guide the optimizing/planning behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    a new test for custom aggregate hint
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48523 from agubichev/agg_hint.
    
    Lead-authored-by: Andrey Gubichev <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 32 +++++----
 .../sql/catalyst/analysis/CheckAnalysis.scala      |  2 +-
 .../catalyst/analysis/ColumnResolutionHelper.scala |  2 +-
 .../catalyst/analysis/DeduplicateRelations.scala   |  4 +-
 .../ResolveLateralColumnAliasReference.scala       |  2 +-
 .../analysis/UnsupportedOperationChecker.scala     |  4 +-
 .../catalyst/optimizer/DecorrelateInnerQuery.scala |  2 +-
 .../InsertMapSortInGroupingExpressions.scala       |  2 +-
 .../catalyst/optimizer/OptimizeOneRowPlan.scala    |  2 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 10 +--
 .../optimizer/PropagateEmptyRelation.scala         |  2 +-
 .../optimizer/RemoveRedundantAggregates.scala      |  4 +-
 .../spark/sql/catalyst/optimizer/expressions.scala |  2 +-
 .../spark/sql/catalyst/optimizer/joins.scala       |  8 +--
 .../spark/sql/catalyst/optimizer/subquery.scala    |  4 +-
 .../spark/sql/catalyst/planning/patterns.scala     |  2 +-
 .../plans/logical/basicLogicalOperators.scala      |  3 +-
 .../spark/sql/catalyst/plans/logical/hints.scala   |  2 +
 .../OptimizerStructuralIntegrityCheckerSuite.scala |  2 +-
 .../optimizer/RewriteDistinctAggregatesSuite.scala |  6 +-
 .../apache/spark/sql/SparkSessionExtensions.scala  | 18 +++++
 .../sql/execution/OptimizeMetadataOnlyQuery.scala  |  2 +-
 .../sql/internal/BaseSessionStateBuilder.scala     |  9 +++
 .../org/apache/spark/sql/CTEInlineSuite.scala      | 10 +--
 .../spark/sql/InjectRuntimeFilterSuite.scala       |  2 +-
 .../apache/spark/sql/LateralColumnAliasSuite.scala |  2 +-
 .../spark/sql/SparkSessionExtensionSuite.scala     | 81 ++++++++++++++++++++--
 27 files changed, 164 insertions(+), 57 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6b64f493f405..b94bd31eb3fa 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -250,6 +250,11 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
    */
   val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
 
+  /**
+   * Override to provide additional rules for the "Hints" resolution batch.
+   */
+  val hintResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+
   /**
    * Override to provide rules to do post-hoc resolution. Note that these 
rules will be executed
    * in an individual batch. This batch is to run right after the normal 
resolution batch and
@@ -278,8 +283,9 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
     Batch("Disable Hints", Once,
       new ResolveHints.DisableHints),
     Batch("Hints", fixedPoint,
-      ResolveHints.ResolveJoinStrategyHints,
-      ResolveHints.ResolveCoalesceHints),
+      Seq(ResolveHints.ResolveJoinStrategyHints,
+        ResolveHints.ResolveCoalesceHints) ++
+        hintResolutionRules: _*),
     Batch("Simple Sanity Check", Once,
       LookupFunctions),
     Batch("Keep Legacy Outputs", Once,
@@ -474,7 +480,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
 
     def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
       _.containsPattern(UNRESOLVED_ALIAS), ruleId) {
-      case Aggregate(groups, aggs, child) if child.resolved && 
hasUnresolvedAlias(aggs) =>
+      case Aggregate(groups, aggs, child, _) if child.resolved && 
hasUnresolvedAlias(aggs) =>
         Aggregate(groups, assignAliases(aggs), child)
 
       case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
@@ -689,7 +695,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
     def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsDownWithPruning(
       _.containsPattern(GROUPING_ANALYTICS), ruleId) {
       case h @ UnresolvedHaving(_, agg @ Aggregate(
-        GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, _))
+        GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, _, _))
         if agg.childrenResolved && aggExprs.forall(_.resolved) =>
         tryResolveHavingCondition(h, agg, selectedGroupByExprs, groupByExprs)
 
@@ -699,7 +705,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       case a if !a.childrenResolved => a
 
       // Ensure group by expressions and aggregate expressions have been 
resolved.
-      case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs), 
aggExprs, child)
+      case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs), 
aggExprs, child, _)
         if aggExprs.forall(_.resolved) =>
         constructAggregate(selectedGroupByExprs, groupByExprs, aggExprs, child)
 
@@ -1841,7 +1847,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
 
       // Replace the index with the corresponding expression in 
aggregateExpressions. The index is
       // a 1-base position of aggregateExpressions, which is output columns 
(select expression)
-      case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
+      case Aggregate(groups, aggs, child, _) if aggs.forall(_.resolved) &&
         groups.exists(containUnresolvedOrdinal) =>
         val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs))
         Aggregate(newGroups, aggs, child)
@@ -2827,15 +2833,15 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
         val nestedGenerator = projectList.find(hasNestedGenerator).get
         throw 
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
 
-      case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
+      case Aggregate(_, aggList, _, _) if aggList.exists(hasNestedGenerator) =>
         val nestedGenerator = aggList.find(hasNestedGenerator).get
         throw 
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
 
-      case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
+      case Aggregate(_, aggList, _, _) if aggList.count(hasGenerator) > 1 =>
         val generators = aggList.filter(hasGenerator).map(trimAlias)
         throw QueryCompilationErrors.moreThanOneGeneratorError(generators)
 
-      case Aggregate(groupList, aggList, child) if 
canRewriteGenerator(aggList) &&
+      case Aggregate(groupList, aggList, child, _) if 
canRewriteGenerator(aggList) &&
           aggList.exists(hasGenerator) =>
         // If generator in the aggregate list was visited, set the boolean 
flag true.
         var generatorVisited = false
@@ -3201,7 +3207,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
 
       // Aggregate with Having clause. This rule works with an unresolved 
Aggregate because
       // a resolved Aggregate will not have Window Functions.
-      case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, 
aggregateExprs, child))
+      case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, 
aggregateExprs, child, _))
         if child.resolved &&
           hasWindowFunction(aggregateExprs) &&
           a.expressions.forall(_.resolved) =>
@@ -3226,7 +3232,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
 
       // Aggregate without Having clause.
       // Make sure the lateral column aliases are properly handled first.
-      case a @ Aggregate(groupingExprs, aggregateExprs, child)
+      case a @ Aggregate(groupingExprs, aggregateExprs, child, _)
         if hasWindowFunction(aggregateExprs) &&
           a.expressions.forall(_.resolved) &&
           
!aggregateExprs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
@@ -3885,9 +3891,9 @@ object CleanupAliases extends Rule[LogicalPlan] with 
AliasHelper {
       val cleanedProjectList = projectList.map(trimNonTopLevelAliases)
       Project(cleanedProjectList, child)
 
-    case Aggregate(grouping, aggs, child) =>
+    case Aggregate(grouping, aggs, child, hint) =>
       val cleanedAggs = aggs.map(trimNonTopLevelAliases)
-      Aggregate(grouping.map(trimAliases), cleanedAggs, child)
+      Aggregate(grouping.map(trimAliases), cleanedAggs, child, hint)
 
     case Window(windowExprs, partitionSpec, orderSpec, child) =>
       val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index efb63ea181a8..16899b656f30 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -757,7 +757,7 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
           case p @ Project(projectList, _) =>
             checkForUnspecifiedWindow(projectList)
 
-          case agg@Aggregate(_, aggregateExpressions, _) if
+          case agg@Aggregate(_, aggregateExpressions, _, _) if
             PlanHelper.specialExpressionsInUnsupportedOperator(agg).isEmpty =>
             checkForUnspecifiedWindow(aggregateExpressions)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 222e2c461999..d9c723aecbe8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -72,7 +72,7 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
               newProject.copyTagsFrom(p)
               (newExprs, newProject)
 
-            case a @ Aggregate(groupExprs, aggExprs, child) =>
+            case a @ Aggregate(groupExprs, aggExprs, child, _) =>
               if (missingAttrs.forall(attr => 
groupExprs.exists(_.semanticEquals(attr)))) {
                 // All the missing attributes are grouping expressions, valid 
case.
                 (newExprs,
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 8181078c519f..ca5a6eee9bc9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -361,14 +361,14 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
         if findAliases(projectList).size == projectList.size =>
         Nil
 
-      case oldVersion @ Aggregate(_, aggregateExpressions, _)
+      case oldVersion @ Aggregate(_, aggregateExpressions, _, _)
           if 
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
         val newVersion = oldVersion.copy(aggregateExpressions = 
newAliases(aggregateExpressions))
         newVersion.copyTagsFrom(oldVersion)
         Seq((oldVersion, newVersion))
 
       // We don't search the child plan recursively for the same reason as the 
above Project.
-      case _ @ Aggregate(_, aggregateExpressions, _)
+      case _ @ Aggregate(_, aggregateExpressions, _, _)
         if findAliases(aggregateExpressions).size == aggregateExpressions.size 
=>
         Nil
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
index c249a3506f2d..da8065eab606 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
@@ -196,7 +196,7 @@ object ResolveLateralColumnAliasReference extends 
Rule[LogicalPlan] {
         if ruleApplicableOnOperator(aggOriginal, 
aggOriginal.aggregateExpressions)
           && aggOriginal.aggregateExpressions.exists(
             _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
-        val agg @ Aggregate(groupingExpressions, aggregateExpressions, _) =
+        val agg @ Aggregate(groupingExpressions, aggregateExpressions, _, _) =
           aggOriginal.mapChildren(apply0)
 
         // Check if current Aggregate is eligible to lift up with Project: the 
aggregate
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index c8dc43209afd..69639b69290c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -244,7 +244,7 @@ object UnsupportedOperationChecker extends Logging {
      * data.
      */
     def containsCompleteData(subplan: LogicalPlan): Boolean = {
-      val aggs = subplan.collect { case a@Aggregate(_, _, _) if a.isStreaming 
=> a }
+      val aggs = subplan.collect { case a@Aggregate(_, _, _, _) if 
a.isStreaming => a }
       // Either the subplan has no streaming source, or it has aggregation 
with Complete mode
       !subplan.isStreaming || (aggs.nonEmpty && outputMode == 
InternalOutputModes.Complete)
     }
@@ -264,7 +264,7 @@ object UnsupportedOperationChecker extends Logging {
       // Operations that cannot exists anywhere in a streaming plan
       subPlan match {
 
-        case Aggregate(groupingExpressions, aggregateExpressions, child) =>
+        case Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
           val distinctAggExprs = aggregateExpressions.flatMap { expr =>
             expr.collect { case ae: AggregateExpression if ae.isDistinct => ae 
}
           }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
index 6c0d7189862d..9758f37efc2d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -773,7 +773,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
               orderSpec = newOrderSpec, newChild)
             (newWindow, joinCond, outerReferenceMap)
 
-          case a @ Aggregate(groupingExpressions, aggregateExpressions, child) 
=>
+          case a @ Aggregate(groupingExpressions, aggregateExpressions, child, 
_) =>
             val outerReferences = collectOuterReferences(a.expressions)
             val newOuterReferences = parentOuterReferences ++ outerReferences
             val (newChild, joinCond, outerReferenceMap) =
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
index 15ef025afd43..b6ced6c49a36 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
@@ -53,7 +53,7 @@ object InsertMapSortInGroupingExpressions extends 
Rule[LogicalPlan] {
     }
 
     plan transformUpWithNewOutput {
-      case agg @ Aggregate(groupingExprs, aggregateExpressions, child)
+      case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
           if agg.groupingExpressions.exists(shouldAddMapSort) =>
         val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
         val newGroupingKeys = groupingExprs.map { expr =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
index 61c08eb8f8b6..8e066d1cd634 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
@@ -50,7 +50,7 @@ object OptimizeOneRowPlan extends Rule[LogicalPlan] {
         isChildEligible(child, enableForStreaming) => child
       case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L) 
&&
         isChildEligible(child, enableForStreaming) => child
-      case agg @ Aggregate(_, _, child) if agg.groupOnly && 
child.maxRows.exists(_ <= 1L) &&
+      case agg @ Aggregate(_, _, child, _) if agg.groupOnly && 
child.maxRows.exists(_ <= 1L) &&
         isChildEligible(child, enableForStreaming) =>
         Project(agg.aggregateExpressions, child)
       case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) &&
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index dcfb64ae51fb..7074f8c4f089 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -338,7 +338,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
       // optimize it again, to save optimization time and avoid breaking 
broadcast/subquery reuse.
       case d: DynamicPruningSubquery => d
       case s @ ScalarSubquery(
-        PhysicalOperation(projections, predicates, a @ Aggregate(group, _, 
child)),
+        PhysicalOperation(projections, predicates, a @ Aggregate(group, _, 
child, _)),
         _, _, _, _, mayHaveCountBug, _)
         if 
conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG)
 &&
           mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
@@ -960,7 +960,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
       d.copy(child = prunedChild(child, d.references))
 
     // Prunes the unused columns from child of 
Aggregate/Expand/Generate/ScriptTransformation
-    case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) 
=>
+    case a @ Aggregate(_, _, child, _) if 
!child.outputSet.subsetOf(a.references) =>
       a.copy(child = prunedChild(child, a.references))
     case f @ FlatMapGroupsInPandas(_, _, _, child) if 
!child.outputSet.subsetOf(f.references) =>
       f.copy(child = prunedChild(child, f.references))
@@ -1653,7 +1653,7 @@ object EliminateSorts extends Rule[LogicalPlan] {
     case j @ Join(originLeft, originRight, _, cond, _) if 
cond.forall(_.deterministic) =>
       j.copy(left = recursiveRemoveSort(originLeft, true),
         right = recursiveRemoveSort(originRight, true))
-    case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
+    case g @ Aggregate(_, aggs, originChild, _) if isOrderIrrelevantAggs(aggs) 
=>
       g.copy(child = recursiveRemoveSort(originChild, true))
   }
 
@@ -2482,7 +2482,7 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
 object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
     _.containsPattern(AGGREGATE), ruleId) {
-    case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
+    case a @ Aggregate(grouping, _, _, _) if grouping.nonEmpty =>
       val newGrouping = grouping.filter(!_.foldable)
       if (newGrouping.nonEmpty) {
         a.copy(groupingExpressions = newGrouping)
@@ -2538,7 +2538,7 @@ object GenerateOptimization extends Rule[LogicalPlan] {
 object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
     _.containsPattern(AGGREGATE), ruleId) {
-    case a @ Aggregate(grouping, _, _) if grouping.size > 1 =>
+    case a @ Aggregate(grouping, _, _, _) if grouping.size > 1 =>
       val newGrouping = ExpressionSet(grouping).toSeq
       if (newGrouping.size == grouping.size) {
         a
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index d23d43acc217..86316494f6ff 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -169,7 +169,7 @@ abstract class PropagateEmptyRelationBase extends 
Rule[LogicalPlan] with CastSup
       // Aggregation on empty LocalRelation generated from a streaming source 
is not eliminated
       // as stateful streaming aggregation need to perform other state 
management operations other
       // than just processing the input data.
-      case Aggregate(ge, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
+      case Aggregate(ge, _, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
       // Generators like Hive-style UDTF may return their records within 
`close`.
       case Generate(_: Explode, _, _, _, _, _) => empty(p)
       case Expand(_, _, _) => empty(p)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
index badf4065f5fb..d6a4bd030c9d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
@@ -30,10 +30,10 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
 object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsPattern(AGGREGATE), ruleId) {
-    case upper @ Aggregate(_, _, lower: Aggregate) if isLowerRedundant(upper, 
lower) =>
+    case upper @ Aggregate(_, _, lower: Aggregate, _) if 
isLowerRedundant(upper, lower) =>
       val projectList = 
lower.aggregateExpressions.filter(upper.references.contains(_))
       upper.copy(child = Project(projectList, lower.child))
-    case agg @ Aggregate(groupingExps, _, child)
+    case agg @ Aggregate(groupingExps, _, child, _)
         if agg.groupOnly && 
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
       Project(agg.aggregateExpressions, child)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index c0cd976b9e9b..06fc366ce6bb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -244,7 +244,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] 
{
   }
 
   private def collectGroupingExpressions(plan: LogicalPlan): ExpressionSet = 
plan match {
-    case Aggregate(groupingExpressions, aggregateExpressions, child) =>
+    case Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
       ExpressionSet.apply(groupingExpressions)
     case _ => ExpressionSet(Seq.empty)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 6802adaa2ea2..5fb30e810649 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -211,17 +211,17 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with 
PredicateHelper {
       val newJoinType = buildNewJoinType(f, j)
       if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType 
= newJoinType))
 
-    case a @ Aggregate(_, _, Join(left, _, LeftOuter, _, _))
+    case a @ Aggregate(_, _, Join(left, _, LeftOuter, _, _), _)
         if a.references.subsetOf(left.outputSet) && allDuplicateAgnostic(a) =>
       a.copy(child = left)
-    case a @ Aggregate(_, _, Join(_, right, RightOuter, _, _))
+    case a @ Aggregate(_, _, Join(_, right, RightOuter, _, _), _)
         if a.references.subsetOf(right.outputSet) && allDuplicateAgnostic(a) =>
       a.copy(child = right)
-    case a @ Aggregate(_, _, p @ Project(projectList, Join(left, _, LeftOuter, 
_, _)))
+    case a @ Aggregate(_, _, p @ Project(projectList, Join(left, _, LeftOuter, 
_, _)), _)
         if projectList.forall(_.deterministic) && 
p.references.subsetOf(left.outputSet) &&
           allDuplicateAgnostic(a) =>
       a.copy(child = p.copy(child = left))
-    case a @ Aggregate(_, _, p @ Project(projectList, Join(_, right, 
RightOuter, _, _)))
+    case a @ Aggregate(_, _, p @ Project(projectList, Join(_, right, 
RightOuter, _, _)), _)
         if projectList.forall(_.deterministic) && 
p.references.subsetOf(right.outputSet) &&
           allDuplicateAgnostic(a) =>
       a.copy(child = p.copy(child = right))
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index d9795cf33827..5a4e9f37c395 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -421,7 +421,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] 
with PredicateHelper
         } else {
           p
         }
-      case a @ Aggregate(grouping, expressions, child) =>
+      case a @ Aggregate(grouping, expressions, child, _) =>
         val referencesToAdd = missingReferences(a)
         if (referencesToAdd.nonEmpty) {
           Aggregate(grouping ++ referencesToAdd, expressions ++ 
referencesToAdd, child)
@@ -952,7 +952,7 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] with AliasHelpe
    * subqueries.
    */
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
-    case a @ Aggregate(grouping, expressions, child) =>
+    case a @ Aggregate(grouping, expressions, child, _) =>
       val subqueries = ArrayBuffer.empty[ScalarSubquery]
       val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, 
subqueries))
       if (subqueries.nonEmpty) {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index e48b44a603ad..a666b977030e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -291,7 +291,7 @@ object PhysicalAggregation {
     (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], 
LogicalPlan)
 
   def unapply(a: Any): Option[ReturnType] = a match {
-    case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+    case logical.Aggregate(groupingExpressions, resultExpressions, child, _) =>
       // A single aggregate expression might appear multiple times in 
resultExpressions.
       // In order to avoid evaluating an individual aggregate function 
multiple times, we'll
       // build a set of semantically distinct aggregate expressions and 
re-write expressions so
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 7c549a32aca0..58c57a1692d8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1197,7 +1197,8 @@ case class Range(
 case class Aggregate(
     groupingExpressions: Seq[Expression],
     aggregateExpressions: Seq[NamedExpression],
-    child: LogicalPlan)
+    child: LogicalPlan,
+    hint: Option[AggregateHint] = None)
   extends UnaryNode {
 
   override lazy val resolved: Boolean = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index ff7c79fbe89c..82260755977f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -197,6 +197,8 @@ case object NO_BROADCAST_AND_REPLICATION extends 
JoinStrategyHint {
   override def hintAliases: Set[String] = Set.empty
 }
 
+abstract class AggregateHint;
+
 /**
  * The callback for implementing customized strategies of handling hint errors.
  */
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
index 36a3fa3f2743..94ed80916eed 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
@@ -36,7 +36,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends 
PlanTest {
       case Project(projectList, child) =>
         val newAttr = UnresolvedAttribute("unresolvedAttr")
         Project(projectList ++ Seq(newAttr), child)
-      case agg @ Aggregate(Nil, aggregateExpressions, child) =>
+      case agg @ Aggregate(Nil, aggregateExpressions, child, _) =>
         // Project cannot host AggregateExpression
         Project(aggregateExpressions, child)
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 9cb5ee46e0f3..08dd4011f04d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -31,7 +31,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
   val testRelation2 = LocalRelation($"a".double, $"b".int, $"c".int, $"d".int, 
$"e".int)
 
   private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
-    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
+    case Aggregate(_, _, Aggregate(_, _, _: Expand, _), _) =>
     case _ => fail(s"Plan is not rewritten:\n$rewrite")
   }
 
@@ -87,7 +87,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
 
     val rewrite = RewriteDistinctAggregates(input)
     rewrite match {
-      case Aggregate(_, _, _: LocalRelation) =>
+      case Aggregate(_, _, _: LocalRelation, _) =>
       case _ => fail(s"Plan is not as expected:\n$rewrite")
     }
   }
@@ -104,7 +104,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
 
     val rewrite = RewriteDistinctAggregates(input)
     rewrite match {
-      case Aggregate(_, _, Aggregate(_, _, e: Expand)) =>
+      case Aggregate(_, _, Aggregate(_, _, e: Expand, _), _) =>
         assert(e.projections.size == 3)
       case _ => fail(s"Plan is not rewritten:\n$rewrite")
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 677dba008257..ec85c73c5ce0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -223,6 +223,24 @@ class SparkSessionExtensions {
     resolutionRuleBuilders += builder
   }
 
+  private[this] val hintResolutionRuleBuilders = 
mutable.Buffer.empty[RuleBuilder]
+
+  /**
+   * Build the analyzer hint resolution rules using the given [[SparkSession]].
+   */
+  private[sql] def buildHintResolutionRules(session: SparkSession): 
Seq[Rule[LogicalPlan]] = {
+    hintResolutionRuleBuilders.map(_.apply(session)).toSeq
+  }
+
+  /**
+   * Inject an analyzer hint resolution rule builder into the 
[[SparkSession]]. These analyzer
+   * rules will be executed as part of the early resolution phase of the 
analyzer, together with
+   * other hint resolution rules.
+   */
+  def injectHintResolutionRule(builder: RuleBuilder): Unit = {
+    hintResolutionRuleBuilders += builder
+  }
+
   private[this] val postHocResolutionRuleBuilders = 
mutable.Buffer.empty[RuleBuilder]
 
   /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
index f48dfbf57b33..a37258977481 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
@@ -52,7 +52,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) 
extends Rule[Logic
 
     plan.transform {
       case a @ Aggregate(_, aggExprs, child @ PhysicalOperation(
-          projectList, filters, PartitionedRelation(partAttrs, rel))) =>
+          projectList, filters, PartitionedRelation(partAttrs, rel)), _) =>
         // We only apply this optimization when only partitioned attributes 
are scanned.
         if (AttributeSet((projectList ++ 
filters).flatMap(_.references)).subsetOf(partAttrs)) {
           // The project list and filters all only refer to partition 
attributes, which means the
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 51100063f32d..f22d4fe32668 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -197,6 +197,8 @@ abstract class BaseSessionStateBuilder(
    * Note: this depends on the `conf` and `catalog` fields.
    */
   protected def analyzer: Analyzer = new Analyzer(catalogManager) {
+    override val hintResolutionRules: Seq[Rule[LogicalPlan]] =
+      customHintResolutionRules
     override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
       new FindDataSourceTable(session) +:
         new ResolveSQLOnFile(session) +:
@@ -239,6 +241,13 @@ abstract class BaseSessionStateBuilder(
     extensions.buildResolutionRules(session)
   }
 
+  /**
+   * Custom hint resolution rules to add to the Analyzer.
+   */
+  protected def customHintResolutionRules: Seq[Rule[LogicalPlan]] = {
+    extensions.buildHintResolutionRules(session)
+  }
+
   /**
    * Custom post resolution rules to add to the Analyzer. Prefer overriding 
this instead of
    * creating your own Analyzer.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
index 7a2ce1d7836b..f22d90d9f35d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
@@ -511,7 +511,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df1, Row(2, 2) :: Nil)
       df1.queryExecution.analyzed match {
-        case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length == 
2)
+        case Aggregate(_, _, WithCTE(_, cteDefs), _) => assert(cteDefs.length 
== 2)
         case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got 
$other")
       }
 
@@ -530,7 +530,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df2, Row(2, 2) :: Nil)
       df2.queryExecution.analyzed match {
-        case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_, cteDefs)), _, 
_, _)) =>
+        case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_, cteDefs)), _, 
_, _), _) =>
           assert(cteDefs.length == 1)
         case other => fail(s"Expect pattern Aggregate(Join(_, WithCTE(_))) but 
got $other")
       }
@@ -560,7 +560,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df3, Row(4, 4) :: Nil)
       df3.queryExecution.analyzed match {
-        case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_: Union, 
cteDefs)), _, _, _)) =>
+        case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_: Union, 
cteDefs)), _, _, _), _) =>
           assert(cteDefs.length == 2)
         case other => fail(
           s"Expect pattern Aggregate(Join(_, (WithCTE(Union(_, _))))) but got 
$other")
@@ -585,7 +585,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df4, Row(4, 4) :: Nil)
       df4.queryExecution.analyzed match {
-        case Aggregate(_, _, Join(_, SubqueryAlias(_, Union(children, _, _)), 
_, _, _))
+        case Aggregate(_, _, Join(_, SubqueryAlias(_, Union(children, _, _)), 
_, _, _), _)
           if children.head.find(_.isInstanceOf[WithCTE]).isDefined =>
           assert(
             children.head.collect {
@@ -618,7 +618,7 @@ abstract class CTEInlineSuiteBase
          """.stripMargin)
       checkAnswer(df5, Row(4, 4) :: Nil)
       df5.queryExecution.analyzed match {
-        case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length == 
2)
+        case Aggregate(_, _, WithCTE(_, cteDefs), _) => assert(cteDefs.length 
== 2)
         case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got 
$other")
       }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
index bc16a6947510..7d7185ae6c13 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -255,7 +255,7 @@ class InjectRuntimeFilterSuite extends QueryTest with 
SQLTestUtils with SharedSp
       case Filter(condition, _) => condition.collect {
         case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
         => subquery.plan.collect {
-          case Aggregate(_, aggregateExpressions, _) =>
+          case Aggregate(_, aggregateExpressions, _, _) =>
             aggregateExpressions.map {
               case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _, 
_, _, _),
               _) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
index 3f921618297d..d7177e19a617 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
@@ -854,7 +854,7 @@ class LateralColumnAliasSuite extends 
LateralColumnAliasSuiteBase {
        |""".stripMargin
     val analyzedPlan = sql(query).queryExecution.analyzed
     analyzedPlan.collect {
-      case Aggregate(_, aggregateExpressions, _) =>
+      case Aggregate(_, aggregateExpressions, _, _) =>
         val extracted = aggregateExpressions.collect {
           case Alias(child, _) => child
           case a: Attribute => a
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index ba87028a7147..5ec557462de1 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -29,10 +29,10 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, 
InternalRow, TableIden
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Final, Max, Partial}
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, CompoundBody, 
ParserInterface}
-import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, 
LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.{PlanTest, SQLHelper}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AggregateHint, 
ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
SinglePartition}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -43,6 +43,7 @@ import 
org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, 
WriteFilesExecBase, WriteFilesSpec}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
 import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS
@@ -53,7 +54,8 @@ import org.apache.spark.unsafe.types.UTF8String
 /**
  * Test cases for the [[SparkSessionExtensions]].
  */
-class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with 
AdaptiveSparkPlanHelper {
+class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with 
AdaptiveSparkPlanHelper
+  with PlanTest {
   private def create(
       builder: SparkSessionExtensionsProvider): 
Seq[SparkSessionExtensionsProvider] = Seq(builder)
 
@@ -159,7 +161,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with 
SQLHelper with Adapt
   }
 
   test("inject custom hint rule") {
-    withSession(Seq(_.injectPostHocResolutionRule(MyHintRule))) { session =>
+    withSession(Seq(_.injectHintResolutionRule(MyHintRule))) { session =>
       assert(
         
session.range(1).hint("CONVERT_TO_EMPTY").logicalPlan.isInstanceOf[LocalRelation],
         "plan is expected to be a local relation"
@@ -543,6 +545,20 @@ class SparkSessionExtensionSuite extends SparkFunSuite 
with SQLHelper with Adapt
       }
     }
   }
+
+  test("custom aggregate hint") {
+    // The custom hint allows us to replace the aggregate (without grouping 
keys) with just
+    // Literal.
+    
withSession(Seq(_.injectHintResolutionRule(CustomerAggregateHintResolutionRule),
+      _.injectOptimizerRule(CustomAggregateRule))) { session =>
+      val res = session.range(10).agg(max("id")).as("max_id")
+        .hint("MAX_VALUE", "id", 10)
+        .queryExecution.optimizedPlan
+      assert(res.isInstanceOf[Aggregate])
+      val expectedAlias = Alias(Literal(10L), "max(id)")()
+      compareExpressions(expectedAlias, 
res.asInstanceOf[Aggregate].aggregateExpressions.head)
+    }
+  }
 }
 
 case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1231,3 +1247,58 @@ object MyQueryPostPlannerStrategyRule extends 
Rule[SparkPlan] {
     }
   }
 }
+
+
+// Example of an Aggregate hint that tells that 'attribute' values are no 
larger than 'max'.
+// We will use them to rewrite MAX(attribute) with 'max' constant.
+case class CustomAggHint(attribute: AttributeReference, max: Int) extends 
AggregateHint
+
+// Attaches the CustomAggHint to the aggregate node without grouping keys if 
the aggregate
+// function is MAX over the specified column.
+case class CustomerAggregateHintResolutionRule(spark: SparkSession) extends 
Rule[LogicalPlan] {
+  val MY_HINT_NAME = Set("MAX_VALUE")
+
+  def isMax(expr: NamedExpression, attribute: String): 
Option[AttributeReference] = {
+    expr match {
+      case Alias(AggregateExpression(Max(a @ AttributeReference(name, _, _, 
_)), _, _, _, _), _)
+        if name.equalsIgnoreCase(attribute) =>
+        Some(a)
+      case _ => None
+    }
+  }
+
+  private def applyMaxValueHint(
+      plan: LogicalPlan,
+      attribute: String,
+      max: Int): LogicalPlan = {
+    val newPlan = plan match {
+      case a @ Aggregate(keys, aggs, _, None) if keys.isEmpty && aggs.size == 
1 =>
+        isMax(aggs.head, attribute) match {
+          case Some(attr) => a.copy(hint = Some(CustomAggHint(attr, max)))
+          case None => a
+        }
+      case _ => plan
+    }
+    newPlan.mapChildren { child =>
+      applyMaxValueHint(child, attribute, max)
+    }
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+    case h: UnresolvedHint if 
MY_HINT_NAME.contains(h.name.toUpperCase(Locale.ROOT)) =>
+      applyMaxValueHint(h.child, "id", 10)
+  }
+}
+
+// Logical rule that replaces the MAX aggregation function (in Aggregates with 
CustomAggHint)
+// with just the constant from the hint.
+case class CustomAggregateRule(spark: SparkSession) extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    plan transformDown {
+      case a @ Aggregate(groupingKeys, aggregates, _, Some(CustomAggHint(_, 
max)))
+        if groupingKeys.isEmpty && aggregates.size == 1 =>
+        a.copy(aggregateExpressions = Seq(Alias(Cast(Literal(max), 
aggregates.head.dataType),
+          aggregates.head.name)()), hint = None)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to