bogao007 commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1767591785
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,47 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
Review Comment:
IIUC these 2 values are only being used when time mode is not `none`, I was
meaning that for `none` time mode, we don't need these 2 extra API calls since
it's not needed anyway
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,43 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp), ExpiredTimerInfo(False))
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+
+ if timeMode == "processingtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(batch_timestamp)
+ elif timeMode == "eventtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(watermark_timestamp)
+ else:
+ expiry_list_iter = []
+
+ result_iter_list = []
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for expiry_list in expiry_list_iter:
+ for key_obj, expiry_timestamp in expiry_list:
+ if timeMode == "processingtime" and expiry_timestamp <
batch_timestamp:
+
result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
Review Comment:
Nit: The code seems to be identical with the one down below, can we combine
them together?
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,130 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error delete timers: "
f"{response_message[1]}")
+
+ def list_timers(self) -> Iterator[list[int]]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ while True:
+ list_call = stateMessage.ListTimers()
+ state_call_command =
stateMessage.TimerStateCallCommand(list=list_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
+ result_list = []
+ batch_df = batch.to_pandas()
+ for i in range(batch.num_rows):
+ timestamp = batch_df.at[i, 'timestamp'].item()
+ result_list.append(timestamp)
+ yield result_list
+ else:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error getting expiry timers: "
f"{response_message[1]}")
+
+ def get_expiry_timers_iterator(self, expiry_timestamp: int) ->
Iterator[list[Any, int]]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ while True:
+ expiry_timer_call =
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+ timer_request =
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
+ result_list = []
+ key_fields = [field.name for field in self.key_schema.fields]
+ # TODO any better way to restore a grouping object from a
batch?
Review Comment:
Maybe try something like below?
```
df.itertuples(index=False, name=None)]
```
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.itertuples.html
##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -501,7 +502,43 @@ def transformWithStateUDF(
)
statefulProcessorApiClient.set_implicit_key(key)
- result = statefulProcessor.handleInputRows(key, inputRows)
+
+ batch_timestamp = statefulProcessorApiClient.get_batch_timestamp()
+ watermark_timestamp =
statefulProcessorApiClient.get_watermark_timestamp()
+ # process with invalid expiry timer info and emit data rows
+ data_iter = statefulProcessor.handleInputRows(
+ key, inputRows, TimerValues(batch_timestamp,
watermark_timestamp), ExpiredTimerInfo(False))
+ statefulProcessorApiClient.set_handle_state(
+ StatefulProcessorHandleState.DATA_PROCESSED
+ )
+
+ if timeMode == "processingtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(batch_timestamp)
+ elif timeMode == "eventtime":
+ expiry_list_iter =
statefulProcessorApiClient.get_expiry_timers_iterator(watermark_timestamp)
+ else:
+ expiry_list_iter = []
+
+ result_iter_list = []
+ # process with valid expiry time info and with empty input rows,
+ # only timer related rows will be emitted
+ for expiry_list in expiry_list_iter:
+ for key_obj, expiry_timestamp in expiry_list:
+ if timeMode == "processingtime" and expiry_timestamp <
batch_timestamp:
+
result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
+ elif timeMode == "eventtime" and expiry_timestamp <
watermark_timestamp:
+
result_iter_list.append(statefulProcessor.handleInputRows(
+ (key_obj,), iter([]),
+ TimerValues(batch_timestamp, watermark_timestamp),
+ ExpiredTimerInfo(True, expiry_timestamp)))
+
+ # TODO(SPARK-49603) set the handle state in the lazily initialized
iterator
+
+ result_iter_list.insert(0, data_iter)
Review Comment:
Nit: can we just append `data_iter` to the empty list before adding timer
iters to it?
##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -124,6 +129,130 @@ def get_value_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
+ def register_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ register_call =
stateMessage.RegisterTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(register=register_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error register timer: "
f"{response_message[1]}")
+
+ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+
+ delete_call =
stateMessage.DeleteTimer(expiryTimestampMs=expiry_time_stamp_ms)
+ state_call_command =
stateMessage.TimerStateCallCommand(delete=delete_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status != 0:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error delete timers: "
f"{response_message[1]}")
+
+ def list_timers(self) -> Iterator[list[int]]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ while True:
+ list_call = stateMessage.ListTimers()
+ state_call_command =
stateMessage.TimerStateCallCommand(list=list_call)
+ call =
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
+ message = stateMessage.StateRequest(statefulProcessorCall=call)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
+ result_list = []
+ batch_df = batch.to_pandas()
+ for i in range(batch.num_rows):
+ timestamp = batch_df.at[i, 'timestamp'].item()
+ result_list.append(timestamp)
+ yield result_list
+ else:
+ # TODO(SPARK-49233): Classify user facing errors.
+ raise PySparkRuntimeError(f"Error getting expiry timers: "
f"{response_message[1]}")
+
+ def get_expiry_timers_iterator(self, expiry_timestamp: int) ->
Iterator[list[Any, int]]:
+ import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+ while True:
+ expiry_timer_call =
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+ timer_request =
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
+ message = stateMessage.StateRequest(timerRequest=timer_request)
+
+ self._send_proto_message(message.SerializeToString())
+ response_message = self._receive_proto_message()
+ status = response_message[0]
+ if status == 1:
+ break
+ elif status == 0:
+ iterator = self._read_arrow_state()
+ batch = next(iterator)
+ result_list = []
+ key_fields = [field.name for field in self.key_schema.fields]
+ # TODO any better way to restore a grouping object from a
batch?
Review Comment:
Btw, what if multiple batches are being sent from JVM, are we handling it
correctly?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -251,4 +338,53 @@ class TransformWithStateInPandasStateServer(
sendResponse(1, s"state $stateName already exists")
}
}
+
+ /** Utils object for sending response to Python client. */
+ private object PythonResponseWriterUtils {
+ def sendResponse(
+ status: Int,
+ errorMessage: String = null,
+ byteString: ByteString = null): Unit = {
+ val responseMessageBuilder =
StateResponse.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ if (byteString != null) {
+ responseMessageBuilder.setValue(byteString)
+ }
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
+ }
+
+ def serializeLongToByteString(longValue: Long): ByteString = {
Review Comment:
I think it may bring some extra complexity to do serde between long and
ByteString. Since this is only used in `TimerValueRequests`, maybe we could add
a dedicated response message for it which returns a long value? That way we can
just use `read_long` on python side.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -251,4 +338,53 @@ class TransformWithStateInPandasStateServer(
sendResponse(1, s"state $stateName already exists")
}
}
+
+ /** Utils object for sending response to Python client. */
+ private object PythonResponseWriterUtils {
+ def sendResponse(
+ status: Int,
+ errorMessage: String = null,
+ byteString: ByteString = null): Unit = {
+ val responseMessageBuilder =
StateResponse.newBuilder().setStatusCode(status)
+ if (status != 0 && errorMessage != null) {
+ responseMessageBuilder.setErrorMessage(errorMessage)
+ }
+ if (byteString != null) {
+ responseMessageBuilder.setValue(byteString)
+ }
+ val responseMessage = responseMessageBuilder.build()
+ val responseMessageBytes = responseMessage.toByteArray
+ val byteLength = responseMessageBytes.length
+ outputStream.writeInt(byteLength)
+ outputStream.write(responseMessageBytes)
+ }
+
+ def serializeLongToByteString(longValue: Long): ByteString = {
+ // ByteBuffer defaults to big-endian byte order.
+ // This is accordingly deserialized as big-endian on Python client.
+ val valueBytes =
+ java.nio.ByteBuffer.allocate(8).putLong(longValue).array()
+ ByteString.copyFrom(valueBytes)
+ }
+
+ def sendIteratorAsArrowBatches[T](
Review Comment:
This looks good, thanks for abstracting the logic out!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]