This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f9cc3dd00b22 [SPARK-55106][SS] Add Repartition Integration test for 
TransformWithState Operators
f9cc3dd00b22 is described below

commit f9cc3dd00b225c97c49b5286bf4f950b896f4f40
Author: zifeif2 <[email protected]>
AuthorDate: Tue Feb 3 11:38:39 2026 -0800

    [SPARK-55106][SS] Add Repartition Integration test for TransformWithState 
Operators
    
    ### What changes were proposed in this pull request?
    
    Adding offline repartitions integration test for transformWithState 
operators
    Tested operators include
    - TWS with multiple column families
    - event time timers
    - processing time timers
    - TTL
    - state with List Type and Map Type
    
    Main changes in util functions:
    - testRepartitionWorkflow supports selectExprs
    - Verify before-vs-after-repartition data by serializing the entire row to 
string instead of just extracting key and value
    
    ### Why are the changes needed?
    
    Ensure that OfflineRepartitionRunner is running as expected
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    See added unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, Sonnet 4.5
    
    Closes #53899 from zifeif2/test-tws.
    
    Lead-authored-by: zifeif2 <[email protected]>
    Co-authored-by: Zifei Feng <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 ...tatePartitionAllColumnFamiliesReaderSuite.scala |   4 +-
 .../OfflineStateRepartitionIntegrationSuite.scala  |  61 +--
 ...rtitionTransformWithStateIntegrationSuite.scala | 493 +++++++++++++++++++++
 .../util/TransformWithStateTestUtils.scala         |  19 +-
 4 files changed, 542 insertions(+), 35 deletions(-)

diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
index 4eb571804935..734177c0d705 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionAllColumnFamiliesReaderSuite.scala
@@ -493,9 +493,9 @@ class StatePartitionAllColumnFamiliesReaderSuite extends 
StateDataSourceTestBase
         testStream(result, OutputMode.Update())(
           StartStream(checkpointLocation = tempDir.getAbsolutePath),
           AddData(inputData, "a", "b", "a"),
-          CheckNewAnswer(("a", "2"), ("b", "1")),
+          ProcessAllAvailable(),
           AddData(inputData, "b", "c"),
-          CheckNewAnswer(("b", "2"), ("c", "1")),
+          ProcessAllAvailable(),
           StopStream
         )
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionIntegrationSuite.scala
index 2ea6f3476ae7..caf60d51ed62 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionIntegrationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionIntegrationSuite.scala
@@ -37,7 +37,7 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
     super.beforeAll()
     spark.conf.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
       classOf[RocksDBStateStoreProvider].getName)
