micheal-o commented on code in PR #53720:
URL: https://github.com/apache/spark/pull/53720#discussion_r2679160973


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -133,7 +136,7 @@ class StateRewriter(
         val stateVarsIfTws = getStateVariablesIfTWS(opMetadata)
 
         // Rewrite each state store of the operator
-        stateStoresMetadata.foreach { stateStoreMetadata =>
+        opMetadata.operatorInfo.operatorId -> stateStoresMetadata.map { 
stateStoreMetadata =>

Review Comment:
   nit: have a `val` here that collects the checkpointIds and then make this 
`operatorId -> checkpointIds` a separate line



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -543,25 +548,30 @@ class RocksDB(
         lineageManager.clear()
         throw t
     }
-    if (enableChangelogCheckpointing && !readOnly) {
+    if (conf.enableChangelogCheckpointing && !readOnly) {

Review Comment:
   Add this comment here too:
   
https://github.com/apache/spark/blob/acc80fd1b205792a43cc352a8e492e34bbe880da/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala#L614



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -543,25 +548,30 @@ class RocksDB(
         lineageManager.clear()
         throw t
     }
-    if (enableChangelogCheckpointing && !readOnly) {
+    if (conf.enableChangelogCheckpointing && !readOnly) {
       // Make sure we don't leak resource.
       changelogWriter.foreach(_.abort())
-      // Initialize the changelog writer with lineage info
-      // The lineage stored in changelog files should normally start with
-      // the version of a snapshot, except for the first few versions.
-      // Because they are solely loaded from changelog file.
-      // (e.g. with default minDeltasForSnapshot, there is only 
1_uuid1.changelog, no 1_uuid1.zip)
-      // It should end with exactly one version before the change log's 
version.
-      changelogWriter = Some(fileManager.getChangeLogWriter(
-        version + 1,
-        useColumnFamilies,
-        sessionStateStoreCkptId,
-        Some(currVersionLineage)))
+      if (loadEmpty) {
+        // No changelog writer for empty stores

Review Comment:
   nit: "We don't want to write changelog file when loadEmpty is true"



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -81,23 +82,25 @@ class StateRewriter(
   private val stateRootLocation = new Path(
     resolvedCheckpointLocation, 
StreamingCheckpointConstants.DIR_NAME_STATE).toString
 
-  def run(): Unit = {
+  // return a Map[operator id, Array[stateStore -> Array[partition -> 
StateStoreCheckpointInfo]]]
+  def run(): Map[Long, Array[Array[StateStoreCheckpointInfo]]] = {
     logInfo(log"Starting state rewrite for " +
       log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}, " +
       log"readCheckpointLocation=" +
       log"${MDC(CHECKPOINT_LOCATION, 
readResolvedCheckpointLocation.getOrElse(""))}, " +
       log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
       log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
 
-    val (_, timeTakenMs) = Utils.timeTakenMs {
+    val (checkpointInfos, timeTakenMs) = Utils.timeTakenMs {
       runInternal()
     }
 
     logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms 
for " +
       log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}")

Review Comment:
   lets also include the checkpointInfos in the log here



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala:
##########
@@ -124,7 +127,7 @@ class StateRewriter(
 
       // Do rewrite for each operator
       // We can potentially parallelize this, but for now, do sequentially
-      allOperatorsMetadata.foreach { opMetadata =>
+      allOperatorsMetadata.map { opMetadata =>

Review Comment:
   Before we do rewrite, we need to check if the readBatchId has checkpointIds 
in the commitlog, if so, we should set 
`SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION` in the `sqlConf` to enable 
checkpoint v2.
   
   We can't rely on the user to correctly set this conf in the session before 
running repartition. They may forget and this conf isn't part of the confs 
written to checkpoint. You can repro this, by not setting the conf in the 
session and the rewrite will fail. Lets also have a test to validate this.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +292,32 @@ class OfflineStateRepartitionRunner(
     }
   }
 
-  private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit 
= {
+  private def commitBatch(
+      newBatchId: Long,
+      lastCommittedBatchId: Long,
+      opIdToStateStoreCkptInfo: Option[Map[Long, 
Array[Array[StateStoreCheckpointInfo]]]]): Unit = {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-
-    // todo: For checkpoint v2, we need to update the stateUniqueIds based on 
the
-    //  newly created state commit. Will be done in subsequent PR.
-    if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+    val commitMetadata = opIdToStateStoreCkptInfo.map {originalInfoMap =>
+      // opIdToStateStoreCkptInfo is Map[operatorId, Array[stateStore -> 
Array[partition -> info]]]
+      // we change it to Map[operatorId, Array[partitionId -> Array[store -> 
info]]]
+      val opIdToPartitionCkptInfo: Map[Long, Array[Array[String]]] =
+        originalInfoMap.map {
+          case(operator, storesSeq) =>
+            val numPartitions = storesSeq.head.length
+            operator -> (0 until numPartitions).map { partitionIdx =>
+              storesSeq.flatMap { storePartitions =>
+                storePartitions(partitionIdx).stateStoreCkptId
+            }
+          }.toArray
+        }
+      // Include checkpoint IDs in commit metadata
+      CommitMetadata(

Review Comment:
   nit: can just do `latestCommit.copy(stateUniqueIds = 
Option(opIdToPartitionCkptInfo))`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -443,77 +443,82 @@ class RocksDB(
   private def loadWithCheckpointId(
       version: Long,
       stateStoreCkptId: Option[String],
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      loadEmpty: Boolean = false): RocksDB = {
     // An array contains lineage information from [snapShotVersion, version]
     // (inclusive in both ends)
     var currVersionLineage: Array[LineageItem] = 
lineageManager.getLineageForCurrVersion()
     try {
-      if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
-          stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+      if (loadEmpty || loadedVersion != version || 
loadedStateStoreCkptId.isEmpty ||
+        stateStoreCkptId.get != loadedStateStoreCkptId.get) {
         closeDB(ignoreException = false)
-
-        val (latestSnapshotVersion, latestSnapshotUniqueId) = {
-          // Special handling when version is 0.
-          // When loading the very first version (0), stateStoreCkptId does 
not need to be defined
-          // because there won't be 0.changelog / 0.zip file created in 
RocksDB under v2.
-          if (version == 0) {
-            assert(stateStoreCkptId.isEmpty,
-              "stateStoreCkptId should be empty when version is zero")
-            (0L, None)
-          // When there is a snapshot file, it is the ground truth, we can skip
-          // reconstructing the lineage from changelog file.
-          } else if (fileManager.existsSnapshotFile(version, 
stateStoreCkptId)) {
-            currVersionLineage = Array(LineageItem(version, 
stateStoreCkptId.get))
-            (version, stateStoreCkptId)
-          } else {
-            currVersionLineage = getLineageFromChangelogFile(version, 
stateStoreCkptId) :+
-              LineageItem(version, stateStoreCkptId.get)
-            currVersionLineage = currVersionLineage.sortBy(_.version)
-
-            val latestSnapshotVersionsAndUniqueId =
-              
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
-            latestSnapshotVersionsAndUniqueId match {
-              case Some(pair) => (pair._1, Option(pair._2))
-              case None if currVersionLineage.head.version == 1L =>
-                logDebug(log"Cannot find latest snapshot based on lineage but 
first version " +
-                  log"is 1, use 0 as default. Lineage: ${MDC(LogKeys.LINEAGE, 
lineageManager)}")
-                (0L, None)
-              case _ =>
-                throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
-                  printLineageItems(currVersionLineage))
+        if (loadEmpty) {
+          loadEmptyStore(version)
+          lineageManager.clear()
+        } else {
+          val (latestSnapshotVersion, latestSnapshotUniqueId) = {
+            // Special handling when version is 0.
+            // When loading the very first version (0), stateStoreCkptId does 
not need to be defined
+            // because there won't be 0.changelog / 0.zip file created in 
RocksDB under v2.
+            if (version == 0) {
+              assert(stateStoreCkptId.isEmpty,
+                "stateStoreCkptId should be empty when version is zero")
+              (0L, None)
+              // When there is a snapshot file, it is the ground truth, we can 
skip
+              // reconstructing the lineage from changelog file.
+            } else if (fileManager.existsSnapshotFile(version, 
stateStoreCkptId)) {
+              currVersionLineage = Array(LineageItem(version, 
stateStoreCkptId.get))
+              (version, stateStoreCkptId)
+            } else {
+              currVersionLineage = getLineageFromChangelogFile(version, 
stateStoreCkptId) :+
+                LineageItem(version, stateStoreCkptId.get)
+              currVersionLineage = currVersionLineage.sortBy(_.version)
+
+              val latestSnapshotVersionsAndUniqueId =
+                
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
+              latestSnapshotVersionsAndUniqueId match {
+                case Some(pair) => (pair._1, Option(pair._2))
+                case None if currVersionLineage.head.version == 1L =>
+                  logDebug(log"Cannot find latest snapshot based on lineage 
but first version " +
+                    log"is 1, use 0 as default. Lineage: 
${MDC(LogKeys.LINEAGE, lineageManager)}")
+                  (0L, None)
+                case _ =>
+                  throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
+                    printLineageItems(currVersionLineage))
+              }
             }
           }
-        }
 
-        logInfo(log"Loaded latestSnapshotVersion: ${
-          MDC(LogKeys.SNAPSHOT_VERSION, latestSnapshotVersion)}, 
latestSnapshotUniqueId: ${
-          MDC(LogKeys.UUID, latestSnapshotUniqueId)}")
+          logInfo(log"Loaded latestSnapshotVersion: ${
+            MDC(LogKeys.SNAPSHOT_VERSION, latestSnapshotVersion)}, 
latestSnapshotUniqueId: ${
+            MDC(LogKeys.UUID, latestSnapshotUniqueId)}")
 
-        val metadata = fileManager.loadCheckpointFromDfs(latestSnapshotVersion,
-          workingDir, rocksDBFileMapping, latestSnapshotUniqueId)
+          val metadata = 
fileManager.loadCheckpointFromDfs(latestSnapshotVersion,
+            workingDir, rocksDBFileMapping, latestSnapshotUniqueId)
 
-        loadedVersion = latestSnapshotVersion
+          loadedVersion = latestSnapshotVersion
 
-        // reset the last snapshot version to the latest available snapshot 
version
-        lastSnapshotVersion = latestSnapshotVersion
-        lineageManager.resetLineage(currVersionLineage)
+          // reset the last snapshot version to the latest available snapshot 
version
+          lastSnapshotVersion = latestSnapshotVersion
+          lineageManager.resetLineage(currVersionLineage)
 
-        // Initialize maxVersion upon successful load from DFS
-        fileManager.setMaxSeenVersion(version)
+          // Initialize maxVersion upon successful load from DFS
+          fileManager.setMaxSeenVersion(version)
 
-        // Report this snapshot version to the coordinator
-        reportSnapshotUploadToCoordinator(latestSnapshotVersion)
+          // Report this snapshot version to the coordinator
+          reportSnapshotUploadToCoordinator(latestSnapshotVersion)
 
-        openLocalRocksDB(metadata)
+          openLocalRocksDB(metadata)
 
-        if (loadedVersion != version) {
-          val versionsAndUniqueIds = currVersionLineage.collect {
+          if (loadedVersion != version) {
+            val versionsAndUniqueIds = currVersionLineage.collect {
               case i if i.version > loadedVersion && i.version <= version =>
                 (i.version, Option(i.checkpointUniqueId))
             }
-          replayChangelog(versionsAndUniqueIds)
-          loadedVersion = version
-          lineageManager.resetLineage(currVersionLineage)
+            replayChangelog(versionsAndUniqueIds)
+            loadedVersion = version
+            lineageManager.resetLineage(currVersionLineage)
+          }

Review Comment:
   should we update the `loaded version..` message below, just like we did for 
`loadWithoutCheckpointId`
   
https://github.com/apache/spark/blob/acc80fd1b205792a43cc352a8e492e34bbe880da/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala#L604



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +292,32 @@ class OfflineStateRepartitionRunner(
     }
   }
 
-  private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit 
= {
+  private def commitBatch(
+      newBatchId: Long,
+      lastCommittedBatchId: Long,
+      opIdToStateStoreCkptInfo: Option[Map[Long, 
Array[Array[StateStoreCheckpointInfo]]]]): Unit = {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-
-    // todo: For checkpoint v2, we need to update the stateUniqueIds based on 
the
-    //  newly created state commit. Will be done in subsequent PR.
-    if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+    val commitMetadata = opIdToStateStoreCkptInfo.map {originalInfoMap =>
+      // opIdToStateStoreCkptInfo is Map[operatorId, Array[stateStore -> 
Array[partition -> info]]]
+      // we change it to Map[operatorId, Array[partitionId -> Array[store -> 
info]]]

Review Comment:
   Why not let state rewriter do this conversion before returning it. Otherwise 
every caller of state rewriter would need to implement this conversion



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -95,12 +95,14 @@ class OfflineStateRepartitionRunner(
           transformFunc = Some(stateRepartitionFunc),
           writeCheckpointMetadata = Some(checkpointMetadata)
         )
-        rewriter.run()
+        val operatorToCkptIds = rewriter.run()
 
         updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId = 
lastCommittedBatchId)
 
         // Commit the repartition batch
-        commitBatch(newBatchId, lastCommittedBatchId)
+        val enableCheckpointId = sparkSession.conf.get(
+          SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt >= 2

Review Comment:
   Use `StatefulOperatorStateInfo.enableStateStoreCheckpointIds` util function 
for this. Also we can move this check into `commitBatch` func



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -443,77 +443,82 @@ class RocksDB(
   private def loadWithCheckpointId(
       version: Long,
       stateStoreCkptId: Option[String],
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      loadEmpty: Boolean = false): RocksDB = {
     // An array contains lineage information from [snapShotVersion, version]
     // (inclusive in both ends)
     var currVersionLineage: Array[LineageItem] = 
lineageManager.getLineageForCurrVersion()
     try {
-      if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
-          stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+      if (loadEmpty || loadedVersion != version || 
loadedStateStoreCkptId.isEmpty ||
+        stateStoreCkptId.get != loadedStateStoreCkptId.get) {
         closeDB(ignoreException = false)
-
-        val (latestSnapshotVersion, latestSnapshotUniqueId) = {
-          // Special handling when version is 0.
-          // When loading the very first version (0), stateStoreCkptId does 
not need to be defined
-          // because there won't be 0.changelog / 0.zip file created in 
RocksDB under v2.
-          if (version == 0) {
-            assert(stateStoreCkptId.isEmpty,
-              "stateStoreCkptId should be empty when version is zero")
-            (0L, None)
-          // When there is a snapshot file, it is the ground truth, we can skip
-          // reconstructing the lineage from changelog file.
-          } else if (fileManager.existsSnapshotFile(version, 
stateStoreCkptId)) {
-            currVersionLineage = Array(LineageItem(version, 
stateStoreCkptId.get))
-            (version, stateStoreCkptId)
-          } else {
-            currVersionLineage = getLineageFromChangelogFile(version, 
stateStoreCkptId) :+
-              LineageItem(version, stateStoreCkptId.get)
-            currVersionLineage = currVersionLineage.sortBy(_.version)
-
-            val latestSnapshotVersionsAndUniqueId =
-              
fileManager.getLatestSnapshotVersionAndUniqueIdFromLineage(currVersionLineage)
-            latestSnapshotVersionsAndUniqueId match {
-              case Some(pair) => (pair._1, Option(pair._2))
-              case None if currVersionLineage.head.version == 1L =>
-                logDebug(log"Cannot find latest snapshot based on lineage but 
first version " +
-                  log"is 1, use 0 as default. Lineage: ${MDC(LogKeys.LINEAGE, 
lineageManager)}")
-                (0L, None)
-              case _ =>
-                throw QueryExecutionErrors.cannotFindBaseSnapshotCheckpoint(
-                  printLineageItems(currVersionLineage))
+        if (loadEmpty) {

Review Comment:
   require that `stateStoreCkptId` is not present



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -106,71 +106,87 @@ class OfflineStateRepartitionSuite extends StreamTest
     )
   }
 
-  test("Repartition: success, failure, retry") {
-    withTempDir { dir =>
-      val originalPartitions = 3
-      val input = MemoryStream[Int]
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
-      val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
-      // Shouldn't be seen as a repartition batch
-      assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
-
-      // Trying to repartition to the same number should fail
-      val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+  Seq(1, 2).foreach { ckptVersion =>
+    def testWithChangelogAndCheckpointId(testName: String)(testFun: => Unit): 
Unit = {

Review Comment:
   nit: you mean `testWithCheckpointId`?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -106,71 +106,87 @@ class OfflineStateRepartitionSuite extends StreamTest
     )
   }
 
-  test("Repartition: success, failure, retry") {
-    withTempDir { dir =>
-      val originalPartitions = 3
-      val input = MemoryStream[Int]
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
-      val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
-      // Shouldn't be seen as a repartition batch
-      assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
-
-      // Trying to repartition to the same number should fail
-      val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+  Seq(1, 2).foreach { ckptVersion =>
+    def testWithChangelogAndCheckpointId(testName: String)(testFun: => Unit): 
Unit = {
+      test(s"$testName (enableCkptId = ${ckptVersion >= 2})") {
+        withSQLConf(
+          SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> 
ckptVersion.toString) {
+          testFun
+        }
       }
-      checkError(
-        ex,
-        condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "batchId" -> batchId.toString,
-          "numPartitions" -> originalPartitions.toString
-        )
-      )
-
-      // Trying to repartition to a different number should succeed
-      val newPartitions = originalPartitions + 1
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
-      val repartitionBatchId = batchId + 1
-      val hadoopConf = spark.sessionState.newHadoopConf()
-      verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
-
-      // Now delete the repartition commit to simulate a failed repartition 
attempt.
-      // This will delete all the commits after the batchId.
-      checkpointMetadata.commitLog.purgeAfter(batchId)
+    }
 
-      // Try to repartition with a different numPartitions should fail,
-      // since it will see an uncommitted repartition batch with a different 
numPartitions.
-      val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions + 1)
-      }
-      checkError(
-        ex2,
-        condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "lastBatchId" -> repartitionBatchId.toString,
-          "lastBatchShufflePartitions" -> newPartitions.toString,
-          "numPartitions" -> (newPartitions + 1).toString
+    testWithChangelogAndCheckpointId("Repartition: success, failure, retry") {
+      withTempDir { dir =>
+        val originalPartitions = 3
+        val input = MemoryStream[Int]
+        val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
+        val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
+        // Shouldn't be seen as a repartition batch
+        assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
+
+        // Trying to repartition to the same number should fail
+        val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] 
{
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+        }
+        checkError(
+          ex,
+          condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "batchId" -> batchId.toString,
+            "numPartitions" -> originalPartitions.toString
+          )
         )
-      )
 
-      // Retrying with the same numPartitions should work
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
-      verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+        // Trying to repartition to a different number should succeed
+        val newPartitions = originalPartitions + 1
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
+        val repartitionBatchId = batchId + 1
+        val hadoopConf = spark.sessionState.newHadoopConf()
+        verifyRepartitionBatch(
+          repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+
+        // Now delete the repartition commit to simulate a failed repartition 
attempt.
+        // This will delete all the commits after the batchId.
+        checkpointMetadata.commitLog.purgeAfter(batchId)
+
+        // Try to repartition with a different numPartitions should fail,
+        // since it will see an uncommitted repartition batch with a different 
numPartitions.
+        val ex2 = 
intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions + 1)
+        }
+        checkError(
+          ex2,
+          condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "lastBatchId" -> repartitionBatchId.toString,
+            "lastBatchShufflePartitions" -> newPartitions.toString,
+            "numPartitions" -> (newPartitions + 1).toString
+          )
+        )
 
-      // Repartition with way more partitions, to verify that empty partitions 
are properly created
-      val morePartitions = newPartitions * 3
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
-      verifyRepartitionBatch(
-        repartitionBatchId + 1, checkpointMetadata, hadoopConf,
-        dir.getAbsolutePath, morePartitions)
+        // Retrying with the same numPartitions should work
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
+        verifyRepartitionBatch(
+          repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+
+        // Repartition with way more partitions, to verify that empty 
partitions are properly
+        // created
+        val morePartitions = newPartitions * 3
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
+        verifyRepartitionBatch(
+          repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+          dir.getAbsolutePath, morePartitions)
+        if (spark.conf.get(
+          SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt >= 2) {
+          verifyCheckpointIds(repartitionBatchId + 1, checkpointMetadata, 
morePartitions)
+        }
 
-      // Restart the query to make sure it can start after repartitioning
-      runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+        // Restart the query to make sure it can start after repartitioning
+        runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+      }
     }
   }

Review Comment:
   we also need to run the two other test cases below with checkpoint id 
enabled/disabled



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -450,8 +464,28 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       }
     }
 
-    testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
-      testRoundTripForAggrStateVersion(1)
+    // Run transformWithState tests with enable/disable checkpoint V2
+    Seq(1, 2).foreach { ckptVersion =>
+      def testWithChangelogAndCheckpointId(testName: String)(testFun: => 
Unit): Unit = {
+        test(s"$testName ($changelogCpTestSuffix, enableCkptId = ${ckptVersion 
>= 2})") {
+          withSQLConf(
+            
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
+              changelogCheckpointingEnabled.toString,
+            SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> 
ckptVersion.toString) {
+            testFun
+          }
+        }
+      }
+
+      testWithChangelogAndCheckpointId("SPARK-54420: aggregation state ver 1") 
{
+        testRoundTripForAggrStateVersion(1)
+      }
+
+      Seq(1, 2).foreach { version =>
+        testWithChangelogAndCheckpointId(s"SPARK-54420: stream-stream join 
state ver $version") {
+          testStreamStreamJoinRoundTrip(version)
+        }
+      }

Review Comment:
   Shouldn't we be testing checkpoint v2 for the others too. At least we should 
also test for TWS too.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -95,11 +96,35 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       transformFunc = None,
       writeCheckpointMetadata = Some(targetCheckpointMetadata)
     )
-    rewriter.run()
+    val checkpointInfos = rewriter.run()
 
-    // Commit to commitLog
+    // Commit to commitLog with checkpoint IDs
     val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
-    targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+    val operatorId = 0L
+    val storesSeq: Array[Array[StateStoreCheckpointInfo]] = 
checkpointInfos(operatorId)
+    val commitMetadata = if 
(StateStoreConf(conf).enableStateStoreCheckpointIds) {

Review Comment:
   We won't need this conversion here if rewriter is doing the conversion 
before returning it. It is cleaner since each rewriter caller won't need to 
implement this.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -95,11 +96,35 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       transformFunc = None,
       writeCheckpointMetadata = Some(targetCheckpointMetadata)
     )
-    rewriter.run()
+    val checkpointInfos = rewriter.run()
 
-    // Commit to commitLog
+    // Commit to commitLog with checkpoint IDs
     val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
-    targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+    val operatorId = 0L
+    val storesSeq: Array[Array[StateStoreCheckpointInfo]] = 
checkpointInfos(operatorId)
+    val commitMetadata = if 
(StateStoreConf(conf).enableStateStoreCheckpointIds) {
+      // Build map: partition id -> array of checkpoint IDs (one per store, in 
order)
+      // checkpointInfos(operatorId) is Seq[Array[StateStoreCheckpointInfo]]
+      // where Seq is stores (in order: store0, store1, ...), Array is 
partitions
+      // For join operators: 4 stores (left-keyToNumValues, 
left-keyWithIndexToValue,
+      //                              right-keyToNumValues, 
right-keyWithIndexToValue)
+      // For regular operators: 1 store
+      val numPartitions = storesSeq.head.length
+      val ckptIds: Array[Array[String]] = (0 until numPartitions).map { 
partitionIdx =>
+          // For this partition, collect checkpoint IDs from all stores (in 
order)
+          storesSeq.flatMap { storePartitions =>
+            storePartitions(partitionIdx).stateStoreCkptId
+          }
+        }.toArray
+      // Include checkpoint IDs in commit metadata
+      CommitMetadata(

Review Comment:
   ditto, use copy() instead



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
         }
     }
   }
+
+  def verifyCheckpointIds(
+      repartitionBatchId: Long,
+      checkpointMetadata: StreamingQueryCheckpointMetadata,
+      expectedShufflePartitions: Int): Unit = {
+    // Verify commit log has the repartition batch with checkpoint IDs
+    val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+    assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should 
exist")
+
+    val commitMetadata = commitOpt.get
+
+    // Verify stateUniqueIds is present for checkpoint V2
+    assert(commitMetadata.stateUniqueIds.isDefined,
+      "stateUniqueIds should be present in commit metadata when checkpoint 
version >= 2")
+
+    val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+    assert(operatorIdToCkptInfos.nonEmpty,
+      "operatorIdToCkptInfos should not be empty")
+
+    // Verify structure for each operator
+    operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+      // Should have checkpoint IDs for all partitions
+      assert(partitionToCkptIds.length == expectedShufflePartitions,
+        s"Operator $operatorId: Expected $expectedShufflePartitions partition 
checkpoint IDs, " +
+          s"but found ${partitionToCkptIds.length}")
+      // Each partition should have checkpoint IDs (at least one per store)
+      partitionToCkptIds.foreach { ckptIds =>
+        assert(ckptIds.nonEmpty, s"Operator $operatorId should have checkpoint 
IDs")

Review Comment:
   nit: include the partitionId in the message right?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
         }
     }
   }
+
+  def verifyCheckpointIds(
+      repartitionBatchId: Long,
+      checkpointMetadata: StreamingQueryCheckpointMetadata,
+      expectedShufflePartitions: Int): Unit = {
+    // Verify commit log has the repartition batch with checkpoint IDs
+    val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+    assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should 
exist")
+
+    val commitMetadata = commitOpt.get
+
+    // Verify stateUniqueIds is present for checkpoint V2
+    assert(commitMetadata.stateUniqueIds.isDefined,
+      "stateUniqueIds should be present in commit metadata when checkpoint 
version >= 2")
+
+    val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+    assert(operatorIdToCkptInfos.nonEmpty,
+      "operatorIdToCkptInfos should not be empty")
+
+    // Verify structure for each operator
+    operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+      // Should have checkpoint IDs for all partitions
+      assert(partitionToCkptIds.length == expectedShufflePartitions,
+        s"Operator $operatorId: Expected $expectedShufflePartitions partition 
checkpoint IDs, " +
+          s"but found ${partitionToCkptIds.length}")
+      // Each partition should have checkpoint IDs (at least one per store)
+      partitionToCkptIds.foreach { ckptIds =>
+        assert(ckptIds.nonEmpty, s"Operator $operatorId should have checkpoint 
IDs")
+        // Each checkpoint ID should be a non-empty string
+        ckptIds.foreach { ckptId =>

Review Comment:
   We also need to assert that the number of ckptIds for the partition, matches 
the number of state stores the operator has. You can read the operator metadata 
to know how many stores per operatorId.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3943,7 +3943,7 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
-  test("SPARK-54420: load with createEmpty creates empty store") {
+  testWithStateStoreCheckpointIds("SPARK-54420: load with createEmpty creates 
empty store") { _ =>

Review Comment:
   This test case is actually not truly testing checkpointId, since you are not 
passing in the checkpointId you got from `db.commit` into `db.load`



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
         }
     }
   }
+
+  def verifyCheckpointIds(
+      repartitionBatchId: Long,
+      checkpointMetadata: StreamingQueryCheckpointMetadata,
+      expectedShufflePartitions: Int): Unit = {
+    // Verify commit log has the repartition batch with checkpoint IDs
+    val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+    assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should 
exist")
+
+    val commitMetadata = commitOpt.get
+
+    // Verify stateUniqueIds is present for checkpoint V2
+    assert(commitMetadata.stateUniqueIds.isDefined,
+      "stateUniqueIds should be present in commit metadata when checkpoint 
version >= 2")
+
+    val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+    assert(operatorIdToCkptInfos.nonEmpty,
+      "operatorIdToCkptInfos should not be empty")
+
+    // Verify structure for each operator
+    operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+      // Should have checkpoint IDs for all partitions
+      assert(partitionToCkptIds.length == expectedShufflePartitions,
+        s"Operator $operatorId: Expected $expectedShufflePartitions partition 
checkpoint IDs, " +
+          s"but found ${partitionToCkptIds.length}")
+      // Each partition should have checkpoint IDs (at least one per store)
+      partitionToCkptIds.foreach { ckptIds =>
+        assert(ckptIds.nonEmpty, s"Operator $operatorId should have checkpoint 
IDs")
+        // Each checkpoint ID should be a non-empty string
+        ckptIds.foreach { ckptId =>
+          assert(ckptId.nonEmpty,
+            s"Operator $operatorId: Checkpoint ID should be non-empty")

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -454,4 +470,40 @@ object OfflineStateRepartitionTestUtils {
         }
     }
   }
+
+  def verifyCheckpointIds(

Review Comment:
   private



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -450,8 +464,28 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       }
     }
 
-    testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
-      testRoundTripForAggrStateVersion(1)
+    // Run transformWithState tests with enable/disable checkpoint V2
+    Seq(1, 2).foreach { ckptVersion =>
+      def testWithChangelogAndCheckpointId(testName: String)(testFun: => 
Unit): Unit = {
+        test(s"$testName ($changelogCpTestSuffix, enableCkptId = ${ckptVersion 
>= 2})") {
+          withSQLConf(
+            
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->

Review Comment:
   why not only set the format version here, then call 
`testWithChangelogConfig`. Instead of duplicating it here



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala:
##########
@@ -106,71 +106,87 @@ class OfflineStateRepartitionSuite extends StreamTest
     )
   }
 
-  test("Repartition: success, failure, retry") {
-    withTempDir { dir =>
-      val originalPartitions = 3
-      val input = MemoryStream[Int]
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
-      val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
-      // Shouldn't be seen as a repartition batch
-      assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
-
-      // Trying to repartition to the same number should fail
-      val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+  Seq(1, 2).foreach { ckptVersion =>
+    def testWithChangelogAndCheckpointId(testName: String)(testFun: => Unit): 
Unit = {
+      test(s"$testName (enableCkptId = ${ckptVersion >= 2})") {
+        withSQLConf(
+          SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> 
ckptVersion.toString) {
+          testFun
+        }
       }
-      checkError(
-        ex,
-        condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "batchId" -> batchId.toString,
-          "numPartitions" -> originalPartitions.toString
-        )
-      )
-
-      // Trying to repartition to a different number should succeed
-      val newPartitions = originalPartitions + 1
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
-      val repartitionBatchId = batchId + 1
-      val hadoopConf = spark.sessionState.newHadoopConf()
-      verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
-
-      // Now delete the repartition commit to simulate a failed repartition 
attempt.
-      // This will delete all the commits after the batchId.
-      checkpointMetadata.commitLog.purgeAfter(batchId)
+    }
 
-      // Try to repartition with a different numPartitions should fail,
-      // since it will see an uncommitted repartition batch with a different 
numPartitions.
-      val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions + 1)
-      }
-      checkError(
-        ex2,
-        condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "lastBatchId" -> repartitionBatchId.toString,
-          "lastBatchShufflePartitions" -> newPartitions.toString,
-          "numPartitions" -> (newPartitions + 1).toString
+    testWithChangelogAndCheckpointId("Repartition: success, failure, retry") {
+      withTempDir { dir =>
+        val originalPartitions = 3
+        val input = MemoryStream[Int]
+        val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
+        val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
+        // Shouldn't be seen as a repartition batch
+        assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
+
+        // Trying to repartition to the same number should fail
+        val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] 
{
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+        }
+        checkError(
+          ex,
+          condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "batchId" -> batchId.toString,
+            "numPartitions" -> originalPartitions.toString
+          )
         )
-      )
 
-      // Retrying with the same numPartitions should work
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
-      verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+        // Trying to repartition to a different number should succeed
+        val newPartitions = originalPartitions + 1
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
+        val repartitionBatchId = batchId + 1
+        val hadoopConf = spark.sessionState.newHadoopConf()
+        verifyRepartitionBatch(
+          repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+
+        // Now delete the repartition commit to simulate a failed repartition 
attempt.
+        // This will delete all the commits after the batchId.
+        checkpointMetadata.commitLog.purgeAfter(batchId)
+
+        // Try to repartition with a different numPartitions should fail,
+        // since it will see an uncommitted repartition batch with a different 
numPartitions.
+        val ex2 = 
intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions + 1)
+        }
+        checkError(
+          ex2,
+          condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "lastBatchId" -> repartitionBatchId.toString,
+            "lastBatchShufflePartitions" -> newPartitions.toString,
+            "numPartitions" -> (newPartitions + 1).toString
+          )
+        )
 
-      // Repartition with way more partitions, to verify that empty partitions 
are properly created
-      val morePartitions = newPartitions * 3
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
-      verifyRepartitionBatch(
-        repartitionBatchId + 1, checkpointMetadata, hadoopConf,
-        dir.getAbsolutePath, morePartitions)
+        // Retrying with the same numPartitions should work
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
+        verifyRepartitionBatch(
+          repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+
+        // Repartition with way more partitions, to verify that empty 
partitions are properly
+        // created
+        val morePartitions = newPartitions * 3
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
+        verifyRepartitionBatch(
+          repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+          dir.getAbsolutePath, morePartitions)
+        if (spark.conf.get(
+          SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key).toInt >= 2) {
+          verifyCheckpointIds(repartitionBatchId + 1, checkpointMetadata, 
morePartitions)

Review Comment:
   Lets call this `verifyCheckpointIds` within `verifyRepartitionBatch`, just 
like we do for the other verifications. So that by calling 
`verifyRepartitionBatch`, it does all the necessary verifications.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala:
##########
@@ -289,12 +292,32 @@ class OfflineStateRepartitionRunner(
     }
   }
 
-  private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit 
= {
+  private def commitBatch(
+      newBatchId: Long,
+      lastCommittedBatchId: Long,
+      opIdToStateStoreCkptInfo: Option[Map[Long, 
Array[Array[StateStoreCheckpointInfo]]]]): Unit = {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
-
-    // todo: For checkpoint v2, we need to update the stateUniqueIds based on 
the
-    //  newly created state commit. Will be done in subsequent PR.
-    if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+    val commitMetadata = opIdToStateStoreCkptInfo.map {originalInfoMap =>
+      // opIdToStateStoreCkptInfo is Map[operatorId, Array[stateStore -> 
Array[partition -> info]]]
+      // we change it to Map[operatorId, Array[partitionId -> Array[store -> 
info]]]
+      val opIdToPartitionCkptInfo: Map[Long, Array[Array[String]]] =
+        originalInfoMap.map {
+          case(operator, storesSeq) =>
+            val numPartitions = storesSeq.head.length
+            operator -> (0 until numPartitions).map { partitionIdx =>
+              storesSeq.flatMap { storePartitions =>
+                storePartitions(partitionIdx).stateStoreCkptId

Review Comment:
   When we move this conversion to rewriter. We should add some assertions  
i.e. 
   1. that the `storePartitions(partitionIdx).partitionId == partitionIdx`
   2. that the `storePartitions(partitionIdx).batchVersion` is what we expect
   3. same for `storePartitions(partitionIdx).baseStateStoreCkptId == None`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to