micheal-o commented on code in PR #54019:
URL: https://github.com/apache/spark/pull/54019#discussion_r2739276299


##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +326,212 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            processor = (
+                SimpleStatefulProcessorWithInitialStateFactory().pandas()
+                if is_pandas
+                else SimpleStatefulProcessorWithInitialStateFactory().row()
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):
+        """Test repartition for streaming transformWithState"""
+
+        self._run_tws_repartition_test(is_pandas=False)
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,

Review Comment:
   you didn't add everything in the link i sent previously. You missed GIL: 
https://github.com/apache/spark/blob/9ce067c7d532a6766d622e55164baa942a065250/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state_state_variable.py#L979



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +326,212 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            processor = (
+                SimpleStatefulProcessorWithInitialStateFactory().pandas()
+                if is_pandas
+                else SimpleStatefulProcessorWithInitialStateFactory().row()
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):
+        """Test repartition for streaming transformWithState"""
+
+        self._run_tws_repartition_test(is_pandas=False)
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_repartition_with_streaming_tws_in_pandas(self):
+        """Test repartition for streaming transformWithStateInPandas."""
+
+        self._run_tws_repartition_test(is_pandas=True)
+
+    def _run_repartition_with_streaming_tws_multiple_state_vars_test(self, 
is_pandas):
+        """Test repartition with processor using multiple state variable types 
(value + list + map)."""
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value_arr", StringType(), True),
+                    StructField("list_state_arr", StringType(), True),
+                    StructField("map_state_arr", StringType(), True),
+                    StructField("nested_map_state_arr", StringType(), True),
+                ]
+            )
+            processor = (
+                StatefulProcessorCompositeTypeFactory().pandas()
+                if is_pandas
+                else StatefulProcessorCompositeTypeFactory().row()
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            rows = {row.id: row for row in results}
+            # StatefulProcessorCompositeType initializes to "0" on first batch
+            # (this is how the processor is designed - see line 1853 in helper 
file)
+            self.assertEqual(rows["a"].value_arr, "0")
+            self.assertEqual(rows["a"].list_state_arr, "0")
+            # Map state initialized with default ATTRIBUTES_MAP and CONFS_MAP
+            self.assertEqual(rows["a"].map_state_arr, '{"key1": [1], "key2": 
[10]}')
+            self.assertEqual(rows["a"].nested_map_state_arr, '{"e1": {"e2": 5, 
"e3": 10}}')
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved
+            rows = {row.id: row for row in results}
+            # Key 'a': value state accumulated
+            self.assertEqual(rows["a"].value_arr, "100")
+            self.assertEqual(rows["a"].list_state_arr, "0,100")
+            # Map state updated with key "a" and temperature 100
+            self.assertEqual(rows["a"].map_state_arr, '{"a": [100], "key1": 
[1], "key2": [10]}')
+            self.assertEqual(
+                rows["a"].nested_map_state_arr, '{"e1": {"a": 100, "e2": 5, 
"e3": 10}}'
+            )
+            # New key 'd' - first batch for this key, outputs "0"
+            self.assertEqual(rows["d"].value_arr, "0")
+            self.assertEqual(rows["d"].list_state_arr, "0")
+            # Map state for 'd' initialized with defaults
+            self.assertEqual(rows["d"].map_state_arr, '{"key1": [1], "key2": 
[10]}')
+            self.assertEqual(rows["d"].nested_map_state_arr, '{"e1": {"e2": 5, 
"e3": 10}}')
+
+        def verify_after_decrease(results):
+            # After another repartition, state should still be preserved
+            rows = {row.id: row for row in results}
+            # Key 'd' gets 102, state was [0], so accumulated: [0+102] = [102]
+            self.assertEqual(rows["d"].value_arr, "102")
+            self.assertEqual(rows["d"].list_state_arr, "0,102")
+            # Map state for 'd' updated with temperature 102
+            self.assertEqual(rows["d"].map_state_arr, '{"d": [102], "key1": 
[1], "key2": [10]}')
+            self.assertEqual(
+                rows["d"].nested_map_state_arr, '{"e1": {"d": 102, "e2": 5, 
"e3": 10}}'
+            )
+            # Key 'a' gets 101, state was [100], so accumulated: [100+101] = 
[201]
+            self.assertEqual(rows["a"].value_arr, "201")
+            self.assertEqual(rows["a"].list_state_arr, "0,100,101")
+            # Map state for 'a' updated with temperature 101 (replaces 
previous value)
+            self.assertEqual(rows["a"].map_state_arr, '{"a": [101], "key1": 
[1], "key2": [10]}')
+            self.assertEqual(
+                rows["a"].nested_map_state_arr, '{"e1": {"a": 101, "e2": 5, 
"e3": 10}}'
+            )
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,100\n",  # a:0
+            batch2_data="a,100\nd,100\n",  # a:100, d:0 (new)
+            batch3_data="d,102\na,101\n",  # d:102, a:201 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws_multiple_state_vars(self):
+        """Test repartition with transformWithState using multiple state 
variable types (value + list + map)."""
+
+        
self._run_repartition_with_streaming_tws_multiple_state_vars_test(is_pandas=False)
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -299,6 +299,13 @@ class StatePartitionAllColumnFamiliesReader(
     isTWSOperator(operatorName) && colFamilyName == 
StateStore.DEFAULT_COL_FAMILY_NAME
   }
 
