jingz-db commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1763700239
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,151 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def get_batch_timestamp(self) -> int:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ get_batch_timestamp = stateMessage.GetBatchTimestampMs()
+ request =
stateMessage.TimerMiscRequest(getBatchTimestampMs=get_batch_timestamp)
+ message = stateMessage.StateRequest(timerMiscRequest=request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error initializing timer state: "
f"{response_message[1]}")
+ else:
+ if len(response_message[2]) == 0:
+ return -1
+ # TODO: can we simply parse from utf8 string here?
+ timestamp = int(response_message[2])
+ return timestamp
+
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error delete timers: "
f"{response_message[1]}")
+
+ def list_timers(self) -> list[int]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ list_call = stateMessage.ListTimers()
+ state_call_command = stateMessage.TimerStateCallCommand(list=list_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ return []
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
Review Comment:
We don't. It is now returning an iterator of List. Do you think this API
makes sense to you?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]