This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new abc72f2da462 [SPARK-52232][SQL] Fix non-deterministic queries to 
produce different results at every step
abc72f2da462 is described below

commit abc72f2da462b07afdc03c4181bb7c2e63fe00b8
Author: pavle-martinovic_data <[email protected]>
AuthorDate: Thu May 29 21:11:04 2025 +0800

    [SPARK-52232][SQL] Fix non-deterministic queries to produce different 
results at every step
    
    ### What changes were proposed in this pull request?
    
    Enable deterministic queries to work with rCTEs. Fix bug where 
non-deterministic queries produce same result every iteration.
    
    ### Why are the changes needed?
    
    Currently, recursive CTEs create a new plan for each iteration of the 
recursion, so the expressions that use randomness "roll-back" to their initial 
value, causing things like "rand()" to produce the same result every time.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests in golden files.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #50957 from Pajaraja/pavle-martinovic_data/RandomInrCTEs.
    
    Lead-authored-by: pavle-martinovic_data <[email protected]>
    Co-authored-by: Wenchen Fan <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../expressions/collectionOperations.scala         |   3 +
 .../spark/sql/catalyst/expressions/misc.scala      |   2 +
 .../catalyst/expressions/randomExpressions.scala   |  14 +++
 .../apache/spark/sql/execution/UnionLoopExec.scala |  17 +++-
 .../analyzer-results/cte-recursion.sql.out         |  72 ++++++++++++++
 .../resources/sql-tests/inputs/cte-recursion.sql   |  56 ++++++++++-
 .../sql-tests/results/cte-recursion.sql.out        | 108 +++++++++++++++++++++
 7 files changed, 269 insertions(+), 3 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 84e52282b632..b4978fbe1f70 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1267,6 +1267,9 @@ case class Shuffle(child: Expression, randomSeed: 
Option[Long] = None) extends U
 
   override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
 
+  override def withShiftedSeed(shift: Long): Shuffle =
+    copy(randomSeed = randomSeed.map(_ + shift))
+
   override lazy val resolved: Boolean =
     childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e7d3701544c5..dcbca34b240b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -260,6 +260,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends 
LeafExpression with Non
 
   override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed))
 
+  override def withShiftedSeed(shift: Long): Uuid = Uuid(randomSeed.map(_ + 
shift))
+
   override lazy val resolved: Boolean = randomSeed.isDefined
 
   override def nullable: Boolean = false
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index fa6eb2c11189..06cc6e55c8ec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -76,6 +76,7 @@ trait ExpressionWithRandomSeed extends Expression {
 
   def seedExpression: Expression
   def withNewSeed(seed: Long): Expression
+  def withShiftedSeed(shift: Long): Expression
 }
 
 private[catalyst] object ExpressionWithRandomSeed {
@@ -114,6 +115,9 @@ case class Rand(child: Expression, hideSeed: Boolean = 
false) extends Nondetermi
 
   override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), 
hideSeed)
 
+  override def withShiftedSeed(shift: Long): Rand =
+    Rand(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)
+
   override protected def evalInternal(input: InternalRow): Double = 
rng.nextDouble()
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -165,6 +169,9 @@ case class Randn(child: Expression, hideSeed: Boolean = 
false) extends Nondeterm
 
   override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), 
hideSeed)
 
+  override def withShiftedSeed(shift: Long): Randn =
+    Randn(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed)
+
   override protected def evalInternal(input: InternalRow): Double = 
