HeartSaVioR commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1800367171


##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            current_processing_time_in_ms: int = -1,
+            current_watermark_in_ms: int = -1) -> None:
+        self._current_processing_time_in_ms = current_processing_time_in_ms
+        self._current_watermark_in_ms = current_watermark_in_ms
+
+    """
+    Get processing time for current batch, return timestamp in millisecond.
+    """
+    def get_current_processing_time_in_ms(self) -> int:
+        return self._current_processing_time_in_ms
+
+    """
+    Get watermark for current batch, return timestamp in millisecond.
+    """
+    def get_current_watermark_in_ms(self) -> int:
+        return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access expired timer
+    info. When is_valid is false, the expiry timestamp is invalid.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,

Review Comment:
   nit: ditto



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,

Review Comment:
   nit: is this indentation correct? looks a bit odd, compared to others - 
params start from same indentation with the first `_`.



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            current_processing_time_in_ms: int = -1,
+            current_watermark_in_ms: int = -1) -> None:
+        self._current_processing_time_in_ms = current_processing_time_in_ms
+        self._current_watermark_in_ms = current_watermark_in_ms
+
+    """
+    Get processing time for current batch, return timestamp in millisecond.
+    """
+    def get_current_processing_time_in_ms(self) -> int:
+        return self._current_processing_time_in_ms
+
+    """
+    Get watermark for current batch, return timestamp in millisecond.
+    """
+    def get_current_watermark_in_ms(self) -> int:
+        return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access expired timer
+    info. When is_valid is false, the expiry timestamp is invalid.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            is_valid: bool,
+            expiry_time_in_ms: int = -1) -> None:
+        self._is_valid = is_valid
+        self._expiry_time_in_ms = expiry_time_in_ms
+
+    """

Review Comment:
   nit: ditto



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -180,6 +234,24 @@ def getListState(
         self.stateful_processor_api_client.get_list_state(state_name, schema, 
ttl_duration_ms)
         return ListState(ListStateClient(self.stateful_processor_api_client), 
state_name, schema)
 
+    def registerTimer(self, expiry_time_stamp_ms: int) -> None:
+        """
+        Register a timer for a given expiry timestamp in milliseconds for the 
grouping key.
+        """
+        self.stateful_processor_api_client.register_timer(expiry_time_stamp_ms)
+
+    def deleteTimer(self, expiry_time_stamp_ms: int) -> None:
+        """
+        Delete a timer for a given expiry timestamp in milliseconds for the 
grouping key.
+        """
+        self.stateful_processor_api_client.delete_timer(expiry_time_stamp_ms)
+
+    def listTimers(self) -> Iterator[list[int]]:

Review Comment:
   nit: probably it should be `Iterator[int]`? Do I miss something?



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
         finally:
             input_dir.cleanup()
 
+    def _test_transform_with_state_in_pandas_proc_timer(
+            self, stateful_processor, check_results):

Review Comment:
   nit: indentation?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,41 @@ def transformWithStateUDF(
                 )
 
             statefulProcessorApiClient.set_implicit_key(key)
-            result = statefulProcessor.handleInputRows(key, inputRows)
+
+            if timeMode != "none":
+                batch_timestamp = 
statefulProcessorApiClient.get_batch_timestamp()
+                watermark_timestamp = 
statefulProcessorApiClient.get_watermark_timestamp()
+            else:
+                batch_timestamp = -1
+                watermark_timestamp = -1
+            # process with invalid expiry timer info and emit data rows
+            data_iter = statefulProcessor.handleInputRows(
+                key, inputRows, TimerValues(batch_timestamp, 
watermark_timestamp), ExpiredTimerInfo(False))
+            statefulProcessorApiClient.set_handle_state(
+                StatefulProcessorHandleState.DATA_PROCESSED
+            )
+
+            if timeMode == "processingtime":
+                expiry_list_iter = 
statefulProcessorApiClient.get_expiry_timers_iterator(batch_timestamp)
+            elif timeMode == "eventtime":
+                expiry_list_iter = 
statefulProcessorApiClient.get_expiry_timers_iterator(watermark_timestamp)
+            else:
+                expiry_list_iter = []
+
+            result_iter_list = [data_iter]
+            # process with valid expiry time info and with empty input rows,
+            # only timer related rows will be emitted
+            for expiry_list in expiry_list_iter:
+                for key_obj, expiry_timestamp in expiry_list:
+                    if (timeMode == "processingtime" and expiry_timestamp < 
batch_timestamp) or\

Review Comment:
   I might lose following of how TWS (for PySpark) works, but given we get the 
iterator of expiry timers based on the timestamp, isn't this if statement 
already covered from API call? In other words, shouldn't API need to cover this?
   
   Please let me know if there is specific reason - don't need to change the 
code directly if there is a reason. I just wanted to understand and possibly 
refresh my head.



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            current_processing_time_in_ms: int = -1,
+            current_watermark_in_ms: int = -1) -> None:
+        self._current_processing_time_in_ms = current_processing_time_in_ms
+        self._current_watermark_in_ms = current_watermark_in_ms
+
+    """
+    Get processing time for current batch, return timestamp in millisecond.
+    """
+    def get_current_processing_time_in_ms(self) -> int:
+        return self._current_processing_time_in_ms
+
+    """

