This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 75fff90d2618 [SPARK-45685][SQL][FOLLOWUP] Add handling for `Stream` 
where `LazyList.force` is called
75fff90d2618 is described below

commit 75fff90d2618617a66b9a3311792c8b1bbbb6e8e
Author: yangjie01 <[email protected]>
AuthorDate: Fri Jun 14 12:30:44 2024 +0800

    [SPARK-45685][SQL][FOLLOWUP] Add handling for `Stream` where 
`LazyList.force` is called
    
    ### What changes were proposed in this pull request?
    Refer to the suggestion of 
https://github.com/apache/spark/pull/43563#pullrequestreview-2114900378, this 
pr add handling for Stream where LazyList.force is called
    
    ### Why are the changes needed?
    Even though `Stream` is deprecated in 2.13, it is not _removed_ and thus is 
is possible that some parts of Spark / Catalyst (or third-party code) might 
continue to pass around `Stream` instances. Hence, we should restore the call 
to `Stream.force` where `.force` is called on `LazyList`, to avoid losing the 
eager materialization for Streams that happen to flow to these call sites. This 
is also a guarantee of compatibility.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Add some new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #46970 from LuciferYang/SPARK-45685-FOLLOWUP.
    
    Authored-by: yangjie01 <[email protected]>
    Signed-off-by: yangjie01 <[email protected]>
---
 .../spark/sql/catalyst/plans/QueryPlan.scala       |  4 +++-
 .../apache/spark/sql/catalyst/trees/TreeNode.scala | 13 ++++++++---
 .../sql/catalyst/plans/LogicalPlanSuite.scala      | 22 ++++++++++++++++++
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   | 27 ++++++++++++++++++++++
 .../sql/execution/WholeStageCodegenExec.scala      |  4 +++-
 .../apache/spark/sql/execution/PlannerSuite.scala  |  8 +++++++
 .../sql/execution/WholeStageCodegenSuite.scala     | 10 ++++++++
 7 files changed, 83 insertions(+), 5 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index bc0ca31dc635..c9c8fdb676b2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -226,12 +226,14 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
       }
     }
 
+    @scala.annotation.nowarn("cat=deprecation")
     def recursiveTransform(arg: Any): AnyRef = arg match {
       case e: Expression => transformExpression(e)
       case Some(value) => Some(recursiveTransform(value))
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
-      case stream: LazyList[_] => stream.map(recursiveTransform).force
+      case stream: Stream[_] => stream.map(recursiveTransform).force
+      case lazyList: LazyList[_] => lazyList.map(recursiveTransform).force
       case seq: Iterable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
       case null => null
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 23d26854a767..6683f2dbfb39 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.trees
 
 import java.util.UUID
 
+import scala.annotation.nowarn
 import scala.collection.{mutable, Map}
 import scala.jdk.CollectionConverters._
 import scala.reflect.ClassTag
@@ -378,12 +379,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
       case nonChild: AnyRef => nonChild
       case null => null
     }
+    @nowarn("cat=deprecation")
     val newArgs = mapProductIterator {
       case s: StructType => s // Don't convert struct types to some other type 
of Seq[StructField]
       // Handle Seq[TreeNode] in TreeNode parameters.
-      case s: LazyList[_] =>
-        // LazyList is lazy so we need to force materialization
+      case s: Stream[_] =>
+        // Stream is lazy so we need to force materialization
         s.map(mapChild).force
+      case l: LazyList[_] =>
+        // LazyList is lazy so we need to force materialization
+        l.map(mapChild).force
       case s: Seq[_] =>
         s.map(mapChild)
       case m: Map[_, _] =>
@@ -801,6 +806,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
       case other => other
     }
 
+    @nowarn("cat=deprecation")
     val newArgs = mapProductIterator {
       case arg: TreeNode[_] if containsChild(arg) =>
         arg.asInstanceOf[BaseType].clone()
@@ -813,7 +819,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
         case (_, other) => other
       }
       case d: DataType => d // Avoid unpacking Structs
