micheal-o commented on code in PR #53386:
URL: https://github.com/apache/spark/pull/53386#discussion_r2625003265


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -146,18 +148,47 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     assert(!checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".changelog"))
     assert(checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".zip"))
 
-    // Step 4: Read from target using normal reader
-    val targetReader = spark.read
-      .format("statestore")
-      .option(StateSourceOptions.PATH, targetDir)
-    val targetNormalData = (storeName match {
-      case Some(name) => targetReader.option(StateSourceOptions.STORE_NAME, 
name)
-      case None => targetReader
-    }).load()
-      .selectExpr("key", "value", "partition_id")
-      .collect()
+    // Step 3: Validate by reading from both source and target using normal 
reader"
+    // Default selectExprs for most column families
+    val defaultSelectExprs = Seq("key", "value", "partition_id")
+
+    def shouldCheckColumnFamilyName: String => Boolean = name => {
+      (!name.startsWith("$")
+        || (columnFamilyToStateSourceOptions.contains(name) &&

Review Comment:
   please add comment to explain what this func is doing



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -146,18 +148,47 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
     assert(!checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".changelog"))
     assert(checkpointFileExists(new File(targetDir, storeNamePath), 
versionToCheck, ".zip"))
 
-    // Step 4: Read from target using normal reader
-    val targetReader = spark.read
-      .format("statestore")
-      .option(StateSourceOptions.PATH, targetDir)
-    val targetNormalData = (storeName match {
-      case Some(name) => targetReader.option(StateSourceOptions.STORE_NAME, 
name)
-      case None => targetReader
-    }).load()
-      .selectExpr("key", "value", "partition_id")
-      .collect()
+    // Step 3: Validate by reading from both source and target using normal 
reader"
+    // Default selectExprs for most column families
+    val defaultSelectExprs = Seq("key", "value", "partition_id")
+
+    def shouldCheckColumnFamilyName: String => Boolean = name => {
+      (!name.startsWith("$")
+        || (columnFamilyToStateSourceOptions.contains(name) &&
+        
columnFamilyToStateSourceOptions(name).contains(StateSourceOptions.READ_REGISTERED_TIMERS)))
+    }
+    // Validate each column family separately (skip internal column families 
starting with $)

Review Comment:
   lets also handle internal cf



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -47,30 +52,39 @@ class StatePartitionAllColumnFamiliesWriter(
     operatorId: Int,
     storeName: String,
     currentBatchId: Long,
-    columnFamilyToSchemaMap: HashMap[String, StateStoreColFamilySchema]) {
+    columnFamilyToSchemaMap: HashMap[String, 
StatePartitionWriterColumnFamilyInfo]) {
   private val defaultSchema = {
-    columnFamilyToSchemaMap.getOrElse(
-      StateStore.DEFAULT_COL_FAMILY_NAME,
-      throw new IllegalArgumentException(
-        s"Column family ${StateStore.DEFAULT_COL_FAMILY_NAME} not found in 
schema map")
-    )
+    columnFamilyToSchemaMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
+      case Some(info) => info.schema
+      case None =>
+        // Return a dummy StateStoreColFamilySchema if not found
+        val placeholderSchema = columnFamilyToSchemaMap.head._2.schema

Review Comment:
   We are doing this just for joinv3 right? Then lets do it for only it. And 
throw for other cases



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -82,11 +96,33 @@ class StatePartitionAllColumnFamiliesWriter(
     // Use loadEmpty=true to create a fresh state store without loading 
previous versions
     // We create the empty store AT version, and the next commit will
     // produce version + 1
-    provider.getStore(
+    val store = provider.getStore(
       currentBatchId,
       stateStoreCkptId = None,
       loadEmpty = true
     )
+    if (columnFamilyToSchemaMap.size > 1) {

Review Comment:
   nit: can just have a class val `useColumnFamilies` and you can use it here 
and the other place



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -468,15 +593,373 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            keySchema,
-            valueSchema,
-            keyStateEncoderSpec
+            createSingleColumnFamilySchemaMap(keySchema, valueSchema, 
keyStateEncoderSpec)
           )
         }
       }
     }
   }
 
+  /**
+   * Helper method to test round-trip for transformWithState with multiple 
column families.
+   * Uses MultiStateVarProcessor which creates ValueState, ListState, and 
MapState.
+   */
+  private def testTransformWithStateMultiColumnFamilies(): Unit = {
+    withTempDir { sourceDir =>
+      withTempDir { targetDir =>
+        val inputData = MemoryStream[String]
+        val query = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new MultiStateVarProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+        def runQuery(checkpointLocation: String, roundsOfData: Int): Unit = {
+          val dataActions = (1 to roundsOfData).flatMap { _ =>
+            Seq(
+              AddData(inputData, "a", "b", "a"),
+              ProcessAllAvailable()
+            )
+          }
+          testStream(query)(
+            Seq(StartStream(checkpointLocation = checkpointLocation)) ++
+              dataActions ++
+              Seq(StopStream): _*
+          )
+        }
+
+        // Step 1: Add data to source
+        runQuery(sourceDir.getAbsolutePath, 2)
+        // Step 2: Add data to target
+        runQuery(targetDir.getAbsolutePath, 1)
+
+        // Step 3: Define schemas for all column families
+        val groupByKeySchema = StructType(Array(

Review Comment:
   ditto, common test util?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -82,11 +96,33 @@ class StatePartitionAllColumnFamiliesWriter(
     // Use loadEmpty=true to create a fresh state store without loading 
previous versions
     // We create the empty store AT version, and the next commit will
     // produce version + 1
-    provider.getStore(
+    val store = provider.getStore(
       currentBatchId,
       stateStoreCkptId = None,
       loadEmpty = true
     )
+    if (columnFamilyToSchemaMap.size > 1) {
+      columnFamilyToSchemaMap.foreach { pair =>
+        val colFamilyName = pair._1
+        val cfSchema = pair._2.schema
+        colFamilyName match {
+          case StateStore.DEFAULT_COL_FAMILY_NAME => // createAndInit has 
registered default
+          case _ =>
+            val isInternal = 
StateStoreColumnFamilySchemaUtils.isInternalColFamily(colFamilyName)
+
+            require(cfSchema.keyStateEncoderSpec.isDefined,
+              s"keyStateEncoderSpec must be defined for column family 
${cfSchema.colFamilyName}")
+            store.createColFamilyIfAbsent(
+              colFamilyName,
+              cfSchema.keySchema,
+              cfSchema.valueSchema,
+              cfSchema.keyStateEncoderSpec.get,
+              columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey,

Review Comment:
   you already have this: `pair._2.useMultipleValuesPerKey` right?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -676,5 +1151,25 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
         testStreamStreamJoinRoundTrip(version)
       }
     }
+
+    testWithChangelogConfig("SPARK-54411: stream-stream join state ver 3") {

Review Comment:
   for funcs that are only used ones, no need to create separate func for them. 
Just implement it within the test case. We use func, when the func will be 
called by multiple test cases



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -123,6 +159,12 @@ class StatePartitionAllColumnFamiliesWriter(
     val valueRow = new 
UnsafeRow(columnFamilyToValueSchemaLenMap(colFamilyName))
     valueRow.pointTo(valueBytes, valueBytes.length)
 
-    stateStore.put(keyRow, valueRow, colFamilyName)
+    if (columnFamilyToSchemaMap(colFamilyName).useMultipleValuesPerKey) {
+      // if a column family useMultipleValuesPerKey (e.g. ListType), we will
+      // write with 1 put followed by merge
+      stateStore.merge(keyRow, valueRow, colFamilyName)

Review Comment:
   This is wrong right. You said it in your comment here, to do 1 put and then 
merge for the rest. But you are only doing merge for everything. You can use 
`stateStore.keyExists(keyRow, colFamilyName)` to determine whether to put or 
merge.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -130,9 +118,23 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
         currentBatchId,
         columnFamilyToSchemaMap
       )
-      val rowConverter = 
CatalystTypeConverters.createToCatalystConverter(schema)
 
-      
allCFWriter.write(partition.map(rowConverter(_).asInstanceOf[InternalRow]))
+      // Use per-column-family converters when there are multiple column 
families
+      if (columnFamilyToSchemaMap.size > 1) {
+        // TODO: Remove the logic of getting colNameToRowConverter once 
allColumnFamiliesReader is
+        // returning actual partitionKeySchema instead of the entire key
+        val colNameToRowConverter = columnFamilyToSchemaMap.view.mapValues { 
colInfo =>

Review Comment:
   nit: cfNameToRowConverter instead?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -203,6 +234,37 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
         }
     }
 
+  private def createColFamilyInfo(
+       keySchema: StructType,
+       valueSchema: StructType,
+       keyStateEncoderSpec: KeyStateEncoderSpec,
+       colFamilyName: String,
+       useMultipleValuePerKey: Boolean = false): 
StatePartitionWriterColumnFamilyInfo = {
+    StatePartitionWriterColumnFamilyInfo(
+      schema = StateStoreColFamilySchema(
+        colFamilyName,
+        keySchemaId = 0,
+        keySchema,
+        valueSchemaId = 0,
+        valueSchema,
+        keyStateEncoderSpec = Some(keyStateEncoderSpec)
+      ),
+      useMultipleValuePerKey)
+  }

Review Comment:
   new line



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -329,15 +389,42 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            keySchema,
-            valueSchema,
-            keyStateEncoderSpec
+            createSingleColumnFamilySchemaMap(keySchema, valueSchema, 
keyStateEncoderSpec)
           )
         }
       }
     }
   }
 
+  private val keyToNumValuesColFamilyNames = Seq("left-keyToNumValues", 
"right-keyToNumValues")

Review Comment:
   Many of these are duplicate of what we do in the multi-cf reader suite too. 
Should we move them to a common test util? Or are they different?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -53,31 +57,27 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
    *
    * @param sourceDir Source checkpoint directory
    * @param targetDir Target checkpoint directory
-   * @param keySchema Key schema for the state store
-   * @param valueSchema Value schema for the state store
-   * @param keyStateEncoderSpec Key state encoder spec
+   * @param columnFamilyToSchemaMap Map of column family names to their schemas
    * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
+   * @param columnFamilyToSelectExprs Map of column family names to custom 
selectExprs
+   * @param columnFamilyToStateSourceOptions Map of column family names to 
state source options
    */
   private def performRoundTripTest(
       sourceDir: String,
       targetDir: String,
-      keySchema: StructType,
-      valueSchema: StructType,
-      keyStateEncoderSpec: KeyStateEncoderSpec,
-      storeName: Option[String] = None): Unit = {
-
-    // Step 1: Read original state using normal reader (for comparison later)
-    val sourceReader = spark.read
-      .format("statestore")
-      .option(StateSourceOptions.PATH, sourceDir)
-    val sourceNormalData = (storeName match {
-      case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME, 
name)
-      case None => sourceReader
-    }).load()
-      .selectExpr("key", "value", "partition_id")
-      .collect()
+      columnFamilyToSchemaMap: HashMap[String, 
StatePartitionWriterColumnFamilyInfo],
+      storeName: Option[String] = None,
+      columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty,
+      columnFamilyToStateSourceOptions: Map[String, Map[String, String]] = 
Map.empty): Unit = {
+
+    // Determine column families to validate based on storeName and map size
+    val columnFamiliesToValidate: Seq[String] = storeName match {
+      case Some(name) => Seq(name)

Review Comment:
   why?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -468,15 +593,373 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            keySchema,
-            valueSchema,
-            keyStateEncoderSpec
+            createSingleColumnFamilySchemaMap(keySchema, valueSchema, 
keyStateEncoderSpec)
           )
         }
       }
     }
   }
 
+  /**
+   * Helper method to test round-trip for transformWithState with multiple 
column families.
+   * Uses MultiStateVarProcessor which creates ValueState, ListState, and 
MapState.
+   */
+  private def testTransformWithStateMultiColumnFamilies(): Unit = {
+    withTempDir { sourceDir =>
+      withTempDir { targetDir =>
+        val inputData = MemoryStream[String]
+        val query = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new MultiStateVarProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+        def runQuery(checkpointLocation: String, roundsOfData: Int): Unit = {
+          val dataActions = (1 to roundsOfData).flatMap { _ =>
+            Seq(
+              AddData(inputData, "a", "b", "a"),
+              ProcessAllAvailable()
+            )
+          }
+          testStream(query)(
+            Seq(StartStream(checkpointLocation = checkpointLocation)) ++
+              dataActions ++
+              Seq(StopStream): _*
+          )
+        }
+
+        // Step 1: Add data to source
+        runQuery(sourceDir.getAbsolutePath, 2)
+        // Step 2: Add data to target
+        runQuery(targetDir.getAbsolutePath, 1)
+
+        // Step 3: Define schemas for all column families
+        val groupByKeySchema = StructType(Array(
+          StructField("value", StringType, nullable = true)
+        ))
+        val countStateValueSchema = StructType(Array(
+          StructField("value", LongType, nullable = false)
+        ))
+        val itemsListValueSchema = StructType(Array(
+          StructField("value", StringType, nullable = true)
+        ))
+        val rowCounterValueSchema = StructType(Array(
+          StructField("count", LongType, nullable = true)
+        ))
+        val itemsMapKeySchema = StructType(Array(
+          StructField("key", StringType),
+          StructField("user_map_key", groupByKeySchema, nullable = true)
+        ))
+        val itemsMapValueSchema = StructType(Array(
+          StructField("user_map_value", IntegerType, nullable = true)
+        ))
+
+        // Build column family to schema map for all 4 state variables
+        val countStateEncoderSpec = 
NoPrefixKeyStateEncoderSpec(groupByKeySchema)
+        val itemsMapEncoderSpec = 
PrefixKeyScanStateEncoderSpec(itemsMapKeySchema, 1)
+
+        val columnFamilyToSchemaMap = HashMap(
+          "countState" -> createColFamilyInfo(
+            groupByKeySchema, countStateValueSchema, countStateEncoderSpec, 
"countState"),
+          "itemsList" -> createColFamilyInfo(
+            groupByKeySchema, itemsListValueSchema, countStateEncoderSpec, 
"itemsList", true),
+          "$rowCounter_itemsList" -> createColFamilyInfo(
+            groupByKeySchema, rowCounterValueSchema,
+            countStateEncoderSpec, "$rowCounter_itemsList"),
+          "itemsMap" -> createColFamilyInfo(
+            itemsMapKeySchema, itemsMapValueSchema, itemsMapEncoderSpec, 
"itemsMap")
+        )
+
+        // Define custom selectExprs for column families with non-standard 
schemas
+        val columnFamilyToSelectExprs = Map(
+          "itemsList" -> Seq("key", "list_element AS value", "partition_id"),
+          "itemsMap" -> Seq("STRUCT(key, user_map_key) AS key", 
"user_map_value AS value",
+            "partition_id")
+        )
+
+        // Define reader options for column families that need them
+        val columnFamilyToStateSourceOptions = Map(
+          "itemsList" -> Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true",
+            StateSourceOptions.STATE_VAR_NAME -> "itemsList"),
+          "itemsMap" -> Map(StateSourceOptions.STATE_VAR_NAME -> "itemsMap"),
+          "countState" -> Map(StateSourceOptions.STATE_VAR_NAME -> 
"countState")
+        )
+
+        // Perform round-trip test using common helper
+        performRoundTripTest(
+          sourceDir.getAbsolutePath,
+          targetDir.getAbsolutePath,
+          columnFamilyToSchemaMap,
+          columnFamilyToSelectExprs = columnFamilyToSelectExprs,
+          columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
+        )
+      }
+    }
+  }
+
+  /**
+   * Helper method to build timer column family schemas and options for 
timerProcesser
+   * that has groupingKey of STRING type and keeps track of a "countState" of 
LONG type.
+   * Used by both event time and processing time timer tests
+   *
+   * @param timeMode Either TimeMode.EventTime() or TimeMode.ProcessingTime()
+   * @return A tuple of three elements:
+   *         - columnFamilyToSchemaMap: Maps column family names to their 
schema info
+   *         - columnFamilyToSelectExprs: Maps column family names to custom 
select expressions
+   *         - columnFamilyToStateSourceOptions: Maps column family names to 
state source options
+   */
+  private def getTimerStateConfigsForCountState(timeMode: TimeMode): (
+      HashMap[String, StatePartitionWriterColumnFamilyInfo],
+      Map[String, Seq[String]],
+      Map[String, Map[String, String]]) = {
+
+    val groupByKeySchema = StructType(Array(

Review Comment:
   ditto, common test util?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -53,31 +57,27 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
    *
    * @param sourceDir Source checkpoint directory
    * @param targetDir Target checkpoint directory
-   * @param keySchema Key schema for the state store
-   * @param valueSchema Value schema for the state store
-   * @param keyStateEncoderSpec Key state encoder spec
+   * @param columnFamilyToSchemaMap Map of column family names to their schemas
    * @param storeName Optional store name (for stream-stream join which has 
multiple stores)
+   * @param columnFamilyToSelectExprs Map of column family names to custom 
selectExprs
+   * @param columnFamilyToStateSourceOptions Map of column family names to 
state source options
    */
   private def performRoundTripTest(
       sourceDir: String,
       targetDir: String,
-      keySchema: StructType,
-      valueSchema: StructType,
-      keyStateEncoderSpec: KeyStateEncoderSpec,
-      storeName: Option[String] = None): Unit = {
-
-    // Step 1: Read original state using normal reader (for comparison later)
-    val sourceReader = spark.read
-      .format("statestore")
-      .option(StateSourceOptions.PATH, sourceDir)
-    val sourceNormalData = (storeName match {
-      case Some(name) => sourceReader.option(StateSourceOptions.STORE_NAME, 
name)
-      case None => sourceReader
-    }).load()
-      .selectExpr("key", "value", "partition_id")
-      .collect()
+      columnFamilyToSchemaMap: HashMap[String, 
StatePartitionWriterColumnFamilyInfo],
+      storeName: Option[String] = None,
+      columnFamilyToSelectExprs: Map[String, Seq[String]] = Map.empty,
+      columnFamilyToStateSourceOptions: Map[String, Map[String, String]] = 
Map.empty): Unit = {
+
+    // Determine column families to validate based on storeName and map size
+    val columnFamiliesToValidate: Seq[String] = storeName match {
+      case Some(name) => Seq(name)
+      case None if columnFamilyToSchemaMap.size > 1 => 
columnFamilyToSchemaMap.keys.toSeq
+      case None => Seq(StateStoreId.DEFAULT_STORE_NAME)

Review Comment:
   nit: DEFAULT_COL_FAMILY instead?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -468,15 +593,373 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
           performRoundTripTest(
             sourceDir.getAbsolutePath,
             targetDir.getAbsolutePath,
-            keySchema,
-            valueSchema,
-            keyStateEncoderSpec
+            createSingleColumnFamilySchemaMap(keySchema, valueSchema, 
keyStateEncoderSpec)
           )
         }
       }
     }
   }
 
+  /**
+   * Helper method to test round-trip for transformWithState with multiple 
column families.
+   * Uses MultiStateVarProcessor which creates ValueState, ListState, and 
MapState.
+   */
+  private def testTransformWithStateMultiColumnFamilies(): Unit = {
+    withTempDir { sourceDir =>
+      withTempDir { targetDir =>
+        val inputData = MemoryStream[String]
+        val query = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new MultiStateVarProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+        def runQuery(checkpointLocation: String, roundsOfData: Int): Unit = {
+          val dataActions = (1 to roundsOfData).flatMap { _ =>
+            Seq(
+              AddData(inputData, "a", "b", "a"),
+              ProcessAllAvailable()
+            )
+          }
+          testStream(query)(
+            Seq(StartStream(checkpointLocation = checkpointLocation)) ++
+              dataActions ++
+              Seq(StopStream): _*
+          )
+        }
+
+        // Step 1: Add data to source
+        runQuery(sourceDir.getAbsolutePath, 2)
+        // Step 2: Add data to target
+        runQuery(targetDir.getAbsolutePath, 1)
+
+        // Step 3: Define schemas for all column families
+        val groupByKeySchema = StructType(Array(
+          StructField("value", StringType, nullable = true)
+        ))
+        val countStateValueSchema = StructType(Array(
+          StructField("value", LongType, nullable = false)
+        ))
+        val itemsListValueSchema = StructType(Array(
+          StructField("value", StringType, nullable = true)
+        ))
+        val rowCounterValueSchema = StructType(Array(
+          StructField("count", LongType, nullable = true)
+        ))
+        val itemsMapKeySchema = StructType(Array(
+          StructField("key", StringType),
+          StructField("user_map_key", groupByKeySchema, nullable = true)
+        ))
+        val itemsMapValueSchema = StructType(Array(
+          StructField("user_map_value", IntegerType, nullable = true)
+        ))
+
+        // Build column family to schema map for all 4 state variables
+        val countStateEncoderSpec = 
NoPrefixKeyStateEncoderSpec(groupByKeySchema)
+        val itemsMapEncoderSpec = 
PrefixKeyScanStateEncoderSpec(itemsMapKeySchema, 1)
+
+        val columnFamilyToSchemaMap = HashMap(
+          "countState" -> createColFamilyInfo(
+            groupByKeySchema, countStateValueSchema, countStateEncoderSpec, 
"countState"),
+          "itemsList" -> createColFamilyInfo(
+            groupByKeySchema, itemsListValueSchema, countStateEncoderSpec, 
"itemsList", true),
+          "$rowCounter_itemsList" -> createColFamilyInfo(
+            groupByKeySchema, rowCounterValueSchema,
+            countStateEncoderSpec, "$rowCounter_itemsList"),
+          "itemsMap" -> createColFamilyInfo(
+            itemsMapKeySchema, itemsMapValueSchema, itemsMapEncoderSpec, 
"itemsMap")
+        )
+
+        // Define custom selectExprs for column families with non-standard 
schemas
+        val columnFamilyToSelectExprs = Map(
+          "itemsList" -> Seq("key", "list_element AS value", "partition_id"),
+          "itemsMap" -> Seq("STRUCT(key, user_map_key) AS key", 
"user_map_value AS value",
+            "partition_id")
+        )
+
+        // Define reader options for column families that need them
+        val columnFamilyToStateSourceOptions = Map(
+          "itemsList" -> Map(StateSourceOptions.FLATTEN_COLLECTION_TYPES -> 
"true",
+            StateSourceOptions.STATE_VAR_NAME -> "itemsList"),
+          "itemsMap" -> Map(StateSourceOptions.STATE_VAR_NAME -> "itemsMap"),
+          "countState" -> Map(StateSourceOptions.STATE_VAR_NAME -> 
"countState")
+        )
+
+        // Perform round-trip test using common helper
+        performRoundTripTest(
+          sourceDir.getAbsolutePath,
+          targetDir.getAbsolutePath,
+          columnFamilyToSchemaMap,
+          columnFamilyToSelectExprs = columnFamilyToSelectExprs,
+          columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
+        )
+      }
+    }
+  }
+
+  /**
+   * Helper method to build timer column family schemas and options for 
timerProcesser
+   * that has groupingKey of STRING type and keeps track of a "countState" of 
LONG type.
+   * Used by both event time and processing time timer tests
+   *
+   * @param timeMode Either TimeMode.EventTime() or TimeMode.ProcessingTime()
+   * @return A tuple of three elements:
+   *         - columnFamilyToSchemaMap: Maps column family names to their 
schema info
+   *         - columnFamilyToSelectExprs: Maps column family names to custom 
select expressions
+   *         - columnFamilyToStateSourceOptions: Maps column family names to 
state source options
+   */
+  private def getTimerStateConfigsForCountState(timeMode: TimeMode): (
+      HashMap[String, StatePartitionWriterColumnFamilyInfo],
+      Map[String, Seq[String]],
+      Map[String, Map[String, String]]) = {
+
+    val groupByKeySchema = StructType(Array(
+      StructField("key", StringType, nullable = true)
+    ))
+    val stateValueSchema = StructType(Array(
+      StructField("value", LongType, nullable = true)
+    ))
+    val keyToTimestampKeySchema = StructType(Array(
+      StructField("key", StringType),
+      StructField("expiryTimestampMs", LongType, nullable = false)
+    ))
+    val timestampToKeyKeySchema = StructType(Array(
+      StructField("expiryTimestampMs", LongType, nullable = false),
+      StructField("key", StringType)
+    ))
+    val dummyValueSchema = StructType(Array(StructField("__dummy__", 
NullType)))
+
+    val encoderSpec = NoPrefixKeyStateEncoderSpec(groupByKeySchema)
+    val keyToTimestampEncoderSpec = 
PrefixKeyScanStateEncoderSpec(keyToTimestampKeySchema, 1)
+    val timestampToKeyEncoderSpec = 
RangeKeyScanStateEncoderSpec(timestampToKeyKeySchema, Seq(0))
+
+    val (keyToTimestampCF, timestampToKeyCF) =
+      TimerStateUtils.getTimerStateVarNames(timeMode.toString)
+
+    val columnFamilyToSchemaMap = HashMap(
+      "countState" -> createColFamilyInfo(
+        groupByKeySchema, stateValueSchema, encoderSpec, "countState"),
+      keyToTimestampCF -> createColFamilyInfo(
+        keyToTimestampKeySchema, dummyValueSchema, keyToTimestampEncoderSpec, 
keyToTimestampCF),
+      timestampToKeyCF -> createColFamilyInfo(
+        timestampToKeyKeySchema, dummyValueSchema, timestampToKeyEncoderSpec, 
timestampToKeyCF)
+    )
+
+    val columnFamilyToSelectExprs = Map(
+      keyToTimestampCF -> Seq(
+        "STRUCT(key AS groupingKey, expiration_timestamp_ms AS key) AS key",
+        "NULL AS value", "partition_id"),
+      timestampToKeyCF -> Seq(
+        "STRUCT(expiration_timestamp_ms AS key, key AS groupingKey) AS key",
+        "NULL AS value", "partition_id")
+    )
+
+    val columnFamilyToStateSourceOptions = Map(
+      "countState" -> Map(StateSourceOptions.STATE_VAR_NAME -> "countState"),
+      keyToTimestampCF -> Map(StateSourceOptions.READ_REGISTERED_TIMERS -> 
"true"),
+      timestampToKeyCF -> Map(StateSourceOptions.READ_REGISTERED_TIMERS -> 
"true")
+    )
+
+    (columnFamilyToSchemaMap, columnFamilyToSelectExprs, 
columnFamilyToStateSourceOptions)
+  }
+
+  /**
+   * Helper method to test round-trip for transformWithState with event time 
timers.
+   */
+  private def testEventTimeTimersRoundTrip(): Unit = {
+    withTempDir { sourceDir =>
+      withTempDir { targetDir =>
+        val inputData = MemoryStream[(String, Long)]
+        val result = inputData.toDS()
+          .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+          .withWatermark("eventTime", "10 seconds")
+          .as[(String, Timestamp)]
+          .groupByKey(_._1)
+          .transformWithState(
+            new EventTimeTimerProcessor(),
+            TimeMode.EventTime(),
+            OutputMode.Update())
+
+        // Step 1: Create source checkpoint
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = sourceDir.getAbsolutePath),
+          AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+          ProcessAllAvailable(),
+          StopStream
+        )
+
+        // Step 2: Create target checkpoint with dummy data
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = targetDir.getAbsolutePath),
+          AddData(inputData, ("x", 1L)),
+          ProcessAllAvailable(),
+          StopStream
+        )
+
+        // Step 3: Build timer column family configs and perform round-trip 
test
+        val (columnFamilyToSchemaMap, columnFamilyToSelectExprs, 
columnFamilyToStateSourceOptions) =
+          getTimerStateConfigsForCountState(TimeMode.EventTime())
+
+        performRoundTripTest(
+          sourceDir.getAbsolutePath,
+          targetDir.getAbsolutePath,
+          columnFamilyToSchemaMap,
+          columnFamilyToSelectExprs = columnFamilyToSelectExprs,
+          columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
+        )
+      }
+    }
+  }
+
+  /**
+   * Helper method to test round-trip for transformWithState with processing 
time timers.
+   */
+  private def testProcessingTimeTimersRoundTrip(): Unit = {
+    withTempDir { sourceDir =>
+      withTempDir { targetDir =>
+        val clock = new StreamManualClock
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        // Step 1: Create source checkpoint
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+            trigger = Trigger.ProcessingTime("1 second"),
+            triggerClock = clock),
+          AddData(inputData, "a"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "1")),
+          StopStream
+        )
+
+        // Step 2: Create target checkpoint with dummy data
+        val clock2 = new StreamManualClock
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = targetDir.getAbsolutePath,
+            trigger = Trigger.ProcessingTime("1 second"),
+            triggerClock = clock2),
+          AddData(inputData, "x"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "1"), ("x", "1")),
+          StopStream
+        )
+
+        // Step 3: Build timer column family configs and perform round-trip 
test
+        val (columnFamilyToSchemaMap, columnFamilyToSelectExprs, 
columnFamilyToStateSourceOptions) =
+          getTimerStateConfigsForCountState(TimeMode.ProcessingTime())
+
+        performRoundTripTest(
+          sourceDir.getAbsolutePath,
+          targetDir.getAbsolutePath,
+          columnFamilyToSchemaMap,
+          columnFamilyToSelectExprs = columnFamilyToSelectExprs,
+          columnFamilyToStateSourceOptions = columnFamilyToStateSourceOptions
+        )
+      }
+    }
+  }
+
+  /**
+   * Helper method to test round-trip for transformWithState with list state 
and TTL.
+   */
+  private def testListStateTTLRoundTrip(): Unit = {
+    withTempDir { sourceDir =>
+      withTempDir { targetDir =>
+        val clock = new StreamManualClock
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new ListStateTTLProcessor(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        // Step 1: Create source checkpoint
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = sourceDir.getAbsolutePath,
+            trigger = Trigger.ProcessingTime("1 second"),
+            triggerClock = clock),
+          AddData(inputData, "a", "b", "a"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "2"), ("b", "1")),
+          StopStream
+        )
+
+        // Step 2: Create target checkpoint with dummy data
+        val clock2 = new StreamManualClock
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = targetDir.getAbsolutePath,
+            trigger = Trigger.ProcessingTime("1 second"),
+            triggerClock = clock2),
+          AddData(inputData, "x"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "2"), ("b", "1"), ("x", "1")),
+          StopStream
+        )
+
+        // Step 3: Define schemas for list state with TTL column families
+        val groupByKeySchema = StructType(Array(

Review Comment:
   ditto, lets try to see as much things we can move to a common test util. So 
we can reuse them easily for new test suites



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionAllColumnFamiliesWriterSuite.scala:
##########
@@ -676,5 +1151,25 @@ class StatePartitionAllColumnFamiliesWriterSuite extends 
StateDataSourceTestBase
         testStreamStreamJoinRoundTrip(version)
       }
     }
+
+    testWithChangelogConfig("SPARK-54411: stream-stream join state ver 3") {
+      testStreamStreamJoinV3RoundTrip()
+    }
+
+    testWithChangelogConfig("SPARK-54411: transformWithState with multiple 
column families") {

Review Comment:
   for the TWS test cases, lets run them with the different encoding formats 
i.e. UnsafeRow, Avro. See `AlsoTestWithEncodingTypes`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to