Review Comment:
   nit: ditto



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            current_processing_time_in_ms: int = -1,
+            current_watermark_in_ms: int = -1) -> None:
+        self._current_processing_time_in_ms = current_processing_time_in_ms
+        self._current_watermark_in_ms = current_watermark_in_ms
+
+    """

Review Comment:
   nit: method doc in Python is placed "after" definition of the method.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -60,8 +62,13 @@ class TransformWithStateInPandasStateServer(
     deserializerForTest: TransformWithStateInPandasDeserializer = null,
     arrowStreamWriterForTest: BaseStreamingArrowWriter = null,
     listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
-    listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null)
+    listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
+    batchTimestampMs: Option[Long] = None,

Review Comment:
   nit: let's move these new params above than params for testing. (move to 
above line of outputStreamForTest) Need to separate both.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -137,11 +148,57 @@ class TransformWithStateInPandasStateServer(
         handleStatefulProcessorCall(message.getStatefulProcessorCall)
       case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
         handleStateVariableRequest(message.getStateVariableRequest)
+      case StateRequest.MethodCase.TIMERREQUEST =>
+        handleTimerRequest(message.getTimerRequest)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
   }
 
+  private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+    message.getMethodCase match {
+      case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+        val timerRequest = message.getTimerValueRequest()
+        timerRequest.getMethodCase match {
+          case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+            val procTimestamp: Long =
+              if (batchTimestampMs.isDefined) batchTimestampMs.get else -1L
+            sendResponseWithLongVal(0, null, procTimestamp)
+          case TimerValueRequest.MethodCase.GETWATERMARK =>
+            val eventTimestamp: Long =
+              if (eventTimeWatermarkForEviction.isDefined) 
eventTimeWatermarkForEviction.get
+              else -1L
+            sendResponseWithLongVal(0, null, eventTimestamp)
+          case _ =>
+            throw new IllegalArgumentException("Invalid timer value method 
call")
+        }
+
+      case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+        val expiryRequest = message.getExpiryTimerRequest()
+        val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+        if (!expiryTimestampIter.isDefined) {
+          expiryTimestampIter =
+            Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp))
+        }
+        // expiryTimestampIter could be None in the TWSPandasServerSuite
+        if (!expiryTimestampIter.isDefined || 
!expiryTimestampIter.get.hasNext) {
+          // iterator is exhausted, signal the end of iterator on python client
+          sendResponse(1)

Review Comment:
   probably need to reset expiryTimestampIter to None, so that this can handle 
other request?



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
         self._value_state_client.clear(self._state_name)
 
 
+class TimerValues:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access processing
+    time or event time for current batch.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            current_processing_time_in_ms: int = -1,
+            current_watermark_in_ms: int = -1) -> None:
+        self._current_processing_time_in_ms = current_processing_time_in_ms
+        self._current_watermark_in_ms = current_watermark_in_ms
+
+    """
+    Get processing time for current batch, return timestamp in millisecond.
+    """
+    def get_current_processing_time_in_ms(self) -> int:
+        return self._current_processing_time_in_ms
+
+    """
+    Get watermark for current batch, return timestamp in millisecond.
+    """
+    def get_current_watermark_in_ms(self) -> int:
+        return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+    """
+    Class used for arbitrary stateful operations with transformWithState to 
access expired timer
+    info. When is_valid is false, the expiry timestamp is invalid.
+
+    .. versionadded:: 4.0.0
+    """
+    def __init__(
+            self,
+            is_valid: bool,
+            expiry_time_in_ms: int = -1) -> None:
+        self._is_valid = is_valid
+        self._expiry_time_in_ms = expiry_time_in_ms
+
+    """
+    Whether the expiry info is valid.
+    """
+    def is_valid(self) -> bool:
+        return self._is_valid
+
+    """