-    spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
+    spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "3")
   }
 
   /**
@@ -47,7 +47,8 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
       checkpointDir: String,
       batchId: Long,
       storeName: String,
-      additionalOptions: Map[String, String]): Dataset[Row] = {
+      additionalOptions: Map[String, String],
+      selectExprs: Seq[String]): Dataset[Row] = {
     var reader = spark.read
       .format("statestore")
       .option(StateSourceOptions.PATH, checkpointDir)
@@ -59,8 +60,11 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
       reader = reader.option(k, v)
     }
 
+    // Not querying partition id because a key will be in different partitions
+    // before and after repartitioning
+    val selectExprsWithoutPartitionId = selectExprs.filterNot(_ == 
"partition_id")
     reader.load()
-      .selectExpr("key", "value", "partition_id")
+      .selectExpr(selectExprsWithoutPartitionId: _*)
       .orderBy("key")
   }
 
@@ -72,27 +76,22 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
   private def readStateDataByStoreName(
       checkpointDir: String,
       batchId: Long,
-      storeToColumnFamilyToStateSourceOptions: Map[String, Map[String, 
Map[String, String]]]
+      storeToColumnFamilyToStateSourceOptions: Map[String, Map[String, 
Map[String, String]]],
+      storeToColumnFamilyToSelectExprs: Map[String, Map[String, Seq[String]]] 
= Map.empty
   ): Map[String, Map[String, Array[Row]]] = {
     storeToColumnFamilyToStateSourceOptions.map { case (storeName, 
columnFamilyToOptions) =>
       val columnFamilyData = columnFamilyToOptions.map { case (cfName, 
options) =>
+        val selectExprs = storeToColumnFamilyToSelectExprs
+          .getOrElse(storeName, Map.empty)
+          .getOrElse(cfName, Seq("key", "value"))
         val stateData = readStateData(
-          checkpointDir, batchId, storeName, options).collect()
+          checkpointDir, batchId, storeName, options, selectExprs).collect()
         cfName -> stateData
       }
       storeName -> columnFamilyData
     }
   }
 
-  /**
-   * Extracts (key, value) pairs from Row array and sorts by key.
-   */
-  private def extractKeyValuePairs(rows: Array[Row]): Array[(String, String)] 
= {
-    rows
-      .map(row => (row.get(0).toString, row.get(1).toString))
-      .sortBy(_._1)
-  }
-
   /**
    * Core helper function that encapsulates the complete repartition test 
workflow:
    * 1. Run query to create initial state
@@ -109,6 +108,8 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
    * @param useManualClock Whether this test requires a manual clock (for 
processing time)
    * @param storeToColumnFamilyToStateSourceOptions Map of store name -> 
column family
    *                                                 name -> options
+   * @param storeToColumnFamilyToSelectExprs Map of store name -> column family
+   *                                          name -> select expressions
    * @tparam T The type of data in the input stream (requires implicit Encoder)
    */
   def testRepartitionWorkflow[T : Encoder](
@@ -118,7 +119,8 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
       useManualClock: Boolean = false,
       storeToColumnFamilyToStateSourceOptions: Map[String, Map[String, 
Map[String, String]]] =
         Map(StateStoreId.DEFAULT_STORE_NAME ->
-          Map(StateStore.DEFAULT_COL_FAMILY_NAME -> Map.empty))): Unit = {
+          Map(StateStore.DEFAULT_COL_FAMILY_NAME -> Map.empty)),
+      storeToColumnFamilyToSelectExprs: Map[String, Map[String, Seq[String]]] 
= Map.empty): Unit = {
     withTempDir { checkpointDir =>
       val clock = if (useManualClock) Some(new StreamManualClock) else None
       val inputData = MemoryStream[T]
@@ -132,8 +134,8 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
       val lastBatchId = checkpointMetadata.commitLog.getLatestBatchId().get
 
       val stateBeforeRepartition = readStateDataByStoreName(
-        checkpointDir.getAbsolutePath, lastBatchId, 
storeToColumnFamilyToStateSourceOptions)
-
+        checkpointDir.getAbsolutePath, lastBatchId, 
storeToColumnFamilyToStateSourceOptions,
+        storeToColumnFamilyToSelectExprs)
       // Verify all stores and column families have data before repartition
       storeToColumnFamilyToStateSourceOptions.foreach { case (storeName, 
columnFamilies) =>
         columnFamilies.keys.foreach { cfName =>
@@ -159,7 +161,8 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
 
       // Step 5: Validate state for each store and column family after 
repartition
       val stateAfterRepartition = readStateDataByStoreName(
-        checkpointDir.getAbsolutePath, repartitionBatchId, 
storeToColumnFamilyToStateSourceOptions)
+        checkpointDir.getAbsolutePath, repartitionBatchId, 
storeToColumnFamilyToStateSourceOptions,
+        storeToColumnFamilyToSelectExprs)
 
       storeToColumnFamilyToStateSourceOptions.foreach { case (storeName, 
columnFamilies) =>
         columnFamilies.keys.foreach { cfName =>
@@ -171,17 +174,15 @@ abstract class 
OfflineStateRepartitionIntegrationSuiteBase extends StateDataSour
             s"Store '$storeName', CF '$cfName': State row count mismatch: " +
               s"before=${beforeState.length}, after=${afterState.length}")
 
-          // Extract (key, value) pairs and compare
-          val beforeByKey = extractKeyValuePairs(beforeState)
-          val afterByKey = extractKeyValuePairs(afterState)
-
-          // Compare each (key, value) pair
-          beforeByKey.zip(afterByKey).zipWithIndex.foreach {
-            case (((keyBefore, valueBefore), (keyAfter, valueAfter)), idx) =>
-              assert(keyBefore == keyAfter,
-                s"Store '$storeName', CF '$cfName': Key mismatch at index 
$idx")
-              assert(valueBefore == valueAfter,
-                s"Store '$storeName', CF '$cfName': Value mismatch for key 
$keyBefore")
+          val sourceSorted = beforeState.sortBy(_.toString)
+          val targetSorted = afterState.sortBy(_.toString)
+
+          sourceSorted.zip(targetSorted).zipWithIndex.foreach {
+            case ((sourceRow, targetRow), idx) =>
+              assert(sourceRow == targetRow,
+                s"Row mismatch at index $idx:\n" +
+                  s"  Source: $sourceRow\n" +
+                  s"  Target: $targetRow")
           }
         }
       }
@@ -215,7 +216,7 @@ abstract class OfflineStateRepartitionIntegrationSuiteBase 
extends StateDataSour
    */
   def testWithAllRepartitionOperations(testNamePrefix: String)
       (testFun: Int => Unit): Unit = {
-    Seq(("increase", 8), ("decrease", 3)).foreach { case (direction, 
newPartitions) =>
+    Seq(("increase", 5), ("decrease", 2)).foreach { case (direction, 
newPartitions) =>
       testWithChangelogConfig(s"$testNamePrefix - $direction partitions") {
         testFun(newPartitions)
       }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionTransformWithStateIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionTransformWithStateIntegrationSuite.scala
new file mode 100644
index 000000000000..fdedf7723ee9
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionTransformWithStateIntegrationSuite.scala
@@ -0,0 +1,493 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming.state
+
+import java.time.Duration
+
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions
+import 
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.timers.TimerStateUtils
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import org.apache.spark.sql.functions.{col, timestamp_seconds}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{InputEvent, ListStateTTLProcessor, 
MapInputEvent, MapOutputEvent, MapStateTTLProcessor, 
MaxEventTimeStatefulProcessor, OutputEvent, OutputMode, 
RunningCountStatefulProcessorWithProcTimeTimer, TimeMode, Trigger, TTLConfig, 
ValueStateTTLProcessor}
+import org.apache.spark.sql.streaming.util.{MultiStateVarProcessor, 
MultiStateVarProcessorTestUtils, TimerTestUtils, TTLProcessorUtils}
+
+/**
+ * Integration test suite for transformWithState operator repartitioning.
+ */
+class OfflineStateRepartitionTransformWithStateCkptV1IntegrationSuite
+  extends OfflineStateRepartitionIntegrationSuiteBase {
+
+  import testImplicits._
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "1")
+  }
+
+  /**
+   * Unified helper to build state source options for transformWithState tests.
+   * Handles basic state variables, timer and TTL column families.
+   *
+   * @param columnFamilyNames List of column family names to configure
+   * @param timeMode Optional TimeMode for timer-based tests (adds 
READ_REGISTERED_TIMERS)
+   * @param listStateName Optional list state name for TTL tests (adds 
FLATTEN_COLLECTION_TYPES)
+   * @return Map of column family names to their state source options
+   */
+  def buildStateSourceOptionsForTWS(
+       columnFamilyNames: Seq[String],
+       timeMode: Option[TimeMode] = None,
+       listStateName: Option[String] = None): Map[String, Map[String, String]] 
= {
+    // Get timer column family names if timeMode is provided
+    val (keyToTimestampCF, timestampToKeyCF) = timeMode match {
+      case Some(tm) => TimerStateUtils.getTimerStateVarNames(tm.toString)
+      case None => (null, null)
+    }
+
+    columnFamilyNames.map { cfName =>
+      // Determine base options based on column family type
+      val options = if (cfName == keyToTimestampCF || cfName == 
timestampToKeyCF) {
+        // Timer column families
+        Map(StateSourceOptions.READ_REGISTERED_TIMERS -> "true")
+      } else if (cfName == StateStore.DEFAULT_COL_FAMILY_NAME) {
+        throw new IllegalArgumentException("TWS operator shouldn't contain 
DEFAULT column family")
+      } else {
+        // Regular state variable column families
+        val baseOptions = Map(StateSourceOptions.STATE_VAR_NAME -> cfName)
+        if (listStateName.contains(cfName)) {
+          baseOptions + (StateSourceOptions.FLATTEN_COLLECTION_TYPES -> "true")
+        } else {
+          baseOptions
+        }
+      }
+
+      cfName -> options
+    }.toMap
+  }
+
+  def testWithDifferentEncodingType(testNamePrefix: String)
+      (testFun: Int => Unit): Unit = {
+    // TODO[SPARK-55301]: add test with "avro" encoding format after SPARK 
increases test timeout
+    // because CI signal "sql - other tests" is timing out after adding the 
integration tests
+    Seq("unsaferow").foreach { encodingFormat =>
+      testWithAllRepartitionOperations(
+        s"$testNamePrefix (encoding = $encodingFormat)") { newPartitions =>
+        withSQLConf(SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> 
encodingFormat) {
+          testFun(newPartitions)
+        }
+      }
+    }
+  }
+
+  testWithDifferentEncodingType(
+    "transformWithState with multiple column families") {
+    newPartitions =>
+      val allColFamilyNames = 
MultiStateVarProcessorTestUtils.ALL_COLUMN_FAMILIES.toList
+      val stateSourceOptions = buildStateSourceOptionsForTWS(
+        allColFamilyNames,
+        listStateName = Some(MultiStateVarProcessorTestUtils.ITEMS_LIST))
+      val selectExprs = 
MultiStateVarProcessorTestUtils.getColumnFamilyToSelectExprs()
+
+      def buildQuery(
+          inputData: MemoryStream[String]): Dataset[(String, String, String, 
String)] = {
+        inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new MultiStateVarProcessor(),
+            TimeMode.None(),
+            OutputMode.Update())
+      }
+
+      testRepartitionWorkflow[String](
+        newPartitions = newPartitions,
+        setupInitialState = (inputData, checkpointDir, _) => {
+          val query = buildQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = checkpointDir),
+            // Batch 1: Creates state in all column families
+            AddData(inputData, "a", "b", "c"),
+            CheckNewAnswer(
+              ("a", "1", "a", "a=1"),
+              ("b", "1", "b", "b=1"),
+              ("c", "1", "c", "c=1")),
+            // Batch 2: Adds more state
+            AddData(inputData, "a", "b", "d"),
+            CheckNewAnswer(
+              ("a", "2", "a,a", "a=2"),
+              ("b", "2", "b,b", "b=2"),
+              ("d", "1", "d", "d=1")),
+            StopStream
+          )
+        },
+        verifyResumedQuery = (inputData, checkpointDir, _) => {
+          val query = buildQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = checkpointDir),
+            // Batch 3: Resume with new data after repartition
+            AddData(inputData, "a", "c", "e"),
+            CheckNewAnswer(
+              ("a", "3", "a,a,a", "a=3"),
+              ("c", "2", "c,c", "c=2"),
+              ("e", "1", "e", "e=1"))
+          )
+        },
+        storeToColumnFamilyToStateSourceOptions = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions
+        ),
+        storeToColumnFamilyToSelectExprs = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> selectExprs
+        )
+      )
+  }
+
+  testWithDifferentEncodingType("transformWithState with eventTime timers") {
+    newPartitions =>
+      // MaxEventTimeStatefulProcessor uses maxEventTimeState and timerState
+      val (keyToTimestampCF, timestampToKeyCF) =
+        TimerStateUtils.getTimerStateVarNames(TimeMode.EventTime().toString)
+      val columnFamilies = Seq(
+        "maxEventTimeState",
+        "timerState",
+        keyToTimestampCF,
+        timestampToKeyCF)
+      val stateSourceOptions = buildStateSourceOptionsForTWS(
+        columnFamilies, timeMode = Some(TimeMode.EventTime()))
+      val selectExprs = 
TimerTestUtils.getTimerColumnFamilyToSelectExprs(TimeMode.EventTime())
+
+      def buildQuery(inputData: MemoryStream[(String, Long)]): 
Dataset[(String, Int)] = {
+        inputData.toDS()
+          .select(col("_1").as("key"), 
timestamp_seconds(col("_2")).as("eventTime"))
+          .withWatermark("eventTime", "10 seconds")
+          .as[(String, Long)]
+          .groupByKey(_._1)
+          .transformWithState(
+            new MaxEventTimeStatefulProcessor(),
+            TimeMode.EventTime(),
+            OutputMode.Update())
+      }
+
+      testRepartitionWorkflow[(String, Long)](
+        newPartitions = newPartitions,
+        setupInitialState = (inputData, checkpointDir, _) => {
+          val query = buildQuery(inputData)
+          testStream(query, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir),
+            // Batch 1: Creates state with event time timers
+            // MaxEventTimeStatefulProcessor outputs (key, maxEventTimeSec)
+            AddData(inputData, ("a", 1L), ("b", 2L), ("c", 3L)),
+            CheckNewAnswer(("a", 1), ("b", 2), ("c", 3)),
+            // Batch 2: More data - max event time for "a" becomes 12
+            AddData(inputData, ("a", 12L)),
+            CheckNewAnswer(("a", 12)),
+            StopStream
+          )
+        },
+        verifyResumedQuery = (inputData, checkpointDir, _) => {
+          val query = buildQuery(inputData)
+          testStream(query, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir),
+            // Batch 3: Resume with new data after repartition
+            // Send event time 18 to advance watermark to (18-10)*1000 = 8000ms
+            // This fires timers for "b" (at 7000ms) and "c" (at 8000ms)
+            // Timer expiry outputs (key, -1)
+            // Add new data for b and c with lower values than previous (b had 
2, c had 3)
+            // to confirm maxEventTime is correctly updated after expiry
+            AddData(inputData, ("a", 18L), ("b", 1L), ("c", 2L)),
+            CheckNewAnswer(("a", 18), ("b", -1), ("c", -1))
+          )
+        },
+        storeToColumnFamilyToStateSourceOptions = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions
+        ),
+        storeToColumnFamilyToSelectExprs = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> selectExprs
+        )
+      )
+  }
+
+  testWithDifferentEncodingType("transformWithState with processing time 
timers") {
+    newPartitions =>
+      val schemas = 
TimerTestUtils.getTimerConfigsForCountState(TimeMode.ProcessingTime())
+      val columnFamilies = schemas.keys.toSeq.filterNot(_ == 
StateStore.DEFAULT_COL_FAMILY_NAME)
+      val stateSourceOptions = buildStateSourceOptionsForTWS(
+        columnFamilies,
+        timeMode = Some(TimeMode.ProcessingTime()))
+      val selectExprs = 
TimerTestUtils.getTimerColumnFamilyToSelectExprs(TimeMode.ProcessingTime())
+
+      def buildQuery(inputData: MemoryStream[String]): Dataset[(String, 
String)] = {
+        inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new 
RunningCountStatefulProcessorWithProcTimeTimer(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+      }
+
+      testRepartitionWorkflow[String](
+        newPartitions = newPartitions,
+        setupInitialState = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            AddData(inputData, "a", "b"),
+            AdvanceManualClock(1000),
+            CheckNewAnswer(("a", "1"), ("b", "1")),
+            AddData(inputData, "a", "c"),
+            AdvanceManualClock(1000),
+            CheckNewAnswer(("a", "2"), ("c", "1")),
+            StopStream
+          )
+        },
+        verifyResumedQuery = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            AddData(inputData, "c", "d"),
+            AdvanceManualClock(5 * 1000),
+            // "a" and "c" are expired, and processor fires processing time 
with "-1"
+            CheckNewAnswer(("a", "-1"), ("c", "-1"), ("c", "2"), ("d", "1")),
+            AddData(inputData, "c"),
+            AdvanceManualClock(1000),
+            // "c" is cleared after timer went off, so recount from 1
+            CheckNewAnswer(("c", "1"))
+          )
+        },
+        useManualClock = true,
+        storeToColumnFamilyToStateSourceOptions = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions
+        ),
+        storeToColumnFamilyToSelectExprs = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> selectExprs
+        )
+      )
+  }
+
+  testWithDifferentEncodingType("transformWithState with list and TTL") {
+    newPartitions =>
+      val schemas = TTLProcessorUtils.getListStateTTLSchemasWithMetadata()
+      val columnFamilies = schemas.keys.toSeq.filterNot(_ == 
StateStore.DEFAULT_COL_FAMILY_NAME)
+      val stateSourceOptions = buildStateSourceOptionsForTWS(
+        columnFamilies,
+        listStateName = Some(TTLProcessorUtils.LIST_STATE))
+      val selectExprs = 
TTLProcessorUtils.getTTLSelectExpressions(columnFamilies)
+
+      def buildQuery(inputData: MemoryStream[InputEvent]): 
Dataset[OutputEvent] = {
+        val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+        inputData.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(new ListStateTTLProcessor(ttlConfig),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+      }
+
+      testRepartitionWorkflow[InputEvent](
+        newPartitions = newPartitions,
+        setupInitialState = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            // Batch 1: Clock advances to 1000ms, TTL = 1000 + 60000 = 61000ms
+            AddData(inputData, InputEvent("k1", "put", 1),
+              InputEvent("k1", "get_ttl_value_from_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(OutputEvent("k1", 1, true, 61000)),
+            // Batch 2: Clock advances to 2000ms, TTL = 2000 + 60000 = 62000ms
+            AddData(inputData, InputEvent("k2", "put", 2),
+              InputEvent("k2", "get_ttl_value_from_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(OutputEvent("k2", 2, true, 62000)),
+            StopStream
+          )
+        },
+        verifyResumedQuery = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query, OutputMode.Update())(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            // Batch 3: Clock advances to 3000ms
+            // Value 1 has TTL from batch 1 (61000ms), value 3 gets TTL = 3000 
+ 60000 = 63000ms
+            AddData(inputData, InputEvent("k1", "append", 3),
+              InputEvent("k1", "get_ttl_value_from_state", 0),
+              InputEvent("k1", "get_values_in_min_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(
+              OutputEvent("k1", 1, true, 61000),
+              OutputEvent("k1", 3, true, 63000),
+              OutputEvent("k1", -1, true, 61000))
+          )
+        },
+        useManualClock = true,
+        storeToColumnFamilyToStateSourceOptions = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions
+        ),
+        storeToColumnFamilyToSelectExprs = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> selectExprs
+        )
+      )
+  }
+
+  testWithDifferentEncodingType("transformWithState with map and TTL") {
+    newPartitions =>
+      val schemas = TTLProcessorUtils.getMapStateTTLSchemasWithMetadata()
+      val columnFamilies = schemas.keys.toSeq.filterNot(_ == 
StateStore.DEFAULT_COL_FAMILY_NAME)
+      val stateSourceOptions = buildStateSourceOptionsForTWS(columnFamilies)
+      val selectExprs = 
TTLProcessorUtils.getTTLSelectExpressions(columnFamilies)
+
+      def buildQuery(inputData: MemoryStream[MapInputEvent]): 
Dataset[MapOutputEvent] = {
+        val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+        inputData.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(new MapStateTTLProcessor(ttlConfig),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+      }
+
+      testRepartitionWorkflow[MapInputEvent](
+        newPartitions = newPartitions,
+        setupInitialState = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            // Batch 1: Clock advances to 1000ms, TTL = 1000 + 60000 = 61000ms
+            AddData(inputData, MapInputEvent("a", "key1", "put", 1),
+              MapInputEvent("a", "key1", "get_ttl_value_from_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(MapOutputEvent("a", "key1", 1, true, 61000)),
+            // Batch 2: Clock advances to 2000ms, TTL = 2000 + 60000 = 62000ms
+            AddData(inputData, MapInputEvent("b", "key2", "put", 2),
+              MapInputEvent("b", "key2", "get_ttl_value_from_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(MapOutputEvent("b", "key2", 2, true, 62000)),
+            StopStream
+          )
+        },
+        verifyResumedQuery = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            // Batch 3: Clock advances to 3000ms
+            // key1 has TTL from batch 1 (61000ms), key3 gets TTL = 3000 + 
60000 = 63000ms
+            AddData(inputData, MapInputEvent("a", "key3", "put", 3),
+              MapInputEvent("a", "key1", "get_ttl_value_from_state", 0),
+              MapInputEvent("a", "key3", "get_ttl_value_from_state", 0),
+              MapInputEvent("a", "key1", "iterator", 0)
+            ),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(MapOutputEvent("a", "key1", 1, true, 61000),
+              MapOutputEvent("a", "key3", 3, true, 63000),
+              MapOutputEvent("a", "key1", 1, false, -1),
+              MapOutputEvent("a", "key3", 3, false, -1)
+            )
+          )
+        },
+        useManualClock = true,
+        storeToColumnFamilyToStateSourceOptions = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions
+        ),
+        storeToColumnFamilyToSelectExprs = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> selectExprs
+        )
+      )
+  }
+
+  testWithDifferentEncodingType("transformWithState with value and TTL") {
+    newPartitions =>
+      val schemas = TTLProcessorUtils.getValueStateTTLSchemasWithMetadata()
+      val stateSourceOptions = buildStateSourceOptionsForTWS(
+        schemas.keys.toSeq.filterNot(_ == StateStore.DEFAULT_COL_FAMILY_NAME))
+
+      def buildQuery(inputData: MemoryStream[InputEvent]): 
Dataset[OutputEvent] = {
+        val ttlConfig = TTLConfig(ttlDuration = Duration.ofMinutes(1))
+        inputData.toDS()
+          .groupByKey(x => x.key)
+          .transformWithState(new ValueStateTTLProcessor(ttlConfig),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+      }
+
+      testRepartitionWorkflow[InputEvent](
+        newPartitions = newPartitions,
+        setupInitialState = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            // Batch 1: Clock advances to 1000ms, TTL = 1000 + 60000 = 61000ms
+            AddData(inputData, InputEvent("k1", "put", 1),
+              InputEvent("k1", "get_ttl_value_from_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(OutputEvent("k1", 1, true, 61000)),
+            // Batch 2: Clock is at 2000ms, TTL = 2000 + 60000 = 62000ms
+            AddData(inputData, InputEvent("k2", "put", 2),
+              InputEvent("k2", "get_ttl_value_from_state", 0)),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(OutputEvent("k2", 2, true, 62000)),
+            StopStream
+          )
+        },
+        verifyResumedQuery = (inputData, checkpointDir, clockOpt) => {
+          val clock = clockOpt.get
+          val query = buildQuery(inputData)
+          testStream(query)(
+            StartStream(checkpointLocation = checkpointDir,
+              trigger = Trigger.ProcessingTime("1 second"),
+              triggerClock = clock),
+            // k2 is still in the state
+            AddData(inputData, InputEvent("k2", "get_ttl_value_from_state", 0),
+              InputEvent("k1", "put", 3),
+              InputEvent("k1", "get_ttl_value_from_state", 0)
+            ),
+            AdvanceManualClock(1 * 1000),
+            CheckNewAnswer(
+              OutputEvent("k2", 2, true, 62000),
+              OutputEvent("k1", 3, true, 63000)
+            )
+          )
+        },
+        useManualClock = true,
+        storeToColumnFamilyToStateSourceOptions = Map(
+          StateStoreId.DEFAULT_STORE_NAME -> stateSourceOptions
+        )
+      )
+  }
+}
+
+class OfflineStateRepartitionTransformWithStateCkptV2IntegrationSuite
+  extends OfflineStateRepartitionTransformWithStateCkptV1IntegrationSuite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.conf.set(SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/TransformWithStateTestUtils.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/TransformWithStateTestUtils.scala
index 963dc80320a5..548cc4e8fe61 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/TransformWithStateTestUtils.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/TransformWithStateTestUtils.scala
@@ -192,13 +192,20 @@ object TTLProcessorUtils {
         Seq("partition_id", "key", "value")
     }
   }
