This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 492fcd83258b [SPARK-50683][SQL] Inline the common expression in With
if used once
492fcd83258b is described below
commit 492fcd83258bcd5a41f60a7b44b0ab6d4c9916b0
Author: zml1206 <[email protected]>
AuthorDate: Thu Jan 2 12:37:15 2025 +0800
[SPARK-50683][SQL] Inline the common expression in With if used once
### What changes were proposed in this pull request?
As title.
### Why are the changes needed?
Simplify plan and reduce unnecessary project.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49310 from zml1206/with.
Authored-by: zml1206 <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalyst/optimizer/RewriteWithExpression.scala | 21 +++++++++++++++------
.../optimizer/RewriteWithExpressionSuite.scala | 14 ++++++++++++--
2 files changed, 27 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index 40189a9f6102..5d85e89e1eab 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
private def applyInternal(p: LogicalPlan): LogicalPlan = {
val inputPlans = p.children
+ val commonExprIdSet = p.expressions
+ .flatMap(_.collect { case r: CommonExpressionRef => r.id })
+ .groupBy(identity)
+ .transform((_, v) => v.size)
+ .filter(_._2 > 1)
+ .keySet
val commonExprsPerChild =
Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
var newPlan: LogicalPlan = p.mapExpressions { expr =>
- rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
+ rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild,
commonExprIdSet)
}
val newChildren = inputPlans.zip(commonExprsPerChild).map { case
(inputPlan, commonExprs) =>
if (commonExprs.isEmpty) {
@@ -96,6 +102,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
e: Expression,
inputPlans: Seq[LogicalPlan],
commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
+ commonExprIdSet: Set[CommonExpressionId],
isNestedWith: Boolean = false): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
@@ -103,9 +110,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
case w: With if !isNestedWith =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(
- w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
+ w.child, inputPlans, commonExprsPerChild, commonExprIdSet,
isNestedWith = true)
val defs = w.defs.map(rewriteWithExprAndInputPlans(
- _, inputPlans, commonExprsPerChild, isNestedWith = true))
+ _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith =
true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id),
index) =>
@@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
"Cannot rewrite canonicalized Common expression definitions")
}
- if (CollapseProject.isCheap(child)) {
+ if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id))
{
refToExpr(id) = child
} else {
val childPlanIndex = inputPlans.indexWhere(
@@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
- rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild,
isNestedWith))
+ rewriteWithExprAndInputPlans(
+ _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
@@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}
case other => other.mapChildren(
- rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild,
isNestedWith)
+ rewriteWithExprAndInputPlans(
+ _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
)
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index 9f0a7fdaf315..8918b58ca1b5 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
- val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
+ val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
@@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
.select(star(), (a + a).as("_common_expr_2"))
// The final Project contains the final result expression, which
references both common
// expressions.
- .select(($"_common_expr_0" + ($"_common_expr_2" +
$"_common_expr_0")).as("col"))
+ .select(($"_common_expr_0" +
+ ($"_common_expr_2" + $"_common_expr_2" +
$"_common_expr_0")).as("col"))
.analyze
)
}
@@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest {
val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
intercept[AssertionError](Optimizer.execute(wrongPlan))
}
+
+ test("SPARK-50683: inline the common expression in With if used once") {
+ val a = testRelation.output.head
+ val exprDef = CommonExpressionDef(a + a)
+ val exprRef = new CommonExpressionRef(exprDef)
+ val expr = With(exprRef + 1, Seq(exprDef))
+ val plan = testRelation.select(expr.as("col"))
+ comparePlans(Optimizer.execute(plan), testRelation.select((a + a +
1).as("col")))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]