This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new abc72f2da462 [SPARK-52232][SQL] Fix non-deterministic queries to
produce different results at every step
abc72f2da462 is described below
commit abc72f2da462b07afdc03c4181bb7c2e63fe00b8
Author: pavle-martinovic_data <[email protected]>
AuthorDate: Thu May 29 21:11:04 2025 +0800
[SPARK-52232][SQL] Fix non-deterministic queries to produce different
results at every step
### What changes were proposed in this pull request?
Enable deterministic queries to work with rCTEs. Fix bug where
non-deterministic queries produce same result every iteration.
### Why are the changes needed?
Currently, recursive CTEs create a new plan for each iteration of the
recursion, so the expressions that use randomness "roll-back" to their initial
value, causing things like "rand()" to produce the same result every time.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests in golden files.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50957 from Pajaraja/pavle-martinovic_data/RandomInrCTEs.
Lead-authored-by: pavle-martinovic_data <[email protected]>
Co-authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/collectionOperations.scala | 3 +
.../spark/sql/catalyst/expressions/misc.scala | 2 +
.../catalyst/expressions/randomExpressions.scala | 14 +++
.../apache/spark/sql/execution/UnionLoopExec.scala | 17 +++-
.../analyzer-results/cte-recursion.sql.out | 72 ++++++++++++++
.../resources/sql-tests/inputs/cte-recursion.sql | 56 ++++++++++-
.../sql-tests/results/cte-recursion.sql.out | 108 +++++++++++++++++++++
7 files changed, 269 insertions(+), 3 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 84e52282b632..b4978fbe1f70 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1267,6 +1267,9 @@ case class Shuffle(child: Expression, randomSeed:
Option[Long] = None) extends U
override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
+ override def withShiftedSeed(shift: Long): Shuffle =
+ copy(randomSeed = randomSeed.map(_ + shift))
+
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e7d3701544c5..dcbca34b240b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -260,6 +260,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends
LeafExpression with Non
override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))
+ override def withShiftedSeed(shift: Long): Uuid = Uuid(randomSeed.map(_ +
shift))
+
override lazy val resolved: Boolean = randomSeed.isDefined
override def nullable: Boolean = false
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index fa6eb2c11189..06cc6e55c8ec 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -76,6 +76,7 @@ trait ExpressionWithRandomSeed extends Expression {
def seedExpression: Expression
def withNewSeed(seed: Long): Expression
+ def withShiftedSeed(shift: Long): Expression
}
private[catalyst] object ExpressionWithRandomSeed {
@@ -114,6 +115,9 @@ case class Rand(child: Expression, hideSeed: Boolean =
false) extends Nondetermi
override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType),
hideSeed)
+ override def withShiftedSeed(shift: Long): Rand =
+ Rand(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)
+
override protected def evalInternal(input: InternalRow): Double =
rng.nextDouble()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -165,6 +169,9 @@ case class Randn(child: Expression, hideSeed: Boolean =
false) extends Nondeterm
override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType),
hideSeed)
+ override def withShiftedSeed(shift: Long): Randn =
+ Randn(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)
+
override protected def evalInternal(input: InternalRow): Double =
rng.nextGaussian()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -268,6 +275,9 @@ case class Uniform(min: Expression, max: Expression,
seedExpression: Expression,
override def withNewSeed(newSeed: Long): Expression =
Uniform(min, max, Literal(newSeed, LongType), hideSeed)
+ override def withShiftedSeed(shift: Long): Expression =
+ Uniform(min, max, Literal(seed + shift, LongType), hideSeed)
+
override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression):
Expression =
Uniform(newFirst, newSecond, newThird, hideSeed)
@@ -348,6 +358,10 @@ case class RandStr(
override def withNewSeed(newSeed: Long): Expression =
RandStr(length, Literal(newSeed, LongType), hideSeed)
+
+ override def withShiftedSeed(shift: Long): Expression =
+ RandStr(length, Literal(seed + shift, LongType), hideSeed)
+
override def withNewChildrenInternal(newFirst: Expression, newSecond:
Expression): Expression =
RandStr(newFirst, newSecond, hideSeed)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
index 33188db5d23b..d44d3b0b6ef0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.SparkException
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
InterpretedMutableProjection, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
ExpressionWithRandomSeed, InterpretedMutableProjection, Literal}
import
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation.hasUnevaluableExpr
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LocalRelation,
LogicalPlan, OneRowRelation, Project, Union, UnionLoopRef}
@@ -183,11 +183,24 @@ case class UnionLoopExec(
// Main loop for obtaining the result of the recursive query.
while (prevCount > 0 && !limitReached) {
var prevPlan: LogicalPlan = null
+
+ // If the recursive part contains non-deterministic expressions that
depends on a seed, we
+ // need to create a new seed since the seed for this expression is set
in the analysis, and
+ // we avoid re-triggering the analysis for every iterative step.
+ val recursionReseeded = if (currentLevel == 1 ||
recursion.deterministic) {
+ recursion
+ } else {
+ recursion.transformAllExpressionsWithSubqueries {
+ case e: ExpressionWithRandomSeed =>
+ e.withShiftedSeed(currentLevel - 1)
+ }
+ }
+
// the current plan is created by substituting UnionLoopRef node with
the project node of
// the previous plan.
// This way we support only UNION ALL case. Additional case should be
added for UNION case.
// One way of supporting UNION case can be seen at SPARK-24497 PR from
Peter Toth.
- val newRecursion = recursion.transformWithSubqueries {
+ val newRecursion = recursionReseeded.transformWithSubqueries {
case r: UnionLoopRef if r.loopId == loopId =>
prevDF.queryExecution.optimizedPlan match {
case l: LocalRelation =>
diff --git
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
index 4aff03883865..dc2b5a20fde5 100644
---
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
+++
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
@@ -1629,3 +1629,75 @@ WithCTE
+- Project [n#x]
+- SubqueryAlias t1
+- CTERelationRef xxxx, true, [n#x], false, false
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
+ UNION ALL
+ SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
+ UNION ALL
+ SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
+ UNION ALL
+ SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT randstr(10, 82374)
+ UNION ALL
+ SELECT randstr(10, 237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT UUID(82374)
+ UNION ALL
+ SELECT UUID(237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT ARRAY(1,2,3,4,5)
+ UNION ALL
+ SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
index fba8861083be..8ef0c391a3fc 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
@@ -586,4 +586,58 @@ WITH RECURSIVE t1 AS (
SELECT 1 AS n
UNION ALL
SELECT n+1 FROM t2 WHERE n < 5)
-SELECT * FROM t1;
\ No newline at end of file
+SELECT * FROM t1;
+
+-- Non-deterministic query with rand with seed
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
+ UNION ALL
+ SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with uniform with seed
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
+ UNION ALL
+ SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with randn with seed
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
+ UNION ALL
+ SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with randstr
+WITH RECURSIVE randoms(val) AS (
+ SELECT randstr(10, 82374)
+ UNION ALL
+ SELECT randstr(10, 237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with UUID
+WITH RECURSIVE randoms(val) AS (
+ SELECT UUID(82374)
+ UNION ALL
+ SELECT UUID(237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with shuffle
+WITH RECURSIVE randoms(val) AS (
+ SELECT ARRAY(1,2,3,4,5)
+ UNION ALL
+ SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
\ No newline at end of file
diff --git
a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
index 06f440a3f633..d6939ab84b57 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
@@ -1475,3 +1475,111 @@ struct<n:int>
3
4
5
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
+ UNION ALL
+ SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:int>
+-- !query output
+1
+3
+4
+4
+5
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
+ UNION ALL
+ SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:int>
+-- !query output
+1
+3
+4
+4
+5
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
+ UNION ALL
+ SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:int>
+-- !query output
+-2
+2
+2
+5
+6
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT randstr(10, 82374)
+ UNION ALL
+ SELECT randstr(10, 237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:string>
+-- !query output
+IpXzdTW03I
+Zj7uI2Ex6e
+dBlWnfo7rO
+fmfDBMf60f
+kFeBV7dQWi
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT UUID(82374)
+ UNION ALL
+ SELECT UUID(237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:string>
+-- !query output
+19974dca-21f6-47ef-b58c-73908ab52aa0
+4ea190e3-c088-4ddd-a545-fb431059ae3c
+8b88900e-f862-468c-8d3b-828188116155
+be4f5346-1c7f-4697-8a2c-1343347872c5
+d0032efe-ae60-461b-8582-f6a7c649f238
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+ SELECT ARRAY(1,2,3,4,5)
+ UNION ALL
+ SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
+ FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:array<int>>
+-- !query output
+[1,2,3,4,5]
+[1,2,3,5,4]
+[2,1,5,3,4]
+[4,3,2,5,1]
+[4,5,1,2,3]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]