This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 2ee37d24bc3c [SPARK-50739][SQL][FOLLOW] Simplify 
ResolveRecursiveCTESuite with dsl
2ee37d24bc3c is described below

commit 2ee37d24bc3c1d7bbf0a2dff8fc6884c5547038b
Author: Wenchen Fan <[email protected]>
AuthorDate: Sat Jan 18 14:54:09 2025 -0800

    [SPARK-50739][SQL][FOLLOW] Simplify ResolveRecursiveCTESuite with dsl
    
    ### What changes were proposed in this pull request?
    
    A followup of https://github.com/apache/spark/pull/49351 to simplify the 
test via dsl.
    
    ### Why are the changes needed?
    
    code cleanup
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    N/A
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #49557 from cloud-fan/clean.
    
    Authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit 66dd7ddc4be352eebe20d7a3d53ff4370f564754)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/catalyst/analysis/ResolveWithCTE.scala     |   7 +-
 .../spark/sql/errors/QueryCompilationErrors.scala  |   8 --
 .../analysis/ResolveRecursiveCTESuite.scala        | 129 +++++++++------------
 3 files changed, 57 insertions(+), 87 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
index f9e5c58b6ff8..3ad88514e17c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
 
 import scala.collection.mutable
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
-import org.apache.spark.sql.errors.QueryCompilationErrors
 
 /**
  * Updates CTE references with the resolve output attributes of corresponding 
CTE definitions.
@@ -144,8 +144,9 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
                 // Project (as UnresolvedSubqueryColumnAliases have not been 
substituted with the
                 // Project yet), leaving us with cases of SubqueryAlias->Union 
and SubqueryAlias->
                 // UnresolvedSubqueryColumnAliases->Union. The same applies to 
Distinct Union.
-                throw QueryCompilationErrors.invalidRecursiveCteError(
-                  "Unsupported recursive CTE UNION placement.")
+                throw new AnalysisException(
+                  errorClass = "INVALID_RECURSIVE_CTE",
+                  messageParameters = Map.empty)
             }
         }
         withCTE.copy(cteDefs = newCTEDefs)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index aa72a37fcb7e..e6f8f0d73d7d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -4332,12 +4332,4 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
       origin = origin
     )
   }
-
-  def invalidRecursiveCteError(error: String): Throwable = {
-    new AnalysisException(
-      errorClass = "INVALID_RECURSIVE_CTE",
-      messageParameters = Map(
-        "error" -> error
-      ))
-  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveRecursiveCTESuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveRecursiveCTESuite.scala
index 7d5a2c5babc7..e76e261223eb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveRecursiveCTESuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveRecursiveCTESuite.scala
@@ -18,102 +18,79 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import 
org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.ResolveSubqueryColumnAliases
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.RuleExecutor
 
 class ResolveRecursiveCTESuite extends AnalysisTest {
   // Motivated by:
   // WITH RECURSIVE t AS (SELECT 1 UNION ALL SELECT * FROM t) SELECT * FROM t;
   test("ResolveWithCTE rule on recursive CTE without 
UnresolvedSubqueryColumnAliases") {
-    // The analyzer will repeat ResolveWithCTE rule twice.
-    val rules = Seq(ResolveWithCTE, ResolveWithCTE)
-    val analyzer = new RuleExecutor[LogicalPlan] {
-      override val batches = Seq(Batch("Resolution", Once, rules: _*))
-    }
-    // Since cteDef IDs need to be the same, cteDef for each case will be 
created by copying
-    // this one with its child replaced.
-    val cteDef = CTERelationDef(OneRowRelation())
-    val anchor = Project(Seq(Alias(Literal(1), "1")()), OneRowRelation())
-
-    def getBeforePlan(cteDef: CTERelationDef): LogicalPlan = {
-      val recursionPart = SubqueryAlias("t",
-          CTERelationRef(cteDef.id, false, Seq(), false, recursive = true))
-
-      val cteDefFinal = cteDef.copy(child =
-        SubqueryAlias("t", Union(Seq(anchor, recursionPart))))
-
+    val cteId = 0
+    val anchor = Project(Seq(Alias(Literal(1), "c")()), OneRowRelation())
+
+    def getBeforePlan(): LogicalPlan = {
+      val cteRef = CTERelationRef(
+        cteId,
+        _resolved = false,
+        output = Seq(),
+        isStreaming = false)
+      val recursion = cteRef.copy(recursive = true).subquery("t")
       WithCTE(
-        SubqueryAlias("t", CTERelationRef(cteDefFinal.id, false, Seq(), false, 
recursive = false)),
-        Seq(cteDefFinal))
+        cteRef.copy(recursive = false),
+        Seq(CTERelationDef(anchor.union(recursion).subquery("t"), cteId)))
     }
 
-    def getAfterPlan(cteDef: CTERelationDef): LogicalPlan = {
-      val saRecursion = SubqueryAlias("t",
-        UnionLoopRef(cteDef.id, anchor.output, false))
-
-      val cteDefFinal = cteDef.copy(child =
-        SubqueryAlias("t", UnionLoop(cteDef.id, anchor, saRecursion)))
-
-      val outerCteRef = CTERelationRef(cteDefFinal.id, true, 
cteDefFinal.output, false,
-        recursive = false)
-
-      WithCTE(SubqueryAlias("t", outerCteRef), Seq(cteDefFinal))
+    def getAfterPlan(): LogicalPlan = {
+      val recursion = UnionLoopRef(cteId, anchor.output, accumulated = 
false).subquery("t")
+      val cteDef = CTERelationDef(UnionLoop(cteId, anchor, 
recursion).subquery("t"), cteId)
+      val cteRef = CTERelationRef(
+        cteId,
+        _resolved = true,
+        output = cteDef.output,
+        isStreaming = false)
+      WithCTE(cteRef, Seq(cteDef))
     }
 
-    val beforePlan = getBeforePlan(cteDef)
-    val afterPlan = getAfterPlan(cteDef)
-
-    comparePlans(analyzer.execute(beforePlan), afterPlan)
+    comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
   }
 
   // Motivated by:
   // WITH RECURSIVE t(n) AS (SELECT 1 UNION ALL SELECT * FROM t) SELECT * FROM 
t;
   test("ResolveWithCTE rule on recursive CTE with 
UnresolvedSubqueryColumnAliases") {
-    // The analyzer will repeat ResolveWithCTE rule twice.
-    val rules = Seq(ResolveWithCTE, ResolveSubqueryColumnAliases, 
ResolveWithCTE)
-    val analyzer = new RuleExecutor[LogicalPlan] {
-      override val batches = Seq(Batch("Resolution", Once, rules: _*))
+    val cteId = 0
+    val anchor = Project(Seq(Alias(Literal(1), "c")()), OneRowRelation())
+
+    def getBeforePlan(): LogicalPlan = {
+      val cteRef = CTERelationRef(
+        cteId,
+        _resolved = false,
+        output = Seq(),
+        isStreaming = false)
+      val recursion = cteRef.copy(recursive = true).subquery("t")
+      val cteDef = CTERelationDef(
+        UnresolvedSubqueryColumnAliases(Seq("n"), 
anchor.union(recursion)).subquery("t"),
+        cteId)
+      WithCTE(cteRef.copy(recursive = false), Seq(cteDef))
     }
-    // Since cteDef IDs need to be the same, cteDef for each case will be 
created by copying
-    // this one with its child replaced.
-    val cteDef = CTERelationDef(OneRowRelation())
-    val anchor = Project(Seq(Alias(Literal(1), "1")()), OneRowRelation())
-
-    def getBeforePlan(cteDef: CTERelationDef): LogicalPlan = {
-      val recursionPart = SubqueryAlias("t",
-          CTERelationRef(cteDef.id, false, Seq(), false, recursive = true))
 
-      val cteDefFinal = cteDef.copy(child =
-        SubqueryAlias("t",
-          UnresolvedSubqueryColumnAliases(Seq("n"),
-            Union(Seq(anchor, recursionPart)))))
-
-      WithCTE(
-        SubqueryAlias("t", CTERelationRef(cteDefFinal.id, false, Seq(), false, 
recursive = false)),
-        Seq(cteDefFinal))
-    }
-
-    def getAfterPlan(cteDef: CTERelationDef): LogicalPlan = {
-      val saRecursion = SubqueryAlias("t",
-        Project(Seq(Alias(anchor.output.head, "n")()),
-          UnionLoopRef(cteDef.id, anchor.output, false)))
-
-      val cteDefFinal = cteDef.copy(child =
-        SubqueryAlias("t",
-          Project(Seq(Alias(anchor.output.head, "n")()),
-            UnionLoop(cteDef.id, anchor, saRecursion))))
-
-      val outerCteRef = CTERelationRef(cteDefFinal.id, true, 
cteDefFinal.output, false,
-        recursive = false)
-
-      WithCTE(SubqueryAlias("t", outerCteRef), Seq(cteDefFinal))
+    def getAfterPlan(): LogicalPlan = {
+      val col = anchor.output.head
+      val recursion = UnionLoopRef(cteId, anchor.output, accumulated = false)
+        .select(col.as("n"))
+        .subquery("t")
+      val cteDef = CTERelationDef(
+        UnionLoop(cteId, anchor, recursion).select(col.as("n")).subquery("t"),
+        cteId)
+      val cteRef = CTERelationRef(
+        cteId,
+        _resolved = true,
+        output = cteDef.output,
+        isStreaming = false)
+      WithCTE(cteRef, Seq(cteDef))
     }
 
-    val beforePlan = getBeforePlan(cteDef)
-    val afterPlan = getAfterPlan(cteDef)
-
-    comparePlans(analyzer.execute(beforePlan), afterPlan)
+    comparePlans(getAnalyzer.execute(getBeforePlan()), getAfterPlan())
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to