This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ace2ebf7ae1 [SPARK-49513][SS] Add Support for timer in 
transformWithStateInPandas API
2ace2ebf7ae1 is described below

commit 2ace2ebf7ae1aa9877de354ac821444aecb1ac89
Author: jingz-db <[email protected]>
AuthorDate: Thu Oct 31 11:15:12 2024 +0900

    [SPARK-49513][SS] Add Support for timer in transformWithStateInPandas API
    
    ### What changes were proposed in this pull request?
    
    Support for timer in TransformWithStateInPandas Python API.
    
    ### Why are the changes needed?
    
    To couple with Scala API, TransformWithStateInPandas should also support 
processing/event time timer for arbitrary state.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users can now interact with timers from `handleInputRows` with two 
addtional parameters as:
    ```
    def handleInputRows(
            self, key: Any, rows: Iterator["PandasDataFrameLike"],
              timer_values: TimerValues,
              expired_timer_info: ExpiredTimerInfo)
    ```
    
    And user can interact with a newly introduce `TimerValues` to get 
processing/event time for current batch:
    ```
    class TimerValues:
        def get_current_processing_time_in_ms(self) -> int
    
        def get_current_watermark_in_ms(self) -> int
    ```
    
    Users can also interact with `expired_timer_info` to get the timestamp for 
expired timers:
    ```
    class ExpiredTimerInfo:
        def is_valid(self) -> bool
    
        def get_expiry_time_in_ms(self) ->
    ```
    ### How was this patch tested?
    
    Unit tests in `TransformWithStateInPandasStateServerSuite` and integration 
tests in `test_pandas_transform_with_state.py`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47878 from jingz-db/python-timer-impl.
    
    Lead-authored-by: jingz-db <[email protected]>
    Co-authored-by: Jing Zhan <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 python/mypy.ini                                    |   2 +-
 python/pyspark/sql/pandas/group_ops.py             |  51 +++-
 python/pyspark/sql/streaming/StateMessage_pb2.py   |  99 -------
 python/pyspark/sql/streaming/list_state_client.py  |  12 +-
 python/pyspark/sql/streaming/map_state_client.py   |  16 +-
 .../sql/streaming/proto/StateMessage_pb2.py        | 124 ++++++++
 .../sql/streaming/{ => proto}/StateMessage_pb2.pyi |  98 ++++++-
 python/pyspark/sql/streaming/proto/__init__.py     |  16 +
 python/pyspark/sql/streaming/stateful_processor.py |  84 +++++-
 .../sql/streaming/stateful_processor_api_client.py | 197 ++++++++++++-
 python/pyspark/sql/streaming/value_state_client.py |   8 +-
 .../pandas/test_pandas_transform_with_state.py     | 294 ++++++++++++++++++-
 .../sql/execution/streaming/StateMessage.proto     |  55 +++-
 .../python/TransformWithStateInPandasExec.scala    |   6 +-
 .../TransformWithStateInPandasPythonRunner.scala   |   8 +-
 .../TransformWithStateInPandasStateServer.scala    | 325 +++++++++++++++------
 ...ransformWithStateInPandasStateServerSuite.scala | 123 +++++++-
 17 files changed, 1272 insertions(+), 246 deletions(-)

diff --git a/python/mypy.ini b/python/mypy.ini
index 4daa18593334..cb3595949e8d 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -181,5 +181,5 @@ ignore_missing_imports = True
 ignore_missing_imports = True
 
 ; Ignore errors for proto generated code
-[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto]
+[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto, 
pyspark.sql.streaming.proto]
 ignore_errors = True
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 0d21edc73b81..1c87df538e1f 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import itertools
 import sys
 from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast
 import warnings
@@ -27,6 +28,12 @@ from pyspark.sql.streaming.stateful_processor_api_client 
import (
     StatefulProcessorApiClient,
     StatefulProcessorHandleState,
 )
+from pyspark.sql.streaming.stateful_processor import (
+    ExpiredTimerInfo,
+    StatefulProcessor,
+    StatefulProcessorHandle,
+    TimerValues,
+)
 from pyspark.sql.streaming.stateful_processor import StatefulProcessor, 
StatefulProcessorHandle
 from pyspark.sql.types import StructType, _parse_datatype_string
 
@@ -501,7 +508,49 @@ class PandasGroupedOpsMixin:
                 )
 
             statefulProcessorApiClient.set_implicit_key(key)
-            result = statefulProcessor.handleInputRows(key, inputRows)
+
+            if timeMode != "none":
+                batch_timestamp = 
statefulProcessorApiClient.get_batch_timestamp()
+                watermark_timestamp = 
statefulProcessorApiClient.get_watermark_timestamp()
+            else:
+                batch_timestamp = -1
+                watermark_timestamp = -1
+            # process with invalid expiry timer info and emit data rows
+            data_iter = statefulProcessor.handleInputRows(
+                key,
+                inputRows,
+                TimerValues(batch_timestamp, watermark_timestamp),
+                ExpiredTimerInfo(False),
+            )
+            
statefulProcessorApiClient.set_handle_state(StatefulProcessorHandleState.DATA_PROCESSED)
+
+            if timeMode == "processingtime":
+                expiry_list_iter = 
statefulProcessorApiClient.get_expiry_timers_iterator(
+                    batch_timestamp
+                )
+            elif timeMode == "eventtime":
+                expiry_list_iter = 
statefulProcessorApiClient.get_expiry_timers_iterator(
+                    watermark_timestamp
+                )
+            else:
+                expiry_list_iter = iter([[]])
+
+            result_iter_list = [data_iter]
+            # process with valid expiry time info and with empty input rows,
+            # only timer related rows will be emitted
+            for expiry_list in expiry_list_iter:
+                for key_obj, expiry_timestamp in expiry_list:
+                    result_iter_list.append(
+                        statefulProcessor.handleInputRows(
+                            key_obj,
+                            iter([]),
+                            TimerValues(batch_timestamp, watermark_timestamp),
+                            ExpiredTimerInfo(True, expiry_timestamp),
+                        )
+                    )
+            # TODO(SPARK-49603) set the handle state in the lazily initialized 
iterator
+
+            result = itertools.chain(*result_iter_list)
 
             return result
 
diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py 
b/python/pyspark/sql/streaming/StateMessage_pb2.py
deleted file mode 100644
index 9c7740c0a223..000000000000
--- a/python/pyspark/sql/streaming/StateMessage_pb2.py
+++ /dev/null
@@ -1,99 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# -*- coding: utf-8 -*-
-# Generated by the protocol buffer compiler.  DO NOT EDIT!
-# NO CHECKED-IN PROTOBUF GENCODE
-# source: StateMessage.proto
-# Protobuf Python Version: 5.27.3
-"""Generated protocol buffer code."""
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import symbol_database as _symbol_database
-from google.protobuf.internal import builder as _builder
-
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01
 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 
\x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03
 
\x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04
 \x01(\x0b\x32J.org.apache.spark [...]
-)
-
-_globals = globals()
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", 
_globals)
-if not _descriptor._USE_C_DESCRIPTORS:
-    DESCRIPTOR._loaded_options = None
-    _globals["_HANDLESTATE"]._serialized_start = 3778
-    _globals["_HANDLESTATE"]._serialized_end = 3853
-    _globals["_STATEREQUEST"]._serialized_start = 71
-    _globals["_STATEREQUEST"]._serialized_end = 432
-    _globals["_STATERESPONSE"]._serialized_start = 434
-    _globals["_STATERESPONSE"]._serialized_end = 506
-    _globals["_STATEFULPROCESSORCALL"]._serialized_start = 509
-    _globals["_STATEFULPROCESSORCALL"]._serialized_end = 902
-    _globals["_STATEVARIABLEREQUEST"]._serialized_start = 905
-    _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1201
-    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1204
-    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1428
-    _globals["_STATECALLCOMMAND"]._serialized_start = 1431
-    _globals["_STATECALLCOMMAND"]._serialized_end = 1585
-    _globals["_VALUESTATECALL"]._serialized_start = 1588
-    _globals["_VALUESTATECALL"]._serialized_end = 1941
-    _globals["_LISTSTATECALL"]._serialized_start = 1944
-    _globals["_LISTSTATECALL"]._serialized_end = 2472
-    _globals["_MAPSTATECALL"]._serialized_start = 2475
-    _globals["_MAPSTATECALL"]._serialized_end = 3212
-    _globals["_SETIMPLICITKEY"]._serialized_start = 3214
-    _globals["_SETIMPLICITKEY"]._serialized_end = 3243
-    _globals["_REMOVEIMPLICITKEY"]._serialized_start = 3245
-    _globals["_REMOVEIMPLICITKEY"]._serialized_end = 3264
-    _globals["_EXISTS"]._serialized_start = 3266
-    _globals["_EXISTS"]._serialized_end = 3274
-    _globals["_GET"]._serialized_start = 3276
-    _globals["_GET"]._serialized_end = 3281
-    _globals["_VALUESTATEUPDATE"]._serialized_start = 3283
-    _globals["_VALUESTATEUPDATE"]._serialized_end = 3316
-    _globals["_CLEAR"]._serialized_start = 3318
-    _globals["_CLEAR"]._serialized_end = 3325
-    _globals["_LISTSTATEGET"]._serialized_start = 3327
-    _globals["_LISTSTATEGET"]._serialized_end = 3361
-    _globals["_LISTSTATEPUT"]._serialized_start = 3363
-    _globals["_LISTSTATEPUT"]._serialized_end = 3377
-    _globals["_APPENDVALUE"]._serialized_start = 3379
-    _globals["_APPENDVALUE"]._serialized_end = 3407
-    _globals["_APPENDLIST"]._serialized_start = 3409
-    _globals["_APPENDLIST"]._serialized_end = 3421
-    _globals["_GETVALUE"]._serialized_start = 3423
-    _globals["_GETVALUE"]._serialized_end = 3450
-    _globals["_CONTAINSKEY"]._serialized_start = 3452
-    _globals["_CONTAINSKEY"]._serialized_end = 3482
-    _globals["_UPDATEVALUE"]._serialized_start = 3484
-    _globals["_UPDATEVALUE"]._serialized_end = 3529
-    _globals["_ITERATOR"]._serialized_start = 3531
-    _globals["_ITERATOR"]._serialized_end = 3561
-    _globals["_KEYS"]._serialized_start = 3563
-    _globals["_KEYS"]._serialized_end = 3589
-    _globals["_VALUES"]._serialized_start = 3591
-    _globals["_VALUES"]._serialized_end = 3619
-    _globals["_REMOVEKEY"]._serialized_start = 3621
-    _globals["_REMOVEKEY"]._serialized_end = 3649
-    _globals["_SETHANDLESTATE"]._serialized_start = 3651
-    _globals["_SETHANDLESTATE"]._serialized_end = 3743
-    _globals["_TTLCONFIG"]._serialized_start = 3745
-    _globals["_TTLCONFIG"]._serialized_end = 3776
-# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/list_state_client.py 
b/python/pyspark/sql/streaming/list_state_client.py
index 3615f6819bdd..d2152842819a 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -35,7 +35,7 @@ class ListStateClient:
         self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
 
     def exists(self, state_name: str) -> bool:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         exists_call = stateMessage.Exists()
         list_state_call = stateMessage.ListStateCall(stateName=state_name, 
exists=exists_call)
@@ -57,7 +57,7 @@ class ListStateClient:
             )
 
     def get(self, state_name: str, iterator_id: str) -> Tuple:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if iterator_id in self.pandas_df_dict:
             # If the state is already in the dictionary, return the next row.
