fedimser commented on code in PR #53758:
URL: https://github.com/apache/spark/pull/53758#discussion_r2748042109


##########
python/pyspark/sql/tests/pandas/streaming/test_tws_tester.py:
##########
@@ -0,0 +1,1294 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import tempfile
+import unittest
+
+import pandas as pd
+import pandas.testing as pdt
+
+from pyspark import SparkConf
+from pyspark.sql import DataFrame
+from pyspark.sql.functions import split
+from pyspark.sql.streaming import StatefulProcessor, TwsTester
+from pyspark.sql.streaming.query import StreamingQuery
+from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import 
(
+    AllMethodsTestProcessorFactory,
+    EventTimeCountProcessorFactory,
+    EventTimeSessionProcessorFactory,
+    RunningCountStatefulProcessorFactory,
+    SessionTimeoutProcessorFactory,
+    TopKProcessorFactory,
+    WordFrequencyProcessorFactory,
+)
+from pyspark.sql.types import (
+    DoubleType,
+    IntegerType,
+    Row,
+    StringType,
+    StructField,
+    StructType,
+)
+from pyspark.errors import PySparkValueError, PySparkAssertionError
+from pyspark.errors.exceptions.base import IllegalArgumentException
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+
[email protected](
+    not have_pandas or not have_pyarrow,
+    pandas_requirement_message or pyarrow_requirement_message or "",
+)
+class TwsTesterTests(ReusedSQLTestCase):
+    def test_running_count_processor(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        self.assertEqual(tester.test("key1", [Row(value="a")]), 
[Row(key="key1", count=1)])
+        self.assertEqual(
+            tester.test("key2", [Row(value="a"), Row(value="a")]),
+            [Row(key="key2", count=2)],
+        )
+        self.assertEqual(tester.test("key3", [Row(value="a")]), 
[Row(key="key3", count=1)])
+        self.assertEqual(
+            tester.test("key1", [Row(value="a"), Row(value="a"), 
Row(value="a")]),
+            [Row(key="key1", count=4)],
+        )
+
+        self.assertEqual(tester.peekValueState("count", "key1"), (4,))
+        self.assertEqual(tester.peekValueState("count", "key2"), (2,))
+        self.assertEqual(tester.peekValueState("count", "key3"), (1,))
+        self.assertIsNone(tester.peekValueState("count", "key4"))
+
+    def test_running_count_processor_pandas(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(processor)
+
+        ans1 = tester.testInPandas("key1", pd.DataFrame({"value": ["a"]}))
+        expected1 = pd.DataFrame({"key": ["key1"], "count": [1]})
+        pdt.assert_frame_equal(ans1, expected1, check_like=True)
+
+        ans2 = tester.testInPandas("key2", pd.DataFrame({"value": ["a", "a"]}))
+        expected2 = pd.DataFrame({"key": ["key2"], "count": [2]})
+        pdt.assert_frame_equal(ans2, expected2, check_like=True)
+
+        ans3 = tester.testInPandas("key3", pd.DataFrame({"value": ["a"]}))
+        expected3 = pd.DataFrame({"key": ["key3"], "count": [1]})
+        pdt.assert_frame_equal(ans3, expected3, check_like=True)
+
+        ans4 = tester.testInPandas("key1", pd.DataFrame({"value": ["a", "a", 
"a"]}))
+        expected4 = pd.DataFrame({"key": ["key1"], "count": [4]})
+        pdt.assert_frame_equal(ans4, expected4, check_like=True)
+
+    def test_direct_access_to_value_state(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        tester.updateValueState("count", "foo", (5,))
+        tester.test("foo", [Row(value="q")])
+        self.assertEqual(tester.peekValueState("count", "foo"), (6,))
+
+    def test_topk_processor(self):
+        processor = TopKProcessorFactory(k=2).row()
+        tester = TwsTester(processor)
+
+        ans1 = tester.test("key1", [Row(score=2.0), Row(score=3.0), 
Row(score=1.0)])
+        self.assertEqual(ans1, [Row(key="key1", score=3.0), Row(key="key1", 
score=2.0)])
+
+        ans2 = tester.test(
+            "key2", [Row(score=10.0), Row(score=20.0), Row(score=30.0), 
Row(score=40.0)]
+        )
+        self.assertEqual(ans2, [Row(key="key2", score=40.0), Row(key="key2", 
score=30.0)])
+
+        ans3 = tester.test("key3", [Row(score=100.0)])
+        self.assertEqual(ans3, [Row(key="key3", score=100.0)])
+
+        self.assertEqual(tester.peekListState("topK", "key1"), [(3.0,), 
(2.0,)])
+        self.assertEqual(tester.peekListState("topK", "key2"), [(40.0,), 
(30.0,)])
+        self.assertEqual(tester.peekListState("topK", "key3"), [(100.0,)])
+        self.assertEqual(tester.peekListState("topK", "key4"), [])
+
+        ans4 = tester.test("key1", [Row(score=10.0)])
+        self.assertEqual(ans4, [Row(key="key1", score=10.0), Row(key="key1", 
score=3.0)])
+        self.assertEqual(tester.peekListState("topK", "key1"), [(10.0,), 
(3.0,)])
+
+    def test_topk_processor_pandas(self):
+        processor = TopKProcessorFactory(k=2).pandas()
+        tester = TwsTester(processor)
+
+        ans1 = tester.testInPandas("key1", pd.DataFrame({"score": [2.0, 3.0, 
1.0]}))
+        expected1 = pd.DataFrame({"key": ["key1", "key1"], "score": [3.0, 
2.0]})
+        pdt.assert_frame_equal(ans1, expected1, check_like=True)
+
+        ans2 = tester.testInPandas("key2", pd.DataFrame({"score": [10.0, 20.0, 
30.0, 40.0]}))
+        expected2 = pd.DataFrame({"key": ["key2", "key2"], "score": [40.0, 
30.0]})
+        pdt.assert_frame_equal(ans2, expected2, check_like=True)
+
+        ans3 = tester.testInPandas("key3", pd.DataFrame({"score": [100.0]}))
+        expected3 = pd.DataFrame({"key": ["key3"], "score": [100.0]})
+        pdt.assert_frame_equal(ans3, expected3, check_like=True)
+
+        ans4 = tester.testInPandas("key1", pd.DataFrame({"score": [10.0]}))
+        expected4 = pd.DataFrame({"key": ["key1", "key1"], "score": [10.0, 
3.0]})
+        pdt.assert_frame_equal(ans4, expected4, check_like=True)
+
+    def test_direct_access_to_list_state(self):
+        processor = TopKProcessorFactory(k=2).row()
+        tester = TwsTester(processor)
+
+        tester.updateListState("topK", "a", [(6.0,), (5.0,)])
+        tester.updateListState("topK", "b", [(8.0,), (7.0,)])
+        tester.test("a", [Row(score=10.0)])
+        tester.test("b", [Row(score=7.5)])
+        tester.test("c", [Row(score=1.0)])
+
+        assert tester.peekListState("topK", "a") == [(10.0,), (6.0,)]
+        assert tester.peekListState("topK", "b") == [(8.0,), (7.5,)]
+        assert tester.peekListState("topK", "c") == [(1.0,)]
+        assert tester.peekListState("topK", "d") == []
+
+    def test_word_frequency_processor(self):
+        processor = WordFrequencyProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        ans1 = tester.test(
+            "user1",
+            [
+                Row(word="hello"),
+                Row(word="world"),
+                Row(word="hello"),
+                Row(word="world"),
+            ],
+        )
+        self.assertEqual(
+            ans1,
+            [
+                Row(key="user1", word="hello", count=1),
+                Row(key="user1", word="world", count=1),
+                Row(key="user1", word="hello", count=2),
+                Row(key="user1", word="world", count=2),
+            ],
+        )
+
+        ans2 = tester.test("user2", [Row(word="hello"), Row(word="spark")])
+        self.assertEqual(
+            ans2,
+            [
+                Row(key="user2", word="hello", count=1),
+                Row(key="user2", word="spark", count=1),
+            ],
+        )
+
+        # Check state using peekMapState.
+        self.assertEqual(
+            tester.peekMapState("frequencies", "user1"),
+            {("hello",): (2,), ("world",): (2,)},
+        )
+        self.assertEqual(
+            tester.peekMapState("frequencies", "user2"),
+            {("hello",): (1,), ("spark",): (1,)},
+        )
+        self.assertEqual(tester.peekMapState("frequencies", "user3"), {})
+
+        # Process more data for user1.
+        ans3 = tester.test("user1", [Row(word="hello"), Row(word="test")])
+        self.assertEqual(
+            ans3,
+            [
+                Row(key="user1", word="hello", count=3),
+                Row(key="user1", word="test", count=1),
+            ],
+        )
+        self.assertEqual(
+            tester.peekMapState("frequencies", "user1"),
+            {("hello",): (3,), ("world",): (2,), ("test",): (1,)},
+        )
+
+    def test_word_frequency_processor_pandas(self):
+        processor = WordFrequencyProcessorFactory().pandas()
+        tester = TwsTester(processor)
+
+        input_df1 = pd.DataFrame({"word": ["hello", "world", "hello", 
"world"]})
+        ans1 = tester.testInPandas("user1", input_df1)
+        expected1 = pd.DataFrame(
+            {
+                "key": ["user1", "user1", "user1", "user1"],
+                "word": ["hello", "world", "hello", "world"],
+                "count": [1, 1, 2, 2],
+            }
+        )
+        pdt.assert_frame_equal(ans1, expected1, check_like=True)
+
+        input_df2 = pd.DataFrame({"word": ["hello", "spark"]})
+        ans2 = tester.testInPandas("user2", input_df2)
+        expected2 = pd.DataFrame(
+            {
+                "key": ["user2", "user2"],
+                "word": ["hello", "spark"],
+                "count": [1, 1],
+            }
+        )
+        pdt.assert_frame_equal(ans2, expected2, check_like=True)
+
+        input_df3 = pd.DataFrame({"word": ["hello", "test"]})
+        ans3 = tester.testInPandas("user1", input_df3)
+        expected3 = pd.DataFrame(
+            {
+                "key": ["user1", "user1"],
+                "word": ["hello", "test"],
+                "count": [3, 1],
+            }
+        )
+        pdt.assert_frame_equal(ans3, expected3, check_like=True)
+
+    def test_direct_access_to_map_state(self):
+        processor = WordFrequencyProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        tester.updateMapState("frequencies", "user1", {("hello",): (5,), 
("world",): (3,)})
+        tester.updateMapState("frequencies", "user2", {("spark",): (10,)})
+
+        tester.test("user1", [Row(word="hello"), Row(word="goodbye")])
+        tester.test("user2", [Row(word="spark")])
+        tester.test("user3", [Row(word="new")])
+
+        self.assertEqual(
+            tester.peekMapState("frequencies", "user1"),
+            {("hello",): (6,), ("world",): (3,), ("goodbye",): (1,)},
+        )
+        self.assertEqual(tester.peekMapState("frequencies", "user2"), 
{("spark",): (11,)})
+        self.assertEqual(tester.peekMapState("frequencies", "user3"), 
{("new",): (1,)})
+        self.assertEqual(tester.peekMapState("frequencies", "user4"), {})
+
+    # Example of how TwsTester can be used to test step function.
+    def test_step_function(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        # Example of helper function using TwsTester to inspect how processing 
a single row changes
+        # state.
+        def step_function(key: str, input_row: str, state_in: int) -> int:
+            tester.updateValueState("count", key, (state_in,))
+            tester.test(key, [Row(value=input_row)])
+            return tester.peekValueState("count", key)[0]
+
+        self.assertEqual(step_function("key1", "a", 10), 11)
+
+    # Example of how TwsTester can be used to simulate real-time mode (row by 
row processing).
+    def test_row_by_row(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        # Example of helper function to test how TransformWithState processes 
rows one-by-one,
+        # which can be used to simulate real-time mode.
+        def test_row_by_row_helper(input_rows: list[tuple[str, Row]]) -> 
list[Row]:
+            result: list[Row] = []
+            for key, row in input_rows:
+                result += tester.test(key, [row])
+            return result
+
+        ans = test_row_by_row_helper(
+            [
+                ("key1", Row(value="a")),
+                ("key2", Row(value="b")),
+                ("key1", Row(value="c")),
+                ("key2", Row(value="b")),
+                ("key1", Row(value="c")),
+                ("key1", Row(value="c")),
+                ("key3", Row(value="q")),
+            ]
+        )
+        self.assertEqual(
+            ans,
+            [
+                Row(key="key1", count=1),
+                Row(key="key2", count=1),
+                Row(key="key1", count=2),
+                Row(key="key2", count=2),
+                Row(key="key1", count=3),
+                Row(key="key1", count=4),
+                Row(key="key3", count=1),
+            ],
+        )
+
+    # Tests that TwsTester calls handleInitialState.
+    def test_initial_state_row(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(
+            processor,
+            initialStateRow=[
+                ("a", Row(initial_count=10)),
+                ("b", Row(initial_count=20)),
+            ],
+        )
+        self.assertEqual(tester.peekValueState("count", "a"), (10,))
+        self.assertEqual(tester.peekValueState("count", "b"), (20,))
+
+        ans1 = tester.test("a", [Row(value="a")])
+        self.assertEqual(ans1, [Row(key="a", count=11)])
+
+        ans2 = tester.test("c", [Row(value="c")])
+        self.assertEqual(ans2, [Row(key="c", count=1)])
+
+    def test_initial_state_pandas(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(
+            processor,
+            initialStatePandas=[
+                ("a", pd.DataFrame({"initial_count": [10]})),
+                ("b", pd.DataFrame({"initial_count": [20]})),
+            ],
+        )
+        self.assertEqual(tester.peekValueState("count", "a"), (10,))
+        self.assertEqual(tester.peekValueState("count", "b"), (20,))
+
+        ans1 = tester.testInPandas("a", pd.DataFrame({"value": ["a"]}))
+        expected1 = pd.DataFrame({"key": ["a"], "count": [11]})
+        pdt.assert_frame_equal(ans1, expected1, check_like=True)
+
+        ans2 = tester.testInPandas("c", pd.DataFrame({"value": ["c"]}))
+        expected2 = pd.DataFrame({"key": ["c"], "count": [1]})
+        pdt.assert_frame_equal(ans2, expected2, check_like=True)
+
+    def test_all_methods_processor(self):
+        """Test that TwsTester exercises all state methods."""
+        processor = AllMethodsTestProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        results = tester.test(
+            "k",
+            [
+                Row(cmd="value-exists"),  # false
+                Row(cmd="value-set"),  # set to 42
+                Row(cmd="value-exists"),  # true
+                Row(cmd="value-clear"),  # clear
+                Row(cmd="value-exists"),  # false again
+                Row(cmd="list-exists"),  # false
+                Row(cmd="list-append"),  # append a, b
+                Row(cmd="list-exists"),  # true
+                Row(cmd="list-append-array"),  # append c, d
+                Row(cmd="list-get"),  # a,b,c,d
+                Row(cmd="map-exists"),  # false
+                Row(cmd="map-add"),  # add x=1, y=2, z=3
+                Row(cmd="map-exists"),  # true
+                Row(cmd="map-keys"),  # x,y,z
+                Row(cmd="map-values"),  # 1,2,3
+                Row(cmd="map-iterator"),  # x=1,y=2,z=3
+                Row(cmd="map-remove"),  # remove y
+                Row(cmd="map-keys"),  # x,z
+                Row(cmd="map-clear"),  # clear map
+                Row(cmd="map-exists"),  # false
+            ],
+        )
+
+        self.assertEqual(
+            results,
+            [
+                Row(key="k", result="value-exists:False"),
+                Row(key="k", result="value-set:done"),
+                Row(key="k", result="value-exists:True"),
+                Row(key="k", result="value-clear:done"),
+                Row(key="k", result="value-exists:False"),
+                Row(key="k", result="list-exists:False"),
+                Row(key="k", result="list-append:done"),
+                Row(key="k", result="list-exists:True"),
+                Row(key="k", result="list-append-array:done"),
+                Row(key="k", result="list-get:a,b,c,d"),
+                Row(key="k", result="map-exists:False"),
+                Row(key="k", result="map-add:done"),
+                Row(key="k", result="map-exists:True"),
+                Row(key="k", result="map-keys:x,y,z"),
+                Row(key="k", result="map-values:1,2,3"),
+                Row(key="k", result="map-iterator:x=1,y=2,z=3"),
+                Row(key="k", result="map-remove:done"),
+                Row(key="k", result="map-keys:x,z"),
+                Row(key="k", result="map-clear:done"),
+                Row(key="k", result="map-exists:False"),
+            ],
+        )
+
+    def test_both_initial_states_specified(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        with self.assertRaises(AssertionError):
+            TwsTester(
+                processor,
+                initialStateRow=[("a", Row(initial_count=10))],
+                initialStatePandas=[("a", pd.DataFrame({"initial_count": 
[10]}))],
+            )
+
+    def test_timer_registration_raises_error_in_none_mode(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor, timeMode="None")
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            tester.handle.registerTimer(12345)
+
+    def test_delete_timer_raises_error_in_none_mode(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor, timeMode="None")
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            tester.handle.deleteTimer(12345)
+
+    def test_list_timers_raises_error_in_none_mode(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor, timeMode="None")
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            list(tester.handle.listTimers())
+
+    def test_empty_input_row(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.test("key1", [])
+        self.assertEqual(result, [Row(key="key1", count=0)])
+
+    def test_empty_input_pandas(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(processor)
+        input_df = pd.DataFrame({"value": []})
+        result = tester.testInPandas("key1", input_df)
+        expected = pd.DataFrame({"key": ["key1"], "count": [0]})
+        pdt.assert_frame_equal(result, expected, check_like=True)
+
+    def test_empty_initial_state_row(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor, initialStateRow=[])
+        result = tester.test("key1", [Row(value="a")])
+        self.assertEqual(result, [Row(key="key1", count=1)])
+
+    def test_empty_initial_state_pandas(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(processor, initialStatePandas=[])
+        result = tester.testInPandas("key1", pd.DataFrame({"value": ["a"]}))
+        expected = pd.DataFrame({"key": ["key1"], "count": [1]})
+        pdt.assert_frame_equal(result, expected, check_like=True)
+
+    def test_none_values_in_input(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.test("key1", [Row(value=None), Row(value="a")])
+        self.assertEqual(result, [Row(key="key1", count=2)])
+
+    def test_single_row_input(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.test("key1", [Row(value="a")])
+        self.assertEqual(result, [Row(key="key1", count=1)])
+
+    def test_peek_nonexistent_state(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.peekValueState("count", "nonexistent_key")
+        self.assertIsNone(result)
+
+    def test_update_nonexistent_state(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        with self.assertRaises(AssertionError):
+            tester.updateValueState("nonexistent_state", "key1", (5,))
+
+    def test_peek_state_before_initialization(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        with self.assertRaises(AssertionError):
+            tester.peekValueState("nonexistent_state", "key1")
+
+    def test_clear_nonexistent_list_state(self):
+        processor = TopKProcessorFactory(k=2).row()
+        tester = TwsTester(processor)
+        tester.handle.setGroupingKey("key1")
+        list_state = tester.handle.getListState("topK", "double")
+        list_state.clear()
+
+    def test_delete_if_exists_nonexistent(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        tester.handle.deleteIfExists("nonexistent_state")
+
+    def test_numeric_key(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        self.assertEqual(tester.test(1, [Row(value="a"), Row(value="c")]), 
[Row(key=1, count=2)])
+        self.assertEqual(tester.test(2, [Row(value="b")]), [Row(key=2, 
count=1)])
+
+    def test_tuple_key(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.test((1, 2), [Row(value="a"), Row(value="b")])
+        self.assertEqual(result, [Row(key=(1, 2), count=2)])
+
+    def test_special_characters_in_key(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.test("key@#$%", [Row(value="a"), Row(value="b")])
+        self.assertEqual(result, [Row(key="key@#$%", count=2)])
+
+    def test_null_key_value_row(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        result = tester.test(None, [Row(value="a"), Row(value="b")])
+        self.assertEqual(result, [Row(key=None, count=2)])
+
+    def test_null_key_value_pandas(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(processor)
+        input_df = pd.DataFrame({"value": ["a", "b"]})
+        result = tester.testInPandas(None, input_df)
+        expected = pd.DataFrame({"key": [None], "count": [2]})
+        pdt.assert_frame_equal(result, expected, check_like=True)
+
+    def test_multiple_value_states(self):
+        from pyspark.sql.streaming.stateful_processor import (
+            StatefulProcessor,
+            StatefulProcessorHandle,
+        )
+        from typing import Iterator
+
+        class MultiValueStateProcessor(StatefulProcessor):
+            def init(self, handle: StatefulProcessorHandle) -> None:
+                self.handle = handle
+                self.state1 = handle.getValueState("state1", "int")
+                self.state2 = handle.getValueState("state2", "int")
+
+            def handleInputRows(self, key, rows, timerValues) -> Iterator:
+                val1 = self.state1.get() if self.state1.exists() else None
+                count1 = val1[0] if val1 else 0
+                val2 = self.state2.get() if self.state2.exists() else None
+                count2 = val2[0] if val2 else 0
+                for row in rows:
+                    count1 += 1
+                    count2 += 2
+                self.state1.update((count1,))
+                self.state2.update((count2,))
+                yield Row(key=key[0], count1=count1, count2=count2)
+
+        processor = MultiValueStateProcessor()
+        tester = TwsTester(processor)
+        result = tester.test("key1", [Row(value="a")])
+        self.assertEqual(result, [Row(key="key1", count1=1, count2=2)])
+        self.assertEqual(tester.peekValueState("state1", "key1"), (1,))
+        self.assertEqual(tester.peekValueState("state2", "key1"), (2,))
+
+    def test_multiple_list_states(self):
+        from pyspark.sql.streaming.stateful_processor import (
+            StatefulProcessor,
+            StatefulProcessorHandle,
+        )
+        from typing import Iterator
+
+        class MultiListStateProcessor(StatefulProcessor):
+            def init(self, handle: StatefulProcessorHandle) -> None:
+                self.handle = handle
+                self.list1 = handle.getListState("list1", "string")
+                self.list2 = handle.getListState("list2", "string")
+
+            def handleInputRows(self, key, rows, timerValues) -> Iterator:
+                for row in rows:
+                    self.list1.appendValue((row.value,))
+                    self.list2.appendValue((row.value + "_2",))
+                yield Row(key=key[0], count=1)
+
+        processor = MultiListStateProcessor()
+        tester = TwsTester(processor)
+        tester.test("key1", [Row(value="a")])
+        self.assertEqual(tester.peekListState("list1", "key1"), [("a",)])
+        self.assertEqual(tester.peekListState("list2", "key1"), [("a_2",)])
+
+    def test_multiple_map_states(self):
+        from pyspark.sql.streaming.stateful_processor import (
+            StatefulProcessor,
+            StatefulProcessorHandle,
+        )
+        from typing import Iterator
+
+        class MultiMapStateProcessor(StatefulProcessor):
+            def init(self, handle: StatefulProcessorHandle) -> None:
+                self.handle = handle
+                self.map1 = handle.getMapState("map1", "string", "int")
+                self.map2 = handle.getMapState("map2", "string", "int")
+
+            def handleInputRows(self, key, rows, timerValues) -> Iterator:
+                for row in rows:
+                    val1 = self.map1.getValue((row.word,))
+                    count1 = val1[0] if val1 is not None else 0
+                    val2 = self.map2.getValue((row.word,))
+                    count2 = val2[0] if val2 is not None else 0
+                    self.map1.updateValue((row.word,), (count1 + 1,))
+                    self.map2.updateValue((row.word,), (count2 + 2,))
+                yield Row(key=key[0], word=row.word)
+
+        processor = MultiMapStateProcessor()
+        tester = TwsTester(processor)
+        tester.test("key1", [Row(word="hello")])
+        self.assertEqual(tester.peekMapState("map1", "key1"), {("hello",): 
(1,)})
+        self.assertEqual(tester.peekMapState("map2", "key1"), {("hello",): 
(2,)})
+
+    def test_mixed_state_types(self):
+        from pyspark.sql.streaming.stateful_processor import (
+            StatefulProcessor,
+            StatefulProcessorHandle,
+        )
+        from typing import Iterator
+
+        class MixedStateProcessor(StatefulProcessor):
+            def init(self, handle: StatefulProcessorHandle) -> None:
+                self.handle = handle
+                self.value_state = handle.getValueState("value", "int")
+                self.list_state = handle.getListState("list", "string")
+                self.map_state = handle.getMapState("map", "string", "int")
+
+            def handleInputRows(self, key, rows, timerValues) -> Iterator:
+                for row in rows:
+                    val = self.value_state.get() if self.value_state.exists() 
else None
+                    count = val[0] if val else 0
+                    self.value_state.update((count + 1,))
+                    self.list_state.appendValue((row.value,))
+                    self.map_state.updateValue((row.value,), (1,))
+                yield Row(key=key[0], count=1)
+
+        processor = MixedStateProcessor()
+        tester = TwsTester(processor)
+        tester.test("key1", [Row(value="a")])
+        self.assertEqual(tester.peekValueState("value", "key1"), (1,))
+        self.assertEqual(tester.peekListState("list", "key1"), [("a",)])
+        self.assertEqual(tester.peekMapState("map", "key1"), {("a",): (1,)})
+
+    def test_pandas_with_multiple_rows(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(processor)
+        input_df = pd.DataFrame({"value": ["a", "b", "c"]})
+        result = tester.testInPandas("key1", input_df)
+        expected = pd.DataFrame({"key": ["key1"], "count": [3]})
+        pdt.assert_frame_equal(result, expected, check_like=True)
+
+    def test_pandas_dtype_preservation(self):
+        processor = RunningCountStatefulProcessorFactory().pandas()
+        tester = TwsTester(processor)
+        input_df = pd.DataFrame({"value": ["a"]})
+        result = tester.testInPandas("key1", input_df)
+        self.assertEqual(result["count"].dtype, "int64")
+
+    def test_processor_init_called_once(self):
+        from pyspark.sql.streaming.stateful_processor import (
+            StatefulProcessor,
+            StatefulProcessorHandle,
+        )
+        from typing import Iterator
+
+        init_call_count = [0]
+
+        class InitCountingProcessor(StatefulProcessor):
+            def init(self, handle: StatefulProcessorHandle) -> None:
+                init_call_count[0] += 1
+                self.handle = handle
+                self.state = handle.getValueState("count", "int")
+
+            def handleInputRows(self, key, rows, timerValues) -> Iterator:
+                val = self.state.get() if self.state.exists() else None
+                count = val[0] if val else 0
+                for row in rows:
+                    count += 1
+                self.state.update((count,))
+                yield Row(key=key[0], count=count)
+
+        processor = InitCountingProcessor()
+        tester = TwsTester(processor)
+        tester.test("key1", [Row(value="a")])
+        tester.test("key1", [Row(value="b")])
+        self.assertEqual(init_call_count[0], 1)
+
+    def test_processor_state_persists_across_batches(self):
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+        tester.test("key1", [Row(value="a")])
+        result = tester.test("key1", [Row(value="b")])
+        self.assertEqual(result, [Row(key="key1", count=2)])
+
+    def test_handle_initial_state_not_called_without_initial_state(self):
+        from pyspark.sql.streaming.stateful_processor import (
+            StatefulProcessor,
+            StatefulProcessorHandle,
+        )
+        from typing import Iterator
+
+        initial_state_call_count = [0]
+
+        class InitialStateCountingProcessor(StatefulProcessor):
+            def init(self, handle: StatefulProcessorHandle) -> None:
+                self.handle = handle
+                self.state = handle.getValueState("count", "int")
+
+            def handleInitialState(self, key, initialState, timerValues) -> 
None:
+                initial_state_call_count[0] += 1
+
+            def handleInputRows(self, key, rows, timerValues) -> Iterator:
+                val = self.state.get() if self.state.exists() else None
+                count = val[0] if val else 0
+                for row in rows:
+                    count += 1
+                self.state.update((count,))
+                yield Row(key=key[0], count=count)
+
+        processor = InitialStateCountingProcessor()
+        tester = TwsTester(processor)
+        tester.test("key1", [Row(value="a")])
+        self.assertEqual(initial_state_call_count[0], 0)
+
+    def test_delete_value_state(self):
+        """Test that deleteState correctly deletes value state for a given 
key."""
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        tester.updateValueState("count", "key1", (10,))
+        tester.updateValueState("count", "key2", (20,))
+        self.assertEqual(tester.peekValueState("count", "key1"), (10,))
+
+        tester.deleteState("count", "key1")
+        self.assertIsNone(tester.peekValueState("count", "key1"))
+        self.assertEqual(tester.peekValueState("count", "key2"), (20,))
+
+    def test_delete_list_state(self):
+        """Test that deleteState correctly deletes list state for a given 
key."""
+        processor = TopKProcessorFactory(k=3).row()
+        tester = TwsTester(processor)
+
+        tester.updateListState("topK", "key1", [(1.0,), (2.0,), (3.0,)])
+        tester.updateListState("topK", "key2", [(4.0,), (5.0,)])
+        self.assertEqual(tester.peekListState("topK", "key1"), [(1.0,), 
(2.0,), (3.0,)])
+
+        tester.deleteState("topK", "key1")
+        self.assertEqual(tester.peekListState("topK", "key1"), [])
+        self.assertEqual(tester.peekListState("topK", "key2"), [(4.0,), 
(5.0,)])
+
+    def test_delete_map_state(self):
+        """Test that deleteState correctly deletes map state for a given 
key."""
+        processor = WordFrequencyProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        tester.updateMapState("frequencies", "user1", {("hello",): (5,), 
("world",): (3,)})
+        tester.updateMapState("frequencies", "user2", {("spark",): (10,)})
+        self.assertEqual(
+            tester.peekMapState("frequencies", "user1"),
+            {("hello",): (5,), ("world",): (3,)},
+        )
+
+        tester.deleteState("frequencies", "user1")
+        self.assertEqual(tester.peekMapState("frequencies", "user1"), {})
+        self.assertEqual(tester.peekMapState("frequencies", "user2"), 
{("spark",): (10,)})
+
+    def test_delete_nonexistent_state_raises_error(self):
+        """Test that deleteState raises an error for non-existent state."""
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor)
+
+        with self.assertRaisesRegex(PySparkAssertionError, "STATE_NOT_EXISTS"):
+            tester.deleteState("nonexistent_state", "key1")
+
+    # Timer tests.
+
+    def test_processing_time_timers(self):
+        """Test that TwsTester supports ProcessingTime timers."""
+        processor = SessionTimeoutProcessorFactory().row()
+        tester = TwsTester(processor, timeMode="ProcessingTime")
+
+        # Process input for key1 - should register a timer at t=10000.
+        result1 = tester.test("key1", [Row(value="hello")])
+        self.assertEqual(result1, [Row(key="key1", result="received:hello")])
+
+        # Set processing time to 5000 - timer should NOT fire yet.
+        expired1 = tester.setProcessingTime(5000)
+        self.assertEqual(expired1, [])
+
+        # Process input for key2 at t=5000 - should register timer at t=15000.
+        result2 = tester.test("key2", [Row(value="world")])
+        self.assertEqual(result2, [Row(key="key2", result="received:world")])
+
+        # Set processing time to 11000 - key1's timer should fire.
+        expired2 = tester.setProcessingTime(11000)
+        self.assertEqual(expired2, [Row(key="key1", result="session-expired")])
+
+        # Set processing time to 16000 - key2's timer should fire.
+        expired3 = tester.setProcessingTime(16000)
+        self.assertEqual(expired3, [Row(key="key2", result="session-expired")])
+
+        # Verify state is cleared after session expiry.
+        self.assertIsNone(tester.peekValueState("lastSeen", "key1"))
+        self.assertIsNone(tester.peekValueState("lastSeen", "key2"))
+
+    def test_processing_time_timers_pandas(self):
+        """Test that TwsTester supports ProcessingTime timers in Pandas 
mode."""
+        processor = SessionTimeoutProcessorFactory().pandas()
+        tester = TwsTester(processor, timeMode="ProcessingTime")
+
+        # Process input for key1 - should register a timer at t=10000.
+        result1 = tester.testInPandas("key1", pd.DataFrame({"value": 
["hello"]}))
+        expected1 = pd.DataFrame({"key": ["key1"], "result": 
["received:hello"]})
+        pdt.assert_frame_equal(result1, expected1, check_like=True)
+
+        # Set processing time to 5000 - timer should NOT fire yet.
+        expired1 = tester.setProcessingTime(5000)
+        self.assertEqual(len(expired1), 0)
+
+        # Set processing time to 11000 - timer should fire.
+        expired2 = tester.setProcessingTime(11000)
+        expected2 = pd.DataFrame({"key": ["key1"], "result": 
["session-expired"]})
+        pdt.assert_frame_equal(expired2, expected2, check_like=True)
+
+    def test_event_time_timers_manual_watermark(self):
+        """Test that TwsTester supports EventTime timers fired by manual 
watermark advance."""
+        processor = EventTimeSessionProcessorFactory().row()
+
+        def event_time_extractor(row: Row) -> int:
+            return row.event_time_ms
+
+        tester = TwsTester(
+            processor,
+            timeMode="EventTime",
+            eventTimeExtractor=event_time_extractor,
+        )
+
+        # Process event at t=10000 for key1 - registers timer at t=15000.
+        result1 = tester.test("key1", [Row(event_time_ms=10000, 
value="hello")])
+        self.assertEqual(result1, [Row(key="key1", 
result="received:hello@10000")])
+
+        # Process event at t=11000 for key2 - registers timer at t=16000.
+        result2 = tester.test("key2", [Row(event_time_ms=11000, 
value="world")])
+        self.assertEqual(result2, [Row(key="key2", 
result="received:world@11000")])
+
+        # Set watermark to 15000 - key1's timer should fire.
+        expired1 = tester.setWatermark(15000)
+        self.assertEqual(len(expired1), 1)
+        self.assertEqual(expired1[0], Row(key="key1", 
result="session-expired@watermark=15000"))
+
+        # Set watermark to 16000 - key2's timer should fire.
+        expired2 = tester.setWatermark(16000)
+        self.assertEqual(len(expired2), 1)
+        self.assertEqual(expired2[0], Row(key="key2", 
result="session-expired@watermark=16000"))
+
+        # Verify state is cleared.
+        self.assertIsNone(tester.peekValueState("lastEventTime", "key1"))
+        self.assertIsNone(tester.peekValueState("lastEventTime", "key2"))
+
+    def test_late_event_filtering(self):
+        """Test that TwsTester filters late events in EventTime mode."""
+        processor = EventTimeCountProcessorFactory().row()
+
+        def event_time_extractor(row: Row) -> int:
+            return row.event_time_ms
+
+        tester = TwsTester(
+            processor,
+            timeMode="EventTime",
+            eventTimeExtractor=event_time_extractor,
+        )
+
+        # Initially watermark is 0, so all events should be processed.
+        result1 = tester.test(
+            "key1",
+            [
+                Row(event_time_ms=5000, value="a"),
+                Row(event_time_ms=6000, value="b"),
+                Row(event_time_ms=7000, value="c"),
+            ],
+        )
+        self.assertEqual(result1, [Row(key="key1", count=3)])
+        self.assertEqual(tester.peekValueState("count", "key1"), (3,))
+
+        # Set watermark to 6000.
+        tester.setWatermark(6000)
+
+        # Now watermark is 6000. Send events with mixed event times:
+        # - (4000, "late1") -> event time 4000 <= watermark 6000, should be 
FILTERED.
+        # - (5000, "late2") -> event time 5000 <= watermark 6000, should be 
FILTERED.
+        # - (6000, "late3") -> event time 6000 <= watermark 6000, should be 
FILTERED.
+        # - (8000, "ontime") -> event time 8000 > watermark 6000, should be 
PROCESSED.
+        result2 = tester.test(
+            "key1",
+            [
+                Row(event_time_ms=4000, value="late1"),
+                Row(event_time_ms=5000, value="late2"),
+                Row(event_time_ms=6000, value="late3"),
+                Row(event_time_ms=8000, value="ontime"),
+            ],
+        )
+        # Only 1 event should be processed (the on-time one)
+        self.assertEqual(result2, [Row(key="key1", count=4)])  # 3 + 1 = 4
+        self.assertEqual(tester.peekValueState("count", "key1"), (4,))
+
+    def test_all_late_events_filtered(self):
+        """Test that all late events are filtered when all are older than 
watermark."""
+        processor = EventTimeCountProcessorFactory().row()
+
+        def event_time_extractor(row: Row) -> int:
+            return row.event_time_ms
+
+        tester = TwsTester(
+            processor,
+            timeMode="EventTime",
+            eventTimeExtractor=event_time_extractor,
+        )
+
+        # Process events to get initial count.
+        tester.test("key1", [Row(event_time_ms=10000, value="a")])
+        self.assertEqual(tester.peekValueState("count", "key1"), (1,))
+
+        # Set watermark to 10000.
+        tester.setWatermark(10000)
+
+        # Now send only late events - all should be filtered, count unchanged.
+        result = tester.test(
+            "key1",
+            [
+                Row(event_time_ms=1000, value="late1"),
+                Row(event_time_ms=5000, value="late2"),
+                Row(event_time_ms=9999, value="late3"),
+                Row(event_time_ms=10000, value="late4"),  # exactly at 
watermark, also filtered
+            ],
+        )
+        # No events processed - handleInputRows is called with empty list.
+        self.assertEqual(result, [Row(key="key1", count=1)])  # Count still 1
+        self.assertEqual(tester.peekValueState("count", "key1"), (1,))
+
+    def test_timer_registration_in_none_mode_raises_error(self):
+        """Test that timer operations raise error in TimeMode.None."""
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor, timeMode="None")
+
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            tester.handle.registerTimer(12345)
+
+    def test_set_processing_time_in_wrong_mode_raises_error(self):
+        """Test that setProcessingTime raises error in non-ProcessingTime 
mode."""
+        processor = RunningCountStatefulProcessorFactory().row()
+
+        # Test in None mode.
+        tester_none = TwsTester(processor, timeMode="None")
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            tester_none.setProcessingTime(1000)
+
+    def test_set_watermark_in_wrong_mode_raises_error(self):
+        """Test that setWatermark raises error in non-EventTime mode."""
+        processor = RunningCountStatefulProcessorFactory().row()
+
+        # Test in None mode.
+        tester_none = TwsTester(processor, timeMode="None")
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            tester_none.setWatermark(1000)
+
+        # Test in ProcessingTime mode.
+        tester_pt = TwsTester(processor, timeMode="ProcessingTime")
+        with self.assertRaisesRegex(PySparkValueError, 
"UNSUPPORTED_OPERATION"):
+            tester_pt.setWatermark(1000)
+
+    def test_eventtime_mode_requires_extractor(self):
+        """Test that EventTime mode requires eventTimeExtractor."""
+        processor = RunningCountStatefulProcessorFactory().row()
+
+        with self.assertRaisesRegex(PySparkValueError, "ARGUMENT_REQUIRED"):
+            TwsTester(processor, timeMode="EventTime")
+
+    def test_processing_time_must_move_forward(self):
+        """Test that setProcessingTime requires time to move forward."""
+        processor = RunningCountStatefulProcessorFactory().row()
+        tester = TwsTester(processor, timeMode="ProcessingTime")
+
+        tester.setProcessingTime(1000)
+
+        # Same time should raise error.
+        with self.assertRaises(IllegalArgumentException):
+            tester.setProcessingTime(1000)
+
+        # Earlier time should raise error.
+        with self.assertRaises(IllegalArgumentException):
+            tester.setProcessingTime(500)
+
+    def test_watermark_must_move_forward(self):
+        """Test that setWatermark requires watermark to move forward."""
+        processor = RunningCountStatefulProcessorFactory().row()
+
+        def event_time_extractor(row: Row) -> int:
+            return row.event_time_ms
+
+        tester = TwsTester(
+            processor,
+            timeMode="EventTime",
+            eventTimeExtractor=event_time_extractor,
+        )
+
+        tester.setWatermark(1000)
+
+        # Same watermark should raise error.
+        with self.assertRaises(IllegalArgumentException):
+            tester.setWatermark(1000)
+
+        # Earlier watermark should raise error.
+        with self.assertRaises(IllegalArgumentException):
+            tester.setWatermark(500)
+
+    def test_timer_cannot_be_registered_in_past(self):
+        """Test that timer cannot be registered at or before current watermark 
in EventTime mode."""
+        processor = EventTimeSessionProcessorFactory().row()
+
+        def event_time_extractor(row: Row) -> int:
+            return row.event_time_ms
+
+        tester = TwsTester(
+            processor,
+            timeMode="EventTime",
+            eventTimeExtractor=event_time_extractor,
+        )
+
+        # Set watermark to 10000.
+        tester.setWatermark(10000)
+
+        # Try to register a timer at or before watermark - should raise error.
+        tester.handle.setGroupingKey("key1")
+        with self.assertRaises(IllegalArgumentException):
+            tester.handle.registerTimer(10000)  # exactly at watermark
+
+        with self.assertRaises(IllegalArgumentException):
+            tester.handle.registerTimer(5000)  # before watermark
+
+
[email protected](
+    not have_pandas or not have_pyarrow,
+    pandas_requirement_message or pyarrow_requirement_message or "",
+)
+class TwsTesterFuzzTests(ReusedSQLTestCase):
+    @classmethod
+    def conf(cls):
+        cfg = SparkConf()
+        cfg.set("spark.sql.shuffle.partitions", "1")
+        cfg.set(
+            "spark.sql.streaming.stateStore.providerClass",
+            
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
+        )
+        return cfg
+
+    # Helper that runs real streaming query with TWS and compares results with 
TwsTester output.
+    # Supports both row mode (use_pandas=False) and Pandas mode 
(use_pandas=True).
+    # Supports multiple batches.
+    # Assumes that the first column is the key column.
+    # Output rows are checked ignoring order.
+    # Correct only if results do not depend on reordering rows within batch.
+    def _check_tws_tester(
+        self,
+        processor: StatefulProcessor,
+        batches: list[list[tuple]],
+        input_columns: list[str],
+        input_types: list[str],
+        output_schema: StructType,
+        use_pandas: bool = False,
+    ):
+        num_batches = len(batches)
+        key_column: str = input_columns[0]  # Assume first column is the key 
column.
+        value_columns: list[str] = input_columns[1:]  # All columns except the 
key column.
+
+        num_input_columns: int = len(input_columns)
+        assert len(input_types) == num_input_columns
+
+        # Use TwsTester to compute expected results.
+        tester = TwsTester(processor)
+
+        def _run_tester(batch: list[tuple]) -> list[Row]:
+            # Group batch by key.
+            grouped = {}
+            for row in batch:
+                key = row[0]
+                if key not in grouped:
+                    grouped[key] = []
+                grouped[key].append(row[1:])
+
+            results: list[Row] = []
+            for key, values in grouped.items():
+                if use_pandas:
+                    input_to_tester = pd.DataFrame(values, 
columns=value_columns)
+                    result_df = tester.testInPandas(key, input_to_tester)
+                    
results.extend(self.spark.createDataFrame(result_df).collect())
+                else:
+                    input_to_tester: list[Row] = [
+                        Row(**dict(zip(value_columns, v))) for v in values
+                    ]
+                    results.extend(tester.test(key, input_to_tester))
+            return results
+
+        expected_results: list[list[Row]] = [_run_tester(batch) for batch in 
batches]
+
+        # Create streaming DataFrame, it will read input from files.
+        input_path = tempfile.mkdtemp()
+        df = self.spark.readStream.format("text").load(input_path)
+        df_split = df.withColumn("split_values", split(df["value"], ","))
+        df_final = df_split.select(
+            *[
+                
df_split.split_values.getItem(i).alias(col_name).cast(input_types[i])
+                for i, col_name in enumerate(input_columns)
+            ]
+        )
+        self.assertTrue(df_final.isStreaming)
+
+        # Apply TransformWithState.
+        df_grouped = df_final.groupBy(key_column)
+        if use_pandas:
+            tws_df = df_grouped.transformWithStateInPandas(
+                statefulProcessor=processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="None",
+            )
+        else:
+            tws_df = df_grouped.transformWithState(
+                statefulProcessor=processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="None",
+            )
+
+        # Start streaming query collecting TWS results.
+        actual_results: list[list[Row]] = []
+
+        def _collect_results(batch_df: DataFrame, batch_id):
+            rows: list[Row] = batch_df.collect()
+            actual_results.append(rows)
+
+        query: StreamingQuery = (
+            tws_df.writeStream.queryName("test_tmp_query")
+            .foreachBatch(_collect_results)
+            .outputMode("update")
+            .start()
+        )
+
+        # Process each input batch.
+        def _write_batch(batch_id: int, batch: list[tuple]):
+            with open(input_path + f"/test{batch_id}.txt", "w") as f:
+                for row in batch:
+                    assert len(row) == num_input_columns
+                    f.write(",".join(str(value) for value in row) + "\n")
+
+        try:
+            for batch_id, batch in enumerate(batches):
+                _write_batch(batch_id, batch)
+                query.processAllAvailable()
+            query.awaitTermination(1)
+            self.assertTrue(query.exception() is None)
+        finally:
+            query.stop()
+
+        # Assert actual results are equal to expected results.
+        self.assertEqual(len(actual_results), num_batches)
+        self.assertEqual(len(expected_results), num_batches)
+        for batch_id in range(num_batches):
+            self.assertCountEqual(actual_results[batch_id], 
expected_results[batch_id])

Review Comment:
   Wow, I just had the same impression after reading your comment and had to 
check again.
   
   Here are people complaining about this name being confusing: 
https://discuss.python.org/t/unittest-assertcountequal-naming-is-confuse/6480
   
   I added a comment explaining what this does.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to