jingz-db commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1762064292
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,151 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def get_batch_timestamp(self) -> int:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ get_batch_timestamp = stateMessage.GetBatchTimestampMs()
+ request =
stateMessage.TimerMiscRequest(getBatchTimestampMs=get_batch_timestamp)
+ message = stateMessage.StateRequest(timerMiscRequest=request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error initializing timer state: "
f"{response_message[1]}")
+ else:
+ if len(response_message[2]) == 0:
+ return -1
+ # TODO: can we simply parse from utf8 string here?
+ timestamp = int(response_message[2])
Review Comment:
Passing a row schema and use `CPickleSerializer` seems a bit heavy-weighted.
Modified this to pass a byte buffer of exact 8 bytes and read exactly 8 bytes
on Python client.
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,47 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp), ExpiredTimerInfo(False))
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+
+ if timeMode == "processingtime":
+ expiry_list: list[(any, int)] =
statefulProcessorApiClient.get_expiry_timers(batch_timestamp)
+ elif timeMode == "eventtime":
+ expiry_list: list[(any, int)] =
statefulProcessorApiClient.get_expiry_timers(watermark_timestamp)
+ else:
+ expiry_list = []
+
+ result_iter_list = []
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for key_obj, expiry_timestamp in expiry_list:
+ if timeMode == "processingtime" and expiry_timestamp <
batch_timestamp:
+ result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
+ elif timeMode == "eventtime" and expiry_timestamp <
watermark_timestamp:
+ result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
+
+ # TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
+ """
Review Comment:
Done!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]