fedimser commented on code in PR #53758: URL: https://github.com/apache/spark/pull/53758#discussion_r2755792522
########## python/pyspark/sql/tests/pandas/streaming/test_tws_tester.py: ########## @@ -0,0 +1,1294 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import random +import tempfile +import unittest + +import pandas as pd +import pandas.testing as pdt + +from pyspark import SparkConf +from pyspark.sql import DataFrame +from pyspark.sql.functions import split +from pyspark.sql.streaming import StatefulProcessor, TwsTester +from pyspark.sql.streaming.query import StreamingQuery +from pyspark.sql.tests.pandas.helper.helper_pandas_transform_with_state import ( + AllMethodsTestProcessorFactory, + EventTimeCountProcessorFactory, + EventTimeSessionProcessorFactory, + RunningCountStatefulProcessorFactory, + SessionTimeoutProcessorFactory, + TopKProcessorFactory, + WordFrequencyProcessorFactory, +) +from pyspark.sql.types import ( + DoubleType, + IntegerType, + Row, + StringType, + StructField, + StructType, +) +from pyspark.errors import PySparkValueError, PySparkAssertionError +from pyspark.errors.exceptions.base import IllegalArgumentException +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + + [email protected]( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message or "", +) +class TwsTesterTests(ReusedSQLTestCase): + def test_running_count_processor(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + self.assertEqual(tester.test("key1", [Row(value="a")]), [Row(key="key1", count=1)]) + self.assertEqual( + tester.test("key2", [Row(value="a"), Row(value="a")]), + [Row(key="key2", count=2)], + ) + self.assertEqual(tester.test("key3", [Row(value="a")]), [Row(key="key3", count=1)]) + self.assertEqual( + tester.test("key1", [Row(value="a"), Row(value="a"), Row(value="a")]), + [Row(key="key1", count=4)], + ) + + self.assertEqual(tester.peekValueState("count", "key1"), (4,)) + self.assertEqual(tester.peekValueState("count", "key2"), (2,)) + self.assertEqual(tester.peekValueState("count", "key3"), (1,)) + self.assertIsNone(tester.peekValueState("count", "key4")) + + def test_running_count_processor_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + + ans1 = tester.testInPandas("key1", pd.DataFrame({"value": ["a"]})) + expected1 = pd.DataFrame({"key": ["key1"], "count": [1]}) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + ans2 = tester.testInPandas("key2", pd.DataFrame({"value": ["a", "a"]})) + expected2 = pd.DataFrame({"key": ["key2"], "count": [2]}) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + ans3 = tester.testInPandas("key3", pd.DataFrame({"value": ["a"]})) + expected3 = pd.DataFrame({"key": ["key3"], "count": [1]}) + pdt.assert_frame_equal(ans3, expected3, check_like=True) + + ans4 = tester.testInPandas("key1", pd.DataFrame({"value": ["a", "a", "a"]})) + expected4 = pd.DataFrame({"key": ["key1"], "count": [4]}) + pdt.assert_frame_equal(ans4, expected4, check_like=True) + + def test_direct_access_to_value_state(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + tester.updateValueState("count", "foo", (5,)) + tester.test("foo", [Row(value="q")]) + self.assertEqual(tester.peekValueState("count", "foo"), (6,)) + + def test_topk_processor(self): + processor = TopKProcessorFactory(k=2).row() + tester = TwsTester(processor) + + ans1 = tester.test("key1", [Row(score=2.0), Row(score=3.0), Row(score=1.0)]) + self.assertEqual(ans1, [Row(key="key1", score=3.0), Row(key="key1", score=2.0)]) + + ans2 = tester.test( + "key2", [Row(score=10.0), Row(score=20.0), Row(score=30.0), Row(score=40.0)] + ) + self.assertEqual(ans2, [Row(key="key2", score=40.0), Row(key="key2", score=30.0)]) + + ans3 = tester.test("key3", [Row(score=100.0)]) + self.assertEqual(ans3, [Row(key="key3", score=100.0)]) + + self.assertEqual(tester.peekListState("topK", "key1"), [(3.0,), (2.0,)]) + self.assertEqual(tester.peekListState("topK", "key2"), [(40.0,), (30.0,)]) + self.assertEqual(tester.peekListState("topK", "key3"), [(100.0,)]) + self.assertEqual(tester.peekListState("topK", "key4"), []) + + ans4 = tester.test("key1", [Row(score=10.0)]) + self.assertEqual(ans4, [Row(key="key1", score=10.0), Row(key="key1", score=3.0)]) + self.assertEqual(tester.peekListState("topK", "key1"), [(10.0,), (3.0,)]) + + def test_topk_processor_pandas(self): + processor = TopKProcessorFactory(k=2).pandas() + tester = TwsTester(processor) + + ans1 = tester.testInPandas("key1", pd.DataFrame({"score": [2.0, 3.0, 1.0]})) + expected1 = pd.DataFrame({"key": ["key1", "key1"], "score": [3.0, 2.0]}) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + ans2 = tester.testInPandas("key2", pd.DataFrame({"score": [10.0, 20.0, 30.0, 40.0]})) + expected2 = pd.DataFrame({"key": ["key2", "key2"], "score": [40.0, 30.0]}) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + ans3 = tester.testInPandas("key3", pd.DataFrame({"score": [100.0]})) + expected3 = pd.DataFrame({"key": ["key3"], "score": [100.0]}) + pdt.assert_frame_equal(ans3, expected3, check_like=True) + + ans4 = tester.testInPandas("key1", pd.DataFrame({"score": [10.0]})) + expected4 = pd.DataFrame({"key": ["key1", "key1"], "score": [10.0, 3.0]}) + pdt.assert_frame_equal(ans4, expected4, check_like=True) + + def test_direct_access_to_list_state(self): + processor = TopKProcessorFactory(k=2).row() + tester = TwsTester(processor) + + tester.updateListState("topK", "a", [(6.0,), (5.0,)]) + tester.updateListState("topK", "b", [(8.0,), (7.0,)]) + tester.test("a", [Row(score=10.0)]) + tester.test("b", [Row(score=7.5)]) + tester.test("c", [Row(score=1.0)]) + + assert tester.peekListState("topK", "a") == [(10.0,), (6.0,)] + assert tester.peekListState("topK", "b") == [(8.0,), (7.5,)] + assert tester.peekListState("topK", "c") == [(1.0,)] + assert tester.peekListState("topK", "d") == [] + + def test_word_frequency_processor(self): + processor = WordFrequencyProcessorFactory().row() + tester = TwsTester(processor) + + ans1 = tester.test( + "user1", + [ + Row(word="hello"), + Row(word="world"), + Row(word="hello"), + Row(word="world"), + ], + ) + self.assertEqual( + ans1, + [ + Row(key="user1", word="hello", count=1), + Row(key="user1", word="world", count=1), + Row(key="user1", word="hello", count=2), + Row(key="user1", word="world", count=2), + ], + ) + + ans2 = tester.test("user2", [Row(word="hello"), Row(word="spark")]) + self.assertEqual( + ans2, + [ + Row(key="user2", word="hello", count=1), + Row(key="user2", word="spark", count=1), + ], + ) + + # Check state using peekMapState. + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (2,), ("world",): (2,)}, + ) + self.assertEqual( + tester.peekMapState("frequencies", "user2"), + {("hello",): (1,), ("spark",): (1,)}, + ) + self.assertEqual(tester.peekMapState("frequencies", "user3"), {}) + + # Process more data for user1. + ans3 = tester.test("user1", [Row(word="hello"), Row(word="test")]) + self.assertEqual( + ans3, + [ + Row(key="user1", word="hello", count=3), + Row(key="user1", word="test", count=1), + ], + ) + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (3,), ("world",): (2,), ("test",): (1,)}, + ) + + def test_word_frequency_processor_pandas(self): + processor = WordFrequencyProcessorFactory().pandas() + tester = TwsTester(processor) + + input_df1 = pd.DataFrame({"word": ["hello", "world", "hello", "world"]}) + ans1 = tester.testInPandas("user1", input_df1) + expected1 = pd.DataFrame( + { + "key": ["user1", "user1", "user1", "user1"], + "word": ["hello", "world", "hello", "world"], + "count": [1, 1, 2, 2], + } + ) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + input_df2 = pd.DataFrame({"word": ["hello", "spark"]}) + ans2 = tester.testInPandas("user2", input_df2) + expected2 = pd.DataFrame( + { + "key": ["user2", "user2"], + "word": ["hello", "spark"], + "count": [1, 1], + } + ) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + input_df3 = pd.DataFrame({"word": ["hello", "test"]}) + ans3 = tester.testInPandas("user1", input_df3) + expected3 = pd.DataFrame( + { + "key": ["user1", "user1"], + "word": ["hello", "test"], + "count": [3, 1], + } + ) + pdt.assert_frame_equal(ans3, expected3, check_like=True) + + def test_direct_access_to_map_state(self): + processor = WordFrequencyProcessorFactory().row() + tester = TwsTester(processor) + + tester.updateMapState("frequencies", "user1", {("hello",): (5,), ("world",): (3,)}) + tester.updateMapState("frequencies", "user2", {("spark",): (10,)}) + + tester.test("user1", [Row(word="hello"), Row(word="goodbye")]) + tester.test("user2", [Row(word="spark")]) + tester.test("user3", [Row(word="new")]) + + self.assertEqual( + tester.peekMapState("frequencies", "user1"), + {("hello",): (6,), ("world",): (3,), ("goodbye",): (1,)}, + ) + self.assertEqual(tester.peekMapState("frequencies", "user2"), {("spark",): (11,)}) + self.assertEqual(tester.peekMapState("frequencies", "user3"), {("new",): (1,)}) + self.assertEqual(tester.peekMapState("frequencies", "user4"), {}) + + # Example of how TwsTester can be used to test step function. + def test_step_function(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + # Example of helper function using TwsTester to inspect how processing a single row changes + # state. + def step_function(key: str, input_row: str, state_in: int) -> int: + tester.updateValueState("count", key, (state_in,)) + tester.test(key, [Row(value=input_row)]) + return tester.peekValueState("count", key)[0] + + self.assertEqual(step_function("key1", "a", 10), 11) + + # Example of how TwsTester can be used to simulate real-time mode (row by row processing). + def test_row_by_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + + # Example of helper function to test how TransformWithState processes rows one-by-one, + # which can be used to simulate real-time mode. + def test_row_by_row_helper(input_rows: list[tuple[str, Row]]) -> list[Row]: + result: list[Row] = [] + for key, row in input_rows: + result += tester.test(key, [row]) + return result + + ans = test_row_by_row_helper( + [ + ("key1", Row(value="a")), + ("key2", Row(value="b")), + ("key1", Row(value="c")), + ("key2", Row(value="b")), + ("key1", Row(value="c")), + ("key1", Row(value="c")), + ("key3", Row(value="q")), + ] + ) + self.assertEqual( + ans, + [ + Row(key="key1", count=1), + Row(key="key2", count=1), + Row(key="key1", count=2), + Row(key="key2", count=2), + Row(key="key1", count=3), + Row(key="key1", count=4), + Row(key="key3", count=1), + ], + ) + + # Tests that TwsTester calls handleInitialState. + def test_initial_state_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester( + processor, + initialStateRow=[ + ("a", Row(initial_count=10)), + ("b", Row(initial_count=20)), + ], + ) + self.assertEqual(tester.peekValueState("count", "a"), (10,)) + self.assertEqual(tester.peekValueState("count", "b"), (20,)) + + ans1 = tester.test("a", [Row(value="a")]) + self.assertEqual(ans1, [Row(key="a", count=11)]) + + ans2 = tester.test("c", [Row(value="c")]) + self.assertEqual(ans2, [Row(key="c", count=1)]) + + def test_initial_state_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester( + processor, + initialStatePandas=[ + ("a", pd.DataFrame({"initial_count": [10]})), + ("b", pd.DataFrame({"initial_count": [20]})), + ], + ) + self.assertEqual(tester.peekValueState("count", "a"), (10,)) + self.assertEqual(tester.peekValueState("count", "b"), (20,)) + + ans1 = tester.testInPandas("a", pd.DataFrame({"value": ["a"]})) + expected1 = pd.DataFrame({"key": ["a"], "count": [11]}) + pdt.assert_frame_equal(ans1, expected1, check_like=True) + + ans2 = tester.testInPandas("c", pd.DataFrame({"value": ["c"]})) + expected2 = pd.DataFrame({"key": ["c"], "count": [1]}) + pdt.assert_frame_equal(ans2, expected2, check_like=True) + + def test_all_methods_processor(self): + """Test that TwsTester exercises all state methods.""" + processor = AllMethodsTestProcessorFactory().row() + tester = TwsTester(processor) + + results = tester.test( + "k", + [ + Row(cmd="value-exists"), # false + Row(cmd="value-set"), # set to 42 + Row(cmd="value-exists"), # true + Row(cmd="value-clear"), # clear + Row(cmd="value-exists"), # false again + Row(cmd="list-exists"), # false + Row(cmd="list-append"), # append a, b + Row(cmd="list-exists"), # true + Row(cmd="list-append-array"), # append c, d + Row(cmd="list-get"), # a,b,c,d + Row(cmd="map-exists"), # false + Row(cmd="map-add"), # add x=1, y=2, z=3 + Row(cmd="map-exists"), # true + Row(cmd="map-keys"), # x,y,z + Row(cmd="map-values"), # 1,2,3 + Row(cmd="map-iterator"), # x=1,y=2,z=3 + Row(cmd="map-remove"), # remove y + Row(cmd="map-keys"), # x,z + Row(cmd="map-clear"), # clear map + Row(cmd="map-exists"), # false + ], + ) + + self.assertEqual( + results, + [ + Row(key="k", result="value-exists:False"), + Row(key="k", result="value-set:done"), + Row(key="k", result="value-exists:True"), + Row(key="k", result="value-clear:done"), + Row(key="k", result="value-exists:False"), + Row(key="k", result="list-exists:False"), + Row(key="k", result="list-append:done"), + Row(key="k", result="list-exists:True"), + Row(key="k", result="list-append-array:done"), + Row(key="k", result="list-get:a,b,c,d"), + Row(key="k", result="map-exists:False"), + Row(key="k", result="map-add:done"), + Row(key="k", result="map-exists:True"), + Row(key="k", result="map-keys:x,y,z"), + Row(key="k", result="map-values:1,2,3"), + Row(key="k", result="map-iterator:x=1,y=2,z=3"), + Row(key="k", result="map-remove:done"), + Row(key="k", result="map-keys:x,z"), + Row(key="k", result="map-clear:done"), + Row(key="k", result="map-exists:False"), + ], + ) + + def test_both_initial_states_specified(self): + processor = RunningCountStatefulProcessorFactory().row() + with self.assertRaises(AssertionError): + TwsTester( + processor, + initialStateRow=[("a", Row(initial_count=10))], + initialStatePandas=[("a", pd.DataFrame({"initial_count": [10]}))], + ) + + def test_timer_registration_raises_error_in_none_mode(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, timeMode="None") + with self.assertRaisesRegex(PySparkValueError, "UNSUPPORTED_OPERATION"): + tester.handle.registerTimer(12345) + + def test_delete_timer_raises_error_in_none_mode(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, timeMode="None") + with self.assertRaisesRegex(PySparkValueError, "UNSUPPORTED_OPERATION"): + tester.handle.deleteTimer(12345) + + def test_list_timers_raises_error_in_none_mode(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, timeMode="None") + with self.assertRaisesRegex(PySparkValueError, "UNSUPPORTED_OPERATION"): + list(tester.handle.listTimers()) + + def test_empty_input_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key1", []) + self.assertEqual(result, [Row(key="key1", count=0)]) + + def test_empty_input_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": []}) + result = tester.testInPandas("key1", input_df) + expected = pd.DataFrame({"key": ["key1"], "count": [0]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_empty_initial_state_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor, initialStateRow=[]) + result = tester.test("key1", [Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count=1)]) + + def test_empty_initial_state_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor, initialStatePandas=[]) + result = tester.testInPandas("key1", pd.DataFrame({"value": ["a"]})) + expected = pd.DataFrame({"key": ["key1"], "count": [1]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_none_values_in_input(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key1", [Row(value=None), Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count=2)]) + + def test_single_row_input(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key1", [Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count=1)]) + + def test_peek_nonexistent_state(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.peekValueState("count", "nonexistent_key") + self.assertIsNone(result) + + def test_update_nonexistent_state(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + with self.assertRaises(AssertionError): + tester.updateValueState("nonexistent_state", "key1", (5,)) + + def test_peek_state_before_initialization(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + with self.assertRaises(AssertionError): + tester.peekValueState("nonexistent_state", "key1") + + def test_clear_nonexistent_list_state(self): + processor = TopKProcessorFactory(k=2).row() + tester = TwsTester(processor) + tester.handle.setGroupingKey("key1") + list_state = tester.handle.getListState("topK", "double") + list_state.clear() + + def test_delete_if_exists_nonexistent(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + tester.handle.deleteIfExists("nonexistent_state") + + def test_numeric_key(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + self.assertEqual(tester.test(1, [Row(value="a"), Row(value="c")]), [Row(key=1, count=2)]) + self.assertEqual(tester.test(2, [Row(value="b")]), [Row(key=2, count=1)]) + + def test_tuple_key(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test((1, 2), [Row(value="a"), Row(value="b")]) + self.assertEqual(result, [Row(key=(1, 2), count=2)]) + + def test_special_characters_in_key(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test("key@#$%", [Row(value="a"), Row(value="b")]) + self.assertEqual(result, [Row(key="key@#$%", count=2)]) + + def test_null_key_value_row(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + result = tester.test(None, [Row(value="a"), Row(value="b")]) + self.assertEqual(result, [Row(key=None, count=2)]) + + def test_null_key_value_pandas(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": ["a", "b"]}) + result = tester.testInPandas(None, input_df) + expected = pd.DataFrame({"key": [None], "count": [2]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_multiple_value_states(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MultiValueStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.state1 = handle.getValueState("state1", "int") + self.state2 = handle.getValueState("state2", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + val1 = self.state1.get() if self.state1.exists() else None + count1 = val1[0] if val1 else 0 + val2 = self.state2.get() if self.state2.exists() else None + count2 = val2[0] if val2 else 0 + for row in rows: + count1 += 1 + count2 += 2 + self.state1.update((count1,)) + self.state2.update((count2,)) + yield Row(key=key[0], count1=count1, count2=count2) + + processor = MultiValueStateProcessor() + tester = TwsTester(processor) + result = tester.test("key1", [Row(value="a")]) + self.assertEqual(result, [Row(key="key1", count1=1, count2=2)]) + self.assertEqual(tester.peekValueState("state1", "key1"), (1,)) + self.assertEqual(tester.peekValueState("state2", "key1"), (2,)) + + def test_multiple_list_states(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MultiListStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.list1 = handle.getListState("list1", "string") + self.list2 = handle.getListState("list2", "string") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + for row in rows: + self.list1.appendValue((row.value,)) + self.list2.appendValue((row.value + "_2",)) + yield Row(key=key[0], count=1) + + processor = MultiListStateProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + self.assertEqual(tester.peekListState("list1", "key1"), [("a",)]) + self.assertEqual(tester.peekListState("list2", "key1"), [("a_2",)]) + + def test_multiple_map_states(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MultiMapStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.map1 = handle.getMapState("map1", "string", "int") + self.map2 = handle.getMapState("map2", "string", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + for row in rows: + val1 = self.map1.getValue((row.word,)) + count1 = val1[0] if val1 is not None else 0 + val2 = self.map2.getValue((row.word,)) + count2 = val2[0] if val2 is not None else 0 + self.map1.updateValue((row.word,), (count1 + 1,)) + self.map2.updateValue((row.word,), (count2 + 2,)) + yield Row(key=key[0], word=row.word) + + processor = MultiMapStateProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(word="hello")]) + self.assertEqual(tester.peekMapState("map1", "key1"), {("hello",): (1,)}) + self.assertEqual(tester.peekMapState("map2", "key1"), {("hello",): (2,)}) + + def test_mixed_state_types(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + class MixedStateProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + self.handle = handle + self.value_state = handle.getValueState("value", "int") + self.list_state = handle.getListState("list", "string") + self.map_state = handle.getMapState("map", "string", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + for row in rows: + val = self.value_state.get() if self.value_state.exists() else None + count = val[0] if val else 0 + self.value_state.update((count + 1,)) + self.list_state.appendValue((row.value,)) + self.map_state.updateValue((row.value,), (1,)) + yield Row(key=key[0], count=1) + + processor = MixedStateProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + self.assertEqual(tester.peekValueState("value", "key1"), (1,)) + self.assertEqual(tester.peekListState("list", "key1"), [("a",)]) + self.assertEqual(tester.peekMapState("map", "key1"), {("a",): (1,)}) + + def test_pandas_with_multiple_rows(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": ["a", "b", "c"]}) + result = tester.testInPandas("key1", input_df) + expected = pd.DataFrame({"key": ["key1"], "count": [3]}) + pdt.assert_frame_equal(result, expected, check_like=True) + + def test_pandas_dtype_preservation(self): + processor = RunningCountStatefulProcessorFactory().pandas() + tester = TwsTester(processor) + input_df = pd.DataFrame({"value": ["a"]}) + result = tester.testInPandas("key1", input_df) + self.assertEqual(result["count"].dtype, "int64") + + def test_processor_init_called_once(self): + from pyspark.sql.streaming.stateful_processor import ( + StatefulProcessor, + StatefulProcessorHandle, + ) + from typing import Iterator + + init_call_count = [0] + + class InitCountingProcessor(StatefulProcessor): + def init(self, handle: StatefulProcessorHandle) -> None: + init_call_count[0] += 1 + self.handle = handle + self.state = handle.getValueState("count", "int") + + def handleInputRows(self, key, rows, timerValues) -> Iterator: + val = self.state.get() if self.state.exists() else None + count = val[0] if val else 0 + for row in rows: + count += 1 + self.state.update((count,)) + yield Row(key=key[0], count=count) + + processor = InitCountingProcessor() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + tester.test("key1", [Row(value="b")]) + self.assertEqual(init_call_count[0], 1) + + def test_processor_state_persists_across_batches(self): + processor = RunningCountStatefulProcessorFactory().row() + tester = TwsTester(processor) + tester.test("key1", [Row(value="a")]) + result = tester.test("key1", [Row(value="b")]) + self.assertEqual(result, [Row(key="key1", count=2)]) + + def test_handle_initial_state_not_called_without_initial_state(self): Review Comment: Removed. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