Review Comment:
   nit: ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -137,11 +148,57 @@ class TransformWithStateInPandasStateServer(
         handleStatefulProcessorCall(message.getStatefulProcessorCall)
       case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
         handleStateVariableRequest(message.getStateVariableRequest)
+      case StateRequest.MethodCase.TIMERREQUEST =>
+        handleTimerRequest(message.getTimerRequest)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
   }
 
+  private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+    message.getMethodCase match {
+      case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+        val timerRequest = message.getTimerValueRequest()
+        timerRequest.getMethodCase match {
+          case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+            val procTimestamp: Long =
+              if (batchTimestampMs.isDefined) batchTimestampMs.get else -1L
+            sendResponseWithLongVal(0, null, procTimestamp)
+          case TimerValueRequest.MethodCase.GETWATERMARK =>
+            val eventTimestamp: Long =
+              if (eventTimeWatermarkForEviction.isDefined) 
eventTimeWatermarkForEviction.get
+              else -1L
+            sendResponseWithLongVal(0, null, eventTimestamp)
+          case _ =>
+            throw new IllegalArgumentException("Invalid timer value method 
call")
+        }
+
+      case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+        val expiryRequest = message.getExpiryTimerRequest()
+        val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+        if (!expiryTimestampIter.isDefined) {

Review Comment:
   Do we ever have a case where the requests of expire timestamp iterator are 
possibly "interleaved"? Like, requesting expiry timers with different timestamp 
before consuming iterator of previous request.
   
   Also, we seem to have an issue with re-entrance; we don't distinguish two 
different requests for the same timestamp if they are interleaved.
   
   Is this something we do not expose the possibility to users (so we wouldn't 
have an issue), or should we fix this (e.g. assign UUID)?
   
   cc. @bogao007 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -190,11 +253,43 @@ class TransformWithStateInPandasStateServer(
         val stateName = message.getGetListState.getStateName
         val schema = message.getGetListState.getSchema
         val ttlDurationMs = if (message.getGetListState.hasTtl) {
-          Some(message.getGetListState.getTtl.getDurationMs)
-        } else {
-            None
-        }
+            Some(message.getGetListState.getTtl.getDurationMs)

Review Comment:
   nit: looks like the prior is correct for indentation?



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -75,6 +75,11 @@ def _prepare_test_resource2(self, input_path):
             input_path + "/text-test2.txt", [0, 0, 0, 1, 1], [123, 223, 323, 
246, 6]
         )
 
+    def _prepare_test_resource3(self, input_path):
+        with open(input_path + "/text-test3.txt", "w") as fw:

Review Comment:
   nit: any specific reason to implement here separately rather than calling 
_prepare_input_data?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -279,4 +283,68 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
     verify(listState).appendList(any)
   }
+
+  test("timer value get processing time") {
+    val message = TimerRequest.newBuilder().setTimerValueRequest(
+      TimerValueRequest.newBuilder().setGetProcessingTimer(
+        GetProcessingTime.newBuilder().build()
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(batchTimestampMs).isDefined
+    verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+  }
+
+  test("timer value get watermark") {
+    val message = TimerRequest.newBuilder().setTimerValueRequest(
+      TimerValueRequest.newBuilder().setGetWatermark(
+        GetWatermark.newBuilder().build()
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(eventTimeWatermarkForEviction).isDefined
+    verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+  }
+
+  test("get expiry timers") {
+    val message = TimerRequest.newBuilder().setExpiryTimerRequest(
+      ExpiryTimerRequest.newBuilder().setExpiryTimestampMs(
+        10L
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(statefulProcessorHandle).getExpiredTimers(any[Long])

Review Comment:
   For other functionality we add the spying instance as a param to test this. 
Do we test this in e2e instead? I'm OK with it. Just wanted to check.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -190,11 +253,43 @@ class TransformWithStateInPandasStateServer(
         val stateName = message.getGetListState.getStateName
         val schema = message.getGetListState.getSchema
         val ttlDurationMs = if (message.getGetListState.hasTtl) {
-          Some(message.getGetListState.getTtl.getDurationMs)
-        } else {
-            None
-        }
+            Some(message.getGetListState.getTtl.getDurationMs)
+          } else {
+              None
+          }
         initializeStateVariable(stateName, schema, 
StateVariableType.ListState, ttlDurationMs)
+      case StatefulProcessorCall.MethodCase.TIMERSTATECALL =>
+        message.getTimerStateCall.getMethodCase match {
+          case TimerStateCallCommand.MethodCase.REGISTER =>
+            val expiryTimestamp =
+              message.getTimerStateCall.getRegister.getExpiryTimestampMs
+            statefulProcessorHandle.registerTimer(expiryTimestamp)
+            sendResponse(0)
+          case TimerStateCallCommand.MethodCase.DELETE =>
+            val expiryTimestamp =
+              message.getTimerStateCall.getDelete.getExpiryTimestampMs
+            statefulProcessorHandle.deleteTimer(expiryTimestamp)
+            sendResponse(0)
+          case TimerStateCallCommand.MethodCase.LIST =>
+            if (!listTimerIter.isDefined) {

Review Comment:
   Same concern, though I guess if this is a real issue, the issue is 
widespread to the TWS in Pandas.



##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -152,7 +155,126 @@ def get_list_state(
         status = response_message[0]
         if status != 0:
             # TODO(SPARK-49233): Classify user facing errors.
-            raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
+            raise PySparkRuntimeError(f"Error initializing list state: " 
f"{response_message[1]}")
+
+    def register_timer(self, expiry_time_stamp_ms: int) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        register_call = 
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+        state_call_command = 
stateMessage.TimerStateCallCommand(register=register_call)
+        call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error register timer: " 
f"{response_message[1]}")
+
+    def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+        import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+        delete_call = 
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+        state_call_command = 
stateMessage.TimerStateCallCommand(delete=delete_call)
+        call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+        message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+        self._send_proto_message(message.SerializeToString())
+        response_message = self._receive_proto_message()
+        status = response_message[0]
+        if status != 0:
+            # TODO(SPARK-49233): Classify user facing errors.
+            raise PySparkRuntimeError(f"Error deleting timer: " 
f"{response_message[1]}")
+
+    def list_timers(self) -> Iterator[list[int]]:

Review Comment:
   Ah OK, ignore the previous comment about type hint. This is specific to 
PySpark impl of TWS - I see what you are doing.



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
         finally:
             input_dir.cleanup()
 
+    def _test_transform_with_state_in_pandas_proc_timer(
+            self, stateful_processor, check_results):
+        input_path = tempfile.mkdtemp()
+        self._prepare_test_resource3(input_path)
+        self._prepare_test_resource1(input_path)
+        self._prepare_test_resource2(input_path)
+
+        df = self._build_test_df(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("countAsString", StringType(), True),
+                StructField("timeValues", StringType(), True)
+            ]
+        )
+
+        query_name = "processing_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="processingtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_proc_timer(self):
+
+        # helper function to check expired timestamp is smaller than current 
processing time
+        def check_timestamp(batch_df):
+            expired_df = batch_df.filter(batch_df["countAsString"] == "-1") \
+                .select("id", "timeValues").withColumnRenamed("timeValues", 
"expiredTimestamp")
+            count_df = batch_df.filter(batch_df["countAsString"] != "-1") \
+                .select("id", "timeValues").withColumnRenamed("timeValues", 
"countStateTimestamp")
+            joined_df = expired_df.join(count_df, on="id")
+            for row in joined_df.collect():
+                assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="1"),
+                    Row(id="1", countAsString="1"),
+                }
+            elif batch_id == 1:
+                # for key 0, the accumulated count is emitted before the count 
state is cleared
+                # during the timer process
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="3"),
+                }
+                self.first_expired_timestamp = \
+                    batch_df.filter(batch_df["countAsString"] == 
-1).first()['timeValues']
+                check_timestamp(batch_df)
+
+            else:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="5")
+                }
+                # The expired timestamp in current batch is larger than expiry 
timestamp in batch 1
+                # because this is a new timer registered in batch1 and
+                # different from the one registered in batch 0
+                current_batch_expired_timestamp = \
+                    batch_df.filter(batch_df["countAsString"] == 
-1).first()['timeValues']
+                assert(current_batch_expired_timestamp > 
self.first_expired_timestamp)
+
+        
self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(),
 check_results)
+
+    def _test_transform_with_state_in_pandas_event_time(self, 
stateful_processor, check_results):
+        import pyspark.sql.functions as f
+
+        input_path = tempfile.mkdtemp()
+
+        def prepare_batch1(input_path):
+            with open(input_path + "/text-test3.txt", "w") as fw:
+                fw.write("a, 20\n")
+
+        def prepare_batch2(input_path):
+            with open(input_path + "/text-test1.txt", "w") as fw:
+                fw.write("a, 4\n")
+
+        def prepare_batch3(input_path):
+            with open(input_path + "/text-test2.txt", "w") as fw:
+                fw.write("a, 11\n")
+                fw.write("a, 13\n")
+                fw.write("a, 15\n")
+
+        prepare_batch1(input_path)
+        prepare_batch2(input_path)
+        prepare_batch3(input_path)
+
+        df = self._build_test_df(input_path)
+        df = df.select("id",
+                       
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")) \
+            .withWatermark("eventTime", "10 seconds")
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("timestamp", StringType(), True)
+            ]
+        )
+
+        query_name = "event_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="eventtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_event_time(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="20")
+                }
+            elif batch_id == 1:
+                # event time = 4 will be discarded because the watermark = 15 
- 10 = 5

