HeartSaVioR commented on code in PR #50123:
URL: https://github.com/apache/spark/pull/50123#discussion_r2038644963
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala:
##########
@@ -158,281 +158,162 @@ class StateStoreCoordinatorSuite extends SparkFunSuite
with SharedSparkContext {
}
}
- Seq(
- ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName),
- ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName)
- ).foreach {
- case (providerName, providerClassName) =>
- test(
- s"SPARK-51358: Snapshot uploads in $providerName are properly reported
to the coordinator"
- ) {
- withCoordinatorAndSQLConf(
- sc,
- SQLConf.SHUFFLE_PARTITIONS.key -> "5",
- SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
- SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
- SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
- SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
- RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX +
".changelogCheckpointing.enabled" -> "true",
- SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key ->
"true",
-
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key ->
"2",
- SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key ->
"0"
- ) {
- case (coordRef, spark) =>
- import spark.implicits._
- implicit val sqlContext = spark.sqlContext
-
- // Start a query and run some data to force snapshot uploads
- val inputData = MemoryStream[Int]
- val aggregated = inputData.toDF().dropDuplicates()
- val checkpointLocation = Utils.createTempDir().getAbsoluteFile
- val query = aggregated.writeStream
- .format("memory")
- .outputMode("update")
- .queryName("query")
- .option("checkpointLocation", checkpointLocation.toString)
- .start()
- // Add, commit, and wait multiple times to force snapshot versions
and time difference
- (0 until 6).foreach { _ =>
- inputData.addData(1, 2, 3)
- query.processAllAvailable()
- Thread.sleep(500)
- }
- val streamingQuery =
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
- val stateCheckpointDir =
streamingQuery.lastExecution.checkpointLocation
- val latestVersion = streamingQuery.lastProgress.batchId + 1
-
- // Verify all stores have uploaded a snapshot and it's logged by
the coordinator
- (0 until
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
- partitionId =>
- val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
- val providerId = StateStoreProviderId(storeId, query.runId)
-
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
- }
- // Verify that we should not have any state stores lagging behind
- assert(coordRef.getLaggingStoresForTesting(query.runId,
latestVersion).isEmpty)
- query.stop()
- }
- }
- }
+ private val allJoinStateStoreNames: Seq[String] =
+ SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
- Seq(
+ /** Lists the state store providers used for a test, and the set of lagging
partition IDs */
+ private val regularStateStoreProviders = Seq(
+ ("RocksDBStateStoreProvider", classOf[RocksDBStateStoreProvider].getName,
Set.empty[Int]),
+ ("HDFSStateStoreProvider", classOf[HDFSBackedStateStoreProvider].getName,
Set.empty[Int])
+ )
+
+ /** Lists the state store providers used for a test, and the set of lagging
partition IDs */
+ private val faultyStateStoreProviders = Seq(
(
"RocksDBSkipMaintenanceOnCertainPartitionsProvider",
- classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName
+ classOf[RocksDBSkipMaintenanceOnCertainPartitionsProvider].getName,
+ Set(0, 1)
),
(
"HDFSBackedSkipMaintenanceOnCertainPartitionsProvider",
- classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName
+ classOf[HDFSBackedSkipMaintenanceOnCertainPartitionsProvider].getName,
+ Set(0, 1)
)
- ).foreach {
- case (providerName, providerClassName) =>
- test(
- s"SPARK-51358: Snapshot uploads in $providerName are properly reported
to the coordinator"
- ) {
- withCoordinatorAndSQLConf(
- sc,
- SQLConf.SHUFFLE_PARTITIONS.key -> "5",
- SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
- SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3",
- SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
- SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName,
- RocksDBConf.ROCKSDB_SQL_CONF_NAME_PREFIX +
".changelogCheckpointing.enabled" -> "true",
- SQLConf.STATE_STORE_COORDINATOR_REPORT_SNAPSHOT_UPLOAD_LAG.key ->
"true",
-
SQLConf.STATE_STORE_COORDINATOR_MULTIPLIER_FOR_MIN_VERSION_DIFF_TO_LOG.key ->
"2",
- SQLConf.STATE_STORE_COORDINATOR_SNAPSHOT_LAG_REPORT_INTERVAL.key ->
"0"
- ) {
- case (coordRef, spark) =>
- import spark.implicits._
- implicit val sqlContext = spark.sqlContext
-
- // Start a query and run some data to force snapshot uploads
- val inputData = MemoryStream[Int]
- val aggregated = inputData.toDF().dropDuplicates()
- val checkpointLocation = Utils.createTempDir().getAbsoluteFile
- val query = aggregated.writeStream
- .format("memory")
- .outputMode("update")
- .queryName("query")
- .option("checkpointLocation", checkpointLocation.toString)
- .start()
- // Add, commit, and wait multiple times to force snapshot versions
and time difference
- (0 until 6).foreach { _ =>
- inputData.addData(1, 2, 3)
- query.processAllAvailable()
- Thread.sleep(500)
- }
- val streamingQuery =
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
- val stateCheckpointDir =
streamingQuery.lastExecution.checkpointLocation
- val latestVersion = streamingQuery.lastProgress.batchId + 1
-
- (0 until
query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
- partitionId =>
- val storeId = StateStoreId(stateCheckpointDir, 0, partitionId)
- val providerId = StateStoreProviderId(storeId, query.runId)
- if (partitionId <= 1) {
- // Verify state stores in partition 0/1 are lagging and
didn't upload anything
-
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).getOrElse(0) ==
0)
- } else {
- // Verify other stores uploaded a snapshot and it's logged
by the coordinator
-
assert(coordRef.getLatestSnapshotVersionForTesting(providerId).get >= 0)
- }
- }
- // We should have two state stores (id 0 and 1) that are lagging
behind at this point
- val laggingStores =
coordRef.getLaggingStoresForTesting(query.runId, latestVersion)
- assert(laggingStores.size == 2)
- assert(laggingStores.forall(_.storeId.partitionId <= 1))
- query.stop()
+ )
+
+ private val allStateStoreProviders =
+ regularStateStoreProviders ++ faultyStateStoreProviders
+
+ /**
+ * Verifies snapshot upload RPC messages from state stores are registered
and verifies
+ * the coordinator detected the correct lagging partitions.
+ */
+ private def verifySnapshotUploadEvents(
+ coordRef: StateStoreCoordinatorRef,
+ query: StreamingQuery,
+ badPartitions: Set[Int],
+ storeNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)): Unit = {
+ val streamingQuery =
query.asInstanceOf[StreamingQueryWrapper].streamingQuery
+ val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation
+ val latestVersion = streamingQuery.lastProgress.batchId + 1
+
+ // Verify all stores have uploaded a snapshot and it's logged by the
coordinator
+ (0 until query.sparkSession.conf.get(SQLConf.SHUFFLE_PARTITIONS)).foreach {
+ partitionId =>
+ // Verify for every store name listed
+ storeNames.foreach { storeName =>
+ val storeId = StateStoreId(stateCheckpointDir, 0, partitionId,
storeName)
+ val providerId = StateStoreProviderId(storeId, query.runId)
+ val latestSnapshotVersion =
coordRef.getLatestSnapshotVersionForTesting(providerId)
+ if (badPartitions.contains(partitionId)) {
+ assert(latestSnapshotVersion.getOrElse(0) == 0)
+ } else {
+ assert(latestSnapshotVersion.get >= 0)
+ }
}
- }
+ }
+ // Verify that only the bad partitions are all marked as lagging.
+ // Join queries should have all their state stores marked as lagging,
+ // which would be 4 stores per partition instead of 1.
+ val laggingStores = coordRef.getLaggingStoresForTesting(query.runId,
latestVersion)
+ assert(laggingStores.size == badPartitions.size * storeNames.size)
+ assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions)
}
- private val allJoinStateStoreNames: Seq[String] =
- SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
+ /** Sets up a stateful dropDuplicate query for testing */
+ private def setUpStatefulQuery(
+ inputData: MemoryStream[Int], queryName: String): StreamingQuery = {
+ // Set up a stateful drop duplicate query
+ val aggregated = inputData.toDF().dropDuplicates()
Review Comment:
It's fine, we don't need to address something so let's leave this as a nit.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]