bogao007 commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1761524284
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,47 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp), ExpiredTimerInfo(False))
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+
+ if timeMode == "processingtime":
+ expiry_list: list[(any, int)] =
statefulProcessorApiClient.get_expiry_timers(batch_timestamp)
+ elif timeMode == "eventtime":
+ expiry_list: list[(any, int)] =
statefulProcessorApiClient.get_expiry_timers(watermark_timestamp)
+ else:
+ expiry_list = []
+
+ result_iter_list = []
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for key_obj, expiry_timestamp in expiry_list:
+ if timeMode == "processingtime" and expiry_timestamp <
batch_timestamp:
+ result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
+ elif timeMode == "eventtime" and expiry_timestamp <
watermark_timestamp:
+ result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
+
+ # TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
+ """
Review Comment:
If we have a TODO here, we can remove the commented code.
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,151 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def get_batch_timestamp(self) -> int:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ get_batch_timestamp = stateMessage.GetBatchTimestampMs()
+ request =
stateMessage.TimerMiscRequest(getBatchTimestampMs=get_batch_timestamp)
+ message = stateMessage.StateRequest(timerMiscRequest=request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error initializing timer state: "
f"{response_message[1]}")
+ else:
+ if len(response_message[2]) == 0:
+ return -1
+ # TODO: can we simply parse from utf8 string here?
+ timestamp = int(response_message[2])
Review Comment:
Just curious: would this return the correct value?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -158,6 +229,51 @@ class TransformWithStateInPandasStateServer(
Some(message.getGetValueState.getTtl.getDurationMs)
} else None
initializeValueState(stateName, schema, ttlDurationMs)
+
+ case StatefulProcessorCall.MethodCase.TIMERSTATECALL =>
+ message.getTimerStateCall.getMethodCase match {
+ case TimerStateCallCommand.MethodCase.REGISTER =>
+ val expiryTimestamp =
+ message.getTimerStateCall.getRegister.getExpiryTimestampMs
+ statefulProcessorHandle.registerTimer(expiryTimestamp)
+ sendResponse(0)
+ case TimerStateCallCommand.MethodCase.DELETE =>
+ val expiryTimestamp =
+ message.getTimerStateCall.getDelete.getExpiryTimestampMs
+ statefulProcessorHandle.deleteTimer(expiryTimestamp)
+ sendResponse(0)
+ case TimerStateCallCommand.MethodCase.LIST =>
+ val iter = statefulProcessorHandle.listTimers()
+
+ if (iter == null || !iter.hasNext) {
+ // avoid sending over empty batch
+ sendResponse(1)
+ } else {
+ sendResponse(0)
+ outputStream.flush()
+ val arrowStreamWriter = {
+ val outputSchema = new StructType()
+ .add(StructField("timestamp", LongType))
+ val arrowSchema = ArrowUtils.toArrowSchema(outputSchema,
timeZoneId,
+ errorOnDuplicatedFieldNames, largeVarTypes)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for transformWithStateInPandas state
socket", 0, Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root,
null, outputStream),
+ arrowTransformWithStateInPandasMaxRecordsPerBatch)
+ }
+ while (iter.hasNext) {
+ val timestamp = iter.next()
+ val internalRow = InternalRow(timestamp)
+ arrowStreamWriter.writeRow(internalRow)
Review Comment:
Same as the python side comment: here we don't limit how many arrow batches
we construct for timers, if user sets a fairly low value for
arrowTransformWithStateInPandasMaxRecordsPerBatch, we would send multiple arrow
batches and client side needs to handle this properly as well.
Question: should we have a lower limit on how many records we send throw a
single batch (e.g. the default value 10000)? IIUC, each timer record is very
small and should not consume a lot of memory. The user also doesn't care about
how many records each batch contains since they would always get a single list
from this API.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,47 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
Review Comment:
Can we move these 2 API calls inside the in-else clause below and only call
it with supported time mode?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -113,11 +123,66 @@ class TransformWithStateInPandasStateServer(
handleStatefulProcessorCall(message.getStatefulProcessorCall)
case StateRequest.MethodCase.STATEVARIABLEREQUEST =>
handleStateVariableRequest(message.getStateVariableRequest)
+ case StateRequest.MethodCase.TIMERREQUEST =>
+ handleTimerRequest(message.getTimerRequest)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
}
+ private[sql] def handleTimerRequest(message: TimerRequest): Unit = {
+ message.getMethodCase match {
+ case TimerRequest.MethodCase.TIMERVALUEREQUEST =>
+ val timerRequest = message.getTimerValueRequest()
+ timerRequest.getMethodCase match {
+ case TimerValueRequest.MethodCase.GETPROCESSINGTIMER =>
+ val valueStr =
+ if (batchTimestampMs.isDefined) batchTimestampMs.get.toString
else "-1"
+ sendResponse(0, null, ByteString.copyFromUtf8(valueStr))
+ case TimerValueRequest.MethodCase.GETWATERMARK =>
+ val valueStr = if (eventTimeWatermarkForEviction.isDefined) {
+ eventTimeWatermarkForEviction.get.toString()
+ } else "-1"
+ sendResponse(0, null, ByteString.copyFromUtf8(valueStr))
+ case _ =>
+ throw new IllegalArgumentException("Invalid timer value method
call")
+ }
+
+ case TimerRequest.MethodCase.EXPIRYTIMERREQUEST =>
+ val expiryRequest = message.getExpiryTimerRequest()
+ val expiryTimestamp = expiryRequest.getExpiryTimestampMs
+ val iter =
statefulProcessorHandle.getExpiredTimersWithKeyRow(expiryTimestamp)
+ if (iter == null || !iter.hasNext) {
+ // avoid sending over empty batch
+ sendResponse(1)
+ } else {
+ sendResponse(0)
+ outputStream.flush()
+ val arrowStreamWriter = {
+ val outputSchema = new StructType()
+ .add("key", groupingKeySchema)
+ .add(StructField("timestamp", LongType))
+ val arrowSchema = ArrowUtils.toArrowSchema(outputSchema,
timeZoneId,
+ errorOnDuplicatedFieldNames, largeVarTypes)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stdout writer for transformWithStateInPandas state socket", 0,
Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root,
null, outputStream),
Review Comment:
Does it make sense to abstract this logic out since it's being used in
multiple places?
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,151 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def get_batch_timestamp(self) -> int:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ get_batch_timestamp = stateMessage.GetBatchTimestampMs()
+ request =
stateMessage.TimerMiscRequest(getBatchTimestampMs=get_batch_timestamp)
+ message = stateMessage.StateRequest(timerMiscRequest=request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error initializing timer state: "
f"{response_message[1]}")
+ else:
+ if len(response_message[2]) == 0:
+ return -1
+ # TODO: can we simply parse from utf8 string here?
+ timestamp = int(response_message[2])
+ return timestamp
+
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error delete timers: "
f"{response_message[1]}")
+
+ def list_timers(self) -> list[int]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ list_call = stateMessage.ListTimers()
+ state_call_command = stateMessage.TimerStateCallCommand(list=list_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ return []
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
Review Comment:
Do we expect all the timers can be stored within a single arrow batch? If
not, should we handle it properly here?
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,47 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp), ExpiredTimerInfo(False))
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+
+ if timeMode == "processingtime":
+ expiry_list: list[(any, int)] =
statefulProcessorApiClient.get_expiry_timers(batch_timestamp)
+ elif timeMode == "eventtime":
+ expiry_list: list[(any, int)] =
statefulProcessorApiClient.get_expiry_timers(watermark_timestamp)
+ else:
+ expiry_list = []
+
+ result_iter_list = []
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for key_obj, expiry_timestamp in expiry_list:
+ if timeMode == "processingtime" and expiry_timestamp <
batch_timestamp:
+ result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
Review Comment:
is `watermark_timestamp` needed for processingTime time mode and vise versa?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]