micheal-o commented on code in PR #53459:
URL: https://github.com/apache/spark/pull/53459#discussion_r2629467428


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -296,29 +315,33 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
           if (sourceOptions.readRegisteredTimers) {
             stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
+          } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+            // When reading all column families (for repartitioning) for TWS 
operator,
+            // we will just choose a random state as placeholder for default 
column family,
+            // because we need to use matching stateVariableInfo and 
stateStoreColFamilySchemaOpt
+            // to inferSchema (partitionKey in particular) later
+            stateVarName = operatorProperties.stateVariables.head.stateName
           }
-          // When reading all column families (for repartitioning), we collect 
all state variable
-          // infos instead of validating a specific stateVarName. This skips 
the normal validation
-          // logic because we're not reading a specific state variable - we're 
reading all of them.
+
           if (sourceOptions.internalOnlyReadAllColumnFamilies) {
             stateVariableInfos = operatorProperties.stateVariables
-          } else {
-            var stateVarInfoList = operatorProperties.stateVariables
-              .filter(stateVar => stateVar.stateName == stateVarName)
-            if (stateVarInfoList.isEmpty &&
-              
StateStoreColumnFamilySchemaUtils.isTestingInternalColFamily(stateVarName)) {
-              // pass this dummy TWSStateVariableInfo for TWS internal column 
family during testing,
-              // because internalColumns are not register in 
operatorProperties.stateVariables,
-              // thus stateVarInfoList will be empty.
-              stateVarInfoList = List(TransformWithStateVariableInfo(
-                stateVarName, StateVariableType.ValueState, false
-              ))
-            }
-            require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
-              s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
-            val stateVarInfo = stateVarInfoList.head
-            transformWithStateVariableInfoOpt = Some(stateVarInfo)
           }

Review Comment:
   nit: new line



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -132,6 +135,22 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
   override def supportsExternalMetadata(): Boolean = false
 