+

Review Comment:
   remove extra line?



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +326,212 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            processor = (
+                SimpleStatefulProcessorWithInitialStateFactory().pandas()
+                if is_pandas
+                else SimpleStatefulProcessorWithInitialStateFactory().row()
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):

Review Comment:
   for tws pandas=False tests, we add this:
   
https://github.com/apache/spark/blob/0cba1c0a81a8379051d009c69f618ea635485a35/python/pyspark/sql/tests/pandas/streaming/test_transform_with_state_state_variable.py#L33



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +326,212 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            processor = (
+                SimpleStatefulProcessorWithInitialStateFactory().pandas()
+                if is_pandas
+                else SimpleStatefulProcessorWithInitialStateFactory().row()
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):
+        """Test repartition for streaming transformWithState"""
+
+        self._run_tws_repartition_test(is_pandas=False)
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_repartition_with_streaming_tws_in_pandas(self):
+        """Test repartition for streaming transformWithStateInPandas."""
+
+        self._run_tws_repartition_test(is_pandas=True)
+
+    def _run_repartition_with_streaming_tws_multiple_state_vars_test(self, 
is_pandas):
+        """Test repartition with processor using multiple state variable types 
(value + list + map)."""
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value_arr", StringType(), True),
+                    StructField("list_state_arr", StringType(), True),
+                    StructField("map_state_arr", StringType(), True),
+                    StructField("nested_map_state_arr", StringType(), True),
+                ]
+            )
+            processor = (
+                StatefulProcessorCompositeTypeFactory().pandas()
+                if is_pandas
+                else StatefulProcessorCompositeTypeFactory().row()
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor,
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            rows = {row.id: row for row in results}
+            # StatefulProcessorCompositeType initializes to "0" on first batch
+            # (this is how the processor is designed - see line 1853 in helper 
file)
+            self.assertEqual(rows["a"].value_arr, "0")
+            self.assertEqual(rows["a"].list_state_arr, "0")
+            # Map state initialized with default ATTRIBUTES_MAP and CONFS_MAP
+            self.assertEqual(rows["a"].map_state_arr, '{"key1": [1], "key2": 
[10]}')
+            self.assertEqual(rows["a"].nested_map_state_arr, '{"e1": {"e2": 5, 
"e3": 10}}')
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved
+            rows = {row.id: row for row in results}
+            # Key 'a': value state accumulated
+            self.assertEqual(rows["a"].value_arr, "100")
+            self.assertEqual(rows["a"].list_state_arr, "0,100")
+            # Map state updated with key "a" and temperature 100
+            self.assertEqual(rows["a"].map_state_arr, '{"a": [100], "key1": 
[1], "key2": [10]}')
+            self.assertEqual(
+                rows["a"].nested_map_state_arr, '{"e1": {"a": 100, "e2": 5, 
"e3": 10}}'
+            )
+            # New key 'd' - first batch for this key, outputs "0"
+            self.assertEqual(rows["d"].value_arr, "0")
+            self.assertEqual(rows["d"].list_state_arr, "0")
+            # Map state for 'd' initialized with defaults
+            self.assertEqual(rows["d"].map_state_arr, '{"key1": [1], "key2": 
[10]}')
+            self.assertEqual(rows["d"].nested_map_state_arr, '{"e1": {"e2": 5, 
"e3": 10}}')
+
+        def verify_after_decrease(results):
+            # After another repartition, state should still be preserved
+            rows = {row.id: row for row in results}
+            # Key 'd' gets 102, state was [0], so accumulated: [0+102] = [102]
+            self.assertEqual(rows["d"].value_arr, "102")
+            self.assertEqual(rows["d"].list_state_arr, "0,102")
+            # Map state for 'd' updated with temperature 102
+            self.assertEqual(rows["d"].map_state_arr, '{"d": [102], "key1": 
[1], "key2": [10]}')
+            self.assertEqual(
+                rows["d"].nested_map_state_arr, '{"e1": {"d": 102, "e2": 5, 
"e3": 10}}'
+            )
+            # Key 'a' gets 101, state was [100], so accumulated: [100+101] = 
[201]
+            self.assertEqual(rows["a"].value_arr, "201")
+            self.assertEqual(rows["a"].list_state_arr, "0,100,101")
+            # Map state for 'a' updated with temperature 101 (replaces 
previous value)
+            self.assertEqual(rows["a"].map_state_arr, '{"a": [101], "key1": 
[1], "key2": [10]}')
+            self.assertEqual(
+                rows["a"].nested_map_state_arr, '{"e1": {"a": 101, "e2": 5, 
"e3": 10}}'
+            )
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,100\n",  # a:0
+            batch2_data="a,100\nd,100\n",  # a:100, d:0 (new)
+            batch3_data="d,102\na,101\n",  # d:102, a:201 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws_multiple_state_vars(self):

Review Comment:
   ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to