This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 01c294b05f3a [SPARK-45760][SQL] Add With expression to avoid duplicating expressions 01c294b05f3a is described below commit 01c294b05f3a9b7bd87cda0ee8b0160f5f58bb24 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Wed Nov 8 00:57:31 2023 +0800 [SPARK-45760][SQL] Add With expression to avoid duplicating expressions ### What changes were proposed in this pull request? Sometimes we need to duplicate expressions when rewriting the plan. It's OK for small query, as codegen has common-subexpression-elimination (CSE) to avoid evaluating the same expression. However, when the query is big, duplicating expressions can lead to a very big expression tree and make catalyst rules very slow, or even OOM when updating a leaf node (need to copy all tree nodes). This PR introduces a new expression to do expression-level CTE: it adds a Project to pre-evaluate the common expressions, so that they appear only once on the query plan tree, and are evaluated only once. `NullIf` now uses this new expression to avoid duplicating the `left` child expression. ### Why are the changes needed? make catalyst more efficient. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new test suite ### Was this patch authored or co-authored using generative AI tooling? No Closes #43623 from cloud-fan/with. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../explain-results/function_count_if.explain | 5 +- .../explain-results/function_regexp_substr.explain | 5 +- .../sql/connect/ProtoToParsedPlanTestSuite.scala | 15 +- .../spark/sql/catalyst/expressions/With.scala | 63 +++++++++ .../sql/catalyst/expressions/nullExpressions.scala | 6 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 + .../catalyst/optimizer/RewriteWithExpression.scala | 90 ++++++++++++ .../spark/sql/catalyst/trees/TreePatterns.scala | 2 + .../optimizer/RewriteWithExpressionSuite.scala | 157 +++++++++++++++++++++ 9 files changed, 338 insertions(+), 8 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain index 1c23bbf6bce5..f2ada15eccb7 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain @@ -1,2 +1,3 @@ -Aggregate [count(if (((a#0 > 0) = false)) null else (a#0 > 0)) AS count_if((a > 0))#0L] -+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] +Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L] ++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain index 69fc760c8291..1811f770f829 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain @@ -1,2 +1,3 @@ -Project [if ((regexp_extract(g#0, \d{2}(a|b|m), 0) = )) null else regexp_extract(g#0, \d{2}(a|b|m), 0) AS regexp_substr(g, \d{2}(a|b|m))#0] -+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] +Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS regexp_substr(g, \d{2}(a|b|m))#0] ++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_0#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index 9fdaffcba670..e0c4e21503e9 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -29,7 +29,9 @@ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog -import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions +import org.apache.spark.sql.catalyst.optimizer.{ReplaceExpressions, RewriteWithExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.SessionHolder @@ -181,8 +183,15 @@ class ProtoToParsedPlanTestSuite val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark)) val catalystPlan = analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker) - val actual = - removeMemoryAddress(normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString) + val finalAnalyzedPlan = { + object Helper extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Finish Analysis", Once, ReplaceExpressions) :: + Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil + } + Helper.execute(catalystPlan) + } + val actual = removeMemoryAddress(normalizeExprIds(finalAnalyzedPlan).treeString) val goldenFile = goldenFilePath.resolve(relativePath).getParent.resolve(name + ".explain") Try(readGoldenFile(goldenFile)) match { case Success(expected) if expected == actual => // Test passes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala new file mode 100644 index 000000000000..bfed63af1740 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} +import org.apache.spark.sql.types.DataType + +/** + * An expression holder that keeps a list of common expressions and allow the actual expression to + * reference these common expressions. The common expressions are guaranteed to be evaluated only + * once even if it's referenced more than once. This is similar to CTE but is expression-level. + */ +case class With(child: Expression, defs: Seq[CommonExpressionDef]) + extends Expression with Unevaluable { + override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION) + override def dataType: DataType = child.dataType + override def nullable: Boolean = child.nullable + override def children: Seq[Expression] = child +: defs + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + copy(child = newChildren.head, defs = newChildren.tail.map(_.asInstanceOf[CommonExpressionDef])) + } +} + +/** + * A wrapper of common expression to carry the id. + */ +case class CommonExpressionDef(child: Expression, id: Long = CommonExpressionDef.newId) + extends UnaryExpression with Unevaluable { + override def dataType: DataType = child.dataType + override protected def withNewChildInternal(newChild: Expression): Expression = + copy(child = newChild) +} + +/** + * A reference to the common expression by its id. Only resolved common expressions can be + * referenced, so that we can determine the data type and nullable of the reference node. + */ +case class CommonExpressionRef(id: Long, dataType: DataType, nullable: Boolean) + extends LeafExpression with Unevaluable { + def this(exprDef: CommonExpressionDef) = this(exprDef.id, exprDef.dataType, exprDef.nullable) + override val nodePatterns: Seq[TreePattern] = Seq(COMMON_EXPR_REF) +} + +object CommonExpressionDef { + private[sql] val curId = new java.util.concurrent.atomic.AtomicLong() + def newId: Long = curId.getAndIncrement() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 948cb6fbedd3..0e9e375b8acf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -154,7 +154,11 @@ case class NullIf(left: Expression, right: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { def this(left: Expression, right: Expression) = { - this(left, right, If(EqualTo(left, right), Literal.create(null, left.dataType), left)) + this(left, right, { + val commonExpr = CommonExpressionDef(left) + val ref = new CommonExpressionRef(commonExpr) + With(If(EqualTo(ref, right), Literal.create(null, left.dataType), ref), Seq(commonExpr)) + }) } override def parameters: Seq[Expression] = Seq(left, right) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 48ecb9aee211..decef766ae97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -147,6 +147,9 @@ abstract class Optimizer(catalogManager: CatalogManager) val batches = ( Batch("Finish Analysis", Once, FinishAnalysis) :: + // We must run this batch after `ReplaceExpressions`, as `RuntimeReplaceable` expression + // may produce `With` expressions that need to be rewritten. + Batch("Rewrite With expression", Once, RewriteWithExpression) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala new file mode 100644 index 000000000000..c5bd71b4a7d1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} + +/** + * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or + * just inline them if they are cheap. + * + * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. If we expand its + * usage, we should support aggregate/window functions as well. + */ +object RewriteWithExpression extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) { + case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => + var newChildren = p.children + var newPlan: LogicalPlan = p.transformExpressionsUp { + case With(child, defs) => + val refToExpr = mutable.HashMap.empty[Long, Expression] + val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias]) + + defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => + if (CollapseProject.isCheap(child)) { + refToExpr(id) = child + } else { + val childProjectionIndex = newChildren.indexWhere( + c => child.references.subsetOf(c.outputSet) + ) + if (childProjectionIndex == -1) { + // When we cannot rewrite the common expressions, force to inline them so that the + // query can still run. This can happen if the join condition contains `With` and + // the common expression references columns from both join sides. + // TODO: things can go wrong if the common expression is nondeterministic. We + // don't fix it for now to match the old buggy behavior when certain + // `RuntimeReplaceable` did not use the `With` expression. + // TODO: we should calculate the ref count and also inline the common expression + // if it's ref count is 1. + refToExpr(id) = child + } else { + val alias = Alias(child, s"_common_expr_$index")() + childProjections(childProjectionIndex) += alias + refToExpr(id) = alias.toAttribute + } + } + } + + newChildren = newChildren.zip(childProjections).map { case (child, projections) => + if (projections.nonEmpty) { + Project(child.output ++ projections, child) + } else { + child + } + } + + child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) { + case ref: CommonExpressionRef => refToExpr(ref.id) + } + } + + newPlan = newPlan.withNewChildren(newChildren) + if (p.output == newPlan.output) { + newPlan + } else { + Project(p.output, newPlan) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 8b714d5a5d28..9b3337d1a940 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -36,6 +36,7 @@ object TreePattern extends Enumeration { val CASE_WHEN: Value = Value val CAST: Value = Value val COALESCE: Value = Value + val COMMON_EXPR_REF: Value = Value val CONCAT: Value = Value val COUNT: Value = Value val CREATE_NAMED_STRUCT: Value = Value @@ -132,6 +133,7 @@ object TreePattern extends Enumeration { val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WINDOW_GROUP_LIMIT: Value = Value + val WITH_EXPRESSION: Value = Value val WITH_WINDOW_DEFINITION: Value = Value // Unresolved expression patterns (Alphabetically ordered) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala new file mode 100644 index 000000000000..c625379eb5ff --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.IntegerType + +class RewriteWithExpressionSuite extends PlanTest { + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil + } + + private val testRelation = LocalRelation($"a".int, $"b".int) + private val testRelation2 = LocalRelation($"x".int, $"y".int) + + test("simple common expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a) + val ref = new CommonExpressionRef(commonExprDef) + val plan = testRelation.select(With(ref + ref, Seq(commonExprDef)).as("col")) + comparePlans(Optimizer.execute(plan), testRelation.select((a + a).as("col"))) + } + + test("non-cheap common expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val plan = testRelation.select(With(ref * ref, Seq(commonExprDef)).as("col")) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .select(($"$commonExprName" * $"$commonExprName").as("col")) + .analyze + ) + } + + test("nested WITH expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val innerExpr = With(ref + ref, Seq(commonExprDef)) + val innerCommonExprName = "_common_expr_0" + + val b = testRelation.output.last + val outerCommonExprDef = CommonExpressionDef(innerExpr + b) + val outerRef = new CommonExpressionRef(outerCommonExprDef) + val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef)) + val outerCommonExprName = "_common_expr_0" + + val plan = testRelation.select(outerExpr.as("col")) + val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) + .as(outerCommonExprName) + val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)( + exprId = rewrittenOuterExpr.exprId) + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*) + .select((testRelation.output :+ $"$innerCommonExprName" :+ rewrittenOuterExpr): _*) + .select((outerExprAttr * outerExprAttr).as("col")) + .analyze + ) + } + + test("WITH expression in filter") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val plan = testRelation.where(With(ref < 10 && ref > 0, Seq(commonExprDef))) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .where($"$commonExprName" < 10 && $"$commonExprName" > 0) + .select(testRelation.output: _*) + .analyze + ) + } + + test("WITH expression in join condition: only reference left child") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val plan = testRelation.join(testRelation2, condition = Some(condition)) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .join(testRelation2, condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)) + .select((testRelation.output ++ testRelation2.output): _*) + .analyze + ) + } + + test("WITH expression in join condition: only reference right child") { + val x = testRelation2.output.head + val commonExprDef = CommonExpressionDef(x + x) + val ref = new CommonExpressionRef(commonExprDef) + val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val plan = testRelation.join(testRelation2, condition = Some(condition)) + val commonExprName = "_common_expr_0" + comparePlans( + Optimizer.execute(plan), + testRelation + .join( + testRelation2.select((testRelation2.output :+ (x + x).as(commonExprName)): _*), + condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0) + ) + .select((testRelation.output ++ testRelation2.output): _*) + .analyze + ) + } + + test("WITH expression in join condition: reference both children") { + val a = testRelation.output.head + val x = testRelation2.output.head + val commonExprDef = CommonExpressionDef(a + x) + val ref = new CommonExpressionRef(commonExprDef) + val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val plan = testRelation.join(testRelation2, condition = Some(condition)) + comparePlans( + Optimizer.execute(plan), + testRelation + .join( + testRelation2, + // Can't pre-evaluate, have to inline + condition = Some((a + x) < 10 && (a + x) > 0) + ) + ) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org