micheal-o commented on code in PR #53287:
URL: https://github.com/apache/spark/pull/53287#discussion_r2591086012


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, 
StateStoreProviderId}
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ *   (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target 
commit version
+ * instead of loading previous partition data. After writing all rows for the 
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+     storeConf: StateStoreConf,
+     hadoopConf: Configuration,
+     partitionId: Int,
+     targetCpLocation: String,
+     operatorId: Int,
+     storeName: String,
+     batchId: Long,
+     columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+  private val defaultSchema = {
+    columnFamilyToSchemaMap.getOrElse(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      throw new IllegalArgumentException(
+        s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in 
schema map")
+    )
+  }
+
+  private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.keySchema.length)
+  private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.valueSchema.length)
+
+  private val rowConverter = {
+    val schema = 
SchemaUtil.getScanAllColumnFamiliesSchema(defaultSchema.keySchema)
+    CatalystTypeConverters.createToCatalystConverter(schema)
+  }
+
+  protected lazy val provider: StateStoreProvider = {
+    val stateCheckpointLocation = new Path(targetCpLocation, 
DIR_NAME_STATE).toString
+    val stateStoreId = StateStoreId(stateCheckpointLocation,
+      operatorId, partitionId, storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
UUID.randomUUID())
+
+    val provider = StateStoreProvider.createAndInit(
+      stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
+      defaultSchema.keyStateEncoderSpec.get,
+      useColumnFamilies = false, storeConf, hadoopConf,
+      useMultipleValuesPerKey = false, stateSchemaProvider = None)
+    provider
+  }
+
+  private lazy val stateStore: StateStore = {
+    // TODO[SPARK-54590]: Support checkpoint V2 in 
StatePartitionAllColumnFamiliesWriter
+    // Create empty store to avoid loading old partition data since we are 
rewriting the
+    // store e.g. during repartitioning
+    // Use loadEmpty=true to create a fresh state store without loading 
previous versions
+    // We create the empty store AT version, and the next commit will
+    // produce version + 1
+    provider.getStore(
+      batchId + 1,
+      stateStoreCkptId = None,
+      loadEmpty = true
+    )
+  }
+
+  // The function that writes and commits data to state store. It takes in 
rows with schema
+  // - partition_key, StructType
+  // - key_bytes, BinaryType
+  // - value_bytes, BinaryType
+  // - column_family_name, StringType
+  def put(rows: Iterator[Row]): Unit = {

Review Comment:
   Take in `Iterator[InternalRow]` instead. That is what state reader returns



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -555,30 +555,49 @@ class RocksDB(
 
   private def loadWithoutCheckpointId(
       version: Long,
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      createEmpty: Boolean = false): RocksDB = {
+
     try {
-      if (loadedVersion != version) {
+      enableChangelogCheckpointing = if (createEmpty) false else 
conf.enableChangelogCheckpointing
+      // For createEmpty, always proceed; otherwise, only if version changed
+      if (createEmpty || loadedVersion != version) {
         closeDB(ignoreException = false)
 
-        // load the latest snapshot
-        loadSnapshotWithoutCheckpointId(version)
-
-        if (loadedVersion != version) {
-          val versionsAndUniqueIds: Array[(Long, Option[String])] =
-            (loadedVersion + 1 to version).map((_, None)).toArray
-          replayChangelog(versionsAndUniqueIds)
+        if (createEmpty) {
+          // Use version 0 logic to create empty directory with no SST files
+          val metadata = fileManager.loadCheckpointFromDfs(0, workingDir, 
rocksDBFileMapping, None)

Review Comment:
   nit: you can make these 4 lines a separate func e.g. 
`createEmptyWithoutCheckpointId`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -703,7 +732,8 @@ class RocksDB(
   def load(
       version: Long,
       stateStoreCkptId: Option[String] = None,
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      createEmpty: Boolean = false): RocksDB = {

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -0,0 +1,682 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.io.File
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, 
RocksDBStateStoreProvider, StateStore, StateStoreColFamilySchema, 
StateStoreConf, StateStoreId}
+import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, 
NullType, StructField, StructType, TimestampType}
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * Test suite for StatePartitionAllColumnFamiliesWriter.
+ * Tests the writer's ability to correctly write raw bytes read from
+ * StatePartitionAllColumnFamiliesReader to a state store without loading 
previous versions.
+ */
+class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase {
+  import testImplicits._
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+  }
+
+  /**
+   * Common helper method to perform round-trip test: read state bytes from 
source,
+   * write to target, and verify target matches source.
+   *
+   * @param sourceDir Source checkpoint directory
+   * @param targetDir Target checkpoint directory
+   * @param keySchema Key schema for the state store
+   * @param valueSchema Value schema for the state store
+   * @param keyStateEncoderSpec Key state encoder spec
+   * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
+   */
+  private def performRoundTripTest(
+      sourceDir: String,
+      targetDir: String,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      storeName: Option[String] = None): Unit = {
+
+    // Step 1: Read original state using normal reader (for comparison later)
+    val sourceReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, sourceDir)
+    val sourceNormalData = (storeName match {
+      case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME, 
name)
+      case None => sourceReader
+    }).load()
+      .selectExpr("key", "value", "partition_id")
+      .collect()
+
+    // Step 2: Read from source using AllColumnFamiliesReader (raw bytes)
+    val sourceBytesReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, sourceDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+    val sourceBytesData = (storeName match {
+      case Some(name) => 
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
+      case None => sourceBytesReader
+    }).load()
+
+    // Verify schema of raw bytes
+    val schema = sourceBytesData.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name"))
+
+    // Step 3: Write raw bytes to target checkpoint location
+    val hadoopConf = spark.sessionState.newHadoopConf()
+    val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
+      hadoopConf, targetDir)
+    val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+      spark, targetCpLocation)
+    val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
+    val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
+    targetCheckpointMetadata.offsetLog.add(lastBatch + 1, targetOffsetSeq)
+
+    // Create column family to schema map
+    val columnFamilyToSchemaMap = HashMap(
+      StateStore.DEFAULT_COL_FAMILY_NAME -> StateStoreColFamilySchema(
+        StateStore.DEFAULT_COL_FAMILY_NAME,
+        keySchemaId = 0,
+        keySchema,
+        valueSchemaId = 0,
+        valueSchema,
+        keyStateEncoderSpec = Some(keyStateEncoderSpec)
+      )
+    )
+
+    val storeConf: StateStoreConf = StateStoreConf(SQLConf.get)
+    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+
+    // Define the partition processing function
+    val putPartitionFunc: Iterator[Row] => Unit = partition => {
+      val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
+        storeConf,
+        serializableHadoopConf.value,
+        TaskContext.getPartitionId(),
+        targetCpLocation,
+        0,
+        storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
+        lastBatch,
+        columnFamilyToSchemaMap
+      )
+      allCFWriter.put(partition)
+    }
+
+    // Write raw bytes to target using foreachPartition
+    sourceBytesData.foreachPartition(putPartitionFunc)
+
+    // Commit to commitLog
+    val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
+    targetCheckpointMetadata.commitLog.add(lastBatch + 1, latestCommit)
+    val batchToCheck = lastBatch + 2
+    assert(!checkpointFileExists(new File(targetDir, "state/0/0"), 
batchToCheck, ".changelog"))
+    assert(checkpointFileExists(new File(targetDir, "state/0/0"), 
batchToCheck, ".zip"))
+
+    // Step 4: Read from target using normal reader
+    val targetReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, targetDir)
+    val targetNormalData = (storeName match {
+      case Some(name) => targetReader.option(StateSourceOptions.STORE_NAME, 
name)
+      case None => targetReader
+    }).load()
+      .selectExpr("key", "value", "partition_id")
+      .collect()
+
+    // Step 5: Verify data matches
+    assert(sourceNormalData.length == targetNormalData.length,
+      s"Row count mismatch: source=${sourceNormalData.length}, " +
+        s"target=${targetNormalData.length}")
+
+    // Sort and compare row by row
+    val sourceSorted = sourceNormalData.sortBy(_.toString)
+    val targetSorted = targetNormalData.sortBy(_.toString)
+
+    sourceSorted.zip(targetSorted).zipWithIndex.foreach {
+      case ((sourceRow, targetRow), idx) =>
+        assert(sourceRow == targetRow,
+          s"Row mismatch at index $idx:\n" +
+            s"  Source: $sourceRow\n" +
+            s"  Target: $targetRow")
+    }
+  }
+
+    /**
+     * Checks if a changelog file for the specified version exists in the 
given directory.
+     * A changelog file has the suffix ".changelog".
+     *
+     * @param dir Directory to search for changelog files
+     * @param version The version to check for existence
+     * @param suffix Either 'zip' or 'changelog'
+     * @return true if a changelog file with the given version exists, false 
otherwise
+     */
+    private def checkpointFileExists(dir: File, version: Long, suffix: 
String): Boolean = {
+      Option(dir.listFiles)
+        .getOrElse(Array.empty)
+        .filter { file =>
+          file.getName.endsWith(suffix) && !file.getName.startsWith(".")
+        }
+        .exists { file =>
+          val nameWithoutSuffix = file.getName.stripSuffix(suffix)
+          val parts = nameWithoutSuffix.split("_")
+          parts.headOption match {
+            case Some(ver) if ver.forall(_.isDigit) => ver.toLong == version
+            case _ => false
+          }
+        }
+    }
+
+  /**
+   * Helper method to test SPARK-54420 read and write with different state 
format versions
+   * for simple aggregation (single grouping key).
+   * @param stateVersion The state format version (1 or 2)
+   */
+  private def testRoundTripForAggrStateVersion(stateVersion: Int): Unit = {
+    withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> 
stateVersion.toString,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "2") {

Review Comment:
   nit: add this conf to the before all conf instead?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)

Review Comment:
   why?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)
+
+        db.put("c", "3")
+        assert(toStr(db.get("b")) === "2")
+        assert(toStr(db.get("c")) === "3")
+        val (version3, _) = db.commit(forceSnapshot = true)
+        assert(version3 === version2 + 1)
+      }
+
+      // Verify we can reload the committed version
+      withDB(remoteDir, version = 3,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        assert(toStr(db.get("c")) === "3")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))
+      }
+  }
+
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -180,7 +180,7 @@ class RocksDB(
 
   @volatile private var db: NativeRocksDB = _
   @volatile private var changelogWriter: Option[StateStoreChangelogWriter] = 
None
-  private val enableChangelogCheckpointing: Boolean = 
conf.enableChangelogCheckpointing
+  private var enableChangelogCheckpointing: Boolean = 
conf.enableChangelogCheckpointing

Review Comment:
   nit: `@volatile`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -555,30 +555,49 @@ class RocksDB(
 
   private def loadWithoutCheckpointId(
       version: Long,
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      createEmpty: Boolean = false): RocksDB = {
+
     try {
-      if (loadedVersion != version) {
+      enableChangelogCheckpointing = if (createEmpty) false else 
conf.enableChangelogCheckpointing

Review Comment:
   Do this in the `load` method since we will also need this for the 
loadwithcheckpointId func later. Also add comment why we are doing this.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, 
StateStoreProviderId}
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ *   (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target 
commit version
+ * instead of loading previous partition data. After writing all rows for the 
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+     storeConf: StateStoreConf,

Review Comment:
   nit: indentation. This is 5 spaces instead of 4



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -555,30 +555,49 @@ class RocksDB(
 
   private def loadWithoutCheckpointId(
       version: Long,
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      createEmpty: Boolean = false): RocksDB = {

Review Comment:
   nit: consistent naming `loadEmpty`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -555,30 +555,49 @@ class RocksDB(
 
   private def loadWithoutCheckpointId(
       version: Long,
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      createEmpty: Boolean = false): RocksDB = {
+
     try {
-      if (loadedVersion != version) {
+      enableChangelogCheckpointing = if (createEmpty) false else 
conf.enableChangelogCheckpointing
+      // For createEmpty, always proceed; otherwise, only if version changed
+      if (createEmpty || loadedVersion != version) {
         closeDB(ignoreException = false)
 
-        // load the latest snapshot
-        loadSnapshotWithoutCheckpointId(version)
-
-        if (loadedVersion != version) {
-          val versionsAndUniqueIds: Array[(Long, Option[String])] =
-            (loadedVersion + 1 to version).map((_, None)).toArray
-          replayChangelog(versionsAndUniqueIds)
+        if (createEmpty) {
+          // Use version 0 logic to create empty directory with no SST files
+          val metadata = fileManager.loadCheckpointFromDfs(0, workingDir, 
rocksDBFileMapping, None)
           loadedVersion = version
+          fileManager.setMaxSeenVersion(version)
+          openLocalRocksDB(metadata)
+          // Empty store has no keys
+          numKeysOnLoadedVersion = 0
+          numInternalKeysOnLoadedVersion = 0
+        } else {
+          // load the latest snapshot
+          loadSnapshotWithoutCheckpointId(version)
+
+          if (loadedVersion != version) {
+            val versionsAndUniqueIds: Array[(Long, Option[String])] =
+              (loadedVersion + 1 to version).map((_, None)).toArray
+            replayChangelog(versionsAndUniqueIds)
+            loadedVersion = version
+          }
+          // After changelog replay the numKeysOnWritingVersion will be 
updated to
+          // the correct number of keys in the loaded version.

Review Comment:
   keep these vars setting where it was before



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, 
StateStoreProviderId}
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ *   (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target 
commit version
+ * instead of loading previous partition data. After writing all rows for the 
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+     storeConf: StateStoreConf,
+     hadoopConf: Configuration,
+     partitionId: Int,
+     targetCpLocation: String,
+     operatorId: Int,
+     storeName: String,
+     batchId: Long,
+     columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+  private val defaultSchema = {
+    columnFamilyToSchemaMap.getOrElse(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      throw new IllegalArgumentException(
+        s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in 
schema map")
+    )
+  }
+
+  private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.keySchema.length)
+  private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.valueSchema.length)
+
+  private val rowConverter = {
+    val schema = 
SchemaUtil.getScanAllColumnFamiliesSchema(defaultSchema.keySchema)
+    CatalystTypeConverters.createToCatalystConverter(schema)
+  }
+
+  protected lazy val provider: StateStoreProvider = {
+    val stateCheckpointLocation = new Path(targetCpLocation, 
DIR_NAME_STATE).toString
+    val stateStoreId = StateStoreId(stateCheckpointLocation,
+      operatorId, partitionId, storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
UUID.randomUUID())
+
+    val provider = StateStoreProvider.createAndInit(
+      stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
+      defaultSchema.keyStateEncoderSpec.get,
+      useColumnFamilies = false, storeConf, hadoopConf,
+      useMultipleValuesPerKey = false, stateSchemaProvider = None)
+    provider
+  }
+
+  private lazy val stateStore: StateStore = {
+    // TODO[SPARK-54590]: Support checkpoint V2 in 
StatePartitionAllColumnFamiliesWriter
+    // Create empty store to avoid loading old partition data since we are 
rewriting the
+    // store e.g. during repartitioning
+    // Use loadEmpty=true to create a fresh state store without loading 
previous versions
+    // We create the empty store AT version, and the next commit will
+    // produce version + 1
+    provider.getStore(
+      batchId + 1,

Review Comment:
   This is incorrect right. The current batch will load at `currentBatchId` and 
will commit `currentBatchId + 1` e.g. batch 0 will load 0 and commit 1.
   
   This is different from the state reader, where the reader is asked to read 
the state produced by batchId, which is batchId + 1. e.g. to read the state for 
batch 0, you will need to load 1



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, 
StateStoreProviderId}
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ *   (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target 
commit version
+ * instead of loading previous partition data. After writing all rows for the 
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+     storeConf: StateStoreConf,
+     hadoopConf: Configuration,
+     partitionId: Int,
+     targetCpLocation: String,
+     operatorId: Int,
+     storeName: String,
+     batchId: Long,
+     columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+  private val defaultSchema = {
+    columnFamilyToSchemaMap.getOrElse(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      throw new IllegalArgumentException(
+        s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in 
schema map")
+    )
+  }
+
+  private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.keySchema.length)
+  private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.valueSchema.length)
+
+  private val rowConverter = {
+    val schema = 
SchemaUtil.getScanAllColumnFamiliesSchema(defaultSchema.keySchema)
+    CatalystTypeConverters.createToCatalystConverter(schema)
+  }
+
+  protected lazy val provider: StateStoreProvider = {
+    val stateCheckpointLocation = new Path(targetCpLocation, 
DIR_NAME_STATE).toString
+    val stateStoreId = StateStoreId(stateCheckpointLocation,
+      operatorId, partitionId, storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
UUID.randomUUID())
+
+    val provider = StateStoreProvider.createAndInit(
+      stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
+      defaultSchema.keyStateEncoderSpec.get,
+      useColumnFamilies = false, storeConf, hadoopConf,
+      useMultipleValuesPerKey = false, stateSchemaProvider = None)
+    provider
+  }
+
+  private lazy val stateStore: StateStore = {
+    // TODO[SPARK-54590]: Support checkpoint V2 in 
StatePartitionAllColumnFamiliesWriter
+    // Create empty store to avoid loading old partition data since we are 
rewriting the
+    // store e.g. during repartitioning
+    // Use loadEmpty=true to create a fresh state store without loading 
previous versions
+    // We create the empty store AT version, and the next commit will
+    // produce version + 1
+    provider.getStore(
+      batchId + 1,
+      stateStoreCkptId = None,
+      loadEmpty = true
+    )
+  }
+
+  // The function that writes and commits data to state store. It takes in 
rows with schema
+  // - partition_key, StructType
+  // - key_bytes, BinaryType
+  // - value_bytes, BinaryType
+  // - column_family_name, StringType
+  def put(rows: Iterator[Row]): Unit = {
+    try {
+      rows.foreach(row => putRaw(row))
+      stateStore.commit()
+    } finally {
+      if (!stateStore.hasCommitted) {
+        stateStore.abort()
+      }
+    }
+  }
+
+  private def putRaw(rawRecord: Row): Unit = {
+    val record = rowConverter(rawRecord).asInstanceOf[InternalRow]

Review Comment:
   Then you wouldn't need this anymore



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -762,10 +764,12 @@ private[sql] class RocksDBStateStoreProvider
           Some(s)
       }
 
+      // Load RocksDB: either empty or from existing checkpoints

Review Comment:
   remove comment, not necessary



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -555,30 +555,49 @@ class RocksDB(
 
   private def loadWithoutCheckpointId(
       version: Long,
-      readOnly: Boolean = false): RocksDB = {
+      readOnly: Boolean = false,
+      createEmpty: Boolean = false): RocksDB = {
+
     try {
-      if (loadedVersion != version) {
+      enableChangelogCheckpointing = if (createEmpty) false else 
conf.enableChangelogCheckpointing
+      // For createEmpty, always proceed; otherwise, only if version changed
+      if (createEmpty || loadedVersion != version) {
         closeDB(ignoreException = false)
 
-        // load the latest snapshot
-        loadSnapshotWithoutCheckpointId(version)
-
-        if (loadedVersion != version) {
-          val versionsAndUniqueIds: Array[(Long, Option[String])] =
-            (loadedVersion + 1 to version).map((_, None)).toArray
-          replayChangelog(versionsAndUniqueIds)
+        if (createEmpty) {
+          // Use version 0 logic to create empty directory with no SST files
+          val metadata = fileManager.loadCheckpointFromDfs(0, workingDir, 
rocksDBFileMapping, None)
           loadedVersion = version
+          fileManager.setMaxSeenVersion(version)
+          openLocalRocksDB(metadata)
+          // Empty store has no keys
+          numKeysOnLoadedVersion = 0

Review Comment:
   No need to set this two vars to 0 here. `numKeysOnWritingVersion` and 
`numInternalKeysOnWritingVersion` will be set to 0 in `openLocalRocksDB` for 
empty db. Hence they will also be set to 0.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -587,7 +606,17 @@ class RocksDB(
     if (enableChangelogCheckpointing && !readOnly) {
       // Make sure we don't leak resource.
       changelogWriter.foreach(_.abort())
-      changelogWriter = Some(fileManager.getChangeLogWriter(version + 1, 
useColumnFamilies))
+      if (createEmpty) {

Review Comment:
   why are we doing this? changelog is disabled when createEmpty is true, so we 
will never run this.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -806,12 +810,14 @@ private[sql] class RocksDBStateStoreProvider
   override def getStore(
       version: Long,
       uniqueId: Option[String] = None,
-      forceSnapshotOnCommit: Boolean = false): StateStore = {
+      forceSnapshotOnCommit: Boolean = false,
+      loadEmpty: Boolean = false): StateStore = {
     loadStateStore(
       version,
       uniqueId,
       readOnly = false,
-      forceSnapshotOnCommit = forceSnapshotOnCommit
+      forceSnapshotOnCommit = if (loadEmpty) true else forceSnapshotOnCommit,

Review Comment:
   we don't need this anymore right. Since we disable changelog anyway, so 
snapshot will be created



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)

Review Comment:
   This should fail that assertion right. The one that checks if createEmpty is 
true when loading with checkpointid.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)
+
+        db.put("c", "3")
+        assert(toStr(db.get("b")) === "2")
+        assert(toStr(db.get("c")) === "3")
+        val (version3, _) = db.commit(forceSnapshot = true)

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -0,0 +1,682 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.io.File
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, 
RocksDBStateStoreProvider, StateStore, StateStoreColFamilySchema, 
StateStoreConf, StateStoreId}
+import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, 
NullType, StructField, StructType, TimestampType}
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * Test suite for StatePartitionAllColumnFamiliesWriter.
+ * Tests the writer's ability to correctly write raw bytes read from
+ * StatePartitionAllColumnFamiliesReader to a state store without loading 
previous versions.
+ */
+class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase {
+  import testImplicits._
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+  }
+
+  /**
+   * Common helper method to perform round-trip test: read state bytes from 
source,
+   * write to target, and verify target matches source.
+   *
+   * @param sourceDir Source checkpoint directory
+   * @param targetDir Target checkpoint directory
+   * @param keySchema Key schema for the state store
+   * @param valueSchema Value schema for the state store
+   * @param keyStateEncoderSpec Key state encoder spec
+   * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
+   */
+  private def performRoundTripTest(
+      sourceDir: String,
+      targetDir: String,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      storeName: Option[String] = None): Unit = {
+
+    // Step 1: Read original state using normal reader (for comparison later)
+    val sourceReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, sourceDir)
+    val sourceNormalData = (storeName match {
+      case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME, 
name)
+      case None => sourceReader
+    }).load()
+      .selectExpr("key", "value", "partition_id")
+      .collect()
+
+    // Step 2: Read from source using AllColumnFamiliesReader (raw bytes)
+    val sourceBytesReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, sourceDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+    val sourceBytesData = (storeName match {
+      case Some(name) => 
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
+      case None => sourceBytesReader
+    }).load()
+
+    // Verify schema of raw bytes
+    val schema = sourceBytesData.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name"))
+
+    // Step 3: Write raw bytes to target checkpoint location
+    val hadoopConf = spark.sessionState.newHadoopConf()
+    val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
+      hadoopConf, targetDir)
+    val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+      spark, targetCpLocation)
+    val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
+    val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
+    targetCheckpointMetadata.offsetLog.add(lastBatch + 1, targetOffsetSeq)
+
+    // Create column family to schema map
+    val columnFamilyToSchemaMap = HashMap(
+      StateStore.DEFAULT_COL_FAMILY_NAME -> StateStoreColFamilySchema(
+        StateStore.DEFAULT_COL_FAMILY_NAME,
+        keySchemaId = 0,
+        keySchema,
+        valueSchemaId = 0,
+        valueSchema,
+        keyStateEncoderSpec = Some(keyStateEncoderSpec)
+      )
+    )
+
+    val storeConf: StateStoreConf = StateStoreConf(SQLConf.get)
+    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+
+    // Define the partition processing function
+    val putPartitionFunc: Iterator[Row] => Unit = partition => {
+      val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
+        storeConf,
+        serializableHadoopConf.value,
+        TaskContext.getPartitionId(),
+        targetCpLocation,
+        0,
+        storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
+        lastBatch,

Review Comment:
   should be current batch



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(

Review Comment:
   we don't support checkpointIds yet, so why are we testing that dimension?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, 
StateStoreProviderId}
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ *   (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target 
commit version
+ * instead of loading previous partition data. After writing all rows for the 
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+     storeConf: StateStoreConf,
+     hadoopConf: Configuration,
+     partitionId: Int,
+     targetCpLocation: String,
+     operatorId: Int,
+     storeName: String,
+     batchId: Long,

Review Comment:
   nit: `currentBatchId`



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -0,0 +1,682 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.io.File
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, 
RocksDBStateStoreProvider, StateStore, StateStoreColFamilySchema, 
StateStoreConf, StateStoreId}
+import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.types.{BooleanType, IntegerType, LongType, 
NullType, StructField, StructType, TimestampType}
+import org.apache.spark.util.SerializableConfiguration
+
+/**
+ * Test suite for StatePartitionAllColumnFamiliesWriter.
+ * Tests the writer's ability to correctly write raw bytes read from
+ * StatePartitionAllColumnFamiliesReader to a state store without loading 
previous versions.
+ */
+class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase {
+  import testImplicits._
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
+      classOf[RocksDBStateStoreProvider].getName)
+  }
+
+  /**
+   * Common helper method to perform round-trip test: read state bytes from 
source,
+   * write to target, and verify target matches source.
+   *
+   * @param sourceDir Source checkpoint directory
+   * @param targetDir Target checkpoint directory
+   * @param keySchema Key schema for the state store
+   * @param valueSchema Value schema for the state store
+   * @param keyStateEncoderSpec Key state encoder spec
+   * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
+   */
+  private def performRoundTripTest(
+      sourceDir: String,
+      targetDir: String,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyStateEncoderSpec: KeyStateEncoderSpec,
+      storeName: Option[String] = None): Unit = {
+
+    // Step 1: Read original state using normal reader (for comparison later)
+    val sourceReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, sourceDir)
+    val sourceNormalData = (storeName match {
+      case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME, 
name)
+      case None => sourceReader
+    }).load()
+      .selectExpr("key", "value", "partition_id")
+      .collect()
+
+    // Step 2: Read from source using AllColumnFamiliesReader (raw bytes)
+    val sourceBytesReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, sourceDir)
+      .option(StateSourceOptions.INTERNAL_ONLY_READ_ALL_COLUMN_FAMILIES, 
"true")
+    val sourceBytesData = (storeName match {
+      case Some(name) => 
sourceBytesReader.option(StateSourceOptions.STORE_NAME, name)
+      case None => sourceBytesReader
+    }).load()
+
+    // Verify schema of raw bytes
+    val schema = sourceBytesData.schema
+    assert(schema.fieldNames === Array(
+      "partition_key", "key_bytes", "value_bytes", "column_family_name"))
+
+    // Step 3: Write raw bytes to target checkpoint location
+    val hadoopConf = spark.sessionState.newHadoopConf()
+    val targetCpLocation = StreamingUtils.resolvedCheckpointLocation(
+      hadoopConf, targetDir)
+    val targetCheckpointMetadata = new StreamingQueryCheckpointMetadata(
+      spark, targetCpLocation)
+    val lastBatch = targetCheckpointMetadata.commitLog.getLatestBatchId().get
+    val targetOffsetSeq = targetCheckpointMetadata.offsetLog.get(lastBatch).get
+    targetCheckpointMetadata.offsetLog.add(lastBatch + 1, targetOffsetSeq)
+
+    // Create column family to schema map
+    val columnFamilyToSchemaMap = HashMap(
+      StateStore.DEFAULT_COL_FAMILY_NAME -> StateStoreColFamilySchema(
+        StateStore.DEFAULT_COL_FAMILY_NAME,
+        keySchemaId = 0,
+        keySchema,
+        valueSchemaId = 0,
+        valueSchema,
+        keyStateEncoderSpec = Some(keyStateEncoderSpec)
+      )
+    )
+
+    val storeConf: StateStoreConf = StateStoreConf(SQLConf.get)
+    val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
+
+    // Define the partition processing function
+    val putPartitionFunc: Iterator[Row] => Unit = partition => {
+      val allCFWriter = new StatePartitionAllColumnFamiliesWriter(
+        storeConf,
+        serializableHadoopConf.value,
+        TaskContext.getPartitionId(),
+        targetCpLocation,
+        0,
+        storeName.getOrElse(StateStoreId.DEFAULT_STORE_NAME),
+        lastBatch,
+        columnFamilyToSchemaMap
+      )
+      allCFWriter.put(partition)
+    }
+
+    // Write raw bytes to target using foreachPartition
+    sourceBytesData.foreachPartition(putPartitionFunc)
+
+    // Commit to commitLog
+    val latestCommit = targetCheckpointMetadata.commitLog.get(lastBatch).get
+    targetCheckpointMetadata.commitLog.add(lastBatch + 1, latestCommit)
+    val batchToCheck = lastBatch + 2
+    assert(!checkpointFileExists(new File(targetDir, "state/0/0"), 
batchToCheck, ".changelog"))
+    assert(checkpointFileExists(new File(targetDir, "state/0/0"), 
batchToCheck, ".zip"))
+
+    // Step 4: Read from target using normal reader
+    val targetReader = spark.read
+      .format("statestore")
+      .option(StateSourceOptions.PATH, targetDir)
+    val targetNormalData = (storeName match {
+      case Some(name) => targetReader.option(StateSourceOptions.STORE_NAME, 
name)
+      case None => targetReader
+    }).load()
+      .selectExpr("key", "value", "partition_id")
+      .collect()
+
+    // Step 5: Verify data matches
+    assert(sourceNormalData.length == targetNormalData.length,
+      s"Row count mismatch: source=${sourceNormalData.length}, " +
+        s"target=${targetNormalData.length}")
+
+    // Sort and compare row by row
+    val sourceSorted = sourceNormalData.sortBy(_.toString)
+    val targetSorted = targetNormalData.sortBy(_.toString)
+
+    sourceSorted.zip(targetSorted).zipWithIndex.foreach {
+      case ((sourceRow, targetRow), idx) =>
+        assert(sourceRow == targetRow,
+          s"Row mismatch at index $idx:\n" +
+            s"  Source: $sourceRow\n" +
+            s"  Target: $targetRow")
+    }
+  }
+
+    /**
+     * Checks if a changelog file for the specified version exists in the 
given directory.
+     * A changelog file has the suffix ".changelog".
+     *
+     * @param dir Directory to search for changelog files
+     * @param version The version to check for existence
+     * @param suffix Either 'zip' or 'changelog'
+     * @return true if a changelog file with the given version exists, false 
otherwise
+     */
+    private def checkpointFileExists(dir: File, version: Long, suffix: 
String): Boolean = {
+      Option(dir.listFiles)
+        .getOrElse(Array.empty)
+        .filter { file =>
+          file.getName.endsWith(suffix) && !file.getName.startsWith(".")
+        }
+        .exists { file =>
+          val nameWithoutSuffix = file.getName.stripSuffix(suffix)
+          val parts = nameWithoutSuffix.split("_")
+          parts.headOption match {
+            case Some(ver) if ver.forall(_.isDigit) => ver.toLong == version
+            case _ => false
+          }
+        }
+    }
+
+  /**
+   * Helper method to test SPARK-54420 read and write with different state 
format versions
+   * for simple aggregation (single grouping key).
+   * @param stateVersion The state format version (1 or 2)
+   */
+  private def testRoundTripForAggrStateVersion(stateVersion: Int): Unit = {
+    withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> 
stateVersion.toString,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running a streaming aggregation
+          runLargeDataStreamingAggregationQuery(sourceDir.getAbsolutePath)
+          val inputData: MemoryStream[Int] = MemoryStream[Int]
+          val aggregated = getLargeDataStreamingAggregationQuery(inputData)
+
+          // add dummy data to target source to test writer won't load 
previous store
+          testStream(aggregated, OutputMode.Update)(
+            StartStream(checkpointLocation = targetDir.getAbsolutePath),
+            // batch 0
+            AddData(inputData, 0 until 2: _*),
+            CheckLastBatch(
+              (0, 1, 0, 0, 0), // 0
+              (1, 1, 1, 1, 1) // 1
+            ),
+            // batch 1
+            AddData(inputData, 0 until 2: _*),
+            CheckLastBatch(
+              (0, 2, 0, 0, 0), // 0
+              (1, 2, 2, 1, 1) // 1
+            )
+          )
+
+          // Step 2: Define schemas based on state version
+          val keySchema = StructType(Array(
+            StructField("groupKey", IntegerType, nullable = false)))
+          val valueSchema = if (stateVersion == 1) {
+            // State version 1 includes key columns in the value
+            StructType(Array(
+              StructField("groupKey", IntegerType, nullable = false),
+              StructField("count", LongType, nullable = false),
+              StructField("sum", LongType, nullable = false),
+              StructField("max", IntegerType, nullable = false),
+              StructField("min", IntegerType, nullable = false)
+            ))
+          } else {
+            // State version 2 excludes key columns from the value
+            StructType(Array(
+              StructField("count", LongType, nullable = false),
+              StructField("sum", LongType, nullable = false),
+              StructField("max", IntegerType, nullable = false),
+              StructField("min", IntegerType, nullable = false)
+            ))
+          }
+
+          // Create key state encoder spec (no prefix key for simple 
aggregation)
+          val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            keySchema,
+            valueSchema,
+            keyStateEncoderSpec
+          )
+        }
+      }
+    }
+  }
+
+  /**
+   * Helper method to test SPARK-54420 read and write with different state 
format versions
+   * for composite key aggregation (multiple grouping keys).
+   * @param stateVersion The state format version (1 or 2)
+   */
+  private def testCompositeKeyRoundTripForStateVersion(stateVersion: Int): 
Unit = {
+    withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> 
stateVersion.toString,
+      SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running a composite key streaming 
aggregation
+          runCompositeKeyStreamingAggregationQuery(sourceDir.getAbsolutePath)
+          val inputData: MemoryStream[Int] = MemoryStream[Int]
+          val aggregated = getCompositeKeyStreamingAggregationQuery(inputData)
+
+          // add dummy data to target source to test writer won't load 
previous store
+          testStream(aggregated, OutputMode.Update)(
+            StartStream(checkpointLocation = targetDir.getAbsolutePath),
+            // batch 0
+            AddData(inputData, 0, 1),
+            CheckLastBatch(
+              (0, "Apple", 1, 0, 0, 0),
+              (1, "Banana", 1, 1, 1, 1)
+            )
+          )
+
+          // Step 2: Define schemas based on state version for composite key
+          val keySchema = StructType(Array(
+            StructField("groupKey", IntegerType, nullable = false),
+            StructField("fruit", org.apache.spark.sql.types.StringType, 
nullable = true)
+          ))
+          val valueSchema = if (stateVersion == 1) {
+            // State version 1 includes key columns in the value
+            StructType(Array(
+              StructField("groupKey", IntegerType, nullable = false),
+              StructField("fruit", org.apache.spark.sql.types.StringType, 
nullable = true),
+              StructField("count", LongType, nullable = false),
+              StructField("sum", LongType, nullable = false),
+              StructField("max", IntegerType, nullable = false),
+              StructField("min", IntegerType, nullable = false)
+            ))
+          } else {
+            // State version 2 excludes key columns from the value
+            StructType(Array(
+              StructField("count", LongType, nullable = false),
+              StructField("sum", LongType, nullable = false),
+              StructField("max", IntegerType, nullable = false),
+              StructField("min", IntegerType, nullable = false)
+            ))
+          }
+
+          // Create key state encoder spec (no prefix key for composite key 
aggregation)
+          val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            keySchema,
+            valueSchema,
+            keyStateEncoderSpec
+          )
+        }
+      }
+    }
+  }
+
+  /**
+   * Helper method to test round-trip for stream-stream join with different 
versions.
+   */
+  private def testStreamStreamJoinRoundTrip(stateVersion: Int): Unit = {
+    withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> 
stateVersion.toString) {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running stream-stream join
+          runStreamStreamJoinQuery(sourceDir.getAbsolutePath)
+
+          // Create dummy data in target
+          val inputData: MemoryStream[(Int, Long)] = MemoryStream[(Int, Long)]
+          val query = getStreamStreamJoinQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = targetDir.getAbsolutePath),
+            AddData(inputData, (1, 1L)),
+            CheckNewAnswer()
+          )
+
+          // Step 2: Test all 4 state stores created by stream-stream join
+          // Test keyToNumValues stores (both left and right)
+          Seq("left-keyToNumValues", "right-keyToNumValues").foreach { 
storeName =>
+            val keySchema = StructType(Array(
+              StructField("key", IntegerType)
+            ))
+            val valueSchema = StructType(Array(
+              StructField("value", LongType)
+            ))
+            val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+            // Perform round-trip test using common helper
+            performRoundTripTest(
+              sourceDir.getAbsolutePath,
+              targetDir.getAbsolutePath,
+              keySchema,
+              valueSchema,
+              keyStateEncoderSpec,
+              storeName = Some(storeName)
+            )
+          }
+
+          // Test keyWithIndexToValue stores (both left and right)
+          Seq("left-keyWithIndexToValue", "right-keyWithIndexToValue").foreach 
{ storeName =>
+            val keySchema = StructType(Array(
+              StructField("key", IntegerType, nullable = false),
+              StructField("index", LongType)
+            ))
+            val valueSchema = if (stateVersion == 2) {
+              StructType(Array(
+                StructField("value", IntegerType, nullable = false),
+                StructField("time", TimestampType, nullable = false),
+                StructField("matched", BooleanType)
+              ))
+            } else {
+              StructType(Array(
+                StructField("value", IntegerType, nullable = false),
+                StructField("time", TimestampType, nullable = false)
+              ))
+            }
+            val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+            // Perform round-trip test using common helper
+            performRoundTripTest(
+              sourceDir.getAbsolutePath,
+              targetDir.getAbsolutePath,
+              keySchema,
+              valueSchema,
+              keyStateEncoderSpec,
+              storeName = Some(storeName)
+            )
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Helper method to test round-trip for flatMapGroupsWithState with 
different versions.
+   */
+  private def testFlatMapGroupsWithStateRoundTrip(stateVersion: Int): Unit = {
+    // Skip this test on big endian platforms (version 1 only)
+    if (stateVersion == 1) {
+      
assume(java.nio.ByteOrder.nativeOrder().equals(java.nio.ByteOrder.LITTLE_ENDIAN))
+    }
+
+    withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> 
stateVersion.toString) {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running flatMapGroupsWithState
+          runFlatMapGroupsWithStateQuery(sourceDir.getAbsolutePath)
+
+          // Create dummy data in target
+          val clock = new StreamManualClock
+          val inputData: MemoryStream[(String, Long)] = MemoryStream[(String, 
Long)]
+          val query = getFlatMapGroupsWithStateQuery(inputData)
+          testStream(query, OutputMode.Update)(
+            StartStream(Trigger.ProcessingTime("1 second"), triggerClock = 
clock,
+              checkpointLocation = targetDir.getAbsolutePath),
+            AddData(inputData, ("a", 1L)),
+            AdvanceManualClock(1 * 1000),
+            CheckLastBatch(("a", 1, 0, false))
+          )
+
+          // Step 2: Define schemas for flatMapGroupsWithState
+          val keySchema = StructType(Array(
+            StructField("value", org.apache.spark.sql.types.StringType, 
nullable = true)
+          ))
+          val valueSchema = if (stateVersion == 1) {
+            StructType(Array(
+              StructField("numEvents", IntegerType, nullable = false),
+              StructField("startTimestampMs", LongType, nullable = false),
+              StructField("endTimestampMs", LongType, nullable = false),
+              StructField("timeoutTimestamp", IntegerType, nullable = false)
+            ))
+          } else {
+            StructType(Array(
+              StructField("groupState", 
org.apache.spark.sql.types.StructType(Array(
+                StructField("numEvents", IntegerType, nullable = false),
+                StructField("startTimestampMs", LongType, nullable = false),
+                StructField("endTimestampMs", LongType, nullable = false)
+              )), nullable = false),
+              StructField("timeoutTimestamp", LongType, nullable = false)
+            ))
+          }
+          val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            keySchema,
+            valueSchema,
+            keyStateEncoderSpec
+          )
+        }
+      }
+    }
+  }
+
+  // Run all tests with both changelog checkpointing enabled and disabled
+  Seq(true, false).foreach { changelogCheckpointingEnabled =>
+    val testSuffix = if (changelogCheckpointingEnabled) {
+      "with changelog checkpointing"
+    } else {
+      "without changelog checkpointing"
+    }
+
+    def testWithChangelogConfig(testName: String)(testFun: => Unit): Unit = {
+      test(s"$testName ($testSuffix)") {
+        withSQLConf(
+          
"spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled" ->
+            changelogCheckpointingEnabled.toString) {
+          testFun
+        }
+      }
+    }
+
+    testWithChangelogConfig("SPARK-54420: aggregation state ver 1") {
+      testRoundTripForAggrStateVersion(1)
+    }
+
+    testWithChangelogConfig("SPARK-54420: aggregation state ver 2") {
+      testRoundTripForAggrStateVersion(2)
+    }
+
+    testWithChangelogConfig("SPARK-54420: composite key aggregation state ver 
1") {
+      testCompositeKeyRoundTripForStateVersion(1)
+    }
+
+    testWithChangelogConfig("SPARK-54420: composite key aggregation state ver 
2") {
+      testCompositeKeyRoundTripForStateVersion(2)
+    }
+
+    testWithChangelogConfig("SPARK-54420: dropDuplicatesWithinWatermark") {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running dropDuplicatesWithinWatermark
+          runDropDuplicatesWithinWatermarkQuery(sourceDir.getAbsolutePath)
+
+          // Create dummy data in target
+          val inputData: MemoryStream[(String, Int)] = MemoryStream[(String, 
Int)]
+          val deduped = getDropDuplicatesWithinWatermarkQuery(inputData)
+          testStream(deduped, OutputMode.Append)(
+            StartStream(checkpointLocation = targetDir.getAbsolutePath),
+            AddData(inputData, ("a", 1)),
+            CheckAnswer(("a", 1))
+          )
+
+          // Step 2: Define schemas for dropDuplicatesWithinWatermark
+          val keySchema = StructType(Array(
+            StructField("_1", org.apache.spark.sql.types.StringType, nullable 
= true)
+          ))
+          val valueSchema = StructType(Array(
+            StructField("expiresAtMicros", LongType, nullable = false)
+          ))
+          val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            keySchema,
+            valueSchema,
+            keyStateEncoderSpec
+          )
+        }
+      }
+    }
+
+    testWithChangelogConfig("SPARK-54420: dropDuplicates with column 
specified") {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running dropDuplicates with column
+          runDropDuplicatesQueryWithColumnSpecified(sourceDir.getAbsolutePath)
+
+          // Create dummy data in target
+          val inputData: MemoryStream[(String, Int)] = MemoryStream[(String, 
Int)]
+          val deduped = getDropDuplicatesQueryWithColumnSpecified(inputData)
+          testStream(deduped, OutputMode.Append)(
+            StartStream(checkpointLocation = targetDir.getAbsolutePath),
+            AddData(inputData, ("a", 1)),
+            CheckAnswer(("a", 1))
+          )
+
+          // Step 2: Define schemas for dropDuplicates with column specified
+          val keySchema = StructType(Array(
+            StructField("col1", org.apache.spark.sql.types.StringType, 
nullable = true)
+          ))
+          val valueSchema = StructType(Array(
+            StructField("__dummy__", NullType, nullable = true)
+          ))
+          val keyStateEncoderSpec = NoPrefixKeyStateEncoderSpec(keySchema)
+
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            keySchema,
+            valueSchema,
+            keyStateEncoderSpec
+          )
+        }
+      }
+    }
+
+    testWithChangelogConfig("SPARK-54420: session window aggregation") {
+      withTempDir { sourceDir =>
+        withTempDir { targetDir =>
+          // Step 1: Create state by running session window aggregation
+          runSessionWindowAggregationQuery(sourceDir.getAbsolutePath)
+
+          // Create dummy data in target
+          val inputData: MemoryStream[(String, Long)] = MemoryStream[(String, 
Long)]
+          val aggregated = getSessionWindowAggregationQuery(inputData)
+          testStream(aggregated, OutputMode.Complete())(
+            StartStream(checkpointLocation = targetDir.getAbsolutePath),
+            AddData(inputData, ("a", 40L)),
+            CheckNewAnswer(
+              ("a", 40, 50, 10, 1)
+            ),
+            StopStream
+          )
+
+          // Step 2: Define schemas for session window aggregation
+          val keySchema = StructType(Array(
+            StructField("sessionId", org.apache.spark.sql.types.StringType, 
nullable = false),
+            StructField("sessionStartTime",
+              org.apache.spark.sql.types.TimestampType, nullable = false)
+          ))
+          val valueSchema = StructType(Array(
+            StructField("session_window", 
org.apache.spark.sql.types.StructType(Array(
+              StructField("start", org.apache.spark.sql.types.TimestampType),
+              StructField("end", org.apache.spark.sql.types.TimestampType)
+            )), nullable = false),
+            StructField("sessionId", org.apache.spark.sql.types.StringType, 
nullable = false),
+            StructField("count", LongType, nullable = false)
+          ))
+          // Session window aggregation uses prefix key scanning where 
sessionId is the prefix
+          val keyStateEncoderSpec = PrefixKeyScanStateEncoderSpec(keySchema, 1)
+
+          // Perform round-trip test using common helper
+          performRoundTripTest(
+            sourceDir.getAbsolutePath,
+            targetDir.getAbsolutePath,
+            keySchema,
+            valueSchema,
+            keyStateEncoderSpec
+          )
+        }
+      }
+    }
+
+    testWithChangelogConfig("SPARK-54420: dropDuplicates") {
+      withSQLConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> 
"2",

Review Comment:
   why set this for dropdup?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)

