eason-yuchen-liu commented on code in PR #46944:
URL: https://github.com/apache/spark/pull/46944#discussion_r1659417200


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala:
##########
@@ -255,6 +271,22 @@ class StateStoreValueSchemaNotCompatible(
       "storedValueSchema" -> storedValueSchema,
       "newValueSchema" -> newValueSchema))
 
+class StateStoreSnapshotFileNotFound(fileToRead: String, clazz: String)
+  extends SparkUnsupportedOperationException(
+    errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_MISSING_SNAPSHOT_FILE",
+    messageParameters = Map(
+      "fileToRead" -> fileToRead,
+      "clazz" -> clazz))

Review Comment:
   It seems so. I learned from here: 
https://github.com/apache/spark/blob/6bfeb094248269920df8b107c86f0982404935cd/common/utils/src/main/resources/error/error-conditions.json#L236C54-L236C59



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala:
##########
@@ -255,6 +271,22 @@ class StateStoreValueSchemaNotCompatible(
       "storedValueSchema" -> storedValueSchema,
       "newValueSchema" -> newValueSchema))
 
+class StateStoreSnapshotFileNotFound(fileToRead: String, clazz: String)
+  extends SparkUnsupportedOperationException(
+    errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_MISSING_SNAPSHOT_FILE",
+    messageParameters = Map(
+      "fileToRead" -> fileToRead,
+      "clazz" -> clazz))

Review Comment:
   It seems so. the parameter names will not appear. I learned from here: 
https://github.com/apache/spark/blob/6bfeb094248269920df8b107c86f0982404935cd/common/utils/src/main/resources/error/error-conditions.json#L236C54-L236C59



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala:
##########
@@ -796,4 +973,228 @@ abstract class StateDataSourceReadSuite extends 
StateDataSourceTestBase with Ass
       testForSide("right")
     }
   }
+
+  protected def testSnapshotNotFound(): Unit = {
+    withTempDir { tempDir =>
+      val provider = getNewStateStoreProvider(tempDir.getAbsolutePath)
+      for (i <- 1 to 4) {
+        val store = provider.getStore(i - 1)
+        put(store, "a", i, i)
+        store.commit()
+        provider.doMaintenance() // create a snapshot every other delta file
+      }
+
+      val exc = intercept[SparkException] {
+        provider.asInstanceOf[SupportsFineGrainedReplay]
+          .replayReadStateFromSnapshot(1, 2)
+      }
+      checkError(exc, "CANNOT_LOAD_STATE_STORE.UNCATEGORIZED")
+    }
+  }
+
+  protected def testGetReadStoreWithStartVersion(): Unit = {
+    withTempDir { tempDir =>
+      val provider = getNewStateStoreProvider(tempDir.getAbsolutePath)
+      for (i <- 1 to 4) {
+        val store = provider.getStore(i - 1)
+        put(store, "a", i, i)
+        store.commit()
+        provider.doMaintenance()
+      }
+
+      val result =
+        provider.asInstanceOf[SupportsFineGrainedReplay]
+          .replayReadStateFromSnapshot(2, 3)
+
+      assert(get(result, "a", 1).get == 1)
+      assert(get(result, "a", 2).get == 2)
+      assert(get(result, "a", 3).get == 3)
+      assert(get(result, "a", 4).isEmpty)
+
+      provider.close()
+    }
+  }
+
+  protected def testSnapshotPartitionId(): Unit = {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[Int]
+      val df = inputData.toDF().limit(10)
+
+      testStream(df)(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, 1, 2, 3, 4),
+        CheckLastBatch(1, 2, 3, 4)
+      )
+
+      val stateDf = spark.read.format("statestore")
+        .option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 0)
+        .option(StateSourceOptions.SNAPSHOT_PARTITION_ID, 0)
+        .option(StateSourceOptions.BATCH_ID, 0)
+        .load(tempDir.getAbsolutePath)
+
+      // should result in only one partition && should not throw error in 
planning stage
+      assert(stateDf.rdd.getNumPartitions == 1)
+
+      // should throw error when partition id is out of range
+      val stateDfError = spark.read.format("statestore")
+        .option(StateSourceOptions.SNAPSHOT_START_BATCH_ID, 0)
+        .option(
+          StateSourceOptions.SNAPSHOT_PARTITION_ID, 1)
+        .option(StateSourceOptions.BATCH_ID, 0)
+        .load(tempDir.getAbsolutePath)
+
+      val exc = intercept[StateStoreSnapshotPartitionNotFound] {
+        stateDfError.show()
+      }
+      assert(exc.getErrorClass === 
"CANNOT_LOAD_STATE_STORE.SNAPSHOT_PARTITION_ID_NOT_FOUND")
+    }
+  }
+
+  private def testSnapshotStateDfAgainstStateDf(resourceDir: File): Unit = {
+    val stateSnapshotDf = spark.read.format("statestore")
+      .option("snapshotPartitionId", 0)
+      .option("snapshotStartBatchId", 1)
+      .load(resourceDir.getAbsolutePath)
+
+    val stateDf = spark.read.format("statestore")
+      .load(resourceDir.getAbsolutePath)
+      .filter(col("partition_id") === 0)
+
+    checkAnswer(stateSnapshotDf, stateDf)
+  }
+
+  protected def testSnapshotOnLimitState(providerName: String): Unit = {
+    /** The golden files are generated by:
+    withSQLConf({
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100"
+    }) {
+      val inputData = MemoryStream[(Int, Long)]
+      val query = inputData.toDF().limit(10)
+      testStream(query)(
+        StartStream(checkpointLocation = <...>),
+        AddData(inputData, (1, 1L), (2, 2L), (3, 3L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) },
+        AddData(inputData, (4, 4L), (5, 5L), (6, 6L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) },
+        AddData(inputData, (7, 7L), (8, 8L), (9, 9L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) },
+        AddData(inputData, (10, 10L), (11, 11L), (12, 12L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) }
+      )
+    }
+     */
+    val resourceUri = this.getClass.getResource(
+      s"/structured-streaming/checkpoint-version-4.0.0/$providerName/limit/"
+    ).toURI
+
+    testSnapshotStateDfAgainstStateDf(new File(resourceUri))
+  }
+
+  protected def testSnapshotOnAggregateState(providerName: String): Unit = {
+    /** The golden files are generated by:
+    withSQLConf({
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100"
+    }) {
+      val inputData = MemoryStream[(Int, Long)]
+      val query = inputData.toDF().groupBy("_1").count()
+      testStream(query, OutputMode.Update)(
+        StartStream(checkpointLocation = <...>),
+        AddData(inputData, (1, 1L), (2, 2L), (3, 3L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) },
+        AddData(inputData, (2, 2L), (3, 3L), (4, 4L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) },
+        AddData(inputData, (3, 3L), (4, 4L), (5, 5L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) },
+        AddData(inputData, (4, 4L), (5, 5L), (6, 6L)),
+        ProcessAllAvailable(),
+        Execute { _ => Thread.sleep(2000) }
+      )
+    }
+     */
+    val resourceUri = this.getClass.getResource(
+      s"/structured-streaming/checkpoint-version-4.0.0/$providerName/dedup/"
+    ).toURI
+
+    testSnapshotStateDfAgainstStateDf(new File(resourceUri))
+  }
+
+  protected def testSnapshotOnDeduplicateState(providerName: String): Unit = {
+    /** The golden files are generated by:
+    withSQLConf({

Review Comment:
   Will move one tab right.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to