This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 418cfd1f7801 [SPARK-51690][SS] Change the protocol of 
ListState.put()/get()/appendList() from Arrow to simple custom protocol
418cfd1f7801 is described below

commit 418cfd1f78014698ac4baac21156341a11b771b3
Author: Jungtaek Lim <[email protected]>
AuthorDate: Mon Apr 7 16:47:44 2025 +0900

    [SPARK-51690][SS] Change the protocol of ListState.put()/get()/appendList() 
from Arrow to simple custom protocol
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to get rid of usage for Arrow on sending multiple elements 
of ListState and replace it with simple custom protocol.
    
    The custom protocol we are proposing is super simple and widely used 
already.
    
    1. Write the size of the element (in bytes), if there is no more element, 
write -1
    2. Write the element (as bytes)
    3. Go back to 1
    
    Note that this PR only makes change to ListState - we are aware that there 
are more usages of Arrow in other state types or other functionality (timer). 
We want to improve over time via benchmarking and addressing if it shows the 
latency implication.
    
    ### Why are the changes needed?
    
    For small number of elements, Arrow does not perform very well compared to 
the custom protocol. In the benchmark, we have three elements to exchange 
between Python worker and JVM, and replacing Arrow with custom protocol could 
cut the elapsed time on state interaction by 1/3.
    
    Given the natural performance diff between Scala version of 
transformWithState and PySpark version of transformWithStateInPandas, I think 
users must use the Scala version to handle noticeable volume of workloads. We 
can position PySpark version to aim for more lightweight workloads - we can 
revisit if we see the opposite demands.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's an internal change.
    
    ### How was this patch tested?
    
    Existing UT, with modification about mock expectation.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50488 from HeartSaVioR/SPARK-51690.
    
    Authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 python/pyspark/sql/streaming/list_state_client.py  | 46 +++++++---------------
 .../sql/streaming/stateful_processor_api_client.py | 20 ++++++++++
 .../TransformWithStateInPandasDeserializer.scala   | 21 ++++++++++
 .../TransformWithStateInPandasStateServer.scala    | 30 +++++++++++---
 ...ransformWithStateInPandasStateServerSuite.scala | 36 +++++++++++------
 5 files changed, 105 insertions(+), 48 deletions(-)

diff --git a/python/pyspark/sql/streaming/list_state_client.py 
b/python/pyspark/sql/streaming/list_state_client.py
index cb618d1a691b..66f2640c935e 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -14,16 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from typing import Dict, Iterator, List, Union, Tuple
+from typing import Any, Dict, Iterator, List, Union, Tuple
 
 from pyspark.sql.streaming.stateful_processor_api_client import 
StatefulProcessorApiClient
-from pyspark.sql.types import StructType, TYPE_CHECKING
+from pyspark.sql.types import StructType
 from pyspark.errors import PySparkRuntimeError
 import uuid
 
-if TYPE_CHECKING:
-    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
-
 __all__ = ["ListStateClient"]
 
 
@@ -38,9 +35,9 @@ class ListStateClient:
             self.schema = 
self._stateful_processor_api_client._parse_string_schema(schema)
         else:
             self.schema = schema
-        # A dictionary to store the mapping between list state name and a 
tuple of pandas DataFrame
+        # A dictionary to store the mapping between list state name and a 
tuple of data batch
         # and the index of the last row that was read.
-        self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+        self.data_batch_dict: Dict[str, Tuple[Any, int]] = {}
 
     def exists(self, state_name: str) -> bool:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -67,9 +64,9 @@ class ListStateClient:
     def get(self, state_name: str, iterator_id: str) -> Tuple:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
 
-        if iterator_id in self.pandas_df_dict:
+        if iterator_id in self.data_batch_dict:
             # If the state is already in the dictionary, return the next row.
-            pandas_df, index = self.pandas_df_dict[iterator_id]
+            data_batch, index = self.data_batch_dict[iterator_id]
         else:
             # If the state is not in the dictionary, fetch the state from the 
server.
             get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
@@ -85,33 +82,20 @@ class ListStateClient:
             response_message = 
self._stateful_processor_api_client._receive_proto_message()
             status = response_message[0]
             if status == 0:
