This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d6b733492d3e [SPARK-50033][SQL] Add a hint to logical.Aggregate() node
d6b733492d3e is described below
commit d6b733492d3e1643babbcd2b1860e35af7928906
Author: Andrey Gubichev <[email protected]>
AuthorDate: Fri Nov 1 09:59:58 2024 -0700
[SPARK-50033][SQL] Add a hint to logical.Aggregate() node
### What changes were proposed in this pull request?
Adds an abstract hint to the case class of Aggregate(). And adds an
extension API to allow custom hint resolution rules.
### Why are the changes needed?
Allow users to define and resolve their own aggregate hints, and use the
hints to guide the optimizing/planning behavior.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
a new test for custom aggregate hint
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48523 from agubichev/agg_hint.
Lead-authored-by: Andrey Gubichev <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 32 +++++----
.../sql/catalyst/analysis/CheckAnalysis.scala | 2 +-
.../catalyst/analysis/ColumnResolutionHelper.scala | 2 +-
.../catalyst/analysis/DeduplicateRelations.scala | 4 +-
.../ResolveLateralColumnAliasReference.scala | 2 +-
.../analysis/UnsupportedOperationChecker.scala | 4 +-
.../catalyst/optimizer/DecorrelateInnerQuery.scala | 2 +-
.../InsertMapSortInGroupingExpressions.scala | 2 +-
.../catalyst/optimizer/OptimizeOneRowPlan.scala | 2 +-
.../spark/sql/catalyst/optimizer/Optimizer.scala | 10 +--
.../optimizer/PropagateEmptyRelation.scala | 2 +-
.../optimizer/RemoveRedundantAggregates.scala | 4 +-
.../spark/sql/catalyst/optimizer/expressions.scala | 2 +-
.../spark/sql/catalyst/optimizer/joins.scala | 8 +--
.../spark/sql/catalyst/optimizer/subquery.scala | 4 +-
.../spark/sql/catalyst/planning/patterns.scala | 2 +-
.../plans/logical/basicLogicalOperators.scala | 3 +-
.../spark/sql/catalyst/plans/logical/hints.scala | 2 +
.../OptimizerStructuralIntegrityCheckerSuite.scala | 2 +-
.../optimizer/RewriteDistinctAggregatesSuite.scala | 6 +-
.../apache/spark/sql/SparkSessionExtensions.scala | 18 +++++
.../sql/execution/OptimizeMetadataOnlyQuery.scala | 2 +-
.../sql/internal/BaseSessionStateBuilder.scala | 9 +++
.../org/apache/spark/sql/CTEInlineSuite.scala | 10 +--
.../spark/sql/InjectRuntimeFilterSuite.scala | 2 +-
.../apache/spark/sql/LateralColumnAliasSuite.scala | 2 +-
.../spark/sql/SparkSessionExtensionSuite.scala | 81 ++++++++++++++++++++--
27 files changed, 164 insertions(+), 57 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6b64f493f405..b94bd31eb3fa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -250,6 +250,11 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
*/
val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+ /**
+ * Override to provide additional rules for the "Hints" resolution batch.
+ */
+ val hintResolutionRules: Seq[Rule[LogicalPlan]] = Nil
+
/**
* Override to provide rules to do post-hoc resolution. Note that these
rules will be executed
* in an individual batch. This batch is to run right after the normal
resolution batch and
@@ -278,8 +283,9 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
- ResolveHints.ResolveJoinStrategyHints,
- ResolveHints.ResolveCoalesceHints),
+ Seq(ResolveHints.ResolveJoinStrategyHints,
+ ResolveHints.ResolveCoalesceHints) ++
+ hintResolutionRules: _*),
Batch("Simple Sanity Check", Once,
LookupFunctions),
Batch("Keep Legacy Outputs", Once,
@@ -474,7 +480,7 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsUpWithPruning(
_.containsPattern(UNRESOLVED_ALIAS), ruleId) {
- case Aggregate(groups, aggs, child) if child.resolved &&
hasUnresolvedAlias(aggs) =>
+ case Aggregate(groups, aggs, child, _) if child.resolved &&
hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
case Pivot(groupByOpt, pivotColumn, pivotValues, aggregates, child)
@@ -689,7 +695,7 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsDownWithPruning(
_.containsPattern(GROUPING_ANALYTICS), ruleId) {
case h @ UnresolvedHaving(_, agg @ Aggregate(
- GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, _))
+ GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, _, _))
if agg.childrenResolved && aggExprs.forall(_.resolved) =>
tryResolveHavingCondition(h, agg, selectedGroupByExprs, groupByExprs)
@@ -699,7 +705,7 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
case a if !a.childrenResolved => a
// Ensure group by expressions and aggregate expressions have been
resolved.
- case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs),
aggExprs, child)
+ case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs),
aggExprs, child, _)
if aggExprs.forall(_.resolved) =>
constructAggregate(selectedGroupByExprs, groupByExprs, aggExprs, child)
@@ -1841,7 +1847,7 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
// Replace the index with the corresponding expression in
aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns
(select expression)
- case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
+ case Aggregate(groups, aggs, child, _) if aggs.forall(_.resolved) &&
groups.exists(containUnresolvedOrdinal) =>
val newGroups = groups.map(resolveGroupByExpressionOrdinal(_, aggs))
Aggregate(newGroups, aggs, child)
@@ -2827,15 +2833,15 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
- case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
+ case Aggregate(_, aggList, _, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
throw
QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))
- case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
+ case Aggregate(_, aggList, _, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
throw QueryCompilationErrors.moreThanOneGeneratorError(generators)
- case Aggregate(groupList, aggList, child) if
canRewriteGenerator(aggList) &&
+ case Aggregate(groupList, aggList, child, _) if
canRewriteGenerator(aggList) &&
aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean
flag true.
var generatorVisited = false
@@ -3201,7 +3207,7 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
// Aggregate with Having clause. This rule works with an unresolved
Aggregate because
// a resolved Aggregate will not have Window Functions.
- case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs,
aggregateExprs, child))
+ case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs,
aggregateExprs, child, _))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
@@ -3226,7 +3232,7 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
// Aggregate without Having clause.
// Make sure the lateral column aliases are properly handled first.
- case a @ Aggregate(groupingExprs, aggregateExprs, child)
+ case a @ Aggregate(groupingExprs, aggregateExprs, child, _)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) &&
!aggregateExprs.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
@@ -3885,9 +3891,9 @@ object CleanupAliases extends Rule[LogicalPlan] with
AliasHelper {
val cleanedProjectList = projectList.map(trimNonTopLevelAliases)
Project(cleanedProjectList, child)
- case Aggregate(grouping, aggs, child) =>
+ case Aggregate(grouping, aggs, child, hint) =>
val cleanedAggs = aggs.map(trimNonTopLevelAliases)
- Aggregate(grouping.map(trimAliases), cleanedAggs, child)
+ Aggregate(grouping.map(trimAliases), cleanedAggs, child, hint)
case Window(windowExprs, partitionSpec, orderSpec, child) =>
val cleanedWindowExprs = windowExprs.map(trimNonTopLevelAliases)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index efb63ea181a8..16899b656f30 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -757,7 +757,7 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
case p @ Project(projectList, _) =>
checkForUnspecifiedWindow(projectList)
- case agg@Aggregate(_, aggregateExpressions, _) if
+ case agg@Aggregate(_, aggregateExpressions, _, _) if
PlanHelper.specialExpressionsInUnsupportedOperator(agg).isEmpty =>
checkForUnspecifiedWindow(aggregateExpressions)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 222e2c461999..d9c723aecbe8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -72,7 +72,7 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
newProject.copyTagsFrom(p)
(newExprs, newProject)
- case a @ Aggregate(groupExprs, aggExprs, child) =>
+ case a @ Aggregate(groupExprs, aggExprs, child, _) =>
if (missingAttrs.forall(attr =>
groupExprs.exists(_.semanticEquals(attr)))) {
// All the missing attributes are grouping expressions, valid
case.
(newExprs,
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 8181078c519f..ca5a6eee9bc9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -361,14 +361,14 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
if findAliases(projectList).size == projectList.size =>
Nil
- case oldVersion @ Aggregate(_, aggregateExpressions, _)
+ case oldVersion @ Aggregate(_, aggregateExpressions, _, _)
if
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(aggregateExpressions =
newAliases(aggregateExpressions))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))
// We don't search the child plan recursively for the same reason as the
above Project.
- case _ @ Aggregate(_, aggregateExpressions, _)
+ case _ @ Aggregate(_, aggregateExpressions, _, _)
if findAliases(aggregateExpressions).size == aggregateExpressions.size
=>
Nil
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
index c249a3506f2d..da8065eab606 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
@@ -196,7 +196,7 @@ object ResolveLateralColumnAliasReference extends
Rule[LogicalPlan] {
if ruleApplicableOnOperator(aggOriginal,
aggOriginal.aggregateExpressions)
&& aggOriginal.aggregateExpressions.exists(
_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
- val agg @ Aggregate(groupingExpressions, aggregateExpressions, _) =
+ val agg @ Aggregate(groupingExpressions, aggregateExpressions, _, _) =
aggOriginal.mapChildren(apply0)
// Check if current Aggregate is eligible to lift up with Project: the
aggregate
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index c8dc43209afd..69639b69290c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -244,7 +244,7 @@ object UnsupportedOperationChecker extends Logging {
* data.
*/
def containsCompleteData(subplan: LogicalPlan): Boolean = {
- val aggs = subplan.collect { case a@Aggregate(_, _, _) if a.isStreaming
=> a }
+ val aggs = subplan.collect { case a@Aggregate(_, _, _, _) if
a.isStreaming => a }
// Either the subplan has no streaming source, or it has aggregation
with Complete mode
!subplan.isStreaming || (aggs.nonEmpty && outputMode ==
InternalOutputModes.Complete)
}
@@ -264,7 +264,7 @@ object UnsupportedOperationChecker extends Logging {
// Operations that cannot exists anywhere in a streaming plan
subPlan match {
- case Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ case Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
val distinctAggExprs = aggregateExpressions.flatMap { expr =>
expr.collect { case ae: AggregateExpression if ae.isDistinct => ae
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
index 6c0d7189862d..9758f37efc2d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -773,7 +773,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
orderSpec = newOrderSpec, newChild)
(newWindow, joinCond, outerReferenceMap)
- case a @ Aggregate(groupingExpressions, aggregateExpressions, child)
=>
+ case a @ Aggregate(groupingExpressions, aggregateExpressions, child,
_) =>
val outerReferences = collectOuterReferences(a.expressions)
val newOuterReferences = parentOuterReferences ++ outerReferences
val (newChild, joinCond, outerReferenceMap) =
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
index 15ef025afd43..b6ced6c49a36 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala
@@ -53,7 +53,7 @@ object InsertMapSortInGroupingExpressions extends
Rule[LogicalPlan] {
}
plan transformUpWithNewOutput {
- case agg @ Aggregate(groupingExprs, aggregateExpressions, child)
+ case agg @ Aggregate(groupingExprs, aggregateExpressions, child, _)
if agg.groupingExpressions.exists(shouldAddMapSort) =>
val exprToMapSort = new mutable.HashMap[Expression, NamedExpression]
val newGroupingKeys = groupingExprs.map { expr =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
index 61c08eb8f8b6..8e066d1cd634 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlan.scala
@@ -50,7 +50,7 @@ object OptimizeOneRowPlan extends Rule[LogicalPlan] {
isChildEligible(child, enableForStreaming) => child
case Sort(_, false, child) if child.maxRowsPerPartition.exists(_ <= 1L)
&&
isChildEligible(child, enableForStreaming) => child
- case agg @ Aggregate(_, _, child) if agg.groupOnly &&
child.maxRows.exists(_ <= 1L) &&
+ case agg @ Aggregate(_, _, child, _) if agg.groupOnly &&
child.maxRows.exists(_ <= 1L) &&
isChildEligible(child, enableForStreaming) =>
Project(agg.aggregateExpressions, child)
case agg: Aggregate if agg.child.maxRows.exists(_ <= 1L) &&
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index dcfb64ae51fb..7074f8c4f089 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -338,7 +338,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
// optimize it again, to save optimization time and avoid breaking
broadcast/subquery reuse.
case d: DynamicPruningSubquery => d
case s @ ScalarSubquery(
- PhysicalOperation(projections, predicates, a @ Aggregate(group, _,
child)),
+ PhysicalOperation(projections, predicates, a @ Aggregate(group, _,
child, _)),
_, _, _, _, mayHaveCountBug, _)
if
conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG)
&&
mayHaveCountBug.nonEmpty && mayHaveCountBug.get =>
@@ -960,7 +960,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
d.copy(child = prunedChild(child, d.references))
// Prunes the unused columns from child of
Aggregate/Expand/Generate/ScriptTransformation
- case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references)
=>
+ case a @ Aggregate(_, _, child, _) if
!child.outputSet.subsetOf(a.references) =>
a.copy(child = prunedChild(child, a.references))
case f @ FlatMapGroupsInPandas(_, _, _, child) if
!child.outputSet.subsetOf(f.references) =>
f.copy(child = prunedChild(child, f.references))
@@ -1653,7 +1653,7 @@ object EliminateSorts extends Rule[LogicalPlan] {
case j @ Join(originLeft, originRight, _, cond, _) if
cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft, true),
right = recursiveRemoveSort(originRight, true))
- case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
+ case g @ Aggregate(_, aggs, originChild, _) if isOrderIrrelevantAggs(aggs)
=>
g.copy(child = recursiveRemoveSort(originChild, true))
}
@@ -2482,7 +2482,7 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
- case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
+ case a @ Aggregate(grouping, _, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
if (newGrouping.nonEmpty) {
a.copy(groupingExpressions = newGrouping)
@@ -2538,7 +2538,7 @@ object GenerateOptimization extends Rule[LogicalPlan] {
object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
- case a @ Aggregate(grouping, _, _) if grouping.size > 1 =>
+ case a @ Aggregate(grouping, _, _, _) if grouping.size > 1 =>
val newGrouping = ExpressionSet(grouping).toSeq
if (newGrouping.size == grouping.size) {
a
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index d23d43acc217..86316494f6ff 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -169,7 +169,7 @@ abstract class PropagateEmptyRelationBase extends
Rule[LogicalPlan] with CastSup
// Aggregation on empty LocalRelation generated from a streaming source
is not eliminated
// as stateful streaming aggregation need to perform other state
management operations other
// than just processing the input data.
- case Aggregate(ge, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
+ case Aggregate(ge, _, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
// Generators like Hive-style UDTF may return their records within
`close`.
case Generate(_: Explode, _, _, _, _, _) => empty(p)
case Expand(_, _, _) => empty(p)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
index badf4065f5fb..d6a4bd030c9d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
@@ -30,10 +30,10 @@ import
org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
- case upper @ Aggregate(_, _, lower: Aggregate) if isLowerRedundant(upper,
lower) =>
+ case upper @ Aggregate(_, _, lower: Aggregate, _) if
isLowerRedundant(upper, lower) =>
val projectList =
lower.aggregateExpressions.filter(upper.references.contains(_))
upper.copy(child = Project(projectList, lower.child))
- case agg @ Aggregate(groupingExps, _, child)
+ case agg @ Aggregate(groupingExps, _, child, _)
if agg.groupOnly &&
child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
Project(agg.aggregateExpressions, child)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index c0cd976b9e9b..06fc366ce6bb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -244,7 +244,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan]
{
}
private def collectGroupingExpressions(plan: LogicalPlan): ExpressionSet =
plan match {
- case Aggregate(groupingExpressions, aggregateExpressions, child) =>
+ case Aggregate(groupingExpressions, aggregateExpressions, child, _) =>
ExpressionSet.apply(groupingExpressions)
case _ => ExpressionSet(Seq.empty)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 6802adaa2ea2..5fb30e810649 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -211,17 +211,17 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with
PredicateHelper {
val newJoinType = buildNewJoinType(f, j)
if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType
= newJoinType))
- case a @ Aggregate(_, _, Join(left, _, LeftOuter, _, _))
+ case a @ Aggregate(_, _, Join(left, _, LeftOuter, _, _), _)
if a.references.subsetOf(left.outputSet) && allDuplicateAgnostic(a) =>
a.copy(child = left)
- case a @ Aggregate(_, _, Join(_, right, RightOuter, _, _))
+ case a @ Aggregate(_, _, Join(_, right, RightOuter, _, _), _)
if a.references.subsetOf(right.outputSet) && allDuplicateAgnostic(a) =>
a.copy(child = right)
- case a @ Aggregate(_, _, p @ Project(projectList, Join(left, _, LeftOuter,
_, _)))
+ case a @ Aggregate(_, _, p @ Project(projectList, Join(left, _, LeftOuter,
_, _)), _)
if projectList.forall(_.deterministic) &&
p.references.subsetOf(left.outputSet) &&
allDuplicateAgnostic(a) =>
a.copy(child = p.copy(child = left))
- case a @ Aggregate(_, _, p @ Project(projectList, Join(_, right,
RightOuter, _, _)))
+ case a @ Aggregate(_, _, p @ Project(projectList, Join(_, right,
RightOuter, _, _)), _)
if projectList.forall(_.deterministic) &&
p.references.subsetOf(right.outputSet) &&
allDuplicateAgnostic(a) =>
a.copy(child = p.copy(child = right))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index d9795cf33827..5a4e9f37c395 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -421,7 +421,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan]
with PredicateHelper
} else {
p
}
- case a @ Aggregate(grouping, expressions, child) =>
+ case a @ Aggregate(grouping, expressions, child, _) =>
val referencesToAdd = missingReferences(a)
if (referencesToAdd.nonEmpty) {
Aggregate(grouping ++ referencesToAdd, expressions ++
referencesToAdd, child)
@@ -952,7 +952,7 @@ object RewriteCorrelatedScalarSubquery extends
Rule[LogicalPlan] with AliasHelpe
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
- case a @ Aggregate(grouping, expressions, child) =>
+ case a @ Aggregate(grouping, expressions, child, _) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_,
subqueries))
if (subqueries.nonEmpty) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index e48b44a603ad..a666b977030e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -291,7 +291,7 @@ object PhysicalAggregation {
(Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression],
LogicalPlan)
def unapply(a: Any): Option[ReturnType] = a match {
- case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ case logical.Aggregate(groupingExpressions, resultExpressions, child, _) =>
// A single aggregate expression might appear multiple times in
resultExpressions.
// In order to avoid evaluating an individual aggregate function
multiple times, we'll
// build a set of semantically distinct aggregate expressions and
re-write expressions so
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 7c549a32aca0..58c57a1692d8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1197,7 +1197,8 @@ case class Range(
case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
- child: LogicalPlan)
+ child: LogicalPlan,
+ hint: Option[AggregateHint] = None)
extends UnaryNode {
override lazy val resolved: Boolean = {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index ff7c79fbe89c..82260755977f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -197,6 +197,8 @@ case object NO_BROADCAST_AND_REPLICATION extends
JoinStrategyHint {
override def hintAliases: Set[String] = Set.empty
}
+abstract class AggregateHint;
+
/**
* The callback for implementing customized strategies of handling hint errors.
*/
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
index 36a3fa3f2743..94ed80916eed 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
@@ -36,7 +36,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends
PlanTest {
case Project(projectList, child) =>
val newAttr = UnresolvedAttribute("unresolvedAttr")
Project(projectList ++ Seq(newAttr), child)
- case agg @ Aggregate(Nil, aggregateExpressions, child) =>
+ case agg @ Aggregate(Nil, aggregateExpressions, child, _) =>
// Project cannot host AggregateExpression
Project(aggregateExpressions, child)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 9cb5ee46e0f3..08dd4011f04d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -31,7 +31,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
val testRelation2 = LocalRelation($"a".double, $"b".int, $"c".int, $"d".int,
$"e".int)
private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
- case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
+ case Aggregate(_, _, Aggregate(_, _, _: Expand, _), _) =>
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}
@@ -87,7 +87,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
val rewrite = RewriteDistinctAggregates(input)
rewrite match {
- case Aggregate(_, _, _: LocalRelation) =>
+ case Aggregate(_, _, _: LocalRelation, _) =>
case _ => fail(s"Plan is not as expected:\n$rewrite")
}
}
@@ -104,7 +104,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
val rewrite = RewriteDistinctAggregates(input)
rewrite match {
- case Aggregate(_, _, Aggregate(_, _, e: Expand)) =>
+ case Aggregate(_, _, Aggregate(_, _, e: Expand, _), _) =>
assert(e.projections.size == 3)
case _ => fail(s"Plan is not rewritten:\n$rewrite")
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 677dba008257..ec85c73c5ce0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -223,6 +223,24 @@ class SparkSessionExtensions {
resolutionRuleBuilders += builder
}
+ private[this] val hintResolutionRuleBuilders =
mutable.Buffer.empty[RuleBuilder]
+
+ /**
+ * Build the analyzer hint resolution rules using the given [[SparkSession]].
+ */
+ private[sql] def buildHintResolutionRules(session: SparkSession):
Seq[Rule[LogicalPlan]] = {
+ hintResolutionRuleBuilders.map(_.apply(session)).toSeq
+ }
+
+ /**
+ * Inject an analyzer hint resolution rule builder into the
[[SparkSession]]. These analyzer
+ * rules will be executed as part of the early resolution phase of the
analyzer, together with
+ * other hint resolution rules.
+ */
+ def injectHintResolutionRule(builder: RuleBuilder): Unit = {
+ hintResolutionRuleBuilders += builder
+ }
+
private[this] val postHocResolutionRuleBuilders =
mutable.Buffer.empty[RuleBuilder]
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
index f48dfbf57b33..a37258977481 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
@@ -52,7 +52,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog)
extends Rule[Logic
plan.transform {
case a @ Aggregate(_, aggExprs, child @ PhysicalOperation(
- projectList, filters, PartitionedRelation(partAttrs, rel))) =>
+ projectList, filters, PartitionedRelation(partAttrs, rel)), _) =>
// We only apply this optimization when only partitioned attributes
are scanned.
if (AttributeSet((projectList ++
filters).flatMap(_.references)).subsetOf(partAttrs)) {
// The project list and filters all only refer to partition
attributes, which means the
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 51100063f32d..f22d4fe32668 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -197,6 +197,8 @@ abstract class BaseSessionStateBuilder(
* Note: this depends on the `conf` and `catalog` fields.
*/
protected def analyzer: Analyzer = new Analyzer(catalogManager) {
+ override val hintResolutionRules: Seq[Rule[LogicalPlan]] =
+ customHintResolutionRules
override val extendedResolutionRules: Seq[Rule[LogicalPlan]] =
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
@@ -239,6 +241,13 @@ abstract class BaseSessionStateBuilder(
extensions.buildResolutionRules(session)
}
+ /**
+ * Custom hint resolution rules to add to the Analyzer.
+ */
+ protected def customHintResolutionRules: Seq[Rule[LogicalPlan]] = {
+ extensions.buildHintResolutionRules(session)
+ }
+
/**
* Custom post resolution rules to add to the Analyzer. Prefer overriding
this instead of
* creating your own Analyzer.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
index 7a2ce1d7836b..f22d90d9f35d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala
@@ -511,7 +511,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df1, Row(2, 2) :: Nil)
df1.queryExecution.analyzed match {
- case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length ==
2)
+ case Aggregate(_, _, WithCTE(_, cteDefs), _) => assert(cteDefs.length
== 2)
case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got
$other")
}
@@ -530,7 +530,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df2, Row(2, 2) :: Nil)
df2.queryExecution.analyzed match {
- case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_, cteDefs)), _,
_, _)) =>
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_, cteDefs)), _,
_, _), _) =>
assert(cteDefs.length == 1)
case other => fail(s"Expect pattern Aggregate(Join(_, WithCTE(_))) but
got $other")
}
@@ -560,7 +560,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df3, Row(4, 4) :: Nil)
df3.queryExecution.analyzed match {
- case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_: Union,
cteDefs)), _, _, _)) =>
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, WithCTE(_: Union,
cteDefs)), _, _, _), _) =>
assert(cteDefs.length == 2)
case other => fail(
s"Expect pattern Aggregate(Join(_, (WithCTE(Union(_, _))))) but got
$other")
@@ -585,7 +585,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df4, Row(4, 4) :: Nil)
df4.queryExecution.analyzed match {
- case Aggregate(_, _, Join(_, SubqueryAlias(_, Union(children, _, _)),
_, _, _))
+ case Aggregate(_, _, Join(_, SubqueryAlias(_, Union(children, _, _)),
_, _, _), _)
if children.head.find(_.isInstanceOf[WithCTE]).isDefined =>
assert(
children.head.collect {
@@ -618,7 +618,7 @@ abstract class CTEInlineSuiteBase
""".stripMargin)
checkAnswer(df5, Row(4, 4) :: Nil)
df5.queryExecution.analyzed match {
- case Aggregate(_, _, WithCTE(_, cteDefs)) => assert(cteDefs.length ==
2)
+ case Aggregate(_, _, WithCTE(_, cteDefs), _) => assert(cteDefs.length
== 2)
case other => fail(s"Expect pattern Aggregate(WithCTE(_)) but got
$other")
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
index bc16a6947510..7d7185ae6c13 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -255,7 +255,7 @@ class InjectRuntimeFilterSuite extends QueryTest with
SQLTestUtils with SharedSp
case Filter(condition, _) => condition.collect {
case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
=> subquery.plan.collect {
- case Aggregate(_, aggregateExpressions, _) =>
+ case Aggregate(_, aggregateExpressions, _, _) =>
aggregateExpressions.map {
case Alias(AggregateExpression(bfAgg : BloomFilterAggregate, _,
_, _, _),
_) =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
index 3f921618297d..d7177e19a617 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
@@ -854,7 +854,7 @@ class LateralColumnAliasSuite extends
LateralColumnAliasSuiteBase {
|""".stripMargin
val analyzedPlan = sql(query).queryExecution.analyzed
analyzedPlan.collect {
- case Aggregate(_, aggregateExpressions, _) =>
+ case Aggregate(_, aggregateExpressions, _, _) =>
val extracted = aggregateExpressions.collect {
case Alias(child, _) => child
case a: Attribute => a
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index ba87028a7147..5ec557462de1 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -29,10 +29,10 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier,
InternalRow, TableIden
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
Final, Max, Partial}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, CompoundBody,
ParserInterface}
-import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit,
LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.{PlanTest, SQLHelper}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AggregateHint,
ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning,
SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
@@ -43,6 +43,7 @@ import
org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec,
WriteFilesExecBase, WriteFilesSpec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.COLUMN_BATCH_SIZE
import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS
@@ -53,7 +54,8 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Test cases for the [[SparkSessionExtensions]].
*/
-class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with
AdaptiveSparkPlanHelper {
+class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with
AdaptiveSparkPlanHelper
+ with PlanTest {
private def create(
builder: SparkSessionExtensionsProvider):
Seq[SparkSessionExtensionsProvider] = Seq(builder)
@@ -159,7 +161,7 @@ class SparkSessionExtensionSuite extends SparkFunSuite with
SQLHelper with Adapt
}
test("inject custom hint rule") {
- withSession(Seq(_.injectPostHocResolutionRule(MyHintRule))) { session =>
+ withSession(Seq(_.injectHintResolutionRule(MyHintRule))) { session =>
assert(
session.range(1).hint("CONVERT_TO_EMPTY").logicalPlan.isInstanceOf[LocalRelation],
"plan is expected to be a local relation"
@@ -543,6 +545,20 @@ class SparkSessionExtensionSuite extends SparkFunSuite
with SQLHelper with Adapt
}
}
}
+
+ test("custom aggregate hint") {
+ // The custom hint allows us to replace the aggregate (without grouping
keys) with just
+ // Literal.
+
withSession(Seq(_.injectHintResolutionRule(CustomerAggregateHintResolutionRule),
+ _.injectOptimizerRule(CustomAggregateRule))) { session =>
+ val res = session.range(10).agg(max("id")).as("max_id")
+ .hint("MAX_VALUE", "id", 10)
+ .queryExecution.optimizedPlan
+ assert(res.isInstanceOf[Aggregate])
+ val expectedAlias = Alias(Literal(10L), "max(id)")()
+ compareExpressions(expectedAlias,
res.asInstanceOf[Aggregate].aggregateExpressions.head)
+ }
+ }
}
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1231,3 +1247,58 @@ object MyQueryPostPlannerStrategyRule extends
Rule[SparkPlan] {
}
}
}
+
+
+// Example of an Aggregate hint that tells that 'attribute' values are no
larger than 'max'.
+// We will use them to rewrite MAX(attribute) with 'max' constant.
+case class CustomAggHint(attribute: AttributeReference, max: Int) extends
AggregateHint
+
+// Attaches the CustomAggHint to the aggregate node without grouping keys if
the aggregate
+// function is MAX over the specified column.
+case class CustomerAggregateHintResolutionRule(spark: SparkSession) extends
Rule[LogicalPlan] {
+ val MY_HINT_NAME = Set("MAX_VALUE")
+
+ def isMax(expr: NamedExpression, attribute: String):
Option[AttributeReference] = {
+ expr match {
+ case Alias(AggregateExpression(Max(a @ AttributeReference(name, _, _,
_)), _, _, _, _), _)
+ if name.equalsIgnoreCase(attribute) =>
+ Some(a)
+ case _ => None
+ }
+ }
+
+ private def applyMaxValueHint(
+ plan: LogicalPlan,
+ attribute: String,
+ max: Int): LogicalPlan = {
+ val newPlan = plan match {
+ case a @ Aggregate(keys, aggs, _, None) if keys.isEmpty && aggs.size ==
1 =>
+ isMax(aggs.head, attribute) match {
+ case Some(attr) => a.copy(hint = Some(CustomAggHint(attr, max)))
+ case None => a
+ }
+ case _ => plan
+ }
+ newPlan.mapChildren { child =>
+ applyMaxValueHint(child, attribute, max)
+ }
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
+ case h: UnresolvedHint if
MY_HINT_NAME.contains(h.name.toUpperCase(Locale.ROOT)) =>
+ applyMaxValueHint(h.child, "id", 10)
+ }
+}
+
+// Logical rule that replaces the MAX aggregation function (in Aggregates with
CustomAggHint)
+// with just the constant from the hint.
+case class CustomAggregateRule(spark: SparkSession) extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ plan transformDown {
+ case a @ Aggregate(groupingKeys, aggregates, _, Some(CustomAggHint(_,
max)))
+ if groupingKeys.isEmpty && aggregates.size == 1 =>
+ a.copy(aggregateExpressions = Seq(Alias(Cast(Literal(max),
aggregates.head.dataType),
+ aggregates.head.name)()), hint = None)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]