This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 813f7576aef2 [SPARK-54590][SS] Support Checkpoint V2 for State 
Rewriter and Repartitioning
813f7576aef2 is described below

commit 813f7576aef225095393bce86df5379409bcd40b
Author: zifeif2 <[email protected]>
AuthorDate: Thu Jan 15 17:38:16 2026 -0800

    [SPARK-54590][SS] Support Checkpoint V2 for State Rewriter and 
Repartitioning
    
    ### What changes were proposed in this pull request?
    
    Support checkpointV2 for repartition writer and StateRewriter by returning 
the checkpoint Id to caller function after write is done.
    Changes include
    - RocksDB loadWithCheckpointId supports loadEmpty
    - StatePartitionAllColumnFamiliesWriter return StateStoreCheckpointInfo
    - StateRewriter also propagate StateStoreCheckpointInfo back to the 
RepartitionRunner
    - RepartitionRunner stores the checkpointIds in commitLog
    
    ### Why are the changes needed?
    
    This is required in PrPr for repartition project
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    See added unit tests on moth operator with single state store and multiple 
state stores
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes. Sonnet 4.5
    
    Closes #53720 from zifeif2/repartition-cp-v2.
    
    Lead-authored-by: zifeif2 <[email protected]>
    Co-authored-by: Zifei Feng <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |   7 +
 .../state/OfflineStateRepartitionRunner.scala      |  14 +-
 .../sql/execution/streaming/state/RocksDB.scala    |  77 +-
 .../streaming/state/RocksDBFileManager.scala       |   2 +
 .../streaming/state/StatePartitionWriter.scala     |   5 +-
 .../execution/streaming/state/StateRewriter.scala  | 114 ++-
 .../state/OfflineStateRepartitionSuite.scala       | 361 ++++++----
 .../execution/streaming/state/RocksDBSuite.scala   | 106 +--
 ...tatePartitionAllColumnFamiliesWriterSuite.scala | 791 +++++++++++----------
 9 files changed, 861 insertions(+), 616 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 2b205908333c..519a3cafd3be 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5603,6 +5603,13 @@
           "Ensure that the checkpoint is for a stateful streaming query and 
the query ran on a Spark version that supports operator metadata (Spark 4.0+)."
         ]
       },
+      "STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH" : {
+        "message" : [
+          "The checkpoint format version in SQLConf does not match the 
checkpoint version in the commit log.",
+          "Expected version <expectedVersion>, but found <actualVersion>.",
+          "Please set '<sqlConfKey>' to <expectedVersion> in your SQLConf 
before retrying."
+        ]
+      },
       "UNSUPPORTED_STATE_STORE_METADATA_VERSION" : {
         "message" : [
           "Unsupported state store metadata version encountered.",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
index 19f75eb385f5..f53efe1cc0ac 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
@@ -95,12 +95,12 @@ class OfflineStateRepartitionRunner(
           transformFunc = Some(stateRepartitionFunc),
           writeCheckpointMetadata = Some(checkpointMetadata)
         )
-        rewriter.run()
+        val operatorToCkptIds = rewriter.run()
 
         updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId = 
lastCommittedBatchId)
 
         // Commit the repartition batch
-        commitBatch(newBatchId, lastCommittedBatchId)
+        commitBatch(newBatchId, lastCommittedBatchId, operatorToCkptIds)
         newBatchId
       }
 
@@ -289,12 +289,14 @@ class OfflineStateRepartitionRunner(
     }
   }
 
