micheal-o commented on code in PR #53316:
URL: https://github.com/apache/spark/pull/53316#discussion_r2621514601


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -268,13 +291,23 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
           if (sourceOptions.readRegisteredTimers) {
             stateVarName = TimerStateUtils.getTimerStateVarNames(timeMode)._1
           }
-
-          val stateVarInfoList = operatorProperties.stateVariables
-            .filter(stateVar => stateVar.stateName == stateVarName)
-          require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
-            s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
-          val stateVarInfo = stateVarInfoList.head
-          transformWithStateVariableInfoOpt = Some(stateVarInfo)
+          if (sourceOptions.internalOnlyReadAllColumnFamilies) {
+            stateVariableInfos = operatorProperties.stateVariables
+          } else {
+            var stateVarInfoList = operatorProperties.stateVariables
+              .filter(stateVar => stateVar.stateName == stateVarName)
+            if (stateVarInfoList.isEmpty &&

Review Comment:
   nit: add comment why



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
     partition: StateStoreInputPartition,
     schema: StructType,
     keyStateEncoderSpec: KeyStateEncoderSpec,
-    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+    defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
   extends StatePartitionReaderBase(
     storeConf,
     hadoopConf, partition, schema,
-    keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+    keyStateEncoderSpec, None,
+    defaultStateStoreColFamilySchemaOpt,
+    stateSchemaProviderOpt, None) {
 
-  private lazy val store: ReadStateStore = {
+  private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
+  private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+
+  private def isListType(colFamilyName: String): Boolean = {
+    SchemaUtil.checkVariableType(
+      stateVariableInfos.find(info => info.stateName == colFamilyName),
+      StateVariableType.ListState)
+  }
+
+  override protected lazy val provider: StateStoreProvider = {
+    val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+      partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
+    val useColumnFamilies = stateStoreColFamilySchemas.length > 1
+    StateStoreProvider.createAndInit(
+      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+      useColumnFamilies, storeConf, hadoopConf.value,
+      useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+  }
+
+
+  private def checkAllColFamiliesExist(
+      colFamilyNames: List[String], stateStore: StateStore
+    ): Unit = {
+    // Filter out DEFAULT column family from validation for two reasons:
+    // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT 
in their schema
+    //    because the underlying RocksDB creates "default" column family 
automatically
+    // 2. The default column family schema is handled separately via
+    //    defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+    val actualCFs = colFamilyNames.toSet.filter(_ != 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    val expectedCFs = stateStore.allColumnFamilyNames
+      .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+    // Validation: All column families found in the checkpoint must be 
declared in the schema.
+    // It's acceptable if some schema CFs are not in expectedCFs - this just 
means those
+    // column families have no data yet in the checkpoint
+    // (they'll be created during registration).
+    // However, if the checkpoint contains CFs not in the schema, it indicates 
a mismatch.
+    require(expectedCFs.subsetOf(actualCFs),
+      s"Checkpoint contains unexpected column families. " +
+        s"Column families in checkpoint but not in schema: 
${expectedCFs.diff(actualCFs)}")
+  }
+
+  // Use a single store instance for both registering column families and 
iteration.
+  // We cannot abort and then get a read store because abort() invalidates the 
loaded version,
+  // causing getReadStore() to reload from checkpoint and clear the column 
family registrations.
+  private lazy val store: StateStore = {
     assert(getStartStoreUniqueId == getEndStoreUniqueId,
       "Start and end store unique IDs must be the same when reading all column 
families")
-    provider.getReadStore(
+    val stateStore = provider.getStore(
       partition.sourceOptions.batchId + 1,
       getStartStoreUniqueId
     )
+
+    // Register all column families from the schema
+    if (stateStoreColFamilySchemas.length > 1) {
+      
checkAllColFamiliesExist(stateStoreColFamilySchemas.map(_.colFamilyName), 
stateStore)
+      stateStoreColFamilySchemas.foreach { cfSchema =>
+        cfSchema.colFamilyName match {
+          case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has 
registered default
+          case _ =>
+            val isInternal = cfSchema.colFamilyName.startsWith("$")

Review Comment:
   nit: Lets add a util func for this in : 
`StateStoreColumnFamilySchemaUtils.scala`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala:
##########
@@ -177,7 +193,8 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
       val stateVars = twsOperatorProperties.stateVariables
       val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName == 
stateVarName)
-      if (stateVarInfo.size != 1) {
+      if (stateVarInfo.size != 1 &&
+        
!StateStoreColumnFamilySchemaUtils.isInternalColFamilyTestOnly(stateVarName)) {

Review Comment:
   nit: please add comment why



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/StateStoreColumnFamilySchemaUtils.scala:
##########
@@ -99,6 +99,17 @@ object StateStoreColumnFamilySchemaUtils {
   def getStateNameFromCountIndexCFName(colFamilyName: String): String =
     getStateName(COUNT_INDEX_PREFIX, colFamilyName)
 
+  /**
+   * Returns true if the column family is internal (starts with "$") and we 
are in testing mode.
+   * This is used to allow internal column families to be read during tests.
+   *
+   * @param colFamilyName The name of the column family to check
+   * @return true if this is an internal column family and Utils.isTesting is 
true
+   */
+  def isInternalColFamilyTestOnly(colFamilyName: String): Boolean = {

Review Comment:
   nit: `isTestingInternalColFamily`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
     partition: StateStoreInputPartition,
     schema: StructType,
     keyStateEncoderSpec: KeyStateEncoderSpec,
-    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+    defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
   extends StatePartitionReaderBase(
     storeConf,
     hadoopConf, partition, schema,
-    keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+    keyStateEncoderSpec, None,
+    defaultStateStoreColFamilySchemaOpt,
+    stateSchemaProviderOpt, None) {
 
-  private lazy val store: ReadStateStore = {
+  private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas

Review Comment:
   we should make this a set to make sure no duplicate. So that we don't 
process same cf twice



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
     partition: StateStoreInputPartition,
     schema: StructType,
     keyStateEncoderSpec: KeyStateEncoderSpec,
-    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+    defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
   extends StatePartitionReaderBase(
     storeConf,
     hadoopConf, partition, schema,
-    keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+    keyStateEncoderSpec, None,
+    defaultStateStoreColFamilySchemaOpt,
+    stateSchemaProviderOpt, None) {
 
-  private lazy val store: ReadStateStore = {
+  private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
+  private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+
+  private def isListType(colFamilyName: String): Boolean = {
+    SchemaUtil.checkVariableType(
+      stateVariableInfos.find(info => info.stateName == colFamilyName),
+      StateVariableType.ListState)
+  }
+
+  override protected lazy val provider: StateStoreProvider = {
+    val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+      partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
+    val useColumnFamilies = stateStoreColFamilySchemas.length > 1
+    StateStoreProvider.createAndInit(
+      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+      useColumnFamilies, storeConf, hadoopConf.value,
+      useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+  }
+
+
+  private def checkAllColFamiliesExist(
+      colFamilyNames: List[String], stateStore: StateStore
+    ): Unit = {
+    // Filter out DEFAULT column family from validation for two reasons:
+    // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT 
in their schema
+    //    because the underlying RocksDB creates "default" column family 
automatically
+    // 2. The default column family schema is handled separately via
+    //    defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+    val actualCFs = colFamilyNames.toSet.filter(_ != 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    val expectedCFs = stateStore.allColumnFamilyNames
+      .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+    // Validation: All column families found in the checkpoint must be 
declared in the schema.
+    // It's acceptable if some schema CFs are not in expectedCFs - this just 
means those
+    // column families have no data yet in the checkpoint
+    // (they'll be created during registration).
+    // However, if the checkpoint contains CFs not in the schema, it indicates 
a mismatch.
+    require(expectedCFs.subsetOf(actualCFs),
+      s"Checkpoint contains unexpected column families. " +

Review Comment:
   Fix the message, it should be: Some column families are present in the state 
store but missing in the metadata.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala:
##########
@@ -310,6 +310,13 @@ trait StateStore extends ReadStateStore {
    * Whether all updates have been committed
    */
   def hasCommitted: Boolean
+
+  /**
+   * Returns all column family names in this state store.
+   *
+   * @return Set of all column family names
+   */
+  def allColumnFamilyNames: collection.Set[String]

Review Comment:
   This should be added under `ReadStateStore` instead, so that we can call it 
for both ReadStateStore and StateStore, since it is a read only operation



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -420,8 +534,9 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
 
-    def testStreamStreamJoin(stateVersion: Int): Unit = {
-      withSQLConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> 
stateVersion.toString) {
+    def testStreamStreamJoinV2(stateVersion: Int): Unit = {

Review Comment:
   nit: testStreamStreamJoinV1AndV2. And assert that the version passed in is 
<=2



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
   }
+
+  test("SPARK-54419: transformWithState with multiple column families") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new MultiStateVarProcessor(),
+          TimeMode.None(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, "a", "b", "a"),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        AddData(inputData, "b", "c"),
+        CheckNewAnswer(("b", "2"), ("c", "1")),
+        StopStream
+      )
+
+      // Read all column families using internalOnlyReadAllColumnFamilies
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+      val allBytesData = bytesDf.collect()
+
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      // Verify countState column family exists
+      assert(columnFamilies.toSet ==
+        Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+      // Define schemas for each column family based on provided schema info
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val countStateValueSchema = StructType(Array(
+        StructField("value", LongType, nullable = false)
+      ))
+      val itemsListValueSchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val rowCounterValueSchema = StructType(Array(
+        StructField("count", LongType, nullable = true)
+      ))
+      val itemsMapKeySchema = StructType(Array(
+        StructField("key", groupByKeySchema),
+        StructField("user_map_key", groupByKeySchema, nullable = true)
+      ))
+      val itemsMapValueSchema = StructType(Array(
+        StructField("user_map_value", IntegerType, nullable = true)
+      ))
+
+      // Validate countState
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+      // Validate itemsList
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+        extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true"),
+        selectExprs = Seq("partition_id", "key", "list_element"))
+
+      // Validate $rowCounter_itemsList - intentionally reuses countState's 
data
+      val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option("countState"))
+      compareNormalAndBytesData(
+        countStateNormalDf.collect(),
+        allBytesData,
+        "$rowCounter_itemsList",
+        groupByKeySchema,
+        rowCounterValueSchema)
+
+      // Validate itemsMap
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+        selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+          "user_map_value AS value"))
+    }
+  }
+
+  test("SPARK-54419: read all column families with event time timers") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[(String, Long)]
+      val result = inputData.toDS()
+        .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Timestamp)]
+        .groupByKey(_._1)
+        .transformWithState(
+          new EventTimeTimerProcessor(),
+          TimeMode.EventTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+        CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+    }
+  }
+
+  test("SPARK-54419: read all column families with processing time timers") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+    }
+  }
+
+  test("SPARK-54419: transformWithState with list state and TTL") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new ListStateTTLProcessor(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a", "b", "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        StopStream
+      )
+
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+
+      val allBytesData = bytesDf.collect()
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      assert(columnFamilies.toSet ==
+        Set("listState", "$ttl_listState", "$min_listState", 
"$count_listState"))
+
+      // Define schemas for list state with TTL column families
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val listStateValueSchema = StructType(Array(
+        StructField("value", StructType(Array(
+          StructField("value", StringType, nullable = true)
+        )), nullable = false),
+        StructField("ttlExpirationMs", LongType, nullable = false)
+      ))
+
+      val listStateNormalDf = spark.read
+        .format("statestore")
+        .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+        .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+        .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+        .load()
+        .selectExpr("partition_id", "key", "list_element")
+
+      compareNormalAndBytesData(
+        listStateNormalDf.collect(),
+        allBytesData,
+        "listState",
+        groupByKeySchema,
+        listStateValueSchema)
+      val dummyValueSchema = StructType(Array(StructField("__dummy__", 
NullType)))
+      val ttlIndexKeySchema = StructType(Array(
+        StructField("expirationMs", LongType, nullable = false),
+        StructField("elementKey", groupByKeySchema)
+      ))
+      val minExpiryValueSchema = StructType(Array(
+        StructField("minExpiry", LongType)
+      ))
+      val countValueSchema = StructType(Array(
+        StructField("count", LongType)
+      ))
+      val columnFamilyAndKeyValueSchema = Seq(
+        ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+        ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+        ("$count_listState", groupByKeySchema, countValueSchema)
+      )
+      columnFamilyAndKeyValueSchema.foreach(pair => {
+        val normalDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+          .load()
+          .selectExpr("partition_id", "key", "value")
+
+        compareNormalAndBytesData(
+          normalDf.collect(),
+          allBytesData,
+          pair._1,
+          pair._2,
+          pair._3)
+      }
+      )
+    }
+  }
+
+  def testStreamStreamJoinV3(): Unit = {
+    withSQLConf(
+      SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
+    ) {
+      withTempDir { tempDir =>
+        val inputData = MemoryStream[(Int, Long)]
+        val query = getStreamStreamJoinQuery(inputData)
+        testStream(query)(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+          ProcessAllAvailable(),
+          StopStream
+        )
+        val stateBytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+        validateBytesReadDfSchema(stateBytesDf)
+
+        Seq("right-keyToNumValues", 
"left-keyToNumValues").foreach(colFamilyName => {
+          val normalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option(colFamilyName))
+
+          val keyToNumValuesKeySchema = StructType(Array(
+            StructField("key", IntegerType)
+          ))
+          val keyToNumValueValueSchema = StructType(Array(
+            StructField("value", LongType)
+          ))
+
+          compareNormalAndBytesData(
+            normalDf.collect(),
+            stateBytesDf.collect(),
+            colFamilyName,
+            keyToNumValuesKeySchema,
+            keyToNumValueValueSchema)
+        })
+
+        Seq("right-keyWithIndexToValue", 
"left-keyWithIndexToValue").foreach(colFamilyName => {
+          val normalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option(colFamilyName))
+          val keyToNumValuesKeySchema = StructType(Array(
+            StructField("key", IntegerType, nullable = false),
+            StructField("index", LongType)
+          ))
+          val keyToNumValueValueSchema = StructType(Array(
+              StructField("value", IntegerType, nullable = false),
+              StructField("time", TimestampType, nullable = false),
+              StructField("matched", BooleanType)
+          ))
+
+          compareNormalAndBytesData(
+            normalDf.collect(),
+            stateBytesDf.collect(),
+            colFamilyName,
+            keyToNumValuesKeySchema,
+            keyToNumValueValueSchema)
+        })
+      }
+    }
+  }
+
+  test("SPARK-54419: stream-stream joinV3") {
+    testStreamStreamJoinV3()

Review Comment:
   fyi, no need to create separate func, if it is only used by one test case



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
   }
+
+  test("SPARK-54419: transformWithState with multiple column families") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new MultiStateVarProcessor(),
+          TimeMode.None(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, "a", "b", "a"),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        AddData(inputData, "b", "c"),
+        CheckNewAnswer(("b", "2"), ("c", "1")),
+        StopStream
+      )
+
+      // Read all column families using internalOnlyReadAllColumnFamilies
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+      val allBytesData = bytesDf.collect()
+
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      // Verify countState column family exists
+      assert(columnFamilies.toSet ==
+        Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+      // Define schemas for each column family based on provided schema info
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val countStateValueSchema = StructType(Array(
+        StructField("value", LongType, nullable = false)
+      ))
+      val itemsListValueSchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val rowCounterValueSchema = StructType(Array(
+        StructField("count", LongType, nullable = true)
+      ))
+      val itemsMapKeySchema = StructType(Array(
+        StructField("key", groupByKeySchema),
+        StructField("user_map_key", groupByKeySchema, nullable = true)
+      ))
+      val itemsMapValueSchema = StructType(Array(
+        StructField("user_map_value", IntegerType, nullable = true)
+      ))
+
+      // Validate countState
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+      // Validate itemsList
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+        extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true"),
+        selectExprs = Seq("partition_id", "key", "list_element"))
+
+      // Validate $rowCounter_itemsList - intentionally reuses countState's 
data
+      val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option("countState"))
+      compareNormalAndBytesData(
+        countStateNormalDf.collect(),
+        allBytesData,
+        "$rowCounter_itemsList",
+        groupByKeySchema,
+        rowCounterValueSchema)
+
+      // Validate itemsMap
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+        selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+          "user_map_value AS value"))
+    }
+  }
+
+  test("SPARK-54419: read all column families with event time timers") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[(String, Long)]
+      val result = inputData.toDS()
+        .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Timestamp)]
+        .groupByKey(_._1)
+        .transformWithState(
+          new EventTimeTimerProcessor(),
+          TimeMode.EventTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+        CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+    }
+  }
+
+  test("SPARK-54419: read all column families with processing time timers") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+    }
+  }
+
+  test("SPARK-54419: transformWithState with list state and TTL") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new ListStateTTLProcessor(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a", "b", "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        StopStream
+      )
+
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+
+      val allBytesData = bytesDf.collect()
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      assert(columnFamilies.toSet ==
+        Set("listState", "$ttl_listState", "$min_listState", 
"$count_listState"))
+
+      // Define schemas for list state with TTL column families
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val listStateValueSchema = StructType(Array(
+        StructField("value", StructType(Array(
+          StructField("value", StringType, nullable = true)
+        )), nullable = false),
+        StructField("ttlExpirationMs", LongType, nullable = false)
+      ))
+
+      val listStateNormalDf = spark.read
+        .format("statestore")
+        .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+        .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+        .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+        .load()
+        .selectExpr("partition_id", "key", "list_element")
+
+      compareNormalAndBytesData(
+        listStateNormalDf.collect(),
+        allBytesData,
+        "listState",
+        groupByKeySchema,
+        listStateValueSchema)
+      val dummyValueSchema = StructType(Array(StructField("__dummy__", 
NullType)))
+      val ttlIndexKeySchema = StructType(Array(
+        StructField("expirationMs", LongType, nullable = false),
+        StructField("elementKey", groupByKeySchema)
+      ))
+      val minExpiryValueSchema = StructType(Array(
+        StructField("minExpiry", LongType)
+      ))
+      val countValueSchema = StructType(Array(
+        StructField("count", LongType)
+      ))
+      val columnFamilyAndKeyValueSchema = Seq(
+        ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+        ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+        ("$count_listState", groupByKeySchema, countValueSchema)
+      )
+      columnFamilyAndKeyValueSchema.foreach(pair => {
+        val normalDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+          .load()
+          .selectExpr("partition_id", "key", "value")
+
+        compareNormalAndBytesData(
+          normalDf.collect(),
+          allBytesData,
+          pair._1,
+          pair._2,
+          pair._3)
+      }
+      )
+    }
+  }
+
+  def testStreamStreamJoinV3(): Unit = {
+    withSQLConf(
+      SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
+    ) {
+      withTempDir { tempDir =>
+        val inputData = MemoryStream[(Int, Long)]
+        val query = getStreamStreamJoinQuery(inputData)
+        testStream(query)(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+          ProcessAllAvailable(),
+          StopStream
+        )
+        val stateBytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+        validateBytesReadDfSchema(stateBytesDf)
+
+        Seq("right-keyToNumValues", 
"left-keyToNumValues").foreach(colFamilyName => {
+          val normalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option(colFamilyName))
+
+          val keyToNumValuesKeySchema = StructType(Array(
+            StructField("key", IntegerType)
+          ))
+          val keyToNumValueValueSchema = StructType(Array(
+            StructField("value", LongType)
+          ))
+
+          compareNormalAndBytesData(
+            normalDf.collect(),
+            stateBytesDf.collect(),
+            colFamilyName,
+            keyToNumValuesKeySchema,
+            keyToNumValueValueSchema)
+        })
+
+        Seq("right-keyWithIndexToValue", 
"left-keyWithIndexToValue").foreach(colFamilyName => {
+          val normalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option(colFamilyName))
+          val keyToNumValuesKeySchema = StructType(Array(
+            StructField("key", IntegerType, nullable = false),
+            StructField("index", LongType)
+          ))
+          val keyToNumValueValueSchema = StructType(Array(
+              StructField("value", IntegerType, nullable = false),
+              StructField("time", TimestampType, nullable = false),
+              StructField("matched", BooleanType)
+          ))
+
+          compareNormalAndBytesData(
+            normalDf.collect(),
+            stateBytesDf.collect(),
+            colFamilyName,
+            keyToNumValuesKeySchema,
+            keyToNumValueValueSchema)
+        })
+      }
+    }
+  }
+
+  test("SPARK-54419: stream-stream joinV3") {
+    testStreamStreamJoinV3()
+  }
+}
+
+/**
+ * Stateful processor with multiple state variables (ValueState + ListState)
+ * for testing multi-column family reading.
+ */
+class MultiStateVarProcessor extends StatefulProcessor[String, String, 
(String, String)] {
+  @transient private var _countState: ValueState[Long] = _
+  @transient private var _itemsList: ListState[String] = _
+  @transient private var _itemsMap: MapState[String, SimpleMapValue] = _
+
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong, TTLConfig.NONE)
+    _itemsList = getHandle.getListState[String]("itemsList", Encoders.STRING, 
TTLConfig.NONE)
+    _itemsMap = getHandle.getMapState[String, SimpleMapValue](
+      "itemsMap", Encoders.STRING, Encoders.product[SimpleMapValue], 
TTLConfig.NONE)
+  }
+
+  override def handleInputRows(
+          key: String,
+          inputRows: Iterator[String],
+          timerValues: TimerValues): Iterator[(String, String)] = {
+    val currentCount = Option(_countState.get()).getOrElse(0L)
+    var newCount = currentCount
+    inputRows.foreach { item =>
+      newCount += 1
+      _itemsList.appendValue(item)
+      _itemsMap.updateValue(item, SimpleMapValue(newCount.toInt))
+    }
+    _countState.update(newCount)
+    Iterator((key, newCount.toString))
+  }
+}
+

