This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0a894cc50e63 [SQL][MINOR] Fix bugs with self references inside
subqueries in rCTEs
0a894cc50e63 is described below
commit 0a894cc50e63ae085ea608b6043fbbc5ad2b7076
Author: pavle-martinovic_data <[email protected]>
AuthorDate: Tue Jun 10 09:04:11 2025 -0700
[SQL][MINOR] Fix bugs with self references inside subqueries in rCTEs
### What changes were proposed in this pull request?
Update function checkIfSelfReferenceIsPlacedCorrectly to check subqueries
and disable OneRowRelation optimization in case the project has a subquery.
### Why are the changes needed?
Fixes two bugs:
- First bug is when an illegal self reference (for example an aggregate)
was made in a subquery.
- Second bug is when a OneRowSubquery was used, as this resolved to a
Project with a OneRowRelation as a son, but a subquery too. Having this
subquery lead to incorrect solutions, as it had conflicting definitions of
subqueries.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New and old golden file tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #51135 from
Pajaraja/pavle-martinovic_data/FixUnionLoopSubqueryIllegalPlaceReference.
Authored-by: pavle-martinovic_data <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/analysis/ResolveWithCTE.scala | 63 +++++++++++-----------
.../sql/catalyst/plans/logical/cteOperators.scala | 4 +-
.../apache/spark/sql/execution/UnionLoopExec.scala | 2 +-
.../analyzer-results/cte-recursion.sql.out | 52 +++++++++++-------
.../resources/sql-tests/inputs/cte-recursion.sql | 7 +++
.../sql-tests/results/cte-recursion.sql.out | 29 +++++++---
6 files changed, 100 insertions(+), 57 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
index 2a522e98a768..457a484b6209 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala
@@ -319,35 +319,38 @@ object ResolveWithCTE extends Rule[LogicalPlan] {
def checkIfSelfReferenceIsPlacedCorrectly(
plan: LogicalPlan,
cteId: Long,
- allowRecursiveRef: Boolean = true): Unit = plan match {
- case Join(left, right, Inner, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
- case Join(left, right, Cross, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
- case Join(left, right, LeftOuter, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef =
false)
- case Join(left, right, RightOuter, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef =
false)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
- case Join(left, right, LeftSemi, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef =
false)
- case Join(left, right, LeftAnti, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef =
false)
- case Join(left, right, _, _, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef =
false)
- checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef =
false)
- case Aggregate(_, _, child, _) =>
- checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef =
false)
- case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId =>
- throw new AnalysisException(
- errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
- messageParameters = Map.empty)
- case other =>
- other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId,
allowRecursiveRef))
+ allowRecursiveRef: Boolean = true): Unit = {
+ plan match {
+ case Join(left, right, Inner, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+ case Join(left, right, Cross, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+ case Join(left, right, LeftOuter, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef
= false)
+ case Join(left, right, RightOuter, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef =
false)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef)
+ case Join(left, right, LeftSemi, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef
= false)
+ case Join(left, right, LeftAnti, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef
= false)
+ case Join(left, right, _, _, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef =
false)
+ checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef
= false)
+ case Aggregate(_, _, child, _) =>
+ checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef
= false)
+ case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId =>
+ throw new AnalysisException(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
+ messageParameters = Map.empty)
+ case other =>
+ other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId,
allowRecursiveRef))
+ }
+ plan.subqueries.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId,
allowRecursiveRef))
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
index de980f8f6396..072aa4540775 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala
@@ -68,8 +68,8 @@ case class UnionLoop(
* @param loopId The id of the loop, inherited from [[CTERelationRef]] which
got resolved into this
* UnionLoopRef.
* @param output The output attributes of this recursive reference.
- * @param accumulated If false the the reference stands for the result of the
previous iteration.
- * If it is true then then it stands for the union of all
previous iteration
+ * @param accumulated If false the reference stands for the result of the
previous iteration.
+ * If it is true then it stands for the union of all
previous iteration
* results.
*/
case class UnionLoopRef(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
index e14f6f378b02..977cc5f52bb7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
@@ -210,7 +210,7 @@ case class UnionLoopExec(
//
SQLConf.CTE_RECURSION_ANCHOR_ROWS_LIMIT_TO_CONVERT_TO_LOCAL_RELATION is set to
be
// anything larger than 0. However, we still handle this case in a
special way to
// optimize the case when the flag is set to 0.
- case p @ Project(projectList, _: OneRowRelation) =>
+ case p @ Project(projectList, _: OneRowRelation) if
p.subqueries.isEmpty =>
prevPlan = p
val prevPlanToRefMapping = projectList.zip(r.output).map {
case (fa: Alias, ta) =>
fa.withExprId(ta.exprId).withName(ta.name)
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
index f67d7c6f8142..037b484f1263 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
@@ -471,24 +471,11 @@ WITH RECURSIVE t(col) (
)
SELECT * FROM t LIMIT 5
-- !query analysis
-WithCTE
-:- CTERelationDef xxxx, false
-: +- SubqueryAlias t
-: +- Project [1#x AS col#x]
-: +- UnionLoop xxxx
-: :- Project [1 AS 1#x]
-: : +- OneRowRelation
-: +- Project [scalar-subquery#x [] AS scalarsubquery()#x]
-: : +- Aggregate [max(col#x) AS max(col)#x]
-: : +- SubqueryAlias t
-: : +- Project [1#x AS col#x]
-: : +- UnionLoopRef xxxx, [1#x], false
-: +- OneRowRelation
-+- GlobalLimit 5
- +- LocalLimit 5
- +- Project [col#x]
- +- SubqueryAlias t
- +- CTERelationRef xxxx, true, [col#x], false, false
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "INVALID_RECURSIVE_REFERENCE.PLACE",
+ "sqlState" : "42836"
+}
-- !query
@@ -511,6 +498,35 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}
+-- !query
+WITH RECURSIVE t1(n) AS (
+ SELECT 1
+ UNION ALL
+ SELECT (SELECT n+1 FROM t1 WHERE n<5)
+)
+SELECT * FROM t1 LIMIT 5
+-- !query analysis
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias t1
+: +- Project [1#x AS n#x]
+: +- UnionLoop xxxx
+: :- Project [1 AS 1#x]
+: : +- OneRowRelation
+: +- Project [scalar-subquery#x [] AS scalarsubquery()#x]
+: : +- Project [(n#x + 1) AS (n + 1)#x]
+: : +- Filter (n#x < 5)
+: : +- SubqueryAlias t1
+: : +- Project [1#x AS n#x]
+: : +- UnionLoopRef xxxx, [1#x], false
+: +- OneRowRelation
++- GlobalLimit 5
+ +- LocalLimit 5
+ +- Project [n#x]
+ +- SubqueryAlias t1
+ +- CTERelationRef xxxx, true, [n#x], false, false
+
+
-- !query
WITH RECURSIVE
t1 AS (
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
index 3005ceff503c..208cdbe21eee 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
@@ -182,6 +182,13 @@ WITH
)
SELECT * FROM t2;
+-- Self reference is inside OneRowSubquery
+WITH RECURSIVE t1(n) AS (
+ SELECT 1
+ UNION ALL
+ SELECT (SELECT n+1 FROM t1 WHERE n<5)
+)
+SELECT * FROM t1 LIMIT 5;
-- recursive reference in a nested CTE
WITH RECURSIVE
diff --git
a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
index 346d6a4140fa..22501f668363 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
@@ -504,13 +504,13 @@ WITH RECURSIVE t(col) (
)
SELECT * FROM t LIMIT 5
-- !query schema
-struct<col:int>
+struct<>
-- !query output
-1
-1
-1
-1
-1
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "INVALID_RECURSIVE_REFERENCE.PLACE",
+ "sqlState" : "42836"
+}
-- !query
@@ -535,6 +535,23 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}
+-- !query
+WITH RECURSIVE t1(n) AS (
+ SELECT 1
+ UNION ALL
+ SELECT (SELECT n+1 FROM t1 WHERE n<5)
+)
+SELECT * FROM t1 LIMIT 5
+-- !query schema
+struct<n:int>
+-- !query output
+1
+2
+3
+4
+5
+
+
-- !query
WITH RECURSIVE
t1 AS (
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]