@@ -106,7 +106,7 @@ class ListStateClient:
         return tuple(pandas_row)
 
     def append_value(self, state_name: str, schema: Union[StructType, str], 
value: Tuple) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
             schema = cast(StructType, _parse_datatype_string(schema))
@@ -128,7 +128,7 @@ class ListStateClient:
     def append_list(
         self, state_name: str, schema: Union[StructType, str], values: 
List[Tuple]
     ) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
             schema = cast(StructType, _parse_datatype_string(schema))
@@ -149,7 +149,7 @@ class ListStateClient:
             raise PySparkRuntimeError(f"Error updating value state: " 
f"{response_message[1]}")
 
     def put(self, state_name: str, schema: Union[StructType, str], values: 
List[Tuple]) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
             schema = cast(StructType, _parse_datatype_string(schema))
@@ -168,7 +168,7 @@ class ListStateClient:
             raise PySparkRuntimeError(f"Error updating value state: " 
f"{response_message[1]}")
 
     def clear(self, state_name: str) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         clear_call = stateMessage.Clear()
         list_state_call = stateMessage.ListStateCall(stateName=state_name, 
clear=clear_call)
diff --git a/python/pyspark/sql/streaming/map_state_client.py 
b/python/pyspark/sql/streaming/map_state_client.py
index 54a5ba1bbffa..6ec7448b4863 100644
--- a/python/pyspark/sql/streaming/map_state_client.py
+++ b/python/pyspark/sql/streaming/map_state_client.py
@@ -49,7 +49,7 @@ class MapStateClient:
         self.user_key_or_value_iterator_cursors: Dict[str, 
Tuple["PandasDataFrameLike", int]] = {}
 
     def exists(self, state_name: str) -> bool:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         exists_call = stateMessage.Exists()
         map_state_call = stateMessage.MapStateCall(stateName=state_name, 
exists=exists_call)
@@ -69,7 +69,7 @@ class MapStateClient:
             raise PySparkRuntimeError(f"Error checking map state exists: 
{response_message[1]}")
 
     def get_value(self, state_name: str, user_key: Tuple) -> Optional[Tuple]:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         bytes = self._stateful_processor_api_client._serialize_to_bytes(
             self.user_key_schema, user_key
@@ -92,7 +92,7 @@ class MapStateClient:
             raise PySparkRuntimeError(f"Error getting value: 
{response_message[1]}")
 
     def contains_key(self, state_name: str, user_key: Tuple) -> bool:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         bytes = self._stateful_processor_api_client._serialize_to_bytes(
             self.user_key_schema, user_key
@@ -124,7 +124,7 @@ class MapStateClient:
         user_key: Tuple,
         value: Tuple,
     ) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         key_bytes = self._stateful_processor_api_client._serialize_to_bytes(
             self.user_key_schema, user_key
@@ -147,7 +147,7 @@ class MapStateClient:
             raise PySparkRuntimeError(f"Error updating map state value: 
{response_message[1]}")
 
     def get_key_value_pair(self, state_name: str, iterator_id: str) -> 
Tuple[Tuple, Tuple]:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if iterator_id in self.user_key_value_pair_iterator_cursors:
             # If the state is already in the dictionary, return the next row.
@@ -195,7 +195,7 @@ class MapStateClient:
         return tuple(key_row), tuple(value_row)
 
     def get_row(self, state_name: str, iterator_id: str, is_key: bool) -> 
Tuple:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if iterator_id in self.user_key_or_value_iterator_cursors:
             # If the state is already in the dictionary, return the next row.
@@ -247,7 +247,7 @@ class MapStateClient:
         return tuple(pandas_row)
 
     def remove_key(self, state_name: str, key: Tuple) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         bytes = 
self._stateful_processor_api_client._serialize_to_bytes(self.user_key_schema, 
key)
         remove_key_call = stateMessage.RemoveKey(userKey=bytes)
@@ -263,7 +263,7 @@ class MapStateClient:
             raise PySparkRuntimeError(f"Error removing key from map state: 
{response_message[1]}")
 
     def clear(self, state_name: str) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         clear_call = stateMessage.Clear()
         map_state_call = stateMessage.MapStateCall(stateName=state_name, 
clear=clear_call)
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py 
b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
new file mode 100644
index 000000000000..aeb195ca10ba
--- /dev/null
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler.  DO NOT EDIT!
+# NO CHECKED-IN PROTOBUF GENCODE
+# source: StateMessage.proto
+# Protobuf Python Version: 5.27.3
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import runtime_version as _runtime_version
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+
+_runtime_version.ValidateProtobufRuntimeVersion(
+    _runtime_version.Domain.PUBLIC, 5, 27, 3, "", "StateMessage.proto"
+)
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+    
b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xbf\x03\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01
 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 
\x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03
 
\x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04
 \x01(\x0b\x32J.org.apache.spark [...]
+)
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", 
_globals)
+if not _descriptor._USE_C_DESCRIPTORS:
+    DESCRIPTOR._loaded_options = None
+    _globals["_HANDLESTATE"]._serialized_start = 4966
+    _globals["_HANDLESTATE"]._serialized_end = 5062
+    _globals["_STATEREQUEST"]._serialized_start = 71
+    _globals["_STATEREQUEST"]._serialized_end = 518
+    _globals["_STATERESPONSE"]._serialized_start = 520
+    _globals["_STATERESPONSE"]._serialized_end = 592
+    _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_start = 594
+    _globals["_STATERESPONSEWITHLONGTYPEVAL"]._serialized_end = 681
+    _globals["_STATEFULPROCESSORCALL"]._serialized_start = 684
+    _globals["_STATEFULPROCESSORCALL"]._serialized_end = 1174
+    _globals["_STATEVARIABLEREQUEST"]._serialized_start = 1177
+    _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1473
+    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1476
+    _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1700
+    _globals["_TIMERREQUEST"]._serialized_start = 1703
+    _globals["_TIMERREQUEST"]._serialized_end = 1921
+    _globals["_TIMERVALUEREQUEST"]._serialized_start = 1924
+    _globals["_TIMERVALUEREQUEST"]._serialized_end = 2136
+    _globals["_EXPIRYTIMERREQUEST"]._serialized_start = 2138
+    _globals["_EXPIRYTIMERREQUEST"]._serialized_end = 2185
+    _globals["_GETPROCESSINGTIME"]._serialized_start = 2187
+    _globals["_GETPROCESSINGTIME"]._serialized_end = 2206
+    _globals["_GETWATERMARK"]._serialized_start = 2208
+    _globals["_GETWATERMARK"]._serialized_end = 2222
+    _globals["_STATECALLCOMMAND"]._serialized_start = 2225
+    _globals["_STATECALLCOMMAND"]._serialized_end = 2379
+    _globals["_TIMERSTATECALLCOMMAND"]._serialized_start = 2382
+    _globals["_TIMERSTATECALLCOMMAND"]._serialized_end = 2653
+    _globals["_VALUESTATECALL"]._serialized_start = 2656
+    _globals["_VALUESTATECALL"]._serialized_end = 3009
+    _globals["_LISTSTATECALL"]._serialized_start = 3012
+    _globals["_LISTSTATECALL"]._serialized_end = 3540
+    _globals["_MAPSTATECALL"]._serialized_start = 3543
+    _globals["_MAPSTATECALL"]._serialized_end = 4280
+    _globals["_SETIMPLICITKEY"]._serialized_start = 4282
+    _globals["_SETIMPLICITKEY"]._serialized_end = 4311
+    _globals["_REMOVEIMPLICITKEY"]._serialized_start = 4313
+    _globals["_REMOVEIMPLICITKEY"]._serialized_end = 4332
+    _globals["_EXISTS"]._serialized_start = 4334
+    _globals["_EXISTS"]._serialized_end = 4342
+    _globals["_GET"]._serialized_start = 4344
+    _globals["_GET"]._serialized_end = 4349
+    _globals["_REGISTERTIMER"]._serialized_start = 4351
+    _globals["_REGISTERTIMER"]._serialized_end = 4393
+    _globals["_DELETETIMER"]._serialized_start = 4395
+    _globals["_DELETETIMER"]._serialized_end = 4435
+    _globals["_LISTTIMERS"]._serialized_start = 4437
+    _globals["_LISTTIMERS"]._serialized_end = 4469
+    _globals["_VALUESTATEUPDATE"]._serialized_start = 4471
+    _globals["_VALUESTATEUPDATE"]._serialized_end = 4504
+    _globals["_CLEAR"]._serialized_start = 4506
+    _globals["_CLEAR"]._serialized_end = 4513
+    _globals["_LISTSTATEGET"]._serialized_start = 4515
+    _globals["_LISTSTATEGET"]._serialized_end = 4549
+    _globals["_LISTSTATEPUT"]._serialized_start = 4551
+    _globals["_LISTSTATEPUT"]._serialized_end = 4565
+    _globals["_APPENDVALUE"]._serialized_start = 4567
+    _globals["_APPENDVALUE"]._serialized_end = 4595
+    _globals["_APPENDLIST"]._serialized_start = 4597
+    _globals["_APPENDLIST"]._serialized_end = 4609
+    _globals["_GETVALUE"]._serialized_start = 4611
+    _globals["_GETVALUE"]._serialized_end = 4638
+    _globals["_CONTAINSKEY"]._serialized_start = 4640
+    _globals["_CONTAINSKEY"]._serialized_end = 4670
+    _globals["_UPDATEVALUE"]._serialized_start = 4672
+    _globals["_UPDATEVALUE"]._serialized_end = 4717
+    _globals["_ITERATOR"]._serialized_start = 4719
+    _globals["_ITERATOR"]._serialized_end = 4749
+    _globals["_KEYS"]._serialized_start = 4751
+    _globals["_KEYS"]._serialized_end = 4777
+    _globals["_VALUES"]._serialized_start = 4779
+    _globals["_VALUES"]._serialized_end = 4807
+    _globals["_REMOVEKEY"]._serialized_start = 4809
+    _globals["_REMOVEKEY"]._serialized_end = 4837
+    _globals["_SETHANDLESTATE"]._serialized_start = 4839
+    _globals["_SETHANDLESTATE"]._serialized_end = 4931
+    _globals["_TTLCONFIG"]._serialized_start = 4933
+    _globals["_TTLCONFIG"]._serialized_end = 4964
+# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi 
b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
similarity index 77%
rename from python/pyspark/sql/streaming/StateMessage_pb2.pyi
rename to python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
index 791a221d96f3..ff525ee136a4 100644
--- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.pyi
@@ -13,6 +13,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+#
 from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
 from google.protobuf import descriptor as _descriptor
 from google.protobuf import message as _message
@@ -30,11 +31,13 @@ class HandleState(int, 
metaclass=_enum_type_wrapper.EnumTypeWrapper):
     CREATED: _ClassVar[HandleState]
     INITIALIZED: _ClassVar[HandleState]
     DATA_PROCESSED: _ClassVar[HandleState]
+    TIMER_PROCESSED: _ClassVar[HandleState]
     CLOSED: _ClassVar[HandleState]
 
 CREATED: HandleState
 INITIALIZED: HandleState
 DATA_PROCESSED: HandleState
+TIMER_PROCESSED: HandleState
 CLOSED: HandleState
 
 class StateRequest(_message.Message):
@@ -43,21 +46,25 @@ class StateRequest(_message.Message):
         "statefulProcessorCall",
         "stateVariableRequest",
         "implicitGroupingKeyRequest",
+        "timerRequest",
     )
     VERSION_FIELD_NUMBER: _ClassVar[int]
     STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int]
     STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int]
     IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int]
+    TIMERREQUEST_FIELD_NUMBER: _ClassVar[int]
     version: int
     statefulProcessorCall: StatefulProcessorCall
     stateVariableRequest: StateVariableRequest
     implicitGroupingKeyRequest: ImplicitGroupingKeyRequest
+    timerRequest: TimerRequest
     def __init__(
         self,
         version: _Optional[int] = ...,
         statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, 
_Mapping]] = ...,
         stateVariableRequest: _Optional[_Union[StateVariableRequest, 
_Mapping]] = ...,
         implicitGroupingKeyRequest: 
_Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ...,
+        timerRequest: _Optional[_Union[TimerRequest, _Mapping]] = ...,
     ) -> None: ...
 
 class StateResponse(_message.Message):
@@ -75,22 +82,40 @@ class StateResponse(_message.Message):
         value: _Optional[bytes] = ...,
     ) -> None: ...
 
+class StateResponseWithLongTypeVal(_message.Message):
+    __slots__ = ("statusCode", "errorMessage", "value")
+    STATUSCODE_FIELD_NUMBER: _ClassVar[int]
+    ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int]
+    VALUE_FIELD_NUMBER: _ClassVar[int]
+    statusCode: int
+    errorMessage: str
+    value: int
+    def __init__(
+        self,
+        statusCode: _Optional[int] = ...,
+        errorMessage: _Optional[str] = ...,
+        value: _Optional[int] = ...,
+    ) -> None: ...
+
 class StatefulProcessorCall(_message.Message):
-    __slots__ = ("setHandleState", "getValueState", "getListState", 
"getMapState")
+    __slots__ = ("setHandleState", "getValueState", "getListState", 
"getMapState", "timerStateCall")
     SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int]
     GETVALUESTATE_FIELD_NUMBER: _ClassVar[int]
     GETLISTSTATE_FIELD_NUMBER: _ClassVar[int]
     GETMAPSTATE_FIELD_NUMBER: _ClassVar[int]
+    TIMERSTATECALL_FIELD_NUMBER: _ClassVar[int]
     setHandleState: SetHandleState
     getValueState: StateCallCommand
     getListState: StateCallCommand
     getMapState: StateCallCommand
+    timerStateCall: TimerStateCallCommand
     def __init__(
         self,
         setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ...,
         getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
         getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
         getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
+        timerStateCall: _Optional[_Union[TimerStateCallCommand, _Mapping]] = 
...,
     ) -> None: ...
 
 class StateVariableRequest(_message.Message):
@@ -120,6 +145,44 @@ class ImplicitGroupingKeyRequest(_message.Message):
         removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = 
...,
     ) -> None: ...
 
+class TimerRequest(_message.Message):
+    __slots__ = ("timerValueRequest", "expiryTimerRequest")
+    TIMERVALUEREQUEST_FIELD_NUMBER: _ClassVar[int]
+    EXPIRYTIMERREQUEST_FIELD_NUMBER: _ClassVar[int]
+    timerValueRequest: TimerValueRequest
+    expiryTimerRequest: ExpiryTimerRequest
+    def __init__(
+        self,
+        timerValueRequest: _Optional[_Union[TimerValueRequest, _Mapping]] = 
...,
+        expiryTimerRequest: _Optional[_Union[ExpiryTimerRequest, _Mapping]] = 
...,
+    ) -> None: ...
+
+class TimerValueRequest(_message.Message):
+    __slots__ = ("getProcessingTimer", "getWatermark")
+    GETPROCESSINGTIMER_FIELD_NUMBER: _ClassVar[int]
+    GETWATERMARK_FIELD_NUMBER: _ClassVar[int]
+    getProcessingTimer: GetProcessingTime
+    getWatermark: GetWatermark
+    def __init__(
+        self,
+        getProcessingTimer: _Optional[_Union[GetProcessingTime, _Mapping]] = 
...,
+        getWatermark: _Optional[_Union[GetWatermark, _Mapping]] = ...,
+    ) -> None: ...
+
+class ExpiryTimerRequest(_message.Message):
+    __slots__ = ("expiryTimestampMs",)
+    EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int]
+    expiryTimestampMs: int
+    def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ...
+
+class GetProcessingTime(_message.Message):
+    __slots__ = ()
+    def __init__(self) -> None: ...
+
+class GetWatermark(_message.Message):
+    __slots__ = ()
+    def __init__(self) -> None: ...
+
 class StateCallCommand(_message.Message):
     __slots__ = ("stateName", "schema", "mapStateValueSchema", "ttl")
     STATENAME_FIELD_NUMBER: _ClassVar[int]
@@ -138,6 +201,21 @@ class StateCallCommand(_message.Message):
         ttl: _Optional[_Union[TTLConfig, _Mapping]] = ...,
     ) -> None: ...
 
+class TimerStateCallCommand(_message.Message):
+    __slots__ = ("register", "delete", "list")
+    REGISTER_FIELD_NUMBER: _ClassVar[int]
+    DELETE_FIELD_NUMBER: _ClassVar[int]
+    LIST_FIELD_NUMBER: _ClassVar[int]
+    register: RegisterTimer
+    delete: DeleteTimer
+    list: ListTimers
+    def __init__(
+        self,
+        register: _Optional[_Union[RegisterTimer, _Mapping]] = ...,
+        delete: _Optional[_Union[DeleteTimer, _Mapping]] = ...,
+        list: _Optional[_Union[ListTimers, _Mapping]] = ...,
+    ) -> None: ...
+
 class ValueStateCall(_message.Message):
     __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear")
     STATENAME_FIELD_NUMBER: _ClassVar[int]
@@ -259,6 +337,24 @@ class Get(_message.Message):
     __slots__ = ()
     def __init__(self) -> None: ...
 
+class RegisterTimer(_message.Message):
+    __slots__ = ("expiryTimestampMs",)
+    EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int]
+    expiryTimestampMs: int
+    def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ...
+
+class DeleteTimer(_message.Message):
+    __slots__ = ("expiryTimestampMs",)
+    EXPIRYTIMESTAMPMS_FIELD_NUMBER: _ClassVar[int]
+    expiryTimestampMs: int
+    def __init__(self, expiryTimestampMs: _Optional[int] = ...) -> None: ...
+
+class ListTimers(_message.Message):
+    __slots__ = ("iteratorId",)
+    ITERATORID_FIELD_NUMBER: _ClassVar[int]
+    iteratorId: str
+    def __init__(self, iteratorId: _Optional[str] = ...) -> None: ...
+
 class ValueStateUpdate(_message.Message):
     __slots__ = ("value",)
     VALUE_FIELD_NUMBER: _ClassVar[int]
diff --git a/python/pyspark/sql/streaming/proto/__init__.py 
b/python/pyspark/sql/streaming/proto/__init__.py
new file mode 100644
index 000000000000..cce3acad34a4
--- /dev/null
+++ b/python/pyspark/sql/streaming/proto/__init__.py
@@ -0,0 +1,16 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/python/pyspark/sql/streaming/stateful_processor.py 
b/python/pyspark/sql/streaming/stateful_processor.py
index bac762e8addd..404326bb479c 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -18,7 +18,10 @@
 from abc import ABC, abstractmethod
 from typing import Any, List, TYPE_CHECKING, Iterator, Optional, Union, Tuple
 
-from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
+from pyspark.sql.streaming.stateful_processor_api_client import (
+    StatefulProcessorApiClient,
+    ListTimerIterator,
+)
 from pyspark.sql.streaming.list_state_client import ListStateClient, 
