This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6706c41cb42c [SPARK-49070][SS][SQL] 
TransformWithStateExec.initialState is rewritten incorrectly to produce invalid 
query plan
6706c41cb42c is described below

commit 6706c41cb42cbd270a6580385be67b2a2313df27
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Aug 1 12:24:43 2024 -0700

    [SPARK-49070][SS][SQL] TransformWithStateExec.initialState is rewritten 
incorrectly to produce invalid query plan
    
    ### What changes were proposed in this pull request?
    
    This patch fixes `TransformWithStateExec` so when its `hasInitialState` is 
false, the `initialState` won't be rewritten by planner incorrectly to produce 
invalid query plan which will cause unexpected errors for extension rules that 
rely on the correctness of query plan.
    
    ### Why are the changes needed?
    
    [SPARK-47363](https://issues.apache.org/jira/browse/SPARK-47363) added the 
support for users to provide initial state for streaming query. Such query 
operators like `TransformWithStateExec` might have `hasInitialState` as false 
which means the initial state related parameters are not used. But when query 
planner applies rules on the query, it will still apply on the initial state 
query plan. When `hasInitialState` is false, some related parameters like 
`initialStateGroupingAttrs` are  [...]
    
    For example, `EnsureRequirements` may apply invalid Sort and Exchange on 
the initial query plan. We encountered these invalid query plan in our 
extension rules.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47546 from viirya/fix_initial_state.
    
    Authored-by: Liang-Chi Hsieh <[email protected]>
    Signed-off-by: huaxingao <[email protected]>
---
 .../streaming/FlatMapGroupsWithStateExec.scala     |  9 +++-
 .../streaming/TransformWithStateExec.scala         |  9 +++-
 .../FlatMapGroupsWithStateDistributionSuite.scala  | 44 +++++++++++++---
 .../streaming/FlatMapGroupsWithStateSuite.scala    | 58 ++++++++++++++++++++++
 .../sql/streaming/TransformWithStateSuite.scala    | 41 +++++++++++++++
 ...StatefulOpClusteredDistributionTestHelper.scala |  9 ++++
 6 files changed, 158 insertions(+), 12 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 3ee1fc1db71f..d56dfebd61ba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -415,8 +415,13 @@ case class FlatMapGroupsWithStateExec(
   override def right: SparkPlan = initialState
 
   override protected def withNewChildrenInternal(
-      newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec =
-    copy(child = newLeft, initialState = newRight)
+      newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = {
+    if (hasInitialState) {
+      copy(child = newLeft, initialState = newRight)
+    } else {
+      copy(child = newLeft)
+    }
+  }
 
   override def createInputProcessor(
       store: StateStore): InputProcessor = new InputProcessor(store) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index acdb8372a67d..7f57f0b69cad 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -160,8 +160,13 @@ case class TransformWithStateExec(
   override def right: SparkPlan = initialState
 
   override protected def withNewChildrenInternal(
-      newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec =
-    copy(child = newLeft, initialState = newRight)
+      newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec = {
+    if (hasInitialState) {
+      copy(child = newLeft, initialState = newRight)
+    } else {
+      copy(child = newLeft)
+    }
+  }
 
   override def keyExpressions: Seq[Attribute] = groupingAttributes
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
index b597a2447108..04b5f3af6463 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
@@ -135,8 +135,15 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
         assert(flatMapGroupsWithStateExecs.length === 1)
         assert(requireStatefulOpClusteredDistribution(
           flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", 
"_2")), numPartitions))
-        assert(hasDesiredHashPartitioningInChildren(
-          flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", 
"_2")), numPartitions))
+        if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+          assert(hasDesiredHashPartitioningInChildren(
+            flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        } else {
+          assert(hasDesiredHashPartitioningInChildren(
+            Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        }
       }
     )
   }
@@ -236,8 +243,15 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
         assert(flatMapGroupsWithStateExecs.length === 1)
         assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head,
           Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions)))
-        assert(hasDesiredHashPartitioningInChildren(
-          flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", 
"_2")), numPartitions))
+        if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+          assert(hasDesiredHashPartitioningInChildren(
+            flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        } else {
+          assert(hasDesiredHashPartitioningInChildren(
+            Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        }
       }
     )
   }
@@ -328,8 +342,15 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
         assert(flatMapGroupsWithStateExecs.length === 1)
         assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head,
           Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions)))
-        assert(hasDesiredHashPartitioningInChildren(
-          flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", 
"_2")), numPartitions))
+        if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+          assert(hasDesiredHashPartitioningInChildren(
+            flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        } else {
+          assert(hasDesiredHashPartitioningInChildren(
+            Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        }
       }
     )
   }
@@ -449,8 +470,15 @@ class FlatMapGroupsWithStateDistributionSuite extends 
StreamTest
         assert(flatMapGroupsWithStateExecs.length === 1)
         assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head,
           Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions)))
