This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 52df0cdf24e1 [SPARK-50162][SS][TESTS] Add tests for loading snapshot
with given version for transformWithState operator state and state data source
reader
52df0cdf24e1 is described below
commit 52df0cdf24e1a964825d1b1a367e79fef8057460
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Wed Nov 6 09:30:53 2024 +0900
[SPARK-50162][SS][TESTS] Add tests for loading snapshot with given version
for transformWithState operator state and state data source reader
### What changes were proposed in this pull request?
Add tests for loading snapshot with given version for transformWithState
operator state and state data source reader
### Why are the changes needed?
To add test coverage for snapshotStartBatchId integration of tws and state
data source reader
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Test only change
Added unit tests
```
===== POSSIBLE THREAD LEAK IN SUITE
o.a.s.sql.execution.datasources.v2.state.StateDataSourceTransformWithStateSuite,
threads: ForkJoinPool.commonPool-worker-6 (daemon=true),
ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3
(daemon=true), ForkJoinPool.commonPool-worker-7 (daemon=true),
ForkJoinPool.commonPool-worker-5 (daemon=true),
ForkJoinPool.commonPool-worker-3 (daemon=true),
ForkJoinPool.commonPool-worker-2 (daemon=true), rpc-boss-3-1 (daemon=true),
[...]
[info] Run completed in 2 minutes, 5 seconds.
[info] Total number of tests run: 23
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 23, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48710 from anishshri-db/task/SPARK-50162.
Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../v2/state/StateDataSourceReadSuite.scala | 16 +--
.../StateDataSourceTransformWithStateSuite.scala | 154 ++++++++++++++++++++-
2 files changed, 160 insertions(+), 10 deletions(-)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 300da03f73e1..4a274d51b62c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -457,19 +457,19 @@ class HDFSBackedStateDataSourceReadSuite extends
StateDataSourceReadSuite {
testSnapshotPartitionId()
}
- test("snapshotStatBatchId on limit state") {
+ test("snapshotStartBatchId on limit state") {
testSnapshotOnLimitState("hdfs")
}
- test("snapshotStatBatchId on aggregation state") {
+ test("snapshotStartBatchId on aggregation state") {
testSnapshotOnAggregateState("hdfs")
}
- test("snapshotStatBatchId on deduplication state") {
+ test("snapshotStartBatchId on deduplication state") {
testSnapshotOnDeduplicateState("hdfs")
}
- test("snapshotStatBatchId on join state") {
+ test("snapshotStartBatchId on join state") {
testSnapshotOnJoinState("hdfs", 1)
testSnapshotOnJoinState("hdfs", 2)
}
@@ -550,19 +550,19 @@ StateDataSourceReadSuite {
testSnapshotPartitionId()
}
- test("snapshotStatBatchId on limit state") {
+ test("snapshotStartBatchId on limit state") {
testSnapshotOnLimitState("rocksdb")
}
- test("snapshotStatBatchId on aggregation state") {
+ test("snapshotStartBatchId on aggregation state") {
testSnapshotOnAggregateState("rocksdb")
}
- test("snapshotStatBatchId on deduplication state") {
+ test("snapshotStartBatchId on deduplication state") {
testSnapshotOnDeduplicateState("rocksdb")
}
- test("snapshotStatBatchId on join state") {
+ test("snapshotStartBatchId on join state") {
testSnapshotOnJoinState("rocksdb", 1)
testSnapshotOnJoinState("rocksdb", 2)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index d00827fbd3b2..baab6327b35c 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -16,15 +16,20 @@
*/
package org.apache.spark.sql.execution.datasources.v2.state
+import java.io.File
import java.time.Duration
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
-import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider, TestClass}
-import org.apache.spark.sql.functions.{explode, timestamp_seconds}
+import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBFileManager, RocksDBStateStoreProvider, TestClass}
+import org.apache.spark.sql.functions.{col, explode, timestamp_seconds}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent,
MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor,
OutputMode, RunningCountStatefulProcessor,
RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor,
StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues,
TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.Utils
/** Stateful processor of single value state var with non-primitive type */
class StatefulProcessorWithSingleValueVar extends
RunningCountStatefulProcessor {
@@ -997,4 +1002,149 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
}
}
}
+
+ /**
+ * Note that we cannot use the golden files approach for transformWithState.
The new schema
+ * format keeps track of the schema file path as an absolute path which
cannot be used with
+ * the getResource model used in other similar tests. Hence, we force the
snapshot creation
+ * for given versions and ensure that we are loading from given start
snapshot version for loading
+ * the state data.
+ */
+ testWithChangelogCheckpointingEnabled("snapshotStartBatchId with
transformWithState") {
+ class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int,
Long), (Int, Long)] {
+ @transient protected var _countState: ValueState[Long] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong,
+ TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: Int,
+ inputRows: Iterator[(Int, Long)],
+ timerValues: TimerValues): Iterator[(Int, Long)] = {
+ val count = _countState.getOption().getOrElse(0L)
+ var totalSum = 0L
+ inputRows.foreach { entry =>
+ totalSum += entry._2
+ }
+ _countState.update(count + totalSum)
+ Iterator((key, count + totalSum))
+ }
+ }
+
+ withTempDir { tmpDir =>
+ withSQLConf(
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
+ SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+ SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
+ val inputData = MemoryStream[(Int, Long)]
+ val query = inputData
+ .toDS()
+ .groupByKey(_._1)
+ .transformWithState(new AggregationStatefulProcessor(),
+ TimeMode.None(),
+ OutputMode.Append())
+ testStream(query)(
+ StartStream(checkpointLocation = tmpDir.getCanonicalPath),
+ AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L)),
+ ProcessAllAvailable(),
+ AddData(inputData, (5, 1L), (6, 2L), (7, 3L), (8, 4L)),
+ ProcessAllAvailable(),
+ AddData(inputData, (9, 1L), (10, 2L), (11, 3L), (12, 4L)),
+ ProcessAllAvailable(),
+ AddData(inputData, (13, 1L), (14, 2L), (15, 3L), (16, 4L)),
+ ProcessAllAvailable(),
+ AddData(inputData, (17, 1L), (18, 2L), (19, 3L), (20, 4L)),
+ ProcessAllAvailable(),
+ // Ensure that we get a chance to upload created snapshots
+ Execute { _ => Thread.sleep(5000) },
+ StopStream
+ )
+ }
+
+ // Create a file manager for the state store with opId=0 and partition=4
+ val dfsRootDir = new File(tmpDir.getAbsolutePath + "/state/0/4")
+ val fileManager = new RocksDBFileManager(
+ dfsRootDir.getAbsolutePath, Utils.createTempDir(), new Configuration,
+ CompressionCodec.LZ4)
+
+ // Read the changelog for one of the partitions at version 3 and
+ // ensure that we have two entries
+ // For this test - keys 9 and 12 are written at version 3 for partition 4
+ val changelogReader = fileManager.getChangelogReader(3, true)
+ val entries = changelogReader.toSeq
+ assert(entries.size == 2)
+ val retainEntry = entries.head
+
+ // Retain one of the entries and delete the changelog file
+ val changelogFilePath = dfsRootDir.getAbsolutePath + "/3.changelog"
+ Utils.deleteRecursively(new File(changelogFilePath))
+
+ // Write the retained entry back to the changelog
+ val changelogWriter = fileManager.getChangeLogWriter(3, true)
+ changelogWriter.put(retainEntry._2, retainEntry._3)
+ changelogWriter.commit()
+
+ // Ensure that we have only one entry in the changelog for version 3
+ // For this test - key 9 is retained and key 12 is deleted
+ val changelogReader1 = fileManager.getChangelogReader(3, true)
+ val entries1 = changelogReader1.toSeq
+ assert(entries1.size == 1)
+
+ // Ensure that the state matches for the partition that is not modified
and does not match for
+ // the other partition
+ Seq(1, 4).foreach { partition =>
+ val stateSnapshotDf = spark
+ .read
+ .format("statestore")
+ .option("snapshotPartitionId", partition)
+ .option("snapshotStartBatchId", 1)
+ .option("stateVarName", "countState")
+ .load(tmpDir.getCanonicalPath)
+
+ val stateDf = spark
+ .read
+ .format("statestore")
+ .option("stateVarName", "countState")
+ .load(tmpDir.getCanonicalPath)
+ .filter(col("partition_id") === partition)
+
+ if (partition == 1) {
+ checkAnswer(stateSnapshotDf, stateDf)
+ } else {
+ // Ensure that key 12 is not present in the final state loaded from
given snapshot
+ val resultDfForSnapshot = stateSnapshotDf.selectExpr(
+ "key.value AS groupingKey",
+ "value.value AS count",
+ "partition_id")
+ checkAnswer(resultDfForSnapshot,
+ Seq(Row(16, 4L, 4),
+ Row(17, 1L, 4),
+ Row(19, 3L, 4),
+ Row(2, 2L, 4),
+ Row(6, 2L, 4),
+ Row(9, 1L, 4)))
+
+ // Ensure that key 12 is present in the final state loaded from the
latest snapshot
+ val resultDf = stateDf.selectExpr(
+ "key.value AS groupingKey",
+ "value.value AS count",
+ "partition_id")
+
+ checkAnswer(resultDf,
+ Seq(Row(16, 4L, 4),
+ Row(17, 1L, 4),
+ Row(19, 3L, 4),
+ Row(2, 2L, 4),
+ Row(6, 2L, 4),
+ Row(9, 1L, 4),
+ Row(12, 4L, 4)))
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]