Review Comment:
   This is not true - watermark for eviction is 5 but watermark for late record 
is 0, hence ("a", 4) is not dropped. This is exactly the reason you still see 
event for "a". Otherwise you shouldn't have ("a", 20).



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -279,4 +283,68 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
     verify(listState).appendList(any)
   }
+
+  test("timer value get processing time") {
+    val message = TimerRequest.newBuilder().setTimerValueRequest(
+      TimerValueRequest.newBuilder().setGetProcessingTimer(
+        GetProcessingTime.newBuilder().build()
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(batchTimestampMs).isDefined
+    verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+  }
+
+  test("timer value get watermark") {
+    val message = TimerRequest.newBuilder().setTimerValueRequest(
+      TimerValueRequest.newBuilder().setGetWatermark(
+        GetWatermark.newBuilder().build()
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(eventTimeWatermarkForEviction).isDefined
+    verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+  }
+
+  test("get expiry timers") {
+    val message = TimerRequest.newBuilder().setExpiryTimerRequest(
+      ExpiryTimerRequest.newBuilder().setExpiryTimestampMs(
+        10L
+      ).build()
+    ).build()
+    stateServer.handleTimerRequest(message)
+    verify(statefulProcessorHandle).getExpiredTimers(any[Long])
+  }
+
+  test("stateful processor register timer") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        
.setRegister(RegisterTimer.newBuilder().setExpiryTimestampMs(10L).build())
+        .build()
+    ).build()
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle).registerTimer(any[Long])
+    verify(outputStream).writeInt(0)
+  }
+
+  test("stateful processor delete timer") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        .setDelete(DeleteTimer.newBuilder().setExpiryTimestampMs(10L).build())
+        .build()
+    ).build()
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle).deleteTimer(any[Long])
+    verify(outputStream).writeInt(0)
+  }
+
+  test("stateful processor list timer") {
+    val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+      TimerStateCallCommand.newBuilder()
+        .setList(ListTimers.newBuilder().build())
+        .build()
+    ).build()
+    stateServer.handleStatefulProcessorCall(message)
+    verify(statefulProcessorHandle).listTimers()

