HeartSaVioR commented on code in PR #51036:
URL: https://github.com/apache/spark/pull/51036#discussion_r2113095885


##########
python/pyspark/sql/streaming/stateful_processor_api_client.py:
##########
@@ -222,76 +224,96 @@ def delete_timer(self, expiry_time_stamp_ms: int) -> None:
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error deleting timer: " 
f"{response_message[1]}")
 
-    def get_list_timer_row(self, iterator_id: str) -> int:
+    def get_list_timer_row(self, iterator_id: str) -> Tuple[int, bool]:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
         if iterator_id in self.list_timer_iterator_cursors:
             # if the iterator is already in the dictionary, return the next row
-            pandas_df, index = self.list_timer_iterator_cursors[iterator_id]
+            data_batch, index, require_next_fetch = 
self.list_timer_iterator_cursors[iterator_id]
         else:
             list_call = stateMessage.ListTimers(iteratorId=iterator_id)
             state_call_command = 
stateMessage.TimerStateCallCommand(list=list_call)
             call = 
stateMessage.StatefulProcessorCall(timerStateCall=state_call_command)
             message = stateMessage.StateRequest(statefulProcessorCall=call)
 
             self._send_proto_message(message.SerializeToString())
-            response_message = self._receive_proto_message()
+            response_message = self._receive_proto_message_with_timers()
             status = response_message[0]
             if status == 0:
-                iterator = self._read_arrow_state()
-                # We need to exhaust the iterator here to make sure all the 
arrow batches are read,
-                # even though there is only one batch in the iterator. 
Otherwise, the stream might
-                # block further reads since it thinks there might still be 
some arrow batches left.
-                # We only need to read the first batch in the iterator because 
it's guaranteed that
-                # there would only be one batch sent from the JVM side.
-                data_batch = None
-                for batch in iterator:
-                    if data_batch is None:
-                        data_batch = batch
-                if data_batch is None:
-                    # TODO(SPARK-49233): Classify user facing errors.
-                    raise PySparkRuntimeError("Error getting map state entry.")
-                pandas_df = data_batch.to_pandas()
+                data_batch = list(
+                    map(
+                        lambda x: x.timestampMs,
+                        response_message[2]
+                    )
+                )
+                require_next_fetch = response_message[3]
                 index = 0
             else:
                 raise StopIteration()
+
+        is_last_row = False
         new_index = index + 1
-        if new_index < len(pandas_df):
+        if new_index < len(data_batch):
             # Update the index in the dictionary.
-            self.list_timer_iterator_cursors[iterator_id] = (pandas_df, 
new_index)
+            self.list_timer_iterator_cursors[iterator_id] = (data_batch, 
new_index, require_next_fetch)
         else:
-            # If the index is at the end of the DataFrame, remove the state 
from the dictionary.
+            # If the index is at the end of the data batch, remove the state 
from the dictionary.
             self.list_timer_iterator_cursors.pop(iterator_id, None)
-        return pandas_df.at[index, "timestamp"].item()
+            is_last_row = True
+
+        is_last_row_from_iterator = is_last_row and not require_next_fetch
+        timestamp = data_batch[index]
+        return (timestamp, is_last_row_from_iterator)
 
     def get_expiry_timers_iterator(
-        self, expiry_timestamp: int
-    ) -> Iterator[list[Tuple[Tuple, int]]]:
+        self,
+        iterator_id: str,
+        expiry_timestamp: int
+    ) -> Tuple[Tuple, int, bool]:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
-        while True:
-            expiry_timer_call = 
stateMessage.ExpiryTimerRequest(expiryTimestampMs=expiry_timestamp)
+        if iterator_id in self.expiry_timer_iterator_cursors:
+            # If the state is already in the dictionary, return the next row.
+            data_batch, index, require_next_fetch = 
self.expiry_timer_iterator_cursors[iterator_id]
+        else:
+            expiry_timer_call = stateMessage.ExpiryTimerRequest(
+                expiryTimestampMs=expiry_timestamp,
+                iteratorId=iterator_id
+            )
             timer_request = 
stateMessage.TimerRequest(expiryTimerRequest=expiry_timer_call)
             message = stateMessage.StateRequest(timerRequest=timer_request)
 
             self._send_proto_message(message.SerializeToString())
-            response_message = self._receive_proto_message()
+            response_message = self._receive_proto_message_with_timers()
             status = response_message[0]
-            if status == 1:
-                break
-            elif status == 0:
-                result_list = []
-                iterator = self._read_arrow_state()
-                for batch in iterator:
-                    batch_df = batch.to_pandas()
-                    for i in range(batch.num_rows):
-                        deserialized_key = self.pickleSer.loads(batch_df.at[i, 
"key"])
-                        timestamp = batch_df.at[i, "timestamp"].item()
-                        result_list.append((tuple(deserialized_key), 
timestamp))
-                yield result_list
+            if status == 0:

Review Comment:
   Two things:
   1. We see the benefit of inlining the data into proto message to save one 
round-trip.
   2. Arrow is the columnar format, which is known to be efficient when there 
are multiple data. It's not a good usage (though sometimes needed) to use Arrow 
RecordBatch with small number of records. It "might be" a bit different when 
there are enough number of records, especially the fact that pickled Python Row 
looks to contain the "schema" as "json", which is not needed at all with Arrow 
RecordBatch. Haven't tested with large number.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to