This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 52df0cdf24e1 [SPARK-50162][SS][TESTS] Add tests for loading snapshot 
with given version for transformWithState operator state and state data source 
reader
52df0cdf24e1 is described below

commit 52df0cdf24e1a964825d1b1a367e79fef8057460
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Wed Nov 6 09:30:53 2024 +0900

    [SPARK-50162][SS][TESTS] Add tests for loading snapshot with given version 
for transformWithState operator state and state data source reader
    
    ### What changes were proposed in this pull request?
    Add tests for loading snapshot with given version for transformWithState 
operator state and state data source reader
    
    ### Why are the changes needed?
    To add test coverage for snapshotStartBatchId integration of tws and state 
data source reader
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Test only change
    
    Added unit tests
    
    ```
    ===== POSSIBLE THREAD LEAK IN SUITE 
o.a.s.sql.execution.datasources.v2.state.StateDataSourceTransformWithStateSuite,
 threads: ForkJoinPool.commonPool-worker-6 (daemon=true), 
ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 
(daemon=true), ForkJoinPool.commonPool-worker-7 (daemon=true), 
ForkJoinPool.commonPool-worker-5 (daemon=true), 
ForkJoinPool.commonPool-worker-3 (daemon=true), 
ForkJoinPool.commonPool-worker-2 (daemon=true), rpc-boss-3-1 (daemon=true),  
[...]
    [info] Run completed in 2 minutes, 5 seconds.
    [info] Total number of tests run: 23
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 23, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48710 from anishshri-db/task/SPARK-50162.
    
    Authored-by: Anish Shrigondekar <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../v2/state/StateDataSourceReadSuite.scala        |  16 +--
 .../StateDataSourceTransformWithStateSuite.scala   | 154 ++++++++++++++++++++-
 2 files changed, 160 insertions(+), 10 deletions(-)

diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 300da03f73e1..4a274d51b62c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -457,19 +457,19 @@ class HDFSBackedStateDataSourceReadSuite extends 
StateDataSourceReadSuite {
     testSnapshotPartitionId()
   }
 
-  test("snapshotStatBatchId on limit state") {
+  test("snapshotStartBatchId on limit state") {
     testSnapshotOnLimitState("hdfs")
   }
 
-  test("snapshotStatBatchId on aggregation state") {
+  test("snapshotStartBatchId on aggregation state") {
     testSnapshotOnAggregateState("hdfs")
   }
 
-  test("snapshotStatBatchId on deduplication state") {
+  test("snapshotStartBatchId on deduplication state") {
     testSnapshotOnDeduplicateState("hdfs")
   }
 
-  test("snapshotStatBatchId on join state") {
+  test("snapshotStartBatchId on join state") {
     testSnapshotOnJoinState("hdfs", 1)
     testSnapshotOnJoinState("hdfs", 2)
   }
@@ -550,19 +550,19 @@ StateDataSourceReadSuite {
     testSnapshotPartitionId()
   }
 
-  test("snapshotStatBatchId on limit state") {
+  test("snapshotStartBatchId on limit state") {
     testSnapshotOnLimitState("rocksdb")
   }
 
-  test("snapshotStatBatchId on aggregation state") {
+  test("snapshotStartBatchId on aggregation state") {
     testSnapshotOnAggregateState("rocksdb")
   }
 
-  test("snapshotStatBatchId on deduplication state") {
+  test("snapshotStartBatchId on deduplication state") {
     testSnapshotOnDeduplicateState("rocksdb")
   }
 
-  test("snapshotStatBatchId on join state") {
+  test("snapshotStartBatchId on join state") {
     testSnapshotOnJoinState("rocksdb", 1)
     testSnapshotOnJoinState("rocksdb", 2)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index d00827fbd3b2..baab6327b35c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -16,15 +16,20 @@
  */
 package org.apache.spark.sql.execution.datasources.v2.state
 
+import java.io.File
 import java.time.Duration
 
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
-import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, TestClass}
-import org.apache.spark.sql.functions.{explode, timestamp_seconds}
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBFileManager, RocksDBStateStoreProvider, TestClass}
+import org.apache.spark.sql.functions.{col, explode, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, 
MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, 
OutputMode, RunningCountStatefulProcessor, 
RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, 
StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, 
TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
 import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.Utils
 
 /** Stateful processor of single value state var with non-primitive type */
 class StatefulProcessorWithSingleValueVar extends 
RunningCountStatefulProcessor {
@@ -997,4 +1002,149 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
       }
     }
   }
+
+  /**
+   * Note that we cannot use the golden files approach for transformWithState. 
The new schema
+   * format keeps track of the schema file path as an absolute path which 
cannot be used with
+   * the getResource model used in other similar tests. Hence, we force the 
snapshot creation
+   * for given versions and ensure that we are loading from given start 
snapshot version for loading
+   * the state data.
+   */
+  testWithChangelogCheckpointingEnabled("snapshotStartBatchId with 
transformWithState") {
+    class AggregationStatefulProcessor extends StatefulProcessor[Int, (Int, 
Long), (Int, Long)] {
+      @transient protected var _countState: ValueState[Long] = _
+
+      override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+        _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong,
+          TTLConfig.NONE)
+      }
+
+      override def handleInputRows(
+          key: Int,
+          inputRows: Iterator[(Int, Long)],
+          timerValues: TimerValues): Iterator[(Int, Long)] = {
+        val count = _countState.getOption().getOrElse(0L)
+        var totalSum = 0L
+        inputRows.foreach { entry =>
+          totalSum += entry._2
+        }
+        _countState.update(count + totalSum)
+        Iterator((key, count + totalSum))
+      }
+    }
+
+    withTempDir { tmpDir =>
+      withSQLConf(
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
+        SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+        SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
+        val inputData = MemoryStream[(Int, Long)]
+        val query = inputData
+          .toDS()
+          .groupByKey(_._1)
+          .transformWithState(new AggregationStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Append())
+        testStream(query)(
+          StartStream(checkpointLocation = tmpDir.getCanonicalPath),
+          AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L)),
+          ProcessAllAvailable(),
+          AddData(inputData, (5, 1L), (6, 2L), (7, 3L), (8, 4L)),
+          ProcessAllAvailable(),
+          AddData(inputData, (9, 1L), (10, 2L), (11, 3L), (12, 4L)),
+          ProcessAllAvailable(),
+          AddData(inputData, (13, 1L), (14, 2L), (15, 3L), (16, 4L)),
+          ProcessAllAvailable(),
+          AddData(inputData, (17, 1L), (18, 2L), (19, 3L), (20, 4L)),
+          ProcessAllAvailable(),
+          // Ensure that we get a chance to upload created snapshots
+          Execute { _ => Thread.sleep(5000) },
+          StopStream
+        )
+      }
+
+      // Create a file manager for the state store with opId=0 and partition=4
+      val dfsRootDir = new File(tmpDir.getAbsolutePath + "/state/0/4")
+      val fileManager = new RocksDBFileManager(
+        dfsRootDir.getAbsolutePath, Utils.createTempDir(), new Configuration,
+        CompressionCodec.LZ4)
+
+      // Read the changelog for one of the partitions at version 3 and
+      // ensure that we have two entries
+      // For this test - keys 9 and 12 are written at version 3 for partition 4
+      val changelogReader = fileManager.getChangelogReader(3, true)
+      val entries = changelogReader.toSeq
+      assert(entries.size == 2)
+      val retainEntry = entries.head
+
+      // Retain one of the entries and delete the changelog file
+      val changelogFilePath = dfsRootDir.getAbsolutePath + "/3.changelog"
+      Utils.deleteRecursively(new File(changelogFilePath))
+
+      // Write the retained entry back to the changelog
+      val changelogWriter = fileManager.getChangeLogWriter(3, true)
+      changelogWriter.put(retainEntry._2, retainEntry._3)
+      changelogWriter.commit()
+
+      // Ensure that we have only one entry in the changelog for version 3
+      // For this test - key 9 is retained and key 12 is deleted
+      val changelogReader1 = fileManager.getChangelogReader(3, true)
+      val entries1 = changelogReader1.toSeq
+      assert(entries1.size == 1)
+
+      // Ensure that the state matches for the partition that is not modified 
and does not match for
+      // the other partition
+      Seq(1, 4).foreach { partition =>
+        val stateSnapshotDf = spark
+          .read
+          .format("statestore")
+          .option("snapshotPartitionId", partition)
+          .option("snapshotStartBatchId", 1)
+          .option("stateVarName", "countState")
+          .load(tmpDir.getCanonicalPath)
+
+        val stateDf = spark
+          .read
+          .format("statestore")
+          .option("stateVarName", "countState")
+          .load(tmpDir.getCanonicalPath)
+          .filter(col("partition_id") === partition)
+
+        if (partition == 1) {
+          checkAnswer(stateSnapshotDf, stateDf)
+        } else {
+          // Ensure that key 12 is not present in the final state loaded from 
given snapshot
+          val resultDfForSnapshot = stateSnapshotDf.selectExpr(
+            "key.value AS groupingKey",
+            "value.value AS count",
+            "partition_id")
+          checkAnswer(resultDfForSnapshot,
+            Seq(Row(16, 4L, 4),
+              Row(17, 1L, 4),
+              Row(19, 3L, 4),
+              Row(2, 2L, 4),
+              Row(6, 2L, 4),
+              Row(9, 1L, 4)))
+
+          // Ensure that key 12 is present in the final state loaded from the 
latest snapshot
+          val resultDf = stateDf.selectExpr(
+            "key.value AS groupingKey",
+            "value.value AS count",
+            "partition_id")
+
+          checkAnswer(resultDf,
+            Seq(Row(16, 4L, 4),
+              Row(17, 1L, 4),
+              Row(19, 3L, 4),
+              Row(2, 2L, 4),
+              Row(6, 2L, 4),
+              Row(9, 1L, 4),
+              Row(12, 4L, 4)))
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to