jingz-db commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1774159447
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -251,4 +338,53 @@ class TransformWithStateInPandasStateServer(
sendResponse(1, s"state $stateName already exists")
}
}
+
+ /** Utils object for sending response to Python client. */
+ private object PythonResponseWriterUtils {
+ def sendResponse(
+ status: Int,
+ errorMessage: String = null,
+ byteString: ByteString = null): Unit = {
+ val responseMessageBuilder =
StateResponse.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ if (byteString != null) {
+ responseMessageBuilder.setValue(byteString)
+ }
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
+ }
+
+ def serializeLongToByteString(longValue: Long): ByteString = {
Review Comment:
Add a new type of StateResponse to transmit Long type directly in proto
message.
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,130 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error delete timers: "
f"{response_message[1]}")
+
+ def list_timers(self) -> Iterator[list[int]]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ while True:
+ list_call = stateMessage.ListTimers()
+ state_call_command =
stateMessage.TimerStateCallCommand(list=list_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
+ result_list = []
+ batch_df = batch.to_pandas()
+ for i in range(batch.num_rows):
+ timestamp = batch_df.at[i, 'timestamp'].item()
+ result_list.append(timestamp)
+ yield result_list
+ else:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error getting expiry timers: "
f"{response_message[1]}")
+
+ def get_expiry_timers_iterator(self, expiry_timestamp: int) ->
Iterator[list[Any, int]]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ while True:
+ expiry_timer_call =
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+ timer_request =
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
+ result_list = []
+ key_fields = [field.name for field in self.key_schema.fields]
+ # TODO any better way to restore a grouping object from a
batch?
Review Comment:
Discussed with Bo offline, JVM will return Row type to Python and we can
directly convert it into Tuple.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]