ListStateIterator
 from pyspark.sql.streaming.map_state_client import (
     MapStateClient,
@@ -74,6 +77,56 @@ class ValueState:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+    .. versionadded:: 4.0.0
+    """
+
+    def __init__(
+        self, current_processing_time_in_ms: int = -1, 
current_watermark_in_ms: int = -1
+    ) -> None:
+        self._current_processing_time_in_ms = current_processing_time_in_ms
+        self._current_watermark_in_ms = current_watermark_in_ms
+
+    def get_current_processing_time_in_ms(self) -> int:
+        """
+        Get processing time for current batch, return timestamp in millisecond.
+        """
+        return self._current_processing_time_in_ms
+
+    def get_current_watermark_in_ms(self) -> int:
+        """
+        Get watermark for current batch, return timestamp in millisecond.
+        """
+        return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access expired timer
+    info. When is_valid is false, the expiry timestamp is invalid.
+    .. versionadded:: 4.0.0
+    """
+
+    def __init__(self, is_valid: bool, expiry_time_in_ms: int = -1) -> None:
+        self._is_valid = is_valid
+        self._expiry_time_in_ms = expiry_time_in_ms
+
+    def is_valid(self) -> bool:
+        """
+        Whether the expiry info is valid.
+        """
+        return self._is_valid
+
+    def get_expiry_time_in_ms(self) -> int:
+        """
+        Get the timestamp for expired timer, return timestamp in millisecond.
+        """
+        return self._expiry_time_in_ms
+
+
 class ListState:
     """
     Class used for arbitrary stateful operations with transformWithState to 
capture list value
@@ -292,6 +345,24 @@ class StatefulProcessorHandle:
             state_name,
         )
 
+    def registerTimer(self, expiry_time_stamp_ms: int) -> None:
+        """
+        Register a timer for a given expiry timestamp in milliseconds for the 
grouping key.
+        """
+        self.stateful_processor_api_client.register_timer(expiry_time_stamp_ms)
+
+    def deleteTimer(self, expiry_time_stamp_ms: int) -> None:
+        """
+        Delete a timer for a given expiry timestamp in milliseconds for the 
grouping key.
+        """
+        self.stateful_processor_api_client.delete_timer(expiry_time_stamp_ms)
+
+    def listTimers(self) -> Iterator[int]:
+        """
+        List all timers of their expiry timestamps in milliseconds for the 
grouping key.
+        """
+        return ListTimerIterator(self.stateful_processor_api_client)
+
 
 class StatefulProcessor(ABC):
     """
@@ -317,7 +388,11 @@ class StatefulProcessor(ABC):
 
     @abstractmethod
     def handleInputRows(
-        self, key: Any, rows: Iterator["PandasDataFrameLike"]
+        self,
+        key: Any,
+        rows: Iterator["PandasDataFrameLike"],
+        timer_values: TimerValues,
+        expired_timer_info: ExpiredTimerInfo,
     ) -> Iterator["PandasDataFrameLike"]:
         """
         Function that will allow users to interact with input data rows along 
with the grouping key.
@@ -336,6 +411,11 @@ class StatefulProcessor(ABC):
             grouping key.
         rows : iterable of :class:`pandas.DataFrame`
             iterator of input rows associated with grouping key
+        timer_values: TimerValues
+                      Timer value for the current batch that process the input 
rows.
+                      Users can get the processing or event time timestamp 
from TimerValues.
+        expired_timer_info: ExpiredTimerInfo
+                            Timestamp of expired timers on the grouping key.
         """
         ...
 
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py 
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 552ab44d1ddf..ce3bae0a7c91 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -17,12 +17,13 @@
 from enum import Enum
 import os
 import socket
-from typing import Any, List, Union, Optional, cast, Tuple
+from typing import Any, Dict, List, Union, Optional, cast, Tuple, Iterator
 
 from pyspark.serializers import write_int, read_int, UTF8Deserializer
 from pyspark.sql.pandas.serializers import ArrowStreamSerializer
 from pyspark.sql.types import (
     StructType,
+    TYPE_CHECKING,
     _parse_datatype_string,
     Row,
 )
@@ -30,6 +31,10 @@ from pyspark.sql.pandas.types import 
convert_pandas_using_numpy_type
 from pyspark.sql.utils import has_numpy
 from pyspark.serializers import CPickleSerializer
 from pyspark.errors import PySparkRuntimeError
+import uuid
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
 
 __all__ = ["StatefulProcessorApiClient", "StatefulProcessorHandleState"]
 
@@ -38,7 +43,8 @@ class StatefulProcessorHandleState(Enum):
     CREATED = 1
     INITIALIZED = 2
     DATA_PROCESSED = 3
-    CLOSED = 4
+    TIMER_PROCESSED = 4
+    CLOSED = 5
 
 
 class StatefulProcessorApiClient:
@@ -53,9 +59,12 @@ class StatefulProcessorApiClient:
         self.utf8_deserializer = UTF8Deserializer()
         self.pickleSer = CPickleSerializer()
         self.serializer = ArrowStreamSerializer()
+        # Dictionaries to store the mapping between iterator id and a tuple of 
pandas DataFrame
+        # and the index of the last row that was read.
+        self.list_timer_iterator_cursors: Dict[str, 
Tuple["PandasDataFrameLike", int]] = {}
 
     def set_handle_state(self, state: StatefulProcessorHandleState) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if state == StatefulProcessorHandleState.CREATED:
             proto_state = stateMessage.CREATED
@@ -63,6 +72,8 @@ class StatefulProcessorApiClient:
             proto_state = stateMessage.INITIALIZED
         elif state == StatefulProcessorHandleState.DATA_PROCESSED:
             proto_state = stateMessage.DATA_PROCESSED
+        elif state == StatefulProcessorHandleState.TIMER_PROCESSED:
+            proto_state = stateMessage.TIMER_PROCESSED
         else:
             proto_state = stateMessage.CLOSED
         set_handle_state = stateMessage.SetHandleState(state=proto_state)
@@ -80,7 +91,7 @@ class StatefulProcessorApiClient:
             raise PySparkRuntimeError(f"Error setting handle state: " 
f"{response_message[1]}")
 
     def set_implicit_key(self, key: Tuple) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         key_bytes = self._serialize_to_bytes(self.key_schema, key)
         set_implicit_key = stateMessage.SetImplicitKey(key=key_bytes)
@@ -95,7 +106,7 @@ class StatefulProcessorApiClient:
             raise PySparkRuntimeError(f"Error setting implicit key: " 
f"{response_message[1]}")
 
     def remove_implicit_key(self) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         remove_implicit_key = stateMessage.RemoveImplicitKey()
         request = 
stateMessage.ImplicitGroupingKeyRequest(removeImplicitKey=remove_implicit_key)
@@ -111,7 +122,7 @@ class StatefulProcessorApiClient:
     def get_value_state(
         self, state_name: str, schema: Union[StructType, str], 
ttl_duration_ms: Optional[int]
     ) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
             schema = cast(StructType, _parse_datatype_string(schema))
@@ -134,7 +145,7 @@ class StatefulProcessorApiClient:
     def get_list_state(
         self, state_name: str, schema: Union[StructType, str], 
ttl_duration_ms: Optional[int]
     ) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
             schema = cast(StructType, _parse_datatype_string(schema))
@@ -152,7 +163,150 @@ class StatefulProcessorApiClient:
         status = response_message[0]
         if status != 0:
             # TODO(SPARK-49233): Classify user facing errors.
-            raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
+            raise PySparkRuntimeError(f"Error initializing list state: " 
f"{response_message[1]}")
+
+    def register_timer(self, expiry_time_stamp_ms: int) -> None:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        register_call = 
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+        state_call_command = 
stateMessage.TimerStateCallCommand(register=register_call)
+        call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error register timer: " 
f"{response_message[1]}")
+
+    def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        delete_call = 
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+        state_call_command = 
stateMessage.TimerStateCallCommand(delete=delete_call)
+        call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error deleting timer: " 
f"{response_message[1]}")
+
+    def get_list_timer_row(self, iterator_id: str) -> int:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        if iterator_id in self.list_timer_iterator_cursors:
+            # if the iterator is already in the dictionary, return the next row
+            pandas_df, index = self.list_timer_iterator_cursors[iterator_id]
+        else:
+            list_call = stateMessage.ListTimers(iteratorId=iterator_id)
+            state_call_command = 
stateMessage.TimerStateCallCommand(list=list_call)
+            call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+            message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+            self._send_proto_message(message.SerializeToString())
+            response_message = self._receive_proto_message()
+            status = response_message[0]
+            if status == 0:
+                iterator = self._read_arrow_state()
+                # We need to exhaust the iterator here to make sure all the 
arrow batches are read,
+                # even though there is only one batch in the iterator. 
Otherwise, the stream might
+                # block further reads since it thinks there might still be 
some arrow batches left.
+                # We only need to read the first batch in the iterator because 
it's guaranteed that
+                # there would only be one batch sent from the JVM side.
+                data_batch = None
+                for batch in iterator:
+                    if data_batch is None:
+                        data_batch = batch
+                if data_batch is None:
+                    # TODO(SPARK-49233): Classify user facing errors.
+                    raise PySparkRuntimeError("Error getting map state entry.")
+                pandas_df = data_batch.to_pandas()
+                index = 0
+            else:
+                raise StopIteration()
+        new_index = index + 1
+        if new_index < len(pandas_df):
+            # Update the index in the dictionary.
+            self.list_timer_iterator_cursors[iterator_id] = (pandas_df, 
new_index)
+        else:
+            # If the index is at the end of the DataFrame, remove the state 
from the dictionary.
+            self.list_timer_iterator_cursors.pop(iterator_id, None)
+        return pandas_df.at[index, "timestamp"].item()
+
+    def get_expiry_timers_iterator(
+        self, expiry_timestamp: int
+    ) -> Iterator[list[Tuple[Tuple, int]]]:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        while True:
+            expiry_timer_call = 
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+            timer_request = 
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
+            message = stateMessage.StateRequest(timerRequest=timer_request)
+
+            self._send_proto_message(message.SerializeToString())
+            response_message = self._receive_proto_message()
+            status = response_message[0]
+            if status == 1:
+                break
+            elif status == 0:
+                result_list = []
+                iterator = self._read_arrow_state()
+                for batch in iterator:
+                    batch_df = batch.to_pandas()
+                    for i in range(batch.num_rows):
+                        deserialized_key = self.pickleSer.loads(batch_df.at[i, 
"key"])
+                        timestamp = batch_df.at[i, "timestamp"].item()
+                        result_list.append((tuple(deserialized_key), 
timestamp))
+                yield result_list
+            else:
+                # TODO(SPARK-49233): Classify user facing errors.
+                raise PySparkRuntimeError(f"Error getting expiry timers: " 
f"{response_message[1]}")
+
+    def get_batch_timestamp(self) -> int:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        get_processing_time_call = stateMessage.GetProcessingTime()
+        timer_value_call = stateMessage.TimerValueRequest(
+            getProcessingTimer=get_processing_time_call
+        )
+        timer_request = 
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+        message = stateMessage.StateRequest(timerRequest=timer_request)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message_with_long_value()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error getting processing timestamp: " 
f"{response_message[1]}"
+            )
+        else:
+            timestamp = response_message[2]
+            return timestamp
+
+    def get_watermark_timestamp(self) -> int:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        get_watermark_call = stateMessage.GetWatermark()
+        timer_value_call = 
stateMessage.TimerValueRequest(getWatermark=get_watermark_call)
+        timer_request = 
stateMessage.TimerRequest(timerValueRequest=timer_value_call)
+        message = stateMessage.StateRequest(timerRequest=timer_request)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message_with_long_value()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(
+                f"Error getting eventtime timestamp: " f"{response_message[1]}"
+            )
+        else:
+            timestamp = response_message[2]
+            return timestamp
 
     def get_map_state(
         self,
@@ -161,7 +315,7 @@ class StatefulProcessorApiClient:
         value_schema: Union[StructType, str],
         ttl_duration_ms: Optional[int],
     ) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(user_key_schema, str):
             user_key_schema = cast(StructType, 
_parse_datatype_string(user_key_schema))
@@ -193,7 +347,7 @@ class StatefulProcessorApiClient:
         self.sockfile.flush()
 
     def _receive_proto_message(self) -> Tuple[int, str, bytes]:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         length = read_int(self.sockfile)
         bytes = self.sockfile.read(length)
@@ -201,6 +355,15 @@ class StatefulProcessorApiClient:
         message.ParseFromString(bytes)
         return message.statusCode, message.errorMessage, message.value
 
+    def _receive_proto_message_with_long_value(self) -> Tuple[int, str, int]:
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
+
+        length = read_int(self.sockfile)
+        bytes = self.sockfile.read(length)
+        message = stateMessage.StateResponseWithLongTypeVal()
+        message.ParseFromString(bytes)
+        return message.statusCode, message.errorMessage, message.value
+
     def _receive_str(self) -> str:
         return self.utf8_deserializer.loads(self.sockfile)
 
