HeartSaVioR commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1800367171
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ current_processing_time_in_ms: int = -1,
+ current_watermark_in_ms: int = -1) -> None:
+ self._current_processing_time_in_ms = current_processing_time_in_ms
+ self._current_watermark_in_ms = current_watermark_in_ms
+
+ """
+ Get processing time for current batch, return timestamp in millisecond.
+ """
+ def get_current_processing_time_in_ms(self) -> int:
+ return self._current_processing_time_in_ms
+
+ """
+ Get watermark for current batch, return timestamp in millisecond.
+ """
+ def get_current_watermark_in_ms(self) -> int:
+ return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access expired timer
+ info. When is_valid is false, the expiry timestamp is invalid.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
Review Comment:
nit: ditto
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
Review Comment:
nit: is this indentation correct? looks a bit odd, compared to others -
params start from same indentation with the first `_`.
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ current_processing_time_in_ms: int = -1,
+ current_watermark_in_ms: int = -1) -> None:
+ self._current_processing_time_in_ms = current_processing_time_in_ms
+ self._current_watermark_in_ms = current_watermark_in_ms
+
+ """
+ Get processing time for current batch, return timestamp in millisecond.
+ """
+ def get_current_processing_time_in_ms(self) -> int:
+ return self._current_processing_time_in_ms
+
+ """
+ Get watermark for current batch, return timestamp in millisecond.
+ """
+ def get_current_watermark_in_ms(self) -> int:
+ return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access expired timer
+ info. When is_valid is false, the expiry timestamp is invalid.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ is_valid: bool,
+ expiry_time_in_ms: int = -1) -> None:
+ self._is_valid = is_valid
+ self._expiry_time_in_ms = expiry_time_in_ms
+
+ """
Review Comment:
nit: ditto
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -180,6 +234,24 @@ def getListState(
self.stateful_processor_api_client.get_list_state(state_name, schema,
ttl_duration_ms)
return ListState(ListStateClient(self.stateful_processor_api_client),
state_name, schema)
+ def registerTimer(self, expiry_time_stamp_ms: int) -> None:
+ """
+ Register a timer for a given expiry timestamp in milliseconds for the
grouping key.
+ """
+ self.stateful_processor_api_client.register_timer(expiry_time_stamp_ms)
+
+ def deleteTimer(self, expiry_time_stamp_ms: int) -> None:
+ """
+ Delete a timer for a given expiry timestamp in milliseconds for the
grouping key.
+ """
+ self.stateful_processor_api_client.delete_timer(expiry_time_stamp_ms)
+
+ def listTimers(self) -> Iterator[list[int]]:
Review Comment:
nit: probably it should be `Iterator[int]`? Do I miss something?
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
finally:
input_dir.cleanup()
+ def _test_transform_with_state_in_pandas_proc_timer(
+ self, stateful_processor, check_results):
Review Comment:
nit: indentation?
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,41 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ if timeMode != "none":
+ batch_timestamp =
statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ else:
+ batch_timestamp = -1
+ watermark_timestamp = -1
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp), ExpiredTimerInfo(False))
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+
+ if timeMode == "processingtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(batch_timestamp)
+ elif timeMode == "eventtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(watermark_timestamp)
+ else:
+ expiry_list_iter = []
+
+ result_iter_list = [data_iter]
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for expiry_list in expiry_list_iter:
+ for key_obj, expiry_timestamp in expiry_list:
+ if (timeMode == "processingtime" and expiry_timestamp <
batch_timestamp) or\
Review Comment:
I might lose following of how TWS (for PySpark) works, but given we get the
iterator of expiry timers based on the timestamp, isn't this if statement
already covered from API call? In other words, shouldn't API need to cover this?
Please let me know if there is specific reason - don't need to change the
code directly if there is a reason. I just wanted to understand and possibly
refresh my head.
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ current_processing_time_in_ms: int = -1,
+ current_watermark_in_ms: int = -1) -> None:
+ self._current_processing_time_in_ms = current_processing_time_in_ms
+ self._current_watermark_in_ms = current_watermark_in_ms
+
+ """
+ Get processing time for current batch, return timestamp in millisecond.
+ """
+ def get_current_processing_time_in_ms(self) -> int:
+ return self._current_processing_time_in_ms
+
+ """
Review Comment:
nit: ditto
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ current_processing_time_in_ms: int = -1,
+ current_watermark_in_ms: int = -1) -> None:
+ self._current_processing_time_in_ms = current_processing_time_in_ms
+ self._current_watermark_in_ms = current_watermark_in_ms
+
+ """
Review Comment:
nit: method doc in Python is placed "after" definition of the method.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -60,8 +62,13 @@ class TransformWithStateInPandasStateServer(
deserializerForTest: TransformWithStateInPandasDeserializer = null,
arrowStreamWriterForTest: BaseStreamingArrowWriter = null,
listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null,
- listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null)
+ listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null,
+ batchTimestampMs: Option[Long] = None,
Review Comment:
nit: let's move these new params above than params for testing. (move to
above line of outputStreamForTest) Need to separate both.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -137,11 +148,57 @@ class TransformWithStateInPandasStateServer(
handleStatefulProcessorCall(message.getStatefulProcessorCall)
case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
handleStateVariableRequest(message.getStateVariableRequest)
+ case StateRequest.MethodCase.TIMERREQUEST =>
+ handleTimerRequest(message.getTimerRequest)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
}
+ private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+ message.getMethodCase match {
+ case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+ val timerRequest = message.getTimerValueRequest()
+ timerRequest.getMethodCase match {
+ case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+ val procTimestamp: Long =
+ if (batchTimestampMs.isDefined) batchTimestampMs.get else -1L
+ sendResponseWithLongVal(0, null, procTimestamp)
+ case TimerValueRequest.MethodCase.GETWATERMARK =>
+ val eventTimestamp: Long =
+ if (eventTimeWatermarkForEviction.isDefined)
eventTimeWatermarkForEviction.get
+ else -1L
+ sendResponseWithLongVal(0, null, eventTimestamp)
+ case _ =>
+ throw new IllegalArgumentException("Invalid timer value method
call")
+ }
+
+ case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+ val expiryRequest = message.getExpiryTimerRequest()
+ val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+ if (!expiryTimestampIter.isDefined) {
+ expiryTimestampIter =
+ Option(statefulProcessorHandle.getExpiredTimers(expiryTimestamp))
+ }
+ // expiryTimestampIter could be None in the TWSPandasServerSuite
+ if (!expiryTimestampIter.isDefined ||
!expiryTimestampIter.get.hasNext) {
+ // iterator is exhausted, signal the end of iterator on python client
+ sendResponse(1)
Review Comment:
probably need to reset expiryTimestampIter to None, so that this can handle
other request?
##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -69,6 +69,60 @@ def clear(self) -> None:
self._value_state_client.clear(self._state_name)
+class TimerValues:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access processing
+ time or event time for current batch.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ current_processing_time_in_ms: int = -1,
+ current_watermark_in_ms: int = -1) -> None:
+ self._current_processing_time_in_ms = current_processing_time_in_ms
+ self._current_watermark_in_ms = current_watermark_in_ms
+
+ """
+ Get processing time for current batch, return timestamp in millisecond.
+ """
+ def get_current_processing_time_in_ms(self) -> int:
+ return self._current_processing_time_in_ms
+
+ """
+ Get watermark for current batch, return timestamp in millisecond.
+ """
+ def get_current_watermark_in_ms(self) -> int:
+ return self._current_watermark_in_ms
+
+
+class ExpiredTimerInfo:
+ """
+ Class used for arbitrary stateful operations with transformWithState to
access expired timer
+ info. When is_valid is false, the expiry timestamp is invalid.
+
+ .. versionadded:: 4.0.0
+ """
+ def __init__(
+ self,
+ is_valid: bool,
+ expiry_time_in_ms: int = -1) -> None:
+ self._is_valid = is_valid
+ self._expiry_time_in_ms = expiry_time_in_ms
+
+ """
+ Whether the expiry info is valid.
+ """
+ def is_valid(self) -> bool:
+ return self._is_valid
+
+ """
Review Comment:
nit: ditto
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -137,11 +148,57 @@ class TransformWithStateInPandasStateServer(
handleStatefulProcessorCall(message.getStatefulProcessorCall)
case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
handleStateVariableRequest(message.getStateVariableRequest)
+ case StateRequest.MethodCase.TIMERREQUEST =>
+ handleTimerRequest(message.getTimerRequest)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
}
+ private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+ message.getMethodCase match {
+ case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+ val timerRequest = message.getTimerValueRequest()
+ timerRequest.getMethodCase match {
+ case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+ val procTimestamp: Long =
+ if (batchTimestampMs.isDefined) batchTimestampMs.get else -1L
+ sendResponseWithLongVal(0, null, procTimestamp)
+ case TimerValueRequest.MethodCase.GETWATERMARK =>
+ val eventTimestamp: Long =
+ if (eventTimeWatermarkForEviction.isDefined)
eventTimeWatermarkForEviction.get
+ else -1L
+ sendResponseWithLongVal(0, null, eventTimestamp)
+ case _ =>
+ throw new IllegalArgumentException("Invalid timer value method
call")
+ }
+
+ case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+ val expiryRequest = message.getExpiryTimerRequest()
+ val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+ if (!expiryTimestampIter.isDefined) {
Review Comment:
Do we ever have a case where the requests of expire timestamp iterator are
possibly "interleaved"? Like, requesting expiry timers with different timestamp
before consuming iterator of previous request.
Also, we seem to have an issue with re-entrance; we don't distinguish two
different requests for the same timestamp if they are interleaved.
Is this something we do not expose the possibility to users (so we wouldn't
have an issue), or should we fix this (e.g. assign UUID)?
cc. @bogao007
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -190,11 +253,43 @@ class TransformWithStateInPandasStateServer(
val stateName = message.getGetListState.getStateName
val schema = message.getGetListState.getSchema
val ttlDurationMs = if (message.getGetListState.hasTtl) {
- Some(message.getGetListState.getTtl.getDurationMs)
- } else {
- None
- }
+ Some(message.getGetListState.getTtl.getDurationMs)
Review Comment:
nit: looks like the prior is correct for indentation?
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -75,6 +75,11 @@ def _prepare_test_resource2(self, input_path):
input_path + "/text-test2.txt", [0, 0, 0, 1, 1], [123, 223, 323,
246, 6]
)
+ def _prepare_test_resource3(self, input_path):
+ with open(input_path + "/text-test3.txt", "w") as fw:
Review Comment:
nit: any specific reason to implement here separately rather than calling
_prepare_input_data?
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -279,4 +283,68 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
verify(listState).appendList(any)
}
+
+ test("timer value get processing time") {
+ val message = TimerRequest.newBuilder().setTimerValueRequest(
+ TimerValueRequest.newBuilder().setGetProcessingTimer(
+ GetProcessingTime.newBuilder().build()
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(batchTimestampMs).isDefined
+ verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+ }
+
+ test("timer value get watermark") {
+ val message = TimerRequest.newBuilder().setTimerValueRequest(
+ TimerValueRequest.newBuilder().setGetWatermark(
+ GetWatermark.newBuilder().build()
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(eventTimeWatermarkForEviction).isDefined
+ verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+ }
+
+ test("get expiry timers") {
+ val message = TimerRequest.newBuilder().setExpiryTimerRequest(
+ ExpiryTimerRequest.newBuilder().setExpiryTimestampMs(
+ 10L
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(statefulProcessorHandle).getExpiredTimers(any[Long])
Review Comment:
For other functionality we add the spying instance as a param to test this.
Do we test this in e2e instead? I'm OK with it. Just wanted to check.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -190,11 +253,43 @@ class TransformWithStateInPandasStateServer(
val stateName = message.getGetListState.getStateName
val schema = message.getGetListState.getSchema
val ttlDurationMs = if (message.getGetListState.hasTtl) {
- Some(message.getGetListState.getTtl.getDurationMs)
- } else {
- None
- }
+ Some(message.getGetListState.getTtl.getDurationMs)
+ } else {
+ None
+ }
initializeStateVariable(stateName, schema,
StateVariableType.ListState, ttlDurationMs)
+ case StatefulProcessorCall.MethodCase.TIMERSTATECALL =>
+ message.getTimerStateCall.getMethodCase match {
+ case TimerStateCallCommand.MethodCase.REGISTER =>
+ val expiryTimestamp =
+ message.getTimerStateCall.getRegister.getExpiryTimestampMs
+ statefulProcessorHandle.registerTimer(expiryTimestamp)
+ sendResponse(0)
+ case TimerStateCallCommand.MethodCase.DELETE =>
+ val expiryTimestamp =
+ message.getTimerStateCall.getDelete.getExpiryTimestampMs
+ statefulProcessorHandle.deleteTimer(expiryTimestamp)
+ sendResponse(0)
+ case TimerStateCallCommand.MethodCase.LIST =>
+ if (!listTimerIter.isDefined) {
Review Comment:
Same concern, though I guess if this is a real issue, the issue is
widespread to the TWS in Pandas.
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -152,7 +155,126 @@ def get_list_state(
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
- raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ raise PySparkRuntimeError(f"Error initializing list state: "
f"{response_message[1]}")
+
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error deleting timer: "
f"{response_message[1]}")
+
+ def list_timers(self) -> Iterator[list[int]]:
Review Comment:
Ah OK, ignore the previous comment about type hint. This is specific to
PySpark impl of TWS - I see what you are doing.
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
finally:
input_dir.cleanup()
+ def _test_transform_with_state_in_pandas_proc_timer(
+ self, stateful_processor, check_results):
+ input_path = tempfile.mkdtemp()
+ self._prepare_test_resource3(input_path)
+ self._prepare_test_resource1(input_path)
+ self._prepare_test_resource2(input_path)
+
+ df = self._build_test_df(input_path)
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("countAsString", StringType(), True),
+ StructField("timeValues", StringType(), True)
+ ]
+ )
+
+ query_name = "processing_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_proc_timer(self):
+
+ # helper function to check expired timestamp is smaller than current
processing time
+ def check_timestamp(batch_df):
+ expired_df = batch_df.filter(batch_df["countAsString"] == "-1") \
+ .select("id", "timeValues").withColumnRenamed("timeValues",
"expiredTimestamp")
+ count_df = batch_df.filter(batch_df["countAsString"] != "-1") \
+ .select("id", "timeValues").withColumnRenamed("timeValues",
"countStateTimestamp")
+ joined_df = expired_df.join(count_df, on="id")
+ for row in joined_df.collect():
+ assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="1"),
+ Row(id="1", countAsString="1"),
+ }
+ elif batch_id == 1:
+ # for key 0, the accumulated count is emitted before the count
state is cleared
+ # during the timer process
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="3"),
+ }
+ self.first_expired_timestamp = \
+ batch_df.filter(batch_df["countAsString"] ==
-1).first()['timeValues']
+ check_timestamp(batch_df)
+
+ else:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="5")
+ }
+ # The expired timestamp in current batch is larger than expiry
timestamp in batch 1
+ # because this is a new timer registered in batch1 and
+ # different from the one registered in batch 0
+ current_batch_expired_timestamp = \
+ batch_df.filter(batch_df["countAsString"] ==
-1).first()['timeValues']
+ assert(current_batch_expired_timestamp >
self.first_expired_timestamp)
+
+
self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(),
check_results)
+
+ def _test_transform_with_state_in_pandas_event_time(self,
stateful_processor, check_results):
+ import pyspark.sql.functions as f
+
+ input_path = tempfile.mkdtemp()
+
+ def prepare_batch1(input_path):
+ with open(input_path + "/text-test3.txt", "w") as fw:
+ fw.write("a, 20\n")
+
+ def prepare_batch2(input_path):
+ with open(input_path + "/text-test1.txt", "w") as fw:
+ fw.write("a, 4\n")
+
+ def prepare_batch3(input_path):
+ with open(input_path + "/text-test2.txt", "w") as fw:
+ fw.write("a, 11\n")
+ fw.write("a, 13\n")
+ fw.write("a, 15\n")
+
+ prepare_batch1(input_path)
+ prepare_batch2(input_path)
+ prepare_batch3(input_path)
+
+ df = self._build_test_df(input_path)
+ df = df.select("id",
+
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")) \
+ .withWatermark("eventTime", "10 seconds")
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("timestamp", StringType(), True)
+ ]
+ )
+
+ query_name = "event_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="eventtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_event_time(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="20")
+ }
+ elif batch_id == 1:
+ # event time = 4 will be discarded because the watermark = 15
- 10 = 5
Review Comment:
This is not true - watermark for eviction is 5 but watermark for late record
is 0, hence ("a", 4) is not dropped. This is exactly the reason you still see
event for "a". Otherwise you shouldn't have ("a", 20).
##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala:
##########
@@ -279,4 +283,68 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
verify(listState).appendList(any)
}
+
+ test("timer value get processing time") {
+ val message = TimerRequest.newBuilder().setTimerValueRequest(
+ TimerValueRequest.newBuilder().setGetProcessingTimer(
+ GetProcessingTime.newBuilder().build()
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(batchTimestampMs).isDefined
+ verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+ }
+
+ test("timer value get watermark") {
+ val message = TimerRequest.newBuilder().setTimerValueRequest(
+ TimerValueRequest.newBuilder().setGetWatermark(
+ GetWatermark.newBuilder().build()
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(eventTimeWatermarkForEviction).isDefined
+ verify(outputStream).writeInt(argThat((x: Int) => x > 0))
+ }
+
+ test("get expiry timers") {
+ val message = TimerRequest.newBuilder().setExpiryTimerRequest(
+ ExpiryTimerRequest.newBuilder().setExpiryTimestampMs(
+ 10L
+ ).build()
+ ).build()
+ stateServer.handleTimerRequest(message)
+ verify(statefulProcessorHandle).getExpiredTimers(any[Long])
+ }
+
+ test("stateful processor register timer") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+
.setRegister(RegisterTimer.newBuilder().setExpiryTimestampMs(10L).build())
+ .build()
+ ).build()
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle).registerTimer(any[Long])
+ verify(outputStream).writeInt(0)
+ }
+
+ test("stateful processor delete timer") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+ .setDelete(DeleteTimer.newBuilder().setExpiryTimestampMs(10L).build())
+ .build()
+ ).build()
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle).deleteTimer(any[Long])
+ verify(outputStream).writeInt(0)
+ }
+
+ test("stateful processor list timer") {
+ val message = StatefulProcessorCall.newBuilder().setTimerStateCall(
+ TimerStateCallCommand.newBuilder()
+ .setList(ListTimers.newBuilder().build())
+ .build()
+ ).build()
+ stateServer.handleStatefulProcessorCall(message)
+ verify(statefulProcessorHandle).listTimers()
Review Comment:
ditto: For other functionality we add the spying instance as a param to test
this. Do we test this in e2e instead? I'm OK with it. Just wanted to check.
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
finally:
input_dir.cleanup()
+ def _test_transform_with_state_in_pandas_proc_timer(
+ self, stateful_processor, check_results):
+ input_path = tempfile.mkdtemp()
+ self._prepare_test_resource3(input_path)
+ self._prepare_test_resource1(input_path)
+ self._prepare_test_resource2(input_path)
+
+ df = self._build_test_df(input_path)
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("countAsString", StringType(), True),
+ StructField("timeValues", StringType(), True)
+ ]
+ )
+
+ query_name = "processing_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_proc_timer(self):
+
+ # helper function to check expired timestamp is smaller than current
processing time
+ def check_timestamp(batch_df):
+ expired_df = batch_df.filter(batch_df["countAsString"] == "-1") \
+ .select("id", "timeValues").withColumnRenamed("timeValues",
"expiredTimestamp")
+ count_df = batch_df.filter(batch_df["countAsString"] != "-1") \
+ .select("id", "timeValues").withColumnRenamed("timeValues",
"countStateTimestamp")
+ joined_df = expired_df.join(count_df, on="id")
+ for row in joined_df.collect():
+ assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="1"),
+ Row(id="1", countAsString="1"),
+ }
+ elif batch_id == 1:
+ # for key 0, the accumulated count is emitted before the count
state is cleared
+ # during the timer process
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="3"),
+ }
+ self.first_expired_timestamp = \
+ batch_df.filter(batch_df["countAsString"] ==
-1).first()['timeValues']
+ check_timestamp(batch_df)
+
+ else:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="5")
+ }
+ # The expired timestamp in current batch is larger than expiry
timestamp in batch 1
+ # because this is a new timer registered in batch1 and
+ # different from the one registered in batch 0
+ current_batch_expired_timestamp = \
+ batch_df.filter(batch_df["countAsString"] ==
-1).first()['timeValues']
+ assert(current_batch_expired_timestamp >
self.first_expired_timestamp)
+
+
self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(),
check_results)
+
+ def _test_transform_with_state_in_pandas_event_time(self,
stateful_processor, check_results):
+ import pyspark.sql.functions as f
+
+ input_path = tempfile.mkdtemp()
+
+ def prepare_batch1(input_path):
+ with open(input_path + "/text-test3.txt", "w") as fw:
+ fw.write("a, 20\n")
+
+ def prepare_batch2(input_path):
+ with open(input_path + "/text-test1.txt", "w") as fw:
+ fw.write("a, 4\n")
+
+ def prepare_batch3(input_path):
+ with open(input_path + "/text-test2.txt", "w") as fw:
+ fw.write("a, 11\n")
+ fw.write("a, 13\n")
+ fw.write("a, 15\n")
+
+ prepare_batch1(input_path)
+ prepare_batch2(input_path)
+ prepare_batch3(input_path)
+
+ df = self._build_test_df(input_path)
+ df = df.select("id",
+
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")) \
+ .withWatermark("eventTime", "10 seconds")
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("timestamp", StringType(), True)
+ ]
+ )
+
+ query_name = "event_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="eventtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_event_time(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="20")
+ }
+ elif batch_id == 1:
+ # event time = 4 will be discarded because the watermark = 15
- 10 = 5
Review Comment:
You might wonder how this works differently with Scala tests - AddData() &
CheckNewAnswer() will trigger no-data batch, hence executing two batches.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]