This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e6252d6c6143 [SPARK-50194][SS][PYTHON] Integration of New Timer API
and Initial State API with Timer
e6252d6c6143 is described below
commit e6252d6c6143f134201a17bb25978a136186cbfa
Author: jingz-db <[email protected]>
AuthorDate: Thu Nov 28 15:34:37 2024 +0900
[SPARK-50194][SS][PYTHON] Integration of New Timer API and Initial State
API with Timer
### What changes were proposed in this pull request?
As Scala side, we modify the timer API with a separate `handleExpiredTimer`
function inside `StatefulProcessor`, this PR make a change to the timer API to
couple with API on Scala side. Also adds a timer parameter to pass into
`handleInitialState` function to support use cases for registering timers in
the first batch for initial state rows.
### Why are the changes needed?
This change is to couple with Scala side of APIs:
https://github.com/apache/spark/pull/48553
### Does this PR introduce _any_ user-facing change?
Yes.
We add a new user defined function to explicitly handle expired timeres:
```
def handleExpiredTimer(
self, key: Any, timer_values: TimerValues, expired_timer_info:
ExpiredTimerInfo
```
We also add a new timer parameter to enable users to register timers for
keys exist in the initial state:
```
def handleInitialState(
self,
key: Any,
initialState: "PandasDataFrameLike",
timer_values: TimerValues) -> None
```
### How was this patch tested?
Add a new test in `test_pandas_transform_with_state`
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48838 from jingz-db/python-new-timer.
Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Co-authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/pandas/group_ops.py | 107 ++++++----
python/pyspark/sql/pandas/serializers.py | 13 +-
python/pyspark/sql/streaming/stateful_processor.py | 49 +++--
.../sql/streaming/stateful_processor_api_client.py | 94 +++++----
.../sql/streaming/stateful_processor_util.py | 27 +++
.../pandas/test_pandas_transform_with_state.py | 227 ++++++++++++---------
python/pyspark/worker.py | 68 +++---
.../TransformWithStateInPandasStateServer.scala | 2 +
8 files changed, 363 insertions(+), 224 deletions(-)
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index d8f22e434374..688ad4b05732 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -35,6 +35,7 @@ from pyspark.sql.streaming.stateful_processor import (
TimerValues,
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor,
StatefulProcessorHandle
+from pyspark.sql.streaming.stateful_processor_util import
TransformWithStateInPandasFuncMode
from pyspark.sql.types import StructType, _parse_datatype_string
if TYPE_CHECKING:
@@ -503,58 +504,59 @@ class PandasGroupedOpsMixin:
if isinstance(outputStructType, str):
outputStructType = cast(StructType,
_parse_datatype_string(outputStructType))
- def handle_data_with_timers(
+ def handle_data_rows(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
- inputRows: Iterator["PandasDataFrameLike"],
+ inputRows: Optional[Iterator["PandasDataFrameLike"]] = None,
) -> Iterator["PandasDataFrameLike"]:
statefulProcessorApiClient.set_implicit_key(key)
- if timeMode != "none":
- batch_timestamp =
statefulProcessorApiClient.get_batch_timestamp()
- watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+
+ batch_timestamp, watermark_timestamp =
statefulProcessorApiClient.get_timestamps(
+ timeMode
+ )
+
+ # process with data rows
+ if inputRows is not None:
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp)
+ )
+ return data_iter
else:
- batch_timestamp = -1
- watermark_timestamp = -1
- # process with invalid expiry timer info and emit data rows
- data_iter = statefulProcessor.handleInputRows(
- key,
- inputRows,
- TimerValues(batch_timestamp, watermark_timestamp),
- ExpiredTimerInfo(False),
+ return iter([])
+
+ def handle_expired_timers(
+ statefulProcessorApiClient: StatefulProcessorApiClient,
+ ) -> Iterator["PandasDataFrameLike"]:
+ batch_timestamp, watermark_timestamp =
statefulProcessorApiClient.get_timestamps(
+ timeMode
)
-
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)
- if timeMode == "processingtime":
+ if timeMode.lower() == "processingtime":
expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(
batch_timestamp
)
- elif timeMode == "eventtime":
+ elif timeMode.lower() == "eventtime":
expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(
watermark_timestamp
)
else:
expiry_list_iter = iter([[]])
- result_iter_list = [data_iter]
- # process with valid expiry time info and with empty input rows,
- # only timer related rows will be emitted
+ # process with expiry timers, only timer related rows will be
emitted
for expiry_list in expiry_list_iter:
for key_obj, expiry_timestamp in expiry_list:
- result_iter_list.append(
- statefulProcessor.handleInputRows(
- key_obj,
- iter([]),
- TimerValues(batch_timestamp, watermark_timestamp),
- ExpiredTimerInfo(True, expiry_timestamp),
- )
- )
- # TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
-
- result = itertools.chain(*result_iter_list)
- return result
+ statefulProcessorApiClient.set_implicit_key(key_obj)
+ for pd in statefulProcessor.handleExpiredTimer(
+ key=key_obj,
+ timer_values=TimerValues(batch_timestamp,
watermark_timestamp),
+ expired_timer_info=ExpiredTimerInfo(expiry_timestamp),
+ ):
+ yield pd
+ statefulProcessorApiClient.delete_timer(expiry_timestamp)
def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
+ mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
@@ -566,19 +568,28 @@ class PandasGroupedOpsMixin:
StatefulProcessorHandleState.INITIALIZED
)
- # Key is None when we have processed all the input data from the
worker and ready to
- # proceed with the cleanup steps.
- if key is None:
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+ result = handle_expired_timers(statefulProcessorApiClient)
+ return result
+ elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.TIMER_PROCESSED
+ )
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
-
- result = handle_data_with_timers(statefulProcessorApiClient, key,
inputRows)
- return result
+ else:
+ # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+ result = handle_data_rows(statefulProcessorApiClient, key,
inputRows)
+ return result
def transformWithStateWithInitStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
+ mode: TransformWithStateInPandasFuncMode,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
initialStates: Optional[Iterator["PandasDataFrameLike"]] = None,
@@ -603,20 +614,30 @@ class PandasGroupedOpsMixin:
StatefulProcessorHandleState.INITIALIZED
)
- # Key is None when we have processed all the input data from the
worker and ready to
- # proceed with the cleanup steps.
- if key is None:
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_TIMER:
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+ result = handle_expired_timers(statefulProcessorApiClient)
+ return result
+ elif mode == TransformWithStateInPandasFuncMode.COMPLETE:
statefulProcessorApiClient.remove_implicit_key()
statefulProcessor.close()
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.CLOSED)
return iter([])
+ else:
+ # mode == TransformWithStateInPandasFuncMode.PROCESS_DATA
+ batch_timestamp, watermark_timestamp =
statefulProcessorApiClient.get_timestamps(
+ timeMode
+ )
# only process initial state if first batch and initial state is
not None
if initialStates is not None:
for cur_initial_state in initialStates:
statefulProcessorApiClient.set_implicit_key(key)
- # TODO(SPARK-50194) integration with new timer API with
initial state
- statefulProcessor.handleInitialState(key,
cur_initial_state)
+ statefulProcessor.handleInitialState(
+ key, cur_initial_state, TimerValues(batch_timestamp,
watermark_timestamp)
+ )
# if we don't have input rows for the given key but only have
initial state
# for the grouping key, the inputRows iterator could be empty
@@ -629,7 +650,7 @@ class PandasGroupedOpsMixin:
inputRows = itertools.chain([first], inputRows)
if not input_rows_empty:
- result = handle_data_with_timers(statefulProcessorApiClient,
key, inputRows)
+ result = handle_data_rows(statefulProcessorApiClient, key,
inputRows)
else:
result = iter([])
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 5bf07b87400f..536bf7307065 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -36,6 +36,7 @@ from pyspark.sql.pandas.types import (
_create_converter_from_pandas,
_create_converter_to_pandas,
)
+from pyspark.sql.streaming.stateful_processor_util import
TransformWithStateInPandasFuncMode
from pyspark.sql.types import (
DataType,
StringType,
@@ -1197,7 +1198,11 @@ class
TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
data_batches = generate_data_batches(_batches)
for k, g in groupby(data_batches, key=lambda x: x[0]):
- yield (k, g)
+ yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+
+ yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
+
+ yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
def dump_stream(self, iterator, stream):
"""
@@ -1281,4 +1286,8 @@ class
TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
data_batches = generate_data_batches(_batches)
for k, g in groupby(data_batches, key=lambda x: x[0]):
- yield (k, g)
+ yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+
+ yield (TransformWithStateInPandasFuncMode.PROCESS_TIMER, None, None)
+
+ yield (TransformWithStateInPandasFuncMode.COMPLETE, None, None)
diff --git a/python/pyspark/sql/streaming/stateful_processor.py
b/python/pyspark/sql/streaming/stateful_processor.py
index 20078c215bac..9caa9304d6a8 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -105,21 +105,13 @@ class TimerValues:
class ExpiredTimerInfo:
"""
- Class used for arbitrary stateful operations with transformWithState to
access expired timer
- info. When is_valid is false, the expiry timestamp is invalid.
+ Class used to provide access to expired timer's expiry time.
.. versionadded:: 4.0.0
"""
- def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None:
- self._is_valid = is_valid
+ def __init__(self, expiry_time_in_ms: int = -1) -> None:
self._expiry_time_in_ms = expiry_time_in_ms
- def is_valid(self) -> bool:
- """
- Whether the expiry info is valid.
- """
- return self._is_valid
-
def get_expiry_time_in_ms(self) -> int:
"""
Get the timestamp for expired timer, return timestamp in millisecond.
@@ -398,7 +390,6 @@ class StatefulProcessor(ABC):
key: Any,
rows: Iterator["PandasDataFrameLike"],
timer_values: TimerValues,
- expired_timer_info: ExpiredTimerInfo,
) -> Iterator["PandasDataFrameLike"]:
"""
Function that will allow users to interact with input data rows along
with the grouping key.
@@ -420,11 +411,29 @@ class StatefulProcessor(ABC):
timer_values: TimerValues
Timer value for the current batch that process the input
rows.
Users can get the processing or event time timestamp
from TimerValues.
- expired_timer_info: ExpiredTimerInfo
- Timestamp of expired timers on the grouping key.
"""
...
+ def handleExpiredTimer(
+ self, key: Any, timer_values: TimerValues, expired_timer_info:
ExpiredTimerInfo
+ ) -> Iterator["PandasDataFrameLike"]:
+ """
+ Optional to implement. Will act return an empty iterator if not
defined.
+ Function that will be invoked when a timer is fired for a given key.
Users can choose to
+ evict state, register new timers and optionally provide output rows.
+
+ Parameters
+ ----------
+ key : Any
+ grouping key.
+ timer_values: TimerValues
+ Timer value for the current batch that process the input
rows.
+ Users can get the processing or event time timestamp
from TimerValues.
+ expired_timer_info: ExpiredTimerInfo
+ Instance of ExpiredTimerInfo that provides access
to expired timer.
+ """
+ return iter([])
+
@abstractmethod
def close(self) -> None:
"""
@@ -433,9 +442,21 @@ class StatefulProcessor(ABC):
"""
...
- def handleInitialState(self, key: Any, initialState:
"PandasDataFrameLike") -> None:
+ def handleInitialState(
+ self, key: Any, initialState: "PandasDataFrameLike", timer_values:
TimerValues
+ ) -> None:
"""
Optional to implement. Will act as no-op if not defined or no initial
state input.
Function that will be invoked only in the first batch for users to
process initial states.
+
+ Parameters
+ ----------
+ key : Any
+ grouping key.
+ initialState: :class:`pandas.DataFrame`
+ One dataframe in the initial state associated with the
key.
+ timer_values: TimerValues
+ Timer value for the current batch that process the input
rows.
+ Users can get the processing or event time timestamp
from TimerValues.
"""
pass
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 353f75e26796..53704188081c 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -62,6 +62,10 @@ class StatefulProcessorApiClient:
# Dictionaries to store the mapping between iterator id and a tuple of
pandas DataFrame
# and the index of the last row that was read.
self.list_timer_iterator_cursors: Dict[str,
Tuple["PandasDataFrameLike", int]] = {}
+ # statefulProcessorApiClient is initialized per batch per partition,
+ # so we will have new timestamps for a new batch
+ self._batch_timestamp = -1
+ self._watermark_timestamp = -1
def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -266,47 +270,15 @@ class StatefulProcessorApiClient:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error getting expiry timers: "
f"{response_message[1]}")
- def get_batch_timestamp(self) -> int:
- import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
-
- get_processing_time_call = stateMessage.GetProcessingTime()
- timer_value_call = stateMessage.TimerValueRequest(
- getProcessingTimer=get_processing_time_call
- )
- timer_request =
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
- message = stateMessage.StateRequest(timerRequest=timer_request)
-
- self._send_proto_message(message.SerializeToString())
- response_message = self._receive_proto_message_with_long_value()
- status = response_message[0]
- if status != 0:
- # TODO(SPARK-49233): Classify user facing errors.
- raise PySparkRuntimeError(
- f"Error getting processing timestamp: "
f"{response_message[1]}"
- )
+ def get_timestamps(self, time_mode: str) -> Tuple[int, int]:
+ if time_mode.lower() == "none":
+ return -1, -1
else:
- timestamp = response_message[2]
- return timestamp
-
- def get_watermark_timestamp(self) -> int:
- import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
-
- get_watermark_call = stateMessage.GetWatermark()
- timer_value_call =
stateMessage.TimerValueRequest(getWatermark=get_watermark_call)
- timer_request =
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
- message = stateMessage.StateRequest(timerRequest=timer_request)
-
- self._send_proto_message(message.SerializeToString())
- response_message = self._receive_proto_message_with_long_value()
- status = response_message[0]
- if status != 0:
- # TODO(SPARK-49233): Classify user facing errors.
- raise PySparkRuntimeError(
- f"Error getting eventtime timestamp: " f"{response_message[1]}"
- )
- else:
- timestamp = response_message[2]
- return timestamp
+ if self._batch_timestamp == -1:
+ self._batch_timestamp = self._get_batch_timestamp()
+ if self._watermark_timestamp == -1:
+ self._watermark_timestamp = self._get_watermark_timestamp()
+ return self._batch_timestamp, self._watermark_timestamp
def get_map_state(
self,
@@ -353,6 +325,48 @@ class StatefulProcessorApiClient:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error deleting state: "
f"{response_message[1]}")
+ def _get_batch_timestamp(self) -> int:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ get_processing_time_call = stateMessage.GetProcessingTime()
+ timer_value_call = stateMessage.TimerValueRequest(
+ getProcessingTimer=get_processing_time_call
+ )
+ timer_request =
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message_with_long_value()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(
+ f"Error getting processing timestamp: "
f"{response_message[1]}"
+ )
+ else:
+ timestamp = response_message[2]
+ return timestamp
+
+ def _get_watermark_timestamp(self) -> int:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ get_watermark_call = stateMessage.GetWatermark()
+ timer_value_call =
stateMessage.TimerValueRequest(getWatermark=get_watermark_call)
+ timer_request =
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message_with_long_value()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(
+ f"Error getting eventtime timestamp: " f"{response_message[1]}"
+ )
+ else:
+ timestamp = response_message[2]
+ return timestamp
+
def _send_proto_message(self, message: bytes) -> None:
# Writing zero here to indicate message version. This allows us to
evolve the message
# format or even changing the message protocol in the future.
diff --git a/python/pyspark/sql/streaming/stateful_processor_util.py
b/python/pyspark/sql/streaming/stateful_processor_util.py
new file mode 100644
index 000000000000..6130a9581bc2
--- /dev/null
+++ b/python/pyspark/sql/streaming/stateful_processor_util.py
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+
+# This file places the utilities for transformWithStateInPandas; we have a
separate file to avoid
+# putting internal classes to the stateful_processor.py file which contains
public APIs.
+
+
+class TransformWithStateInPandasFuncMode(Enum):
+ PROCESS_DATA = 1
+ PROCESS_TIMER = 2
+ COMPLETE = 3
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index f385d7cd1abc..60f2c9348db3 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -55,6 +55,7 @@ class TransformWithStateInPandasTestsMixin:
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider",
)
cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch",
"2")
+ cfg.set("spark.sql.session.timeZone", "UTC")
return cfg
def _prepare_input_data(self, input_path, col1, col2):
@@ -558,14 +559,25 @@ class TransformWithStateInPandasTestsMixin:
def test_transform_with_state_in_pandas_event_time(self):
def check_results(batch_df, batch_id):
if batch_id == 0:
- assert set(batch_df.sort("id").collect()) == {Row(id="a",
timestamp="20")}
- elif batch_id == 1:
+ # watermark for late event = 0
+ # watermark for eviction = 0
+ # timer is registered with expiration time = 0, hence expired
at the same batch
assert set(batch_df.sort("id").collect()) == {
Row(id="a", timestamp="20"),
Row(id="a-expired", timestamp="0"),
}
+ elif batch_id == 1:
+ # watermark for late event = 0
+ # watermark for eviction = 10 (20 - 10)
+ # timer is registered with expiration time = 10, hence expired
at the same batch
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="4"),
+ Row(id="a-expired", timestamp="10000"),
+ }
elif batch_id == 2:
- # verify that rows and expired timer produce the expected
result
+ # watermark for late event = 10
+ # watermark for eviction = 10 (unchanged as 4 < 10)
+ # timer is registered with expiration time = 10, hence expired
at the same batch
assert set(batch_df.sort("id").collect()) == {
Row(id="a", timestamp="15"),
Row(id="a-expired", timestamp="10000"),
@@ -578,7 +590,9 @@ class TransformWithStateInPandasTestsMixin:
EventTimeStatefulProcessor(), check_results
)
- def _test_transform_with_state_init_state_in_pandas(self,
stateful_processor, check_results):
+ def _test_transform_with_state_init_state_in_pandas(
+ self, stateful_processor, check_results, time_mode="None"
+ ):
input_path = tempfile.mkdtemp()
self._prepare_test_resource1(input_path)
time.sleep(2)
@@ -606,7 +620,7 @@ class TransformWithStateInPandasTestsMixin:
statefulProcessor=stateful_processor,
outputStructType=output_schema,
outputMode="Update",
- timeMode="None",
+ timeMode=time_mode,
initialState=initial_state,
)
.writeStream.queryName("this_query")
@@ -806,6 +820,45 @@ class TransformWithStateInPandasTestsMixin:
StatefulProcessorChainingOps(), check_results, "eventTime",
["outputTimestamp", "id"]
)
+ def test_transform_with_state_init_state_with_timers(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ # timers are registered and handled in the first batch for
+ # rows in initial state; For key=0 and key=3 which contains
+ # expired timers, both should be handled by handleExpiredTimers
+ # regardless of whether key exists in the data rows or not
+ expired_df =
batch_df.filter(batch_df["id"].contains("expired"))
+ data_df = batch_df.filter(~batch_df["id"].contains("expired"))
+ assert set(expired_df.sort("id").select("id").collect()) == {
+ Row(id="0-expired"),
+ Row(id="3-expired"),
+ }
+ assert set(data_df.sort("id").collect()) == {
+ Row(id="0", value=str(789 + 123 + 46)),
+ Row(id="1", value=str(146 + 346)),
+ }
+ elif batch_id == 1:
+ # handleInitialState is only processed in the first batch,
+ # no more timer is registered so no more expired timers
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", value=str(789 + 123 + 46 + 67)),
+ Row(id="3", value=str(987 + 12)),
+ }
+ else:
+ for q in self.spark.streams.active:
+ q.stop()
+
+ self._test_transform_with_state_init_state_in_pandas(
+ StatefulProcessorWithInitialStateTimers(), check_results,
"processingTime"
+ )
+
+ # run the same test suites again but with single shuffle partition
+ def test_transform_with_state_with_timers_single_partition(self):
+ with self.sql_conf({"spark.sql.shuffle.partitions": "1"}):
+ self.test_transform_with_state_init_state_with_timers()
+ self.test_transform_with_state_in_pandas_event_time()
+ self.test_transform_with_state_in_pandas_proc_timer()
+
class SimpleStatefulProcessorWithInitialState(StatefulProcessor):
# this dict is the same as input initial state dataframe
@@ -814,10 +867,9 @@ class
SimpleStatefulProcessorWithInitialState(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.value_state = handle.getValueState("value_state", state_schema)
+ self.handle = handle
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
exists = self.value_state.exists()
if exists:
value_row = self.value_state.get()
@@ -840,7 +892,7 @@ class
SimpleStatefulProcessorWithInitialState(StatefulProcessor):
else:
yield pd.DataFrame({"id": key, "value": str(accumulated_value)})
- def handleInitialState(self, key, initialState) -> None:
+ def handleInitialState(self, key, initialState, timer_values) -> None:
init_val = initialState.at[0, "initVal"]
self.value_state.update((init_val,))
if len(key) == 1:
@@ -850,6 +902,19 @@ class
SimpleStatefulProcessorWithInitialState(StatefulProcessor):
pass
+class
StatefulProcessorWithInitialStateTimers(SimpleStatefulProcessorWithInitialState):
+ def handleExpiredTimer(self, key, timer_values, expired_timer_info) ->
Iterator[pd.DataFrame]:
+ self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+ str_key = f"{str(key[0])}-expired"
+ yield pd.DataFrame(
+ {"id": (str_key,), "value":
str(expired_timer_info.get_expiry_time_in_ms())}
+ )
+
+ def handleInitialState(self, key, initialState, timer_values) -> None:
+ super().handleInitialState(key, initialState, timer_values)
+
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() - 1)
+
+
# A stateful processor that output the max event time it has seen. Register
timer for
# current watermark. Clear max state if timer expires.
class EventTimeStatefulProcessor(StatefulProcessor):
@@ -858,33 +923,30 @@ class EventTimeStatefulProcessor(StatefulProcessor):
self.handle = handle
self.max_state = handle.getValueState("max_state", state_schema)
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
- if expired_timer_info.is_valid():
- self.max_state.clear()
- self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
- str_key = f"{str(key[0])}-expired"
- yield pd.DataFrame(
- {"id": (str_key,), "timestamp":
str(expired_timer_info.get_expiry_time_in_ms())}
- )
+ def handleExpiredTimer(self, key, timer_values, expired_timer_info) ->
Iterator[pd.DataFrame]:
+ self.max_state.clear()
+ self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+ str_key = f"{str(key[0])}-expired"
+ yield pd.DataFrame(
+ {"id": (str_key,), "timestamp":
str(expired_timer_info.get_expiry_time_in_ms())}
+ )
- else:
- timestamp_list = []
- for pdf in rows:
- # int64 will represent timestamp in nanosecond, restore to
second
- timestamp_list.extend((pdf["eventTime"].astype("int64") //
10**9).tolist())
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
+ timestamp_list = []
+ for pdf in rows:
+ # int64 will represent timestamp in nanosecond, restore to second
+ timestamp_list.extend((pdf["eventTime"].astype("int64") //
10**9).tolist())
- if self.max_state.exists():
- cur_max = int(self.max_state.get()[0])
- else:
- cur_max = 0
- max_event_time = str(max(cur_max, max(timestamp_list)))
+ if self.max_state.exists():
+ cur_max = int(self.max_state.get()[0])
+ else:
+ cur_max = 0
+ max_event_time = str(max(cur_max, max(timestamp_list)))
- self.max_state.update((max_event_time,))
-
self.handle.registerTimer(timer_values.get_current_watermark_in_ms())
+ self.max_state.update((max_event_time,))
+ self.handle.registerTimer(timer_values.get_current_watermark_in_ms())
- yield pd.DataFrame({"id": key, "timestamp": max_event_time})
+ yield pd.DataFrame({"id": key, "timestamp": max_event_time})
def close(self) -> None:
pass
@@ -898,54 +960,49 @@ class ProcTimeStatefulProcessor(StatefulProcessor):
self.handle = handle
self.count_state = handle.getValueState("count_state", state_schema)
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
- if expired_timer_info.is_valid():
- # reset count state each time the timer is expired
- timer_list_1 = [e for e in self.handle.listTimers()]
- timer_list_2 = []
- idx = 0
- for e in self.handle.listTimers():
- timer_list_2.append(e)
- # check multiple iterator on the same grouping key works
- assert timer_list_2[idx] == timer_list_1[idx]
- idx += 1
-
- if len(timer_list_1) > 0:
- # before deleting the expiring timers, there are 2 timers -
- # one timer we just registered, and one that is going to be
deleted
- assert len(timer_list_1) == 2
- self.count_state.clear()
- self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
- yield pd.DataFrame(
- {
- "id": key,
- "countAsString": str("-1"),
- "timeValues":
str(expired_timer_info.get_expiry_time_in_ms()),
- }
- )
+ def handleExpiredTimer(self, key, timer_values, expired_timer_info) ->
Iterator[pd.DataFrame]:
+ # reset count state each time the timer is expired
+ timer_list_1 = [e for e in self.handle.listTimers()]
+ timer_list_2 = []
+ idx = 0
+ for e in self.handle.listTimers():
+ timer_list_2.append(e)
+ # check multiple iterator on the same grouping key works
+ assert timer_list_2[idx] == timer_list_1[idx]
+ idx += 1
+
+ if len(timer_list_1) > 0:
+ assert len(timer_list_1) == 2
+ self.count_state.clear()
+ self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+ yield pd.DataFrame(
+ {
+ "id": key,
+ "countAsString": str("-1"),
+ "timeValues": str(expired_timer_info.get_expiry_time_in_ms()),
+ }
+ )
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
+ if not self.count_state.exists():
+ count = 0
else:
- if not self.count_state.exists():
- count = 0
- else:
- count = int(self.count_state.get()[0])
+ count = int(self.count_state.get()[0])
- if key == ("0",):
-
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms())
+ if key == ("0",):
+
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms() + 1)
- rows_count = 0
- for pdf in rows:
- pdf_count = len(pdf)
- rows_count += pdf_count
+ rows_count = 0
+ for pdf in rows:
+ pdf_count = len(pdf)
+ rows_count += pdf_count
- count = count + rows_count
+ count = count + rows_count
- self.count_state.update((str(count),))
- timestamp = str(timer_values.get_current_processing_time_in_ms())
+ self.count_state.update((str(count),))
+ timestamp = str(timer_values.get_current_processing_time_in_ms())
- yield pd.DataFrame({"id": key, "countAsString": str(count),
"timeValues": timestamp})
+ yield pd.DataFrame({"id": key, "countAsString": str(count),
"timeValues": timestamp})
def close(self) -> None:
pass
@@ -961,9 +1018,7 @@ class SimpleStatefulProcessor(StatefulProcessor,
unittest.TestCase):
self.temp_state = handle.getValueState("tempState", state_schema)
handle.deleteIfExists("tempState")
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value
state exists"):
self.temp_state.exists()
new_violations = 0
@@ -995,9 +1050,7 @@ class StatefulProcessorChainingOps(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
pass
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
for pdf in rows:
timestamp_list = pdf["eventTime"].tolist()
yield pd.DataFrame({"id": key, "outputTimestamp": timestamp_list[0]})
@@ -1027,9 +1080,7 @@ class TTLStatefulProcessor(StatefulProcessor):
"ttl-map-state", user_key_schema, state_schema, 10000
)
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
count = 0
ttl_count = 0
ttl_list_state_count = 0
@@ -1079,9 +1130,7 @@ class InvalidSimpleStatefulProcessor(StatefulProcessor):
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations",
state_schema)
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
count = 0
exists = self.num_violations_state.exists()
assert not exists
@@ -1105,9 +1154,7 @@ class ListStateProcessor(StatefulProcessor):
self.list_state1 = handle.getListState("listState1", state_schema)
self.list_state2 = handle.getListState("listState2", state_schema)
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
count = 0
for pdf in rows:
list_state_rows = [(120,), (20,)]
@@ -1162,9 +1209,7 @@ class MapStateProcessor(StatefulProcessor):
value_schema = StructType([StructField("count", IntegerType(), True)])
self.map_state = handle.getMapState("mapState", key_schema,
value_schema)
- def handleInputRows(
- self, key, rows, timer_values, expired_timer_info
- ) -> Iterator[pd.DataFrame]:
+ def handleInputRows(self, key, rows, timer_values) ->
Iterator[pd.DataFrame]:
count = 0
key1 = ("key1",)
key2 = ("key2",)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 04f95e9f5264..1ebc04520eca 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -34,6 +34,7 @@ from pyspark.accumulators import (
_deserialize_accumulator,
)
from pyspark.sql.streaming.stateful_processor_api_client import
StatefulProcessorApiClient
+from pyspark.sql.streaming.stateful_processor_util import
TransformWithStateInPandasFuncMode
from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.resource import ResourceInformation
from pyspark.util import PythonEvalType, local_connect_and_auth
@@ -493,36 +494,36 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec,
runner_conf):
def wrap_grouped_transform_with_state_pandas_udf(f, return_type, runner_conf):
- def wrapped(stateful_processor_api_client, key, value_series_gen):
+ def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
import pandas as pd
values = (pd.concat(x, axis=1) for x in value_series_gen)
- result_iter = f(stateful_processor_api_client, key, values)
+ result_iter = f(stateful_processor_api_client, mode, key, values)
# TODO(SPARK-49100): add verification that elements in result_iter are
# indeed of type pd.DataFrame and confirm to assigned cols
return result_iter
- return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))]
+ return lambda p, m, k, v: [(wrapped(p, m, k, v),
to_arrow_type(return_type))]
def wrap_grouped_transform_with_state_pandas_init_state_udf(f, return_type,
runner_conf):
- def wrapped(stateful_processor_api_client, key, value_series_gen):
+ def wrapped(stateful_processor_api_client, mode, key, value_series_gen):
import pandas as pd
state_values_gen, init_states_gen = itertools.tee(value_series_gen, 2)
state_values = (df for x, _ in state_values_gen if not (df :=
pd.concat(x, axis=1)).empty)
init_states = (df for _, x in init_states_gen if not (df :=
pd.concat(x, axis=1)).empty)
- result_iter = f(stateful_processor_api_client, key, state_values,
init_states)
+ result_iter = f(stateful_processor_api_client, mode, key,
state_values, init_states)
# TODO(SPARK-49100): add verification that elements in result_iter are
# indeed of type pd.DataFrame and confirm to assigned cols
return result_iter
- return lambda p, k, v: [(wrapped(p, k, v), to_arrow_type(return_type))]
+ return lambda p, m, k, v: [(wrapped(p, m, k, v),
to_arrow_type(return_type))]
def wrap_grouped_map_pandas_udf_with_state(f, return_type):
@@ -1697,18 +1698,22 @@ def read_udfs(pickleSer, infile, eval_type):
ser.key_offsets = parsed_offsets[0][0]
stateful_processor_api_client =
StatefulProcessorApiClient(state_server_port, key_schema)
- # Create function like this:
- # mapper a: f([a[0]], [a[0], a[1]])
def mapper(a):
- key = a[0]
+ mode = a[0]
- def values_gen():
- for x in a[1]:
- retVal = [x[1][o] for o in parsed_offsets[0][1]]
- yield retVal
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ key = a[1]
- # This must be generator comprehension - do not materialize.
- return f(stateful_processor_api_client, key, values_gen())
+ def values_gen():
+ for x in a[2]:
+ retVal = [x[1][o] for o in parsed_offsets[0][1]]
+ yield retVal
+
+ # This must be generator comprehension - do not materialize.
+ return f(stateful_processor_api_client, mode, key,
values_gen())
+ else:
+ # mode == PROCESS_TIMER or mode == COMPLETE
+ return f(stateful_processor_api_client, mode, None, iter([]))
elif eval_type ==
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
# We assume there is only one UDF here because grouped map doesn't
@@ -1731,16 +1736,22 @@ def read_udfs(pickleSer, infile, eval_type):
stateful_processor_api_client =
StatefulProcessorApiClient(state_server_port, key_schema)
def mapper(a):
- key = a[0]
+ mode = a[0]
- def values_gen():
- for x in a[1]:
- retVal = [x[1][o] for o in parsed_offsets[0][1]]
- initVal = [x[2][o] for o in parsed_offsets[1][1]]
- yield retVal, initVal
+ if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
+ key = a[1]
- # This must be generator comprehension - do not materialize.
- return f(stateful_processor_api_client, key, values_gen())
+ def values_gen():
+ for x in a[2]:
+ retVal = [x[1][o] for o in parsed_offsets[0][1]]
+ initVal = [x[2][o] for o in parsed_offsets[1][1]]
+ yield retVal, initVal
+
+ # This must be generator comprehension - do not materialize.
+ return f(stateful_processor_api_client, mode, key,
values_gen())
+ else:
+ # mode == PROCESS_TIMER or mode == COMPLETE
+ return f(stateful_processor_api_client, mode, None, iter([]))
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
import pyarrow as pa
@@ -1958,17 +1969,6 @@ def main(infile, outfile):
try:
serializer.dump_stream(out_iter, outfile)
finally:
- # Sending a signal to TransformWithState UDF to perform proper
cleanup steps.
- if (
- eval_type ==
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
- or eval_type ==
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
- ):
- # Sending key as None to indicate that process() has
finished.
- end_iter = func(split_index, iter([(None, None)]))
- # Need to materialize the iterator to trigger the cleanup
steps, nothing needs
- # to be done here.
- for _ in end_iter:
- pass
if hasattr(out_iter, "close"):
out_iter.close()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 0373c8607ff2..2957f4b38758 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -120,6 +120,8 @@ class TransformWithStateInPandasStateServer(
}
/** Timer related class variables */
+ // An iterator to store all expired timer info. This is meant to be consumed
only once per
+ // partition. This should be called after finishing handling all input rows.
private var expiryTimestampIter: Option[Iterator[(Any, Long)]] =
if (expiryTimerIterForTest != null) {
Option(expiryTimerIterForTest)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]