-                iterator = 
self._stateful_processor_api_client._read_arrow_state()
-                # We need to exhaust the iterator here to make sure all the 
arrow batches are read,
-                # even though there is only one batch in the iterator. 
Otherwise, the stream might
-                # block further reads since it thinks there might still be 
some arrow batches left.
-                # We only need to read the first batch in the iterator because 
it's guaranteed that
-                # there would only be one batch sent from the JVM side.
-                data_batch = None
-                for batch in iterator:
-                    if data_batch is None:
-                        data_batch = batch
-                if data_batch is None:
-                    # TODO(SPARK-49233): Classify user facing errors.
-                    raise PySparkRuntimeError("Error getting next list state 
row.")
-                pandas_df = data_batch.to_pandas()
+                data_batch = 
self._stateful_processor_api_client._read_list_state()
                 index = 0
             else:
                 raise StopIteration()
 
         new_index = index + 1
-        if new_index < len(pandas_df):
+        if new_index < len(data_batch):
             # Update the index in the dictionary.
-            self.pandas_df_dict[iterator_id] = (pandas_df, new_index)
+            self.data_batch_dict[iterator_id] = (data_batch, new_index)
         else:
-            # If the index is at the end of the DataFrame, remove the state 
from the dictionary.
-            self.pandas_df_dict.pop(iterator_id, None)
-        pandas_row = pandas_df.iloc[index]
-        return tuple(pandas_row)
+            # If the index is at the end of the data batch, remove the state 
from the dictionary.
+            self.data_batch_dict.pop(iterator_id, None)
+        row = data_batch[index]
+        return tuple(row)
 
     def append_value(self, state_name: str, value: Tuple) -> None:
         import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -143,7 +127,7 @@ class ListStateClient:
 
         
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
 
-        self._stateful_processor_api_client._send_arrow_state(self.schema, 
values)
+        self._stateful_processor_api_client._send_list_state(self.schema, 
values)
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status != 0:
@@ -160,7 +144,7 @@ class ListStateClient:
 
         
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
 
-        self._stateful_processor_api_client._send_arrow_state(self.schema, 
values)
+        self._stateful_processor_api_client._send_list_state(self.schema, 
values)
         response_message = 
self._stateful_processor_api_client._receive_proto_message()
         status = response_message[0]
         if status != 0:
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py 
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index fa50ed00738c..50945198f9c4 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -467,6 +467,26 @@ class StatefulProcessorApiClient:
     def _read_arrow_state(self) -> Any:
         return self.serializer.load_stream(self.sockfile)
 
+    def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None:
+        for value in state:
+            bytes = self._serialize_to_bytes(schema, value)
+            length = len(bytes)
+            write_int(length, self.sockfile)
+            self.sockfile.write(bytes)
+
+        write_int(-1, self.sockfile)
+        self.sockfile.flush()
+
+    def _read_list_state(self) -> List[Any]:
+        data_array = []
+        while True:
+            length = read_int(self.sockfile)
+            if length < 0:
+                break
+            bytes = self.sockfile.read(length)
+            data_array.append(self._deserialize_from_bytes(bytes))
+        return data_array
+
     # Parse a string schema into a StructType schema. This method will perform 
an API call to
     # JVM side to parse the schema string.
     def _parse_string_schema(self, schema: str) -> StructType:
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
index 1a8ffb35c053..b38697aeb021 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
@@ -26,6 +26,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
@@ -57,4 +58,24 @@ class TransformWithStateInPandasDeserializer(deserializer: 
ExpressionEncoder.Des
     reader.close(false)
     rows.toSeq
   }
+
+  def readListElements(stream: DataInputStream, listStateInfo: ListStateInfo): 
Seq[Row] = {
+    val rows = new scala.collection.mutable.ArrayBuffer[Row]
+
+    var endOfLoop = false
+    while (!endOfLoop) {
+      val size = stream.readInt()
+      if (size < 0) {
+        endOfLoop = true
+      } else {
+        val bytes = new Array[Byte](size)
+        stream.read(bytes, 0, size)
+        val newRow = PythonSQLUtils.toJVMRow(bytes, listStateInfo.schema,
+          listStateInfo.deserializer)
+        rows.append(newRow)
+      }
+    }
+
+    rows.toSeq
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
index 6a194976257e..f46b66204383 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
@@ -475,7 +475,7 @@ class TransformWithStateInPandasStateServer(
           sendResponse(2, s"state $stateName doesn't exist")
         }
       case ListStateCall.MethodCase.LISTSTATEPUT =>
-        val rows = deserializer.readArrowBatches(inputStream)
+        val rows = deserializer.readListElements(inputStream, listStateInfo)
         listStateInfo.listState.put(rows.toArray)
         sendResponse(0)
       case ListStateCall.MethodCase.LISTSTATEGET =>
@@ -487,12 +487,10 @@ class TransformWithStateInPandasStateServer(
         }
         if (!iteratorOption.get.hasNext) {
           sendResponse(2, s"List state $stateName doesn't contain any value.")
-          return
         } else {
           sendResponse(0)
+          sendIteratorForListState(iteratorOption.get)
         }
-        sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema,
-          arrowStreamWriterForTest) { data => listStateInfo.serializer(data)}
       case ListStateCall.MethodCase.APPENDVALUE =>
         val byteArray = message.getAppendValue.getValue.toByteArray
         val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema,
@@ -500,7 +498,7 @@ class TransformWithStateInPandasStateServer(
         listStateInfo.listState.appendValue(newRow)
         sendResponse(0)
       case ListStateCall.MethodCase.APPENDLIST =>
-        val rows = deserializer.readArrowBatches(inputStream)
+        val rows = deserializer.readListElements(inputStream, listStateInfo)
         listStateInfo.listState.appendList(rows.toArray)
         sendResponse(0)
       case ListStateCall.MethodCase.CLEAR =>
@@ -511,6 +509,28 @@ class TransformWithStateInPandasStateServer(
     }
   }
 
+  private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
+    // Only write a single batch in each GET request. Stops writing row if 
rowCount reaches
+    // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to 
handle a case
+    // when there are multiple state variables, user tries to access a 
different state variable
+    // while the current state variable is not exhausted yet.
+    var rowCount = 0
+    while (iter.hasNext && rowCount < 
arrowTransformWithStateInPandasMaxRecordsPerBatch) {
+      val data = iter.next()
+
+      // Serialize the value row as a byte array
+      val valueBytes = PythonSQLUtils.toPyRow(data)
+      val lenBytes = valueBytes.length
+
+      outputStream.writeInt(lenBytes)
+      outputStream.write(valueBytes)
+
+      rowCount += 1
+    }
+    outputStream.writeInt(-1)
+    outputStream.flush()
+  }
+
   private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
     val stateName = message.getStateName
     if (!mapStates.contains(stateName)) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