@@ -243,3 +406,17 @@ class StatefulProcessorApiClient:
 
     def _read_arrow_state(self) -> Any:
         return self.serializer.load_stream(self.sockfile)
+
+
+class ListTimerIterator:
+    def __init__(self, stateful_processor_api_client: 
StatefulProcessorApiClient):
+        # Generate a unique identifier for the iterator to make sure iterators 
on the
+        # same partition won't interfere with each other
+        self.iterator_id = str(uuid.uuid4())
+        self.stateful_processor_api_client = stateful_processor_api_client
+
+    def __iter__(self) -> Iterator[int]:
+        return self
+
+    def __next__(self) -> int:
+        return 
self.stateful_processor_api_client.get_list_timer_row(self.iterator_id)
diff --git a/python/pyspark/sql/streaming/value_state_client.py 
b/python/pyspark/sql/streaming/value_state_client.py
index 3fe32bcc5235..fd783af7931d 100644
--- a/python/pyspark/sql/streaming/value_state_client.py
+++ b/python/pyspark/sql/streaming/value_state_client.py
@@ -28,7 +28,7 @@ class ValueStateClient:
         self._stateful_processor_api_client = stateful_processor_api_client
 
     def exists(self, state_name: str) -> bool:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         exists_call = stateMessage.Exists()
         value_state_call = stateMessage.ValueStateCall(stateName=state_name, 
exists=exists_call)
@@ -50,7 +50,7 @@ class ValueStateClient:
             )
 
     def get(self, state_name: str) -> Optional[Tuple]:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         get_call = stateMessage.Get()
         value_state_call = stateMessage.ValueStateCall(stateName=state_name, 
get=get_call)
@@ -70,7 +70,7 @@ class ValueStateClient:
             raise PySparkRuntimeError(f"Error getting value state: " 
f"{response_message[1]}")
 
     def update(self, state_name: str, schema: Union[StructType, str], value: 
Tuple) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
             schema = cast(StructType, _parse_datatype_string(schema))
@@ -90,7 +90,7 @@ class ValueStateClient:
             raise PySparkRuntimeError(f"Error updating value state: " 
f"{response_message[1]}")
 
     def clear(self, state_name: str) -> None:
-        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+        import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         clear_call = stateMessage.Clear()
         value_state_call = stateMessage.ValueStateCall(stateName=state_name, 
clear=clear_call)
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 7339897cb2cc..c888b6b3298f 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -75,6 +75,9 @@ class TransformWithStateInPandasTestsMixin:
             input_path + "/text-test2.txt", [0, 0, 0, 1, 1], [123, 223, 323, 
246, 6]
         )
 
+    def _prepare_test_resource3(self, input_path):
+        self._prepare_input_data(input_path + "/text-test3.txt", [0, 1], [123, 
6])
+
     def _build_test_df(self, input_path):
         df = self.spark.readStream.format("text").option("maxFilesPerTrigger", 
1).load(input_path)
         df_split = df.withColumn("split_values", split(df["value"], ","))
@@ -363,6 +366,277 @@ class TransformWithStateInPandasTestsMixin:
         finally:
             input_dir.cleanup()
 
+    def _test_transform_with_state_in_pandas_proc_timer(self, 
stateful_processor, check_results):
+        input_path = tempfile.mkdtemp()
+        self._prepare_test_resource3(input_path)
+        self._prepare_test_resource1(input_path)
+        self._prepare_test_resource2(input_path)
+
+        df = self._build_test_df(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("countAsString", StringType(), True),
+                StructField("timeValues", StringType(), True),
+            ]
+        )
+
+        query_name = "processing_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="processingtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_proc_timer(self):
+        # helper function to check expired timestamp is smaller than current 
processing time
+        def check_timestamp(batch_df):
+            expired_df = (
+                batch_df.filter(batch_df["countAsString"] == "-1")
+                .select("id", "timeValues")
+                .withColumnRenamed("timeValues", "expiredTimestamp")
+            )
+            count_df = (
+                batch_df.filter(batch_df["countAsString"] != "-1")
+                .select("id", "timeValues")
+                .withColumnRenamed("timeValues", "countStateTimestamp")
+            )
+            joined_df = expired_df.join(count_df, on="id")
+            for row in joined_df.collect():
+                assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="1"),
+                    Row(id="1", countAsString="1"),
+                }
+            elif batch_id == 1:
+                # for key 0, the accumulated count is emitted before the count 
state is cleared
+                # during the timer process
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="3"),
+                }
+                self.first_expired_timestamp = batch_df.filter(
+                    batch_df["countAsString"] == -1
+                ).first()["timeValues"]
+                check_timestamp(batch_df)
+
+            else:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="5"),
+                }
+                # The expired timestamp in current batch is larger than expiry 
timestamp in batch 1
+                # because this is a new timer registered in batch1 and
+                # different from the one registered in batch 0
+                current_batch_expired_timestamp = batch_df.filter(
+                    batch_df["countAsString"] == -1
+                ).first()["timeValues"]
+                assert current_batch_expired_timestamp > 
self.first_expired_timestamp
+
+        self._test_transform_with_state_in_pandas_proc_timer(
+            ProcTimeStatefulProcessor(), check_results
+        )
+
+    def _test_transform_with_state_in_pandas_event_time(self, 
stateful_processor, check_results):
+        import pyspark.sql.functions as f
+
+        input_path = tempfile.mkdtemp()
+
+        def prepare_batch1(input_path):
+            with open(input_path + "/text-test3.txt", "w") as fw:
+                fw.write("a, 20\n")
+
+        def prepare_batch2(input_path):
+            with open(input_path + "/text-test1.txt", "w") as fw:
+                fw.write("a, 4\n")
+
+        def prepare_batch3(input_path):
+            with open(input_path + "/text-test2.txt", "w") as fw:
+                fw.write("a, 11\n")
+                fw.write("a, 13\n")
+                fw.write("a, 15\n")
+
+        prepare_batch1(input_path)
+        prepare_batch2(input_path)
+        prepare_batch3(input_path)
+
+        df = self._build_test_df(input_path)
+        df = df.select(
+            "id", 
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")
+        ).withWatermark("eventTime", "10 seconds")
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [StructField("id", StringType(), True), StructField("timestamp", 
StringType(), True)]
+        )
+
+        query_name = "event_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="eventtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_event_time(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {Row(id="a", 
timestamp="20")}
+            elif batch_id == 1:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="20"),
+                    Row(id="a-expired", timestamp="0"),
+                }
+            else:
+                # watermark has not progressed, so timer registered in batch 
1(watermark = 10)
+                # has not yet expired
+                assert set(batch_df.sort("id").collect()) == {Row(id="a", 
timestamp="15")}
+
+        self._test_transform_with_state_in_pandas_event_time(
+            EventTimeStatefulProcessor(), check_results
+        )
+
+
+# A stateful processor that output the max event time it has seen. Register 
timer for
+# current watermark. Clear max state if timer expires.
+class EventTimeStatefulProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        state_schema = StructType([StructField("value", StringType(), True)])
+        self.handle = handle
+        self.max_state = handle.getValueState("max_state", state_schema)
+
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
+        if expired_timer_info.is_valid():
+            self.max_state.clear()
+            self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+            str_key = f"{str(key[0])}-expired"
+            yield pd.DataFrame(
+                {"id": (str_key,), "timestamp": 
str(expired_timer_info.get_expiry_time_in_ms())}
+            )
+
+        else:
+            timestamp_list = []
+            for pdf in rows:
+                # int64 will represent timestamp in nanosecond, restore to 
second
+                timestamp_list.extend((pdf["eventTime"].astype("int64") // 
10**9).tolist())
+
+            if self.max_state.exists():
+                cur_max = int(self.max_state.get()[0])
+            else:
+                cur_max = 0
+            max_event_time = str(max(cur_max, max(timestamp_list)))
+
+            self.max_state.update((max_event_time,))
+            
self.handle.registerTimer(timer_values.get_current_watermark_in_ms())
+
+            yield pd.DataFrame({"id": key, "timestamp": max_event_time})
+
+    def close(self) -> None:
+        pass
+
+
+# A stateful processor that output the accumulation of count of input rows; 
register
+# processing timer and clear the counter if timer expires.
+class ProcTimeStatefulProcessor(StatefulProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        state_schema = StructType([StructField("value", StringType(), True)])
+        self.handle = handle
+        self.count_state = handle.getValueState("count_state", state_schema)
+
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
+        if expired_timer_info.is_valid():
+            # reset count state each time the timer is expired
+            timer_list_1 = [e for e in self.handle.listTimers()]
+            timer_list_2 = []
+            idx = 0
+            for e in self.handle.listTimers():
+                timer_list_2.append(e)
+                # check multiple iterator on the same grouping key works
+                assert timer_list_2[idx] == timer_list_1[idx]
+                idx += 1
+
+            if len(timer_list_1) > 0:
+                # before deleting the expiring timers, there are 2 timers -
+                # one timer we just registered, and one that is going to be 
deleted
+                assert len(timer_list_1) == 2
+            self.count_state.clear()
+            self.handle.deleteTimer(expired_timer_info.get_expiry_time_in_ms())
+            yield pd.DataFrame(
+                {
+                    "id": key,
+                    "countAsString": str("-1"),
+                    "timeValues": 
str(expired_timer_info.get_expiry_time_in_ms()),
+                }
+            )
+
+        else:
+            if not self.count_state.exists():
+                count = 0
+            else:
+                count = int(self.count_state.get()[0])
+
+            if key == ("0",):
+                
self.handle.registerTimer(timer_values.get_current_processing_time_in_ms())
+
+            rows_count = 0
+            for pdf in rows:
+                pdf_count = len(pdf)
+                rows_count += pdf_count
+
+            count = count + rows_count
+
+            self.count_state.update((str(count),))
+            timestamp = str(timer_values.get_current_processing_time_in_ms())
+
+            yield pd.DataFrame({"id": key, "countAsString": str(count), 
"timeValues": timestamp})
+
+    def close(self) -> None:
+        pass
+
 
 class SimpleStatefulProcessor(StatefulProcessor):
     dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
@@ -372,7 +646,9 @@ class SimpleStatefulProcessor(StatefulProcessor):
         state_schema = StructType([StructField("value", IntegerType(), True)])
         self.num_violations_state = handle.getValueState("numViolations", 
state_schema)
 
-    def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
         new_violations = 0
         count = 0
         key_str = key[0]
@@ -417,7 +693,9 @@ class TTLStatefulProcessor(StatefulProcessor):
             "ttl-map-state", user_key_schema, state_schema, 10000
         )
 
-    def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
         count = 0
         ttl_count = 0
         ttl_list_state_count = 0
@@ -467,7 +745,9 @@ class InvalidSimpleStatefulProcessor(StatefulProcessor):
         state_schema = StructType([StructField("value", IntegerType(), True)])
         self.num_violations_state = handle.getValueState("numViolations", 
state_schema)
 