+
+  def getTTLSelectExpressions(columnFamilyNames: Seq[String]): Map[String, 
Seq[String]] = {
+    columnFamilyNames.map { cfName =>
+      cfName -> TTLProcessorUtils.getTTLSelectExpressions(cfName)
+    }.toMap
+  }
 }
 
 /**
  * Stateful processor with multiple state variables (value + list + map)
  * for testing transformWithState operator.
  */
-class MultiStateVarProcessor extends StatefulProcessor[String, String, 
(String, String)] {
+class MultiStateVarProcessor
+    extends StatefulProcessor[String, String, (String, String, String, 
String)] {
   @transient private var _countState: ValueState[Long] = _
   @transient private var _itemsList: ListState[String] = _
   @transient private var _itemsMap: MapState[String, SimpleMapValue] = _
@@ -213,7 +220,7 @@ class MultiStateVarProcessor extends 
StatefulProcessor[String, String, (String,
   override def handleInputRows(
       key: String,
       inputRows: Iterator[String],
-      timerValues: TimerValues): Iterator[(String, String)] = {
+      timerValues: TimerValues): Iterator[(String, String, String, String)] = {
     val currentCount = Option(_countState.get()).getOrElse(0L)
     var newCount = currentCount
     inputRows.foreach { item =>
@@ -222,7 +229,13 @@ class MultiStateVarProcessor extends 
StatefulProcessor[String, String, (String,
       _itemsMap.updateValue(item, SimpleMapValue(newCount.toInt))
     }
     _countState.update(newCount)
-    Iterator((key, newCount.toString))
+
+    // Convert list to human-readable string like "a,a"
+    val listStr = _itemsList.get().mkString(",")
+    // Convert map to human-readable string like "a=1"
+    val mapStr = _itemsMap.iterator().map { case (k, v) => s"$k=${v.count}" 
}.mkString(",")
+
+    Iterator((key, newCount.toString, listStr, mapStr))
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to