+  /**
+   * Return the state format version for SYMMETRIC_HASH_JOIN operators.
+   * This currently only support join operators because this function is only 
used by
+   * PartitionKeyExtractor and PartitionKeyExtractor only needs state format 
version for
+   * join operators.
+   */
+  private def getStateFormatVersion(
+      storeMetadata: Array[StateMetadataTableEntry]): Option[Int] = {
+    if (storeMetadata.nonEmpty &&
+      storeMetadata.head.operatorName == 
StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME) {
+      Some(session.conf.get(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION))

Review Comment:
   We should read this from the current batch offset seq conf instead. 
`buildStateStoreConf` does similar.
   
   The `session` here doesn't include the confs written in checkpoint, so can 
return wrong value



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {
+    if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
+    } else if 
(TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
+      
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
+
+  private def getStateVarInfo(

Review Comment:
   nit: `getTWSStateVarInfo`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -296,29 +315,33 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
           if (sourceOptions.readRegisteredTimers) {
             stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
+          } else if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+            // When reading all column families (for repartitioning) for TWS 
operator,
+            // we will just choose a random state as placeholder for default 
column family,
+            // because we need to use matching stateVariableInfo and 
stateStoreColFamilySchemaOpt
+            // to inferSchema (partitionKey in particular) later
+            stateVarName = operatorProperties.stateVariables.head.stateName
           }
-          // When reading all column families (for repartitioning), we collect 
all state variable
-          // infos instead of validating a specific stateVarName. This skips 
the normal validation
-          // logic because we're not reading a specific state variable - we're 
reading all of them.
+
           if (sourceOptions.internalOnlyReadAllColumnFamilies) {
             stateVariableInfos = operatorProperties.stateVariables
-          } else {
-            var stateVarInfoList = operatorProperties.stateVariables
-              .filter(stateVar => stateVar.stateName == stateVarName)
-            if (stateVarInfoList.isEmpty &&
-              
StateStoreColumnFamilySchemaUtils.isTestingInternalColFamily(stateVarName)) {
-              // pass this dummy TWSStateVariableInfo for TWS internal column 
family during testing,
-              // because internalColumns are not register in 
operatorProperties.stateVariables,
-              // thus stateVarInfoList will be empty.
-              stateVarInfoList = List(TransformWithStateVariableInfo(
-                stateVarName, StateVariableType.ValueState, false
-              ))
-            }
-            require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
-              s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
-            val stateVarInfo = stateVarInfoList.head
-            transformWithStateVariableInfoOpt = Some(stateVarInfo)
           }
+          var stateVarInfoList = operatorProperties.stateVariables
+            .filter(stateVar => stateVar.stateName == stateVarName)
+          if (stateVarInfoList.isEmpty &&

Review Comment:
   we don't need this anymore right. Since it won't be empty



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -120,10 +120,13 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
           (resultSchema.keySchema, resultSchema.valueSchema)
       }
 
+      val stateVarInfo = stateStoreReaderInfo.transformWithStateVariableInfoOpt

Review Comment:
   nit: why change to `val`? it is only used once below



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {
+    if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
+    } else if 
(TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
+      
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
+

Review Comment:
   nit: remove extra line



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {

Review Comment:
   nit: move this logic to StateStoreColumnFamilySchemaUtils.
   
   Also this `getTWSStateName` func should be something like this:
   ```
   if (StateStoreColumnFamilySchemaUtils.isInternalColFamily(colFamilyName)) {
     StateStoreColumnFamilySchemaUtils.getStateNameForInternalCF(colFamilyName)
   } else {
     colFamilyName
   }
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {
+    if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
+    } else if 
(TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
+      
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
+
+  private def getStateVarInfo(
+      colFamilyName: String): Option[TransformWithStateVariableInfo] = {
+    if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
+      Some(TransformWithStateVariableUtils.getTimerState(colFamilyName))
+    } else {
+      stateVariableInfos.find(_.stateName == getBaseStateName(colFamilyName))
+    }
+  }
+
+  // Create extractors for each column family - each column family may have 
different key schema
+  private lazy val partitionKeyExtractors: Map[String, 
StatePartitionKeyExtractor] = {
+    stateStoreColFamilySchemas
+      .filter(schema => !isDefaultColFamilyInTWS(operatorName, 
schema.colFamilyName))

Review Comment:
   add one line comment



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -357,21 +411,25 @@ class StatePartitionAllColumnFamiliesReader(
 
   override lazy val iter: Iterator[InternalRow] = {
     // Iterate all column families and concatenate results
-    stateStoreColFamilySchemas.iterator.flatMap { cfSchema =>
-      if (isListType(cfSchema.colFamilyName)) {
-        store.iterator(cfSchema.colFamilyName).flatMap(
-          pair =>
-            store.valuesIterator(pair.key, cfSchema.colFamilyName).map {
-              value =>
-                SchemaUtil.unifyStateRowPairAsRawBytes((pair.key, value), 
cfSchema.colFamilyName)
-            }
-        )
-      } else {
-        store.iterator(cfSchema.colFamilyName).map { pair =>
-          SchemaUtil.unifyStateRowPairAsRawBytes(
-            (pair.key, pair.value), cfSchema.colFamilyName)
+    stateStoreColFamilySchemas.iterator
+      .filter(schema => !isDefaultColFamilyInTWS(operatorName, 
schema.colFamilyName))

Review Comment:
   add one line comment



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -77,6 +99,27 @@ object SchemaUtil {
     }
   }
 
+  /**
+   * Creates a StatePartitionKeyExtractor for the given operator.
+   * This is used to extract partition keys from state store keys for state 
repartitioning.
+   */
+  def getPartitionKeyExtractor(

Review Comment:
   There is no point for this func. It is basically just calling 
`StatePartitionKeyExtractorFactory.create` and passing in the exact param. Lets 
just use `StatePartitionKeyExtractorFactory.create` directly instead



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {
+    if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
+    } else if 
(TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
+      
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
+
+  private def getStateVarInfo(
+      colFamilyName: String): Option[TransformWithStateVariableInfo] = {
+    if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
+      Some(TransformWithStateVariableUtils.getTimerState(colFamilyName))

Review Comment:
   nit: add comment why



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -49,8 +52,35 @@ object SchemaUtil {
       keySchema: StructType,
       valueSchema: StructType,
       transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo],
-      stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): 
StructType = {
-    if (transformWithStateVariableInfoOpt.isDefined) {
+      stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+      operatorName: String,
+      stateFormatVersion: Option[Int] = None): StructType = {
+    if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+      val colFamilyName: String =

Review Comment:
   Why are we doing this when we already have the 
`stateStoreColFamilySchemaOpt`? Why not just get the cf name from that? 
Remember the partition key schema will be the same for all CFs



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala:
##########
@@ -87,18 +130,18 @@ object SchemaUtil {
 
   /**
    * Returns an InternalRow representing
-   * 1. partitionKey
+   * 1. partitionKey (extracted using the StatePartitionKeyExtractor)
    * 2. key in bytes
    * 3. value in bytes
    * 4. column family name
    */
   def unifyStateRowPairAsRawBytes(
       pair: (UnsafeRow, UnsafeRow),
-      colFamilyName: String): InternalRow = {
+      colFamilyName: String,
+      extractor: StatePartitionKeyExtractor): InternalRow = {
     val row = new GenericInternalRow(4)
-    // todo [SPARK-54443]: change keySchema to more specific type after we
-    //  can extract partition key from keySchema
-    row.update(0, pair._1)
+    val partitionKey = extractor.partitionKey(pair._1)

Review Comment:
   no need for val. You are only using once and `extractor.partitionKey` 
already makes it clear that we are getting the partition key



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {
+    if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
+    } else if 
(TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
+      
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
+
+  private def getStateVarInfo(
+      colFamilyName: String): Option[TransformWithStateVariableInfo] = {
+    if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
+      Some(TransformWithStateVariableUtils.getTimerState(colFamilyName))
+    } else {
+      stateVariableInfos.find(_.stateName == getBaseStateName(colFamilyName))
+    }
+  }
+
+  // Create extractors for each column family - each column family may have 
different key schema
+  private lazy val partitionKeyExtractors: Map[String, 
StatePartitionKeyExtractor] = {

Review Comment:
   nit: `cfPartitionKeyExtractors`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateVariableUtils.scala:
##########
@@ -63,6 +63,10 @@ object TransformWithStateVariableUtils {
   def isRowCounterCFName(colFamilyName: String): Boolean = {
     colFamilyName.startsWith(ROW_COUNTER_CF_PREFIX)
   }
+
+  def getStateNameFromRowCounterCFName(colFamilyName: String): String = {

Review Comment:
   `require(isRowCounterCFName`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/timers/TimerStateImpl.scala:
##########
@@ -61,9 +61,14 @@ object TimerStateUtils {
   }
 
   def isTimerSecondaryIndexCF(colFamilyName: String): Boolean = {
-    assert(isTimerCFName(colFamilyName), s"Column family name must be for a 
timer: $colFamilyName")

Review Comment:
   why?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -136,26 +150,29 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         valueUnsafeRow.getBytes.clone()
       }
 
-      (keyBytes, valueBytes)
+      (keyBytes, valueBytes, partitionKey)
     }
 
-    // Extract raw bytes from bytes read data (no 
deserialization/reserialization)
-    val bytesAsBytes = filteredBytesData.map { row =>
+    // Extract (partitionKeyStr, keyBytes, valueBytes) from bytes read data
+    val bytesData = filteredBytesData.map { row =>
+      val partitionKey = row.getStruct(0)
       val keyBytes = row.getAs[Array[Byte]](1)
       val valueBytes = row.getAs[Array[Byte]](2)
-      (keyBytes, valueBytes)
+      (keyBytes, valueBytes, partitionKey)
     }
 
-    // Sort both for comparison (since Set equality doesn't work well with 
byte arrays)
-    val normalSorted = normalAsBytes.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
-    val bytesSorted = bytesAsBytes.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
+    // Sort both for comparison by key and value bytes
+    val normalSorted = normalData.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
+    val bytesSorted = bytesData.sortBy(x => (x._1.mkString(","), 
x._2.mkString(",")))
 
     assert(normalSorted.length == bytesSorted.length,
       s"Size mismatch: normal has ${normalSorted.length}, bytes has 
${bytesSorted.length}")
 
-    // Compare each pair
+    // Compare each tuple (partitionKeyStr, keyBytes, valueBytes)

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -278,6 +282,56 @@ class StatePartitionAllColumnFamiliesReader(
 
   private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
   private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+  private val operatorName = allColumnFamiliesReaderInfo.operatorName
+  private val stateFormatVersion = 
allColumnFamiliesReaderInfo.stateFormatVersion
+
+  private def isDefaultColFamilyInTWS(operatorName: String, colFamilyName: 
String): Boolean = {
+    
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_OP_NAMES.contains(operatorName) &&
+      colFamilyName == StateStore.DEFAULT_COL_FAMILY_NAME
+  }
+
+  /**
+   * Extracts the base state variable name from internal column family names.
+   */
+  private def getBaseStateName(colFamilyName: String): String = {
+    if (StateStoreColumnFamilySchemaUtils.isTtlColFamilyName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromTtlColFamily(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isMinExpiryIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromMinExpiryIndexCFName(colFamilyName)
+    } else if 
(StateStoreColumnFamilySchemaUtils.isCountIndexCFName(colFamilyName)) {
+      
StateStoreColumnFamilySchemaUtils.getStateNameFromCountIndexCFName(colFamilyName)
+    } else if 
(TransformWithStateVariableUtils.isRowCounterCFName(colFamilyName)) {
+      
TransformWithStateVariableUtils.getStateNameFromRowCounterCFName(colFamilyName)
+    } else {
+      colFamilyName
+    }
+  }
+
+
+  private def getStateVarInfo(
+      colFamilyName: String): Option[TransformWithStateVariableInfo] = {
+    if (TimerStateUtils.isTimerSecondaryIndexCF(colFamilyName)) {
+      Some(TransformWithStateVariableUtils.getTimerState(colFamilyName))
+    } else {
+      stateVariableInfos.find(_.stateName == getBaseStateName(colFamilyName))
+    }
+  }
+
+  // Create extractors for each column family - each column family may have 
different key schema
+  private lazy val partitionKeyExtractors: Map[String, 
StatePartitionKeyExtractor] = {
+    stateStoreColFamilySchemas
+      .filter(schema => !isDefaultColFamilyInTWS(operatorName, 
schema.colFamilyName))
+      .map { cfSchema =>
+        val extractor = SchemaUtil.getPartitionKeyExtractor(
+          operatorName,
+          cfSchema.keySchema,
+          partition.sourceOptions.storeName,
+          cfSchema.colFamilyName,
+          getStateVarInfo(cfSchema.colFamilyName),

Review Comment:
   Lets be explicit here. If (TWS) getStateVarInfo else None



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -817,13 +848,16 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         val countValueSchema = StructType(Array(
           StructField("count", LongType)
         ))
-        val columnFamilyAndKeyValueSchema = Seq(
-          ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
-          ("$min_listState", groupByKeySchema, minExpiryValueSchema),
-          ("$count_listState", groupByKeySchema, countValueSchema)
+        val ttlColFamilyPartitionKeyExtractor: Option[Row => Row] =
+          Some(compositeKey => compositeKey.getStruct(1))
+        val simpleColumnFamilies = Seq(

Review Comment:
   nit: `ttlColumnFamilies`?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -255,12 +274,14 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       "partition_id",
       "STRUCT(key AS groupingKey, expiration_timestamp_ms AS key)",
       "NULL AS value")
+    // Partition key should be just the grouping key, not the composite (key, 
timestamp)

Review Comment:
   move the comment to be close to `partitionKeyExtractor = ` below. Same for 
others below. It is best for comments to be close to what they are describing. 
Also helps when updating the code later to easily see that the comment needs 
update.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -113,12 +122,17 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
     val keyConverter = 
CatalystTypeConverters.createToCatalystConverter(keySchema)
     val valueConverter = 
CatalystTypeConverters.createToCatalystConverter(valueSchema)
 
-    // Convert normal data to bytes
-    val normalAsBytes = normalDf.toSeq.map { row =>
+    // Convert normal data to (partitionKeyStr, keyBytes, valueBytes)

Review Comment:
   do you mean partitionKeyStruct instead?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -136,26 +150,29 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         valueUnsafeRow.getBytes.clone()
       }
 
-      (keyBytes, valueBytes)
+      (keyBytes, valueBytes, partitionKey)
     }
 
-    // Extract raw bytes from bytes read data (no 
deserialization/reserialization)
-    val bytesAsBytes = filteredBytesData.map { row =>
+    // Extract (partitionKeyStr, keyBytes, valueBytes) from bytes read data

Review Comment:
   ditto, partitionKeyStruct?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -374,13 +397,16 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       }
     }
 
+    val operatorName = if (storeMetadata.nonEmpty) 
storeMetadata.head.operatorName else ""
+    val stateFormatVersion = getStateFormatVersion(storeMetadata)
     StateStoreReaderInfo(
       keyStateEncoderSpecOpt,
       stateStoreColFamilySchemaOpt,
       transformWithStateVariableInfoOpt,
       stateSchemaProvider,
       joinColFamilyOpt,
-      AllColumnFamiliesReaderInfo(stateStoreColFamilySchemas, 
stateVariableInfos)
+      AllColumnFamiliesReaderInfo(

Review Comment:
   why are we always populating this? even when allCF reader is off



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to