rng.nextGaussian()
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -268,6 +275,9 @@ case class Uniform(min: Expression, max: Expression, 
seedExpression: Expression,
   override def withNewSeed(newSeed: Long): Expression =
     Uniform(min, max, Literal(newSeed, LongType), hideSeed)
 
+  override def withShiftedSeed(shift: Long): Expression =
+    Uniform(min, max, Literal(seed + shift, LongType), hideSeed)
+
   override def withNewChildrenInternal(
       newFirst: Expression, newSecond: Expression, newThird: Expression): 
Expression =
     Uniform(newFirst, newSecond, newThird, hideSeed)
@@ -348,6 +358,10 @@ case class RandStr(
 
   override def withNewSeed(newSeed: Long): Expression =
     RandStr(length, Literal(newSeed, LongType), hideSeed)
+
+  override def withShiftedSeed(shift: Long): Expression =
+    RandStr(length, Literal(seed + shift, LongType), hideSeed)
+
   override def withNewChildrenInternal(newFirst: Expression, newSecond: 
Expression): Expression =
     RandStr(newFirst, newSecond, hideSeed)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
index 33188db5d23b..d44d3b0b6ef0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable
 import org.apache.spark.SparkException
 import org.apache.spark.rdd.{EmptyRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
InterpretedMutableProjection, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
ExpressionWithRandomSeed, InterpretedMutableProjection, Literal}
 import 
org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation.hasUnevaluableExpr
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LocalRelation, 
LogicalPlan, OneRowRelation, Project, Union, UnionLoopRef}
@@ -183,11 +183,24 @@ case class UnionLoopExec(
     // Main loop for obtaining the result of the recursive query.
     while (prevCount > 0 && !limitReached) {
       var prevPlan: LogicalPlan = null
+
+      // If the recursive part contains non-deterministic expressions that 
depends on a seed, we
+      // need to create a new seed since the seed for this expression is set 
in the analysis, and
+      // we avoid re-triggering the analysis for every iterative step.
+      val recursionReseeded = if (currentLevel == 1 || 
recursion.deterministic) {
+        recursion
+      } else {
+        recursion.transformAllExpressionsWithSubqueries {
+          case e: ExpressionWithRandomSeed =>
+            e.withShiftedSeed(currentLevel - 1)
+        }
+      }
+
       // the current plan is created by substituting UnionLoopRef node with 
the project node of
       // the previous plan.
       // This way we support only UNION ALL case. Additional case should be 
added for UNION case.
       // One way of supporting UNION case can be seen at SPARK-24497 PR from 
Peter Toth.
-      val newRecursion = recursion.transformWithSubqueries {
+      val newRecursion = recursionReseeded.transformWithSubqueries {
         case r: UnionLoopRef if r.loopId == loopId =>
           prevDF.queryExecution.optimizedPlan match {
             case l: LocalRelation =>
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
index 4aff03883865..dc2b5a20fde5 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out
@@ -1629,3 +1629,75 @@ WithCTE
 +- Project [n#x]
    +- SubqueryAlias t1
       +- CTERelationRef xxxx, true, [n#x], false, false
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
+    UNION ALL
+    SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
+    UNION ALL
+    SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
+    UNION ALL
+    SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT randstr(10, 82374)
+    UNION ALL
+    SELECT randstr(10, 237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT UUID(82374)
+    UNION ALL
+    SELECT UUID(237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT ARRAY(1,2,3,4,5)
+    UNION ALL
+    SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql 
b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
index fba8861083be..8ef0c391a3fc 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql
@@ -586,4 +586,58 @@ WITH RECURSIVE t1 AS (
     SELECT 1 AS n
     UNION ALL
     SELECT n+1 FROM t2 WHERE n < 5)
-SELECT * FROM t1;
\ No newline at end of file
+SELECT * FROM t1;
+
+-- Non-deterministic query with rand with seed
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
+    UNION ALL
+    SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with uniform with seed
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
+    UNION ALL
+    SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with randn with seed
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
+    UNION ALL
+    SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with randstr
+WITH RECURSIVE randoms(val) AS (
+    SELECT randstr(10, 82374)
+    UNION ALL
+    SELECT randstr(10, 237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with UUID
+WITH RECURSIVE randoms(val) AS (
+    SELECT UUID(82374)
+    UNION ALL
+    SELECT UUID(237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
+
+-- Non-deterministic query with shuffle
+WITH RECURSIVE randoms(val) AS (
+    SELECT ARRAY(1,2,3,4,5)
+    UNION ALL
+    SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5;
\ No newline at end of file
diff --git 
a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out 
b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
index 06f440a3f633..d6939ab84b57 100644
--- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out
@@ -1475,3 +1475,111 @@ struct<n:int>
 3
 4
 5
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(floor(rand(82374) * 5 + 1) AS INT)
+    UNION ALL
+    SELECT CAST(floor(rand(237685) * 5 + 1) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:int>
+-- !query output
+1
+3
+4
+4
+5
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(UNIFORM(1, 6, 82374) AS INT)
+    UNION ALL
+    SELECT CAST(UNIFORM(1, 6, 237685) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:int>
+-- !query output
+1
+3
+4
+4
+5
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT CAST(floor(randn(82374) * 5 + 1) AS INT)
+    UNION ALL
+    SELECT CAST(floor(randn(237685) * 5 + 1) AS INT)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:int>
+-- !query output
+-2
+2
+2
+5
+6
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT randstr(10, 82374)
+    UNION ALL
+    SELECT randstr(10, 237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:string>
+-- !query output
+IpXzdTW03I
+Zj7uI2Ex6e
+dBlWnfo7rO
+fmfDBMf60f
+kFeBV7dQWi
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT UUID(82374)
+    UNION ALL
+    SELECT UUID(237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:string>
+-- !query output
+19974dca-21f6-47ef-b58c-73908ab52aa0
+4ea190e3-c088-4ddd-a545-fb431059ae3c
+8b88900e-f862-468c-8d3b-828188116155
+be4f5346-1c7f-4697-8a2c-1343347872c5
+d0032efe-ae60-461b-8582-f6a7c649f238
+
+
+-- !query
+WITH RECURSIVE randoms(val) AS (
+    SELECT ARRAY(1,2,3,4,5)
+    UNION ALL
+    SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685)
+    FROM randoms
+)
+SELECT val FROM randoms LIMIT 5
+-- !query schema
+struct<val:array<int>>
+-- !query output
+[1,2,3,4,5]
+[1,2,3,5,4]
+[2,1,5,3,4]
+[4,3,2,5,1]
+[4,5,1,2,3]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to