This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a894cc50e63 [SQL][MINOR] Fix bugs with self references inside 
subqueries in rCTEs
0a894cc50e63 is described below

commit 0a894cc50e63ae085ea608b6043fbbc5ad2b7076
Author: pavle-martinovic_data <[email protected]>
AuthorDate: Tue Jun 10 09:04:11 2025 -0700

    [SQL][MINOR] Fix bugs with self references inside subqueries in rCTEs
    
    ### What changes were proposed in this pull request?
    
    Update function checkIfSelfReferenceIsPlacedCorrectly to check subqueries 
and disable OneRowRelation optimization in case the project has a subquery.
    
    ### Why are the changes needed?
    
    Fixes two bugs:
    - First bug is when an illegal self reference (for example an aggregate) 
was made in a subquery.
    - Second bug is when a OneRowSubquery was used, as this resolved to a 
Project with a OneRowRelation as a son, but a subquery too. Having this 
subquery lead to incorrect solutions, as it had conflicting definitions of 
subqueries.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New and old golden file tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #51135 from 
Pajaraja/pavle-martinovic_data/FixUnionLoopSubqueryIllegalPlaceReference.
    
    Authored-by: pavle-martinovic_data <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../sql/catalyst/analysis/ResolveWithCTE.scala     | 63 +++++++++++-----------
 .../sql/catalyst/plans/logical/cteOperators.scala  |  4 +-
 .../apache/spark/sql/execution/UnionLoopExec.scala |  2 +-
 .../analyzer-results/cte-recursion.sql.out         | 52 +++++++++++-------
 .../resources/sql-tests/inputs/cte-recursion.sql   |  7 +++
 .../sql-tests/results/cte-recursion.sql.out        | 29 +++++++---
 6 files changed, 100 insertions(+), 57 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