-        assert(hasDesiredHashPartitioningInChildren(
-          flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1", 
"_2")), numPartitions))
+        if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+          assert(hasDesiredHashPartitioningInChildren(
+            flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        } else {
+          assert(hasDesiredHashPartitioningInChildren(
+            Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"), 
Seq("_1", "_2")),
+            numPartitions))
+        }
       }
     )
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index b35e996106f3..45a80a210fce 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -32,6 +32,7 @@ import 
org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
 import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.streaming._
 import 
org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper,
 MemoryStateStore, RocksDBStateStoreProvider, StateStore}
 import org.apache.spark.sql.functions.timestamp_seconds
@@ -903,6 +904,63 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
     }
   }
 
+  test("SPARK-49070: flatMapGroupsWithStateExec - valid initial state plan") {
+    withTempDir { dir =>
+      withSQLConf(
+        (SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false"),
+        (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath),
+        (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName)) {
+
+        val inputData = MemoryStream[Timestamp]
+        val stateFunc = (key: Int, values: Iterator[Timestamp], state: 
GroupState[Int]) => {
+          // Should never timeout. All batches should have data and even if a 
timeout is set,
+          // it should get cleared when the key receives data per contract.
+          require(!state.hasTimedOut, "The state should not have timed out!")
+          // Set state and timeout once, only on the first call. The timeout 
should get cleared
+          // in the subsequent batch which has data for the key.
+          if (!state.exists) {
+            state.update(0)
+            state.setTimeoutTimestamp(500)  // Timeout at 500 milliseconds.
+          }
+          0
+        }
+
+        val query = inputData.toDS()
+          .withWatermark("value", "0 seconds")
+          .groupByKey(_ => 0)  // Always the same key: 0.
+          .mapGroupsWithState(GroupStateTimeout.EventTimeTimeout())(stateFunc)
+          .writeStream
+          .format("console")
+          .outputMode("update")
+          .start()
+
+        try {
+          (1 to 2).foreach {i =>
+            inputData.addData(new Timestamp(i * 1000))
+            query.processAllAvailable()
+          }
+
+          val sparkPlan =
+            
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan
+          val flatMapGroupsWithStateExec = sparkPlan.collect {
+            case p: FlatMapGroupsWithStateExec => p
+          }.head
+
+          assert(!flatMapGroupsWithStateExec.hasInitialState)
+
+          // EnsureRequirements should not apply on the initial state plan
+          val exchange = flatMapGroupsWithStateExec.initialState.collect {
+            case s: ShuffleExchangeExec => s
+          }
+
+          assert(exchange.isEmpty)
+        } finally {
+          query.stop()
+        }
+      }
+    }
+  }
+
   def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
     test("SPARK-20714: watermark does not fail query when timeout = " + 
timeoutConf) {
       // Function to maintain running count up to 2, and then remove the count
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 5b91e0de1903..43db8921897c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkRuntimeException, 
SparkUnsupportedOperationExcepti
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Dataset, Encoders, Row}
 import org.apache.spark.sql.catalyst.util.stringToFile
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.functions.timestamp_seconds
@@ -1231,6 +1232,46 @@ class TransformWithStateSuite extends 
StateStoreMetricsTest
       }
     }
   }
+
+  test("SPARK-49070: transformWithState - valid initial state plan") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+      withTempDir { srcDir =>
+        Seq("a", "b", "c").foreach(createFile(_, srcDir))
+        val df = createFileStream(srcDir)
+
+        var index = 0
+
+        val q = df.writeStream
+          .foreachBatch((_: Dataset[(String, String)], _: Long) => {
+            index += 1
+          })
+          .trigger(Trigger.AvailableNow)
+          .start()
+
+        try {
+          assert(q.awaitTermination(streamingTimeout.toMillis))
+
+          val sparkPlan =
+            
q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan
+          val transformWithStateExec = sparkPlan.collect {
+            case p: TransformWithStateExec => p
+          }.head
+
+          assert(!transformWithStateExec.hasInitialState)
+
+          // EnsureRequirements should not apply on the initial state plan
+          val exchange = transformWithStateExec.initialState.collect {
+            case s: ShuffleExchangeExec => s
+          }
+
+          assert(exchange.isEmpty)
+        } finally {
+          q.stop()
+        }
+      }
+    }
+  }
 }
 
 class TransformWithStateValidationSuite extends StateStoreMetricsTest {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
index f2684b8c39cd..877bca6898f0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
@@ -72,6 +72,15 @@ trait StatefulOpClusteredDistributionTestHelper extends 
SparkFunSuite {
     }
   }
 
+  protected def hasDesiredHashPartitioningInChildren(
+      children: Seq[SparkPlan],
+      desiredClusterColumns: Seq[Seq[String]],
+      desiredNumPartitions: Int): Boolean = {
+    children.zip(desiredClusterColumns).forall { case (child, clusterColumns) 
=>
+      hasDesiredHashPartitioning(child, clusterColumns, desiredNumPartitions)
+    }
+  }
+
   private def partitionExpressionsColumns(expressions: Seq[Expression]): 
Seq[String] = {
     expressions.flatMap {
       case ref: AttributeReference => Some(ref.name)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to