micheal-o commented on code in PR #54019:
URL: https://github.com/apache/spark/pull/54019#discussion_r2734919596


##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +324,87 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, processor_factory, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+        from pyspark.sql.functions import split
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor_factory.pandas(),
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor_factory.row(),
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):
+        """Test repartition for streaming transformWithState with both row and 
pandas processors"""

Review Comment:
   fix comment



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +324,87 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, processor_factory, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+        from pyspark.sql.functions import split
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor_factory.pandas(),
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor_factory.row(),
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):
+        """Test repartition for streaming transformWithState with both row and 
pandas processors"""
+        self._run_tws_repartition_test(
+            SimpleStatefulProcessorWithInitialStateFactory(), is_pandas=False

Review Comment:
   We should also add a test case that uses a processor with multiple state 
vars (i.e. value + list + map). Just to make sure there is coverage. There 
should be an existing processor for this.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -55,24 +53,16 @@ class StatePartitionAllColumnFamiliesWriter(
     storeName: String,
     currentBatchId: Long,
     colFamilyToWriterInfoMap: Map[String, 
StatePartitionWriterColumnFamilyInfo],
-    operatorName: String,
-    schemaProviderOpt: Option[StateSchemaProvider],
-    sqlConf: SQLConf) {
-
-  private def isJoinV3Operator(
-      operatorName: String, sqlConf: SQLConf): Boolean = {
-    operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
-      sqlConf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) == 3
-  }
+    schemaProviderOpt: Option[StateSchemaProvider]) {
 
+  // Using the heuristic that all operators that enable column families
+  // has a non-default column family
+  private val useColumnFamilies = colFamilyToWriterInfoMap.keys.toSeq
+    .exists(_ != StateStoreId.DEFAULT_STORE_NAME)
   private val defaultSchema = {
     colFamilyToWriterInfoMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) match {
       case Some(info) => info.schema
       case None =>
-        // joinV3 operator doesn't have default column family schema
-        assert(isJoinV3Operator(operatorName, sqlConf),

Review Comment:
   Lets add the assert. Makes it easier to know/debug if we run into that 
situation



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -29,6 +29,10 @@
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import 
(
+    TTLStatefulProcessorFactory,
+    SimpleStatefulProcessorWithInitialStateFactory,

Review Comment:
   S before T



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -29,6 +29,10 @@
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import 
(

Review Comment:
   move up



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StatePartitionWriter.scala:
##########
@@ -55,24 +53,16 @@ class StatePartitionAllColumnFamiliesWriter(
     storeName: String,
     currentBatchId: Long,
     colFamilyToWriterInfoMap: Map[String, 
StatePartitionWriterColumnFamilyInfo],
-    operatorName: String,
-    schemaProviderOpt: Option[StateSchemaProvider],
-    sqlConf: SQLConf) {
-
-  private def isJoinV3Operator(
-      operatorName: String, sqlConf: SQLConf): Boolean = {
-    operatorName == StatefulOperatorsUtils.SYMMETRIC_HASH_JOIN_EXEC_OP_NAME &&
-      sqlConf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION) == 3
-  }
+    schemaProviderOpt: Option[StateSchemaProvider]) {
 
+  // Using the heuristic that all operators that enable column families
+  // have a non-default column family
+  private val useColumnFamilies = colFamilyToWriterInfoMap.keys.toSeq
+    .exists(_ != StateStoreId.DEFAULT_STORE_NAME)

Review Comment:
   `StateStore.DEFAULT_COL_FAMILY_NAME`



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -29,6 +29,10 @@
     pandas_requirement_message,

Review Comment:
   run `dev/reformat-python` to fix formatting issue 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -299,6 +299,13 @@ class StatePartitionAllColumnFamiliesReader(
     isTWSOperator(operatorName) && colFamilyName == 
StateStore.DEFAULT_COL_FAMILY_NAME
   }
 
+
+  private def checkIfUseColumnFamily(schema: Set[StateStoreColFamilySchema]): 
Boolean = {
+    // Using the heuristic that all operators that enable column families
+    // have a non-default column family
+    schema.exists(_.colFamilyName != StateStoreId.DEFAULT_STORE_NAME)

Review Comment:
   `StateStore.DEFAULT_COL_FAMILY_NAME`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -299,6 +299,13 @@ class StatePartitionAllColumnFamiliesReader(
     isTWSOperator(operatorName) && colFamilyName == 
StateStore.DEFAULT_COL_FAMILY_NAME
   }
 
+
+  private def checkIfUseColumnFamily(schema: Set[StateStoreColFamilySchema]): 
Boolean = {

Review Comment:
   can just name it `useColumnFamilies`



##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -320,6 +324,87 @@ def verify_after_decrease(results):
             verify_after_decrease=verify_after_decrease,
         )
 
+    def _run_tws_repartition_test(self, processor_factory, is_pandas):
+        """Helper method to run repartition test with a given processor 
factory method"""
+        from pyspark.sql.functions import split
+
+        def create_streaming_df(df):
+            # Parse text input format "id,temperature" into structured columns
+            split_df = split(df["value"], ",")
+            parsed_df = df.select(
+                split_df.getItem(0).alias("id"),
+                split_df.getItem(1).cast("integer").alias("temperature"),
+            )
+
+            output_schema = StructType(
+                [
+                    StructField("id", StringType(), True),
+                    StructField("value", StringType(), True),
+                ]
+            )
+            return (
+                parsed_df.groupBy("id").transformWithStateInPandas(
+                    statefulProcessor=processor_factory.pandas(),
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+                if is_pandas
+                else parsed_df.groupBy("id").transformWithState(
+                    statefulProcessor=processor_factory.row(),
+                    outputStructType=output_schema,
+                    outputMode="Update",
+                    timeMode="None",
+                    initialState=None,
+                )
+            )
+
+        def verify_initial(results):
+            # SimpleStatefulProcessorWithInitialState accumulates temperature 
values
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "270", "a: 120 + 150 = 270")
+            self.assertEqual(values.get("b"), "50", "b: 50")
+            self.assertEqual(values.get("c"), "30", "c: 30")
+
+        def verify_after_increase(results):
+            # After repartition, state should be preserved and accumulated
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "371", "State for 'a': 270 + 101 
= 371")
+            self.assertEqual(values.get("b"), "152", "State for 'b': 50 + 102 
= 152")
+            self.assertEqual(values.get("d"), "103", "New key 'd' should have 
value 103")
+
+        def verify_after_decrease(results):
+            # After repartition, state should still be preserved
+            values = {row.id: row.value for row in results}
+            self.assertEqual(values.get("a"), "475", "State for 'a': 371 + 104 
= 475")
+            self.assertEqual(values.get("c"), "135", "State for 'c': 30 + 105 
= 135")
+            self.assertEqual(values.get("e"), "106", "New key 'e' should have 
value 106")
+
+        OfflineStateRepartitionTestUtils.run_repartition_test(
+            spark=self.spark,
+            num_shuffle_partitions=self.NUM_SHUFFLE_PARTITIONS,
+            create_streaming_df=create_streaming_df,
+            output_mode="update",
+            batch1_data="a,120\na,150\nb,50\nc,30\n",  # a:270, b:50, c:30
+            batch2_data="a,101\nb,102\nd,103\n",  # a:371, b:152, d:103 (new)
+            batch3_data="a,104\nc,105\ne,106\n",  # a:475, c:135, e:106 (new)
+            verify_initial=verify_initial,
+            verify_after_increase=verify_after_increase,
+            verify_after_decrease=verify_after_decrease,
+        )
+
+    def test_repartition_with_streaming_tws(self):
+        """Test repartition for streaming transformWithState with both row and 
pandas processors"""
+        self._run_tws_repartition_test(
+            SimpleStatefulProcessorWithInitialStateFactory(), is_pandas=False
+        )
+
+    def test_repartition_with_streaming_tws_in_pandas(self):

Review Comment:
   pandas tests require: 
https://github.com/apache/spark/blob/9ce067c7d532a6766d622e55164baa942a065250/python/pyspark/sql/tests/pandas/streaming/test_pandas_transform_with_state_state_variable.py#L979



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala:
##########
@@ -299,6 +299,13 @@ class StatePartitionAllColumnFamiliesReader(
     isTWSOperator(operatorName) && colFamilyName == 
StateStore.DEFAULT_COL_FAMILY_NAME
   }
 
+
+  private def checkIfUseColumnFamily(schema: Set[StateStoreColFamilySchema]): 
Boolean = {

Review Comment:
   This should be a `private lazy val` right? We don't need to compute this 
every time right? since the `stateStoreColFamilySchemas` isn't changing within 
the class



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to