-  private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit 
= {
+  private def commitBatch(
+      newBatchId: Long,
+      lastCommittedBatchId: Long,
+      opIdToStateStoreCkptInfo: Option[Map[Long, Array[Array[String]]]]): Unit 
= {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
+    val commitMetadata = latestCommit.copy(stateUniqueIds = 
opIdToStateStoreCkptInfo)
 
-    // todo: For checkpoint v2, we need to update the stateUniqueIds based on 
the
-    //  newly created state commit. Will be done in subsequent PR.
-    if (!checkpointMetadata.commitLog.add(newBatchId, latestCommit)) {
+    if (!checkpointMetadata.commitLog.add(newBatchId, commitMetadata)) {
       throw QueryExecutionErrors.concurrentStreamLogUpdate(newBatchId)
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index f9b94160f9eb..4c9b2282ba27 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -443,13 +443,26 @@ class RocksDB(
   private def loadWithCheckpointId(
       version: Long,
       stateStoreCkptId: Option[String],
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      loadEmpty: Boolean = false): RocksDB = {
     // An array contains lineage information from [snapShotVersion, version]
     // (inclusive in both ends)
     var currVersionLineage: Array[LineageItem] = 
lineageManager.getLineageForCurrVersion()
     try {
-      if (loadedVersion != version || (loadedStateStoreCkptId.isEmpty ||
-          stateStoreCkptId.get != loadedStateStoreCkptId.get)) {
+      if (loadEmpty) {
+        // Handle empty store loading separately for clarity
+        require(stateStoreCkptId.isEmpty,
+          "stateStoreCkptId should be empty when loadEmpty is true")
+        closeDB(ignoreException = false)
+        loadEmptyStore(version)
+        lineageManager.clear()
+        // After loading empty store, set the key counts and metrics
+        numKeysOnLoadedVersion = numKeysOnWritingVersion
+        numInternalKeysOnLoadedVersion = numInternalKeysOnWritingVersion
+        fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
+      } else if (loadedVersion != version || loadedStateStoreCkptId.isEmpty ||
+        stateStoreCkptId.get != loadedStateStoreCkptId.get) {
+        // Handle normal checkpoint loading
         closeDB(ignoreException = false)
 
         val (latestSnapshotVersion, latestSnapshotUniqueId) = {
@@ -508,9 +521,9 @@ class RocksDB(
 
         if (loadedVersion != version) {
           val versionsAndUniqueIds = currVersionLineage.collect {
-              case i if i.version > loadedVersion && i.version <= version =>
-                (i.version, Option(i.checkpointUniqueId))
-            }
+            case i if i.version > loadedVersion && i.version <= version =>
+              (i.version, Option(i.checkpointUniqueId))
+          }
           replayChangelog(versionsAndUniqueIds)
           loadedVersion = version
           lineageManager.resetLineage(currVersionLineage)
@@ -530,9 +543,13 @@ class RocksDB(
       if (conf.resetStatsOnLoad) {
         nativeStats.reset
       }
-
-      logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)} " +
-        log"with uniqueId ${MDC(LogKeys.UUID, stateStoreCkptId)}")
+      if (loadEmpty) {
+        logInfo(log"Loaded empty store at version ${MDC(LogKeys.VERSION_NUM, 
version)} " +
+          log"with empty uniqueId")
+      } else {
+        logInfo(log"Loaded ${MDC(LogKeys.VERSION_NUM, version)} " +
+          log"with uniqueId ${MDC(LogKeys.UUID, stateStoreCkptId)}")
+      }
     } catch {
       case t: Throwable =>
         loadedVersion = -1  // invalidate loaded data
@@ -543,25 +560,33 @@ class RocksDB(
         lineageManager.clear()
         throw t
     }
-    if (enableChangelogCheckpointing && !readOnly) {
+    // Checking conf.enableChangelogCheckpointing instead of 
enableChangelogCheckpointing.
+    // enableChangelogCheckpointing is set to false when loadEmpty is true, 
but we still want
+    // to abort previous used changelogWriter if there is any
+    if (conf.enableChangelogCheckpointing && !readOnly) {
       // Make sure we don't leak resource.
       changelogWriter.foreach(_.abort())
-      // Initialize the changelog writer with lineage info
-      // The lineage stored in changelog files should normally start with
-      // the version of a snapshot, except for the first few versions.
-      // Because they are solely loaded from changelog file.
-      // (e.g. with default minDeltasForSnapshot, there is only 
1_uuid1.changelog, no 1_uuid1.zip)
-      // It should end with exactly one version before the change log's 
version.
-      changelogWriter = Some(fileManager.getChangeLogWriter(
-        version + 1,
-        useColumnFamilies,
-        sessionStateStoreCkptId,
-        Some(currVersionLineage)))
+      if (loadEmpty) {
+        // We don't want to write changelog file when loadEmpty is true
+        changelogWriter = None
+      } else {
+        // Initialize the changelog writer with lineage info
+        // The lineage stored in changelog files should normally start with
+        // the version of a snapshot, except for the first few versions.
+        // Because they are solely loaded from changelog file.
+        // (e.g. with default minDeltasForSnapshot, there is only 
1_uuid1.changelog, no 1_uuid1.zip)
+        // It should end with exactly one version before the change log's 
version.
+        changelogWriter = Some(fileManager.getChangeLogWriter(
+          version + 1,
+          useColumnFamilies,
+          sessionStateStoreCkptId,
+          Some(currVersionLineage)))
+      }
     }
     this
   }
 
-  private def loadEmptyStoreWithoutCheckpointId(version: Long): Unit = {
+  private def loadEmptyStore(version: Long): Unit = {
     // Use version 0 logic to create empty directory with no SST files
     val metadata = fileManager.loadCheckpointFromDfs(0, workingDir, 
rocksDBFileMapping, None)
     loadedVersion = version
@@ -580,7 +605,7 @@ class RocksDB(
         closeDB(ignoreException = false)
 
         if (loadEmpty) {
-          loadEmptyStoreWithoutCheckpointId(version)
+          loadEmptyStore(version)
         } else {
           // load the latest snapshot
           loadSnapshotWithoutCheckpointId(version)
@@ -752,9 +777,9 @@ class RocksDB(
     // If loadEmpty is true, we will not generate a changelog but only a 
snapshot file to prevent
     // mistakenly applying new changelog to older state version
     enableChangelogCheckpointing = if (loadEmpty) false else 
conf.enableChangelogCheckpointing
-    if (stateStoreCkptId.isDefined || enableStateStoreCheckpointIds && version 
== 0) {
-      assert(!loadEmpty, "loadEmpty not supported for checkpointV2 yet")
-      loadWithCheckpointId(version, stateStoreCkptId, readOnly)
+    if (stateStoreCkptId.isDefined ||
+      enableStateStoreCheckpointIds && (version == 0 || loadEmpty)) {
+      loadWithCheckpointId(version, stateStoreCkptId, readOnly, loadEmpty)
     } else {
       loadWithoutCheckpointId(version, readOnly, loadEmpty)
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 75e459b26511..7135421f4866 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -375,6 +375,8 @@ class RocksDBFileManager(
       createDfsRootDirIfNotExist()
       // Since we cleared the local dir, we should also clear the local file 
mapping
       rocksDBFileMapping.clear()
+      // Set empty metrics since we're not loading any files from DFS
+      loadCheckpointMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS
       RocksDBCheckpointMetadata(Seq.empty, 0)
     } else {
       // Delete all non-immutable files in local dir, and unzip new ones from 
DFS commit file
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
index 3df97d3adc0e..64c0fb6d6e76 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala
@@ -145,10 +145,13 @@ class StatePartitionAllColumnFamiliesWriter(
   // - key_bytes, BinaryType
   // - value_bytes, BinaryType
   // - column_family_name, StringType
-  def write(rows: Iterator[InternalRow]): Unit = {
+  // Returns StateStoreCheckpointInfo containing the checkpoint ID after 
commit if
+  // enabled checkpointV2
+  def write(rows: Iterator[InternalRow]): StateStoreCheckpointInfo = {
     try {
       rows.foreach(row => writeRow(row))
       stateStore.commit()
+      stateStore.getStateStoreCheckpointInfo()
     } finally {
       if (!stateStore.hasCommitted) {
         stateStore.abort()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
index da28a3c907f7..42561ec5ca7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateRewriter.scala
@@ -22,16 +22,17 @@ import java.util.UUID
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.{SparkIllegalStateException, TaskContext}
+import org.apache.spark.{SparkIllegalStateException, SparkThrowable, 
TaskContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
 import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
 import org.apache.spark.sql.execution.streaming.checkpointing.OffsetSeqMetadata
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorsUtils
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo,
 StatefulOperatorsUtils}
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
 TransformWithStateOperatorProperties, TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants, 
StreamingQueryCheckpointMetadata}
 import 
org.apache.spark.sql.execution.streaming.state.{StatePartitionAllColumnFamiliesWriter,
 StateSchemaCompatibilityChecker}
@@ -80,8 +81,19 @@ class StateRewriter(
     readResolvedCheckpointLocation.getOrElse(resolvedCheckpointLocation)
   private val stateRootLocation = new Path(
     resolvedCheckpointLocation, 
StreamingCheckpointConstants.DIR_NAME_STATE).toString
+  private lazy val writeCheckpoint = writeCheckpointMetadata.getOrElse(
+    new StreamingQueryCheckpointMetadata(sparkSession, 
resolvedCheckpointLocation))
+  private lazy val readCheckpoint = if 
(readResolvedCheckpointLocation.isDefined) {
+    new StreamingQueryCheckpointMetadata(sparkSession, 
readResolvedCheckpointLocation.get)
+  } else {
+    // Same checkpoint for read & write
+    writeCheckpoint
+  }
 
-  def run(): Unit = {
+  // If checkpoint id is enabled, return
+  // Map[operatorId, Array[partition -> Array[stateStore -> 
StateStoreCheckpointId]]].
+  // Otherwise, return None
+  def run(): Option[Map[Long, Array[Array[String]]]] = {
     logInfo(log"Starting state rewrite for " +
       log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}, " +
       log"readCheckpointLocation=" +
@@ -89,16 +101,45 @@ class StateRewriter(
       log"readBatchId=${MDC(BATCH_ID, readBatchId)}, " +
       log"writeBatchId=${MDC(BATCH_ID, writeBatchId)}")
 
-    val (_, timeTakenMs) = Utils.timeTakenMs {
+    val (checkpointIds, timeTakenMs) = Utils.timeTakenMs {
       runInternal()
     }
 
     logInfo(log"State rewrite completed in ${MDC(DURATION, timeTakenMs)} ms 
for " +
       log"checkpointLocation=${MDC(CHECKPOINT_LOCATION, 
resolvedCheckpointLocation)}")
+    checkpointIds
   }
 
-  private def runInternal(): Unit = {
+  private def extractCheckpointIdsPerPartition(
+      checkpointInfos: Map[Long, Array[Array[StateStoreCheckpointInfo]]]
+    ): Option[Map[Long, Array[Array[String]]]] = {
+    val enableCheckpointId = StatefulOperatorStateInfo.
+      enableStateStoreCheckpointIds(sparkSession.sessionState.conf)
+    if (!enableCheckpointId) {
+      return None
+    }
+    // convert Map[operatorId, Array[stateStore -> Array[partition -> 
StateStoreCheckpointInfo]]]
+    // to Map[operatorId, Array[partition -> Array[stateStore -> 
StateStoreCheckpointId]]].
+    Option(checkpointInfos.map {
+      case(operator, storesSeq) =>
+        val numPartitions = storesSeq.head.length
+        operator -> (0 until numPartitions).map { partitionIdx =>
+          storesSeq.map { storePartitions =>
+            val checkpointInfoPerPartition = storePartitions(partitionIdx)
+            assert(checkpointInfoPerPartition.partitionId == partitionIdx)
+            assert(checkpointInfoPerPartition.batchVersion == writeBatchId + 1)
+            // expect baseStateStoreCkptId empty because we didn't load
+            // any previous stores when doing the rewrite
+            assert(checkpointInfoPerPartition.baseStateStoreCkptId.isEmpty)
+            checkpointInfoPerPartition.stateStoreCkptId.get
+          }
+        }.toArray
+    })
+  }
+
+  private def runInternal(): Option[Map[Long, Array[Array[String]]]] = {
     try {
+      verifyCheckpointFormatVersion()
       val stateMetadataReader = new StateMetadataPartitionReader(
         resolvedCheckpointLocation,
         new SerializableConfiguration(hadoopConf),
@@ -124,7 +165,7 @@ class StateRewriter(
 
       // Do rewrite for each operator
       // We can potentially parallelize this, but for now, do sequentially
-      allOperatorsMetadata.foreach { opMetadata =>
+      val checkpointInfos = allOperatorsMetadata.map { opMetadata =>
         val stateStoresMetadata = opMetadata.stateStoresMetadata
         assert(!stateStoresMetadata.isEmpty,
           s"Operator ${opMetadata.operatorInfo.operatorName} has no state 
stores")
@@ -133,7 +174,7 @@ class StateRewriter(
         val stateVarsIfTws = getStateVariablesIfTWS(opMetadata)
 
         // Rewrite each state store of the operator
-        stateStoresMetadata.foreach { stateStoreMetadata =>
+        val checkpointInfo = stateStoresMetadata.map { stateStoreMetadata =>
           rewriteStore(
             opMetadata,
             stateStoreMetadata,
@@ -143,8 +184,10 @@ class StateRewriter(
             stateVarsIfTws,
             sqlConfEntries
           )
-        }
-      }
+        }.toArray
+        opMetadata.operatorInfo.operatorId -> checkpointInfo
+      }.toMap
+      extractCheckpointIdsPerPartition(checkpointInfos)
     } catch {
       case e: Throwable =>
         logError(log"State rewrite failed for " +
@@ -163,7 +206,7 @@ class StateRewriter(
       storeSchemaFiles: List[Path],
       stateVarsIfTws: Map[String, TransformWithStateVariableInfo],
       sqlConfEntries: Map[String, String]
-  ): Unit = {
+  ): Array[StateStoreCheckpointInfo] = {
     // Read state
     val stateDf = sparkSession.read
       .format("statestore")
@@ -206,7 +249,7 @@ class StateRewriter(
     // to avoid serializing the entire Rewriter object per partition.
     val targetCheckpointLocation = resolvedCheckpointLocation
     val currentBatchId = writeBatchId
-    updatedStateDf.queryExecution.toRdd.foreachPartition { partitionIter =>
+    updatedStateDf.queryExecution.toRdd.mapPartitions { partitionIter: 
Iterator[InternalRow] =>
       // Recreate SQLConf on executor from serialized entries
       val executorSqlConf = new SQLConf()
       sqlConfEntries.foreach { case (k, v) => executorSqlConf.setConfString(k, 
v) }
@@ -224,9 +267,8 @@ class StateRewriter(
         schemaProvider,
         executorSqlConf
       )
-
-      partitionWriter.write(partitionIter)
-    }
+      Iterator(partitionWriter.write(partitionIter))
+    }.collect()
   }
 
   /** Create the store and sql confs from the conf written in the offset log */
@@ -342,6 +384,30 @@ class StateRewriter(
       None
     }
   }
+
+  private def verifyCheckpointFormatVersion(): Unit = {
+    // Verify checkpoint version in sqlConf based on commitLog for 
readCheckpoint
+    // in case user forgot to set STATE_STORE_CHECKPOINT_FORMAT_VERSION.
+    // Using read batch commit since the latest commit could be a skipped 
batch.
+    // If SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION is wrong, 
readCheckpoint.commitLog
+    // will throw an exception, and we will propagate this exception upstream.
+    // This prevents the StateRewriter from failing to write the correct state 
files
+    try {
+      readCheckpoint.commitLog.get(readBatchId)
+    } catch {
+        case e: IllegalStateException if e.getCause != null &&
+            e.getCause.isInstanceOf[SparkThrowable] =>
+          val sparkThrowable = e.getCause.asInstanceOf[SparkThrowable]
+          if (sparkThrowable.getCondition == 
"INVALID_LOG_VERSION.EXACT_MATCH_VERSION") {
+            val params = sparkThrowable.getMessageParameters
+            val expectedVersion = params.get("version")
+            val actualVersion = params.get("matchVersion")
+            throw 
StateRewriterErrors.stateCheckpointFormatVersionMismatchError(
+              checkpointLocationForRead, expectedVersion, actualVersion)
+          }
+          throw e
+      }
+  }
 }
 
 /**
@@ -364,6 +430,14 @@ private[state] object StateRewriterErrors {
       checkpointLocation: String): StateRewriterInvalidCheckpointError = {
     new StateRewriterUnsupportedStoreMetadataVersionError(checkpointLocation)
   }
+
+  def stateCheckpointFormatVersionMismatchError(
+      checkpointLocation: String,
+      expectedVersion: String,
+      actualVersion: String): StateRewriterInvalidCheckpointError = {
+    new StateRewriterStateCheckpointFormatVersionMismatchError(
+      checkpointLocation, expectedVersion, actualVersion)
+  }
 }
 
 /**
@@ -402,3 +476,15 @@ private[state] class 
StateRewriterUnsupportedStoreMetadataVersionError(
     checkpointLocation,
     subClass = "UNSUPPORTED_STATE_STORE_METADATA_VERSION",
     messageParameters = Map.empty)
+
+private[state] class StateRewriterStateCheckpointFormatVersionMismatchError(
+    checkpointLocation: String,
+    expectedVersion: String,
+    actualVersion: String)
+  extends StateRewriterInvalidCheckpointError(
+    checkpointLocation,
+    subClass = "STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH",
+    messageParameters = Map(
+      "expectedVersion" -> expectedVersion,
+      "actualVersion" -> actualVersion,
+      "sqlConfKey" -> SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
index 1ab581d79437..b4bd50dfb42f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
@@ -21,8 +21,10 @@ import scala.util.Try
 
 import org.apache.hadoop.conf.Configuration
 
+import org.apache.spark.sql.SparkSession
 import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
 import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
CommitMetadata}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.StatefulOperatorStateInfo
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
@@ -106,147 +108,171 @@ class OfflineStateRepartitionSuite extends StreamTest
     )
   }
 
-  test("Repartition: success, failure, retry") {
-    withTempDir { dir =>
-      val originalPartitions = 3
-      val input = MemoryStream[Int]
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
-      val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
-      // Shouldn't be seen as a repartition batch
-      assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
-
-      // Trying to repartition to the same number should fail
-      val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+  Seq(1, 2).foreach { ckptVersion =>
+    def testWithCheckpointId(testName: String)(testFun: => Unit): Unit = {
+      test(s"$testName (enableCkptId = ${ckptVersion >= 2})") {
+        withSQLConf(
+          SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> 
ckptVersion.toString) {
+          testFun
+        }
       }
-      checkError(
-        ex,
-        condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "batchId" -> batchId.toString,
-          "numPartitions" -> originalPartitions.toString
-        )
-      )
+    }
 
-      // Trying to repartition to a different number should succeed
-      val newPartitions = originalPartitions + 1
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
-      val repartitionBatchId = batchId + 1
-      val hadoopConf = spark.sessionState.newHadoopConf()
-      verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
-
-      // Now delete the repartition commit to simulate a failed repartition 
attempt.
-      // This will delete all the commits after the batchId.
-      checkpointMetadata.commitLog.purgeAfter(batchId)
-
-      // Try to repartition with a different numPartitions should fail,
-      // since it will see an uncommitted repartition batch with a different 
numPartitions.
-      val ex2 = intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions + 1)
-      }
-      checkError(
-        ex2,
-        condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "lastBatchId" -> repartitionBatchId.toString,
-          "lastBatchShufflePartitions" -> newPartitions.toString,
-          "numPartitions" -> (newPartitions + 1).toString
+    testWithCheckpointId("Repartition: success, failure, retry") {
+      withTempDir { dir =>
+        val originalPartitions = 3
+        val input = MemoryStream[Int]
+        val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
+        val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
+        // Shouldn't be seen as a repartition batch
+        assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
+
+        // Trying to repartition to the same number should fail
+        val ex = intercept[StateRepartitionShufflePartitionsAlreadyMatchError] 
{
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions)
+        }
+        checkError(
+          ex,
+          condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.SHUFFLE_PARTITIONS_ALREADY_MATCH",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "batchId" -> batchId.toString,
+            "numPartitions" -> originalPartitions.toString
+          )
         )
-      )
-
-      // Retrying with the same numPartitions should work
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
-      verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
 
-      // Repartition with way more partitions, to verify that empty partitions 
are properly created
-      val morePartitions = newPartitions * 3
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
-      verifyRepartitionBatch(
-        repartitionBatchId + 1, checkpointMetadata, hadoopConf,
-        dir.getAbsolutePath, morePartitions)
+        // Trying to repartition to a different number should succeed
+        val newPartitions = originalPartitions + 1
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
+        val repartitionBatchId = batchId + 1
+        val hadoopConf = spark.sessionState.newHadoopConf()
+        verifyRepartitionBatch(
+          repartitionBatchId,
+          checkpointMetadata,
+          hadoopConf,
+          dir.getAbsolutePath,
+          newPartitions,
+          spark)
+
+        // Now delete the repartition commit to simulate a failed repartition 
attempt.
+        // This will delete all the commits after the batchId.
+        checkpointMetadata.commitLog.purgeAfter(batchId)
+
+        // Try to repartition with a different numPartitions should fail,
+        // since it will see an uncommitted repartition batch with a different 
numPartitions.
+        val ex2 = 
intercept[StateRepartitionLastBatchAbandonedRepartitionError] {
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions + 1)
+        }
+        checkError(
+          ex2,
+          condition = 
"STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_ABANDONED_REPARTITION",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "lastBatchId" -> repartitionBatchId.toString,
+            "lastBatchShufflePartitions" -> newPartitions.toString,
+            "numPartitions" -> (newPartitions + 1).toString
+          )
+        )
 
-      // Restart the query to make sure it can start after repartitioning
-      runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+        // Retrying with the same numPartitions should work
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
+        verifyRepartitionBatch(
+          repartitionBatchId,
+          checkpointMetadata,
+          hadoopConf,
+          dir.getAbsolutePath,
+          newPartitions,
+          spark)
+
+        // Repartition with way more partitions, to verify that empty 
partitions are properly
+        // created
+        val morePartitions = newPartitions * 3
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
+        verifyRepartitionBatch(
+          repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+          dir.getAbsolutePath, morePartitions, spark)
+        // Restart the query to make sure it can start after repartitioning
+        runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
+      }
     }
-  }
 
-  test("Query last batch failed before repartitioning") {
-    withTempDir { dir =>
-      val originalPartitions = 3
-      val input = MemoryStream[Int]
-      // Run 3 batches
-      val firstBatchId = 0
-      val lastBatchId = firstBatchId + 2
-      (firstBatchId to lastBatchId).foreach { _ =>
-        runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input)
-      }
-      val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
+    testWithCheckpointId("Query last batch failed before repartitioning") {
+      withTempDir { dir =>
+        val originalPartitions = 3
+        val input = MemoryStream[Int]
+        // Run 3 batches
+        val firstBatchId = 0
+        val lastBatchId = firstBatchId + 2
+        (firstBatchId to lastBatchId).foreach { _ =>
+          runSimpleStreamQuery(originalPartitions, dir.getAbsolutePath, input)
+        }
+        val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
 
-      // Lets keep only the first commit to simulate multiple failed batches
-      checkpointMetadata.commitLog.purgeAfter(firstBatchId)
+        // Lets keep only the first commit to simulate multiple failed batches
+        checkpointMetadata.commitLog.purgeAfter(firstBatchId)
 
-      // Now repartitioning should fail
-      val ex = intercept[StateRepartitionLastBatchFailedError] {
-        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions + 1)
-      }
-      checkError(
-        ex,
-        condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_FAILED",
-        parameters = Map(
-          "checkpointLocation" -> dir.getAbsolutePath,
-          "lastBatchId" -> lastBatchId.toString
+        // Now repartitioning should fail
+        val ex = intercept[StateRepartitionLastBatchFailedError] {
+          spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions + 1)
+        }
+        checkError(
+          ex,
+          condition = "STATE_REPARTITION_INVALID_CHECKPOINT.LAST_BATCH_FAILED",
+          parameters = Map(
+            "checkpointLocation" -> dir.getAbsolutePath,
+            "lastBatchId" -> lastBatchId.toString
+          )
         )
-      )
 
-      // Setting enforceExactlyOnceSink to false should allow repartitioning
-      spark.streamingCheckpointManager.repartition(
-        dir.getAbsolutePath, originalPartitions + 1, enforceExactlyOnceSink = 
false)
-      verifyRepartitionBatch(
-        lastBatchId + 1,
-        checkpointMetadata,
-        spark.sessionState.newHadoopConf(),
-        dir.getAbsolutePath,
-        originalPartitions + 1,
-        // Repartition should be based on the first batch, since we skipped 
the others
-        baseBatchId = Some(firstBatchId))
+        // Setting enforceExactlyOnceSink to false should allow repartitioning
+        spark.streamingCheckpointManager.repartition(
+          dir.getAbsolutePath, originalPartitions + 1, enforceExactlyOnceSink 
= false)
+        verifyRepartitionBatch(
+          lastBatchId + 1,
+          checkpointMetadata,
+          spark.sessionState.newHadoopConf(),
+          dir.getAbsolutePath,
+          originalPartitions + 1,
+          spark,
+          // Repartition should be based on the first batch, since we skipped 
the others
+          baseBatchId = Some(firstBatchId))
+      }
     }
-  }
 
-  test("Consecutive repartition") {
-    withTempDir { dir =>
-      val originalPartitions = 5
-      val input = MemoryStream[Int]
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
-
-      val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
-      val hadoopConf = spark.sessionState.newHadoopConf()
-
-      // decrease
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions - 3)
-      verifyRepartitionBatch(
-        batchId + 1,
-        checkpointMetadata,
-        hadoopConf,
-        dir.getAbsolutePath,
-        originalPartitions - 3
-      )
+    testWithCheckpointId("Consecutive repartition") {
+      withTempDir { dir =>
+        val originalPartitions = 5
+        val input = MemoryStream[Int]
+        val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
+
+        val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
+        val hadoopConf = spark.sessionState.newHadoopConf()
+
+        // decrease
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions - 3)
+        verifyRepartitionBatch(
+          batchId + 1,
+          checkpointMetadata,
+          hadoopConf,
+          dir.getAbsolutePath,
+          originalPartitions - 3,
+          spark
+        )
 
-      // increase
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions + 1)
-      verifyRepartitionBatch(
-        batchId + 2,
-        checkpointMetadata,
-        hadoopConf,
-        dir.getAbsolutePath,
-        originalPartitions + 1
-      )
+        // increase
+        spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions + 1)
+        verifyRepartitionBatch(
+          batchId + 2,
+          checkpointMetadata,
+          hadoopConf,
+          dir.getAbsolutePath,
+          originalPartitions + 1,
+          spark
+        )
 
-      // Restart the query to make sure it can start after repartitioning
-      runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath, input)
+        // Restart the query to make sure it can start after repartitioning
+        runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath, 
input)
+      }
     }
   }
 
@@ -285,6 +311,7 @@ object OfflineStateRepartitionTestUtils {
       hadoopConf: Configuration,
       checkpointLocation: String,
       expectedShufflePartitions: Int,
+      spark: SparkSession,
       baseBatchId: Option[Long] = None): Unit = {
     // Should be seen as a repartition batch
     assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
checkpointLocation))
@@ -296,8 +323,21 @@ object OfflineStateRepartitionTestUtils {
     verifyOffsetAndCommitLog(
       batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata)
     verifyPartitionDirs(checkpointLocation, expectedShufflePartitions)
+
+    val serializableConf = new SerializableConfiguration(hadoopConf)
+    val baseOperatorsMetadata = getOperatorMetadata(
+      checkpointLocation, serializableConf, previousBatchId)
+    val repartitionOperatorsMetadata = getOperatorMetadata(
+      checkpointLocation, serializableConf, batchId)
     verifyOperatorMetadata(
-      batchId, previousBatchId, checkpointLocation, expectedShufflePartitions, 
hadoopConf)
+      baseOperatorsMetadata, repartitionOperatorsMetadata, 
expectedShufflePartitions)
+    if 
(StatefulOperatorStateInfo.enableStateStoreCheckpointIds(spark.sessionState.conf))
 {
+      verifyCheckpointIds(
+        batchId,
+        checkpointMetadata,
+        expectedShufflePartitions,
+        baseOperatorsMetadata)
+    }
   }
 
   private def verifyOffsetAndCommitLog(
@@ -376,23 +416,20 @@ object OfflineStateRepartitionTestUtils {
     }
   }
 
-  private def verifyOperatorMetadata(
-      repartitionBatchId: Long,
-      baseBatchId: Long,
+  private def getOperatorMetadata(
       checkpointLocation: String,
-      expectedShufflePartitions: Int,
-      hadoopConf: Configuration): Unit = {
-    val serializableConf = new SerializableConfiguration(hadoopConf)
-
-    // Read operator metadata for both batches
-    val baseMetadataReader = new StateMetadataPartitionReader(
-      checkpointLocation, serializableConf, baseBatchId)
-    val repartitionMetadataReader = new StateMetadataPartitionReader(
-      checkpointLocation, serializableConf, repartitionBatchId)
-
-    val baseOperatorsMetadata = baseMetadataReader.allOperatorStateMetadata
-    val repartitionOperatorsMetadata = 
repartitionMetadataReader.allOperatorStateMetadata
+      serializableConf: SerializableConfiguration,
+      batchId: Long
+    ): Array[OperatorStateMetadata] = {
+    val metadataPartitionReader = new StateMetadataPartitionReader(
+      checkpointLocation, serializableConf, batchId)
+    metadataPartitionReader.allOperatorStateMetadata
+  }
 
+  private def verifyOperatorMetadata(
+      baseOperatorsMetadata: Array[OperatorStateMetadata],
+      repartitionOperatorsMetadata: Array[OperatorStateMetadata],
+      expectedShufflePartitions: Int): Unit = {
     assert(baseOperatorsMetadata.nonEmpty, "Base batch should have operator 
metadata")
     assert(repartitionOperatorsMetadata.nonEmpty, "Repartition batch should 
have operator metadata")
     assert(baseOperatorsMetadata.length == repartitionOperatorsMetadata.length,
@@ -454,4 +491,44 @@ object OfflineStateRepartitionTestUtils {
         }
     }
   }
+
+  private def verifyCheckpointIds(
+      repartitionBatchId: Long,
+      checkpointMetadata: StreamingQueryCheckpointMetadata,
+      expectedShufflePartitions: Int,
+      baseOperatorsMetadata: Array[OperatorStateMetadata]): Unit = {
+    val expectedStoreCnts: Map[Long, Int] = baseOperatorsMetadata.map {
+      case metadataV2: OperatorStateMetadataV2 =>
+        metadataV2.operatorInfo.operatorId -> metadataV2.stateStoreInfo.length
+      case metadataV1: OperatorStateMetadataV1 =>
+        metadataV1.operatorInfo.operatorId -> metadataV1.stateStoreInfo.length
+    }.toMap
+    // Verify commit log has the repartition batch with checkpoint IDs
+    val commitOpt = checkpointMetadata.commitLog.get(repartitionBatchId)
+    assert(commitOpt.isDefined, s"Commit for batch $repartitionBatchId should 
exist")
+
+    val commitMetadata = commitOpt.get
+
+    // Verify stateUniqueIds is present for checkpoint V2
+    assert(commitMetadata.stateUniqueIds.isDefined,
+      "stateUniqueIds should be present in commit metadata when checkpoint 
version >= 2")
+
+    val operatorIdToCkptInfos = commitMetadata.stateUniqueIds.get
+    assert(operatorIdToCkptInfos.nonEmpty,
+      "operatorIdToCkptInfos should not be empty")
+
+    // Verify structure for each operator
+    operatorIdToCkptInfos.foreach { case (operatorId, partitionToCkptIds) =>
+      // Should have checkpoint IDs for all partitions
+      assert(partitionToCkptIds.length == expectedShufflePartitions,
+        s"Operator $operatorId: Expected $expectedShufflePartitions partition 
checkpoint IDs, " +
+          s"but found ${partitionToCkptIds.length}")
+      // Each partition should have checkpoint IDs (at least one per store)
+      partitionToCkptIds.zipWithIndex.foreach { case (ckptIds, partitionId) =>
+        assert(ckptIds.length == expectedStoreCnts(operatorId),
+            s"Operator $operatorId, partition $partitionId should" +
+            s"has ${expectedStoreCnts(operatorId)} checkpoint Ids")
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index ad532fb71bf0..36b65fdf50e5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -3943,48 +3943,68 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
-  test("SPARK-54420: load with createEmpty creates empty store") {
-    val remoteDir = Utils.createTempDir().toString
-    new File(remoteDir).delete()
-
-    withDB(remoteDir) { db =>
-      // loading batch 0 with loadEmpty = true
-      db.load(0, None, loadEmpty = true)
-      assert(iterator(db).isEmpty)
-      db.put("a", "1")
-      val (version1, _) = db.commit()
-      assert(toStr(db.get("a")) === "1")
+  testWithStateStoreCheckpointIds(
+    "SPARK-54420: load with createEmpty creates empty store") { enableCkptId =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      var lastVersion = 0L
+      var lastCheckpointInfo: Option[StateStoreCheckpointInfo] = None
 
-      // check we can load store normally even the previous one loadEmpty = 
true
-      db.load(version1)
-      db.put("b", "2")
-      val (version2, _) = db.commit()
-      assert(version2 === version1 + 1)
-      assert(toStr(db.get("b")) === "2")
-      assert(toStr(db.get("a")) === "1")
+      withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+        // loading batch 0 with loadEmpty = true
+        db.load(0, None, loadEmpty = true)
+        assert(iterator(db).isEmpty)
+        db.put("a", "1")
+        val (version1, checkpointInfoV1) = db.commit()
+        assert(toStr(db.get("a")) === "1")
 
-      // load an empty store
-      db.load(version2, loadEmpty = true)
-      db.put("c", "3")
-      val (version3, _) = db.commit()
-      assert(db.get("b") === null)
-      assert(db.get("a") === null)
-      assert(toStr(db.get("c")) === "3")
-      assert(version3 === version2 + 1)
-
-      // load 2 empty store in a row
-      db.load(version3, loadEmpty = true)
-      db.put("d", "4")
-      val (version4, _) = db.commit()
-      assert(db.get("c") === null)
-      assert(toStr(db.get("d")) === "4")
-      assert(version4 === version3 + 1)
-
-      db.load(version4)
-      db.put("e", "5")
-      db.commit()
-      assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
-    }
+        // check we can load store normally even the previous one loadEmpty = 
true
+        db.load(version1, checkpointInfoV1.stateStoreCkptId)
+        db.put("b", "2")
+        val (version2, _) = db.commit()
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(toStr(db.get("a")) === "1")
+
+        // load an empty store
+        db.load(version2, loadEmpty = true)
+        db.put("c", "3")
+        val (version3, _) = db.commit()
+        assert(db.get("b") === null)
+        assert(db.get("a") === null)
+        assert(toStr(db.get("c")) === "3")
+        assert(version3 === version2 + 1)
+
+        // load 2 empty store in a row
+        db.load(version3, loadEmpty = true)
+        db.put("d", "4")
+
+        val (version4, checkpointV4) = db.commit()
+        assert(db.get("c") === null)
+        assert(toStr(db.get("d")) === "4")
+        lastVersion = version4
+        lastCheckpointInfo = Option(checkpointV4)
+        assert(lastVersion === version3 + 1)
+      }
+
+      withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+        db.load(lastVersion, lastCheckpointInfo.map(_.stateStoreCkptId).orNull)
+        db.put("e", "5")
+        db.commit()
+        assert(db.iterator().map(toStr).toSet === Set(("d", "4"), ("e", "5")))
+      }
+
+      if (enableCkptId) {
+        withDB(remoteDir, enableStateStoreCheckpointIds = enableCkptId) { db =>
+          val ex = intercept[IllegalArgumentException] {
+            db.load(
+              lastVersion,
+              lastCheckpointInfo.map(_.stateStoreCkptId).orNull,
+              loadEmpty = true)
+          }
+          assert(ex.getMessage.contains("stateStoreCkptId should be empty when 
loadEmpty is true"))
+        }
+      }
   }
 
   test("SPARK-44639: Use Java tmp dir instead of configured local dirs on 
Yarn") {
@@ -4027,13 +4047,13 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
         version: Long,
         ckptId: Option[String] = None,
         readOnly: Boolean = false,
-        createEmpty: Boolean = false): RocksDB = {
+        loadEmpty: Boolean = false): RocksDB = {
       // When a ckptId is defined, it means the test is explicitly using v2 
semantic
       // When it is not, it is possible that implicitly uses it.
       // So still do a versionToUniqueId.get
       ckptId match {
-        case Some(_) => super.load(version, ckptId, readOnly)
-        case None => super.load(version, versionToUniqueId.get(version), 
readOnly)
+        case Some(_) => super.load(version, ckptId, readOnly, loadEmpty)
+        case None => super.load(version, versionToUniqueId.get(version), 
readOnly, loadEmpty)
       }
     }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
index e495db499bfe..dc6b9d7e8987 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala
@@ -95,11 +95,12 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       transformFunc = None,
       writeCheckpointMetadata = Some(targetCheckpointMetadata)
     )
-    rewriter.run()
+    val checkpointInfos = rewriter.run()
 
-    // Commit to commitLog
+    // Commit to commitLog with checkpoint IDs
     val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
-    targetCheckpointMetadata.commitLog.add(writeBatchId, latestCommit)
+    val commitMetadata = latestCommit.copy(stateUniqueIds = checkpointInfos)
+    targetCheckpointMetadata.commitLog.add(writeBatchId, commitMetadata)
     val versionToCheck = writeBatchId + 1
 
     storeToColumnFamilies.foreach { case (storeName, columnFamilies) =>
@@ -215,20 +216,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       useMultipleValuePerKey)
   }
 
-  /**
-   * Helper method to create a single-entry column family schema map.
-   * This simplifies the common case where only the default column family is 
used.
-   */
-  private def createSingleColumnFamilySchemaMap(
-      keySchema: StructType,
-      valueSchema: StructType,
-      keyStateEncoderSpec: KeyStateEncoderSpec,
-      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME
-  ): Map[String, StatePartitionWriterColumnFamilyInfo] = {
-    Map(colFamilyName -> createColFamilyInfo(keySchema, valueSchema,
-      keyStateEncoderSpec, colFamilyName))
-  }
-
   /**
    * Helper method to test SPARK-54420 read and write with different state 
format versions
    * for simple aggregation (single grouping key).
@@ -307,17 +294,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     }
   }
 
-  private def getJoinV3ColumnSchemaMap(): Map[String, 
StatePartitionWriterColumnFamilyInfo] = {
-    val schemas = 
StreamStreamJoinTestUtils.getJoinV3ColumnSchemaMapWithMetadata()
-    schemas.map { case (cfName, metadata) =>
-      cfName -> createColFamilyInfo(
-        metadata.keySchema,
-        metadata.valueSchema,
-        metadata.encoderSpec,
-        cfName,
-        metadata.useMultipleValuePerKey)
-    }
-  }
   /**
    * Helper method to test round-trip for stream-stream join with different 
versions.
    */
@@ -450,8 +426,375 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       }
     }
 
-    testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
-      testRoundTripForAggrStateVersion(1)
+    // Run transformWithState tests with enable/disable checkpoint V2
+    Seq(1, 2).foreach { ckptVersion =>
+      def testWithChangelogAndCheckpointId(testName: String)(testFun: => 
Unit): Unit = {
+        testWithChangelogConfig(s"$testName (enableCkptId = ${ckptVersion >= 
2})") {
+          withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> 
ckptVersion.toString) {
+            testFun
+          }
+        }
+      }
+
+      testWithChangelogAndCheckpointId("SPARK-54420: aggregation state ver 1") 
{
+        testRoundTripForAggrStateVersion(1)
+      }
+
+      Seq(1, 2).foreach { version =>
+        testWithChangelogAndCheckpointId(s"SPARK-54420: stream-stream join 
state ver $version") {
+          testStreamStreamJoinRoundTrip(version)
+        }
+      }
+      // Run transformWithState tests with different encoding formats
+      Seq("unsaferow", "avro").foreach { encodingFormat =>
+        def testWithChangelogAndEncodingConfig(testName: String)(testFun: => 
Unit): Unit = {
+          testWithChangelogAndCheckpointId(
+            s"$testName (encoding = $encodingFormat)") {
+            withSQLConf(
+              SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> 
encodingFormat) {
+              testFun
+            }
+          }
+        }
+
+        testWithChangelogAndEncodingConfig(
+          "SPARK-54411: transformWithState with multiple column families") {
+          withTempDir { sourceDir =>
+            withTempDir { targetDir =>
+              val inputData = MemoryStream[String]
+              val query = inputData.toDS()
+                .groupByKey(x => x)
+                .transformWithState(new MultiStateVarProcessor(),
+                  TimeMode.None(),
+                  OutputMode.Update())
+              def runQuery(checkpointLocation: String, roundsOfData: Int): 
Unit = {
+                val dataActions = (1 to roundsOfData).flatMap { _ =>
+                  Seq(
+                    AddData(inputData, "a", "b", "a"),
+                    ProcessAllAvailable()
+                  )
+                }
+                testStream(query)(
+                  Seq(StartStream(checkpointLocation = checkpointLocation)) ++
+                    dataActions ++
+                    Seq(StopStream): _*
+                )
+              }
+
+              runQuery(sourceDir.getAbsolutePath, 2)
+              runQuery(targetDir.getAbsolutePath, 1)
+
+              val schemas = 
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
+              val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
+                .getColumnFamilyToSelectExprs()
+
+              val columnFamilyToStateSourceOptions = schemas.keys.map { cfName 
=>
+                val base = Map(
+                  StateSourceOptions.STATE_VAR_NAME -> cfName
+                )
+
+                val withFlatten =
+                  if (cfName == MultiStateVarProcessorTestUtils.ITEMS_LIST) {
+                    base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true")
+                  } else {
+                    base
+                  }
+
+                cfName -> withFlatten
+              }.toMap
+
+              performRoundTripTest(
+                sourceDir.getAbsolutePath,
+                targetDir.getAbsolutePath,
+                storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> 
schemas.keys.toList),
+                storeToColumnFamilyToSelectExprs =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
+                storeToColumnFamilyToStateSourceOptions =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
+                operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+              )
+            }
+          }
+        }
+
+        testWithChangelogAndEncodingConfig(
+          "SPARK-54411: transformWithState with eventTime timers") {
+          withTempDir { sourceDir =>
+            withTempDir { targetDir =>
+              val inputData = MemoryStream[(String, Long)]
+              val result = inputData.toDS()
+                .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+                .withWatermark("eventTime", "10 seconds")
+                .as[(String, Timestamp)]
+                .groupByKey(_._1)
+                .transformWithState(
+                  new EventTimeTimerProcessor(),
+                  TimeMode.EventTime(),
+                  OutputMode.Update())
+
+              testStream(result, OutputMode.Update())(
+                StartStream(checkpointLocation = sourceDir.getAbsolutePath),
+                AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+                ProcessAllAvailable(),
+                StopStream
+              )
+
+              testStream(result, OutputMode.Update())(
+                StartStream(checkpointLocation = targetDir.getAbsolutePath),
+                AddData(inputData, ("x", 1L)),
+                ProcessAllAvailable(),
+                StopStream
+              )
+
+              val (schemaMap, selectExprs, stateSourceOptions) =
+                getTimerStateConfigsForCountState(TimeMode.EventTime())
+
+              performRoundTripTest(
+                sourceDir.getAbsolutePath,
+                targetDir.getAbsolutePath,
+                storeToColumnFamilies =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
schemaMap.keys.toList),
+                storeToColumnFamilyToSelectExprs =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+                storeToColumnFamilyToStateSourceOptions =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
+                operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+              )
+            }
+          }
+        }
+
+        testWithChangelogAndEncodingConfig(
+          "SPARK-54411: transformWithState with processing time timers") {
+          withTempDir { sourceDir =>
+            withTempDir { targetDir =>
+              val clock = new StreamManualClock
+              val inputData = MemoryStream[String]
+              val result = inputData.toDS()
+                .groupByKey(x => x)
+                .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+                  TimeMode.ProcessingTime(),
+                  OutputMode.Update())
+
+              testStream(result, OutputMode.Update())(
+                StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock),
+                AddData(inputData, "a"),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(("a", "1")),
+                StopStream
+              )
+
+              val clock2 = new StreamManualClock
+              testStream(result, OutputMode.Update())(
+                StartStream(checkpointLocation = targetDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock2),
+                AddData(inputData, "x"),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(("a", "1"), ("x", "1")),
+                StopStream
+              )
+
+              val (schemaMap, selectExprs, sourceOptions) =
+                getTimerStateConfigsForCountState(TimeMode.ProcessingTime())
+
+              performRoundTripTest(
+                sourceDir.getAbsolutePath,
+                targetDir.getAbsolutePath,
+                storeToColumnFamilies =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
schemaMap.keys.toList),
+                storeToColumnFamilyToSelectExprs =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
+                storeToColumnFamilyToStateSourceOptions =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
+                operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+              )
+            }
+          }
+        }
+
+        testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState 
with list and TTL") {
+          withTempDir { sourceDir =>
+            withTempDir { targetDir =>
+              val clock = new StreamManualClock
+              val inputData = MemoryStream[InputEvent]
+              val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+              val result = inputData.toDS()
+                .groupByKey(x => x.key)
+                .transformWithState(new ListStateTTLProcessor(ttlConfig),
+                  TimeMode.ProcessingTime(),
+                  OutputMode.Update())
+
+              testStream(result, OutputMode.Update())(
+                StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock),
+                AddData(inputData, InputEvent("k1", "put", 1)),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(),
+                StopStream
+              )
+
+              val clock2 = new StreamManualClock
+              testStream(result, OutputMode.Update())(
+                StartStream(checkpointLocation = targetDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock2),
+                AddData(inputData, InputEvent("k1", "append", 1)),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(),
+                StopStream
+              )
+
+              val schemas = 
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
+
+              val columnFamilyToSelectExprs = Map(
+                TTLProcessorUtils.LIST_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
+                  TTLProcessorUtils.LIST_STATE
+                ))
+
+              val columnFamilyToStateSourceOptions = schemas.keys.map { cfName 
=>
+                val base = Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+
+                val withFlatten =
+                  if (cfName == TTLProcessorUtils.LIST_STATE) {
+                    base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true")
+                  } else {
+                    base
+                  }
+
+                cfName -> withFlatten
+              }.toMap
+
+              performRoundTripTest(
+                sourceDir.getAbsolutePath,
+                targetDir.getAbsolutePath,
+                storeToColumnFamilies =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+                storeToColumnFamilyToSelectExprs =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
+                storeToColumnFamilyToStateSourceOptions =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
+                operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+              )
+            }
+          }
+        }
+
+        testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState 
with map and TTL") {
+          withTempDir { sourceDir =>
+            withTempDir { targetDir =>
+              val clock = new StreamManualClock
+              val inputData = MemoryStream[MapInputEvent]
+              val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+              val result = inputData.toDS()
+                .groupByKey(x => x.key)
+                .transformWithState(new MapStateTTLProcessor(ttlConfig),
+                  TimeMode.ProcessingTime(),
+                  OutputMode.Update())
+
+              testStream(result)(
+                StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock),
+                AddData(inputData, MapInputEvent("a", "key1", "put", 1)),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(),
+                StopStream
+              )
+
+              val clock2 = new StreamManualClock
+              testStream(result)(
+                StartStream(checkpointLocation = targetDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock2),
+                AddData(inputData, MapInputEvent("x", "key1", "put", 1)),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(),
+                StopStream
+              )
+
+              val schemas = 
TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
+
+              val columnFamilyToSelectExprs = Map(
+                TTLProcessorUtils.MAP_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
+                  TTLProcessorUtils.MAP_STATE
+                ))
+
+              val columnFamilyToStateSourceOptions = schemas.keys.map { cfName 
=>
+                cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+              }.toMap
+
+              performRoundTripTest(
+                sourceDir.getAbsolutePath,
+                targetDir.getAbsolutePath,
+                storeToColumnFamilies =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+                storeToColumnFamilyToSelectExprs =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
+                storeToColumnFamilyToStateSourceOptions =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
+                operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+              )
+            }
+          }
+        }
+
+        testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState 
with value and TTL") {
+          withTempDir { sourceDir =>
+            withTempDir { targetDir =>
+              val clock = new StreamManualClock
+              val inputData = MemoryStream[InputEvent]
+              val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+              val result = inputData.toDS()
+                .groupByKey(x => x.key)
+                .transformWithState(new ValueStateTTLProcessor(ttlConfig),
+                  TimeMode.ProcessingTime(),
+                  OutputMode.Update())
+
+              testStream(result)(
+                StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock),
+                AddData(inputData, InputEvent("k1", "put", 1)),
+                AddData(inputData, InputEvent("k2", "put", 2)),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(),
+                StopStream
+              )
+
+              val clock2 = new StreamManualClock
+              testStream(result)(
+                StartStream(checkpointLocation = targetDir.getAbsolutePath,
+                  trigger = Trigger.ProcessingTime("1 second"),
+                  triggerClock = clock2),
+                AddData(inputData, InputEvent("x", "put", 1)),
+                AdvanceManualClock(1 * 1000),
+                CheckNewAnswer(),
+                StopStream
+              )
+
+              val schemas = 
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
+
+              val columnFamilyToStateSourceOptions = schemas.keys.map { cfName 
=>
+                cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+              }.toMap
+
+              performRoundTripTest(
+                sourceDir.getAbsolutePath,
+                targetDir.getAbsolutePath,
+                storeToColumnFamilies =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
+                storeToColumnFamilyToStateSourceOptions =
+                  Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
+                operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
+              )
+            }
+          }
+        }
+      } // End of foreach loop for encoding format (transformWithState tests 
only)
     }
 
     testWithChangelogConfig("SPARK-54420: aggregation state ver 2") {
@@ -573,12 +916,6 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
       }
     }
 