Review Comment:
   we shouldn't set forceSnapshot



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)
+
+        db.put("c", "3")
+        assert(toStr(db.get("b")) === "2")
+        assert(toStr(db.get("c")) === "3")
+        val (version3, _) = db.commit(forceSnapshot = true)
+        assert(version3 === version2 + 1)
+      }
+
+      // Verify we can reload the committed version
+      withDB(remoteDir, version = 3,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        assert(toStr(db.get("c")) === "3")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))
+      }
+  }
+
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty at version 0") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Create empty store at version 0
+        val ckptId = if (enableStateStoreCheckpointIds) {
+          Some(java.util.UUID.randomUUID.toString)
+        } else {
+          None
+        }
+
+        db.load(0, ckptId, createEmpty = true)

Review Comment:
   no need for separate test case for this. This can be added to your test case 
above. This suite is already very large.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)
+
+        db.put("c", "3")
+        assert(toStr(db.get("b")) === "2")
+        assert(toStr(db.get("c")) === "3")
+        val (version3, _) = db.commit(forceSnapshot = true)
+        assert(version3 === version2 + 1)
+      }
+
+      // Verify we can reload the committed version
+      withDB(remoteDir, version = 3,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        assert(toStr(db.get("c")) === "3")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))

Review Comment:
   is this correct?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala:
