fedimser commented on code in PR #53758: URL: https://github.com/apache/spark/pull/53758#discussion_r2756166663
########## python/pyspark/sql/tests/pandas/streaming/test_tws_tester.py: ########## @@ -0,0 +1,1294 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import random +import tempfile +import unittest + +import pandas as pd +import pandas.testing as pdt + +from pyspark import SparkConf +from pyspark.sql import DataFrame +from pyspark.sql.functions import split +from pyspark.sql.streaming import StatefulProcessor, TwsTester +from pyspark.sql.streaming.query import StreamingQuery +from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import ( + AllMethodsTestProcessorFactory, + EventTimeCountProcessorFactory, + EventTimeSessionProcessorFactory, + RunningCountStatefulProcessorFactory, + SessionTimeoutProcessorFactory, + TopKProcessorFactory, + WordFrequencyProcessorFactory, +) +from pyspark.sql.types import ( + DoubleType, + IntegerType, + Row, + StringType, + StructField, + StructType, +) +from pyspark.errors import PySparkValueError, PySparkAssertionError +from pyspark.errors.exceptions.base import IllegalArgumentException +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + [email protected]( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message or "", +) +class TwsTesterTests(ReusedSQLTestCase): + def test_running_count_processor(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + self.assertEqual(tester.test("key1", [Row(value="a")]), [Row(key="key1", count=1)]) + self.assertEqual( + tester.test("key2", [Row(value="a"), Row(value="a")]), + [Row(key="key2", count=2)], + ) + self.assertEqual(tester.test("key3", [Row(value="a")]), [Row(key="key3", count=1)]) + self.assertEqual( + tester.test("key1", [Row(value="a"), Row(value="a"), Row(value="a")]), + [Row(key="key1", count=4)], + ) + + self.assertEqual(tester.peekValueState("count", "key1"), (4,)) + self.assertEqual(tester.peekValueState("count", "key2"), (2,)) + self.assertEqual(tester.peekValueState("count", "key3"), (1,)) + self.assertIsNone(tester.peekValueState("count", "key4")) + + def test_running_count_processor_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + + ans1 = tester.testInPandas("key1", pd.DataFrame({"value": ["a"]})) + expected1 = pd.DataFrame({"key": ["key1"], "count": [1]}) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + ans2 = tester.testInPandas("key2", pd.DataFrame({"value": ["a", "a"]})) + expected2 = pd.DataFrame({"key": ["key2"], "count": [2]}) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + ans3 = tester.testInPandas("key3", pd.DataFrame({"value": ["a"]})) + expected3 = pd.DataFrame({"key": ["key3"], "count": [1]}) + pdt.assert_frame_equal(ans3, expected3, check_like=True) + + ans4 = tester.testInPandas("key1", pd.DataFrame({"value": ["a", "a", "a"]})) + expected4 = pd.DataFrame({"key": ["key1"], "count": [4]}) + pdt.assert_frame_equal(ans4, expected4, check_like=True) + + def test_direct_access_to_value_state(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + tester.updateValueState("count", "foo", (5,)) + tester.test("foo", [Row(value="q")]) + self.assertEqual(tester.peekValueState("count", "foo"), (6,)) + + def test_topk_processor(self): + processor = TopKProcessorFactory(k=2).row() + tester = TwsTester(processor) + + ans1 = tester.test("key1", [Row(score=2.0), Row(score=3.0), Row(score=1.0)]) + self.assertEqual(ans1, [Row(key="key1", score=3.0), Row(key="key1", score=2.0)]) + + ans2 = tester.test( + "key2", [Row(score=10.0), Row(score=20.0), Row(score=30.0), Row(score=40.0)] + ) + self.assertEqual(ans2, [Row(key="key2", score=40.0), Row(key="key2", score=30.0)]) + + ans3 = tester.test("key3", [Row(score=100.0)]) + self.assertEqual(ans3, [Row(key="key3", score=100.0)]) + + self.assertEqual(tester.peekListState("topK", "key1"), [(3.0,), (2.0,)]) + self.assertEqual(tester.peekListState("topK", "key2"), [(40.0,), (30.0,)]) + self.assertEqual(tester.peekListState("topK", "key3"), [(100.0,)]) + self.assertEqual(tester.peekListState("topK", "key4"), []) + + ans4 = tester.test("key1", [Row(score=10.0)]) + self.assertEqual(ans4, [Row(key="key1", score=10.0), Row(key="key1", score=3.0)]) + self.assertEqual(tester.peekListState("topK", "key1"), [(10.0,), (3.0,)]) + + def test_topk_processor_pandas(self): + processor = TopKProcessorFactory(k=2).pandas() + tester = TwsTester(processor) + + ans1 = tester.testInPandas("key1", pd.DataFrame({"score": [2.0, 3.0, 1.0]})) + expected1 = pd.DataFrame({"key": ["key1", "key1"], "score": [3.0, 2.0]}) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + ans2 = tester.testInPandas("key2", pd.DataFrame({"score": [10.0, 20.0, 30.0, 40.0]})) + expected2 = pd.DataFrame({"key": ["key2", "key2"], "score": [40.0, 30.0]}) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + ans3 = tester.testInPandas("key3", pd.DataFrame({"score": [100.0]})) + expected3 = pd.DataFrame({"key": ["key3"], "score": [100.0]}) + pdt.assert_frame_equal(ans3, expected3, check_like=True) + + ans4 = tester.testInPandas("key1", pd.DataFrame({"score": [10.0]})) + expected4 = pd.DataFrame({"key": ["key1", "key1"], "score": [10.0, 3.0]}) + pdt.assert_frame_equal(ans4, expected4, check_like=True) + + def test_direct_access_to_list_state(self): + processor = TopKProcessorFactory(k=2).row() + tester = TwsTester(processor) + + tester.updateListState("topK", "a", [(6.0,), (5.0,)]) + tester.updateListState("topK", "b", [(8.0,), (7.0,)]) + tester.test("a", [Row(score=10.0)]) + tester.test("b", [Row(score=7.5)]) + tester.test("c", [Row(score=1.0)]) + + assert tester.peekListState("topK", "a") == [(10.0,), (6.0,)] + assert tester.peekListState("topK", "b") == [(8.0,), (7.5,)] + assert tester.peekListState("topK", "c") == [(1.0,)] + assert tester.peekListState("topK", "d") == [] + + def test_word_frequency_processor(self): + processor = WordFrequencyProcessorFactory().row() + tester = TwsTester(processor) + + ans1 = tester.test( + "user1", + [ + Row(word="hello"), + Row(word="world"), + Row(word="hello"), + Row(word="world"), + ], + ) + self.assertEqual( + ans1, + [ + Row(key="user1", word="hello", count=1), + Row(key="user1", word="world", count=1), + Row(key="user1", word="hello", count=2), + Row(key="user1", word="world", count=2), + ], + ) + + ans2 = tester.test("user2", [Row(word="hello"), Row(word="spark")]) + self.assertEqual( + ans2, + [ + Row(key="user2", word="hello", count=1), + Row(key="user2", word="spark", count=1), + ], + ) + + # Check state using peekMapState. + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (2,), ("world",): (2,)}, + ) + self.assertEqual( + tester.peekMapState("frequencies", "user2"), + {("hello",): (1,), ("spark",): (1,)}, + ) + self.assertEqual(tester.peekMapState("frequencies", "user3"), {}) + + # Process more data for user1. + ans3 = tester.test("user1", [Row(word="hello"), Row(word="test")]) + self.assertEqual( + ans3, + [ + Row(key="user1", word="hello", count=3), + Row(key="user1", word="test", count=1), + ], + ) + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (3,), ("world",): (2,), ("test",): (1,)}, + ) + + def test_word_frequency_processor_pandas(self): + processor = WordFrequencyProcessorFactory().pandas() + tester = TwsTester(processor) + + input_df1 = pd.DataFrame({"word": ["hello", "world", "hello", "world"]}) + ans1 = tester.testInPandas("user1", input_df1) + expected1 = pd.DataFrame( + { + "key": ["user1", "user1", "user1", "user1"], + "word": ["hello", "world", "hello", "world"], + "count": [1, 1, 2, 2], + } + ) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + input_df2 = pd.DataFrame({"word": ["hello", "spark"]}) + ans2 = tester.testInPandas("user2", input_df2) + expected2 = pd.DataFrame( + { + "key": ["user2", "user2"], + "word": ["hello", "spark"], + "count": [1, 1], + } + ) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + input_df3 = pd.DataFrame({"word": ["hello", "test"]}) + ans3 = tester.testInPandas("user1", input_df3) + expected3 = pd.DataFrame( + { + "key": ["user1", "user1"], + "word": ["hello", "test"], + "count": [3, 1], + } + ) + pdt.assert_frame_equal(ans3, expected3, check_like=True) + + def test_direct_access_to_map_state(self): + processor = WordFrequencyProcessorFactory().row() + tester = TwsTester(processor) + + tester.updateMapState("frequencies", "user1", {("hello",): (5,), ("world",): (3,)}) + tester.updateMapState("frequencies", "user2", {("spark",): (10,)}) + + tester.test("user1", [Row(word="hello"), Row(word="goodbye")]) + tester.test("user2", [Row(word="spark")]) + tester.test("user3", [Row(word="new")]) + + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (6,), ("world",): (3,), ("goodbye",): (1,)}, + ) + self.assertEqual(tester.peekMapState("frequencies", "user2"), {("spark",): (11,)}) + self.assertEqual(tester.peekMapState("frequencies", "user3"), {("new",): (1,)}) + self.assertEqual(tester.peekMapState("frequencies", "user4"), {}) + + # Example of how TwsTester can be used to test step function. + def test_step_function(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + # Example of helper function using TwsTester to inspect how processing a single row changes + # state. + def step_function(key: str, input_row: str, state_in: int) -> int: + tester.updateValueState("count", key, (state_in,)) + tester.test(key, [Row(value=input_row)]) + return tester.peekValueState("count", key)[0] + + self.assertEqual(step_function("key1", "a", 10), 11) + + # Example of how TwsTester can be used to simulate real-time mode (row by row processing). + def test_row_by_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + # Example of helper function to test how TransformWithState processes rows one-by-one, + # which can be used to simulate real-time mode. + def test_row_by_row_helper(input_rows: list[tuple[str, Row]]) -> list[Row]: + result: list[Row] = [] + for key, row in input_rows: + result += tester.test(key, [row]) + return result + + ans = test_row_by_row_helper( + [ + ("key1", Row(value="a")), + ("key2", Row(value="b")), + ("key1", Row(value="c")), + ("key2", Row(value="b")), + ("key1", Row(value="c")), + ("key1", Row(value="c")), + ("key3", Row(value="q")), + ] + ) + self.assertEqual( + ans, + [ + Row(key="key1", count=1), + Row(key="key2", count=1), + Row(key="key1", count=2), + Row(key="key2", count=2), + Row(key="key1", count=3), + Row(key="key1", count=4), + Row(key="key3", count=1), + ], + ) + + # Tests that TwsTester calls handleInitialState. + def test_initial_state_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester( + processor, + initialStateRow=[ + ("a", Row(initial_count=10)), + ("b", Row(initial_count=20)), + ], + ) + self.assertEqual(tester.peekValueState("count", "a"), (10,)) + self.assertEqual(tester.peekValueState("count", "b"), (20,)) + + ans1 = tester.test("a", [Row(value="a")]) + self.assertEqual(ans1, [Row(key="a", count=11)]) + + ans2 = tester.test("c", [Row(value="c")]) + self.assertEqual(ans2, [Row(key="c", count=1)]) + + def test_initial_state_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester( + processor, + initialStatePandas=[ + ("a", pd.DataFrame({"initial_count": [10]})), + ("b", pd.DataFrame({"initial_count": [20]})), + ], + ) + self.assertEqual(tester.peekValueState("count", "a"), (10,)) + self.assertEqual(tester.peekValueState("count", "b"), (20,)) + + ans1 = tester.testInPandas("a", pd.DataFrame({"value": ["a"]})) + expected1 = pd.DataFrame({"key": ["a"], "count": [11]}) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + ans2 = tester.testInPandas("c", pd.DataFrame({"value": ["c"]})) + expected2 = pd.DataFrame({"key": ["c"], "count": [1]}) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + def test_all_methods_processor(self): + """Test that TwsTester exercises all state methods.""" + processor = AllMethodsTestProcessorFactory().row() + tester = TwsTester(processor) + + results = tester.test( + "k", + [ + Row(cmd="value-exists"), # false + Row(cmd="value-set"), # set to 42 + Row(cmd="value-exists"), # true + Row(cmd="value-clear"), # clear + Row(cmd="value-exists"), # false again + Row(cmd="list-exists"), # false + Row(cmd="list-append"), # append a, b + Row(cmd="list-exists"), # true + Row(cmd="list-append-array"), # append c, d + Row(cmd="list-get"), # a,b,c,d + Row(cmd="map-exists"), # false + Row(cmd="map-add"), # add x=1, y=2, z=3 + Row(cmd="map-exists"), # true + Row(cmd="map-keys"), # x,y,z + Row(cmd="map-values"), # 1,2,3 + Row(cmd="map-iterator"), # x=1,y=2,z=3 + Row(cmd="map-remove"), # remove y + Row(cmd="map-keys"), # x,z + Row(cmd="map-clear"), # clear map + Row(cmd="map-exists"), # false + ], + ) + + self.assertEqual( + results, + [ + Row(key="k", result="value-exists:False"), + Row(key="k", result="value-set:done"), + Row(key="k", result="value-exists:True"), + Row(key="k", result="value-clear:done"), + Row(key="k", result="value-exists:False"), + Row(key="k", result="list-exists:False"), + Row(key="k", result="list-append:done"), + Row(key="k", result="list-exists:True"), + Row(key="k", result="list-append-array:done"), + Row(key="k", result="list-get:a,b,c,d"), + Row(key="k", result="map-exists:False"), + Row(key="k", result="map-add:done"), + Row(key="k", result="map-exists:True"), + Row(key="k", result="map-keys:x,y,z"), + Row(key="k", result="map-values:1,2,3"), + Row(key="k", result="map-iterator:x=1,y=2,z=3"), + Row(key="k", result="map-remove:done"), + Row(key="k", result="map-keys:x,z"), + Row(key="k", result="map-clear:done"), + Row(key="k", result="map-exists:False"), + ], + ) + + def test_both_initial_states_specified(self): + processor = RunningCountStatefulProcessorFactory().row() + with self.assertRaises(AssertionError): + TwsTester( + processor, + initialStateRow=[("a", Row(initial_count=10))], + initialStatePandas=[("a", pd.DataFrame({"initial_count": [10]}))], + ) + + def test_timer_registration_raises_error_in_none_mode(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, timeMode="None") + with self.assertRaisesRegex(PySparkValueError, "UNSUPPORTED_OPERATION"): + tester.handle.registerTimer(12345) + + def test_delete_timer_raises_error_in_none_mode(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, timeMode="None") + with self.assertRaisesRegex(PySparkValueError, "UNSUPPORTED_OPERATION"): + tester.handle.deleteTimer(12345) + + def test_list_timers_raises_error_in_none_mode(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, timeMode="None") + with self.assertRaisesRegex(PySparkValueError, "UNSUPPORTED_OPERATION"): + list(tester.handle.listTimers()) + + def test_empty_input_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key1", []) + self.assertEqual(result, [Row(key="key1", count=0)]) + + def test_empty_input_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": []}) + result = tester.testInPandas("key1", input_df) + expected = pd.DataFrame({"key": ["key1"], "count": [0]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_empty_initial_state_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, initialStateRow=[]) + result = tester.test("key1", [Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count=1)]) + + def test_empty_initial_state_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor, initialStatePandas=[]) + result = tester.testInPandas("key1", pd.DataFrame({"value": ["a"]})) + expected = pd.DataFrame({"key": ["key1"], "count": [1]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_none_values_in_input(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key1", [Row(value=None), Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count=2)]) + + def test_single_row_input(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key1", [Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count=1)]) + + def test_peek_nonexistent_state(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.peekValueState("count", "nonexistent_key") + self.assertIsNone(result) + + def test_update_nonexistent_state(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + with self.assertRaises(AssertionError): + tester.updateValueState("nonexistent_state", "key1", (5,)) + + def test_peek_state_before_initialization(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + with self.assertRaises(AssertionError): + tester.peekValueState("nonexistent_state", "key1") + + def test_clear_nonexistent_list_state(self): + processor = TopKProcessorFactory(k=2).row() + tester = TwsTester(processor) + tester.handle.setGroupingKey("key1") + list_state = tester.handle.getListState("topK", "double") + list_state.clear() + + def test_delete_if_exists_nonexistent(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + tester.handle.deleteIfExists("nonexistent_state") + + def test_numeric_key(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + self.assertEqual(tester.test(1, [Row(value="a"), Row(value="c")]), [Row(key=1, count=2)]) + self.assertEqual(tester.test(2, [Row(value="b")]), [Row(key=2, count=1)]) + + def test_tuple_key(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test((1, 2), [Row(value="a"), Row(value="b")]) + self.assertEqual(result, [Row(key=(1, 2), count=2)]) + + def test_special_characters_in_key(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key@#$%", [Row(value="a"), Row(value="b")]) + self.assertEqual(result, [Row(key="key@#$%", count=2)]) + + def test_null_key_value_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test(None, [Row(value="a"), Row(value="b")]) + self.assertEqual(result, [Row(key=None, count=2)]) + + def test_null_key_value_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": ["a", "b"]}) + result = tester.testInPandas(None, input_df) + expected = pd.DataFrame({"key": [None], "count": [2]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_multiple_value_states(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MultiValueStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.state1 = handle.getValueState("state1", "int") + self.state2 = handle.getValueState("state2", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + val1 = self.state1.get() if self.state1.exists() else None + count1 = val1[0] if val1 else 0 + val2 = self.state2.get() if self.state2.exists() else None + count2 = val2[0] if val2 else 0 + for row in rows: + count1 += 1 + count2 += 2 + self.state1.update((count1,)) + self.state2.update((count2,)) + yield Row(key=key[0], count1=count1, count2=count2) + + processor = MultiValueStateProcessor() + tester = TwsTester(processor) + result = tester.test("key1", [Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count1=1, count2=2)]) + self.assertEqual(tester.peekValueState("state1", "key1"), (1,)) + self.assertEqual(tester.peekValueState("state2", "key1"), (2,)) + + def test_multiple_list_states(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MultiListStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.list1 = handle.getListState("list1", "string") + self.list2 = handle.getListState("list2", "string") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + for row in rows: + self.list1.appendValue((row.value,)) + self.list2.appendValue((row.value + "_2",)) + yield Row(key=key[0], count=1) + + processor = MultiListStateProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + self.assertEqual(tester.peekListState("list1", "key1"), [("a",)]) + self.assertEqual(tester.peekListState("list2", "key1"), [("a_2",)]) + + def test_multiple_map_states(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MultiMapStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.map1 = handle.getMapState("map1", "string", "int") + self.map2 = handle.getMapState("map2", "string", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + for row in rows: + val1 = self.map1.getValue((row.word,)) + count1 = val1[0] if val1 is not None else 0 + val2 = self.map2.getValue((row.word,)) + count2 = val2[0] if val2 is not None else 0 + self.map1.updateValue((row.word,), (count1 + 1,)) + self.map2.updateValue((row.word,), (count2 + 2,)) + yield Row(key=key[0], word=row.word) + + processor = MultiMapStateProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(word="hello")]) + self.assertEqual(tester.peekMapState("map1", "key1"), {("hello",): (1,)}) + self.assertEqual(tester.peekMapState("map2", "key1"), {("hello",): (2,)}) + + def test_mixed_state_types(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MixedStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.value_state = handle.getValueState("value", "int") + self.list_state = handle.getListState("list", "string") + self.map_state = handle.getMapState("map", "string", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + for row in rows: + val = self.value_state.get() if self.value_state.exists() else None + count = val[0] if val else 0 + self.value_state.update((count + 1,)) + self.list_state.appendValue((row.value,)) + self.map_state.updateValue((row.value,), (1,)) + yield Row(key=key[0], count=1) + + processor = MixedStateProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + self.assertEqual(tester.peekValueState("value", "key1"), (1,)) + self.assertEqual(tester.peekListState("list", "key1"), [("a",)]) + self.assertEqual(tester.peekMapState("map", "key1"), {("a",): (1,)}) + + def test_pandas_with_multiple_rows(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": ["a", "b", "c"]}) + result = tester.testInPandas("key1", input_df) + expected = pd.DataFrame({"key": ["key1"], "count": [3]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_pandas_dtype_preservation(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": ["a"]}) + result = tester.testInPandas("key1", input_df) + self.assertEqual(result["count"].dtype, "int64") + + def test_processor_init_called_once(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + init_call_count = [0] + + class InitCountingProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + init_call_count[0] += 1 + self.handle = handle + self.state = handle.getValueState("count", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + val = self.state.get() if self.state.exists() else None + count = val[0] if val else 0 + for row in rows: + count += 1 + self.state.update((count,)) + yield Row(key=key[0], count=count) + + processor = InitCountingProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + tester.test("key1", [Row(value="b")]) + self.assertEqual(init_call_count[0], 1) + + def test_processor_state_persists_across_batches(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + result = tester.test("key1", [Row(value="b")]) + self.assertEqual(result, [Row(key="key1", count=2)]) + + def test_handle_initial_state_not_called_without_initial_state(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + initial_state_call_count = [0] + + class InitialStateCountingProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.state = handle.getValueState("count", "int") + + def handleInitialState(self, key, initialState, timerValues) -> None: + initial_state_call_count[0] += 1 + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + val = self.state.get() if self.state.exists() else None + count = val[0] if val else 0 + for row in rows: + count += 1 + self.state.update((count,)) + yield Row(key=key[0], count=count) + + processor = InitialStateCountingProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + self.assertEqual(initial_state_call_count[0], 0) + + def test_delete_value_state(self): + """Test that deleteState correctly deletes value state for a given key.""" + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + tester.updateValueState("count", "key1", (10,)) + tester.updateValueState("count", "key2", (20,)) + self.assertEqual(tester.peekValueState("count", "key1"), (10,)) + + tester.deleteState("count", "key1") + self.assertIsNone(tester.peekValueState("count", "key1")) + self.assertEqual(tester.peekValueState("count", "key2"), (20,)) + + def test_delete_list_state(self): + """Test that deleteState correctly deletes list state for a given key.""" + processor = TopKProcessorFactory(k=3).row() + tester = TwsTester(processor) + + tester.updateListState("topK", "key1", [(1.0,), (2.0,), (3.0,)]) + tester.updateListState("topK", "key2", [(4.0,), (5.0,)]) + self.assertEqual(tester.peekListState("topK", "key1"), [(1.0,), (2.0,), (3.0,)]) + + tester.deleteState("topK", "key1") + self.assertEqual(tester.peekListState("topK", "key1"), []) + self.assertEqual(tester.peekListState("topK", "key2"), [(4.0,), (5.0,)]) + + def test_delete_map_state(self): + """Test that deleteState correctly deletes map state for a given key.""" + processor = WordFrequencyProcessorFactory().row() + tester = TwsTester(processor) + + tester.updateMapState("frequencies", "user1", {("hello",): (5,), ("world",): (3,)}) + tester.updateMapState("frequencies", "user2", {("spark",): (10,)}) + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (5,), ("world",): (3,)}, + ) + + tester.deleteState("frequencies", "user1") + self.assertEqual(tester.peekMapState("frequencies", "user1"), {}) + self.assertEqual(tester.peekMapState("frequencies", "user2"), {("spark",): (10,)}) + + def test_delete_nonexistent_state_raises_error(self): + """Test that deleteState raises an error for non-existent state.""" + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + with self.assertRaisesRegex(PySparkAssertionError, "STATE_NOT_EXISTS"): + tester.deleteState("nonexistent_state", "key1") + + # Timer tests. + + def test_processing_time_timers(self): + """Test that TwsTester supports ProcessingTime timers.""" + processor = SessionTimeoutProcessorFactory().row() + tester = TwsTester(processor, timeMode="ProcessingTime") + + # Process input for key1 - should register a timer at t=10000. + result1 = tester.test("key1", [Row(value="hello")]) + self.assertEqual(result1, [Row(key="key1", result="received:hello")]) + + # Set processing time to 5000 - timer should NOT fire yet. + expired1 = tester.setProcessingTime(5000) + self.assertEqual(expired1, []) + + # Process input for key2 at t=5000 - should register timer at t=15000. + result2 = tester.test("key2", [Row(value="world")]) + self.assertEqual(result2, [Row(key="key2", result="received:world")]) + + # Set processing time to 11000 - key1's timer should fire. + expired2 = tester.setProcessingTime(11000) + self.assertEqual(expired2, [Row(key="key1", result="session-expired")]) + + # Set processing time to 16000 - key2's timer should fire. + expired3 = tester.setProcessingTime(16000) + self.assertEqual(expired3, [Row(key="key2", result="session-expired")]) + + # Verify state is cleared after session expiry. + self.assertIsNone(tester.peekValueState("lastSeen", "key1")) + self.assertIsNone(tester.peekValueState("lastSeen", "key2")) + + def test_processing_time_timers_pandas(self): + """Test that TwsTester supports ProcessingTime timers in Pandas mode.""" + processor = SessionTimeoutProcessorFactory().pandas() + tester = TwsTester(processor, timeMode="ProcessingTime") + + # Process input for key1 - should register a timer at t=10000. + result1 = tester.testInPandas("key1", pd.DataFrame({"value": ["hello"]})) + expected1 = pd.DataFrame({"key": ["key1"], "result": ["received:hello"]}) + pdt.assert_frame_equal(result1, expected1, check_like=True) + + # Set processing time to 5000 - timer should NOT fire yet. + expired1 = tester.setProcessingTime(5000) + self.assertEqual(len(expired1), 0) + + # Set processing time to 11000 - timer should fire. + expired2 = tester.setProcessingTime(11000) + expected2 = pd.DataFrame({"key": ["key1"], "result": ["session-expired"]}) + pdt.assert_frame_equal(expired2, expected2, check_like=True) + + def test_event_time_timers_manual_watermark(self): + """Test that TwsTester supports EventTime timers fired by manual watermark advance.""" + processor = EventTimeSessionProcessorFactory().row() + + def event_time_extractor(row: Row) -> int: + return row.event_time_ms + + tester = TwsTester( + processor, + timeMode="EventTime", + eventTimeExtractor=event_time_extractor, + ) + + # Process event at t=10000 for key1 - registers timer at t=15000. + result1 = tester.test("key1", [Row(event_time_ms=10000, value="hello")]) + self.assertEqual(result1, [Row(key="key1", result="received:hello@10000")]) + + # Process event at t=11000 for key2 - registers timer at t=16000. + result2 = tester.test("key2", [Row(event_time_ms=11000, value="world")]) + self.assertEqual(result2, [Row(key="key2", result="received:world@11000")]) + + # Set watermark to 15000 - key1's timer should fire. + expired1 = tester.setWatermark(15000) + self.assertEqual(len(expired1), 1) + self.assertEqual(expired1[0], Row(key="key1", result="session-expired@watermark=15000")) + + # Set watermark to 16000 - key2's timer should fire. + expired2 = tester.setWatermark(16000) + self.assertEqual(len(expired2), 1) + self.assertEqual(expired2[0], Row(key="key2", result="session-expired@watermark=16000")) + + # Verify state is cleared. + self.assertIsNone(tester.peekValueState("lastEventTime", "key1")) + self.assertIsNone(tester.peekValueState("lastEventTime", "key2")) + + def test_late_event_filtering(self): Review Comment: It is mostly in sync now. There are some differences that you pointed out, e.g. here we cover null keys or clearing nonexistent state. I will not remove these tests just because we don't have them in Scala version. Also some tests are duplicated to test Row/Pandas API, but I didn't do this for all tests, the decision which tests to duplicate was arbitrary, I just wanted to cover all code paths. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