index 1f0aa72d2713..305a520f6af8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
@@ -103,6 +103,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
       listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap, 
expiryTimerIter, listTimerMap)
     when(transformWithStateInPandasDeserializer.readArrowBatches(any))
       .thenReturn(Seq(getIntegerRow(1)))
+    when(transformWithStateInPandasDeserializer.readListElements(any, any))
+      .thenReturn(Seq(getIntegerRow(1)))
   }
 
   test("set handle state") {
@@ -260,8 +262,10 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
       
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
-    verify(arrowStreamWriter).writeRow(any)
-    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+    // 1 for row, 1 for end of the data, 1 for proto response
+    verify(outputStream, times(3)).writeInt(any)
+    // 1 for sending an actual row, 1 for sending proto message
+    verify(outputStream, times(2)).write(any[Array[Byte]])
   }
 
   test("list state get - iterator in map with multiple batches") {
@@ -278,15 +282,20 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     // First call should send 2 records.
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
-    verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
-    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+    // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto 
response
+    verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
+    // maxRecordsPerBatch times for rows, 1 for sending proto message
+    verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
     // Second call should send the remaining 2 records.
     stateServer.handleListStateRequest(message)
     verify(listState, times(0)).get()
-    // Since Mockito's verify counts the total number of calls, the expected 
number of writeRow call
-    // should be 2 * maxRecordsPerBatch.
-    verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any)
-    verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch()
+    // Since Mockito's verify counts the total number of calls, the expected 
number of writeInt
+    // and write should be accumulated from the prior count; the number of 
calls are the same
+    // with prior one.
+    // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto 
response
+    verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any)
+    // maxRecordsPerBatch times for rows, 1 for sending proto message
+    verify(outputStream, times(maxRecordsPerBatch * 2 + 
2)).write(any[Array[Byte]])
   }
 
   test("list state get - iterator not in map") {
@@ -302,17 +311,20 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     when(listState.get()).thenReturn(Iterator(getIntegerRow(1), 
getIntegerRow(2), getIntegerRow(3)))
     stateServer.handleListStateRequest(message)
     verify(listState).get()
+
     // Verify that only maxRecordsPerBatch (2) rows are written to the output 
stream while still
     // having 1 row left in the iterator.
-    verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
-    verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+    // maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto 
response
+    verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
+    // 2 for rows, 1 for proto message
+    verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
   }
 
   test("list state put") {
     val message = ListStateCall.newBuilder().setStateName(stateName)
       .setListStatePut(ListStatePut.newBuilder().build()).build()
     stateServer.handleListStateRequest(message)
-    verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
+    verify(transformWithStateInPandasDeserializer).readListElements(any, any)
     verify(listState).put(any)
   }
 
@@ -328,7 +340,7 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     val message = ListStateCall.newBuilder().setStateName(stateName)
       .setAppendList(AppendList.newBuilder().build()).build()
     stateServer.handleListStateRequest(message)
-    verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
+    verify(transformWithStateInPandasDeserializer).readListElements(any, any)
     verify(listState).appendList(any)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to