This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new df08177de2fd [SPARK-48416][SQL] Support nested correlated With 
expression
df08177de2fd is described below

commit df08177de2fd2b177caf79ca533eb0cd2c6a4ba6
Author: Wenchen Fan <[email protected]>
AuthorDate: Thu Dec 12 15:36:09 2024 -0800

    [SPARK-48416][SQL] Support nested correlated With expression
    
    ### What changes were proposed in this pull request?
    
    The inner `With` may reference common expressions of an outer `With`. This 
PR supports this case by making the rule `RewriteWithExpression` only rewrite 
top-level `With` expressions, and run the rule repeatedly so that the inner 
`With` expression becomes top-level `With` after one iteration, and gets 
rewritten in the next iteration.
    
    ### Why are the changes needed?
    
    To support optimized filter pushdown with `With` expression.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    updated the unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #49093 from cloud-fan/with.
    
    Lead-authored-by: Wenchen Fan <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  2 +-
 .../catalyst/optimizer/RewriteWithExpression.scala | 25 ++++-----
 .../optimizer/RewriteWithExpressionSuite.scala     | 61 ++++++++++++++--------
 3 files changed, 50 insertions(+), 38 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7ec467badce5..31c1f8917763 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -160,7 +160,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
     Batch("Finish Analysis", FixedPoint(1), FinishAnalysis) ::
     // We must run this batch after `ReplaceExpressions`, as 
`RuntimeReplaceable` expression
     // may produce `With` expressions that need to be rewritten.
-    Batch("Rewrite With expression", Once, RewriteWithExpression) ::
+    Batch("Rewrite With expression", fixedPoint, RewriteWithExpression) ::
     
//////////////////////////////////////////////////////////////////////////////////////////
     // Optimizer rules start here
     
//////////////////////////////////////////////////////////////////////////////////////////
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index 393a66f7c1e4..d0c5d8158644 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -85,21 +85,19 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
 
   private def rewriteWithExprAndInputPlans(
       e: Expression,
-      inputPlans: Array[LogicalPlan]): Expression = {
+      inputPlans: Array[LogicalPlan],
+      isNestedWith: Boolean = false): Expression = {
     if (!e.containsPattern(WITH_EXPRESSION)) return e
     e match {
-      case w: With =>
+      // Do not handle nested With in one pass. Leave it to the next rule 
executor batch.
+      case w: With if !isNestedWith =>
         // Rewrite nested With expressions first
-        val child = rewriteWithExprAndInputPlans(w.child, inputPlans)
-        val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans))
+        val child = rewriteWithExprAndInputPlans(w.child, inputPlans, 
isNestedWith = true)
+        val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans, 
isNestedWith = true))
         val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]
         val childProjections = 
Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias])
 
         defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), 
index) =>
-          if (child.containsPattern(COMMON_EXPR_REF)) {
-            throw SparkException.internalError(
-              "Common expression definition cannot reference other Common 
expression definitions")
-          }
           if (id.canonicalized) {
             throw SparkException.internalError(
               "Cannot rewrite canonicalized Common expression definitions")
@@ -148,10 +146,9 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
         }
 
         child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
-          case ref: CommonExpressionRef =>
-            if (!refToExpr.contains(ref.id)) {
-              throw SparkException.internalError("Undefined common expression 
id " + ref.id)
-            }
+          // `child` may contain nested With and we only replace 
`CommonExpressionRef` that
+          // references common expressions in the current `With`.
+          case ref: CommonExpressionRef if refToExpr.contains(ref.id) =>
             if (ref.id.canonicalized) {
               throw SparkException.internalError(
                 "Cannot rewrite canonicalized Common expression references")
@@ -161,7 +158,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
 
       case c: ConditionalExpression =>
         val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
-          rewriteWithExprAndInputPlans(_, inputPlans))
+          rewriteWithExprAndInputPlans(_, inputPlans, isNestedWith))
         val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
         // Use transformUp to handle nested With.
         newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
@@ -174,7 +171,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
             }
         }
 
