micheal-o commented on code in PR #53931:
URL: https://github.com/apache/spark/pull/53931#discussion_r2723446698


##########
python/pyspark/sql/tests/streaming/test_streaming_offline_state_repartition.py:
##########
@@ -0,0 +1,249 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+from pyspark import SparkConf
+from pyspark.sql.streaming.state import GroupStateTimeout
+from pyspark.sql.types import LongType, StringType, StructType, StructField
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+if have_pandas:
+    import pandas as pd
+
+if have_pyarrow:
+    import pyarrow as pa  # noqa: F401
+
+
+class StreamingOfflineStateRepartitionTests(ReusedSQLTestCase):
+    """
+    Test suite for Offline state repartitioning.
+    """
+    NUM_SHUFFLE_PARTITIONS = 3
+
+    @classmethod
+    def conf(cls):
+        cfg = SparkConf()
+        cfg.set("spark.sql.shuffle.partitions", 
str(cls.NUM_SHUFFLE_PARTITIONS))
+        cfg.set(
+            "spark.sql.streaming.stateStore.providerClass",
+            
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider"
+        )
+        return cfg
+
+    def test_fail_if_empty_checkpoint_directory(self):
+        """Test that repartition fails if checkpoint directory is empty."""
+        with tempfile.TemporaryDirectory() as checkpoint_dir:
+            with self.assertRaisesRegex(
+                Exception,
+                "STATE_REPARTITION_INVALID_CHECKPOINT.NO_COMMITTED_BATCH"
+            ):
+                
self.spark._streamingCheckpointManager.repartition(checkpoint_dir, 5)
+
+    def test_fail_if_no_batch_found_in_checkpoint_directory(self):
+        """Test that repartition fails if no batch found in checkpoint 
directory."""
+        with tempfile.TemporaryDirectory() as checkpoint_dir:
+            # Write commit log but no offset log
+            commits_dir = os.path.join(checkpoint_dir, "commits")
+            os.makedirs(commits_dir)
+            # Create a minimal commit file for batch 0
+            with open(os.path.join(commits_dir, "0"), "w") as f:
+                f.write("v1\n{}")
+
+            with self.assertRaisesRegex(
+                Exception,
+                "STATE_REPARTITION_INVALID_CHECKPOINT.NO_BATCH_FOUND"
+            ):
+                
self.spark._streamingCheckpointManager.repartition(checkpoint_dir, 5)
+
+    def test_fail_if_repartition_parameter_is_invalid(self):
+        """Test that repartition fails with invalid parameters."""
+        # Test null checkpoint location
+        with self.assertRaisesRegex(
+            Exception,
+            "STATE_REPARTITION_INVALID_PARAMETER.IS_NULL"
+        ):
+            self.spark._streamingCheckpointManager.repartition(None, 5)
+
+        # Test empty checkpoint location
+        with self.assertRaisesRegex(
+            Exception,
+            "STATE_REPARTITION_INVALID_PARAMETER.IS_EMPTY"
+        ):
+            self.spark._streamingCheckpointManager.repartition("", 5)
+
+        # Test numPartitions <= 0
+        with self.assertRaisesRegex(
+            Exception,
+            "STATE_REPARTITION_INVALID_PARAMETER.IS_NOT_GREATER_THAN_ZERO"
+        ):
+            self.spark._streamingCheckpointManager.repartition("test", 0)
+
+    @unittest.skipIf(
+        not have_pandas or not have_pyarrow,
+        pandas_requirement_message or pyarrow_requirement_message,
+    )
+    def test_repartition_with_apply_in_pandas_with_state(self):

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to