Review Comment:
   Is there a reason why you're implementing your own processors instead of 
reusing the ones in TransformWithStateSuite: 
https://github.com/apache/spark/blob/25c57d87382ca90850541d4d20447bfd842ec85a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala#L606



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
   }
+
+  test("SPARK-54419: transformWithState with multiple column families") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new MultiStateVarProcessor(),
+          TimeMode.None(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, "a", "b", "a"),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        AddData(inputData, "b", "c"),
+        CheckNewAnswer(("b", "2"), ("c", "1")),
+        StopStream
+      )
+
+      // Read all column families using internalOnlyReadAllColumnFamilies
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+      val allBytesData = bytesDf.collect()
+
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      // Verify countState column family exists
+      assert(columnFamilies.toSet ==
+        Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+      // Define schemas for each column family based on provided schema info
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val countStateValueSchema = StructType(Array(
+        StructField("value", LongType, nullable = false)
+      ))
+      val itemsListValueSchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val rowCounterValueSchema = StructType(Array(
+        StructField("count", LongType, nullable = true)
+      ))
+      val itemsMapKeySchema = StructType(Array(
+        StructField("key", groupByKeySchema),
+        StructField("user_map_key", groupByKeySchema, nullable = true)
+      ))
+      val itemsMapValueSchema = StructType(Array(
+        StructField("user_map_value", IntegerType, nullable = true)
+      ))
+
+      // Validate countState
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+      // Validate itemsList
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+        extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true"),
+        selectExprs = Seq("partition_id", "key", "list_element"))
+
+      // Validate $rowCounter_itemsList - intentionally reuses countState's 
data
+      val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option("countState"))
+      compareNormalAndBytesData(
+        countStateNormalDf.collect(),
+        allBytesData,
+        "$rowCounter_itemsList",
+        groupByKeySchema,
+        rowCounterValueSchema)
+
+      // Validate itemsMap
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+        selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+          "user_map_value AS value"))
+    }
+  }
+
+  test("SPARK-54419: read all column families with event time timers") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[(String, Long)]
+      val result = inputData.toDS()
+        .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Timestamp)]
+        .groupByKey(_._1)
+        .transformWithState(
+          new EventTimeTimerProcessor(),
+          TimeMode.EventTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+        CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+    }
+  }
+
+  test("SPARK-54419: read all column families with processing time timers") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+    }
+  }
+
+  test("SPARK-54419: transformWithState with list state and TTL") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new ListStateTTLProcessor(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a", "b", "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        StopStream
+      )
+
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+
+      val allBytesData = bytesDf.collect()
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      assert(columnFamilies.toSet ==
+        Set("listState", "$ttl_listState", "$min_listState", 
"$count_listState"))
+
+      // Define schemas for list state with TTL column families
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val listStateValueSchema = StructType(Array(
+        StructField("value", StructType(Array(
+          StructField("value", StringType, nullable = true)
+        )), nullable = false),
+        StructField("ttlExpirationMs", LongType, nullable = false)
+      ))
+
+      val listStateNormalDf = spark.read
+        .format("statestore")
+        .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+        .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+        .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+        .load()
+        .selectExpr("partition_id", "key", "list_element")
+
+      compareNormalAndBytesData(
+        listStateNormalDf.collect(),
+        allBytesData,
+        "listState",
+        groupByKeySchema,
+        listStateValueSchema)
+      val dummyValueSchema = StructType(Array(StructField("__dummy__", 
NullType)))
+      val ttlIndexKeySchema = StructType(Array(
+        StructField("expirationMs", LongType, nullable = false),
+        StructField("elementKey", groupByKeySchema)
+      ))
+      val minExpiryValueSchema = StructType(Array(
+        StructField("minExpiry", LongType)
+      ))
+      val countValueSchema = StructType(Array(
+        StructField("count", LongType)
+      ))
+      val columnFamilyAndKeyValueSchema = Seq(
+        ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+        ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+        ("$count_listState", groupByKeySchema, countValueSchema)
+      )
+      columnFamilyAndKeyValueSchema.foreach(pair => {
+        val normalDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+          .load()
+          .selectExpr("partition_id", "key", "value")
+
+        compareNormalAndBytesData(
+          normalDf.collect(),
+          allBytesData,
+          pair._1,
+          pair._2,
+          pair._3)
+      }
+      )
+    }
+  }
+
+  def testStreamStreamJoinV3(): Unit = {
+    withSQLConf(
+      SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION.key -> "3"
+    ) {
+      withTempDir { tempDir =>
+        val inputData = MemoryStream[(Int, Long)]
+        val query = getStreamStreamJoinQuery(inputData)
+        testStream(query)(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+          ProcessAllAvailable(),
+          StopStream
+        )
+        val stateBytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+        validateBytesReadDfSchema(stateBytesDf)
+
+        Seq("right-keyToNumValues", 
"left-keyToNumValues").foreach(colFamilyName => {
+          val normalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option(colFamilyName))
+
+          val keyToNumValuesKeySchema = StructType(Array(
+            StructField("key", IntegerType)
+          ))
+          val keyToNumValueValueSchema = StructType(Array(
+            StructField("value", LongType)
+          ))
+
+          compareNormalAndBytesData(
+            normalDf.collect(),
+            stateBytesDf.collect(),
+            colFamilyName,
+            keyToNumValuesKeySchema,
+            keyToNumValueValueSchema)
+        })
+
+        Seq("right-keyWithIndexToValue", 
"left-keyWithIndexToValue").foreach(colFamilyName => {
+          val normalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option(colFamilyName))
+          val keyToNumValuesKeySchema = StructType(Array(
+            StructField("key", IntegerType, nullable = false),
+            StructField("index", LongType)
+          ))
+          val keyToNumValueValueSchema = StructType(Array(
+              StructField("value", IntegerType, nullable = false),
+              StructField("time", TimestampType, nullable = false),
+              StructField("matched", BooleanType)
+          ))
+
+          compareNormalAndBytesData(
+            normalDf.collect(),
+            stateBytesDf.collect(),
+            colFamilyName,
+            keyToNumValuesKeySchema,
+            keyToNumValueValueSchema)
+        })
+      }
+    }
+  }
+
+  test("SPARK-54419: stream-stream joinV3") {
+    testStreamStreamJoinV3()
+  }
+}
+
+/**
+ * Stateful processor with multiple state variables (ValueState + ListState)
+ * for testing multi-column family reading.
+ */
+class MultiStateVarProcessor extends StatefulProcessor[String, String, 
(String, String)] {
+  @transient private var _countState: ValueState[Long] = _
+  @transient private var _itemsList: ListState[String] = _
+  @transient private var _itemsMap: MapState[String, SimpleMapValue] = _
+
+  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong, TTLConfig.NONE)
+    _itemsList = getHandle.getListState[String]("itemsList", Encoders.STRING, 
TTLConfig.NONE)
+    _itemsMap = getHandle.getMapState[String, SimpleMapValue](
+      "itemsMap", Encoders.STRING, Encoders.product[SimpleMapValue], 
TTLConfig.NONE)
+  }
+
+  override def handleInputRows(

Review Comment:
   nit: indentation



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala:
##########
@@ -610,6 +610,10 @@ private[sql] class RocksDBStateStoreProvider
 
     override def hasCommitted: Boolean = state == COMMITTED
 
+    override def allColumnFamilyNames: collection.Set[String] = {

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -336,6 +336,14 @@ class RocksDB(
     colFamilyNameToInfoMap.asScala.values.toSeq.count(_.isInternal == 
isInternal)
   }
 
+  /**
+   * Returns all column family names currently registered in RocksDB.
+   * This includes column families loaded from checkpoint metadata.
+   */
+  def allColumnFamilyNames: collection.Set[String] = {

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala:
##########
@@ -146,6 +146,9 @@ private[sql] class HDFSBackedStateStoreProvider extends 
StateStoreProvider with
       throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
     }
 
+    override def allColumnFamilyNames: collection.Set[String] =

Review Comment:
   why not just use `Set[String]`? `collection.Set` is the base call of all 
sets and here it means you can return either mutable or immutable set. Using 
`Set[String]` makes it clear that we are only returning immutable



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
   }
+
+  test("SPARK-54419: transformWithState with multiple column families") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new MultiStateVarProcessor(),
+          TimeMode.None(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, "a", "b", "a"),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        AddData(inputData, "b", "c"),
+        CheckNewAnswer(("b", "2"), ("c", "1")),
+        StopStream
+      )
+
+      // Read all column families using internalOnlyReadAllColumnFamilies
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+      val allBytesData = bytesDf.collect()
+
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      // Verify countState column family exists
+      assert(columnFamilies.toSet ==
+        Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+      // Define schemas for each column family based on provided schema info
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val countStateValueSchema = StructType(Array(
+        StructField("value", LongType, nullable = false)
+      ))
+      val itemsListValueSchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val rowCounterValueSchema = StructType(Array(
+        StructField("count", LongType, nullable = true)
+      ))
+      val itemsMapKeySchema = StructType(Array(
+        StructField("key", groupByKeySchema),
+        StructField("user_map_key", groupByKeySchema, nullable = true)
+      ))
+      val itemsMapValueSchema = StructType(Array(
+        StructField("user_map_value", IntegerType, nullable = true)
+      ))
+
+      // Validate countState
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+      // Validate itemsList
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+        extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true"),
+        selectExprs = Seq("partition_id", "key", "list_element"))
+
+      // Validate $rowCounter_itemsList - intentionally reuses countState's 
data