-    def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
         count = 0
         exists = self.num_violations_state.exists()
         assert not exists
@@ -491,7 +771,9 @@ class ListStateProcessor(StatefulProcessor):
         self.list_state1 = handle.getListState("listState1", state_schema)
         self.list_state2 = handle.getListState("listState2", state_schema)
 
-    def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
         count = 0
         for pdf in rows:
             list_state_rows = [(120,), (20,)]
@@ -546,7 +828,9 @@ class MapStateProcessor(StatefulProcessor):
         value_schema = StructType([StructField("count", IntegerType(), True)])
         self.map_state = handle.getMapState("mapState", key_schema, 
value_schema)
 
-    def handleInputRows(self, key, rows):
+    def handleInputRows(
+        self, key, rows, timer_values, expired_timer_info
+    ) -> Iterator[pd.DataFrame]:
         count = 0
         key1 = ("key1",)
         key2 = ("key2",)
diff --git 
a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
 
b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
index bb1c4c4f8e6c..544cd3b10b1c 100644
--- 
a/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
+++ 
b/sql/core/src/main/protobuf/org/apache/spark/sql/execution/streaming/StateMessage.proto
@@ -25,6 +25,7 @@ message StateRequest {
     StatefulProcessorCall statefulProcessorCall = 2;
     StateVariableRequest stateVariableRequest = 3;
     ImplicitGroupingKeyRequest implicitGroupingKeyRequest = 4;
+    TimerRequest timerRequest = 5;
   }
 }
 
@@ -34,12 +35,19 @@ message StateResponse {
   bytes value = 3;
 }
 
+message StateResponseWithLongTypeVal {
+  int32 statusCode = 1;
+  string errorMessage = 2;
+  int64 value = 3;
+}
+
 message StatefulProcessorCall {
   oneof method {
     SetHandleState setHandleState = 1;
     StateCallCommand getValueState = 2;
     StateCallCommand getListState = 3;
     StateCallCommand getMapState = 4;
+    TimerStateCallCommand timerStateCall = 5;
   }
 }
 
@@ -58,6 +66,30 @@ message ImplicitGroupingKeyRequest {
   }
 }
 
+message TimerRequest {
+  oneof method {
+    TimerValueRequest timerValueRequest = 1;
+    ExpiryTimerRequest expiryTimerRequest = 2;
+  }
+}
+
+message TimerValueRequest {
+  oneof method {
+    GetProcessingTime getProcessingTimer = 1;
+    GetWatermark getWatermark = 2;
+  }
+}
+
+message ExpiryTimerRequest {
+  int64 expiryTimestampMs = 1;
+}
+
+message GetProcessingTime {
+}
+
+message GetWatermark {
+}
+
 message StateCallCommand {
   string stateName = 1;
   string schema = 2;
@@ -65,6 +97,14 @@ message StateCallCommand {
   TTLConfig ttl = 4;
 }
 
+message TimerStateCallCommand {
+  oneof method {
+    RegisterTimer register = 1;
+    DeleteTimer delete = 2;
+    ListTimers list = 3;
+  }
+}
+
 message ValueStateCall {
   string stateName = 1;
   oneof method {
@@ -115,6 +155,18 @@ message Exists {
 message Get {
 }
 
+message RegisterTimer {
+  int64 expiryTimestampMs = 1;
+}
+
+message DeleteTimer {
+  int64 expiryTimestampMs = 1;
+}
+
+message ListTimers {
+  string iteratorId = 1;
+}
+
 message ValueStateUpdate {
   bytes value = 1;
 }
@@ -169,7 +221,8 @@ enum HandleState {
   CREATED = 0;
   INITIALIZED = 1;
   DATA_PROCESSED = 2;
-  CLOSED = 3;
+  TIMER_PROCESSED = 3;
+  CLOSED = 4;
 }
 
 message SetHandleState {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
index f6a3f1b1394f..f7e1ec79bfcd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasExec.scala
@@ -127,6 +127,8 @@ case class TransformWithStateInPandasExec(
       case (store: StateStore, dataIterator: Iterator[InternalRow]) =>
         val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
         val commitTimeMs = longMetric("commitTimeMs")
+        // TODO(SPARK-49603) set the metrics in the lazily initialized iterator
+        val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
         val currentTimeNs = System.nanoTime
         val updatesStartTimeNs = currentTimeNs
 
@@ -144,7 +146,9 @@ case class TransformWithStateInPandasExec(
           pythonRunnerConf,
           pythonMetrics,
           jobArtifactUUID,
-          groupingKeySchema
+          groupingKeySchema,
+          batchTimestampMs,
+          eventTimeWatermarkForEviction
         )
 
         val outputIterator = executePython(data, output, runner)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
index b4b516ba9e5a..655fe259578f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
@@ -50,7 +50,9 @@ class TransformWithStateInPandasPythonRunner(
     initialWorkerConf: Map[String, String],
     override val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String],
-    groupingKeySchema: StructType)
+    groupingKeySchema: StructType,
+    batchTimestampMs: Option[Long] = None,
+    eventTimeWatermarkForEviction: Option[Long] = None)
   extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
   with PythonArrowInput[InType]
   with BasicPythonArrowOutput
@@ -104,7 +106,9 @@ class TransformWithStateInPandasPythonRunner(
     executionContext.execute(
       new TransformWithStateInPandasStateServer(stateServerSocket, 
processorHandle,
         groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, 
largeVarTypes,
-        sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch))
+        sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch,
+        batchTimestampMs = batchTimestampMs,
+        eventTimeWatermarkForEviction = eventTimeWatermarkForEviction))
 
     context.addTaskCompletionListener[Unit] { _ =>
       logInfo(log"completion listener called")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 3aed0b2463f3..8a67d5d47f05 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -34,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState, StateVariableType}
-import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, 
StateRequest, StateResponse, StateVariableRequest, ValueStateCall}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, ListStateCall, MapStateCall, StatefulProcessorCall, 
StateRequest, StateResponse, StateResponseWithLongTypeVal, 
StateVariableRequest, TimerRequest, TimerStateCallCommand, TimerValueRequest, 
ValueStateCall}
 import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, 
ValueState}
-import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
+import org.apache.spark.sql.types.{BinaryType, LongType, StructField, 
StructType}
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.util.Utils
 
