This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e64f620fe8fd [SPARK-48796][SS] Load Column Family Id from 
RocksDBCheckpointMetadata for VCF when restarting
e64f620fe8fd is described below

commit e64f620fe8fd2e049d0c0ce2105a8417264d3022
Author: Eric Marnadi <[email protected]>
AuthorDate: Tue Aug 20 11:38:40 2024 +0900

    [SPARK-48796][SS] Load Column Family Id from RocksDBCheckpointMetadata for 
VCF when restarting
    
    ### What changes were proposed in this pull request?
    
    Persisting the mapping between columnFamilyName to columnFamilyId in 
RocksDBCheckpointMetadata and RocksDBSnapshot. RocksDB will maintain an 
internal metadata of this mapping, and set this info on load. 
RocksDBStateStoreProvider can call columnFamily operations as usual, and 
RocksDB.scala will translate the name to the virtual column family ID.
    ### Why are the changes needed?
    
    To enable the use of virtual column families, and the performance benefits 
it comes along with, with the TransformWithState operator
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Amended unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47778 from ericm-db/vcf-integration-state-store.
    
    Authored-by: Eric Marnadi <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../sql/execution/streaming/state/RocksDB.scala    | 138 ++++++++++++--
 .../streaming/state/RocksDBFileManager.scala       |  87 ++++++++-
 .../state/RocksDBStateStoreProvider.scala          | 199 ++++++++-------------
 .../sql/execution/streaming/state/StateStore.scala |   2 +
 .../streaming/state/RocksDBStateStoreSuite.scala   |  64 ++++++-
 .../sql/streaming/TransformWithStateSuite.scala    |  81 +++++++++
 6 files changed, 427 insertions(+), 144 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index b454e0ba5c93..d743e581df0f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
 import java.util.Locale
-import java.util.concurrent.TimeUnit
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.{mutable, Map}
 import scala.collection.mutable.ListBuffer
+import scala.jdk.CollectionConverters.{ConcurrentMapHasAsScala, MapHasAsJava}
 import scala.ref.WeakReference
 import scala.util.Try
 