Review Comment:
   nit: add why?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -258,32 +266,116 @@ class StatePartitionAllColumnFamiliesReader(
     partition: StateStoreInputPartition,
     schema: StructType,
     keyStateEncoderSpec: KeyStateEncoderSpec,
-    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+    defaultStateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    stateSchemaProviderOpt: Option[StateSchemaProvider],
+    allColumnFamiliesReaderInfo: AllColumnFamiliesReaderInfo)
   extends StatePartitionReaderBase(
     storeConf,
     hadoopConf, partition, schema,
-    keyStateEncoderSpec, None, stateStoreColFamilySchemaOpt, None, None) {
+    keyStateEncoderSpec, None,
+    defaultStateStoreColFamilySchemaOpt,
+    stateSchemaProviderOpt, None) {
 
-  private lazy val store: ReadStateStore = {
+  private val stateStoreColFamilySchemas = 
allColumnFamiliesReaderInfo.colFamilySchemas
+  private val stateVariableInfos = 
allColumnFamiliesReaderInfo.stateVariableInfos
+
+  private def isListType(colFamilyName: String): Boolean = {
+    SchemaUtil.checkVariableType(
+      stateVariableInfos.find(info => info.stateName == colFamilyName),
+      StateVariableType.ListState)
+  }
+
+  override protected lazy val provider: StateStoreProvider = {
+    val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
+      partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
+    val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
+    val useColumnFamilies = stateStoreColFamilySchemas.length > 1
+    StateStoreProvider.createAndInit(
+      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+      useColumnFamilies, storeConf, hadoopConf.value,
+      useMultipleValuesPerKey = false, stateSchemaProviderOpt)
+  }
+
+
+  private def checkAllColFamiliesExist(
+      colFamilyNames: List[String], stateStore: StateStore
+    ): Unit = {
+    // Filter out DEFAULT column family from validation for two reasons:
+    // 1. Some operators (e.g., stream-stream join v3) don't include DEFAULT 
in their schema
+    //    because the underlying RocksDB creates "default" column family 
automatically
+    // 2. The default column family schema is handled separately via
+    //    defaultStateStoreColFamilySchemaOpt, so no need to verify it here
+    val actualCFs = colFamilyNames.toSet.filter(_ != 
StateStore.DEFAULT_COL_FAMILY_NAME)
+    val expectedCFs = stateStore.allColumnFamilyNames
+      .filter(_ != StateStore.DEFAULT_COL_FAMILY_NAME)
+
+    // Validation: All column families found in the checkpoint must be 
declared in the schema.
+    // It's acceptable if some schema CFs are not in expectedCFs - this just 
means those
+    // column families have no data yet in the checkpoint
+    // (they'll be created during registration).
+    // However, if the checkpoint contains CFs not in the schema, it indicates 
a mismatch.
+    require(expectedCFs.subsetOf(actualCFs),
+      s"Checkpoint contains unexpected column families. " +
+        s"Column families in checkpoint but not in schema: 
${expectedCFs.diff(actualCFs)}")

Review Comment:
   nit: Column families in state store but not in metadata: 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
   }
+
+  test("SPARK-54419: transformWithState with multiple column families") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new MultiStateVarProcessor(),
+          TimeMode.None(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, "a", "b", "a"),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        AddData(inputData, "b", "c"),
+        CheckNewAnswer(("b", "2"), ("c", "1")),
+        StopStream
+      )
+
+      // Read all column families using internalOnlyReadAllColumnFamilies
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+      val allBytesData = bytesDf.collect()
+
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      // Verify countState column family exists
+      assert(columnFamilies.toSet ==
+        Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+      // Define schemas for each column family based on provided schema info
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val countStateValueSchema = StructType(Array(
+        StructField("value", LongType, nullable = false)
+      ))
+      val itemsListValueSchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val rowCounterValueSchema = StructType(Array(
+        StructField("count", LongType, nullable = true)
+      ))
+      val itemsMapKeySchema = StructType(Array(
+        StructField("key", groupByKeySchema),
+        StructField("user_map_key", groupByKeySchema, nullable = true)
+      ))
+      val itemsMapValueSchema = StructType(Array(
+        StructField("user_map_value", IntegerType, nullable = true)
+      ))
+
+      // Validate countState
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+      // Validate itemsList
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+        extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true"),
+        selectExprs = Seq("partition_id", "key", "list_element"))
+
+      // Validate $rowCounter_itemsList - intentionally reuses countState's 
data
+      val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option("countState"))
+      compareNormalAndBytesData(
+        countStateNormalDf.collect(),
+        allBytesData,
+        "$rowCounter_itemsList",
+        groupByKeySchema,
+        rowCounterValueSchema)
+
+      // Validate itemsMap
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+        selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+          "user_map_value AS value"))
+    }
+  }
+
+  test("SPARK-54419: read all column families with event time timers") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[(String, Long)]
+      val result = inputData.toDS()
+        .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Timestamp)]
+        .groupByKey(_._1)
+        .transformWithState(
+          new EventTimeTimerProcessor(),
+          TimeMode.EventTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+        CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+    }
+  }
+
+  test("SPARK-54419: read all column families with processing time timers") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+    }
+  }
+
+  test("SPARK-54419: transformWithState with list state and TTL") {

Review Comment:
   Should we add test for map and value TTL too



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala:
##########
@@ -533,4 +643,399 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
       }
     }
   }
