micheal-o commented on code in PR #53720:
URL: https://github.com/apache/spark/pull/53720#discussion_r2679160973
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -133,7 +136,7 @@ class StateRewriter(
val stateVarsIfTws = getStateVariablesIfTWS(opMetadata)
// Rewrite each state store of the operator
- stateStoresMetadata.foreach { stateStoreMetadata =>
+ opMetadata.operatorInfo.operatorId -> stateStoresMetadata.map {
stateStoreMetadata =>
Review Comment:
nit: have a `val` here that collects the checkpointIds and then make this
`operatorId -> checkpointIds` a separate line
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -543,25 +548,30 @@ class RocksDB(
lineageManager.clear()
throw t
}
- if (enableChangelogCheckpointing && !readOnly) {
+ if (conf.enableChangelogCheckpointing && !readOnly) {
Review Comment:
Add this comment here too:
https://github.com/apache/spark/blob/acc80fd1b205792a43cc352a8e492e34bbe880da/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala#L614
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -543,25 +548,30 @@ class RocksDB(
lineageManager.clear()
throw t
}
- if (enableChangelogCheckpointing && !readOnly) {
+ if (conf.enableChangelogCheckpointing && !readOnly) {
// Make sure we don't leak resource.
changelogWriter.foreach(_.abort())
- // Initialize the changelog writer with lineage info
- // The lineage stored in changelog files should normally start with
- // the version of a snapshot, except for the first few versions.
- // Because they are solely loaded from changelog file.
- // (e.g. with default minDeltasForSnapshot, there is only
1_uuid1.changelog, no 1_uuid1.zip)
- // It should end with exactly one version before the change log's
version.
- changelogWriter = Some(fileManager.getChangeLogWriter(
- version + 1,
- useColumnFamilies,
- sessionStateStoreCkptId,
- Some(currVersionLineage)))
+ if (loadEmpty) {
+ // No changelog writer for empty stores
Review Comment:
nit: "We don't want to write changelog file when loadEmpty is true"
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,23 +82,25 @@ class StateRewriter(
private val stateRootLocation = new Path(
resolvedCheckpointLocation,
StreamingCheckpointConstants.DIR_NAME_STATE).toString
- def run(): Unit = {
+ // return a Map[operator id, Array[stateStore -> Array[partition ->
StateStoreCheckpointInfo]]]
+ def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
logInfo(log"Starting state rewrite for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}, " +
log"readCheckpointLocation=" +
log"${MDC(CHECKPOINT_LOCATION,
readResolvedCheckpointLocation.getOrElse(""))}, " +
log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
- val (_, timeTakenMs) = Utils.timeTakenMs {
+ val (checkpointInfos, timeTakenMs) = Utils.timeTakenMs {
runInternal()
}
logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms
for " +
log"checkpointLocation=${MDC(CHECKPOINT_LOCATION,
resolvedCheckpointLocation)}")
Review Comment:
lets also include the checkpointInfos in the log here
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -124,7 +127,7 @@ class StateRewriter(
// Do rewrite for each operator
// We can potentially parallelize this, but for now, do sequentially
- allOperatorsMetadata.foreach { opMetadata =>
+ allOperatorsMetadata.map { opMetadata =>
Review Comment:
Before we do rewrite, we need to check if the readBatchId has checkpointIds
in the commitlog, if so, we should set
`SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION` in the `sqlConf` to enable
checkpoint v2.
We can't rely on the user to correctly set this conf in the session before
running repartition. They may forget and this conf isn't part of the confs
written to checkpoint. You can repro this, by not setting the conf in the
session and the rewrite will fail. Lets also have a test to validate this.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +292,32 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Option[Map[Long,
Array[Array[StateStoreCheckpointInfo]]]]): Unit = {
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-
- // todo: For checkpoint v2, we need to update the stateUniqueIds based on
the
- // newly created state commit. Will be done in subsequent PR.
- if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+ val commitMetadata = opIdToStateStoreCkptInfo.map {originalInfoMap =>
+ // opIdToStateStoreCkptInfo is Map[operatorId, Array[stateStore ->
Array[partition -> info]]]
+ // we change it to Map[operatorId, Array[partitionId -> Array[store ->
info]]]
+ val opIdToPartitionCkptInfo: Map[Long, Array[Array[String]]] =
+ originalInfoMap.map {
+ case(operator, storesSeq) =>
+ val numPartitions = storesSeq.head.length
+ operator -> (0 until numPartitions).map { partitionIdx =>
+ storesSeq.flatMap { storePartitions =>
+ storePartitions(partitionIdx).stateStoreCkptId
+ }
+ }.toArray
+ }
+ // Include checkpoint IDs in commit metadata
+ CommitMetadata(
Review Comment:
nit: can just do `latestCommit.copy(stateUniqueIds =
Option(opIdToPartitionCkptInfo))`
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -443,77 +443,82 @@ class RocksDB(
private def loadWithCheckpointId(
version: Long,
stateStoreCkptId: Option[String],
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ loadEmpty: Boolean = false): RocksDB = {
// An array contains lineage information from [snapShotVersion, version]
// (inclusive in both ends)
var currVersionLineage: Array[LineageItem] =
lineageManager.getLineageForCurrVersion()
try {
- if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
- stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+ if (loadEmpty || loadedVersion != version ||
loadedStateStoreCkptId.isEmpty ||
+ stateStoreCkptId.get != loadedStateStoreCkptId.get) {
closeDB(ignoreException = false)
-
- val (latestSnapshotVersion, latestSnapshotUniqueId) = {
- // Special handling when version is 0.
- // When loading the very first version (0), stateStoreCkptId does
not need to be defined
- // because there won't be 0.changelog / 0.zip file created in
RocksDB under v2.
- if (version == 0) {
- assert(stateStoreCkptId.isEmpty,
- "stateStoreCkptId should be empty when version is zero")
- (0L, None)
- // When there is a snapshot file, it is the ground truth, we can skip
- // reconstructing the lineage from changelog file.
- } else if (fileManager.existsSnapshotFile(version,
stateStoreCkptId)) {
- currVersionLineage = Array(LineageItem(version,
stateStoreCkptId.get))
- (version, stateStoreCkptId)
- } else {
- currVersionLineage = getLineageFromChangelogFile(version,
stateStoreCkptId) :+
- LineageItem(version, stateStoreCkptId.get)
- currVersionLineage = currVersionLineage.sortBy(_.version)
-
- val latestSnapshotVersionsAndUniqueId =
-
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
- latestSnapshotVersionsAndUniqueId match {
- case Some(pair) => (pair._1, Option(pair._2))
- case None if currVersionLineage.head.version == 1L =>
- logDebug(log"Cannot find latest snapshot based on lineage but
first version " +
- log"is 1, use 0 as default. Lineage: ${MDC(LogKeys.LINEAGE,
lineageManager)}")
- (0L, None)
- case _ =>
- throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
- printLineageItems(currVersionLineage))
+ if (loadEmpty) {
+ loadEmptyStore(version)
+ lineageManager.clear()
+ } else {
+ val (latestSnapshotVersion, latestSnapshotUniqueId) = {
+ // Special handling when version is 0.
+ // When loading the very first version (0), stateStoreCkptId does
not need to be defined
+ // because there won't be 0.changelog / 0.zip file created in
RocksDB under v2.
+ if (version == 0) {
+ assert(stateStoreCkptId.isEmpty,
+ "stateStoreCkptId should be empty when version is zero")
+ (0L, None)
+ // When there is a snapshot file, it is the ground truth, we can
skip
+ // reconstructing the lineage from changelog file.
+ } else if (fileManager.existsSnapshotFile(version,
stateStoreCkptId)) {
+ currVersionLineage = Array(LineageItem(version,
stateStoreCkptId.get))
+ (version, stateStoreCkptId)
+ } else {
+ currVersionLineage = getLineageFromChangelogFile(version,
stateStoreCkptId) :+
+ LineageItem(version, stateStoreCkptId.get)
+ currVersionLineage = currVersionLineage.sortBy(_.version)
+
+ val latestSnapshotVersionsAndUniqueId =
+
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
+ latestSnapshotVersionsAndUniqueId match {
+ case Some(pair) => (pair._1, Option(pair._2))
+ case None if currVersionLineage.head.version == 1L =>
+ logDebug(log"Cannot find latest snapshot based on lineage
but first version " +
+ log"is 1, use 0 as default. Lineage:
${MDC(LogKeys.LINEAGE, lineageManager)}")
+ (0L, None)
+ case _ =>
+ throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
+ printLineageItems(currVersionLineage))
+ }
}
}
- }
- logInfo(log"Loaded latestSnapshotVersion: ${
- MDC(LogKeys.SNAPSHOT_VERSION, latestSnapshotVersion)},
latestSnapshotUniqueId: ${
- MDC(LogKeys.UUID, latestSnapshotUniqueId)}")
+ logInfo(log"Loaded latestSnapshotVersion: ${
+ MDC(LogKeys.SNAPSHOT_VERSION, latestSnapshotVersion)},
latestSnapshotUniqueId: ${
+ MDC(LogKeys.UUID, latestSnapshotUniqueId)}")
- val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion,
- workingDir, rocksDBFileMapping, latestSnapshotUniqueId)
+ val metadata =
fileManager.loadCheckpointFromDfs(latestSnapshotVersion,
+ workingDir, rocksDBFileMapping, latestSnapshotUniqueId)
- loadedVersion = latestSnapshotVersion
+ loadedVersion = latestSnapshotVersion
- // reset the last snapshot version to the latest available snapshot
version
- lastSnapshotVersion = latestSnapshotVersion
- lineageManager.resetLineage(currVersionLineage)
+ // reset the last snapshot version to the latest available snapshot
version
+ lastSnapshotVersion = latestSnapshotVersion
+ lineageManager.resetLineage(currVersionLineage)
- // Initialize maxVersion upon successful load from DFS
- fileManager.setMaxSeenVersion(version)
+ // Initialize maxVersion upon successful load from DFS
+ fileManager.setMaxSeenVersion(version)
- // Report this snapshot version to the coordinator
- reportSnapshotUploadToCoordinator(latestSnapshotVersion)
+ // Report this snapshot version to the coordinator
+ reportSnapshotUploadToCoordinator(latestSnapshotVersion)
- openLocalRocksDB(metadata)
+ openLocalRocksDB(metadata)
- if (loadedVersion != version) {
- val versionsAndUniqueIds = currVersionLineage.collect {
+ if (loadedVersion != version) {
+ val versionsAndUniqueIds = currVersionLineage.collect {
case i if i.version > loadedVersion && i.version <= version =>
(i.version, Option(i.checkpointUniqueId))
}
- replayChangelog(versionsAndUniqueIds)
- loadedVersion = version
- lineageManager.resetLineage(currVersionLineage)
+ replayChangelog(versionsAndUniqueIds)
+ loadedVersion = version
+ lineageManager.resetLineage(currVersionLineage)
+ }
Review Comment:
should we update the `loaded version..` message below, just like we did for
`loadWithoutCheckpointId`
https://github.com/apache/spark/blob/acc80fd1b205792a43cc352a8e492e34bbe880da/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala#L604
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +292,32 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Option[Map[Long,
Array[Array[StateStoreCheckpointInfo]]]]): Unit = {
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-
- // todo: For checkpoint v2, we need to update the stateUniqueIds based on
the
- // newly created state commit. Will be done in subsequent PR.
- if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+ val commitMetadata = opIdToStateStoreCkptInfo.map {originalInfoMap =>
+ // opIdToStateStoreCkptInfo is Map[operatorId, Array[stateStore ->
Array[partition -> info]]]
+ // we change it to Map[operatorId, Array[partitionId -> Array[store ->
info]]]
Review Comment:
Why not let state rewriter do this conversion before returning it. Otherwise
every caller of state rewriter would need to implement this conversion
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -95,12 +95,14 @@ class OfflineStateRepartitionRunner(
transformFunc = Some(stateRepartitionFunc),
writeCheckpointMetadata = Some(checkpointMetadata)
)
- rewriter.run()
+ val operatorToCkptIds = rewriter.run()
updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId =
lastCommittedBatchId)
// Commit the repartition batch
- commitBatch(newBatchId, lastCommittedBatchId)
+ val enableCheckpointId = sparkSession.conf.get(
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt >= 2
Review Comment:
Use `StatefulOperatorStateInfo.enableStateStoreCheckpointIds` util function
for this. Also we can move this check into `commitBatch` func
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -443,77 +443,82 @@ class RocksDB(
private def loadWithCheckpointId(
version: Long,
stateStoreCkptId: Option[String],
- readOnly: Boolean = false): RocksDB = {
+ readOnly: Boolean = false,
+ loadEmpty: Boolean = false): RocksDB = {
// An array contains lineage information from [snapShotVersion, version]
// (inclusive in both ends)
var currVersionLineage: Array[LineageItem] =
lineageManager.getLineageForCurrVersion()
try {
- if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
- stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+ if (loadEmpty || loadedVersion != version ||
loadedStateStoreCkptId.isEmpty ||
+ stateStoreCkptId.get != loadedStateStoreCkptId.get) {
closeDB(ignoreException = false)
-
- val (latestSnapshotVersion, latestSnapshotUniqueId) = {
- // Special handling when version is 0.
- // When loading the very first version (0), stateStoreCkptId does
not need to be defined
- // because there won't be 0.changelog / 0.zip file created in
RocksDB under v2.
- if (version == 0) {
- assert(stateStoreCkptId.isEmpty,
- "stateStoreCkptId should be empty when version is zero")
- (0L, None)
- // When there is a snapshot file, it is the ground truth, we can skip
- // reconstructing the lineage from changelog file.
- } else if (fileManager.existsSnapshotFile(version,
stateStoreCkptId)) {
- currVersionLineage = Array(LineageItem(version,
stateStoreCkptId.get))
- (version, stateStoreCkptId)
- } else {
- currVersionLineage = getLineageFromChangelogFile(version,
stateStoreCkptId) :+
- LineageItem(version, stateStoreCkptId.get)
- currVersionLineage = currVersionLineage.sortBy(_.version)
-
- val latestSnapshotVersionsAndUniqueId =
-
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
- latestSnapshotVersionsAndUniqueId match {
- case Some(pair) => (pair._1, Option(pair._2))
- case None if currVersionLineage.head.version == 1L =>
- logDebug(log"Cannot find latest snapshot based on lineage but
first version " +
- log"is 1, use 0 as default. Lineage: ${MDC(LogKeys.LINEAGE,
lineageManager)}")
- (0L, None)
- case _ =>
- throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
- printLineageItems(currVersionLineage))
+ if (loadEmpty) {
Review Comment:
require that `stateStoreCkptId` is not present
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -106,71 +106,87 @@ class OfflineStateRepartitionSuite extends StreamTest
)
}
- test("Repartition: success, failure, retry") {
- withTempDir { dir =>
- val originalPartitions = 3
- val input = MemoryStream[Int]
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
- val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
- // Shouldn't be seen as a repartition batch
- assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
-
- // Trying to repartition to the same number should fail
- val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithChangelogAndCheckpointId(testName: String)(testFun: => Unit):
Unit = {
Review Comment:
nit: you mean `testWithCheckpointId`?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -106,71 +106,87 @@ class OfflineStateRepartitionSuite extends StreamTest
)
}
- test("Repartition: success, failure, retry") {
- withTempDir { dir =>
- val originalPartitions = 3
- val input = MemoryStream[Int]
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
- val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
- // Shouldn't be seen as a repartition batch
- assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
-
- // Trying to repartition to the same number should fail
- val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithChangelogAndCheckpointId(testName: String)(testFun: => Unit):
Unit = {
+ test(s"$testName (enableCkptId = ${ckptVersion >= 2})") {
+ withSQLConf(
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key ->
ckptVersion.toString) {
+ testFun
+ }
}
- checkError(
- ex,
- condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "batchId" -> batchId.toString,
- "numPartitions" -> originalPartitions.toString
- )
- )
-
- // Trying to repartition to a different number should succeed
- val newPartitions = originalPartitions + 1
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
- val repartitionBatchId = batchId + 1
- val hadoopConf = spark.sessionState.newHadoopConf()
- verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
-
- // Now delete the repartition commit to simulate a failed repartition
attempt.
- // This will delete all the commits after the batchId.
- checkpointMetadata.commitLog.purgeAfter(batchId)
+ }
- // Try to repartition with a different numPartitions should fail,
- // since it will see an uncommitted repartition batch with a different
numPartitions.
- val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions + 1)
- }
- checkError(
- ex2,
- condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "lastBatchId" -> repartitionBatchId.toString,
- "lastBatchShufflePartitions" -> newPartitions.toString,
- "numPartitions" -> (newPartitions + 1).toString
+ testWithChangelogAndCheckpointId("Repartition: success, failure, retry") {
+ withTempDir { dir =>
+ val originalPartitions = 3
+ val input = MemoryStream[Int]
+ val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
+ val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
+ // Shouldn't be seen as a repartition batch
+ assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
+
+ // Trying to repartition to the same number should fail
+ val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError]
{
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ }
+ checkError(
+ ex,
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "batchId" -> batchId.toString,
+ "numPartitions" -> originalPartitions.toString
+ )
)
- )
- // Retrying with the same numPartitions should work
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
- verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+ // Trying to repartition to a different number should succeed
+ val newPartitions = originalPartitions + 1
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
+ val repartitionBatchId = batchId + 1
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ verifyRepartitionBatch(
+ repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+
+ // Now delete the repartition commit to simulate a failed repartition
attempt.
+ // This will delete all the commits after the batchId.
+ checkpointMetadata.commitLog.purgeAfter(batchId)
+
+ // Try to repartition with a different numPartitions should fail,
+ // since it will see an uncommitted repartition batch with a different
numPartitions.
+ val ex2 =
intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions + 1)
+ }
+ checkError(
+ ex2,
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "lastBatchId" -> repartitionBatchId.toString,
+ "lastBatchShufflePartitions" -> newPartitions.toString,
+ "numPartitions" -> (newPartitions + 1).toString
+ )
+ )
- // Repartition with way more partitions, to verify that empty partitions
are properly created
- val morePartitions = newPartitions * 3
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
- verifyRepartitionBatch(
- repartitionBatchId + 1, checkpointMetadata, hadoopConf,
- dir.getAbsolutePath, morePartitions)
+ // Retrying with the same numPartitions should work
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+
+ // Repartition with way more partitions, to verify that empty
partitions are properly
+ // created
+ val morePartitions = newPartitions * 3
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+ dir.getAbsolutePath, morePartitions)
+ if (spark.conf.get(
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt >= 2) {
+ verifyCheckpointIds(repartitionBatchId + 1, checkpointMetadata,
morePartitions)
+ }
- // Restart the query to make sure it can start after repartitioning
- runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+ // Restart the query to make sure it can start after repartitioning
+ runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+ }
}
}
Review Comment:
we also need to run the two other test cases below with checkpoint id
enabled/disabled
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -450,8 +464,28 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
- testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
- testRoundTripForAggrStateVersion(1)
+ // Run transformWithState tests with enable/disable checkpoint V2
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithChangelogAndCheckpointId(testName: String)(testFun: =>
Unit): Unit = {
+ test(s"$testName ($changelogCpTestSuffix, enableCkptId = ${ckptVersion
>= 2})") {
+ withSQLConf(
+
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
+ changelogCheckpointingEnabled.toString,
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key ->
ckptVersion.toString) {
+ testFun
+ }
+ }
+ }
+
+ testWithChangelogAndCheckpointId("SPARK-54420: aggregation state ver 1")
{
+ testRoundTripForAggrStateVersion(1)
+ }
+
+ Seq(1, 2).foreach { version =>
+ testWithChangelogAndCheckpointId(s"SPARK-54420: stream-stream join
state ver $version") {
+ testStreamStreamJoinRoundTrip(version)
+ }
+ }
Review Comment:
Shouldn't we be testing checkpoint v2 for the others too. At least we should
also test for TWS too.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -95,11 +96,35 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
transformFunc = None,
writeCheckpointMetadata = Some(targetCheckpointMetadata)
)
- rewriter.run()
+ val checkpointInfos = rewriter.run()
- // Commit to commitLog
+ // Commit to commitLog with checkpoint IDs
val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
- targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+ val operatorId = 0L
+ val storesSeq: Array[Array[StateStoreCheckpointInfo]] =
checkpointInfos(operatorId)
+ val commitMetadata = if
(StateStoreConf(conf).enableStateStoreCheckpointIds) {
Review Comment:
We won't need this conversion here if rewriter is doing the conversion
before returning it. It is cleaner since each rewriter caller won't need to
implement this.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -95,11 +96,35 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
transformFunc = None,
writeCheckpointMetadata = Some(targetCheckpointMetadata)
)
- rewriter.run()
+ val checkpointInfos = rewriter.run()
- // Commit to commitLog
+ // Commit to commitLog with checkpoint IDs
val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
- targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+ val operatorId = 0L
+ val storesSeq: Array[Array[StateStoreCheckpointInfo]] =
checkpointInfos(operatorId)
+ val commitMetadata = if
(StateStoreConf(conf).enableStateStoreCheckpointIds) {
+ // Build map: partition id -> array of checkpoint IDs (one per store, in
order)
+ // checkpointInfos(operatorId) is Seq[Array[StateStoreCheckpointInfo]]
+ // where Seq is stores (in order: store0, store1, ...), Array is
partitions
+ // For join operators: 4 stores (left-keyToNumValues,
left-keyWithIndexToValue,
+ // right-keyToNumValues,
right-keyWithIndexToValue)
+ // For regular operators: 1 store
+ val numPartitions = storesSeq.head.length
+ val ckptIds: Array[Array[String]] = (0 until numPartitions).map {
partitionIdx =>
+ // For this partition, collect checkpoint IDs from all stores (in
order)
+ storesSeq.flatMap { storePartitions =>
+ storePartitions(partitionIdx).stateStoreCkptId
+ }
+ }.toArray
+ // Include checkpoint IDs in commit metadata
+ CommitMetadata(
Review Comment:
ditto, use copy() instead
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
}
}
}
+
+ def verifyCheckpointIds(
+ repartitionBatchId: Long,
+ checkpointMetadata: StreamingQueryCheckpointMetadata,
+ expectedShufflePartitions: Int): Unit = {
+ // Verify commit log has the repartition batch with checkpoint IDs
+ val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+ assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should
exist")
+
+ val commitMetadata = commitOpt.get
+
+ // Verify stateUniqueIds is present for checkpoint V2
+ assert(commitMetadata.stateUniqueIds.isDefined,
+ "stateUniqueIds should be present in commit metadata when checkpoint
version >= 2")
+
+ val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+ assert(operatorIdToCkptInfos.nonEmpty,
+ "operatorIdToCkptInfos should not be empty")
+
+ // Verify structure for each operator
+ operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+ // Should have checkpoint IDs for all partitions
+ assert(partitionToCkptIds.length == expectedShufflePartitions,
+ s"Operator $operatorId: Expected $expectedShufflePartitions partition
checkpoint IDs, " +
+ s"but found ${partitionToCkptIds.length}")
+ // Each partition should have checkpoint IDs (at least one per store)
+ partitionToCkptIds.foreach { ckptIds =>
+ assert(ckptIds.nonEmpty, s"Operator $operatorId should have checkpoint
IDs")
Review Comment:
nit: include the partitionId in the message right?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
}
}
}
+
+ def verifyCheckpointIds(
+ repartitionBatchId: Long,
+ checkpointMetadata: StreamingQueryCheckpointMetadata,
+ expectedShufflePartitions: Int): Unit = {
+ // Verify commit log has the repartition batch with checkpoint IDs
+ val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+ assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should
exist")
+
+ val commitMetadata = commitOpt.get
+
+ // Verify stateUniqueIds is present for checkpoint V2
+ assert(commitMetadata.stateUniqueIds.isDefined,
+ "stateUniqueIds should be present in commit metadata when checkpoint
version >= 2")
+
+ val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+ assert(operatorIdToCkptInfos.nonEmpty,
+ "operatorIdToCkptInfos should not be empty")
+
+ // Verify structure for each operator
+ operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+ // Should have checkpoint IDs for all partitions
+ assert(partitionToCkptIds.length == expectedShufflePartitions,
+ s"Operator $operatorId: Expected $expectedShufflePartitions partition
checkpoint IDs, " +
+ s"but found ${partitionToCkptIds.length}")
+ // Each partition should have checkpoint IDs (at least one per store)
+ partitionToCkptIds.foreach { ckptIds =>
+ assert(ckptIds.nonEmpty, s"Operator $operatorId should have checkpoint
IDs")
+ // Each checkpoint ID should be a non-empty string
+ ckptIds.foreach { ckptId =>
Review Comment:
We also need to assert that the number of ckptIds for the partition, matches
the number of state stores the operator has. You can read the operator metadata
to know how many stores per operatorId.
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3943,7 +3943,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures
with SharedSparkSession
}}
}
- test("SPARK-54420: load with createEmpty creates empty store") {
+ testWithStateStoreCheckpointIds("SPARK-54420: load with createEmpty creates
empty store") { _ =>
Review Comment:
This test case is actually not truly testing checkpointId, since you are not
passing in the checkpointId you got from `db.commit` into `db.load`
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
}
}
}
+
+ def verifyCheckpointIds(
+ repartitionBatchId: Long,
+ checkpointMetadata: StreamingQueryCheckpointMetadata,
+ expectedShufflePartitions: Int): Unit = {
+ // Verify commit log has the repartition batch with checkpoint IDs
+ val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+ assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should
exist")
+
+ val commitMetadata = commitOpt.get
+
+ // Verify stateUniqueIds is present for checkpoint V2
+ assert(commitMetadata.stateUniqueIds.isDefined,
+ "stateUniqueIds should be present in commit metadata when checkpoint
version >= 2")
+
+ val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+ assert(operatorIdToCkptInfos.nonEmpty,
+ "operatorIdToCkptInfos should not be empty")
+
+ // Verify structure for each operator
+ operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+ // Should have checkpoint IDs for all partitions
+ assert(partitionToCkptIds.length == expectedShufflePartitions,
+ s"Operator $operatorId: Expected $expectedShufflePartitions partition
checkpoint IDs, " +
+ s"but found ${partitionToCkptIds.length}")
+ // Each partition should have checkpoint IDs (at least one per store)
+ partitionToCkptIds.foreach { ckptIds =>
+ assert(ckptIds.nonEmpty, s"Operator $operatorId should have checkpoint
IDs")
+ // Each checkpoint ID should be a non-empty string
+ ckptIds.foreach { ckptId =>
+ assert(ckptId.nonEmpty,
+ s"Operator $operatorId: Checkpoint ID should be non-empty")
Review Comment:
ditto
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
}
}
}
+
+ def verifyCheckpointIds(
Review Comment:
private
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -450,8 +464,28 @@ class StatePartitionAllColumnFamiliesWriterSuite extends
StateDataSourceTestBase
}
}
- testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
- testRoundTripForAggrStateVersion(1)
+ // Run transformWithState tests with enable/disable checkpoint V2
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithChangelogAndCheckpointId(testName: String)(testFun: =>
Unit): Unit = {
+ test(s"$testName ($changelogCpTestSuffix, enableCkptId = ${ckptVersion
>= 2})") {
+ withSQLConf(
+
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
Review Comment:
why not only set the format version here, then call
`testWithChangelogConfig`. Instead of duplicating it here
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -106,71 +106,87 @@ class OfflineStateRepartitionSuite extends StreamTest
)
}
- test("Repartition: success, failure, retry") {
- withTempDir { dir =>
- val originalPartitions = 3
- val input = MemoryStream[Int]
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
- val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
- // Shouldn't be seen as a repartition batch
- assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
-
- // Trying to repartition to the same number should fail
- val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ Seq(1, 2).foreach { ckptVersion =>
+ def testWithChangelogAndCheckpointId(testName: String)(testFun: => Unit):
Unit = {
+ test(s"$testName (enableCkptId = ${ckptVersion >= 2})") {
+ withSQLConf(
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key ->
ckptVersion.toString) {
+ testFun
+ }
}
- checkError(
- ex,
- condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "batchId" -> batchId.toString,
- "numPartitions" -> originalPartitions.toString
- )
- )
-
- // Trying to repartition to a different number should succeed
- val newPartitions = originalPartitions + 1
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
- val repartitionBatchId = batchId + 1
- val hadoopConf = spark.sessionState.newHadoopConf()
- verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
-
- // Now delete the repartition commit to simulate a failed repartition
attempt.
- // This will delete all the commits after the batchId.
- checkpointMetadata.commitLog.purgeAfter(batchId)
+ }
- // Try to repartition with a different numPartitions should fail,
- // since it will see an uncommitted repartition batch with a different
numPartitions.
- val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions + 1)
- }
- checkError(
- ex2,
- condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
- parameters = Map(
- "checkpointLocation" -> dir.getAbsolutePath,
- "lastBatchId" -> repartitionBatchId.toString,
- "lastBatchShufflePartitions" -> newPartitions.toString,
- "numPartitions" -> (newPartitions + 1).toString
+ testWithChangelogAndCheckpointId("Repartition: success, failure, retry") {
+ withTempDir { dir =>
+ val originalPartitions = 3
+ val input = MemoryStream[Int]
+ val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
+ val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
+ // Shouldn't be seen as a repartition batch
+ assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
+
+ // Trying to repartition to the same number should fail
+ val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError]
{
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions)
+ }
+ checkError(
+ ex,
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "batchId" -> batchId.toString,
+ "numPartitions" -> originalPartitions.toString
+ )
)
- )
- // Retrying with the same numPartitions should work
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
- verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+ // Trying to repartition to a different number should succeed
+ val newPartitions = originalPartitions + 1
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
+ val repartitionBatchId = batchId + 1
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ verifyRepartitionBatch(
+ repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+
+ // Now delete the repartition commit to simulate a failed repartition
attempt.
+ // This will delete all the commits after the batchId.
+ checkpointMetadata.commitLog.purgeAfter(batchId)
+
+ // Try to repartition with a different numPartitions should fail,
+ // since it will see an uncommitted repartition batch with a different
numPartitions.
+ val ex2 =
intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions + 1)
+ }
+ checkError(
+ ex2,
+ condition =
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
+ parameters = Map(
+ "checkpointLocation" -> dir.getAbsolutePath,
+ "lastBatchId" -> repartitionBatchId.toString,
+ "lastBatchShufflePartitions" -> newPartitions.toString,
+ "numPartitions" -> (newPartitions + 1).toString
+ )
+ )
- // Repartition with way more partitions, to verify that empty partitions
are properly created
- val morePartitions = newPartitions * 3
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
- verifyRepartitionBatch(
- repartitionBatchId + 1, checkpointMetadata, hadoopConf,
- dir.getAbsolutePath, morePartitions)
+ // Retrying with the same numPartitions should work
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+
+ // Repartition with way more partitions, to verify that empty
partitions are properly
+ // created
+ val morePartitions = newPartitions * 3
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+ dir.getAbsolutePath, morePartitions)
+ if (spark.conf.get(
+ SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt >= 2) {
+ verifyCheckpointIds(repartitionBatchId + 1, checkpointMetadata,
morePartitions)
Review Comment:
Lets call this `verifyCheckpointIds` within `verifyRepartitionBatch`, just
like we do for the other verifications. So that by calling
`verifyRepartitionBatch`, it does all the necessary verifications.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +292,32 @@ class OfflineStateRepartitionRunner(
}
}
- private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
+ private def commitBatch(
+ newBatchId: Long,
+ lastCommittedBatchId: Long,
+ opIdToStateStoreCkptInfo: Option[Map[Long,
Array[Array[StateStoreCheckpointInfo]]]]): Unit = {
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-
- // todo: For checkpoint v2, we need to update the stateUniqueIds based on
the
- // newly created state commit. Will be done in subsequent PR.
- if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+ val commitMetadata = opIdToStateStoreCkptInfo.map {originalInfoMap =>
+ // opIdToStateStoreCkptInfo is Map[operatorId, Array[stateStore ->
Array[partition -> info]]]
+ // we change it to Map[operatorId, Array[partitionId -> Array[store ->
info]]]
+ val opIdToPartitionCkptInfo: Map[Long, Array[Array[String]]] =
+ originalInfoMap.map {
+ case(operator, storesSeq) =>
+ val numPartitions = storesSeq.head.length
+ operator -> (0 until numPartitions).map { partitionIdx =>
+ storesSeq.flatMap { storePartitions =>
+ storePartitions(partitionIdx).stateStoreCkptId
Review Comment:
When we move this conversion to rewriter. We should add some assertions
i.e.
1. that the `storePartitions(partitionIdx).partitionId == partitionIdx`
2. that the `storePartitions(partitionIdx).batchVersion` is what we expect
3. same for `storePartitions(partitionIdx).baseStateStoreCkptId == None`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]