-      case other => other.mapChildren(rewriteWithExprAndInputPlans(_, 
inputPlans))
+      case other => other.mapChildren(rewriteWithExprAndInputPlans(_, 
inputPlans, isNestedWith))
     }
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index 0aeca961aa51..0be6ae649464 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
@@ -29,7 +28,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
 class RewriteWithExpressionSuite extends PlanTest {
 
   object Optimizer extends RuleExecutor[LogicalPlan] {
-    val batches = Batch("Rewrite With expression", Once,
+    val batches = Batch("Rewrite With expression", FixedPoint(5),
       PullOutGroupingExpressions,
       RewriteWithExpression) :: Nil
   }
@@ -84,13 +83,11 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref * ref
     }
 
-    val plan = testRelation.select(outerExpr.as("col"))
     comparePlans(
-      Optimizer.execute(plan),
+      Optimizer.execute(testRelation.select(outerExpr.as("col"))),
       testRelation
-        .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
-        .select((testRelation.output ++ Seq($"_common_expr_0",
-          ($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))): 
_*)
+        .select(star(), (a + a).as("_common_expr_0"))
+        .select(a, b, ($"_common_expr_0" + $"_common_expr_0" + 
b).as("_common_expr_1"))
         .select(($"_common_expr_1" * $"_common_expr_1").as("col"))
         .analyze
     )
@@ -104,42 +101,60 @@ class RewriteWithExpressionSuite extends PlanTest {
     val outerExpr = With(b + b) { case Seq(ref) =>
       ref * ref + innerExpr
     }
-
-    val plan = testRelation.select(outerExpr.as("col"))
-    val rewrittenInnerExpr = (a + a).as("_common_expr_0")
-    val rewrittenOuterExpr = (b + b).as("_common_expr_1")
-    val finalExpr = rewrittenOuterExpr.toAttribute * 
rewrittenOuterExpr.toAttribute +
-      (rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
+    val finalExpr = $"_common_expr_1" * $"_common_expr_1" + ($"_common_expr_0" 
+ $"_common_expr_0")
     comparePlans(
-      Optimizer.execute(plan),
+      Optimizer.execute(testRelation.select(outerExpr.as("col"))),
       testRelation
-        .select((testRelation.output :+ rewrittenInnerExpr): _*)
-        .select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ 
rewrittenOuterExpr): _*)
+        .select(star(), (b + b).as("_common_expr_1"))
+        .select(star(), (a + a).as("_common_expr_0"))
         .select(finalExpr.as("col"))
         .analyze
     )
   }
 
-  test("correlated nested WITH expression is not supported") {
+  test("correlated nested WITH expression is supported") {
     val Seq(a, b) = testRelation.output
     val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
     val outerRef = new CommonExpressionRef(outerCommonExprDef)
+    val rewrittenOuterExpr = (b + b).as("_common_expr_0")
 
     // The inner expression definition references the outer expression
     val commonExprDef1 = CommonExpressionDef(a + a + outerRef, 
CommonExpressionId(1))
     val ref1 = new CommonExpressionRef(commonExprDef1)
     val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))
-
     val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef))
-    
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col"))))
+    comparePlans(
+      Optimizer.execute(testRelation.select(outerExpr1.as("col"))),
+      testRelation
+        // The first Project contains the common expression of the outer With
+        .select(star(), rewrittenOuterExpr)
+        // The second Project contains the common expression of the inner 
With, which references
+        // the common expression of the outer With.
+        .select(star(), (a + a + $"_common_expr_0").as("_common_expr_1"))
+        // The final Project contains the final result expression, which 
references both common
+        // expressions.
+        .select(($"_common_expr_0" + ($"_common_expr_1" + 
$"_common_expr_1")).as("col"))
+        .analyze
+    )
 
-    val commonExprDef2 = CommonExpressionDef(a + a)
+    val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
     val ref2 = new CommonExpressionRef(commonExprDef2)
     // The inner main expression references the outer expression
-    val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1))
-
+    val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
     val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
-    
intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col"))))
+    comparePlans(
+      Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
+      testRelation
+        // The first Project contains the common expression of the outer With
+        .select(star(), rewrittenOuterExpr)
+        // The second Project contains the common expression of the inner 
With, which does not
+        // reference the common expression of the outer With.
+        .select(star(), (a + a).as("_common_expr_2"))
+        // The final Project contains the final result expression, which 
references both common
+        // expressions.
+        .select(($"_common_expr_0" + ($"_common_expr_2" + 
$"_common_expr_0")).as("col"))
+        .analyze
+    )
   }
 
   test("WITH expression in filter") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to