This is an automated email from the ASF dual-hosted git repository. gengliang pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 47f8687 [SPARK-35075][SQL] Add traversal pruning for subquery related rules 47f8687 is described below commit 47f86875f73edc3bec56d3610ab46a16ec37091c Author: Yingyi Bu <yingyi...@databricks.com> AuthorDate: Fri Apr 23 12:42:55 2021 +0800 [SPARK-35075][SQL] Add traversal pruning for subquery related rules ### What changes were proposed in this pull request? Added the following TreePattern enums: - DYNAMIC_PRUNING_SUBQUERY - EXISTS_SUBQUERY - IN_SUBQUERY - LIST_SUBQUERY - PLAN_EXPRESSION - SCALAR_SUBQUERY - FILTER Used them in the following rules: - ResolveSubquery - UpdateOuterReferences - OptimizeSubqueries - RewritePredicateSubquery - PullupCorrelatedPredicates - RewriteCorrelatedScalarSubquery (not the rule itself but an internal transform call, the full support is in SPARK-35148) - InsertAdaptiveSparkPlan - PlanAdaptiveSubqueries ### Why are the changes needed? Reduce the number of tree traversals and hence improve the query compilation latency. ### How was this patch tested? Existing tests. Closes #32247 from sigmod/subquery. Authored-by: Yingyi Bu <yingyi...@databricks.com> Signed-off-by: Gengliang Wang <ltn...@gmail.com> --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 14 +++++++------- .../spark/sql/catalyst/expressions/DynamicPruning.scala | 3 +++ .../spark/sql/catalyst/expressions/predicates.scala | 3 ++- .../apache/spark/sql/catalyst/expressions/subquery.scala | 13 +++++++++++++ .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 +++- .../apache/spark/sql/catalyst/optimizer/subquery.scala | 15 ++++++++++----- .../catalyst/plans/logical/basicLogicalOperators.scala | 4 +++- .../spark/sql/catalyst/rules/RuleIdCollection.scala | 3 +++ .../apache/spark/sql/catalyst/trees/TreePatterns.scala | 7 +++++++ .../sql/execution/adaptive/InsertAdaptiveSparkPlan.scala | 4 ++++ .../sql/execution/adaptive/PlanAdaptiveSubqueries.scala | 5 ++++- .../dynamicpruning/PlanDynamicPruningFilters.scala | 3 ++- .../scala/org/apache/spark/sql/execution/subquery.scala | 3 ++- 13 files changed, 63 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2c146c..87b8d52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,9 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.trees.TreePattern.{ - EXPRESSION_WITH_RANDOM_SEED, NATURAL_LIKE_JOIN, WINDOW_EXPRESSION -} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -2179,7 +2177,8 @@ class Analyzer(override val catalogManager: CatalogManager) * outer plan to get evaluated. */ private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { - plan transformExpressions { + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, + EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => @@ -2196,7 +2195,8 @@ class Analyzer(override val catalogManager: CatalogManager) /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, IN_SUBQUERY), ruleId) { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -3790,9 +3790,9 @@ object UpdateOuterReferences extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = { - plan resolveOperators { + plan.resolveOperatorsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, FILTER), ruleId) { case f @ Filter(_, a: Aggregate) if f.resolved => - f transformExpressions { + f.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { case s: SubqueryExpression if s.children.nonEmpty => // Collect the aliases from output of aggregate. val outerAliases = a.aggregateExpressions collect { case a: Alias => a } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala index de4b874..1c185dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DynamicPruning.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, TreePattern} import org.apache.spark.sql.catalyst.trees.UnaryLike trait DynamicPruning extends Predicate @@ -69,6 +70,8 @@ case class DynamicPruningSubquery( pruningKey.dataType == buildKeys(broadcastKeyIndex).dataType } + final override def nodePatternsInternal: Seq[TreePattern] = Seq(DYNAMIC_PRUNING_SUBQUERY) + override def toString: String = s"dynamicpruning#${exprId.id} $conditionString" override lazy val canonicalized: DynamicPruning = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 34d1d8f..cb710ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, INSET, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{IN, IN_SUBQUERY, INSET, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -342,6 +342,7 @@ case class InSubquery(values: Seq[Expression], query: ListQuery) values.head } + final override val nodePatterns: Seq[TreePattern] = Seq(IN_SUBQUERY) override def checkInputDataTypes(): TypeCheckResult = { if (values.length != query.childOutputs.length) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 2bedf84..ac939bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, LIST_SUBQUERY, + PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet @@ -38,6 +40,11 @@ abstract class PlanExpression[T <: QueryPlan[_]] extends Expression { bits } + final override val nodePatterns: Seq[TreePattern] = Seq(PLAN_EXPRESSION) ++ nodePatternsInternal + + // Subclasses can override this function to provide more TreePatterns. + def nodePatternsInternal(): Seq[TreePattern] = Seq() + /** The id of the subquery expression. */ def exprId: ExprId @@ -247,6 +254,8 @@ case class ScalarSubquery( override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): ScalarSubquery = copy(children = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = Seq(SCALAR_SUBQUERY) } object ScalarSubquery { @@ -295,6 +304,8 @@ case class ListQuery( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): ListQuery = copy(children = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = Seq(LIST_SUBQUERY) } /** @@ -340,4 +351,6 @@ case class Exists( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Exists = copy(children = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = Seq(EXISTS_SUBQUERY) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5343fce..09e7cff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -283,7 +284,8 @@ abstract class Optimizer(catalogManager: CatalogManager) case other => other } } - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressionsWithPruning( + _.containsPattern(PLAN_EXPRESSION), ruleId) { case s: SubqueryExpression => val Subquery(newPlan, _) = Optimizer.this.execute(Subquery.fromExpression(s)) // At this point we have an optimized subquery plan that we are going to attach diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9381796..fa87894 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, FILTER, IN_SUBQUERY, + LIST_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -94,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { } } - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + t => t.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY) && t.containsPattern(FILTER)) { case Filter(condition, child) if SubqueryExpression.hasInOrCorrelatedExistsSubquery(condition) => val (withSubquery, withoutSubquery) = @@ -164,7 +167,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { plan: LogicalPlan): (Option[Expression], LogicalPlan) = { var newPlan = plan val newExprs = exprs.map { e => - e transformDown { + e.transformDownWithPruning(_.containsAnyPattern(EXISTS_SUBQUERY, IN_SUBQUERY)) { case Exists(sub, conditions, _) => val exists = AttributeReference("exists", BooleanType, nullable = false)() newPlan = @@ -303,7 +306,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } } - plan transformExpressions { + plan.transformExpressionsWithPruning(_.containsAnyPattern( + SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) { case ScalarSubquery(sub, children, exprId) if children.nonEmpty => val (newPlan, newCond) = decorrelate(sub, outerPlans) ScalarSubquery(newPlan, getJoinCondition(newCond, children), exprId) @@ -319,7 +323,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper /** * Pull up the correlated predicates and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAnyPattern(SCALAR_SUBQUERY, EXISTS_SUBQUERY, LIST_SUBQUERY)) { case f @ Filter(_, a: Aggregate) => rewriteSubQueries(f, Seq(a, a.child)) // Only a few unary nodes (Project/Filter/Aggregate) can contain subqueries. @@ -341,7 +346,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe private def extractCorrelatedScalarSubqueries[E <: Expression]( expression: E, subqueries: ArrayBuffer[ScalarSubquery]): E = { - val newExpression = expression transform { + val newExpression = expression.transformWithPruning(_.containsPattern(SCALAR_SUBQUERY)) { case s: ScalarSubquery if s.children.nonEmpty => subqueries += s s.plan.output.head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 49e3e3c..bb999ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern.{ - INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern + FILTER, INNER_LIKE_JOIN, JOIN, LEFT_SEMI_OR_ANTI_JOIN, NATURAL_LIKE_JOIN, OUTER_JOIN, TreePattern } import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors @@ -166,6 +166,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows + final override val nodePatterns: Seq[TreePattern] = Seq(FILTER) + override protected lazy val validConstraints: ExpressionSet = { val predicates = splitConjunctivePredicates(condition) .filterNot(SubqueryExpression.hasCorrelatedSubquery) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index d745f50..884e259 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -42,12 +42,15 @@ object RuleIdCollection { // Catalyst Analyzer rules "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveNaturalAndUsingJoin" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveRandomSeed" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubquery" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: + "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" :: // Catalyst Optimizer rules "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" :: "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" :: "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" :: + "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" :: "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" :: "org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 3dc1aff..7d725fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -24,15 +24,22 @@ object TreePattern extends Enumeration { // Enum Ids start from 0. // Expression patterns (alphabetically ordered) val ATTRIBUTE_REFERENCE = Value(0) + val DYNAMIC_PRUNING_SUBQUERY: Value = Value + val EXISTS_SUBQUERY = Value val EXPRESSION_WITH_RANDOM_SEED = Value val IN: Value = Value + val IN_SUBQUERY: Value = Value val INSET: Value = Value + val LIST_SUBQUERY: Value = Value val LITERAL: Value = Value val NULL_LITERAL: Value = Value + val PLAN_EXPRESSION: Value = Value + val SCALAR_SUBQUERY: Value = Value val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value // Logical plan patterns (alphabetically ordered) + val FILTER: Value = Value val INNER_LIKE_JOIN: Value = Value val JOIN: Value = Value val LEFT_SEMI_OR_ANTI_JOIN: Value = Value diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala index d98b7c2..1065519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.v2.V2CommandExec @@ -113,6 +114,9 @@ case class InsertAdaptiveSparkPlan( */ private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = { val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec] + if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { + return subqueryMap.toMap + } plan.foreach(_.expressions.foreach(_.foreach { case expressions.ScalarSubquery(p, _, exprId) if !subqueryMap.contains(exprId.id) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index 13ff236..a2e4397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, + SCALAR_SUBQUERY} import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan} @@ -27,7 +29,8 @@ case class PlanAdaptiveSubqueries( subqueryMap: Map[Long, BaseSubqueryExec]) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - plan.transformAllExpressions { + plan.transformAllExpressionsWithPruning( + _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { case expressions.ScalarSubquery(_, _, exprId) => execution.ScalarSubquery(subqueryMap(exprId.id), exprId) case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index d3bc4ae..9a05e39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.DYNAMIC_PRUNING_SUBQUERY import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan, SubqueryBroadcastExec} import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec import org.apache.spark.sql.execution.joins._ @@ -49,7 +50,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) return plan } - plan transformAllExpressions { + plan.transformAllExpressionsWithPruning(_.containsPattern(DYNAMIC_PRUNING_SUBQUERY)) { case DynamicPruningSubquery( value, buildPlan, buildKeys, broadcastKeyIndex, onlyInBroadcast, exprId) => val sparkPlan = QueryExecution.createSparkPlan( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 15b8501..f96e9ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} +import org.apache.spark.sql.catalyst.trees.TreePattern.{IN_SUBQUERY, SCALAR_SUBQUERY} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, StructType} @@ -176,7 +177,7 @@ case class InSubqueryExec( */ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { - plan.transformAllExpressions { + plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY)) { case subquery: expressions.ScalarSubquery => val executedPlan = QueryExecution.prepareExecutedPlan(sparkSession, subquery.plan) ScalarSubquery( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org