This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2ace2ebf7ae1 [SPARK-49513][SS] Add Support for timer in
transformWithStateInPandas API
2ace2ebf7ae1 is described below
commit 2ace2ebf7ae1aa9877de354ac821444aecb1ac89
Author: jingz-db <[email protected]>
AuthorDate: Thu Oct 31 11:15:12 2024 +0900
[SPARK-49513][SS] Add Support for timer in transformWithStateInPandas API
### What changes were proposed in this pull request?
Support for timer in TransformWithStateInPandas Python API.
### Why are the changes needed?
To couple with Scala API, TransformWithStateInPandas should also support
processing/event time timer for arbitrary state.
### Does this PR introduce _any_ user-facing change?
Yes. Users can now interact with timers from `handleInputRows` with two
addtional parameters as:
```
def handleInputRows(
self, key: Any, rows: Iterator["PandasDataFrameLike"],
timer_values: TimerValues,
expired_timer_info: ExpiredTimerInfo)
```
And user can interact with a newly introduce `TimerValues` to get
processing/event time for current batch:
```
class TimerValues:
def get_current_processing_time_in_ms(self) -> int
def get_current_watermark_in_ms(self) -> int
```
Users can also interact with `expired_timer_info` to get the timestamp for
expired timers:
```
class ExpiredTimerInfo:
def is_valid(self) -> bool
def get_expiry_time_in_ms(self) ->
```
### How was this patch tested?
Unit tests in `TransformWithStateInPandasStateServerSuite` and integration
tests in `test_pandas_transform_with_state.py`.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47878 from jingz-db/python-timer-impl.
Lead-authored-by: jingz-db <[email protected]>
Co-authored-by: Jing Zhan <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/mypy.ini | 2 +-
python/pyspark/sql/pandas/group_ops.py | 51 +++-
python/pyspark/sql/streaming/StateMessage_pb2.py | 99 -------
python/pyspark/sql/streaming/list_state_client.py | 12 +-
python/pyspark/sql/streaming/map_state_client.py | 16 +-
.../sql/streaming/proto/StateMessage_pb2.py | 124 ++++++++
.../sql/streaming/{ => proto}/StateMessage_pb2.pyi | 98 ++++++-
python/pyspark/sql/streaming/proto/__init__.py | 16 +
python/pyspark/sql/streaming/stateful_processor.py | 84 +++++-
.../sql/streaming/stateful_processor_api_client.py | 197 ++++++++++++-
python/pyspark/sql/streaming/value_state_client.py | 8 +-
.../pandas/test_pandas_transform_with_state.py | 294 ++++++++++++++++++-
.../sql/execution/streaming/StateMessage.proto | 55 +++-
.../python/TransformWithStateInPandasExec.scala | 6 +-
.../TransformWithStateInPandasPythonRunner.scala | 8 +-
.../TransformWithStateInPandasStateServer.scala | 325 +++++++++++++++------
...ransformWithStateInPandasStateServerSuite.scala | 123 +++++++-
17 files changed, 1272 insertions(+), 246 deletions(-)
diff --git a/python/mypy.ini b/python/mypy.ini
index 4daa18593334..cb3595949e8d 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -181,5 +181,5 @@ ignore_missing_imports = True
ignore_missing_imports = True
; Ignore errors for proto generated code
-[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto]
+[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto,
pyspark.sql.streaming.proto]
ignore_errors = True
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 0d21edc73b81..1c87df538e1f 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import itertools
import sys
from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast
import warnings
@@ -27,6 +28,12 @@ from pyspark.sql.streaming.stateful_processor_api_client
import (
StatefulProcessorApiClient,
StatefulProcessorHandleState,
)
+from pyspark.sql.streaming.stateful_processor import (
+ ExpiredTimerInfo,
+ StatefulProcessor,
+ StatefulProcessorHandle,
+ TimerValues,
+)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor,
StatefulProcessorHandle
from pyspark.sql.types import StructType, _parse_datatype_string
@@ -501,7 +508,49 @@ class PandasGroupedOpsMixin:
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ if timeMode != "none":
+ batch_timestamp =
statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ else:
+ batch_timestamp = -1
+ watermark_timestamp = -1
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key,
+ inputRows,
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(False),
+ )
+
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)
+
+ if timeMode == "processingtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(
+ batch_timestamp
+ )
+ elif timeMode == "eventtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(
+ watermark_timestamp
+ )
+ else:
+ expiry_list_iter = iter([[]])
+
+ result_iter_list = [data_iter]
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for expiry_list in expiry_list_iter:
+ for key_obj, expiry_timestamp in expiry_list:
+ result_iter_list.append(
+ statefulProcessor.handleInputRows(
+ key_obj,
+ iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp),
+ )
+ )
+ # TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
+
+ result = itertools.chain(*result_iter_list)
return result
diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py
b/python/pyspark/sql/streaming/StateMessage_pb2.py
deleted file mode 100644
index 9c7740c0a223..000000000000
--- a/python/pyspark/sql/streaming/StateMessage_pb2.py
+++ /dev/null
@@ -1,99 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- coding: utf-8 -*-
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# NO CHECKED-IN PROTOBUF GENCODE
-# source: StateMessage.proto
-# Protobuf Python Version: 5.27.3
-"""Generated protocol buffer code."""
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf.internal import builder as _builder
-
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01
\x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02
\x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03
\x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04
\x01(\x0b\x32J.org.apache.spark [...]
-)
-
-_globals = globals()
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2",
_globals)
-if not _descriptor._USE_C_DESCRIPTORS:
- DESCRIPTOR._loaded_options = None
- _globals["_HANDLESTATE"]._serialized_start = 3778
- _globals["_HANDLESTATE"]._serialized_end = 3853
- _globals["_STATEREQUEST"]._serialized_start = 71
- _globals["_STATEREQUEST"]._serialized_end = 432
- _globals["_STATERESPONSE"]._serialized_start = 434
- _globals["_STATERESPONSE"]._serialized_end = 506
- _globals["_STATEFULPROCESSORCALL"]._serialized_start = 509
- _globals["_STATEFULPROCESSORCALL"]._serialized_end = 902
- _globals["_STATEVARIABLEREQUEST"]._serialized_start = 905
- _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1201
- _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1204
- _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1428
- _globals["_STATECALLCOMMAND"]._serialized_start = 1431
- _globals["_STATECALLCOMMAND"]._serialized_end = 1585
- _globals["_VALUESTATECALL"]._serialized_start = 1588
- _globals["_VALUESTATECALL"]._serialized_end = 1941
- _globals["_LISTSTATECALL"]._serialized_start = 1944
- _globals["_LISTSTATECALL"]._serialized_end = 2472
- _globals["_MAPSTATECALL"]._serialized_start = 2475
- _globals["_MAPSTATECALL"]._serialized_end = 3212
- _globals["_SETIMPLICITKEY"]._serialized_start = 3214
- _globals["_SETIMPLICITKEY"]._serialized_end = 3243
- _globals["_REMOVEIMPLICITKEY"]._serialized_start = 3245
- _globals["_REMOVEIMPLICITKEY"]._serialized_end = 3264
- _globals["_EXISTS"]._serialized_start = 3266
- _globals["_EXISTS"]._serialized_end = 3274
- _globals["_GET"]._serialized_start = 3276
- _globals["_GET"]._serialized_end = 3281
- _globals["_VALUESTATEUPDATE"]._serialized_start = 3283
- _globals["_VALUESTATEUPDATE"]._serialized_end = 3316
- _globals["_CLEAR"]._serialized_start = 3318
- _globals["_CLEAR"]._serialized_end = 3325
- _globals["_LISTSTATEGET"]._serialized_start = 3327
- _globals["_LISTSTATEGET"]._serialized_end = 3361
- _globals["_LISTSTATEPUT"]._serialized_start = 3363
- _globals["_LISTSTATEPUT"]._serialized_end = 3377
- _globals["_APPENDVALUE"]._serialized_start = 3379
- _globals["_APPENDVALUE"]._serialized_end = 3407
- _globals["_APPENDLIST"]._serialized_start = 3409
- _globals["_APPENDLIST"]._serialized_end = 3421
- _globals["_GETVALUE"]._serialized_start = 3423
- _globals["_GETVALUE"]._serialized_end = 3450
- _globals["_CONTAINSKEY"]._serialized_start = 3452
- _globals["_CONTAINSKEY"]._serialized_end = 3482
- _globals["_UPDATEVALUE"]._serialized_start = 3484
- _globals["_UPDATEVALUE"]._serialized_end = 3529
- _globals["_ITERATOR"]._serialized_start = 3531
- _globals["_ITERATOR"]._serialized_end = 3561
- _globals["_KEYS"]._serialized_start = 3563
- _globals["_KEYS"]._serialized_end = 3589
- _globals["_VALUES"]._serialized_start = 3591
- _globals["_VALUES"]._serialized_end = 3619
- _globals["_REMOVEKEY"]._serialized_start = 3621
- _globals["_REMOVEKEY"]._serialized_end = 3649
- _globals["_SETHANDLESTATE"]._serialized_start = 3651
- _globals["_SETHANDLESTATE"]._serialized_end = 3743
- _globals["_TTLCONFIG"]._serialized_start = 3745
- _globals["_TTLCONFIG"]._serialized_end = 3776
-# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/list_state_client.py
b/python/pyspark/sql/streaming/list_state_client.py
index 3615f6819bdd..d2152842819a 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -35,7 +35,7 @@ class ListStateClient:
self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
def exists(self, state_name: str) -> bool:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
exists_call = stateMessage.Exists()
list_state_call = stateMessage.ListStateCall(stateName=state_name,
exists=exists_call)
@@ -57,7 +57,7 @@ class ListStateClient:
)
def get(self, state_name: str, iterator_id: str) -> Tuple:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if iterator_id in self.pandas_df_dict:
# If the state is already in the dictionary, return the next row.
@@ -106,7 +106,7 @@ class ListStateClient:
return tuple(pandas_row)
def append_value(self, state_name: str, schema: Union[StructType, str],
value: Tuple) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
@@ -128,7 +128,7 @@ class ListStateClient:
def append_list(
self, state_name: str, schema: Union[StructType, str], values:
List[Tuple]
) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
@@ -149,7 +149,7 @@ class ListStateClient:
raise PySparkRuntimeError(f"Error updating value state: "
f"{response_message[1]}")
def put(self, state_name: str, schema: Union[StructType, str], values:
List[Tuple]) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
@@ -168,7 +168,7 @@ class ListStateClient:
raise PySparkRuntimeError(f"Error updating value state: "
f"{response_message[1]}")
def clear(self, state_name: str) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
clear_call = stateMessage.Clear()
list_state_call = stateMessage.ListStateCall(stateName=state_name,
clear=clear_call)
diff --git a/python/pyspark/sql/streaming/map_state_client.py
b/python/pyspark/sql/streaming/map_state_client.py
index 54a5ba1bbffa..6ec7448b4863 100644
--- a/python/pyspark/sql/streaming/map_state_client.py
+++ b/python/pyspark/sql/streaming/map_state_client.py
@@ -49,7 +49,7 @@ class MapStateClient:
self.user_key_or_value_iterator_cursors: Dict[str,
Tuple["PandasDataFrameLike", int]] = {}
def exists(self, state_name: str) -> bool:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
exists_call = stateMessage.Exists()
map_state_call = stateMessage.MapStateCall(stateName=state_name,
exists=exists_call)
@@ -69,7 +69,7 @@ class MapStateClient:
raise PySparkRuntimeError(f"Error checking map state exists:
{response_message[1]}")
def get_value(self, state_name: str, user_key: Tuple) -> Optional[Tuple]:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
bytes = self._stateful_processor_api_client._serialize_to_bytes(
self.user_key_schema, user_key
@@ -92,7 +92,7 @@ class MapStateClient:
raise PySparkRuntimeError(f"Error getting value:
{response_message[1]}")
def contains_key(self, state_name: str, user_key: Tuple) -> bool:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
bytes = self._stateful_processor_api_client._serialize_to_bytes(
self.user_key_schema, user_key
@@ -124,7 +124,7 @@ class MapStateClient:
user_key: Tuple,
value: Tuple,
) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
key_bytes = self._stateful_processor_api_client._serialize_to_bytes(
self.user_key_schema, user_key
@@ -147,7 +147,7 @@ class MapStateClient:
raise PySparkRuntimeError(f"Error updating map state value:
{response_message[1]}")
def get_key_value_pair(self, state_name: str, iterator_id: str) ->
Tuple[Tuple, Tuple]:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if iterator_id in self.user_key_value_pair_iterator_cursors:
# If the state is already in the dictionary, return the next row.
@@ -195,7 +195,7 @@ class MapStateClient:
return tuple(key_row), tuple(value_row)
def get_row(self, state_name: str, iterator_id: str, is_key: bool) ->
Tuple:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if iterator_id in self.user_key_or_value_iterator_cursors:
# If the state is already in the dictionary, return the next row.
@@ -247,7 +247,7 @@ class MapStateClient:
return tuple(pandas_row)
def remove_key(self, state_name: str, key: Tuple) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
bytes =
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema,
key)
remove_key_call = stateMessage.RemoveKey(userKey=bytes)
@@ -263,7 +263,7 @@ class MapStateClient:
raise PySparkRuntimeError(f"Error removing key from map state:
{response_message[1]}")
def clear(self, state_name: str) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
clear_call = stateMessage.Clear()
map_state_call = stateMessage.MapStateCall(stateName=state_name,
clear=clear_call)
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
new file mode 100644
index 000000000000..aeb195ca10ba
--- /dev/null
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# NO CHECKED-IN PROTOBUF GENCODE
+# source: StateMessage.proto
+# Protobuf Python Version: 5.27.3
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import runtime_version as _runtime_version
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+_runtime_version.ValidateProtobufRuntimeVersion(
+ _runtime_version.Domain.PUBLIC, 5, 27, 3, "", "StateMessage.proto"
+)
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+
b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xbf\x03\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01
\x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02
\x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03
\x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04
\x01(\x0b\x32J.org.apache.spark [...]
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2",
_globals)
+if not _descriptor._USE_C_DESCRIPTORS:
+ DESCRIPTOR._loaded_options = None
+ _globals["_HANDLESTATE"]._serialized_start = 4966
+ _globals["_HANDLESTATE"]._serialized_end = 5062
+ _globals["_STATEREQUEST"]._serialized_start = 71
+ _globals["_STATEREQUEST"]._serialized_end = 518
+ _globals["_STATERESPONSE"]._serialized_start = 520
+ _globals["_STATERESPONSE"]._serialized_end = 592
+ _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 594
+ _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 681
+ _globals["_STATEFULPROCESSORCALL"]._serialized_start = 684
+ _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1174
+ _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1177
+ _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1473
+ _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1476
+ _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1700
+ _globals["_TIMERREQUEST"]._serialized_start = 1703
+ _globals["_TIMERREQUEST"]._serialized_end = 1921
+ _globals["_TIMERVALUEREQUEST"]._serialized_start = 1924
+ _globals["_TIMERVALUEREQUEST"]._serialized_end = 2136
+ _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2138
+ _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2185
+ _globals["_GETPROCESSINGTIME"]._serialized_start = 2187
+ _globals["_GETPROCESSINGTIME"]._serialized_end = 2206
+ _globals["_GETWATERMARK"]._serialized_start = 2208
+ _globals["_GETWATERMARK"]._serialized_end = 2222
+ _globals["_STATECALLCOMMAND"]._serialized_start = 2225
+ _globals["_STATECALLCOMMAND"]._serialized_end = 2379
+ _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2382
+ _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 2653
+ _globals["_VALUESTATECALL"]._serialized_start = 2656
+ _globals["_VALUESTATECALL"]._serialized_end = 3009
+ _globals["_LISTSTATECALL"]._serialized_start = 3012
+ _globals["_LISTSTATECALL"]._serialized_end = 3540
+ _globals["_MAPSTATECALL"]._serialized_start = 3543
+ _globals["_MAPSTATECALL"]._serialized_end = 4280
+ _globals["_SETIMPLICITKEY"]._serialized_start = 4282
+ _globals["_SETIMPLICITKEY"]._serialized_end = 4311
+ _globals["_REMOVEIMPLICITKEY"]._serialized_start = 4313
+ _globals["_REMOVEIMPLICITKEY"]._serialized_end = 4332
+ _globals["_EXISTS"]._serialized_start = 4334
+ _globals["_EXISTS"]._serialized_end = 4342
+ _globals["_GET"]._serialized_start = 4344
+ _globals["_GET"]._serialized_end = 4349
+ _globals["_REGISTERTIMER"]._serialized_start = 4351
+ _globals["_REGISTERTIMER"]._serialized_end = 4393
+ _globals["_DELETETIMER"]._serialized_start = 4395
+ _globals["_DELETETIMER"]._serialized_end = 4435
+ _globals["_LISTTIMERS"]._serialized_start = 4437
+ _globals["_LISTTIMERS"]._serialized_end = 4469
+ _globals["_VALUESTATEUPDATE"]._serialized_start = 4471
+ _globals["_VALUESTATEUPDATE"]._serialized_end = 4504
+ _globals["_CLEAR"]._serialized_start = 4506
+ _globals["_CLEAR"]._serialized_end = 4513
+ _globals["_LISTSTATEGET"]._serialized_start = 4515
+ _globals["_LISTSTATEGET"]._serialized_end = 4549
+ _globals["_LISTSTATEPUT"]._serialized_start = 4551
+ _globals["_LISTSTATEPUT"]._serialized_end = 4565
+ _globals["_APPENDVALUE"]._serialized_start = 4567
+ _globals["_APPENDVALUE"]._serialized_end = 4595
+ _globals["_APPENDLIST"]._serialized_start = 4597
+ _globals["_APPENDLIST"]._serialized_end = 4609
+ _globals["_GETVALUE"]._serialized_start = 4611
+ _globals["_GETVALUE"]._serialized_end = 4638
+ _globals["_CONTAINSKEY"]._serialized_start = 4640
+ _globals["_CONTAINSKEY"]._serialized_end = 4670
+ _globals["_UPDATEVALUE"]._serialized_start = 4672
+ _globals["_UPDATEVALUE"]._serialized_end = 4717
+ _globals["_ITERATOR"]._serialized_start = 4719
+ _globals["_ITERATOR"]._serialized_end = 4749
+ _globals["_KEYS"]._serialized_start = 4751
+ _globals["_KEYS"]._serialized_end = 4777
+ _globals["_VALUES"]._serialized_start = 4779
+ _globals["_VALUES"]._serialized_end = 4807
+ _globals["_REMOVEKEY"]._serialized_start = 4809
+ _globals["_REMOVEKEY"]._serialized_end = 4837
+ _globals["_SETHANDLESTATE"]._serialized_start = 4839
+ _globals["_SETHANDLESTATE"]._serialized_end = 4931
+ _globals["_TTLCONFIG"]._serialized_start = 4933
+ _globals["_TTLCONFIG"]._serialized_end = 4964
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi
b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
similarity index 77%
rename from python/pyspark/sql/streaming/StateMessage_pb2.pyi
rename to python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
index 791a221d96f3..ff525ee136a4 100644
--- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+#
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
@@ -30,11 +31,13 @@ class HandleState(int,
metaclass=_enum_type_wrapper.EnumTypeWrapper):
CREATED: _ClassVar[HandleState]
INITIALIZED: _ClassVar[HandleState]
DATA_PROCESSED: _ClassVar[HandleState]
+ TIMER_PROCESSED: _ClassVar[HandleState]
CLOSED: _ClassVar[HandleState]
CREATED: HandleState
INITIALIZED: HandleState
DATA_PROCESSED: HandleState
+TIMER_PROCESSED: HandleState
CLOSED: HandleState
class StateRequest(_message.Message):
@@ -43,21 +46,25 @@ class StateRequest(_message.Message):
"statefulProcessorCall",
"stateVariableRequest",
"implicitGroupingKeyRequest",
+ "timerRequest",
)
VERSION_FIELD_NUMBER: _ClassVar[int]
STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int]
STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int]
IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int]
+ TIMERREQUEST_FIELD_NUMBER: _ClassVar[int]
version: int
statefulProcessorCall: StatefulProcessorCall
stateVariableRequest: StateVariableRequest
implicitGroupingKeyRequest: ImplicitGroupingKeyRequest
+ timerRequest: TimerRequest
def __init__(
self,
version: _Optional[int] = ...,
statefulProcessorCall: _Optional[_Union[StatefulProcessorCall,
_Mapping]] = ...,
stateVariableRequest: _Optional[_Union[StateVariableRequest,
_Mapping]] = ...,
implicitGroupingKeyRequest:
_Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ...,
+ timerRequest: _Optional[_Union[TimerRequest, _Mapping]] = ...,
) -> None: ...
class StateResponse(_message.Message):
@@ -75,22 +82,40 @@ class StateResponse(_message.Message):
value: _Optional[bytes] = ...,
) -> None: ...
+class StateResponseWithLongTypeVal(_message.Message):
+ __slots__ = ("statusCode", "errorMessage", "value")
+ STATUSCODE_FIELD_NUMBER: _ClassVar[int]
+ ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ statusCode: int
+ errorMessage: str
+ value: int
+ def __init__(
+ self,
+ statusCode: _Optional[int] = ...,
+ errorMessage: _Optional[str] = ...,
+ value: _Optional[int] = ...,
+ ) -> None: ...
+
class StatefulProcessorCall(_message.Message):
- __slots__ = ("setHandleState", "getValueState", "getListState",
"getMapState")
+ __slots__ = ("setHandleState", "getValueState", "getListState",
"getMapState", "timerStateCall")
SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int]
GETVALUESTATE_FIELD_NUMBER: _ClassVar[int]
GETLISTSTATE_FIELD_NUMBER: _ClassVar[int]
GETMAPSTATE_FIELD_NUMBER: _ClassVar[int]
+ TIMERSTATECALL_FIELD_NUMBER: _ClassVar[int]
setHandleState: SetHandleState
getValueState: StateCallCommand
getListState: StateCallCommand
getMapState: StateCallCommand
+ timerStateCall: TimerStateCallCommand
def __init__(
self,
setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ...,
getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
+ timerStateCall: _Optional[_Union[TimerStateCallCommand, _Mapping]] =
...,
) -> None: ...
class StateVariableRequest(_message.Message):
@@ -120,6 +145,44 @@ class ImplicitGroupingKeyRequest(_message.Message):
removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] =
...,
) -> None: ...
+class TimerRequest(_message.Message):
+ __slots__ = ("timerValueRequest", "expiryTimerRequest")
+ TIMERVALUEREQUEST_FIELD_NUMBER: _ClassVar[int]
+ EXPIRYTIMERREQUEST_FIELD_NUMBER: _ClassVar[int]
+ timerValueRequest: TimerValueRequest
+ expiryTimerRequest: ExpiryTimerRequest
+ def __init__(
+ self,
+ timerValueRequest: _Optional[_Union[TimerValueRequest, _Mapping]] =
...,
+ expiryTimerRequest: _Optional[_Union[ExpiryTimerRequest, _Mapping]] =
...,
+ ) -> None: ...
+
+class TimerValueRequest(_message.Message):
+ __slots__ = ("getProcessingTimer", "getWatermark")
+ GETPROCESSINGTIMER_FIELD_NUMBER: _ClassVar[int]
+ GETWATERMARK_FIELD_NUMBER: _ClassVar[int]
+ getProcessingTimer: GetProcessingTime
+ getWatermark: GetWatermark
+ def __init__(
+ self,
+ getProcessingTimer: _Optional[_Union[GetProcessingTime, _Mapping]] =
...,
+ getWatermark: _Optional[_Union[GetWatermark, _Mapping]] = ...,
+ ) -> None: ...
+
+class ExpiryTimerRequest(_message.Message):
+ __slots__ = ("expiryTimestampMs",)
+ EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int]
+ expiryTimestampMs: int
+ def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ...
+
+class GetProcessingTime(_message.Message):
+ __slots__ = ()
+ def __init__(self) -> None: ...
+
+class GetWatermark(_message.Message):
+ __slots__ = ()
+ def __init__(self) -> None: ...
+
class StateCallCommand(_message.Message):
__slots__ = ("stateName", "schema", "mapStateValueSchema", "ttl")
STATENAME_FIELD_NUMBER: _ClassVar[int]
@@ -138,6 +201,21 @@ class StateCallCommand(_message.Message):
ttl: _Optional[_Union[TTLConfig, _Mapping]] = ...,
) -> None: ...
+class TimerStateCallCommand(_message.Message):
+ __slots__ = ("register", "delete", "list")
+ REGISTER_FIELD_NUMBER: _ClassVar[int]
+ DELETE_FIELD_NUMBER: _ClassVar[int]
+ LIST_FIELD_NUMBER: _ClassVar[int]
+ register: RegisterTimer
+ delete: DeleteTimer
+ list: ListTimers
+ def __init__(
+ self,
+ register: _Optional[_Union[RegisterTimer, _Mapping]] = ...,
+ delete: _Optional[_Union[DeleteTimer, _Mapping]] = ...,
+ list: _Optional[_Union[ListTimers, _Mapping]] = ...,
+ ) -> None: ...
+
class ValueStateCall(_message.Message):
__slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear")
STATENAME_FIELD_NUMBER: _ClassVar[int]
@@ -259,6 +337,24 @@ class Get(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
+class RegisterTimer(_message.Message):
+ __slots__ = ("expiryTimestampMs",)
+ EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int]
+ expiryTimestampMs: int
+ def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ...
+
+class DeleteTimer(_message.Message):
+ __slots__ = ("expiryTimestampMs",)
+ EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int]
+ expiryTimestampMs: int
+ def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ...
+
+class ListTimers(_message.Message):
+ __slots__ = ("iteratorId",)
+ ITERATORID_FIELD_NUMBER: _ClassVar[int]
+ iteratorId: str
+ def __init__(self, iteratorId: _Optional[str] = ...) -> None: ...
+
class ValueStateUpdate(_message.Message):
__slots__ = ("value",)
VALUE_FIELD_NUMBER: _ClassVar[int]
diff --git a/python/pyspark/sql/streaming/proto/__init__.py
b/python/pyspark/sql/streaming/proto/__init__.py
new file mode 100644
index 000000000000..cce3acad34a4
--- /dev/null
+++ b/python/pyspark/sql/streaming/proto/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/streaming/stateful_processor.py
b/python/pyspark/sql/streaming/stateful_processor.py
index bac762e8addd..404326bb479c 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -18,7 +18,10 @@
from abc import ABC, abstractmethod
from typing import Any, List, TYPE_CHECKING, Iterator, Optional, Union, Tuple
-from pyspark.sql.streaming.stateful_processor_api_client import
StatefulProcessorApiClient
+from pyspark.sql.streaming.stateful_processor_api_client import (
+ StatefulProcessorApiClient,
+ ListTimerIterator,
+)
from pyspark.sql.streaming.list_state_client import ListStateClient,
ListStateIterator
from pyspark.sql.streaming.map_state_client import (
MapStateClient,
@@ -74,6 +77,56 @@ class ValueState:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+ .. versionadded:: 4.0.0
+ """
+
+ def __init__(
+ self, current_processing_time_in_ms: int = -1,
current_watermark_in_ms: int = -1
+ ) -> None:
+ self._current_processing_time_in_ms = current_processing_time_in_ms
+ self._current_watermark_in_ms = current_watermark_in_ms
+
+ def get_current_processing_time_in_ms(self) -> int:
+ """
+ Get processing time for current batch, return timestamp in millisecond.
+ """
+ return self._current_processing_time_in_ms
+
+ def get_current_watermark_in_ms(self) -> int:
+ """
+ Get watermark for current batch, return timestamp in millisecond.
+ """
+ return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access expired timer
+ info. When is_valid is false, the expiry timestamp is invalid.
+ .. versionadded:: 4.0.0
+ """
+
+ def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None:
+ self._is_valid = is_valid
+ self._expiry_time_in_ms = expiry_time_in_ms
+
+ def is_valid(self) -> bool:
+ """
+ Whether the expiry info is valid.
+ """
+ return self._is_valid
+
+ def get_expiry_time_in_ms(self) -> int:
+ """
+ Get the timestamp for expired timer, return timestamp in millisecond.
+ """
+ return self._expiry_time_in_ms
+
+
class ListState:
"""
Class used for arbitrary stateful operations with transformWithState to
capture list value
@@ -292,6 +345,24 @@ class StatefulProcessorHandle:
state_name,
)
+ def registerTimer(self, expiry_time_stamp_ms: int) -> None:
+ """
+ Register a timer for a given expiry timestamp in milliseconds for the
grouping key.
+ """
+ self.stateful_processor_api_client.register_timer(expiry_time_stamp_ms)
+
+ def deleteTimer(self, expiry_time_stamp_ms: int) -> None:
+ """
+ Delete a timer for a given expiry timestamp in milliseconds for the
grouping key.
+ """
+ self.stateful_processor_api_client.delete_timer(expiry_time_stamp_ms)
+
+ def listTimers(self) -> Iterator[int]:
+ """
+ List all timers of their expiry timestamps in milliseconds for the
grouping key.
+ """
+ return ListTimerIterator(self.stateful_processor_api_client)
+
class StatefulProcessor(ABC):
"""
@@ -317,7 +388,11 @@ class StatefulProcessor(ABC):
@abstractmethod
def handleInputRows(
- self, key: Any, rows: Iterator["PandasDataFrameLike"]
+ self,
+ key: Any,
+ rows: Iterator["PandasDataFrameLike"],
+ timer_values: TimerValues,
+ expired_timer_info: ExpiredTimerInfo,
) -> Iterator["PandasDataFrameLike"]:
"""
Function that will allow users to interact with input data rows along
with the grouping key.
@@ -336,6 +411,11 @@ class StatefulProcessor(ABC):
grouping key.
rows : iterable of :class:`pandas.DataFrame`
iterator of input rows associated with grouping key
+ timer_values: TimerValues
+ Timer value for the current batch that process the input
rows.
+ Users can get the processing or event time timestamp
from TimerValues.
+ expired_timer_info: ExpiredTimerInfo
+ Timestamp of expired timers on the grouping key.
"""
...
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 552ab44d1ddf..ce3bae0a7c91 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -17,12 +17,13 @@
from enum import Enum
import os
import socket
-from typing import Any, List, Union, Optional, cast, Tuple
+from typing import Any, Dict, List, Union, Optional, cast, Tuple, Iterator
from pyspark.serializers import write_int, read_int, UTF8Deserializer
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
from pyspark.sql.types import (
StructType,
+ TYPE_CHECKING,
_parse_datatype_string,
Row,
)
@@ -30,6 +31,10 @@ from pyspark.sql.pandas.types import
convert_pandas_using_numpy_type
from pyspark.sql.utils import has_numpy
from pyspark.serializers import CPickleSerializer
from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
__all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]
@@ -38,7 +43,8 @@ class StatefulProcessorHandleState(Enum):
CREATED = 1
INITIALIZED = 2
DATA_PROCESSED = 3
- CLOSED = 4
+ TIMER_PROCESSED = 4
+ CLOSED = 5
class StatefulProcessorApiClient:
@@ -53,9 +59,12 @@ class StatefulProcessorApiClient:
self.utf8_deserializer = UTF8Deserializer()
self.pickleSer = CPickleSerializer()
self.serializer = ArrowStreamSerializer()
+ # Dictionaries to store the mapping between iterator id and a tuple of
pandas DataFrame
+ # and the index of the last row that was read.
+ self.list_timer_iterator_cursors: Dict[str,
Tuple["PandasDataFrameLike", int]] = {}
def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if state == StatefulProcessorHandleState.CREATED:
proto_state = stateMessage.CREATED
@@ -63,6 +72,8 @@ class StatefulProcessorApiClient:
proto_state = stateMessage.INITIALIZED
elif state == StatefulProcessorHandleState.DATA_PROCESSED:
proto_state = stateMessage.DATA_PROCESSED
+ elif state == StatefulProcessorHandleState.TIMER_PROCESSED:
+ proto_state = stateMessage.TIMER_PROCESSED
else:
proto_state = stateMessage.CLOSED
set_handle_state = stateMessage.SetHandleState(state=proto_state)
@@ -80,7 +91,7 @@ class StatefulProcessorApiClient:
raise PySparkRuntimeError(f"Error setting handle state: "
f"{response_message[1]}")
def set_implicit_key(self, key: Tuple) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
key_bytes = self._serialize_to_bytes(self.key_schema, key)
set_implicit_key = stateMessage.SetImplicitKey(key=key_bytes)
@@ -95,7 +106,7 @@ class StatefulProcessorApiClient:
raise PySparkRuntimeError(f"Error setting implicit key: "
f"{response_message[1]}")
def remove_implicit_key(self) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
remove_implicit_key = stateMessage.RemoveImplicitKey()
request =
stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key)
@@ -111,7 +122,7 @@ class StatefulProcessorApiClient:
def get_value_state(
self, state_name: str, schema: Union[StructType, str],
ttl_duration_ms: Optional[int]
) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
@@ -134,7 +145,7 @@ class StatefulProcessorApiClient:
def get_list_state(
self, state_name: str, schema: Union[StructType, str],
ttl_duration_ms: Optional[int]
) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
@@ -152,7 +163,150 @@ class StatefulProcessorApiClient:
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
- raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ raise PySparkRuntimeError(f"Error initializing list state: "
f"{response_message[1]}")
+
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error deleting timer: "
f"{response_message[1]}")
+
+ def get_list_timer_row(self, iterator_id: str) -> int:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ if iterator_id in self.list_timer_iterator_cursors:
+ # if the iterator is already in the dictionary, return the next row
+ pandas_df, index = self.list_timer_iterator_cursors[iterator_id]
+ else:
+ list_call = stateMessage.ListTimers(iteratorId=iterator_id)
+ state_call_command =
stateMessage.TimerStateCallCommand(list=list_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 0:
+ iterator = self._read_arrow_state()
+ # We need to exhaust the iterator here to make sure all the
arrow batches are read,
+ # even though there is only one batch in the iterator.
Otherwise, the stream might
+ # block further reads since it thinks there might still be
some arrow batches left.
+ # We only need to read the first batch in the iterator because
it's guaranteed that
+ # there would only be one batch sent from the JVM side.
+ data_batch = None
+ for batch in iterator:
+ if data_batch is None:
+ data_batch = batch
+ if data_batch is None:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError("Error getting map state entry.")
+ pandas_df = data_batch.to_pandas()
+ index = 0
+ else:
+ raise StopIteration()
+ new_index = index + 1
+ if new_index < len(pandas_df):
+ # Update the index in the dictionary.
+ self.list_timer_iterator_cursors[iterator_id] = (pandas_df,
new_index)
+ else:
+ # If the index is at the end of the DataFrame, remove the state
from the dictionary.
+ self.list_timer_iterator_cursors.pop(iterator_id, None)
+ return pandas_df.at[index, "timestamp"].item()
+
+ def get_expiry_timers_iterator(
+ self, expiry_timestamp: int
+ ) -> Iterator[list[Tuple[Tuple, int]]]:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ while True:
+ expiry_timer_call =
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+ timer_request =
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ result_list = []
+ iterator = self._read_arrow_state()
+ for batch in iterator:
+ batch_df = batch.to_pandas()
+ for i in range(batch.num_rows):
+ deserialized_key = self.pickleSer.loads(batch_df.at[i,
"key"])
+ timestamp = batch_df.at[i, "timestamp"].item()
+ result_list.append((tuple(deserialized_key),
timestamp))
+ yield result_list
+ else:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error getting expiry timers: "
f"{response_message[1]}")
+
+ def get_batch_timestamp(self) -> int:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ get_processing_time_call = stateMessage.GetProcessingTime()
+ timer_value_call = stateMessage.TimerValueRequest(
+ getProcessingTimer=get_processing_time_call
+ )
+ timer_request =
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message_with_long_value()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(
+ f"Error getting processing timestamp: "
f"{response_message[1]}"
+ )
+ else:
+ timestamp = response_message[2]
+ return timestamp
+
+ def get_watermark_timestamp(self) -> int:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ get_watermark_call = stateMessage.GetWatermark()
+ timer_value_call =
stateMessage.TimerValueRequest(getWatermark=get_watermark_call)
+ timer_request =
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message_with_long_value()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(
+ f"Error getting eventtime timestamp: " f"{response_message[1]}"
+ )
+ else:
+ timestamp = response_message[2]
+ return timestamp
def get_map_state(
self,
@@ -161,7 +315,7 @@ class StatefulProcessorApiClient:
value_schema: Union[StructType, str],
ttl_duration_ms: Optional[int],
) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(user_key_schema, str):
user_key_schema = cast(StructType,
_parse_datatype_string(user_key_schema))
@@ -193,7 +347,7 @@ class StatefulProcessorApiClient:
self.sockfile.flush()
def _receive_proto_message(self) -> Tuple[int, str, bytes]:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
length = read_int(self.sockfile)
bytes = self.sockfile.read(length)
@@ -201,6 +355,15 @@ class StatefulProcessorApiClient:
message.ParseFromString(bytes)
return message.statusCode, message.errorMessage, message.value
+ def _receive_proto_message_with_long_value(self) -> Tuple[int, str, int]:
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+ length = read_int(self.sockfile)
+ bytes = self.sockfile.read(length)
+ message = stateMessage.StateResponseWithLongTypeVal()
+ message.ParseFromString(bytes)
+ return message.statusCode, message.errorMessage, message.value
+
def _receive_str(self) -> str:
return self.utf8_deserializer.loads(self.sockfile)
@@ -243,3 +406,17 @@ class StatefulProcessorApiClient:
def _read_arrow_state(self) -> Any:
return self.serializer.load_stream(self.sockfile)
+
+
+class ListTimerIterator:
+ def __init__(self, stateful_processor_api_client:
StatefulProcessorApiClient):
+ # Generate a unique identifier for the iterator to make sure iterators
on the
+ # same partition won't interfere with each other
+ self.iterator_id = str(uuid.uuid4())
+ self.stateful_processor_api_client = stateful_processor_api_client
+
+ def __iter__(self) -> Iterator[int]:
+ return self
+
+ def __next__(self) -> int:
+ return
self.stateful_processor_api_client.get_list_timer_row(self.iterator_id)
diff --git a/python/pyspark/sql/streaming/value_state_client.py
b/python/pyspark/sql/streaming/value_state_client.py
index 3fe32bcc5235..fd783af7931d 100644
--- a/python/pyspark/sql/streaming/value_state_client.py
+++ b/python/pyspark/sql/streaming/value_state_client.py
@@ -28,7 +28,7 @@ class ValueStateClient:
self._stateful_processor_api_client = stateful_processor_api_client
def exists(self, state_name: str) -> bool:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
exists_call = stateMessage.Exists()
value_state_call = stateMessage.ValueStateCall(stateName=state_name,
exists=exists_call)
@@ -50,7 +50,7 @@ class ValueStateClient:
)
def get(self, state_name: str) -> Optional[Tuple]:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
get_call = stateMessage.Get()
value_state_call = stateMessage.ValueStateCall(stateName=state_name,
get=get_call)
@@ -70,7 +70,7 @@ class ValueStateClient:
raise PySparkRuntimeError(f"Error getting value state: "
f"{response_message[1]}")
def update(self, state_name: str, schema: Union[StructType, str], value:
Tuple) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
schema = cast(StructType, _parse_datatype_string(schema))
@@ -90,7 +90,7 @@ class ValueStateClient:
raise PySparkRuntimeError(f"Error updating value state: "
f"{response_message[1]}")
def clear(self, state_name: str) -> None:
- import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
clear_call = stateMessage.Clear()
value_state_call = stateMessage.ValueStateCall(stateName=state_name,
clear=clear_call)
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 7339897cb2cc..c888b6b3298f 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -75,6 +75,9 @@ class TransformWithStateInPandasTestsMixin:
input_path + "/text-test2.txt", [0, 0, 0, 1, 1], [123, 223, 323,
246, 6]
)
+ def _prepare_test_resource3(self, input_path):
+ self._prepare_input_data(input_path + "/text-test3.txt", [0, 1], [123,
6])
+
def _build_test_df(self, input_path):
df = self.spark.readStream.format("text").option("maxFilesPerTrigger",
1).load(input_path)
df_split = df.withColumn("split_values", split(df["value"], ","))
@@ -363,6 +366,277 @@ class TransformWithStateInPandasTestsMixin:
finally:
input_dir.cleanup()
+ def _test_transform_with_state_in_pandas_proc_timer(self,
stateful_processor, check_results):
+ input_path = tempfile.mkdtemp()
+ self._prepare_test_resource3(input_path)
+ self._prepare_test_resource1(input_path)
+ self._prepare_test_resource2(input_path)
+
+ df = self._build_test_df(input_path)
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("countAsString", StringType(), True),
+ StructField("timeValues", StringType(), True),
+ ]
+ )
+
+ query_name = "processing_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_proc_timer(self):
+ # helper function to check expired timestamp is smaller than current
processing time
+ def check_timestamp(batch_df):
+ expired_df = (
+ batch_df.filter(batch_df["countAsString"] == "-1")
+ .select("id", "timeValues")
+ .withColumnRenamed("timeValues", "expiredTimestamp")
+ )
+ count_df = (
+ batch_df.filter(batch_df["countAsString"] != "-1")
+ .select("id", "timeValues")
+ .withColumnRenamed("timeValues", "countStateTimestamp")
+ )
+ joined_df = expired_df.join(count_df, on="id")
+ for row in joined_df.collect():
+ assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="1"),
+ Row(id="1", countAsString="1"),
+ }
+ elif batch_id == 1:
+ # for key 0, the accumulated count is emitted before the count
state is cleared
+ # during the timer process
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="3"),
+ }
+ self.first_expired_timestamp = batch_df.filter(
+ batch_df["countAsString"] == -1
+ ).first()["timeValues"]
+ check_timestamp(batch_df)
+
+ else:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="5"),
+ }
+ # The expired timestamp in current batch is larger than expiry
timestamp in batch 1
+ # because this is a new timer registered in batch1 and
+ # different from the one registered in batch 0
+ current_batch_expired_timestamp = batch_df.filter(
+ batch_df["countAsString"] == -1
+ ).first()["timeValues"]
+ assert current_batch_expired_timestamp >
self.first_expired_timestamp
+
+ self._test_transform_with_state_in_pandas_proc_timer(
+ ProcTimeStatefulProcessor(), check_results
+ )
+
+ def _test_transform_with_state_in_pandas_event_time(self,
stateful_processor, check_results):
+ import pyspark.sql.functions as f
+
+ input_path = tempfile.mkdtemp()
+
+ def prepare_batch1(input_path):
+ with open(input_path + "/text-test3.txt", "w") as fw:
+ fw.write("a, 20\n")
+
+ def prepare_batch2(input_path):
+ with open(input_path + "/text-test1.txt", "w") as fw:
+ fw.write("a, 4\n")
+
+ def prepare_batch3(input_path):
+ with open(input_path + "/text-test2.txt", "w") as fw:
+ fw.write("a, 11\n")
+ fw.write("a, 13\n")
+ fw.write("a, 15\n")
+
+ prepare_batch1(input_path)
+ prepare_batch2(input_path)
+ prepare_batch3(input_path)
+
+ df = self._build_test_df(input_path)
+ df = df.select(
+ "id",
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")
+ ).withWatermark("eventTime", "10 seconds")
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [StructField("id", StringType(), True), StructField("timestamp",
StringType(), True)]
+ )
+
+ query_name = "event_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="eventtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_event_time(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {Row(id="a",
timestamp="20")}
+ elif batch_id == 1:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="20"),
+ Row(id="a-expired", timestamp="0"),
+ }
+ else:
+ # watermark has not progressed, so timer registered in batch
1(watermark = 10)
+ # has not yet expired
+ assert set(batch_df.sort("id").collect()) == {Row(id="a",
timestamp="15")}
+
+ self._test_transform_with_state_in_pandas_event_time(
+ EventTimeStatefulProcessor(), check_results
+ )
+
+
+# A stateful processor that output the max event time it has seen. Register
timer for
+# current watermark. Clear max state if timer expires.
+class EventTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.max_state = handle.getValueState("max_state", state_schema)
+
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
+ if expired_timer_info.is_valid():
+ self.max_state.clear()
+ self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+ str_key = f"{str(key[0])}-expired"
+ yield pd.DataFrame(
+ {"id": (str_key,), "timestamp":
str(expired_timer_info.get_expiry_time_in_ms())}
+ )
+
+ else:
+ timestamp_list = []
+ for pdf in rows:
+ # int64 will represent timestamp in nanosecond, restore to
second
+ timestamp_list.extend((pdf["eventTime"].astype("int64") //
10**9).tolist())
+
+ if self.max_state.exists():
+ cur_max = int(self.max_state.get()[0])
+ else:
+ cur_max = 0
+ max_event_time = str(max(cur_max, max(timestamp_list)))
+
+ self.max_state.update((max_event_time,))
+
self.handle.registerTimer(timer_values.get_current_watermark_in_ms())
+
+ yield pd.DataFrame({"id": key, "timestamp": max_event_time})
+
+ def close(self) -> None:
+ pass
+
+
+# A stateful processor that output the accumulation of count of input rows;
register
+# processing timer and clear the counter if timer expires.
+class ProcTimeStatefulProcessor(StatefulProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("value", StringType(), True)])
+ self.handle = handle
+ self.count_state = handle.getValueState("count_state", state_schema)
+
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
+ if expired_timer_info.is_valid():
+ # reset count state each time the timer is expired
+ timer_list_1 = [e for e in self.handle.listTimers()]
+ timer_list_2 = []
+ idx = 0
+ for e in self.handle.listTimers():
+ timer_list_2.append(e)
+ # check multiple iterator on the same grouping key works
+ assert timer_list_2[idx] == timer_list_1[idx]
+ idx += 1
+
+ if len(timer_list_1) > 0:
+ # before deleting the expiring timers, there are 2 timers -
+ # one timer we just registered, and one that is going to be
deleted
+ assert len(timer_list_1) == 2
+ self.count_state.clear()
+ self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+ yield pd.DataFrame(
+ {
+ "id": key,
+ "countAsString": str("-1"),
+ "timeValues":
str(expired_timer_info.get_expiry_time_in_ms()),
+ }
+ )
+
+ else:
+ if not self.count_state.exists():
+ count = 0
+ else:
+ count = int(self.count_state.get()[0])
+
+ if key == ("0",):
+
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms())
+
+ rows_count = 0
+ for pdf in rows:
+ pdf_count = len(pdf)
+ rows_count += pdf_count
+
+ count = count + rows_count
+
+ self.count_state.update((str(count),))
+ timestamp = str(timer_values.get_current_processing_time_in_ms())
+
+ yield pd.DataFrame({"id": key, "countAsString": str(count),
"timeValues": timestamp})
+
+ def close(self) -> None:
+ pass
+
class SimpleStatefulProcessor(StatefulProcessor):
dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
@@ -372,7 +646,9 @@ class SimpleStatefulProcessor(StatefulProcessor):
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations",
state_schema)
- def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
new_violations = 0
count = 0
key_str = key[0]
@@ -417,7 +693,9 @@ class TTLStatefulProcessor(StatefulProcessor):
"ttl-map-state", user_key_schema, state_schema, 10000
)
- def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
count = 0
ttl_count = 0
ttl_list_state_count = 0
@@ -467,7 +745,9 @@ class InvalidSimpleStatefulProcessor(StatefulProcessor):
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations",
state_schema)
- def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
count = 0
exists = self.num_violations_state.exists()
assert not exists
@@ -491,7 +771,9 @@ class ListStateProcessor(StatefulProcessor):
self.list_state1 = handle.getListState("listState1", state_schema)
self.list_state2 = handle.getListState("listState2", state_schema)
- def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
count = 0
for pdf in rows:
list_state_rows = [(120,), (20,)]
@@ -546,7 +828,9 @@ class MapStateProcessor(StatefulProcessor):
value_schema = StructType([StructField("count", IntegerType(), True)])
self.map_state = handle.getMapState("mapState", key_schema,
value_schema)
- def handleInputRows(self, key, rows):
+ def handleInputRows(
+ self, key, rows, timer_values, expired_timer_info
+ ) -> Iterator[pd.DataFrame]:
count = 0
key1 = ("key1",)
key2 = ("key2",)
diff --git
a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
index bb1c4c4f8e6c..544cd3b10b1c 100644
---
a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
+++
b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
@@ -25,6 +25,7 @@ message StateRequest {
StatefulProcessorCall statefulProcessorCall = 2;
StateVariableRequest stateVariableRequest = 3;
ImplicitGroupingKeyRequest implicitGroupingKeyRequest = 4;
+ TimerRequest timerRequest = 5;
}
}
@@ -34,12 +35,19 @@ message StateResponse {
bytes value = 3;
}
+message StateResponseWithLongTypeVal {
+ int32 statusCode = 1;
+ string errorMessage = 2;
+ int64 value = 3;
+}
+
message StatefulProcessorCall {
oneof method {
SetHandleState setHandleState = 1;
StateCallCommand getValueState = 2;
StateCallCommand getListState = 3;
StateCallCommand getMapState = 4;
+ TimerStateCallCommand timerStateCall = 5;
}
}
@@ -58,6 +66,30 @@ message ImplicitGroupingKeyRequest {
}
}
+message TimerRequest {
+ oneof method {
+ TimerValueRequest timerValueRequest = 1;
+ ExpiryTimerRequest expiryTimerRequest = 2;
+ }
+}
+
+message TimerValueRequest {
+ oneof method {
+ GetProcessingTime getProcessingTimer = 1;
+ GetWatermark getWatermark = 2;
+ }
+}
+
+message ExpiryTimerRequest {
+ int64 expiryTimestampMs = 1;
+}
+
+message GetProcessingTime {
+}
+
+message GetWatermark {
+}
+
message StateCallCommand {
string stateName = 1;
string schema = 2;
@@ -65,6 +97,14 @@ message StateCallCommand {
TTLConfig ttl = 4;
}
+message TimerStateCallCommand {
+ oneof method {
+ RegisterTimer register = 1;
+ DeleteTimer delete = 2;
+ ListTimers list = 3;
+ }
+}
+
message ValueStateCall {
string stateName = 1;
oneof method {
@@ -115,6 +155,18 @@ message Exists {
message Get {
}
+message RegisterTimer {
+ int64 expiryTimestampMs = 1;
+}
+
+message DeleteTimer {
+ int64 expiryTimestampMs = 1;
+}
+
+message ListTimers {
+ string iteratorId = 1;
+}
+
message ValueStateUpdate {
bytes value = 1;
}
@@ -169,7 +221,8 @@ enum HandleState {
CREATED = 0;
INITIALIZED = 1;
DATA_PROCESSED = 2;
- CLOSED = 3;
+ TIMER_PROCESSED = 3;
+ CLOSED = 4;
}
message SetHandleState {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index f6a3f1b1394f..f7e1ec79bfcd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -127,6 +127,8 @@ case class TransformWithStateInPandasExec(
case (store: StateStore, dataIterator: Iterator[InternalRow]) =>
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
val commitTimeMs = longMetric("commitTimeMs")
+ // TODO(SPARK-49603) set the metrics in the lazily initialized iterator
+ val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
val currentTimeNs = System.nanoTime
val updatesStartTimeNs = currentTimeNs
@@ -144,7 +146,9 @@ case class TransformWithStateInPandasExec(
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
- groupingKeySchema
+ groupingKeySchema,
+ batchTimestampMs,
+ eventTimeWatermarkForEviction
)
val outputIterator = executePython(data, output, runner)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
index b4b516ba9e5a..655fe259578f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
@@ -50,7 +50,9 @@ class TransformWithStateInPandasPythonRunner(
initialWorkerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
- groupingKeySchema: StructType)
+ groupingKeySchema: StructType,
+ batchTimestampMs: Option[Long] = None,
+ eventTimeWatermarkForEviction: Option[Long] = None)
extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType,
argOffsets, jobArtifactUUID)
with PythonArrowInput[InType]
with BasicPythonArrowOutput
@@ -104,7 +106,9 @@ class TransformWithStateInPandasPythonRunner(
executionContext.execute(
new TransformWithStateInPandasStateServer(stateServerSocket,
processorHandle,
groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames,
largeVarTypes,
- sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch))
+ sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch,
+ batchTimestampMs = batchTimestampMs,
+ eventTimeWatermarkForEviction = eventTimeWatermarkForEviction))
context.addTaskCompletionListener[Unit] { _ =>
logInfo(log"completion listener called")
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 3aed0b2463f3..8a67d5d47f05 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -34,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker,
StatefulProcessorHandleImpl, StatefulProcessorHandleState, StateVariableType}
-import
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState,
ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall,
StateRequest, StateResponse, StateVariableRequest, ValueStateCall}
+import
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState,
ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall,
StateRequest, StateResponse, StateResponseWithLongTypeVal,
StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest,
ValueStateCall}
import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig,
ValueState}
-import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+import org.apache.spark.sql.types.{BinaryType, LongType, StructField,
StructType}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
@@ -58,6 +58,8 @@ class TransformWithStateInPandasStateServer(
errorOnDuplicatedFieldNames: Boolean,
largeVarTypes: Boolean,
arrowTransformWithStateInPandasMaxRecordsPerBatch: Int,
+ batchTimestampMs: Option[Long] = None,
+ eventTimeWatermarkForEviction: Option[Long] = None,
outputStreamForTest: DataOutputStream = null,
valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null,
deserializerForTest: TransformWithStateInPandasDeserializer = null,
@@ -65,12 +67,19 @@ class TransformWithStateInPandasStateServer(
listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
iteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
mapStatesMapForTest : mutable.HashMap[String, MapStateInfo] = null,
- keyValueIteratorMapForTest: mutable.HashMap[String, Iterator[(Row, Row)]]
= null)
+ keyValueIteratorMapForTest: mutable.HashMap[String, Iterator[(Row, Row)]]
= null,
+ expiryTimerIterForTest: Iterator[(Any, Long)] = null,
+ listTimerMapForTest: mutable.HashMap[String, Iterator[Long]] = null)
extends Runnable with Logging {
+
+ import PythonResponseWriterUtils._
+
private val keyRowDeserializer: ExpressionEncoder.Deserializer[Row] =
ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()
private var inputStream: DataInputStream = _
private var outputStream: DataOutputStream = outputStreamForTest
+
+ /** State variable related class variables */
// A map to store the value state name -> (value state, schema, value row
deserializer) mapping.
private val valueStates = if (valueStateMapForTest != null) {
valueStateMapForTest
@@ -110,6 +119,20 @@ class TransformWithStateInPandasStateServer(
new mutable.HashMap[String, Iterator[(Row, Row)]]()
}
+ /** Timer related class variables */
+ private var expiryTimestampIter: Option[Iterator[(Any, Long)]] =
+ if (expiryTimerIterForTest != null) {
+ Option(expiryTimerIterForTest)
+ } else None
+
+ // A map to store the iterator id -> Iterator[Long] mapping. This is to keep
track of the
+ // current iterator position for each iterator id in the same partition for
a grouping key in case
+ // user tries to fetch multiple iterators before the current iterator is
exhausted. This is
+ // used for list timer function call
+ private var listTimerIters = if (listTimerMapForTest != null) {
+ listTimerMapForTest
+ } else new mutable.HashMap[String, Iterator[Long]]()
+
def run(): Unit = {
val listeningSocket = stateServerSocket.accept()
inputStream = new DataInputStream(
@@ -159,11 +182,62 @@ class TransformWithStateInPandasStateServer(
handleStatefulProcessorCall(message.getStatefulProcessorCall)
case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
handleStateVariableRequest(message.getStateVariableRequest)
+ case StateRequest.MethodCase.TIMERREQUEST =>
+ handleTimerRequest(message.getTimerRequest)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
}
+ private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+ message.getMethodCase match {
+ case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+ val timerRequest = message.getTimerValueRequest()
+ timerRequest.getMethodCase match {
+ case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+ val procTimestamp: Long =
+ if (batchTimestampMs.isDefined) batchTimestampMs.get else -1L
+ sendResponseWithLongVal(0, null, procTimestamp)
+ case TimerValueRequest.MethodCase.GETWATERMARK =>
+ val eventTimestamp: Long =
+ if (eventTimeWatermarkForEviction.isDefined)
eventTimeWatermarkForEviction.get
+ else -1L
+ sendResponseWithLongVal(0, null, eventTimestamp)
+ case _ =>
+ throw new IllegalArgumentException("Invalid timer value method
call")
+ }
+
+ case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+ // Note that for `getExpiryTimers` python call, as this is not a public
+ // API and it will only be used by `group_ops` once per partition, we
won't
+ // need to worry about different function calls will interleaved and
hence
+ // this implementation is safe
+ val expiryRequest = message.getExpiryTimerRequest()
+ val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+ if (!expiryTimestampIter.isDefined) {
+ expiryTimestampIter =
+ Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp))
+ }
+ // expiryTimestampIter could be None in the TWSPandasServerSuite
+ if (!expiryTimestampIter.isDefined ||
!expiryTimestampIter.get.hasNext) {
+ // iterator is exhausted, signal the end of iterator on python client
+ sendResponse(1)
+ } else {
+ sendResponse(0)
+ val outputSchema = new StructType()
+ .add("key", BinaryType)
+ .add(StructField("timestamp", LongType))
+ sendIteratorAsArrowBatches(expiryTimestampIter.get, outputSchema,
+ arrowStreamWriterForTest) { data =>
+ InternalRow(PythonSQLUtils.toPyRow(data._1.asInstanceOf[Row]),
data._2)
+ }
+ }
+
+ case _ =>
+ throw new IllegalArgumentException("Invalid timer request method call")
+ }
+ }
+
private def handleImplicitGroupingKeyRequest(message:
ImplicitGroupingKeyRequest): Unit = {
message.getMethodCase match {
case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY =>
@@ -173,11 +247,13 @@ class TransformWithStateInPandasStateServer(
ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
// Reset the list/map state iterators for a new grouping key.
iterators = new mutable.HashMap[String, Iterator[Row]]()
+ listTimerIters = new mutable.HashMap[String, Iterator[Long]]()
sendResponse(0)
case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY =>
ImplicitGroupingKeyTracker.removeImplicitKey()
// Reset the list/map state iterators for a new grouping key.
iterators = new mutable.HashMap[String, Iterator[Row]]()
+ listTimerIters = new mutable.HashMap[String, Iterator[Long]]()
sendResponse(0)
case _ =>
throw new IllegalArgumentException("Invalid method call")
@@ -195,6 +271,12 @@ class TransformWithStateInPandasStateServer(
case HandleState.INITIALIZED =>
logInfo(log"set handle state to Initialized")
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+ case HandleState.DATA_PROCESSED =>
+ logInfo(log"set handle state to Data Processed")
+
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED)
+ case HandleState.TIMER_PROCESSED =>
+ logInfo(log"set handle state to Timer Processed")
+
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.TIMER_PROCESSED)
case HandleState.CLOSED =>
logInfo(log"set handle state to Closed")
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
@@ -226,6 +308,41 @@ class TransformWithStateInPandasStateServer(
} else None
initializeStateVariable(stateName, userKeySchema,
StateVariableType.MapState, ttlDurationMs,
valueSchema)
+ case StatefulProcessorCall.MethodCase.TIMERSTATECALL =>
+ message.getTimerStateCall.getMethodCase match {
+ case TimerStateCallCommand.MethodCase.REGISTER =>
+ val expiryTimestamp =
+ message.getTimerStateCall.getRegister.getExpiryTimestampMs
+ statefulProcessorHandle.registerTimer(expiryTimestamp)
+ sendResponse(0)
+ case TimerStateCallCommand.MethodCase.DELETE =>
+ val expiryTimestamp =
+ message.getTimerStateCall.getDelete.getExpiryTimestampMs
+ statefulProcessorHandle.deleteTimer(expiryTimestamp)
+ sendResponse(0)
+ case TimerStateCallCommand.MethodCase.LIST =>
+ val iteratorId = message.getTimerStateCall.getList.getIteratorId
+ var iteratorOption = listTimerIters.get(iteratorId)
+ if (iteratorOption.isEmpty) {
+ iteratorOption = Option(statefulProcessorHandle.listTimers())
+ listTimerIters.put(iteratorId, iteratorOption.get)
+ }
+ if (!iteratorOption.get.hasNext) {
+ sendResponse(2, s"List timer iterator doesn't contain any
value.")
+ return
+ } else {
+ sendResponse(0)
+ }
+ val outputSchema = new StructType()
+ .add(StructField("timestamp", LongType))
+ sendIteratorAsArrowBatches(iteratorOption.get, outputSchema,
+ arrowStreamWriterForTest) { data =>
+ InternalRow(data)
+ }
+
+ case _ =>
+ throw new IllegalArgumentException("Invalid timer state method
call")
+ }
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
@@ -463,24 +580,6 @@ class TransformWithStateInPandasStateServer(
}
}
- private def sendResponse(
- status: Int,
- errorMessage: String = null,
- byteString: ByteString = null): Unit = {
- val responseMessageBuilder =
StateResponse.newBuilder().setStatusCode(status)
- if (status != 0 && errorMessage != null) {
- responseMessageBuilder.setErrorMessage(errorMessage)
- }
- if (byteString != null) {
- responseMessageBuilder.setValue(byteString)
- }
- val responseMessage = responseMessageBuilder.build()
- val responseMessageBytes = responseMessage.toByteArray
- val byteLength = responseMessageBytes.length
- outputStream.writeInt(byteLength)
- outputStream.write(responseMessageBytes)
- }
-
private def initializeStateVariable(
stateName: String,
schemaString: String,
@@ -490,90 +589,128 @@ class TransformWithStateInPandasStateServer(
val schema = StructType.fromString(schemaString)
val expressionEncoder = ExpressionEncoder(schema).resolveAndBind()
stateType match {
- case StateVariableType.ValueState => if
(!valueStates.contains(stateName)) {
- val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
- } else {
- statefulProcessorHandle.getValueState(
- stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
- }
- valueStates.put(stateName,
- ValueStateInfo(state, schema,
expressionEncoder.createDeserializer()))
- sendResponse(0)
+ case StateVariableType.ValueState => if
(!valueStates.contains(stateName)) {
+ val state = if (ttlDurationMs.isEmpty) {
+ statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
} else {
- sendResponse(1, s"Value state $stateName already exists")
+ statefulProcessorHandle.getValueState(
+ stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
}
- case StateVariableType.ListState => if
(!listStates.contains(stateName)) {
- val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getListState[Row](stateName,
Encoders.row(schema))
- } else {
- statefulProcessorHandle.getListState(
- stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
- }
- listStates.put(stateName,
- ListStateInfo(state, schema,
expressionEncoder.createDeserializer(),
- expressionEncoder.createSerializer()))
- sendResponse(0)
+ valueStates.put(stateName,
+ ValueStateInfo(state, schema,
expressionEncoder.createDeserializer()))
+ sendResponse(0)
+ } else {
+ sendResponse(1, s"Value state $stateName already exists")
+ }
+ case StateVariableType.ListState => if (!listStates.contains(stateName))
{
+ val state = if (ttlDurationMs.isEmpty) {
+ statefulProcessorHandle.getListState[Row](stateName,
Encoders.row(schema))
} else {
- sendResponse(1, s"List state $stateName already exists")
+ statefulProcessorHandle.getListState(
+ stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
}
- case StateVariableType.MapState => if (!mapStates.contains(stateName))
{
- val valueSchema = StructType.fromString(mapStateValueSchemaString)
- val valueExpressionEncoder =
ExpressionEncoder(valueSchema).resolveAndBind()
- val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getMapState[Row, Row](stateName,
- Encoders.row(schema), Encoders.row(valueSchema))
- } else {
- statefulProcessorHandle.getMapState[Row, Row](stateName,
Encoders.row(schema),
- Encoders.row(valueSchema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
- }
- mapStates.put(stateName,
- MapStateInfo(state, schema, valueSchema,
expressionEncoder.createDeserializer(),
- expressionEncoder.createSerializer(),
valueExpressionEncoder.createDeserializer(),
- valueExpressionEncoder.createSerializer()))
- sendResponse(0)
+ listStates.put(stateName,
+ ListStateInfo(state, schema, expressionEncoder.createDeserializer(),
+ expressionEncoder.createSerializer()))
+ sendResponse(0)
+ } else {
+ sendResponse(1, s"List state $stateName already exists")
+ }
+ case StateVariableType.MapState => if (!mapStates.contains(stateName)) {
+ val valueSchema = StructType.fromString(mapStateValueSchemaString)
+ val valueExpressionEncoder =
ExpressionEncoder(valueSchema).resolveAndBind()
+ val state = if (ttlDurationMs.isEmpty) {
+ statefulProcessorHandle.getMapState[Row, Row](stateName,
+ Encoders.row(schema), Encoders.row(valueSchema))
} else {
- sendResponse(1, s"Map state $stateName already exists")
+ statefulProcessorHandle.getMapState[Row, Row](stateName,
Encoders.row(schema),
+ Encoders.row(valueSchema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
}
+ mapStates.put(stateName,
+ MapStateInfo(state, schema, valueSchema,
expressionEncoder.createDeserializer(),
+ expressionEncoder.createSerializer(),
valueExpressionEncoder.createDeserializer(),
+ valueExpressionEncoder.createSerializer()))
+ sendResponse(0)
+ } else {
+ sendResponse(1, s"Map state $stateName already exists")
+ }
}
}
- private def sendIteratorAsArrowBatches[T](
- iter: Iterator[T],
- outputSchema: StructType,
- arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T =>
InternalRow): Unit = {
- outputStream.flush()
- val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
- errorOnDuplicatedFieldNames, largeVarTypes)
- val allocator = ArrowUtils.rootAllocator.newChildAllocator(
- s"stdout writer for transformWithStateInPandas state socket", 0,
Long.MaxValue)
- val root = VectorSchemaRoot.create(arrowSchema, allocator)
- val writer = new ArrowStreamWriter(root, null, outputStream)
- val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
- arrowStreamWriterForTest
- } else {
- new BaseStreamingArrowWriter(root, writer,
arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ /** Utils object for sending response to Python client. */
+ private object PythonResponseWriterUtils {
+ def sendResponse(
+ status: Int,
+ errorMessage: String = null,
+ byteString: ByteString = null): Unit = {
+ val responseMessageBuilder =
StateResponse.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ if (byteString != null) {
+ responseMessageBuilder.setValue(byteString)
+ }
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
}
- // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
- // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to
handle a case
- // when there are multiple state variables, user tries to access a
different state variable
- // while the current state variable is not exhausted yet.
- var rowCount = 0
- while (iter.hasNext && rowCount <
arrowTransformWithStateInPandasMaxRecordsPerBatch) {
- val data = iter.next()
- val internalRow = func(data)
- arrowStreamWriter.writeRow(internalRow)
- rowCount += 1
+
+ def sendResponseWithLongVal(
+ status: Int,
+ errorMessage: String = null,
+ longVal: Long): Unit = {
+ val responseMessageBuilder =
StateResponseWithLongTypeVal.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ responseMessageBuilder.setValue(longVal)
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
}
- arrowStreamWriter.finalizeCurrentArrowBatch()
- Utils.tryWithSafeFinally {
- // end writes footer to the output stream and doesn't clean any
resources.
- // It could throw exception if the output stream is closed, so it should
be
- // in the try block.
- writer.end()
- } {
- root.close()
- allocator.close()
+
+ def sendIteratorAsArrowBatches[T](
+ iter: Iterator[T],
+ outputSchema: StructType,
+ arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T =>
InternalRow): Unit = {
+ outputStream.flush()
+ val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
+ errorOnDuplicatedFieldNames, largeVarTypes)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for transformWithStateInPandas state socket", 0,
Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ val writer = new ArrowStreamWriter(root, null, outputStream)
+ val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
+ arrowStreamWriterForTest
+ } else {
+ new BaseStreamingArrowWriter(root, writer,
+ arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ }
+ // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
+ // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is
to handle a case
+ // when there are multiple state variables, user tries to access a
different state variable
+ // while the current state variable is not exhausted yet.
+ var rowCount = 0
+ while (iter.hasNext && rowCount <
arrowTransformWithStateInPandasMaxRecordsPerBatch) {
+ val data = iter.next()
+ val internalRow = func(data)
+ arrowStreamWriter.writeRow(internalRow)
+ rowCount += 1
+ }
+ arrowStreamWriter.finalizeCurrentArrowBatch()
+ Utils.tryWithSafeFinally {
+ // end writes footer to the output stream and doesn't clean any
resources.
+ // It could throw exception if the output stream is closed, so it
should be
+ // in the try block.
+ writer.end()
+ } {
+ root.close()
+ allocator.close()
+ }
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index 2a728dc81d0b..aaaa53059126 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -32,7 +32,7 @@ import
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl,
StatefulProcessorHandleState}
import org.apache.spark.sql.execution.streaming.state.StateMessage
-import
org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList,
AppendValue, Clear, ContainsKey, Exists, Get, GetValue, HandleState, Keys,
ListStateCall, ListStateGet, ListStatePut, MapStateCall, RemoveKey,
SetHandleState, StateCallCommand, StatefulProcessorCall, UpdateValue, Values,
ValueStateCall, ValueStateUpdate}
+import
org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList,
AppendValue, Clear, ContainsKey, DeleteTimer, Exists, ExpiryTimerRequest, Get,
GetProcessingTime, GetValue, GetWatermark, HandleState, Keys, ListStateCall,
ListStateGet, ListStatePut, ListTimers, MapStateCall, RegisterTimer, RemoveKey,
SetHandleState, StateCallCommand, StatefulProcessorCall, TimerRequest,
TimerStateCallCommand, TimerValueRequest, UpdateValue, Values, ValueStateCall,
ValueStateUpdate}
import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig,
ValueState}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
@@ -59,9 +59,13 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
var stateSerializer: ExpressionEncoder.Serializer[Row] = _
var transformWithStateInPandasDeserializer:
TransformWithStateInPandasDeserializer = _
var arrowStreamWriter: BaseStreamingArrowWriter = _
+ var batchTimestampMs: Option[Long] = _
+ var eventTimeWatermarkForEviction: Option[Long] = _
var valueStateMap: mutable.HashMap[String, ValueStateInfo] =
mutable.HashMap()
var listStateMap: mutable.HashMap[String, ListStateInfo] = mutable.HashMap()
var mapStateMap: mutable.HashMap[String, MapStateInfo] = mutable.HashMap()
+ var expiryTimerIter: Iterator[(Any, Long)] = _
+ var listTimerMap: mutable.HashMap[String, Iterator[Long]] = mutable.HashMap()
override def beforeEach(): Unit = {
statefulProcessorHandle = mock(classOf[StatefulProcessorHandleImpl])
@@ -83,15 +87,20 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
// reset the iterator map to empty so be careful to call it if you want to
access the iterator
// map later.
val testRow = getIntegerRow(1)
+ expiryTimerIter = Iterator.single(testRow, 1L /* a random long type value
*/)
val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId ->
Iterator(testRow))
val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row,
Row)]](iteratorId ->
Iterator((testRow, testRow)))
+ listTimerMap = mutable.HashMap[String, Iterator[Long]](iteratorId ->
Iterator(1L))
transformWithStateInPandasDeserializer =
mock(classOf[TransformWithStateInPandasDeserializer])
arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter])
+ batchTimestampMs = mock(classOf[Option[Long]])
+ eventTimeWatermarkForEviction = mock(classOf[Option[Long]])
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false, 2,
+ batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer,
arrowStreamWriter,
- listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap)
+ listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap,
expiryTimerIter, listTimerMap)
when(transformWithStateInPandasDeserializer.readArrowBatches(any))
.thenReturn(Seq(getIntegerRow(1)))
}
@@ -251,8 +260,9 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3),
getIntegerRow(4)))
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false,
- maxRecordsPerBatch, outputStream, valueStateMap,
- transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap,
iteratorMap)
+ maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
+ valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
+ listStateMap, iteratorMap)
// First call should send 2 records.
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
@@ -274,8 +284,9 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false,
- maxRecordsPerBatch, outputStream, valueStateMap,
- transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap,
iteratorMap)
+ maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
+ valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
+ listStateMap, iteratorMap)
when(listState.get()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2), getIntegerRow(3)))
stateServer.handleListStateRequest(message)
verify(listState).get()
@@ -362,8 +373,9 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
(getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4),
getIntegerRow(4))))
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false,
- maxRecordsPerBatch, outputStream, valueStateMap,
transformWithStateInPandasDeserializer,
- arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap)
+ maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
+ valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
+ listStateMap, null, mapStateMap, keyValueIteratorMap)
// First call should send 2 records.
stateServer.handleMapStateRequest(message)
verify(mapState, times(0)).iterator()
@@ -385,7 +397,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] =
mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false,
- maxRecordsPerBatch, outputStream, valueStateMap,
transformWithStateInPandasDeserializer,
+ maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
+ outputStream, valueStateMap, transformWithStateInPandasDeserializer,
arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap)
when(mapState.iterator()).thenReturn(Iterator((getIntegerRow(1),
getIntegerRow(1)),
(getIntegerRow(2), getIntegerRow(2)), (getIntegerRow(3),
getIntegerRow(3))))
@@ -413,7 +426,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false,
- maxRecordsPerBatch, outputStream, valueStateMap,
transformWithStateInPandasDeserializer,
+ maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
+ outputStream, valueStateMap, transformWithStateInPandasDeserializer,
arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
when(mapState.keys()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2), getIntegerRow(3)))
stateServer.handleMapStateRequest(message)
@@ -440,7 +454,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
statefulProcessorHandle, groupingKeySchema, "", false, false,
- maxRecordsPerBatch, outputStream, valueStateMap,
transformWithStateInPandasDeserializer,
+ maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream,
+ valueStateMap, transformWithStateInPandasDeserializer,
arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
when(mapState.values()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2),
getIntegerRow(3)))
@@ -460,6 +475,92 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
verify(mapState).removeKey(any[Row])
}
+ test("timer value get processing time") {
+ val message = TimerRequest.newBuilder().setTimerValueRequest(
+ TimerValueRequest.newBuilder().setGetProcessingTimer(
+ GetProcessingTime.newBuilder().build()
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(batchTimestampMs).isDefined
+ verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+ }
+
+ test("timer value get watermark") {
+ val message = TimerRequest.newBuilder().setTimerValueRequest(
+ TimerValueRequest.newBuilder().setGetWatermark(
+ GetWatermark.newBuilder().build()
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(eventTimeWatermarkForEviction).isDefined
+ verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+ }
+
+ test("get expiry timers") {
+ val message = TimerRequest.newBuilder().setExpiryTimerRequest(
+ ExpiryTimerRequest.newBuilder().setExpiryTimestampMs(
+ 10L
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(arrowStreamWriter).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ }
+
+ test("stateful processor register timer") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+
.setRegister(RegisterTimer.newBuilder().setExpiryTimestampMs(10L).build())
+ .build()
+ ).build()
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle).registerTimer(any[Long])
+ verify(outputStream).writeInt(0)
+ }
+
+ test("stateful processor delete timer") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+ .setDelete(DeleteTimer.newBuilder().setExpiryTimestampMs(10L).build())
+ .build()
+ ).build()
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle).deleteTimer(any[Long])
+ verify(outputStream).writeInt(0)
+ }
+
+ test("stateful processor list timer - iterator in map") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+ .setList(ListTimers.newBuilder().setIteratorId(iteratorId).build())
+ .build()
+ ).build()
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle, times(0)).listTimers()
+ verify(arrowStreamWriter).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ }
+
+ test("stateful processor list timer - iterator not in map") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+ .setList(ListTimers.newBuilder().setIteratorId("non-exist").build())
+ .build()
+ ).build()
+ stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ statefulProcessorHandle, groupingKeySchema, "", false, false,
+ 2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
+ valueStateMap, transformWithStateInPandasDeserializer,
+ arrowStreamWriter, listStateMap, null, mapStateMap, null,
+ null, listTimerMap)
+ when(statefulProcessorHandle.listTimers()).thenReturn(Iterator(1))
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle, times(1)).listTimers()
+ verify(arrowStreamWriter).writeRow(any)
+ verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ }
+
private def getIntegerRow(value: Int): Row = {
new GenericRowWithSchema(Array(value), stateSchema)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]