anishshri-db commented on code in PR #53316:
URL: https://github.com/apache/spark/pull/53316#discussion_r2628257602


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,117 @@ class StatePartitionAllColumnFamiliesReader(
     partition: StateStoreInputPartition,
     schema: StructType,
     keyStateEncoderSpec: KeyStateEncoderSpec,
-    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+    defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
   extends StatePartitionReaderBase(
     storeConf,
     hadoopConf, partition, schema,
-    keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+    keyStateEncoderSpec, None,
+    defaultStateStoreColFamilySchemaOpt,
+    stateSchemaProviderOpt, None) {
 
-  private lazy val store: ReadStateStore = {
+  private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
+  private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+
+  private def isListType(colFamilyName: String): Boolean = {
+    SchemaUtil.checkVariableType(
+      stateVariableInfos.find(info => info.stateName == colFamilyName),
+      StateVariableType.ListState)
+  }
+
+  override protected lazy val provider: StateStoreProvider = {
+    val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+      partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
+    val useColumnFamilies = stateStoreColFamilySchemas.size > 1
+    StateStoreProvider.createAndInit(
+      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+      useColumnFamilies, storeConf, hadoopConf.value,
+      useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+  }
+
+
+  private def checkAllColFamiliesExist(
+      colFamilyNames: List[String], stateStore: StateStore
+    ): Unit = {
+    // Filter out DEFAULT column family from validation for two reasons:
+    // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT 
in their schema
+    //    because the underlying RocksDB creates "default" column family 
automatically
+    // 2. The default column family schema is handled separately via
+    //    defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+    val actualCFs = colFamilyNames.toSet.filter(_ != 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    val expectedCFs = stateStore.allColumnFamilyNames
+      .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+    // Validation: All column families found in the checkpoint must be 
declared in the schema.
+    // It's acceptable if some schema CFs are not in expectedCFs - this just 
means those
+    // column families have no data yet in the checkpoint
+    // (they'll be created during registration).
+    // However, if the checkpoint contains CFs not in the schema, it indicates 
a mismatch.
+    require(expectedCFs.subsetOf(actualCFs),
+      s"Some column families are present in the state store but missing in the 
metadata. " +
+        s"Column families in state store but not in metadata: 
${expectedCFs.diff(actualCFs)}")
+  }
+
+  // Use a single store instance for both registering column families and 
iteration.
+  // We cannot abort and then get a read store because abort() invalidates the 
loaded version,
+  // causing getReadStore() to reload from checkpoint and clear the column 
family registrations.
+  private lazy val store: StateStore = {
     assert(getStartStoreUniqueId == getEndStoreUniqueId,
       "Start and end store unique IDs must be the same when reading all column 
families")
-    provider.getReadStore(
+    val stateStore = provider.getStore(
       partition.sourceOptions.batchId + 1,
       getStartStoreUniqueId
     )
+
+    // Register all column families from the schema
+    if (stateStoreColFamilySchemas.size > 1) {
+      
checkAllColFamiliesExist(stateStoreColFamilySchemas.map(_.colFamilyName).toList,
 stateStore)
+      stateStoreColFamilySchemas.foreach { cfSchema =>
+        cfSchema.colFamilyName match {
+          case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has 
registered default
+          case _ =>
+            val isInternal =
+              
StateStoreColumnFamilySchemaUtils.isInternalColumn(cfSchema.colFamilyName)
+            val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName)
+            require(cfSchema.keyStateEncoderSpec.isDefined,
+              s"keyStateEncoderSpec must be defined for column family 
${cfSchema.colFamilyName}")
+            stateStore.createColFamilyIfAbsent(
+              cfSchema.colFamilyName,
+              cfSchema.keySchema,
+              cfSchema.valueSchema,
+              cfSchema.keyStateEncoderSpec.get,
+              useMultipleValuesPerKey,
+              isInternal)
+        }
+      }
+    }
+    stateStore
   }
 
   override lazy val iter: Iterator[InternalRow] = {
-    store
-      .iterator()
-      .map { pair =>
-        SchemaUtil.unifyStateRowPairAsRawBytes(
-          (pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
+    // Iterate all column families and concatenate results
+    stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
+      if (isListType(cfSchema.colFamilyName)) {
+        store.iterator(cfSchema.colFamilyName).flatMap(
+          pair =>
+            store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
+              value =>
+                SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), 
cfSchema.colFamilyName)
+            }
+        )
+      } else {
+        store.iterator(cfSchema.colFamilyName).map { pair =>
+          SchemaUtil.unifyStateRowPairAsRawBytes(
+            (pair.key, pair.value), cfSchema.colFamilyName)
+        }
       }
+    }
   }
 
   override def close(): Unit = {
-    store.release()
+    store.abort()

Review Comment:
   Why do we change this ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala:
##########
@@ -99,6 +99,21 @@ object StateStoreColumnFamilySchemaUtils {
   def getStateNameFromCountIndexCFName(colFamilyName: String): String =
     getStateName(COUNT_INDEX_PREFIX, colFamilyName)
 
+  def isInternalColumn(colFamilyName: String): Boolean = {

Review Comment:
   nit: `isInternalColFamily` ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -184,6 +184,13 @@ trait ReadStateStore {
    * This method is idempotent and safe to call multiple times.
    */
   def release(): Unit
+
+  /**
+   * Returns all column family names in this state store.
+   *
+   * @return Set of all column family names
+   */
+  def allColumnFamilyNames: Set[String]

Review Comment:
   Can we add some tests for this in RocksDBSuite and maybe also at operator 
level for tws tests ?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -22,13 +22,17 @@ import 
org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
 import 
org.apache.spark.sql.execution.streaming.operators.stateful.join.SymmetricHashJoinStateManager
-import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateVariableType,
 TransformWithStateVariableInfo}
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{StateStoreColumnFamilySchemaUtils,
 StateVariableType, TransformWithStateVariableInfo}
 import org.apache.spark.sql.execution.streaming.state._
 import 
org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString,
 RecordType}
 import org.apache.spark.sql.types.{NullType, StructField, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.{NextIterator, SerializableConfiguration}
 
+case class AllColumnFamiliesReaderInfo(

Review Comment:
   Can we add some comments here ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to