+
+  test("SPARK-54419: transformWithState with multiple column families") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new MultiStateVarProcessor(),
+          TimeMode.None(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, "a", "b", "a"),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        AddData(inputData, "b", "c"),
+        CheckNewAnswer(("b", "2"), ("c", "1")),
+        StopStream
+      )
+
+      // Read all column families using internalOnlyReadAllColumnFamilies
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+      val allBytesData = bytesDf.collect()
+
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      // Verify countState column family exists
+      assert(columnFamilies.toSet ==
+        Set("countState", "itemsList", "$rowCounter_itemsList", "itemsMap"))
+
+      // Define schemas for each column family based on provided schema info
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val countStateValueSchema = StructType(Array(
+        StructField("value", LongType, nullable = false)
+      ))
+      val itemsListValueSchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val rowCounterValueSchema = StructType(Array(
+        StructField("count", LongType, nullable = true)
+      ))
+      val itemsMapKeySchema = StructType(Array(
+        StructField("key", groupByKeySchema),
+        StructField("user_map_key", groupByKeySchema, nullable = true)
+      ))
+      val itemsMapValueSchema = StructType(Array(
+        StructField("user_map_value", IntegerType, nullable = true)
+      ))
+
+      // Validate countState
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "countState", groupByKeySchema, countStateValueSchema)
+
+      // Validate itemsList
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsList", groupByKeySchema, itemsListValueSchema,
+        extraOptions = Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true"),
+        selectExprs = Seq("partition_id", "key", "list_element"))
+
+      // Validate $rowCounter_itemsList - intentionally reuses countState's 
data
+      val countStateNormalDf = getNormalReadDf(tempDir.getAbsolutePath, 
Option("countState"))
+      compareNormalAndBytesData(
+        countStateNormalDf.collect(),
+        allBytesData,
+        "$rowCounter_itemsList",
+        groupByKeySchema,
+        rowCounterValueSchema)
+
+      // Validate itemsMap
+      readAndValidateStateVar(
+        tempDir.getAbsolutePath, allBytesData,
+        stateVarName = "itemsMap", itemsMapKeySchema, itemsMapValueSchema,
+        selectExprs = Seq("partition_id", "STRUCT(key, user_map_key) AS KEY",
+          "user_map_value AS value"))
+    }
+  }
+
+  test("SPARK-54419: read all column families with event time timers") {
+    withTempDir { tempDir =>
+      val inputData = MemoryStream[(String, Long)]
+      val result = inputData.toDS()
+        .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Timestamp)]
+        .groupByKey(_._1)
+        .transformWithState(
+          new EventTimeTimerProcessor(),
+          TimeMode.EventTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath),
+        AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+        CheckLastBatch(("a", "1"), ("b", "1"), ("c", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "event")
+    }
+  }
+
+  test("SPARK-54419: read all column families with processing time timers") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "1")),
+        StopStream
+      )
+
+      validateTimerColumnFamilies(tempDir.getAbsolutePath, "proc")
+    }
+  }
+
+  test("SPARK-54419: transformWithState with list state and TTL") {
+    withTempDir { tempDir =>
+      val clock = new StreamManualClock
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new ListStateTTLProcessor(),
+          TimeMode.ProcessingTime(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        StartStream(checkpointLocation = tempDir.getAbsolutePath,
+          trigger = Trigger.ProcessingTime("1 second"),
+          triggerClock = clock),
+        AddData(inputData, "a", "b", "a"),
+        AdvanceManualClock(1 * 1000),
+        CheckNewAnswer(("a", "2"), ("b", "1")),
+        StopStream
+      )
+
+      val bytesDf = getBytesReadDf(tempDir.getAbsolutePath)
+      validateBytesReadDfSchema(bytesDf)
+
+      val allBytesData = bytesDf.collect()
+      val columnFamilies = allBytesData.map(_.getString(3)).distinct.sorted
+
+      assert(columnFamilies.toSet ==
+        Set("listState", "$ttl_listState", "$min_listState", 
"$count_listState"))
+
+      // Define schemas for list state with TTL column families
+      val groupByKeySchema = StructType(Array(
+        StructField("value", StringType, nullable = true)
+      ))
+      val listStateValueSchema = StructType(Array(
+        StructField("value", StructType(Array(
+          StructField("value", StringType, nullable = true)
+        )), nullable = false),
+        StructField("ttlExpirationMs", LongType, nullable = false)
+      ))
+
+      val listStateNormalDf = spark.read
+        .format("statestore")
+        .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+        .option(StateSourceOptions.STATE_VAR_NAME, "listState")
+        .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "true")
+        .load()
+        .selectExpr("partition_id", "key", "list_element")
+
+      compareNormalAndBytesData(
+        listStateNormalDf.collect(),
+        allBytesData,
+        "listState",
+        groupByKeySchema,
+        listStateValueSchema)
+      val dummyValueSchema = StructType(Array(StructField("__dummy__", 
NullType)))
+      val ttlIndexKeySchema = StructType(Array(
+        StructField("expirationMs", LongType, nullable = false),
+        StructField("elementKey", groupByKeySchema)
+      ))
+      val minExpiryValueSchema = StructType(Array(
+        StructField("minExpiry", LongType)
+      ))
+      val countValueSchema = StructType(Array(
+        StructField("count", LongType)
+      ))
+      val columnFamilyAndKeyValueSchema = Seq(
+        ("$ttl_listState", ttlIndexKeySchema, dummyValueSchema),
+        ("$min_listState", groupByKeySchema, minExpiryValueSchema),
+        ("$count_listState", groupByKeySchema, countValueSchema)
+      )
+      columnFamilyAndKeyValueSchema.foreach(pair => {
+        val normalDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, pair._1)
+          .load()
+          .selectExpr("partition_id", "key", "value")
+
+        compareNormalAndBytesData(
+          normalDf.collect(),
+          allBytesData,
+          pair._1,
+          pair._2,
+          pair._3)
+      }
+      )
+    }
+  }
+
+  def testStreamStreamJoinV3(): Unit = {

Review Comment:
   This seems to be doing the exact thing as `testStreamStreamJoinV2` func. Why 
not use a single func for both and pass in different params. Or am I missing 
something?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to