-    Seq(1, 2).foreach { version =>
-      testWithChangelogConfig(s"SPARK-54420: stream-stream join state ver 
$version") {
-        testStreamStreamJoinRoundTrip(version)
-      }
-    }
-
     testWithChangelogConfig("SPARK-54411: stream-stream join state ver 3") {
       withSQLConf(
         SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
@@ -620,358 +957,44 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
         }
       }
     }
+  } // End of foreach loop for changelog checkpointing dimension
 
-    // Run transformWithState tests with different encoding formats
-    Seq("unsaferow", "avro").foreach { encodingFormat =>
-      def testWithChangelogAndEncodingConfig(testName: String)(testFun: => 
Unit): Unit = {
-        test(s"$testName ($changelogCpTestSuffix, encoding = 
$encodingFormat)") {
-          withSQLConf(
-            
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
-              changelogCheckpointingEnabled.toString,
-            SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> 
encodingFormat) {
-            testFun
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig(
-          "SPARK-54411: transformWithState with multiple column families") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val inputData = MemoryStream[String]
-            val query = inputData.toDS()
-              .groupByKey(x => x)
-              .transformWithState(new MultiStateVarProcessor(),
-                TimeMode.None(),
-                OutputMode.Update())
-            def runQuery(checkpointLocation: String, roundsOfData: Int): Unit 
= {
-              val dataActions = (1 to roundsOfData).flatMap { _ =>
-                Seq(
-                  AddData(inputData, "a", "b", "a"),
-                  ProcessAllAvailable()
-                )
-              }
-              testStream(query)(
-                Seq(StartStream(checkpointLocation = checkpointLocation)) ++
-                  dataActions ++
-                  Seq(StopStream): _*
-              )
-            }
-
-            runQuery(sourceDir.getAbsolutePath, 2)
-            runQuery(targetDir.getAbsolutePath, 1)
-
-            val schemas = 
MultiStateVarProcessorTestUtils.getSchemasWithMetadata()
-            val columnFamilyToSelectExprs = MultiStateVarProcessorTestUtils
-              .getColumnFamilyToSelectExprs()
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              val base = Map(
-                StateSourceOptions.STATE_VAR_NAME -> cfName
-              )
-
-              val withFlatten =
-                if (cfName == MultiStateVarProcessorTestUtils.ITEMS_LIST) {
-                  base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true")
-                } else {
-                  base
-                }
-
-              cfName -> withFlatten
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies = Map(StateStoreId.DEFAULT_STORE_NAME -> 
schemas.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
eventTime timers") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val inputData = MemoryStream[(String, Long)]
-            val result = inputData.toDS()
-              .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
-              .withWatermark("eventTime", "10 seconds")
-              .as[(String, Timestamp)]
-              .groupByKey(_._1)
-              .transformWithState(
-                new EventTimeTimerProcessor(),
-                TimeMode.EventTime(),
-                OutputMode.Update())
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath),
-              AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
-              ProcessAllAvailable(),
-              StopStream
-            )
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath),
-              AddData(inputData, ("x", 1L)),
-              ProcessAllAvailable(),
-              StopStream
-            )
-
-            val (schemaMap, selectExprs, stateSourceOptions) =
-              getTimerStateConfigsForCountState(TimeMode.EventTime())
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig(
-        "SPARK-54411: transformWithState with processing time timers") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[String]
-            val result = inputData.toDS()
-              .groupByKey(x => x)
-              .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, "a"),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(("a", "1")),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, "x"),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(("a", "1"), ("x", "1")),
-              StopStream
-            )
-
-            val (schemaMap, selectExprs, sourceOptions) =
-              getTimerStateConfigsForCountState(TimeMode.ProcessingTime())
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemaMap.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> selectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> sourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
list and TTL") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[InputEvent]
-            val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-            val result = inputData.toDS()
-              .groupByKey(x => x.key)
-              .transformWithState(new ListStateTTLProcessor(ttlConfig),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, InputEvent("k1", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result, OutputMode.Update())(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, InputEvent("k1", "append", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val schemas = 
TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
-
-            val columnFamilyToSelectExprs = Map(
-              TTLProcessorUtils.LIST_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
-                TTLProcessorUtils.LIST_STATE
-            ))
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              val base = Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-
-              val withFlatten =
-                if (cfName == TTLProcessorUtils.LIST_STATE) {
-                  base + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true")
-                } else {
-                  base
-                }
-
-              cfName -> withFlatten
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
map and TTL") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[MapInputEvent]
-            val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-            val result = inputData.toDS()
-              .groupByKey(x => x.key)
-              .transformWithState(new MapStateTTLProcessor(ttlConfig),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result)(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, MapInputEvent("a", "key1", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result)(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, MapInputEvent("x", "key1", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
-
-            val columnFamilyToSelectExprs = Map(
-              TTLProcessorUtils.MAP_STATE -> 
TTLProcessorUtils.getTTLSelectExpressions(
-                TTLProcessorUtils.MAP_STATE
-            ))
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
-              storeToColumnFamilyToSelectExprs =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToSelectExprs),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
+  test("SPARK-54590: Rewriter throw exception if checkpoint version is not set 
correct") {
+    withSQLConf(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2") {
+      withTempDir { sourceDir =>
+        // Step 1: Create state by running a streaming aggregation
+        runDropDuplicatesQuery(sourceDir.getAbsolutePath)
+        val sourceCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+          spark, sourceDir.getAbsolutePath)
+        val readBatchId = 
sourceCheckpointMetadata.commitLog.getLatestBatchId().get
+        // Forced set STATE_STORE_CHECKPOINT_FORMAT_VERSION to 1 to mimic when 
user forgot to
+        // update checkpoint version to 2 in sqlConfig when running 
stateRewriter
+        // on checkpointV2 query.
+        spark.conf.unset(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
+        val ex = intercept[StateRewriterInvalidCheckpointError] {
+          val rewriter = new StateRewriter(
+            spark,
+            readBatchId,
+            readBatchId + 1,
+            sourceDir.getAbsolutePath,
+            spark.sessionState.newHadoopConf()
+          )
+          rewriter.run()
         }
-      }
-
-      testWithChangelogAndEncodingConfig("SPARK-54411: transformWithState with 
value and TTL") {
-        withTempDir { sourceDir =>
-          withTempDir { targetDir =>
-            val clock = new StreamManualClock
-            val inputData = MemoryStream[InputEvent]
-            val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
-            val result = inputData.toDS()
-              .groupByKey(x => x.key)
-              .transformWithState(new ValueStateTTLProcessor(ttlConfig),
-                TimeMode.ProcessingTime(),
-                OutputMode.Update())
-
-            testStream(result)(
-              StartStream(checkpointLocation = sourceDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock),
-              AddData(inputData, InputEvent("k1", "put", 1)),
-              AddData(inputData, InputEvent("k2", "put", 2)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
-
-            val clock2 = new StreamManualClock
-            testStream(result)(
-              StartStream(checkpointLocation = targetDir.getAbsolutePath,
-                trigger = Trigger.ProcessingTime("1 second"),
-                triggerClock = clock2),
-              AddData(inputData, InputEvent("x", "put", 1)),
-              AdvanceManualClock(1 * 1000),
-              CheckNewAnswer(),
-              StopStream
-            )
 
-            val schemas = 
TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
-
-            val columnFamilyToStateSourceOptions = schemas.keys.map { cfName =>
-              cfName -> Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
-            }.toMap
-
-            performRoundTripTest(
-              sourceDir.getAbsolutePath,
-              targetDir.getAbsolutePath,
-              storeToColumnFamilies =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> schemas.keys.toList),
-              storeToColumnFamilyToStateSourceOptions =
-                Map(StateStoreId.DEFAULT_STORE_NAME -> 
columnFamilyToStateSourceOptions),
-              operatorName = 
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
-            )
-          }
-        }
+        checkError(
+          ex,
+          
"STATE_REWRITER_INVALID_CHECKPOINT.STATE_CHECKPOINT_FORMAT_VERSION_MISMATCH",
+          "55019",
+          parameters = Map(
+            "checkpointLocation" -> sourceDir.getAbsolutePath,
+            "expectedVersion" -> "2",
+            "actualVersion" -> "1",
+            "sqlConfKey" -> SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key)
+        )
       }
-    } // End of foreach loop for encoding format (transformWithState tests 
only)
-  } // End of foreach loop for changelog checkpointing dimension
+    }
+  }
 
   test("SPARK-54411: Non-JoinV3 operator requires default column family in 
schema") {
     withTempDir { targetDir =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to