@@ -58,6 +58,8 @@ class TransformWithStateInPandasStateServer(
     errorOnDuplicatedFieldNames: Boolean,
     largeVarTypes: Boolean,
     arrowTransformWithStateInPandasMaxRecordsPerBatch: Int,
+    batchTimestampMs: Option[Long] = None,
+    eventTimeWatermarkForEviction: Option[Long] = None,
     outputStreamForTest: DataOutputStream = null,
     valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null,
     deserializerForTest: TransformWithStateInPandasDeserializer = null,
@@ -65,12 +67,19 @@ class TransformWithStateInPandasStateServer(
     listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
     iteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
     mapStatesMapForTest : mutable.HashMap[String, MapStateInfo] = null,
-    keyValueIteratorMapForTest: mutable.HashMap[String, Iterator[(Row, Row)]] 
= null)
+    keyValueIteratorMapForTest: mutable.HashMap[String, Iterator[(Row, Row)]] 
= null,
+    expiryTimerIterForTest: Iterator[(Any, Long)] = null,
+    listTimerMapForTest: mutable.HashMap[String, Iterator[Long]] = null)
   extends Runnable with Logging {
+
+  import PythonResponseWriterUtils._
+
   private val keyRowDeserializer: ExpressionEncoder.Deserializer[Row] =
     ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer()
   private var inputStream: DataInputStream = _
   private var outputStream: DataOutputStream = outputStreamForTest
+
+  /** State variable related class variables */
   // A map to store the value state name -> (value state, schema, value row 
deserializer) mapping.
   private val valueStates = if (valueStateMapForTest != null) {
     valueStateMapForTest
@@ -110,6 +119,20 @@ class TransformWithStateInPandasStateServer(
     new mutable.HashMap[String, Iterator[(Row, Row)]]()
   }
 
+  /** Timer related class variables */
+  private var expiryTimestampIter: Option[Iterator[(Any, Long)]] =
+    if (expiryTimerIterForTest != null) {
+      Option(expiryTimerIterForTest)
+    } else None
+
+  // A map to store the iterator id -> Iterator[Long] mapping. This is to keep 
track of the
+  // current iterator position for each iterator id in the same partition for 
a grouping key in case
+  // user tries to fetch multiple iterators before the current iterator is 
exhausted. This is
+  // used for list timer function call
+  private var listTimerIters = if (listTimerMapForTest != null) {
+    listTimerMapForTest
+  } else new mutable.HashMap[String, Iterator[Long]]()
+
   def run(): Unit = {
     val listeningSocket = stateServerSocket.accept()
     inputStream = new DataInputStream(
@@ -159,11 +182,62 @@ class TransformWithStateInPandasStateServer(
         handleStatefulProcessorCall(message.getStatefulProcessorCall)
       case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
         handleStateVariableRequest(message.getStateVariableRequest)
+      case StateRequest.MethodCase.TIMERREQUEST =>
+        handleTimerRequest(message.getTimerRequest)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
   }
 
+  private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+    message.getMethodCase match {
+      case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+        val timerRequest = message.getTimerValueRequest()
+        timerRequest.getMethodCase match {
+          case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+            val procTimestamp: Long =
+              if (batchTimestampMs.isDefined) batchTimestampMs.get else -1L
+            sendResponseWithLongVal(0, null, procTimestamp)
+          case TimerValueRequest.MethodCase.GETWATERMARK =>
+            val eventTimestamp: Long =
+              if (eventTimeWatermarkForEviction.isDefined) 
eventTimeWatermarkForEviction.get
+              else -1L
+            sendResponseWithLongVal(0, null, eventTimestamp)
+          case _ =>
+            throw new IllegalArgumentException("Invalid timer value method 
call")
+        }
+
+      case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+        // Note that for `getExpiryTimers` python call, as this is not a public
+        // API and it will only be used by `group_ops` once per partition, we 
won't
+        // need to worry about different function calls will interleaved and 
hence
+        // this implementation is safe
+        val expiryRequest = message.getExpiryTimerRequest()
+        val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+        if (!expiryTimestampIter.isDefined) {
+          expiryTimestampIter =
+            Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp))
+        }
+        // expiryTimestampIter could be None in the TWSPandasServerSuite
+        if (!expiryTimestampIter.isDefined || 
!expiryTimestampIter.get.hasNext) {
+          // iterator is exhausted, signal the end of iterator on python client
+          sendResponse(1)
+        } else {
+          sendResponse(0)
+          val outputSchema = new StructType()
+            .add("key", BinaryType)
+            .add(StructField("timestamp", LongType))
+          sendIteratorAsArrowBatches(expiryTimestampIter.get, outputSchema,
+            arrowStreamWriterForTest) { data =>
+            InternalRow(PythonSQLUtils.toPyRow(data._1.asInstanceOf[Row]), 
data._2)
+          }
+        }
+
+      case _ =>
+        throw new IllegalArgumentException("Invalid timer request method call")
+    }
+  }
+
   private def handleImplicitGroupingKeyRequest(message: 
ImplicitGroupingKeyRequest): Unit = {
     message.getMethodCase match {
       case ImplicitGroupingKeyRequest.MethodCase.SETIMPLICITKEY =>
@@ -173,11 +247,13 @@ class TransformWithStateInPandasStateServer(
         ImplicitGroupingKeyTracker.setImplicitKey(keyRow)
         // Reset the list/map state iterators for a new grouping key.
         iterators = new mutable.HashMap[String, Iterator[Row]]()
+        listTimerIters = new mutable.HashMap[String, Iterator[Long]]()
         sendResponse(0)
       case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY =>
         ImplicitGroupingKeyTracker.removeImplicitKey()
         // Reset the list/map state iterators for a new grouping key.
         iterators = new mutable.HashMap[String, Iterator[Row]]()
+        listTimerIters = new mutable.HashMap[String, Iterator[Long]]()
         sendResponse(0)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
@@ -195,6 +271,12 @@ class TransformWithStateInPandasStateServer(
           case HandleState.INITIALIZED =>
             logInfo(log"set handle state to Initialized")
             
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.INITIALIZED)
+          case HandleState.DATA_PROCESSED =>
+            logInfo(log"set handle state to Data Processed")
+            
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED)
+          case HandleState.TIMER_PROCESSED =>
+            logInfo(log"set handle state to Timer Processed")
+            
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.TIMER_PROCESSED)
           case HandleState.CLOSED =>
             logInfo(log"set handle state to Closed")
             
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
@@ -226,6 +308,41 @@ class TransformWithStateInPandasStateServer(
         } else None
         initializeStateVariable(stateName, userKeySchema, 
StateVariableType.MapState, ttlDurationMs,
           valueSchema)
+      case StatefulProcessorCall.MethodCase.TIMERSTATECALL =>
+        message.getTimerStateCall.getMethodCase match {
+          case TimerStateCallCommand.MethodCase.REGISTER =>
+            val expiryTimestamp =
+              message.getTimerStateCall.getRegister.getExpiryTimestampMs
+            statefulProcessorHandle.registerTimer(expiryTimestamp)
+            sendResponse(0)
+          case TimerStateCallCommand.MethodCase.DELETE =>
+            val expiryTimestamp =
+              message.getTimerStateCall.getDelete.getExpiryTimestampMs
+            statefulProcessorHandle.deleteTimer(expiryTimestamp)
+            sendResponse(0)
+          case TimerStateCallCommand.MethodCase.LIST =>
+            val iteratorId = message.getTimerStateCall.getList.getIteratorId
+            var iteratorOption = listTimerIters.get(iteratorId)
+            if (iteratorOption.isEmpty) {
+              iteratorOption = Option(statefulProcessorHandle.listTimers())
+              listTimerIters.put(iteratorId, iteratorOption.get)
+            }
+            if (!iteratorOption.get.hasNext) {
+              sendResponse(2, s"List timer iterator doesn't contain any 
value.")
+              return
+            } else {
+              sendResponse(0)
+            }
+            val outputSchema = new StructType()
+              .add(StructField("timestamp", LongType))
+            sendIteratorAsArrowBatches(iteratorOption.get, outputSchema,
+              arrowStreamWriterForTest) { data =>
+              InternalRow(data)
+            }
+
+          case _ =>
+            throw new IllegalArgumentException("Invalid timer state method 
call")
+        }
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
@@ -463,24 +580,6 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
-  private def sendResponse(
-      status: Int,
-      errorMessage: String = null,
-      byteString: ByteString = null): Unit = {
-    val responseMessageBuilder = 
StateResponse.newBuilder().setStatusCode(status)
-    if (status != 0 && errorMessage != null) {
-      responseMessageBuilder.setErrorMessage(errorMessage)
-    }
-    if (byteString != null) {
-      responseMessageBuilder.setValue(byteString)
-    }
-    val responseMessage = responseMessageBuilder.build()
-    val responseMessageBytes = responseMessage.toByteArray
-    val byteLength = responseMessageBytes.length
-    outputStream.writeInt(byteLength)
-    outputStream.write(responseMessageBytes)
-  }
-
   private def initializeStateVariable(
       stateName: String,
       schemaString: String,
@@ -490,90 +589,128 @@ class TransformWithStateInPandasStateServer(
     val schema = StructType.fromString(schemaString)
     val expressionEncoder = ExpressionEncoder(schema).resolveAndBind()
     stateType match {
-        case StateVariableType.ValueState => if 
(!valueStates.contains(stateName)) {
-          val state = if (ttlDurationMs.isEmpty) {
-            statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
-          } else {
-            statefulProcessorHandle.getValueState(
-              stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
-          }
-          valueStates.put(stateName,
-            ValueStateInfo(state, schema, 
expressionEncoder.createDeserializer()))
-          sendResponse(0)
+      case StateVariableType.ValueState => if 
(!valueStates.contains(stateName)) {
+        val state = if (ttlDurationMs.isEmpty) {
+          statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
         } else {
-          sendResponse(1, s"Value state $stateName already exists")
+          statefulProcessorHandle.getValueState(
+            stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
         }
-        case StateVariableType.ListState => if 
(!listStates.contains(stateName)) {
-          val state = if (ttlDurationMs.isEmpty) {
-            statefulProcessorHandle.getListState[Row](stateName, 
Encoders.row(schema))
-          } else {
-            statefulProcessorHandle.getListState(
-              stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
-          }
-          listStates.put(stateName,
-            ListStateInfo(state, schema, 
expressionEncoder.createDeserializer(),
-              expressionEncoder.createSerializer()))
-          sendResponse(0)
+        valueStates.put(stateName,
+          ValueStateInfo(state, schema, 
expressionEncoder.createDeserializer()))
+        sendResponse(0)
+      } else {
+        sendResponse(1, s"Value state $stateName already exists")
+      }
+      case StateVariableType.ListState => if (!listStates.contains(stateName)) 
{
+        val state = if (ttlDurationMs.isEmpty) {
+          statefulProcessorHandle.getListState[Row](stateName, 
Encoders.row(schema))
         } else {
-          sendResponse(1, s"List state $stateName already exists")
+          statefulProcessorHandle.getListState(
+            stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
         }
-        case StateVariableType.MapState => if (!mapStates.contains(stateName)) 
{
-          val valueSchema = StructType.fromString(mapStateValueSchemaString)
-          val valueExpressionEncoder = 
ExpressionEncoder(valueSchema).resolveAndBind()
-          val state = if (ttlDurationMs.isEmpty) {
-            statefulProcessorHandle.getMapState[Row, Row](stateName,
-              Encoders.row(schema), Encoders.row(valueSchema))
-          } else {
-            statefulProcessorHandle.getMapState[Row, Row](stateName, 
Encoders.row(schema),
-              Encoders.row(valueSchema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
-          }
-          mapStates.put(stateName,
-            MapStateInfo(state, schema, valueSchema, 
expressionEncoder.createDeserializer(),
-              expressionEncoder.createSerializer(), 
valueExpressionEncoder.createDeserializer(),
-              valueExpressionEncoder.createSerializer()))
-          sendResponse(0)
+        listStates.put(stateName,
+          ListStateInfo(state, schema, expressionEncoder.createDeserializer(),
+            expressionEncoder.createSerializer()))
+        sendResponse(0)
+      } else {
+        sendResponse(1, s"List state $stateName already exists")
+      }
+      case StateVariableType.MapState => if (!mapStates.contains(stateName)) {
+        val valueSchema = StructType.fromString(mapStateValueSchemaString)
+        val valueExpressionEncoder = 
ExpressionEncoder(valueSchema).resolveAndBind()
+        val state = if (ttlDurationMs.isEmpty) {
+          statefulProcessorHandle.getMapState[Row, Row](stateName,
+            Encoders.row(schema), Encoders.row(valueSchema))
         } else {
-          sendResponse(1, s"Map state $stateName already exists")
+          statefulProcessorHandle.getMapState[Row, Row](stateName, 
Encoders.row(schema),
+            Encoders.row(valueSchema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
         }
+        mapStates.put(stateName,
+          MapStateInfo(state, schema, valueSchema, 
expressionEncoder.createDeserializer(),
+            expressionEncoder.createSerializer(), 
valueExpressionEncoder.createDeserializer(),
+            valueExpressionEncoder.createSerializer()))
+        sendResponse(0)
+      } else {
+        sendResponse(1, s"Map state $stateName already exists")
+      }
     }
   }
 
-  private def sendIteratorAsArrowBatches[T](
-      iter: Iterator[T],
-      outputSchema: StructType,
-      arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T => 
InternalRow): Unit = {
-    outputStream.flush()
-    val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
-        errorOnDuplicatedFieldNames, largeVarTypes)
-    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-      s"stdout writer for transformWithStateInPandas state socket", 0, 
Long.MaxValue)
-    val root = VectorSchemaRoot.create(arrowSchema, allocator)
-    val writer = new ArrowStreamWriter(root, null, outputStream)
-    val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
-      arrowStreamWriterForTest
-    } else {
-      new BaseStreamingArrowWriter(root, writer, 
arrowTransformWithStateInPandasMaxRecordsPerBatch)
+  /** Utils object for sending response to Python client. */
+  private object PythonResponseWriterUtils {
+    def sendResponse(
+        status: Int,
+        errorMessage: String = null,
+        byteString: ByteString = null): Unit = {
+      val responseMessageBuilder = 
StateResponse.newBuilder().setStatusCode(status)
+      if (status != 0 && errorMessage != null) {
+        responseMessageBuilder.setErrorMessage(errorMessage)
+      }
+      if (byteString != null) {
+        responseMessageBuilder.setValue(byteString)
+      }
+      val responseMessage = responseMessageBuilder.build()
+      val responseMessageBytes = responseMessage.toByteArray
+      val byteLength = responseMessageBytes.length
+      outputStream.writeInt(byteLength)
+      outputStream.write(responseMessageBytes)
     }
-    // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
-    // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to 
handle a case
-    // when there are multiple state variables, user tries to access a 
different state variable
-    // while the current state variable is not exhausted yet.
-    var rowCount = 0
-    while (iter.hasNext && rowCount < 
arrowTransformWithStateInPandasMaxRecordsPerBatch) {
-      val data = iter.next()
-      val internalRow = func(data)
-      arrowStreamWriter.writeRow(internalRow)
-      rowCount += 1
+
+    def sendResponseWithLongVal(
+        status: Int,
+        errorMessage: String = null,
+        longVal: Long): Unit = {
+      val responseMessageBuilder = 
StateResponseWithLongTypeVal.newBuilder().setStatusCode(status)
+      if (status != 0 && errorMessage != null) {
+        responseMessageBuilder.setErrorMessage(errorMessage)
+      }
+      responseMessageBuilder.setValue(longVal)
+      val responseMessage = responseMessageBuilder.build()
+      val responseMessageBytes = responseMessage.toByteArray
+      val byteLength = responseMessageBytes.length
+      outputStream.writeInt(byteLength)
+      outputStream.write(responseMessageBytes)
     }
-    arrowStreamWriter.finalizeCurrentArrowBatch()
-    Utils.tryWithSafeFinally {
-      // end writes footer to the output stream and doesn't clean any 
resources.
-      // It could throw exception if the output stream is closed, so it should 
be
-      // in the try block.
-      writer.end()
-    } {
-      root.close()
-      allocator.close()
+
+    def sendIteratorAsArrowBatches[T](
+        iter: Iterator[T],
+        outputSchema: StructType,
+        arrowStreamWriterForTest: BaseStreamingArrowWriter = null)(func: T => 
InternalRow): Unit = {
+      outputStream.flush()
+      val arrowSchema = ArrowUtils.toArrowSchema(outputSchema, timeZoneId,
+        errorOnDuplicatedFieldNames, largeVarTypes)
+      val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+        s"stdout writer for transformWithStateInPandas state socket", 0, 
Long.MaxValue)
+      val root = VectorSchemaRoot.create(arrowSchema, allocator)
+      val writer = new ArrowStreamWriter(root, null, outputStream)
+      val arrowStreamWriter = if (arrowStreamWriterForTest != null) {
+        arrowStreamWriterForTest
+      } else {
+        new BaseStreamingArrowWriter(root, writer,
+          arrowTransformWithStateInPandasMaxRecordsPerBatch)
+      }
+      // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
+      // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is 
to handle a case
+      // when there are multiple state variables, user tries to access a 
different state variable
+      // while the current state variable is not exhausted yet.
+      var rowCount = 0
+      while (iter.hasNext && rowCount < 
arrowTransformWithStateInPandasMaxRecordsPerBatch) {
+        val data = iter.next()
+        val internalRow = func(data)
+        arrowStreamWriter.writeRow(internalRow)
+        rowCount += 1
+      }
+      arrowStreamWriter.finalizeCurrentArrowBatch()
+      Utils.tryWithSafeFinally {
+        // end writes footer to the output stream and doesn't clean any 
resources.
+        // It could throw exception if the output stream is closed, so it 
should be
+        // in the try block.
+        writer.end()
+      } {
+        root.close()
+        allocator.close()
+      }
     }
   }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index 2a728dc81d0b..aaaa53059126 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -32,7 +32,7 @@ import 
org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, 
StatefulProcessorHandleState}
 import org.apache.spark.sql.execution.streaming.state.StateMessage
-import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, 
AppendValue, Clear, ContainsKey, Exists, Get, GetValue, HandleState, Keys, 
ListStateCall, ListStateGet, ListStatePut, MapStateCall, RemoveKey, 
SetHandleState, StateCallCommand, StatefulProcessorCall, UpdateValue, Values, 
ValueStateCall, ValueStateUpdate}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, 
AppendValue, Clear, ContainsKey, DeleteTimer, Exists, ExpiryTimerRequest, Get, 
GetProcessingTime, GetValue, GetWatermark, HandleState, Keys, ListStateCall, 
ListStateGet, ListStatePut, ListTimers, MapStateCall, RegisterTimer, RemoveKey, 
SetHandleState, StateCallCommand, StatefulProcessorCall, TimerRequest, 
TimerStateCallCommand, TimerValueRequest, UpdateValue, Values, ValueStateCall, 
ValueStateUpdate}
 import org.apache.spark.sql.streaming.{ListState, MapState, TTLConfig, 
ValueState}
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 
@@ -59,9 +59,13 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
   var stateSerializer: ExpressionEncoder.Serializer[Row] = _
   var transformWithStateInPandasDeserializer: 
TransformWithStateInPandasDeserializer = _
   var arrowStreamWriter: BaseStreamingArrowWriter = _
+  var batchTimestampMs: Option[Long] = _
+  var eventTimeWatermarkForEviction: Option[Long] = _
   var valueStateMap: mutable.HashMap[String, ValueStateInfo] = 
mutable.HashMap()
   var listStateMap: mutable.HashMap[String, ListStateInfo] = mutable.HashMap()
   var mapStateMap: mutable.HashMap[String, MapStateInfo] = mutable.HashMap()
+  var expiryTimerIter: Iterator[(Any, Long)] = _
+  var listTimerMap: mutable.HashMap[String, Iterator[Long]] = mutable.HashMap()
 
   override def beforeEach(): Unit = {
     statefulProcessorHandle = mock(classOf[StatefulProcessorHandleImpl])
@@ -83,15 +87,20 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     // reset the iterator map to empty so be careful to call it if you want to 
access the iterator
     // map later.
     val testRow = getIntegerRow(1)
+    expiryTimerIter = Iterator.single(testRow, 1L /* a random long type value 
*/)
     val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> 
Iterator(testRow))
     val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row, 
Row)]](iteratorId ->
       Iterator((testRow, testRow)))
+    listTimerMap = mutable.HashMap[String, Iterator[Long]](iteratorId -> 
Iterator(1L))
     transformWithStateInPandasDeserializer = 
mock(classOf[TransformWithStateInPandasDeserializer])
     arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter])
+    batchTimestampMs = mock(classOf[Option[Long]])
+    eventTimeWatermarkForEviction = mock(classOf[Option[Long]])
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false, 2,
+      batchTimestampMs, eventTimeWatermarkForEviction,
       outputStream, valueStateMap, transformWithStateInPandasDeserializer, 
arrowStreamWriter,
-      listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap)
+      listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap, 
expiryTimerIter, listTimerMap)
     when(transformWithStateInPandasDeserializer.readArrowBatches(any))
       .thenReturn(Seq(getIntegerRow(1)))
   }
@@ -251,8 +260,9 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
       Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3), 
getIntegerRow(4)))
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false,
-      maxRecordsPerBatch, outputStream, valueStateMap,
-      transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, 
iteratorMap)
+      maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
+      valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
+      listStateMap, iteratorMap)
     // First call should send 2 records.
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
@@ -274,8 +284,9 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false,
-      maxRecordsPerBatch, outputStream, valueStateMap,
-      transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, 
iteratorMap)
+      maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
+      valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
+      listStateMap, iteratorMap)
     when(listState.get()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2), getIntegerRow(3)))
     stateServer.handleListStateRequest(message)
     verify(listState).get()
