This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 492fcd83258b [SPARK-50683][SQL] Inline the common expression in With 
if used once
492fcd83258b is described below

commit 492fcd83258bcd5a41f60a7b44b0ab6d4c9916b0
Author: zml1206 <[email protected]>
AuthorDate: Thu Jan 2 12:37:15 2025 +0800

    [SPARK-50683][SQL] Inline the common expression in With if used once
    
    ### What changes were proposed in this pull request?
    As title.
    
    ### Why are the changes needed?
    
    Simplify plan and reduce unnecessary project.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49310 from zml1206/with.
    
    Authored-by: zml1206 <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../catalyst/optimizer/RewriteWithExpression.scala  | 21 +++++++++++++++------
 .../optimizer/RewriteWithExpressionSuite.scala      | 14 ++++++++++++--
 2 files changed, 27 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index 40189a9f6102..5d85e89e1eab 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
 
   private def applyInternal(p: LogicalPlan): LogicalPlan = {
     val inputPlans = p.children
+    val commonExprIdSet = p.expressions
+      .flatMap(_.collect { case r: CommonExpressionRef => r.id })
+      .groupBy(identity)
+      .transform((_, v) => v.size)
+      .filter(_._2 > 1)
+      .keySet
     val commonExprsPerChild = 
Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
     var newPlan: LogicalPlan = p.mapExpressions { expr =>
-      rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
+      rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, 
commonExprIdSet)
     }
     val newChildren = inputPlans.zip(commonExprsPerChild).map { case 
(inputPlan, commonExprs) =>
       if (commonExprs.isEmpty) {
@@ -96,6 +102,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
       e: Expression,
       inputPlans: Seq[LogicalPlan],
       commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
+      commonExprIdSet: Set[CommonExpressionId],
       isNestedWith: Boolean = false): Expression = {
     if (!e.containsPattern(WITH_EXPRESSION)) return e
     e match {
@@ -103,9 +110,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
       case w: With if !isNestedWith =>
         // Rewrite nested With expressions first
         val child = rewriteWithExprAndInputPlans(
-          w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
+          w.child, inputPlans, commonExprsPerChild, commonExprIdSet, 
isNestedWith = true)
         val defs = w.defs.map(rewriteWithExprAndInputPlans(
-          _, inputPlans, commonExprsPerChild, isNestedWith = true))
+          _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = 
true))
         val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
 
         defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), 
index) =>
@@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
               "Cannot rewrite canonicalized Common expression definitions")
           }
 
-          if (CollapseProject.isCheap(child)) {
+          if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id)) 
{
             refToExpr(id) = child
           } else {
             val childPlanIndex = inputPlans.indexWhere(
@@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
 
       case c: ConditionalExpression =>
         val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
-          rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, 
isNestedWith))
+          rewriteWithExprAndInputPlans(
+            _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
         val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
         // Use transformUp to handle nested With.
         newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
@@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
         }
 
       case other => other.mapChildren(
-        rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, 
isNestedWith)
+        rewriteWithExprAndInputPlans(
+          _, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
       )
     }
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index 9f0a7fdaf315..8918b58ca1b5 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
     val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
     val ref2 = new CommonExpressionRef(commonExprDef2)
     // The inner main expression references the outer expression
-    val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
+    val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
     val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
     comparePlans(
       Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
@@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
         .select(star(), (a + a).as("_common_expr_2"))
         // The final Project contains the final result expression, which 
references both common
         // expressions.
-        .select(($"_common_expr_0" + ($"_common_expr_2" + 
$"_common_expr_0")).as("col"))
+        .select(($"_common_expr_0" +
+          ($"_common_expr_2" + $"_common_expr_2" + 
$"_common_expr_0")).as("col"))
         .analyze
     )
   }
@@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest {
     val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
     intercept[AssertionError](Optimizer.execute(wrongPlan))
   }
+
+  test("SPARK-50683: inline the common expression in With if used once") {
+    val a = testRelation.output.head
+    val exprDef = CommonExpressionDef(a + a)
+    val exprRef = new CommonExpressionRef(exprDef)
+    val expr = With(exprRef + 1, Seq(exprDef))
+    val plan = testRelation.select(expr.as("col"))
+    comparePlans(Optimizer.execute(plan), testRelation.select((a + a + 
1).as("col")))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to