zifeif2 commented on code in PR #53316:
URL: https://github.com/apache/spark/pull/53316#discussion_r2628306948


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,117 @@ class StatePartitionAllColumnFamiliesReader(
     partition: StateStoreInputPartition,
     schema: StructType,
     keyStateEncoderSpec: KeyStateEncoderSpec,
-    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+    defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
   extends StatePartitionReaderBase(
     storeConf,
     hadoopConf, partition, schema,
-    keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+    keyStateEncoderSpec, None,
+    defaultStateStoreColFamilySchemaOpt,
+    stateSchemaProviderOpt, None) {
 
-  private lazy val store: ReadStateStore = {
+  private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
+  private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+
+  private def isListType(colFamilyName: String): Boolean = {
+    SchemaUtil.checkVariableType(
+      stateVariableInfos.find(info => info.stateName == colFamilyName),
+      StateVariableType.ListState)
+  }
+
+  override protected lazy val provider: StateStoreProvider = {
+    val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+      partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
+    val useColumnFamilies = stateStoreColFamilySchemas.size > 1
+    StateStoreProvider.createAndInit(
+      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+      useColumnFamilies, storeConf, hadoopConf.value,
+      useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+  }
+
+
+  private def checkAllColFamiliesExist(
+      colFamilyNames: List[String], stateStore: StateStore
+    ): Unit = {
+    // Filter out DEFAULT column family from validation for two reasons:
+    // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT 
in their schema
+    //    because the underlying RocksDB creates "default" column family 
automatically
+    // 2. The default column family schema is handled separately via
+    //    defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+    val actualCFs = colFamilyNames.toSet.filter(_ != 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    val expectedCFs = stateStore.allColumnFamilyNames
+      .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+    // Validation: All column families found in the checkpoint must be 
declared in the schema.
+    // It's acceptable if some schema CFs are not in expectedCFs - this just 
means those
+    // column families have no data yet in the checkpoint
+    // (they'll be created during registration).
+    // However, if the checkpoint contains CFs not in the schema, it indicates 
a mismatch.
+    require(expectedCFs.subsetOf(actualCFs),
+      s"Some column families are present in the state store but missing in the 
metadata. " +
+        s"Column families in state store but not in metadata: 
${expectedCFs.diff(actualCFs)}")
+  }
+
+  // Use a single store instance for both registering column families and 
iteration.
+  // We cannot abort and then get a read store because abort() invalidates the 
loaded version,
+  // causing getReadStore() to reload from checkpoint and clear the column 
family registrations.
+  private lazy val store: StateStore = {
     assert(getStartStoreUniqueId == getEndStoreUniqueId,
       "Start and end store unique IDs must be the same when reading all column 
families")
-    provider.getReadStore(
+    val stateStore = provider.getStore(
       partition.sourceOptions.batchId + 1,
       getStartStoreUniqueId
     )
+
+    // Register all column families from the schema
+    if (stateStoreColFamilySchemas.size > 1) {
+      
checkAllColFamiliesExist(stateStoreColFamilySchemas.map(_.colFamilyName).toList,
 stateStore)
+      stateStoreColFamilySchemas.foreach { cfSchema =>
+        cfSchema.colFamilyName match {
+          case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has 
registered default
+          case _ =>
+            val isInternal =
+              
StateStoreColumnFamilySchemaUtils.isInternalColumn(cfSchema.colFamilyName)
+            val useMultipleValuesPerKey = isListType(cfSchema.colFamilyName)
+            require(cfSchema.keyStateEncoderSpec.isDefined,
+              s"keyStateEncoderSpec must be defined for column family 
${cfSchema.colFamilyName}")
+            stateStore.createColFamilyIfAbsent(
+              cfSchema.colFamilyName,
+              cfSchema.keySchema,
+              cfSchema.valueSchema,
+              cfSchema.keyStateEncoderSpec.get,
+              useMultipleValuesPerKey,
+              isInternal)
+        }
+      }
+    }
+    stateStore
   }
 
   override lazy val iter: Iterator[InternalRow] = {
-    store
-      .iterator()
-      .map { pair =>
-        SchemaUtil.unifyStateRowPairAsRawBytes(
-          (pair.key, pair.value), StateStore.DEFAULT_COL_FAMILY_NAME)
+    // Iterate all column families and concatenate results
+    stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
+      if (isListType(cfSchema.colFamilyName)) {
+        store.iterator(cfSchema.colFamilyName).flatMap(
+          pair =>
+            store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
+              value =>
+                SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), 
cfSchema.colFamilyName)
+            }
+        )
+      } else {
+        store.iterator(cfSchema.colFamilyName).map { pair =>
+          SchemaUtil.unifyStateRowPairAsRawBytes(
+            (pair.key, pair.value), cfSchema.colFamilyName)
+        }
       }
+    }
   }
 
   override def close(): Unit = {
-    store.release()
+    store.abort()

Review Comment:
   The store was created as a [read 
store](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala#L200)
 but this PR changed it to a normal State Store (since we need to call 
createColFamilyIfAbsent), and looks like `release` only works for ReadStore, so 
I changed it to `abort` here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to