@@ -362,8 +373,9 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
         (getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4), 
getIntegerRow(4))))
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false,
-      maxRecordsPerBatch, outputStream, valueStateMap, 
transformWithStateInPandasDeserializer,
-      arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap)
+      maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
+      valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
+      listStateMap, null, mapStateMap, keyValueIteratorMap)
     // First call should send 2 records.
     stateServer.handleMapStateRequest(message)
     verify(mapState, times(0)).iterator()
@@ -385,7 +397,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] = 
mutable.HashMap()
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false,
-      maxRecordsPerBatch, outputStream, valueStateMap, 
transformWithStateInPandasDeserializer,
+      maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
+      outputStream, valueStateMap, transformWithStateInPandasDeserializer,
       arrowStreamWriter, listStateMap, null, mapStateMap, keyValueIteratorMap)
     when(mapState.iterator()).thenReturn(Iterator((getIntegerRow(1), 
getIntegerRow(1)),
       (getIntegerRow(2), getIntegerRow(2)), (getIntegerRow(3), 
getIntegerRow(3))))
@@ -413,7 +426,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false,
-      maxRecordsPerBatch, outputStream, valueStateMap, 
transformWithStateInPandasDeserializer,
+      maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
+      outputStream, valueStateMap, transformWithStateInPandasDeserializer,
       arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
     when(mapState.keys()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2), getIntegerRow(3)))
     stateServer.handleMapStateRequest(message)
@@ -440,7 +454,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
     stateServer = new TransformWithStateInPandasStateServer(serverSocket,
       statefulProcessorHandle, groupingKeySchema, "", false, false,
-      maxRecordsPerBatch, outputStream, valueStateMap, 
transformWithStateInPandasDeserializer,
+      maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, 
outputStream,
+      valueStateMap, transformWithStateInPandasDeserializer,
       arrowStreamWriter, listStateMap, iteratorMap, mapStateMap)
     when(mapState.values()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2),
       getIntegerRow(3)))
@@ -460,6 +475,92 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     verify(mapState).removeKey(any[Row])
   }
 
+  test("timer value get processing time") {
+    val message = TimerRequest.newBuilder().setTimerValueRequest(
+      TimerValueRequest.newBuilder().setGetProcessingTimer(
+        GetProcessingTime.newBuilder().build()
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(batchTimestampMs).isDefined
+    verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+  }
+
+  test("timer value get watermark") {
+    val message = TimerRequest.newBuilder().setTimerValueRequest(
+      TimerValueRequest.newBuilder().setGetWatermark(
+        GetWatermark.newBuilder().build()
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(eventTimeWatermarkForEviction).isDefined
+    verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+  }
+
+  test("get expiry timers") {
+    val message = TimerRequest.newBuilder().setExpiryTimerRequest(
+      ExpiryTimerRequest.newBuilder().setExpiryTimestampMs(
+        10L
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(arrowStreamWriter).writeRow(any)
+    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+  }
+
+  test("stateful processor register timer") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        
.setRegister(RegisterTimer.newBuilder().setExpiryTimestampMs(10L).build())
+        .build()
+    ).build()
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle).registerTimer(any[Long])
+    verify(outputStream).writeInt(0)
+  }
+
+  test("stateful processor delete timer") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        .setDelete(DeleteTimer.newBuilder().setExpiryTimestampMs(10L).build())
+        .build()
+    ).build()
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle).deleteTimer(any[Long])
+    verify(outputStream).writeInt(0)
+  }
+
+  test("stateful processor list timer - iterator in map") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        .setList(ListTimers.newBuilder().setIteratorId(iteratorId).build())
+        .build()
+    ).build()
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle, times(0)).listTimers()
+    verify(arrowStreamWriter).writeRow(any)
+    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+  }
+
+  test("stateful processor list timer - iterator not in map") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        .setList(ListTimers.newBuilder().setIteratorId("non-exist").build())
+        .build()
+    ).build()
+    stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+      statefulProcessorHandle, groupingKeySchema, "", false, false,
+      2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
+      valueStateMap, transformWithStateInPandasDeserializer,
+      arrowStreamWriter, listStateMap, null, mapStateMap, null,
+      null, listTimerMap)
+    when(statefulProcessorHandle.listTimers()).thenReturn(Iterator(1))
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle, times(1)).listTimers()
+    verify(arrowStreamWriter).writeRow(any)
+    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+  }
+
   private def getIntegerRow(value: Int): Row = {
     new GenericRowWithSchema(Array(value), stateSchema)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to