-      case args: LazyList[_] => args.map(mapChild).force // Force 
materialization on stream
+      case args: Stream[_] => args.map(mapChild).force // Force 
materialization on stream
+      case args: LazyList[_] => args.map(mapChild).force // Force 
materialization on LazyList
       case args: Iterable[_] => args.map(mapChild)
       case nonChild: AnyRef => nonChild
       case null => null
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index 31f7e07143c5..f783083d0a44 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import scala.annotation.nowarn
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -83,6 +85,26 @@ class LogicalPlanSuite extends SparkFunSuite {
   }
 
   test("transformExpressions works with a Stream") {
+    val id1 = NamedExpression.newExprId
+    val id2 = NamedExpression.newExprId
+    @nowarn("cat=deprecation")
+    val plan = Project(Stream(
+      Alias(Literal(1), "a")(exprId = id1),
+      Alias(Literal(2), "b")(exprId = id2)),
+      OneRowRelation())
+    val result = plan.transformExpressions {
+      case Literal(v: Int, IntegerType) if v != 1 =>
+        Literal(v + 1, IntegerType)
+    }
+    @nowarn("cat=deprecation")
+    val expected = Project(Stream(
+      Alias(Literal(1), "a")(exprId = id1),
+      Alias(Literal(3), "b")(exprId = id2)),
+      OneRowRelation())
+    assert(result.sameResult(expected))
+  }
+
+  test("SPARK-45685: transformExpressions works with a LazyList") {
     val id1 = NamedExpression.newExprId
     val id2 = NamedExpression.newExprId
     val plan = Project(LazyList(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 4dbadef93a07..21542d43eac9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.trees
 import java.math.BigInteger
 import java.util.UUID
 
+import scala.annotation.nowarn
 import scala.collection.mutable.ArrayBuffer
 
 import org.json4s.JsonAST._
@@ -693,6 +694,22 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
   }
 
   test("transform works on stream of children") {
+    @nowarn("cat=deprecation")
+    val before = Coalesce(Stream(Literal(1), Literal(2)))
+    // Note it is a bit tricky to exhibit the broken behavior. Basically we 
want to create the
+    // situation in which the TreeNode.mapChildren function's change detection 
is not triggered. A
+    // stream's first element is typically materialized, so in order to not 
trip the TreeNode change
+    // detection logic, we should not change the first element in the sequence.
+    val result = before.transform {
+      case Literal(v: Int, IntegerType) if v != 1 =>
+        Literal(v + 1, IntegerType)
+    }
+    @nowarn("cat=deprecation")
+    val expected = Coalesce(Stream(Literal(1), Literal(3)))
+    assert(result === expected)
+  }
+
+  test("SPARK-45685: transform works on LazyList of children") {
     val before = Coalesce(LazyList(Literal(1), Literal(2)))
     // Note it is a bit tricky to exhibit the broken behavior. Basically we 
want to create the
     // situation in which the TreeNode.mapChildren function's change detection 
is not triggered. A
@@ -707,6 +724,16 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
   }
 
   test("withNewChildren on stream of children") {
+    @nowarn("cat=deprecation")
+    val before = Coalesce(Stream(Literal(1), Literal(2)))
+    @nowarn("cat=deprecation")
+    val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
+    @nowarn("cat=deprecation")
+    val expected = Coalesce(Stream(Literal(1), Literal(3)))
+    assert(result === expected)
+  }
+
+  test("SPARK-45685: withNewChildren on LazyList of children") {
     val before = Coalesce(LazyList(Literal(1), Literal(2)))
     val result = before.withNewChildren(LazyList(Literal(1), Literal(3)))
     val expected = Coalesce(LazyList(Literal(1), Literal(3)))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 382f8cf8861a..6ec0836f704c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -165,8 +165,10 @@ trait CodegenSupport extends SparkPlan {
         }
       }
 
+    @scala.annotation.nowarn("cat=deprecation")
     val inputVars = inputVarsCandidate match {
-      case stream: LazyList[ExprCode] => stream.force
+      case stream: Stream[ExprCode] => stream.force
+      case lazyList: LazyList[ExprCode] => lazyList.force
       case other => other
     }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 15de4c5cc5b2..1400ee25f431 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -744,6 +744,14 @@ class PlannerSuite extends SharedSparkSession with 
AdaptiveSparkPlanHelper {
   }
 
   test("SPARK-24500: create union with stream of children") {
+    @scala.annotation.nowarn("cat=deprecation")
+    val df = Union(Stream(
+      Range(1, 1, 1, 1),
+      Range(1, 2, 1, 1)))
+    df.queryExecution.executedPlan.execute()
+  }
+
+  test("SPARK-45685: create union with LazyList of children") {
     val df = Union(LazyList(
       Range(1, 1, 1, 1),
       Range(1, 2, 1, 1)))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 3aaf61ffba46..4d2d46582892 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -785,6 +785,16 @@ class WholeStageCodegenSuite extends QueryTest with 
SharedSparkSession
   }
 
   test("SPARK-26680: Stream in groupBy does not cause StackOverflowError") {
+    @scala.annotation.nowarn("cat=deprecation")
+    val groupByCols = Stream(col("key"))
+    val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value")
+      .groupBy(groupByCols: _*)
+      .max("value")
+
+    checkAnswer(df, Seq(Row(1, 3), Row(2, 3)))
+  }
+
+  test("SPARK-45685: LazyList in groupBy does not cause StackOverflowError") {
     val groupByCols = LazyList(col("key"))
     val df = Seq((1, 2), (2, 3), (1, 3)).toDF("key", "value")
       .groupBy(groupByCols: _*)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to