This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6706c41cb42c [SPARK-49070][SS][SQL]
TransformWithStateExec.initialState is rewritten incorrectly to produce invalid
query plan
6706c41cb42c is described below
commit 6706c41cb42cbd270a6580385be67b2a2313df27
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Thu Aug 1 12:24:43 2024 -0700
[SPARK-49070][SS][SQL] TransformWithStateExec.initialState is rewritten
incorrectly to produce invalid query plan
### What changes were proposed in this pull request?
This patch fixes `TransformWithStateExec` so when its `hasInitialState` is
false, the `initialState` won't be rewritten by planner incorrectly to produce
invalid query plan which will cause unexpected errors for extension rules that
rely on the correctness of query plan.
### Why are the changes needed?
[SPARK-47363](https://issues.apache.org/jira/browse/SPARK-47363) added the
support for users to provide initial state for streaming query. Such query
operators like `TransformWithStateExec` might have `hasInitialState` as false
which means the initial state related parameters are not used. But when query
planner applies rules on the query, it will still apply on the initial state
query plan. When `hasInitialState` is false, some related parameters like
`initialStateGroupingAttrs` are [...]
For example, `EnsureRequirements` may apply invalid Sort and Exchange on
the initial query plan. We encountered these invalid query plan in our
extension rules.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47546 from viirya/fix_initial_state.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: huaxingao <[email protected]>
---
.../streaming/FlatMapGroupsWithStateExec.scala | 9 +++-
.../streaming/TransformWithStateExec.scala | 9 +++-
.../FlatMapGroupsWithStateDistributionSuite.scala | 44 +++++++++++++---
.../streaming/FlatMapGroupsWithStateSuite.scala | 58 ++++++++++++++++++++++
.../sql/streaming/TransformWithStateSuite.scala | 41 +++++++++++++++
...StatefulOpClusteredDistributionTestHelper.scala | 9 ++++
6 files changed, 158 insertions(+), 12 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 3ee1fc1db71f..d56dfebd61ba 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -415,8 +415,13 @@ case class FlatMapGroupsWithStateExec(
override def right: SparkPlan = initialState
override protected def withNewChildrenInternal(
- newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec =
- copy(child = newLeft, initialState = newRight)
+ newLeft: SparkPlan, newRight: SparkPlan): FlatMapGroupsWithStateExec = {
+ if (hasInitialState) {
+ copy(child = newLeft, initialState = newRight)
+ } else {
+ copy(child = newLeft)
+ }
+ }
override def createInputProcessor(
store: StateStore): InputProcessor = new InputProcessor(store) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index acdb8372a67d..7f57f0b69cad 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -160,8 +160,13 @@ case class TransformWithStateExec(
override def right: SparkPlan = initialState
override protected def withNewChildrenInternal(
- newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec =
- copy(child = newLeft, initialState = newRight)
+ newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec = {
+ if (hasInitialState) {
+ copy(child = newLeft, initialState = newRight)
+ } else {
+ copy(child = newLeft)
+ }
+ }
override def keyExpressions: Seq[Attribute] = groupingAttributes
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
index b597a2447108..04b5f3af6463 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
@@ -135,8 +135,15 @@ class FlatMapGroupsWithStateDistributionSuite extends
StreamTest
assert(flatMapGroupsWithStateExecs.length === 1)
assert(requireStatefulOpClusteredDistribution(
flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1",
"_2")), numPartitions))
- assert(hasDesiredHashPartitioningInChildren(
- flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1",
"_2")), numPartitions))
+ if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+ assert(hasDesiredHashPartitioningInChildren(
+ flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ } else {
+ assert(hasDesiredHashPartitioningInChildren(
+ Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ }
}
)
}
@@ -236,8 +243,15 @@ class FlatMapGroupsWithStateDistributionSuite extends
StreamTest
assert(flatMapGroupsWithStateExecs.length === 1)
assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head,
Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions)))
- assert(hasDesiredHashPartitioningInChildren(
- flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1",
"_2")), numPartitions))
+ if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+ assert(hasDesiredHashPartitioningInChildren(
+ flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ } else {
+ assert(hasDesiredHashPartitioningInChildren(
+ Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ }
}
)
}
@@ -328,8 +342,15 @@ class FlatMapGroupsWithStateDistributionSuite extends
StreamTest
assert(flatMapGroupsWithStateExecs.length === 1)
assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head,
Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions)))
- assert(hasDesiredHashPartitioningInChildren(
- flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1",
"_2")), numPartitions))
+ if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+ assert(hasDesiredHashPartitioningInChildren(
+ flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ } else {
+ assert(hasDesiredHashPartitioningInChildren(
+ Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ }
}
)
}
@@ -449,8 +470,15 @@ class FlatMapGroupsWithStateDistributionSuite extends
StreamTest
assert(flatMapGroupsWithStateExecs.length === 1)
assert(requireClusteredDistribution(flatMapGroupsWithStateExecs.head,
Seq(Seq("_1", "_2"), Seq("_1", "_2")), Some(numPartitions)))
- assert(hasDesiredHashPartitioningInChildren(
- flatMapGroupsWithStateExecs.head, Seq(Seq("_1", "_2"), Seq("_1",
"_2")), numPartitions))
+ if (flatMapGroupsWithStateExecs.head.hasInitialState) {
+ assert(hasDesiredHashPartitioningInChildren(
+ flatMapGroupsWithStateExecs.head.children, Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ } else {
+ assert(hasDesiredHashPartitioningInChildren(
+ Seq(flatMapGroupsWithStateExecs.head.left), Seq(Seq("_1", "_2"),
Seq("_1", "_2")),
+ numPartitions))
+ }
}
)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index b35e996106f3..45a80a210fce 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -32,6 +32,7 @@ import
org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.RDDScanExec
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.streaming._
import
org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper,
MemoryStateStore, RocksDBStateStoreProvider, StateStore}
import org.apache.spark.sql.functions.timestamp_seconds
@@ -903,6 +904,63 @@ class FlatMapGroupsWithStateSuite extends
StateStoreMetricsTest {
}
}
+ test("SPARK-49070: flatMapGroupsWithStateExec - valid initial state plan") {
+ withTempDir { dir =>
+ withSQLConf(
+ (SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false"),
+ (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath),
+ (SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName)) {
+
+ val inputData = MemoryStream[Timestamp]
+ val stateFunc = (key: Int, values: Iterator[Timestamp], state:
GroupState[Int]) => {
+ // Should never timeout. All batches should have data and even if a
timeout is set,
+ // it should get cleared when the key receives data per contract.
+ require(!state.hasTimedOut, "The state should not have timed out!")
+ // Set state and timeout once, only on the first call. The timeout
should get cleared
+ // in the subsequent batch which has data for the key.
+ if (!state.exists) {
+ state.update(0)
+ state.setTimeoutTimestamp(500) // Timeout at 500 milliseconds.
+ }
+ 0
+ }
+
+ val query = inputData.toDS()
+ .withWatermark("value", "0 seconds")
+ .groupByKey(_ => 0) // Always the same key: 0.
+ .mapGroupsWithState(GroupStateTimeout.EventTimeTimeout())(stateFunc)
+ .writeStream
+ .format("console")
+ .outputMode("update")
+ .start()
+
+ try {
+ (1 to 2).foreach {i =>
+ inputData.addData(new Timestamp(i * 1000))
+ query.processAllAvailable()
+ }
+
+ val sparkPlan =
+
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan
+ val flatMapGroupsWithStateExec = sparkPlan.collect {
+ case p: FlatMapGroupsWithStateExec => p
+ }.head
+
+ assert(!flatMapGroupsWithStateExec.hasInitialState)
+
+ // EnsureRequirements should not apply on the initial state plan
+ val exchange = flatMapGroupsWithStateExec.initialState.collect {
+ case s: ShuffleExchangeExec => s
+ }
+
+ assert(exchange.isEmpty)
+ } finally {
+ query.stop()
+ }
+ }
+ }
+ }
+
def testWithTimeout(timeoutConf: GroupStateTimeout): Unit = {
test("SPARK-20714: watermark does not fail query when timeout = " +
timeoutConf) {
// Function to maintain running count up to 2, and then remove the count
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 5b91e0de1903..43db8921897c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkRuntimeException,
SparkUnsupportedOperationExcepti
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, Encoders, Row}
import org.apache.spark.sql.catalyst.util.stringToFile
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.functions.timestamp_seconds
@@ -1231,6 +1232,46 @@ class TransformWithStateSuite extends
StateStoreMetricsTest
}
}
}
+
+ test("SPARK-49070: transformWithState - valid initial state plan") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ withTempDir { srcDir =>
+ Seq("a", "b", "c").foreach(createFile(_, srcDir))
+ val df = createFileStream(srcDir)
+
+ var index = 0
+
+ val q = df.writeStream
+ .foreachBatch((_: Dataset[(String, String)], _: Long) => {
+ index += 1
+ })
+ .trigger(Trigger.AvailableNow)
+ .start()
+
+ try {
+ assert(q.awaitTermination(streamingTimeout.toMillis))
+
+ val sparkPlan =
+
q.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.executedPlan
+ val transformWithStateExec = sparkPlan.collect {
+ case p: TransformWithStateExec => p
+ }.head
+
+ assert(!transformWithStateExec.hasInitialState)
+
+ // EnsureRequirements should not apply on the initial state plan
+ val exchange = transformWithStateExec.initialState.collect {
+ case s: ShuffleExchangeExec => s
+ }
+
+ assert(exchange.isEmpty)
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
}
class TransformWithStateValidationSuite extends StateStoreMetricsTest {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
index f2684b8c39cd..877bca6898f0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/StatefulOpClusteredDistributionTestHelper.scala
@@ -72,6 +72,15 @@ trait StatefulOpClusteredDistributionTestHelper extends
SparkFunSuite {
}
}
+ protected def hasDesiredHashPartitioningInChildren(
+ children: Seq[SparkPlan],
+ desiredClusterColumns: Seq[Seq[String]],
+ desiredNumPartitions: Int): Boolean = {
+ children.zip(desiredClusterColumns).forall { case (child, clusterColumns)
=>
+ hasDesiredHashPartitioning(child, clusterColumns, desiredNumPartitions)
+ }
+ }
+
private def partitionExpressionsColumns(expressions: Seq[Expression]):
Seq[String] = {
expressions.flatMap {
case ref: AttributeReference => Some(ref.name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]