Review Comment:
   ditto: For other functionality we add the spying instance as a param to test 
this. Do we test this in e2e instead? I'm OK with it. Just wanted to check.



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
         finally:
             input_dir.cleanup()
 
+    def _test_transform_with_state_in_pandas_proc_timer(
+            self, stateful_processor, check_results):
+        input_path = tempfile.mkdtemp()
+        self._prepare_test_resource3(input_path)
+        self._prepare_test_resource1(input_path)
+        self._prepare_test_resource2(input_path)
+
+        df = self._build_test_df(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("countAsString", StringType(), True),
+                StructField("timeValues", StringType(), True)
+            ]
+        )
+
+        query_name = "processing_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="processingtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_proc_timer(self):
+
+        # helper function to check expired timestamp is smaller than current 
processing time
+        def check_timestamp(batch_df):
+            expired_df = batch_df.filter(batch_df["countAsString"] == "-1") \
+                .select("id", "timeValues").withColumnRenamed("timeValues", 
"expiredTimestamp")
+            count_df = batch_df.filter(batch_df["countAsString"] != "-1") \
+                .select("id", "timeValues").withColumnRenamed("timeValues", 
"countStateTimestamp")
+            joined_df = expired_df.join(count_df, on="id")
+            for row in joined_df.collect():
+                assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="1"),
+                    Row(id="1", countAsString="1"),
+                }
+            elif batch_id == 1:
+                # for key 0, the accumulated count is emitted before the count 
state is cleared
+                # during the timer process
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="3"),
+                }
+                self.first_expired_timestamp = \
+                    batch_df.filter(batch_df["countAsString"] == 
-1).first()['timeValues']
+                check_timestamp(batch_df)
+
+            else:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="5")
+                }
+                # The expired timestamp in current batch is larger than expiry 
timestamp in batch 1
+                # because this is a new timer registered in batch1 and
+                # different from the one registered in batch 0
+                current_batch_expired_timestamp = \
+                    batch_df.filter(batch_df["countAsString"] == 
-1).first()['timeValues']
+                assert(current_batch_expired_timestamp > 
self.first_expired_timestamp)
+
+        
self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(),
 check_results)