##########
@@ -3942,6 +3942,85 @@ class RocksDBSuite extends AlsoTestWithRocksDBFeatures 
with SharedSparkSession
     }}
   }
 
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty creates empty store at specified 
version") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Put initial data first
+        val version = 0
+        db.load(version, versionToUniqueId.get(0))
+        db.put("a", "1")
+        val (version1, _) = db.commit()
+        assert(db.get("a") === "1")
+
+        db.load(version1, versionToUniqueId.get(1), createEmpty = true)
+
+        // Add data and commit - should produce version 11
+        db.put("b", "2")
+        val (version2, _) = db.commit(forceSnapshot = true)
+        assert(version2 === version1 + 1)
+        assert(toStr(db.get("b")) === "2")
+        assert(db.get("a") === null)
+        assert(iterator(db).isEmpty)
+
+        db.put("c", "3")
+        assert(toStr(db.get("b")) === "2")
+        assert(toStr(db.get("c")) === "3")
+        val (version3, _) = db.commit(forceSnapshot = true)
+        assert(version3 === version2 + 1)
+      }
+
+      // Verify we can reload the committed version
+      withDB(remoteDir, version = 3,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        assert(toStr(db.get("c")) === "3")
+        assert(db.iterator().map(toStr).toSet === Set(("a", "1")))
+      }
+  }
+
+  testWithStateStoreCheckpointIdsAndChangelogEnabled(
+    "SPARK-54420: load with createEmpty at version 0") {
+    enableStateStoreCheckpointIds =>
+      val remoteDir = Utils.createTempDir().toString
+      new File(remoteDir).delete()
+      val versionToUniqueId = new mutable.HashMap[Long, String]()
+
+      withDB(remoteDir,
+        enableStateStoreCheckpointIds = enableStateStoreCheckpointIds,
+        versionToUniqueId = versionToUniqueId) { db =>
+        // Create empty store at version 0
+        val ckptId = if (enableStateStoreCheckpointIds) {

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionWriter.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.util.UUID
+
+import scala.collection.MapView
+import scala.collection.immutable.HashMap
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import 
org.apache.spark.sql.execution.streaming.runtime.StreamingCheckpointConstants.DIR_NAME_STATE
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProvider, 
StateStoreProviderId}
+
+/**
+ * A writer that can directly write binary data to the streaming state store.
+ *
+ * This writer expects input rows with the same schema produced by
+ * StatePartitionAllColumnFamiliesReader:
+ *   (partition_key, key_bytes, value_bytes, column_family_name)
+ *
+ * The writer creates a fresh (empty) state store instance for the target 
commit version
+ * instead of loading previous partition data. After writing all rows for the 
partition, it will
+ * commit all changes as a snapshot
+ */
+class StatePartitionAllColumnFamiliesWriter(
+     storeConf: StateStoreConf,
+     hadoopConf: Configuration,
+     partitionId: Int,
+     targetCpLocation: String,
+     operatorId: Int,
+     storeName: String,
+     batchId: Long,
+     columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+  private val defaultSchema = {
+    columnFamilyToSchemaMap.getOrElse(
+      StateStore.DEFAULT_COL_FAMILY_NAME,
+      throw new IllegalArgumentException(
+        s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in 
schema map")
+    )
+  }
+
+  private val columnFamilyToKeySchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.keySchema.length)
+  private val columnFamilyToValueSchemaLenMap: MapView[String, Int] =
+    columnFamilyToSchemaMap.view.mapValues(_.valueSchema.length)
+
+  private val rowConverter = {
+    val schema = 
SchemaUtil.getScanAllColumnFamiliesSchema(defaultSchema.keySchema)
+    CatalystTypeConverters.createToCatalystConverter(schema)
+  }
+
+  protected lazy val provider: StateStoreProvider = {
+    val stateCheckpointLocation = new Path(targetCpLocation, 
DIR_NAME_STATE).toString
+    val stateStoreId = StateStoreId(stateCheckpointLocation,
+      operatorId, partitionId, storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
UUID.randomUUID())
+
+    val provider = StateStoreProvider.createAndInit(
+      stateStoreProviderId, defaultSchema.keySchema, defaultSchema.valueSchema,
+      defaultSchema.keyStateEncoderSpec.get,
+      useColumnFamilies = false, storeConf, hadoopConf,
+      useMultipleValuesPerKey = false, stateSchemaProvider = None)
+    provider
+  }
+
+  private lazy val stateStore: StateStore = {
+    // TODO[SPARK-54590]: Support checkpoint V2 in 
StatePartitionAllColumnFamiliesWriter
+    // Create empty store to avoid loading old partition data since we are 
rewriting the
+    // store e.g. during repartitioning
+    // Use loadEmpty=true to create a fresh state store without loading 
previous versions
+    // We create the empty store AT version, and the next commit will
+    // produce version + 1
+    provider.getStore(
+      batchId + 1,
+      stateStoreCkptId = None,
+      loadEmpty = true
+    )
+  }
+
+  // The function that writes and commits data to state store. It takes in 
rows with schema
+  // - partition_key, StructType
+  // - key_bytes, BinaryType
+  // - value_bytes, BinaryType
+  // - column_family_name, StringType
+  def put(rows: Iterator[Row]): Unit = {

Review Comment:
   nit: call it `write` instead. Since it is possible for it to be doing 
non-put operations. To avoid confusion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to