index 2a522e98a768..457a484b6209 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
@@ -319,35 +319,38 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
   def checkIfSelfReferenceIsPlacedCorrectly(
       plan: LogicalPlan,
       cteId: Long,
-      allowRecursiveRef: Boolean = true): Unit = plan match {
-    case Join(left, right, Inner, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
-    case Join(left, right, Cross, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
-    case Join(left, right, LeftOuter, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = 
false)
-    case Join(left, right, RightOuter, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = 
false)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
-    case Join(left, right, LeftSemi, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = 
false)
-    case Join(left, right, LeftAnti, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = 
false)
-    case Join(left, right, _, _, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = 
false)
-      checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = 
false)
-    case Aggregate(_, _, child, _) =>
-      checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef = 
false)
-    case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId =>
-      throw new AnalysisException(
-        errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
-        messageParameters = Map.empty)
-    case other =>
-      other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId, 
allowRecursiveRef))
+      allowRecursiveRef: Boolean = true): Unit = {
+    plan match {
+      case Join(left, right, Inner, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+      case Join(left, right, Cross, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+      case Join(left, right, LeftOuter, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef 
= false)
+      case Join(left, right, RightOuter, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = 
false)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+      case Join(left, right, LeftSemi, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef 
= false)
+      case Join(left, right, LeftAnti, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef 
= false)
+      case Join(left, right, _, _, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = 
false)
+        checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef 
= false)
+      case Aggregate(_, _, child, _) =>
+        checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef 
= false)
+      case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId =>
+        throw new AnalysisException(
+          errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
+          messageParameters = Map.empty)
+      case other =>
+        other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId, 
allowRecursiveRef))
+    }
+    plan.subqueries.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId, 
allowRecursiveRef))
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
index de980f8f6396..072aa4540775 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
@@ -68,8 +68,8 @@ case class UnionLoop(
  * @param loopId The id of the loop, inherited from [[CTERelationRef]] which 
got resolved into this
  *               UnionLoopRef.
  * @param output The output attributes of this recursive reference.
- * @param accumulated If false the the reference stands for the result of the 
previous iteration.
- *                    If it is true then then it stands for the union of all 
previous iteration
+ * @param accumulated If false the reference stands for the result of the 
previous iteration.
+ *                    If it is true then it stands for the union of all 
previous iteration
  *                    results.
  */
 case class UnionLoopRef(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
index e14f6f378b02..977cc5f52bb7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
@@ -210,7 +210,7 @@ case class UnionLoopExec(
             // 
SQLConf.CTE_RECURSION_ANCHOR_ROWS_LIMIT_TO_CONVERT_TO_LOCAL_RELATION is set to 
be
             // anything larger than 0. However, we still handle this case in a 
special way to
             // optimize the case when the flag is set to 0.
-            case p @ Project(projectList, _: OneRowRelation) =>
+            case p @ Project(projectList, _: OneRowRelation) if 
p.subqueries.isEmpty =>
               prevPlan = p
               val prevPlanToRefMapping = projectList.zip(r.output).map {
                 case (fa: Alias, ta) => 
fa.withExprId(ta.exprId).withName(ta.name)
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
index f67d7c6f8142..037b484f1263 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
@@ -471,24 +471,11 @@ WITH RECURSIVE t(col) (
 )
 SELECT * FROM t LIMIT 5
 -- !query analysis
-WithCTE
-:- CTERelationDef xxxx, false
-:  +- SubqueryAlias t
-:     +- Project [1#x AS col#x]
-:        +- UnionLoop xxxx
-:           :- Project [1 AS 1#x]
-:           :  +- OneRowRelation
-:           +- Project [scalar-subquery#x [] AS scalarsubquery()#x]
-:              :  +- Aggregate [max(col#x) AS max(col)#x]
-:              :     +- SubqueryAlias t
-:              :        +- Project [1#x AS col#x]
-:              :           +- UnionLoopRef xxxx, [1#x], false
-:              +- OneRowRelation
-+- GlobalLimit 5
-   +- LocalLimit 5
-      +- Project [col#x]
-         +- SubqueryAlias t
-            +- CTERelationRef xxxx, true, [col#x], false, false
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "INVALID_RECURSIVE_REFERENCE.PLACE",
+  "sqlState" : "42836"
+}
 
 
 -- !query
@@ -511,6 +498,35 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
+-- !query
+WITH RECURSIVE t1(n) AS (
+    SELECT 1
+    UNION ALL
+    SELECT (SELECT n+1 FROM t1 WHERE n<5)
+)
+SELECT * FROM t1 LIMIT 5
+-- !query analysis
+WithCTE
+:- CTERelationDef xxxx, false
+:  +- SubqueryAlias t1
+:     +- Project [1#x AS n#x]
+:        +- UnionLoop xxxx
+:           :- Project [1 AS 1#x]
+:           :  +- OneRowRelation
+:           +- Project [scalar-subquery#x [] AS scalarsubquery()#x]
+:              :  +- Project [(n#x + 1) AS (n + 1)#x]
+:              :     +- Filter (n#x < 5)
+:              :        +- SubqueryAlias t1
+:              :           +- Project [1#x AS n#x]
+:              :              +- UnionLoopRef xxxx, [1#x], false
+:              +- OneRowRelation
++- GlobalLimit 5
+   +- LocalLimit 5
+      +- Project [n#x]
+         +- SubqueryAlias t1
+            +- CTERelationRef xxxx, true, [n#x], false, false
+
+
 -- !query
 WITH RECURSIVE
   t1 AS (
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql 
b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
index 3005ceff503c..208cdbe21eee 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
@@ -182,6 +182,13 @@ WITH
   )
 SELECT * FROM t2;
 
+-- Self reference is inside OneRowSubquery
+WITH RECURSIVE t1(n) AS (
+    SELECT 1
+    UNION ALL
+    SELECT (SELECT n+1 FROM t1 WHERE n<5)
+)
+SELECT * FROM t1 LIMIT 5;
 
 -- recursive reference in a nested CTE
 WITH RECURSIVE
diff --git 
a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out 
b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
index 346d6a4140fa..22501f668363 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
@@ -504,13 +504,13 @@ WITH RECURSIVE t(col) (
 )
 SELECT * FROM t LIMIT 5
 -- !query schema
-struct<col:int>
+struct<>
 -- !query output
-1
-1
-1
-1
-1
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+  "errorClass" : "INVALID_RECURSIVE_REFERENCE.PLACE",
+  "sqlState" : "42836"
+}
 
 
 -- !query
@@ -535,6 +535,23 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 }
 
 
+-- !query
+WITH RECURSIVE t1(n) AS (
+    SELECT 1
+    UNION ALL
+    SELECT (SELECT n+1 FROM t1 WHERE n<5)
+)
+SELECT * FROM t1 LIMIT 5
+-- !query schema
+struct<n:int>
+-- !query output
+1
+2
+3
+4
+5
+
+
 -- !query
 WITH RECURSIVE
   t1 AS (


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to