@@ -76,7 +78,9 @@ class RocksDB(
       checkpointDir: File,
       version: Long,
       numKeys: Long,
-      capturedFileMappings: RocksDBFileMappings) {
+      capturedFileMappings: RocksDBFileMappings,
+      columnFamilyMapping: Map[String, Short],
+      maxColumnFamilyId: Short) {
     def close(): Unit = {
       silentDeleteRecursively(checkpointDir, s"Free up local checkpoint of 
snapshot $version")
     }
@@ -166,6 +170,87 @@ class RocksDB(
   @GuardedBy("acquireLock")
   @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
 
+  // This is accessed and updated only between load and commit
+  // which means it is implicitly guarded by acquireLock
+  @GuardedBy("acquireLock")
+  private val colFamilyNameToIdMap = new ConcurrentHashMap[String, Short]()
+
+  @GuardedBy("acquireLock")
+  private val maxColumnFamilyId: AtomicInteger = new AtomicInteger(-1)
+
+  @GuardedBy("acquireLock")
+  private val shouldForceSnapshot: AtomicBoolean = new AtomicBoolean(false)
+
+  /**
+   * Check whether the column family name is for internal column families.
+   *
+   * @param cfName - column family name
+   * @return - true if the column family is for internal use, false otherwise
+   */
+  private def checkInternalColumnFamilies(cfName: String): Boolean = 
cfName.charAt(0) == '_'
+
+  // Methods to fetch column family mapping for this State Store version
+  def getColumnFamilyMapping: Map[String, Short] = {
+    colFamilyNameToIdMap.asScala
+  }
+
+  def getColumnFamilyId(cfName: String): Short = {
+    colFamilyNameToIdMap.get(cfName)
+  }
+
+  /**
+   * Create RocksDB column family, if not created already
+   */
+  def createColFamilyIfAbsent(colFamilyName: String): Short = {
+    if (!checkColFamilyExists(colFamilyName)) {
+      val newColumnFamilyId = maxColumnFamilyId.incrementAndGet().toShort
+      colFamilyNameToIdMap.putIfAbsent(colFamilyName, newColumnFamilyId)
+      shouldForceSnapshot.set(true)
+      newColumnFamilyId
+    } else {
+      colFamilyNameToIdMap.get(colFamilyName)
+    }
+  }
+
+  /**
+   * Remove RocksDB column family, if exists
+   */
+  def removeColFamilyIfExists(colFamilyName: String): Boolean = {
+    if (checkColFamilyExists(colFamilyName)) {
+      colFamilyNameToIdMap.remove(colFamilyName)
+      shouldForceSnapshot.set(true)
+      true
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Function to check if the column family exists in the state store instance.
+   *
+   * @param colFamilyName - name of the column family
+   * @return - true if the column family exists, false otherwise
+   */
+  def checkColFamilyExists(colFamilyName: String): Boolean = {
+    colFamilyNameToIdMap.containsKey(colFamilyName)
+  }
+
+  // This method sets the internal column family metadata to
+  // the default values it should be set to on load
+  private def setInitialCFInfo(): Unit = {
+    colFamilyNameToIdMap.clear()
+    shouldForceSnapshot.set(false)
+    maxColumnFamilyId.set(0)
+  }
+
+  def getColFamilyCount(isInternal: Boolean): Long = {
+    if (isInternal) {
+      
colFamilyNameToIdMap.asScala.keys.toSeq.count(checkInternalColumnFamilies)
+    } else {
+      
colFamilyNameToIdMap.asScala.keys.toSeq.count(!checkInternalColumnFamilies(_))
+    }
+  }
+
   /**
    * Load the given version of data in a native RocksDB instance.
    * Note that this will copy all the necessary file from DFS to local disk as 
needed,
@@ -188,6 +273,14 @@ class RocksDB(
         // Initialize maxVersion upon successful load from DFS
         fileManager.setMaxSeenVersion(version)
 
+        setInitialCFInfo()
+        metadata.columnFamilyMapping.foreach { mapping =>
+          colFamilyNameToIdMap.putAll(mapping.asJava)
+        }
+
+        metadata.maxColumnFamilyId.foreach { maxId =>
+          maxColumnFamilyId.set(maxId)
+        }
         // reset last snapshot version
         if (lastSnapshotVersion > latestSnapshotVersion) {
           // discard any newer snapshots
@@ -496,7 +589,7 @@ class RocksDB(
       var compactTimeMs = 0L
       var flushTimeMs = 0L
       var checkpointTimeMs = 0L
-      if (shouldCreateSnapshot()) {
+      if (shouldCreateSnapshot() || shouldForceSnapshot.get()) {
         // Need to flush the change to disk before creating a checkpoint
         // because rocksdb wal is disabled.
         logInfo(log"Flushing updates for ${MDC(LogKeys.VERSION_NUM, 
newVersion)}")
@@ -535,7 +628,9 @@ class RocksDB(
               RocksDBSnapshot(checkpointDir,
                 newVersion,
                 numKeysOnWritingVersion,
-                fileManager.captureFileMapReference()))
+                fileManager.captureFileMapReference(),
+                colFamilyNameToIdMap.asScala.toMap,
+                maxColumnFamilyId.get().toShort))
             lastSnapshotVersion = newVersion
           }
         }
@@ -544,11 +639,20 @@ class RocksDB(
       logInfo(log"Syncing checkpoint for ${MDC(LogKeys.VERSION_NUM, 
newVersion)} to DFS")
       val fileSyncTimeMs = timeTakenMs {
         if (enableChangelogCheckpointing) {
-          try {
-            assert(changelogWriter.isDefined)
-            changelogWriter.foreach(_.commit())
-          } finally {
+          // If we have changed the columnFamilyId mapping, we have set a new
+          // snapshot and need to upload this to the DFS even if changelog 
checkpointing
+          // is enabled.
+          if (shouldForceSnapshot.get()) {
+            uploadSnapshot()
             changelogWriter = None
+            changelogWriter.foreach(_.abort())
+          } else {
+            try {
+              assert(changelogWriter.isDefined)
+              changelogWriter.foreach(_.commit())
+            } finally {
+              changelogWriter = None
+            }
           }
         } else {
           assert(changelogWriter.isEmpty)
@@ -606,10 +710,24 @@ class RocksDB(
       checkpoint
     }
     localCheckpoint match {
-      case Some(RocksDBSnapshot(localDir, version, numKeys, 
capturedFileMappings)) =>
+      case Some(
+        RocksDBSnapshot(
+          localDir,
+          version,
+          numKeys,
+          capturedFileMappings,
+          columnFamilyMapping,
+          maxColumnFamilyId)) =>
         try {
           val uploadTime = timeTakenMs {
-            fileManager.saveCheckpointToDfs(localDir, version, numKeys, 
capturedFileMappings)
+            fileManager.saveCheckpointToDfs(
+              localDir,
+              version,
+              numKeys,
+              capturedFileMappings,
+              Some(columnFamilyMapping.toMap),
+              Some(maxColumnFamilyId)
+            )
             fileManagerMetrics = fileManager.latestSaveCheckpointMetrics
           }
           logInfo(log"${MDC(LogKeys.LOG_ID, loggingId)}: Upload snapshot of 
version " +
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 0c673047db62..350a5797978b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -249,12 +249,15 @@ class RocksDBFileManager(
       checkpointDir: File,
       version: Long,
       numKeys: Long,
-      capturedFileMappings: RocksDBFileMappings): Unit = {
+      capturedFileMappings: RocksDBFileMappings,
+      columnFamilyMapping: Option[Map[String, Short]] = None,
+      maxColumnFamilyId: Option[Short] = None): Unit = {
     logFilesInDir(checkpointDir, log"Saving checkpoint files " +
       log"for version ${MDC(LogKeys.VERSION_NUM, version)}")
     val (localImmutableFiles, localOtherFiles) = 
listRocksDBFiles(checkpointDir)
     val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles, 
capturedFileMappings)
-    val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys)
+    val metadata = RocksDBCheckpointMetadata(
+      rocksDBFiles, numKeys, columnFamilyMapping, maxColumnFamilyId)
     val metadataFile = localMetadataFile(checkpointDir)
     metadata.writeToFile(metadataFile)
     logInfo(log"Written metadata for version ${MDC(LogKeys.VERSION_NUM, 
version)}:\n" +
@@ -889,11 +892,17 @@ object RocksDBFileManagerMetrics {
 case class RocksDBCheckpointMetadata(
     sstFiles: Seq[RocksDBSstFile],
     logFiles: Seq[RocksDBLogFile],
-    numKeys: Long) {
+    numKeys: Long,
+    columnFamilyMapping: Option[Map[String, Short]] = None,
+    maxColumnFamilyId: Option[Short] = None) {
+
+  require(columnFamilyMapping.isDefined == maxColumnFamilyId.isDefined,
+    "columnFamilyMapping and maxColumnFamilyId must either both be defined or 
both be None")
+
   import RocksDBCheckpointMetadata._
 
   def json: String = {
-    // We turn this field into a null to avoid write a empty logFiles field in 
the json.
+    // We turn this field into a null to avoid write an empty logFiles field 
in the json.
     val nullified = if (logFiles.isEmpty) this.copy(logFiles = null) else this
     mapper.writeValueAsString(nullified)
   }
@@ -941,11 +950,73 @@ object RocksDBCheckpointMetadata {
     }
   }
 
-  def apply(rocksDBFiles: Seq[RocksDBImmutableFile], numKeys: Long): 
RocksDBCheckpointMetadata = {
-    val sstFiles = rocksDBFiles.collect { case file: RocksDBSstFile => file }
-    val logFiles = rocksDBFiles.collect { case file: RocksDBLogFile => file }
+  // Apply method for cases without column family information
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long): RocksDBCheckpointMetadata = {
+    val (sstFiles, logFiles) = 
rocksDBFiles.partition(_.isInstanceOf[RocksDBSstFile])
+    new RocksDBCheckpointMetadata(
+      sstFiles.map(_.asInstanceOf[RocksDBSstFile]),
+      logFiles.map(_.asInstanceOf[RocksDBLogFile]),
+      numKeys,
+      None,
+      None
+    )
+  }
+
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long,
+      columnFamilyMapping: Option[Map[String, Short]],
+      maxColumnFamilyId: Option[Short]): RocksDBCheckpointMetadata = {
+    val (sstFiles, logFiles) = 
rocksDBFiles.partition(_.isInstanceOf[RocksDBSstFile])
+    new RocksDBCheckpointMetadata(
+      sstFiles.map(_.asInstanceOf[RocksDBSstFile]),
+      logFiles.map(_.asInstanceOf[RocksDBLogFile]),
+      numKeys,
+      columnFamilyMapping,
+      maxColumnFamilyId
+    )
+  }
+
+  // Apply method for cases with separate sstFiles and logFiles, without 
column family information
+  def apply(
+      sstFiles: Seq[RocksDBSstFile],
+      logFiles: Seq[RocksDBLogFile],
+      numKeys: Long): RocksDBCheckpointMetadata = {
+    new RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys, None, None)
+  }
+
+  // Apply method for cases with column family information
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long,
+      columnFamilyMapping: Map[String, Short],
+      maxColumnFamilyId: Short): RocksDBCheckpointMetadata = {
+    val (sstFiles, logFiles) = 
rocksDBFiles.partition(_.isInstanceOf[RocksDBSstFile])
+    new RocksDBCheckpointMetadata(
+      sstFiles.map(_.asInstanceOf[RocksDBSstFile]),
+      logFiles.map(_.asInstanceOf[RocksDBLogFile]),
+      numKeys,
+      Some(columnFamilyMapping),
+      Some(maxColumnFamilyId)
+    )
+  }
 
-    RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys)
+  // Apply method for cases with separate sstFiles and logFiles, and column 
family information
+  def apply(
+      sstFiles: Seq[RocksDBSstFile],
+      logFiles: Seq[RocksDBLogFile],
+      numKeys: Long,
+      columnFamilyMapping: Map[String, Short],
+      maxColumnFamilyId: Short): RocksDBCheckpointMetadata = {
+    new RocksDBCheckpointMetadata(
+      sstFiles,
+      logFiles,
+      numKeys,
+      Some(columnFamilyMapping),
+      Some(maxColumnFamilyId)
+    )
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 0073e22d4956..075ab7d00842 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -19,9 +19,7 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io._
 import java.util.concurrent.ConcurrentHashMap
-import java.util.concurrent.atomic.AtomicInteger
 
-import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.conf.Configuration
@@ -56,6 +54,17 @@ private[sql] class RocksDBStateStoreProvider
 
     override def version: Long = lastVersion
 
+    // Test-visible methods to fetch column family mapping for this State 
Store version
+    // Because column families are only enabled for RocksDBStateStore, these 
methods
+    // are no-ops everywhere else.
+    private[sql] def getColumnFamilyMapping: Map[String, Short] = {
+      rocksDB.getColumnFamilyMapping.toMap
+    }
+
+    private[sql] def getColumnFamilyId(cfName: String): Short = {
+      rocksDB.getColumnFamilyId(cfName)
+    }
+
     override def createColFamilyIfAbsent(
         colFamilyName: String,
         keySchema: StructType,
@@ -63,16 +72,17 @@ private[sql] class RocksDBStateStoreProvider
         keyStateEncoderSpec: KeyStateEncoderSpec,
         useMultipleValuesPerKey: Boolean = false,
         isInternal: Boolean = false): Unit = {
-      val newColFamilyId = 
ColumnFamilyUtils.createColFamilyIfAbsent(colFamilyName, isInternal)
-
+      verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, 
isInternal)
+      val newColFamilyId = rocksDB.createColFamilyIfAbsent(colFamilyName)
       keyValueEncoderMap.putIfAbsent(colFamilyName,
-        (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, 
useColumnFamilies, newColFamilyId),
-         RocksDBStateEncoder.getValueEncoder(valueSchema, 
useMultipleValuesPerKey)))
+        (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec, 
useColumnFamilies,
+          Some(newColFamilyId)), 
RocksDBStateEncoder.getValueEncoder(valueSchema,
+          useMultipleValuesPerKey)))
     }
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
       verify(key != null, "Key cannot be null")
-      ColumnFamilyUtils.verifyColFamilyOperations("get", colFamilyName)
+      verifyColFamilyOperations("get", colFamilyName)
 
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       val value =
@@ -98,7 +108,7 @@ private[sql] class RocksDBStateStoreProvider
      */
     override def valuesIterator(key: UnsafeRow, colFamilyName: String): 
Iterator[UnsafeRow] = {
       verify(key != null, "Key cannot be null")
-      ColumnFamilyUtils.verifyColFamilyOperations("valuesIterator", 
colFamilyName)
+      verifyColFamilyOperations("valuesIterator", colFamilyName)
 
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       val valueEncoder = kvEncoder._2
@@ -114,7 +124,7 @@ private[sql] class RocksDBStateStoreProvider
     override def merge(key: UnsafeRow, value: UnsafeRow,
         colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
       verify(state == UPDATING, "Cannot merge after already committed or 
aborted")
-      ColumnFamilyUtils.verifyColFamilyOperations("merge", colFamilyName)
+      verifyColFamilyOperations("merge", colFamilyName)
 
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       val keyEncoder = kvEncoder._1
@@ -131,7 +141,7 @@ private[sql] class RocksDBStateStoreProvider
       verify(state == UPDATING, "Cannot put after already committed or 
aborted")
       verify(key != null, "Key cannot be null")
       require(value != null, "Cannot put a null value")
-      ColumnFamilyUtils.verifyColFamilyOperations("put", colFamilyName)
+      verifyColFamilyOperations("put", colFamilyName)
 
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       rocksDB.put(kvEncoder._1.encodeKey(key), kvEncoder._2.encodeValue(value))
@@ -140,7 +150,7 @@ private[sql] class RocksDBStateStoreProvider
     override def remove(key: UnsafeRow, colFamilyName: String): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or 
aborted")
       verify(key != null, "Key cannot be null")
-      ColumnFamilyUtils.verifyColFamilyOperations("remove", colFamilyName)
+      verifyColFamilyOperations("remove", colFamilyName)
 
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       rocksDB.remove(kvEncoder._1.encodeKey(key))
@@ -150,7 +160,7 @@ private[sql] class RocksDBStateStoreProvider
       // Note this verify function only verify on the colFamilyName being 
valid,
       // we are actually doing prefix when useColumnFamilies,
       // but pass "iterator" to throw correct error message
-      ColumnFamilyUtils.verifyColFamilyOperations("iterator", colFamilyName)
+      verifyColFamilyOperations("iterator", colFamilyName)
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       val rowPair = new UnsafeRowPair()
 
@@ -184,7 +194,7 @@ private[sql] class RocksDBStateStoreProvider
 
     override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String):
       Iterator[UnsafeRowPair] = {
-      ColumnFamilyUtils.verifyColFamilyOperations("prefixScan", colFamilyName)
+      verifyColFamilyOperations("prefixScan", colFamilyName)
 
       val kvEncoder = keyValueEncoderMap.get(colFamilyName)
       require(kvEncoder._1.supportPrefixKeyScan,
@@ -248,13 +258,11 @@ private[sql] class RocksDBStateStoreProvider
 
         // Used for metrics reporting around internal/external column families
         def internalColFamilyCnt(): Long = {
-          colFamilyNameToIdMap.keys.asScala.toSeq
-            .filter(ColumnFamilyUtils.checkInternalColumnFamilies(_)).size
+          rocksDB.getColFamilyCount(isInternal = true)
         }
 
         def externalColFamilyCnt(): Long = {
-          colFamilyNameToIdMap.keys.asScala.toSeq
-            .filter(!ColumnFamilyUtils.checkInternalColumnFamilies(_)).size
+          rocksDB.getColFamilyCount(isInternal = false)
         }
 
         val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long](
@@ -309,10 +317,11 @@ private[sql] class RocksDBStateStoreProvider
 
     /** Remove column family if exists */
     override def removeColFamilyIfExists(colFamilyName: String): Boolean = {
+      verifyColFamilyCreationOrDeletion("remove_col_family", colFamilyName)
       verify(useColumnFamilies, "Column families are not supported in this 
store")
 
       val result = {
-        val colFamilyExists = 
ColumnFamilyUtils.removeColFamilyIfExists(colFamilyName)
+        val colFamilyExists = rocksDB.removeColFamilyIfExists(colFamilyName)
 
         if (colFamilyExists) {
           val colFamilyIdBytes =
@@ -328,6 +337,11 @@ private[sql] class RocksDBStateStoreProvider
     }
   }
 
+  // Test-visible method to fetch the internal RocksDBStateStore class
+  private[sql] def getRocksDBStateStore(version: Long): RocksDBStateStore = {
+    getStore(version).asInstanceOf[RocksDBStateStore]
+  }
+
   override def init(
       stateStoreId: StateStoreId,
       keySchema: StructType,
@@ -349,18 +363,17 @@ private[sql] class RocksDBStateStoreProvider
         " enabled in RocksDBStateStore.")
     }
 
+    rocksDB // lazy initialization
     var defaultColFamilyId: Option[Short] = None
+
     if (useColumnFamilies) {
-      // put default column family only if useColumnFamilies are enabled
-      colFamilyNameToIdMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME, 
colFamilyId.shortValue())
-      defaultColFamilyId = Option(colFamilyId.shortValue())
+      defaultColFamilyId = 
Some(rocksDB.createColFamilyIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME))
     }
+
     keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
       (RocksDBStateEncoder.getKeyEncoder(keyStateEncoderSpec,
         useColumnFamilies, defaultColFamilyId),
         RocksDBStateEncoder.getValueEncoder(valueSchema, 
useMultipleValuesPerKey)))
-
-    rocksDB // lazy initialization
   }
 
   override def stateStoreId: StateStoreId = stateStoreId_
@@ -447,9 +460,8 @@ private[sql] class RocksDBStateStoreProvider
   private val keyValueEncoderMap = new 
java.util.concurrent.ConcurrentHashMap[String,
     (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]
 
-  private val colFamilyNameToIdMap = new 
java.util.concurrent.ConcurrentHashMap[String, Short]
-  // TODO SPARK-48796 load column family id from state schema when restarting
-  private val colFamilyId = new AtomicInteger(0)
+  private val multiColFamiliesDisabledStr = "multiple column families is 
disabled in " +
+    "RocksDBStateStoreProvider"
 
   private def verify(condition: => Boolean, msg: String): Unit = {
     if (!condition) { throw new IllegalStateException(msg) }
@@ -498,116 +510,63 @@ private[sql] class RocksDBStateStoreProvider
   }
 
   /**
-   * Class for column family related utility functions.
-   * Verification functions for column family names, column family operation 
validations etc.
+   * Function to verify invariants for column family based operations
+   * such as get, put, remove etc.
+   *
+   * @param operationName - name of the store operation
+   * @param colFamilyName - name of the column family
    */
-  private object ColumnFamilyUtils {
-    private val multColFamiliesDisabledStr = "multiple column families is 
disabled in " +
-      "RocksDBStateStoreProvider"
-
-    /**
-     * Function to verify invariants for column family based operations
-     * such as get, put, remove etc.
-     *
-     * @param operationName - name of the store operation
-     * @param colFamilyName - name of the column family
-     */
-    def verifyColFamilyOperations(
-        operationName: String,
-        colFamilyName: String): Unit = {
-      if (colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME) {
-        // if the state store instance does not support multiple column 
families, throw an exception
-        if (!useColumnFamilies) {
-          throw StateStoreErrors.unsupportedOperationException(operationName,
-            multColFamiliesDisabledStr)
-        }
-
-        // if the column family name is empty or contains leading/trailing 
whitespaces, throw an
-        // exception
-        if (colFamilyName.isEmpty || colFamilyName.trim != colFamilyName) {
-          throw 
StateStoreErrors.cannotUseColumnFamilyWithInvalidName(operationName, 
colFamilyName)
-        }
-
-        // if the column family does not exist, throw an exception
-        if (!checkColFamilyExists(colFamilyName)) {
-          throw 
StateStoreErrors.unsupportedOperationOnMissingColumnFamily(operationName,
-            colFamilyName)
-        }
-      }
-    }
-
-    /**
-     * Function to verify invariants for column family creation or deletion 
operations.
-     *
-     * @param operationName - name of the store operation
-     * @param colFamilyName - name of the column family
-     */
-    private def verifyColFamilyCreationOrDeletion(
-        operationName: String,
-        colFamilyName: String,
-        isInternal: Boolean = false): Unit = {
+  private def verifyColFamilyOperations(
+      operationName: String,
+      colFamilyName: String): Unit = {
+    if (colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME) {
       // if the state store instance does not support multiple column 
families, throw an exception
       if (!useColumnFamilies) {
         throw StateStoreErrors.unsupportedOperationException(operationName,
-          multColFamiliesDisabledStr)
+          multiColFamiliesDisabledStr)
       }
 
-      // if the column family name is empty or contains leading/trailing 
whitespaces
-      // or using the reserved "default" column family, throw an exception
-      if (colFamilyName.isEmpty
-        || colFamilyName.trim != colFamilyName
-        || colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME) {
+      // if the column family name is empty or contains leading/trailing 
whitespaces, throw an
+      // exception
+      if (colFamilyName.isEmpty || colFamilyName.trim != colFamilyName) {
         throw 
StateStoreErrors.cannotUseColumnFamilyWithInvalidName(operationName, 
colFamilyName)
       }
 
-      // if the column family is not internal and uses reserved characters, 
throw an exception
-      if (!isInternal && colFamilyName.charAt(0) == '_') {
-        throw 
StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName)
+      // if the column family does not exist, throw an exception
+      if (!rocksDB.checkColFamilyExists(colFamilyName)) {
+        throw 
StateStoreErrors.unsupportedOperationOnMissingColumnFamily(operationName,
+          colFamilyName)
       }
     }
+  }
 
-    /**
-     * Check whether the column family name is for internal column families.
-     *
-     * @param cfName - column family name
-     * @return - true if the column family is for internal use, false otherwise
-     */
-    def checkInternalColumnFamilies(cfName: String): Boolean = 
cfName.charAt(0) == '_'
-
-    /**
-     * Create RocksDB column family, if not created already
-     */
-    def createColFamilyIfAbsent(colFamilyName: String, isInternal: Boolean = 
false):
-      Option[Short] = {
-      verifyColFamilyCreationOrDeletion("create_col_family", colFamilyName, 
isInternal)
-      if (!checkColFamilyExists(colFamilyName)) {
-        val newColumnFamilyId = colFamilyId.incrementAndGet().toShort
-        colFamilyNameToIdMap.putIfAbsent(colFamilyName, newColumnFamilyId)
-        Option(newColumnFamilyId)
-      } else None
+  /**
+   * Function to verify invariants for column family creation or deletion 
operations.
+   *
+   * @param operationName - name of the store operation
+   * @param colFamilyName - name of the column family
+   */
+  private def verifyColFamilyCreationOrDeletion(
+      operationName: String,
+      colFamilyName: String,
+      isInternal: Boolean = false): Unit = {
+    // if the state store instance does not support multiple column families, 
throw an exception
+    if (!useColumnFamilies) {
+      throw StateStoreErrors.unsupportedOperationException(operationName,
+        multiColFamiliesDisabledStr)
     }
 
-    /**
-     * Remove RocksDB column family, if exists
-     */
-    def removeColFamilyIfExists(colFamilyName: String): Boolean = {
-      verifyColFamilyCreationOrDeletion("remove_col_family", colFamilyName)
-      if (checkColFamilyExists(colFamilyName)) {
-        colFamilyNameToIdMap.remove(colFamilyName)
-        true
-      } else {
-        false
-      }
+    // if the column family name is empty or contains leading/trailing 
whitespaces
+    // or using the reserved "default" column family, throw an exception
+    if (colFamilyName.isEmpty
+      || (colFamilyName.trim != colFamilyName)
+      || (colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME && !isInternal)) 
{
+      throw 
StateStoreErrors.cannotUseColumnFamilyWithInvalidName(operationName, 
colFamilyName)
     }
 
-    /**
-     * Function to check if the column family exists in the state store 
instance.
-     *
-     * @param colFamilyName - name of the column family
-     * @return - true if the column family exists, false otherwise
-     */
-    def checkColFamilyExists(colFamilyName: String): Boolean = {
-      colFamilyNameToIdMap.containsKey(colFamilyName)
+    // if the column family is not internal and uses reserved characters, 
throw an exception
+    if (!isInternal && colFamilyName.charAt(0) == '_') {
+      throw 
StateStoreErrors.cannotCreateColumnFamilyWithReservedChars(colFamilyName)
     }
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index c9722c2676df..d55a973a14e1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -126,6 +126,8 @@ trait StateStore extends ReadStateStore {
 
   /**
    * Create column family with given name, if absent.
+   *
+   * @return column family ID
    */
   def createColFamilyIfAbsent(
       colFamilyName: String,
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index f6fffc519d8c..0bd86068ca3f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -1026,15 +1026,14 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
-  // TODO SPARK-48796 after restart state id will not be the same
-  ignore(s"get, put, iterator, commit, load with multiple column families") {
+  test(s"get, put, iterator, commit, load with multiple column families") {
     tryWithProviderResource(newStoreProvider(useColumnFamilies = true)) { 
provider =>
       def get(store: StateStore, col1: String, col2: Int, colFamilyName: 
String): UnsafeRow = {
         store.get(dataToKeyRow(col1, col2), colFamilyName)
       }
 
-      def iterator(store: StateStore, colFamilyName: String): Seq[((String, 
Int), Int)] = {
-        store.iterator(colFamilyName).toSeq.map {
+      def iterator(store: StateStore, colFamilyName: String): 
Iterator[((String, Int), Int)] = {
+        store.iterator(colFamilyName).map {
           case unsafePair =>
             (keyRowToData(unsafePair.key), valueRowToData(unsafePair.value))
         }
@@ -1063,12 +1062,21 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
       put(store, ("a", 1), 1, colFamily2)
       assert(valueRowToData(get(store, "a", 1, colFamily2)) === 1)
 
+      // calling commit on this store creates version 1
       store.commit()
 
       // reload version 0
       store = provider.getStore(0)
-      assert(get(store, "a", 1, colFamily1) === null)
-      assert(iterator(store, colFamily1).isEmpty)
+
+      val e = intercept[Exception]{
+        get(store, "a", 1, colFamily1)
+      }
+      checkError(
+        exception = 
e.asInstanceOf[StateStoreUnsupportedOperationOnMissingColumnFamily],
+        errorClass = 
"STATE_STORE_UNSUPPORTED_OPERATION_ON_MISSING_COLUMN_FAMILY",
+        sqlState = Some("42802"),
+        parameters = Map("operationType" -> "get", "colFamilyName" -> 
colFamily1)
+      )
 
       store = provider.getStore(1)
       // version 1 data recovered correctly
@@ -1090,6 +1098,50 @@ class RocksDBStateStoreSuite extends 
StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+
+  test("verify that column family id is assigned correctly after removal") {
+    tryWithProviderResource(newStoreProvider(useColumnFamilies = true)) { 
provider =>
+      var store = provider.getRocksDBStateStore(0)
+      val colFamily1: String = "abc"
+      val colFamily2: String = "def"
+      val colFamily3: String = "ghi"
+      val colFamily4: String = "jkl"
+      val colFamily5: String = "mno"
+
+      store.createColFamilyIfAbsent(colFamily1, keySchema, valueSchema,
+        NoPrefixKeyStateEncoderSpec(keySchema))
+      store.createColFamilyIfAbsent(colFamily2, keySchema, valueSchema,
+        NoPrefixKeyStateEncoderSpec(keySchema))
+      store.commit()
+
+      store = provider.getRocksDBStateStore(1)
+      store.removeColFamilyIfExists(colFamily2)
+      store.commit()
+
+      store = provider.getRocksDBStateStore(2)
+      store.createColFamilyIfAbsent(colFamily3, keySchema, valueSchema,
+        NoPrefixKeyStateEncoderSpec(keySchema))
+      assert(store.getColumnFamilyId(colFamily3) == 3)
+      store.removeColFamilyIfExists(colFamily1)
+      store.removeColFamilyIfExists(colFamily3)
+      store.commit()
+
+      store = provider.getRocksDBStateStore(1)
+      // this should return the old id, because we didn't remove this 
colFamily for version 1
+      store.createColFamilyIfAbsent(colFamily1, keySchema, valueSchema,
+        NoPrefixKeyStateEncoderSpec(keySchema))
+      assert(store.getColumnFamilyId(colFamily1) == 1)
+
+      store = provider.getRocksDBStateStore(3)
+      store.createColFamilyIfAbsent(colFamily4, keySchema, valueSchema,
+        NoPrefixKeyStateEncoderSpec(keySchema))
+      assert(store.getColumnFamilyId(colFamily4) == 4)
+      store.createColFamilyIfAbsent(colFamily5, keySchema, valueSchema,
+        NoPrefixKeyStateEncoderSpec(keySchema))
+      assert(store.getColumnFamilyId(colFamily5) == 5)
+    }
+  }
+
   Seq(
     NoPrefixKeyStateEncoderSpec(keySchema), 
PrefixKeyScanStateEncoderSpec(keySchema, 1)
   ).foreach { keyEncoder =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 06b97e606632..bd18fd83e43a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -1307,6 +1307,87 @@ class TransformWithStateSuite extends 
StateStoreMetricsTest
     }
   }
 
+  test("test query restart with new state variable succeeds") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key ->
+        TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+      withTempDir { checkpointDir =>
+        val clock = new StreamManualClock
+
+        val inputData1 = MemoryStream[String]
+        val result1 = inputData1.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new RunningCountStatefulProcessor(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        testStream(result1, OutputMode.Update())(
+          StartStream(
+            checkpointLocation = checkpointDir.getCanonicalPath,
+            trigger = Trigger.ProcessingTime("1 second"),
+            triggerClock = clock),
+          AddData(inputData1, "a"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "1")),
+          StopStream
+        )
+
+        val result2 = inputData1.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        testStream(result2, OutputMode.Update())(
+          StartStream(
+            checkpointLocation = checkpointDir.getCanonicalPath,
+            trigger = Trigger.ProcessingTime("1 second"),
+            triggerClock = clock),
+          AddData(inputData1, "a"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "2")),
+          StopStream
+        )
+      }
+    }
+  }
+
+  test("test query restart succeeds") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key ->
+        TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+      withTempDir { checkpointDir =>
+        val inputData = MemoryStream[String]
+        val result1 = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new RunningCountStatefulProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+
+        testStream(result1, OutputMode.Update())(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          CheckNewAnswer(("a", "1")),
+          StopStream
+        )
+        val result2 = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+            TimeMode.None(),
+            OutputMode.Update())
+
+        testStream(result2, OutputMode.Update())(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          CheckNewAnswer(("a", "2")),
+          StopStream
+        )
+      }
+    }
+  }
+
   test("SPARK-49070: transformWithState - valid initial state plan") {
     withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
       classOf[RocksDBStateStoreProvider].getName) {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to