This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1e94415739c [SPARK-45586][SQL] Reduce compiler latency for plans with large expression trees 1e94415739c is described below commit 1e94415739ccfc4222a067459d3cb8be480530b4 Author: Kelvin Jiang <kelvin.ji...@databricks.com> AuthorDate: Thu Oct 19 10:24:58 2023 +0800 [SPARK-45586][SQL] Reduce compiler latency for plans with large expression trees ### What changes were proposed in this pull request? * Included rule ID pruning when traversing the expression trees in `TypeCoercionRule` (this avoids us from traversing the expression tree over and over again in future iterations of the rule) * Improved `EquivalentExpressions`: * Since `supportedExpression()` is checking for the existence of a pattern in the tree, changed to check the `TreePatternBits` instead of recursing using `.exists()` * When creating an `ExpressionEquals` object, calculating the height requires recursing through all of its children, which is O(n^2) when called upon each expression in the expression tree. This changes it so that this height is cached in the `TreeNode`, so that it is now O(n) when called upon each expression in the tree * More targeted TreePatternBits pruning in `ResolveTimeZone` and `ConstantPropagation` ### Why are the changes needed? This PR improves some analyzer and optimizer rules to address inefficiencies when handling extremely large expression trees. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? There should be no plan changes, so no unit tests were modified. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43420 from kelvinjian-db/SPARK-45586-large-expr-trees. Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 4 +++- .../sql/catalyst/analysis/timeZoneAnalysis.scala | 8 ++++---- .../expressions/EquivalentExpressions.scala | 23 +++++++--------------- .../spark/sql/catalyst/optimizer/expressions.scala | 6 ++++-- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 2 ++ 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c26569866e5..b34fd873621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -1215,7 +1216,8 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { } else { beforeMapChildren } - withPropagatedTypes.transformExpressionsUp(typeCoercionFn) + withPropagatedTypes.transformExpressionsUpWithPruning( + AlwaysProcess.fn, ruleId)(typeCoercionFn) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala index 11a5bc99b6c..01d88f050ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.TIME_ZONE_AWARE_EXPRESSION import org.apache.spark.sql.types.DataType /** @@ -40,10 +40,10 @@ object ResolveTimeZone extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressionsWithPruning( - _.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION), ruleId - )(transformTimeZoneExprs) + _.containsPattern(TIME_ZONE_AWARE_EXPRESSION), ruleId)(transformTimeZoneExprs) - def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) + def resolveTimeZones(e: Expression): Expression = e.transformWithPruning( + _.containsPattern(TIME_ZONE_AWARE_EXPRESSION))(transformTimeZoneExprs) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 1a84859cc3a..8738015ce91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -22,7 +22,7 @@ import java.util.Objects import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable +import org.apache.spark.sql.catalyst.trees.TreePattern.{LAMBDA_VARIABLE, PLAN_EXPRESSION} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -163,18 +163,13 @@ class EquivalentExpressions( case _ => Nil } - private def supportedExpression(e: Expression) = { - !e.exists { - // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the - // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - case _: LambdaVariable => true - + private def supportedExpression(e: Expression): Boolean = { + // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the + // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. + !(e.containsPattern(LAMBDA_VARIABLE) || // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, // can cause error like NPE. - case _: PlanExpression[_] => Utils.isInRunningSparkTask - - case _ => false - } + (e.containsPattern(PLAN_EXPRESSION) && Utils.isInRunningSparkTask)) } /** @@ -244,13 +239,9 @@ class EquivalentExpressions( * Wrapper around an Expression that provides semantic equality. */ case class ExpressionEquals(e: Expression) { - private def getHeight(tree: Expression): Int = { - tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1 - } - // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can // only be a parent of expr2 if expr1.height is larger than expr2.height. - lazy val height = getHeight(e) + def height: Int = e.height override def equals(o: Any): Boolean = o match { case other: ExpressionEquals => e.semanticEquals(other.e) && height == other.height diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index cc14789f6f5..91d5e180c59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -116,7 +116,7 @@ object ConstantFolding extends Rule[LogicalPlan] { */ object ConstantPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( - _.containsAllPatterns(LITERAL, FILTER), ruleId) { + _.containsAllPatterns(LITERAL, FILTER, BINARY_COMPARISON), ruleId) { case f: Filter => val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true) if (newCondition.isDefined) { @@ -147,6 +147,8 @@ object ConstantPropagation extends Rule[LogicalPlan] { private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) = condition match { + case _ if !condition.containsAllPatterns(LITERAL, BINARY_COMPARISON) => + (None, AttributeMap.empty) case e @ EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => (None, AttributeMap(Map(left -> (right, e)))) @@ -206,7 +208,7 @@ object ConstantPropagation extends Rule[LogicalPlan] { equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = { val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit }) val predicates = equalityPredicates.values.map(_._2).toSet - condition transform { + condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { case b: BinaryComparison if !predicates.contains(b) => b transform { case a: AttributeReference => constantsMap.getOrElse(a, a) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index a34ad10f36a..cc470d0de6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -173,6 +173,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] lazy val containsChild: Set[TreeNode[_]] = children.toSet + lazy val height: Int = children.map(_.height).reduceOption(_ max _).getOrElse(0) + 1 + // Copied from Scala 2.13.1 // github.com/scala/scala/blob/v2.13.1/src/library/scala/util/hashing/MurmurHash3.scala#L56-L73 // to prevent the issue https://github.com/scala/bug/issues/10495 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org