This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 418cfd1f7801 [SPARK-51690][SS] Change the protocol of
ListState.put()/get()/appendList() from Arrow to simple custom protocol
418cfd1f7801 is described below
commit 418cfd1f78014698ac4baac21156341a11b771b3
Author: Jungtaek Lim <[email protected]>
AuthorDate: Mon Apr 7 16:47:44 2025 +0900
[SPARK-51690][SS] Change the protocol of ListState.put()/get()/appendList()
from Arrow to simple custom protocol
### What changes were proposed in this pull request?
This PR proposes to get rid of usage for Arrow on sending multiple elements
of ListState and replace it with simple custom protocol.
The custom protocol we are proposing is super simple and widely used
already.
1. Write the size of the element (in bytes), if there is no more element,
write -1
2. Write the element (as bytes)
3. Go back to 1
Note that this PR only makes change to ListState - we are aware that there
are more usages of Arrow in other state types or other functionality (timer).
We want to improve over time via benchmarking and addressing if it shows the
latency implication.
### Why are the changes needed?
For small number of elements, Arrow does not perform very well compared to
the custom protocol. In the benchmark, we have three elements to exchange
between Python worker and JVM, and replacing Arrow with custom protocol could
cut the elapsed time on state interaction by 1/3.
Given the natural performance diff between Scala version of
transformWithState and PySpark version of transformWithStateInPandas, I think
users must use the Scala version to handle noticeable volume of workloads. We
can position PySpark version to aim for more lightweight workloads - we can
revisit if we see the opposite demands.
### Does this PR introduce _any_ user-facing change?
No, it's an internal change.
### How was this patch tested?
Existing UT, with modification about mock expectation.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50488 from HeartSaVioR/SPARK-51690.
Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/streaming/list_state_client.py | 46 +++++++---------------
.../sql/streaming/stateful_processor_api_client.py | 20 ++++++++++
.../TransformWithStateInPandasDeserializer.scala | 21 ++++++++++
.../TransformWithStateInPandasStateServer.scala | 30 +++++++++++---
...ransformWithStateInPandasStateServerSuite.scala | 36 +++++++++++------
5 files changed, 105 insertions(+), 48 deletions(-)
diff --git a/python/pyspark/sql/streaming/list_state_client.py
b/python/pyspark/sql/streaming/list_state_client.py
index cb618d1a691b..66f2640c935e 100644
--- a/python/pyspark/sql/streaming/list_state_client.py
+++ b/python/pyspark/sql/streaming/list_state_client.py
@@ -14,16 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Dict, Iterator, List, Union, Tuple
+from typing import Any, Dict, Iterator, List, Union, Tuple
from pyspark.sql.streaming.stateful_processor_api_client import
StatefulProcessorApiClient
-from pyspark.sql.types import StructType, TYPE_CHECKING
+from pyspark.sql.types import StructType
from pyspark.errors import PySparkRuntimeError
import uuid
-if TYPE_CHECKING:
- from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
-
__all__ = ["ListStateClient"]
@@ -38,9 +35,9 @@ class ListStateClient:
self.schema =
self._stateful_processor_api_client._parse_string_schema(schema)
else:
self.schema = schema
- # A dictionary to store the mapping between list state name and a
tuple of pandas DataFrame
+ # A dictionary to store the mapping between list state name and a
tuple of data batch
# and the index of the last row that was read.
- self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
+ self.data_batch_dict: Dict[str, Tuple[Any, int]] = {}
def exists(self, state_name: str) -> bool:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -67,9 +64,9 @@ class ListStateClient:
def get(self, state_name: str, iterator_id: str) -> Tuple:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
- if iterator_id in self.pandas_df_dict:
+ if iterator_id in self.data_batch_dict:
# If the state is already in the dictionary, return the next row.
- pandas_df, index = self.pandas_df_dict[iterator_id]
+ data_batch, index = self.data_batch_dict[iterator_id]
else:
# If the state is not in the dictionary, fetch the state from the
server.
get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
@@ -85,33 +82,20 @@ class ListStateClient:
response_message =
self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status == 0:
- iterator =
self._stateful_processor_api_client._read_arrow_state()
- # We need to exhaust the iterator here to make sure all the
arrow batches are read,
- # even though there is only one batch in the iterator.
Otherwise, the stream might
- # block further reads since it thinks there might still be
some arrow batches left.
- # We only need to read the first batch in the iterator because
it's guaranteed that
- # there would only be one batch sent from the JVM side.
- data_batch = None
- for batch in iterator:
- if data_batch is None:
- data_batch = batch
- if data_batch is None:
- # TODO(SPARK-49233): Classify user facing errors.
- raise PySparkRuntimeError("Error getting next list state
row.")
- pandas_df = data_batch.to_pandas()
+ data_batch =
self._stateful_processor_api_client._read_list_state()
index = 0
else:
raise StopIteration()
new_index = index + 1
- if new_index < len(pandas_df):
+ if new_index < len(data_batch):
# Update the index in the dictionary.
- self.pandas_df_dict[iterator_id] = (pandas_df, new_index)
+ self.data_batch_dict[iterator_id] = (data_batch, new_index)
else:
- # If the index is at the end of the DataFrame, remove the state
from the dictionary.
- self.pandas_df_dict.pop(iterator_id, None)
- pandas_row = pandas_df.iloc[index]
- return tuple(pandas_row)
+ # If the index is at the end of the data batch, remove the state
from the dictionary.
+ self.data_batch_dict.pop(iterator_id, None)
+ row = data_batch[index]
+ return tuple(row)
def append_value(self, state_name: str, value: Tuple) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
@@ -143,7 +127,7 @@ class ListStateClient:
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- self._stateful_processor_api_client._send_arrow_state(self.schema,
values)
+ self._stateful_processor_api_client._send_list_state(self.schema,
values)
response_message =
self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
@@ -160,7 +144,7 @@ class ListStateClient:
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
- self._stateful_processor_api_client._send_arrow_state(self.schema,
values)
+ self._stateful_processor_api_client._send_list_state(self.schema,
values)
response_message =
self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index fa50ed00738c..50945198f9c4 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -467,6 +467,26 @@ class StatefulProcessorApiClient:
def _read_arrow_state(self) -> Any:
return self.serializer.load_stream(self.sockfile)
+ def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None:
+ for value in state:
+ bytes = self._serialize_to_bytes(schema, value)
+ length = len(bytes)
+ write_int(length, self.sockfile)
+ self.sockfile.write(bytes)
+
+ write_int(-1, self.sockfile)
+ self.sockfile.flush()
+
+ def _read_list_state(self) -> List[Any]:
+ data_array = []
+ while True:
+ length = read_int(self.sockfile)
+ if length < 0:
+ break
+ bytes = self.sockfile.read(length)
+ data_array.append(self._deserialize_from_bytes(bytes))
+ return data_array
+
# Parse a string schema into a StructType schema. This method will perform
an API call to
# JVM side to parse the schema string.
def _parse_string_schema(self, schema: str) -> StructType:
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
index 1a8ffb35c053..b38697aeb021 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala
@@ -26,6 +26,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch,
ColumnVector}
@@ -57,4 +58,24 @@ class TransformWithStateInPandasDeserializer(deserializer:
ExpressionEncoder.Des
reader.close(false)
rows.toSeq
}
+
+ def readListElements(stream: DataInputStream, listStateInfo: ListStateInfo):
Seq[Row] = {
+ val rows = new scala.collection.mutable.ArrayBuffer[Row]
+
+ var endOfLoop = false
+ while (!endOfLoop) {
+ val size = stream.readInt()
+ if (size < 0) {
+ endOfLoop = true
+ } else {
+ val bytes = new Array[Byte](size)
+ stream.read(bytes, 0, size)
+ val newRow = PythonSQLUtils.toJVMRow(bytes, listStateInfo.schema,
+ listStateInfo.deserializer)
+ rows.append(newRow)
+ }
+ }
+
+ rows.toSeq
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
index 6a194976257e..f46b66204383 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
@@ -475,7 +475,7 @@ class TransformWithStateInPandasStateServer(
sendResponse(2, s"state $stateName doesn't exist")
}
case ListStateCall.MethodCase.LISTSTATEPUT =>
- val rows = deserializer.readArrowBatches(inputStream)
+ val rows = deserializer.readListElements(inputStream, listStateInfo)
listStateInfo.listState.put(rows.toArray)
sendResponse(0)
case ListStateCall.MethodCase.LISTSTATEGET =>
@@ -487,12 +487,10 @@ class TransformWithStateInPandasStateServer(
}
if (!iteratorOption.get.hasNext) {
sendResponse(2, s"List state $stateName doesn't contain any value.")
- return
} else {
sendResponse(0)
+ sendIteratorForListState(iteratorOption.get)
}
- sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema,
- arrowStreamWriterForTest) { data => listStateInfo.serializer(data)}
case ListStateCall.MethodCase.APPENDVALUE =>
val byteArray = message.getAppendValue.getValue.toByteArray
val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema,
@@ -500,7 +498,7 @@ class TransformWithStateInPandasStateServer(
listStateInfo.listState.appendValue(newRow)
sendResponse(0)
case ListStateCall.MethodCase.APPENDLIST =>
- val rows = deserializer.readArrowBatches(inputStream)
+ val rows = deserializer.readListElements(inputStream, listStateInfo)
listStateInfo.listState.appendList(rows.toArray)
sendResponse(0)
case ListStateCall.MethodCase.CLEAR =>
@@ -511,6 +509,28 @@ class TransformWithStateInPandasStateServer(
}
}
+ private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
+ // Only write a single batch in each GET request. Stops writing row if
rowCount reaches
+ // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to
handle a case
+ // when there are multiple state variables, user tries to access a
different state variable
+ // while the current state variable is not exhausted yet.
+ var rowCount = 0
+ while (iter.hasNext && rowCount <
arrowTransformWithStateInPandasMaxRecordsPerBatch) {
+ val data = iter.next()
+
+ // Serialize the value row as a byte array
+ val valueBytes = PythonSQLUtils.toPyRow(data)
+ val lenBytes = valueBytes.length
+
+ outputStream.writeInt(lenBytes)
+ outputStream.write(valueBytes)
+
+ rowCount += 1
+ }
+ outputStream.writeInt(-1)
+ outputStream.flush()
+ }
+
private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
val stateName = message.getStateName
if (!mapStates.contains(stateName)) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
index 1f0aa72d2713..305a520f6af8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
@@ -103,6 +103,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap,
expiryTimerIter, listTimerMap)
when(transformWithStateInPandasDeserializer.readArrowBatches(any))
.thenReturn(Seq(getIntegerRow(1)))
+ when(transformWithStateInPandasDeserializer.readListElements(any, any))
+ .thenReturn(Seq(getIntegerRow(1)))
}
test("set handle state") {
@@ -260,8 +262,10 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
- verify(arrowStreamWriter).writeRow(any)
- verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ // 1 for row, 1 for end of the data, 1 for proto response
+ verify(outputStream, times(3)).writeInt(any)
+ // 1 for sending an actual row, 1 for sending proto message
+ verify(outputStream, times(2)).write(any[Array[Byte]])
}
test("list state get - iterator in map with multiple batches") {
@@ -278,15 +282,20 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
// First call should send 2 records.
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
- verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
- verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto
response
+ verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
+ // maxRecordsPerBatch times for rows, 1 for sending proto message
+ verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
// Second call should send the remaining 2 records.
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
- // Since Mockito's verify counts the total number of calls, the expected
number of writeRow call
- // should be 2 * maxRecordsPerBatch.
- verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any)
- verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch()
+ // Since Mockito's verify counts the total number of calls, the expected
number of writeInt
+ // and write should be accumulated from the prior count; the number of
calls are the same
+ // with prior one.
+ // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto
response
+ verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any)
+ // maxRecordsPerBatch times for rows, 1 for sending proto message
+ verify(outputStream, times(maxRecordsPerBatch * 2 +
2)).write(any[Array[Byte]])
}
test("list state get - iterator not in map") {
@@ -302,17 +311,20 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
when(listState.get()).thenReturn(Iterator(getIntegerRow(1),
getIntegerRow(2), getIntegerRow(3)))
stateServer.handleListStateRequest(message)
verify(listState).get()
+
// Verify that only maxRecordsPerBatch (2) rows are written to the output
stream while still
// having 1 row left in the iterator.
- verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
- verify(arrowStreamWriter).finalizeCurrentArrowBatch()
+ // maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto
response
+ verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
+ // 2 for rows, 1 for proto message
+ verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
}
test("list state put") {
val message = ListStateCall.newBuilder().setStateName(stateName)
.setListStatePut(ListStatePut.newBuilder().build()).build()
stateServer.handleListStateRequest(message)
- verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
+ verify(transformWithStateInPandasDeserializer).readListElements(any, any)
verify(listState).put(any)
}
@@ -328,7 +340,7 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
val message = ListStateCall.newBuilder().setStateName(stateName)
.setAppendList(AppendList.newBuilder().build()).build()
stateServer.handleListStateRequest(message)
- verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
+ verify(transformWithStateInPandasDeserializer).readListElements(any, any)
verify(listState).appendList(any)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]