+
+    def _test_transform_with_state_in_pandas_event_time(self, 
stateful_processor, check_results):
+        import pyspark.sql.functions as f
+
+        input_path = tempfile.mkdtemp()
+
+        def prepare_batch1(input_path):
+            with open(input_path + "/text-test3.txt", "w") as fw:
+                fw.write("a, 20\n")
+
+        def prepare_batch2(input_path):
+            with open(input_path + "/text-test1.txt", "w") as fw:
+                fw.write("a, 4\n")
+
+        def prepare_batch3(input_path):
+            with open(input_path + "/text-test2.txt", "w") as fw:
+                fw.write("a, 11\n")
+                fw.write("a, 13\n")
+                fw.write("a, 15\n")
+
+        prepare_batch1(input_path)
+        prepare_batch2(input_path)
+        prepare_batch3(input_path)
+
+        df = self._build_test_df(input_path)
+        df = df.select("id",
+                       
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")) \
+            .withWatermark("eventTime", "10 seconds")
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("timestamp", StringType(), True)
+            ]
+        )
+
+        query_name = "event_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="eventtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_event_time(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="20")
+                }
+            elif batch_id == 1:
+                # event time = 4 will be discarded because the watermark = 15 
- 10 = 5

Review Comment:
   You might wonder how this works differently with Scala tests - AddData() & 
CheckNewAnswer() will trigger no-data batch, hence executing two batches.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to