This is an automated email from the ASF dual-hosted git repository.

dtenedor pushed a commit to branch branch-4.x
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.x by this push:
     new 6bc27159eb1d [SPARK-56324] Introducing message-based communication to 
Spark -> PySpark communication channel
6bc27159eb1d is described below

commit 6bc27159eb1d5a918eacbd293989a48fc3068637
Author: Sven Weber <[email protected]>
AuthorDate: Wed May 27 10:37:08 2026 -0700

    [SPARK-56324] Introducing message-based communication to Spark -> PySpark 
communication channel
    
    ### What changes were proposed in this pull request?
    
    This is the second in a series of PRs that introduce message-based 
communication to PySpark UDFs. This initiative is part of [SPIP 
SPARK-55278](https://issues.apache.org/jira/browse/SPARK-55278), which proposes 
language-agnostic UDFs. This PR builds on top of the changes from [PR 
#55515](https://github.com/apache/spark/pull/55515).
    
    The goal of introducing message-based communication to PySpark is to:
    
    1. Make the communication between Spark <-> PySpark more structured.
    2. Enable new communication protocols (e.g., gRPC) transparently.
    
    The overall goal is to introduce a second communication channel while 
keeping the existing channel intact. Specifically, we want to introduce gRPC in 
addition to Unix Domain Sockets (UDS). The existing UDS channel will not be 
changed, and its characteristics, including performance, will remain untouched.
    
    This PR specifically propose the following changes:
    
    1.  `PythonRunner.scala` - Add a new message header and a length field to 
the initialization data/message send from Spark to PySpark. This change is 
required to distinguish the initial message from other, later, messages. It 
constitutes the only required change in the Spark -> PySpark wire protocol.
    2. Add new abstractions to read Spark -> PySpark messages from the existing 
socket channel - including the new init message
    3. Change `worker.py` to use the new socket message reader to process the 
UDF request
    4. Updating the wire-protocol for the PySpark benchmarks to include the 
length field introduced in this PR.
    
    With these changes, a new message reader can be implemented and 
transparently use for other transport channels (e.g. gRPC).
    
    ### Why are the changes needed?
    
    The changes introduced here make PySpark transport layer agnostic for the 
Spark -> PySpark channel. This is required for PySpark to support the new, 
language agnostic UDF protocol proposed in [SPIP 
SPARK-55278](https://issues.apache.org/jira/browse/SPARK-55278). Follow-up PRs 
will address the PySpark -> Spark communication in a similar manner.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    ### Correctness
    
    Existing test suites:
    
    PySpark
    ```bash
    pyspark.tests.test_worker
    pyspark.sql.tests.test_udf
    pyspark.sql.tests.test_udtf
    pyspark.sql.tests.pandas.test_pandas_udf_scalar
    pyspark.sql.tests.arrow.test_arrow_udf_scalar
    pyspark.sql.tests.arrow.test_arrow_udf
    pyspark.sql.tests.arrow.test_arrow_grouped_map
    pyspark.sql.tests.arrow.test_arrow_cogrouped_map
    pyspark.tests.test_taskcontext
    pyspark.sql.tests.test_python_datasource
    ```
    
    Spark
    ```bash
    org.apache.spark.sql.execution.python.PythonUDFSuite
    org.apache.spark.sql.execution.python.PythonUDTFSuite
    org.apache.spark.sql.execution.python.ArrowColumnarPythonUDFSuite
    org.apache.spark.sql.execution.python.BatchEvalPythonExecSuite
    org.apache.spark.sql.execution.python.PythonDataSourceSuite
    org.apache.spark.sql.execution.python.PythonWorkerLogsSuite
    ```
    
    ### Performance
    
    To validate that the proposed changes do not introduce a performance 
regression, the PySpark benchmarks were run. Comparing the benchmark runs 
before and after this change, like so
    ```
    ./python/asv compare 6e4d9d7147afdc8d160a72ada9d57edbb86ec138 
8ba30ed2ec1eaae6fb530896f74612574ba7de1c --only-changed
    ```
    
    Yielded no output -> **No performance regression could be found**.
    
    A detailed comparison of all benchmark suites, their runtime, and ratio can 
be found in [this 
gist](https://gist.github.com/sven-weber-db/18b0b97d72e17fb49693a3df320c3df5).
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #55716 from sven-weber-db/sven-weber_data/spark-56324.
    
    Authored-by: Sven Weber <[email protected]>
    Signed-off-by: Daniel Tenedorio <[email protected]>
    (cherry picked from commit 28904a3b15830c30e450c9eebbf690fa82d872d0)
    Signed-off-by: Daniel Tenedorio <[email protected]>
---
 .../org/apache/spark/api/python/PythonRunner.scala |  14 ++-
 dev/sparktestsupport/modules.py                    |   1 +
 python/benchmarks/bench_eval_type.py               |  64 +++++++----
 python/packaging/classic/setup.py                  |   2 +
 python/packaging/client/setup.py                   |   2 +
 python/pyspark/messages/__init__.py                |   4 +
 python/pyspark/messages/{ => socket}/__init__.py   |   6 -
 .../socket/spark_socket_message_receiver.py        |  62 +++++++++++
 python/pyspark/messages/spark_message_receiver.py  | 122 +++++++++++++++++++++
 python/pyspark/serializers.py                      |   4 +-
 python/pyspark/taskcontext.py                      |   2 +-
 .../pyspark/tests/test_spark_message_receiver.py   |  79 +++++++++++++
 python/pyspark/worker.py                           |  49 +++++++--
 python/pyspark/worker_message.py                   |  27 +++--
 python/pyspark/worker_util.py                      |  16 ++-
 15 files changed, 401 insertions(+), 53 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 4e6089b50a46..1524ff455ec7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -415,13 +415,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
      */
     def writeNextInputToStream(dataOut: DataOutputStream): Boolean
 
-    def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
+    def open(outputStream: DataOutputStream): Unit = 
Utils.logUncaughtExceptions {
       val isUnixDomainSock = 
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
       lazy val sockPath = new File(
         authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
           .getOrElse(System.getProperty("java.io.tmpdir")),
         s".${UUID.randomUUID()}.sock")
       try {
+        // Buffer the initialization message, and send it together with its 
length.
+        val buffer = new ByteArrayOutputStream()
+        val dataOut = new DataOutputStream(buffer)
+
         // Partition index
         dataOut.writeInt(partitionIndex)
 
@@ -522,6 +526,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         writeCommand(dataOut)
 
         dataOut.flush()
+
+        // The initialization message is complete, write it to the stream with 
its length.
+        val messageBytes = buffer.toByteArray
+        outputStream.writeInt(SpecialLengths.START_OF_INIT_MESSAGE)
+        outputStream.writeInt(messageBytes.length)
+        outputStream.write(messageBytes)
+        outputStream.flush()
       } catch {
         case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>
           if (context.isCompleted() || context.isInterrupted()) {
@@ -1085,6 +1096,7 @@ private[spark] object SpecialLengths {
   val NULL = -5
   val START_ARROW_STREAM = -6
   val END_OF_MICRO_BATCH = -7
+  val START_OF_INIT_MESSAGE = -8
 }
 
 private[spark] object BarrierTaskContextMessageProtocol {
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index a03732a3554e..286a0c35b27e 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -514,6 +514,7 @@ pyspark_core = Module(
         "pyspark.tests.test_readwrite",
         "pyspark.tests.test_serializers",
         "pyspark.tests.test_shuffle",
+        "pyspark.tests.test_spark_message_receiver",
         "pyspark.tests.test_statcounter",
         "pyspark.tests.test_taskcontext",
         "pyspark.tests.test_util",
diff --git a/python/benchmarks/bench_eval_type.py 
b/python/benchmarks/bench_eval_type.py
index 14020d6f22ec..c75e4490d1ed 100644
--- a/python/benchmarks/bench_eval_type.py
+++ b/python/benchmarks/bench_eval_type.py
@@ -35,7 +35,7 @@ import numpy as np
 import pyarrow as pa
 
 from pyspark.cloudpickle import dumps as cloudpickle_dumps
-from pyspark.serializers import write_int, write_long
+from pyspark.serializers import write_int, write_long, SpecialLengths
 from pyspark.sql.types import (
     BinaryType,
     BooleanType,
@@ -127,6 +127,46 @@ class MockProtocolWriter:
         cls.write_bool(False, buf)  # needs_broadcast_decryption_server
         write_int(0, buf)  # num_broadcast_variables
 
+    @classmethod
+    def write_init_message(
+        cls,
+        eval_type: int,
+        write_udf: Callable[[io.BufferedIOBase], None],
+        target_buffer: io.BytesIO,
+        runner_conf: dict[str, str] | None = None,
+        eval_conf: dict[str, str] | None = None,
+    ) -> None:
+        """Write the initial message with header, length + its data."""
+
+        # Write everything to a seperate buffer so we can
+        # determine the length of the initial message.
+        buf = io.BytesIO()
+        cls.write_preamble(buf)
+        write_int(eval_type, buf)
+        if runner_conf:
+            write_int(len(runner_conf), buf)
+            for k, v in runner_conf.items():
+                cls.write_utf8(k, buf)
+                cls.write_utf8(v, buf)
+        else:
+            write_int(0, buf)  # RunnerConf  (0 key-value pairs)
+        if eval_conf:
+            write_int(len(eval_conf), buf)
+            for k, v in eval_conf.items():
+                cls.write_utf8(k, buf)
+                cls.write_utf8(v, buf)
+        else:
+            write_int(0, buf)  # EvalConf    (0 key-value pairs)
+        write_udf(buf)
+
+        # Write the actual data
+        # header...
+        write_int(SpecialLengths.START_OF_INIT_MESSAGE, target_buffer)
+        write_int(buf.getbuffer().nbytes, target_buffer)
+        # ... + previously buffered data
+        buf.seek(0)
+        target_buffer.write(buf.read())
+
     @classmethod
     def write_udf_payload(
         cls,
@@ -165,25 +205,11 @@ class MockProtocolWriter:
         eval_conf: dict[str, str] | None = None,
     ) -> None:
         """Write the full worker binary stream: preamble + command + data + 
end."""
-        cls.write_preamble(buf)
-        write_int(eval_type, buf)
-        if runner_conf:
-            write_int(len(runner_conf), buf)
-            for k, v in runner_conf.items():
-                cls.write_utf8(k, buf)
-                cls.write_utf8(v, buf)
-        else:
-            write_int(0, buf)  # RunnerConf  (0 key-value pairs)
-        if eval_conf:
-            write_int(len(eval_conf), buf)
-            for k, v in eval_conf.items():
-                cls.write_utf8(k, buf)
-                cls.write_utf8(v, buf)
-        else:
-            write_int(0, buf)  # EvalConf    (0 key-value pairs)
-        write_udf(buf)
+        cls.write_init_message(
+            eval_type, write_udf, buf, runner_conf=runner_conf, 
eval_conf=eval_conf
+        )
         write_data(buf)
-        write_int(-4, buf)  # SpecialLengths.END_OF_STREAM
+        write_int(SpecialLengths.END_OF_STREAM, buf)
 
     @classmethod
     def write_arrow_udtf_payload(
diff --git a/python/packaging/classic/setup.py 
b/python/packaging/classic/setup.py
index bee34e06d5c3..54b7707dbc44 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -267,6 +267,8 @@ try:
             "pyspark",
             "pyspark.core",
             "pyspark.cloudpickle",
+            "pyspark.messages",
+            "pyspark.messages.socket",
             "pyspark.mllib",
             "pyspark.mllib.linalg",
             "pyspark.mllib.stat",
diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py
index 564478af8cd8..9c7df16d9537 100755
--- a/python/packaging/client/setup.py
+++ b/python/packaging/client/setup.py
@@ -148,6 +148,8 @@ try:
     connect_packages = [
         "pyspark",
         "pyspark.cloudpickle",
+        "pyspark.messages",
+        "pyspark.messages.socket",
         "pyspark.mllib",
         "pyspark.mllib.linalg",
         "pyspark.mllib.stat",
diff --git a/python/pyspark/messages/__init__.py 
b/python/pyspark/messages/__init__.py
index ccb7b9323257..69cfbf6bd53a 100644
--- a/python/pyspark/messages/__init__.py
+++ b/python/pyspark/messages/__init__.py
@@ -15,8 +15,12 @@
 # limitations under the License.
 #
 
+from pyspark.messages.spark_message_receiver import SparkMessageReceiver
 from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream
+from pyspark.messages.socket.spark_socket_message_receiver import 
SparkSocketMessageReceiver
 
 __all__ = [
+    "SparkMessageReceiver",
+    "SparkSocketMessageReceiver",
     "ZeroCopyByteStream",
 ]
diff --git a/python/pyspark/messages/__init__.py 
b/python/pyspark/messages/socket/__init__.py
similarity index 87%
copy from python/pyspark/messages/__init__.py
copy to python/pyspark/messages/socket/__init__.py
index ccb7b9323257..cce3acad34a4 100644
--- a/python/pyspark/messages/__init__.py
+++ b/python/pyspark/messages/socket/__init__.py
@@ -14,9 +14,3 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
-from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream
-
-__all__ = [
-    "ZeroCopyByteStream",
-]
diff --git a/python/pyspark/messages/socket/spark_socket_message_receiver.py 
b/python/pyspark/messages/socket/spark_socket_message_receiver.py
new file mode 100644
index 000000000000..de3db2fe6611
--- /dev/null
+++ b/python/pyspark/messages/socket/spark_socket_message_receiver.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import BinaryIO
+
+from pyspark.serializers import read_int, SpecialLengths
+from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream
+from pyspark.messages.spark_message_receiver import (
+    SparkMessageReceiver,
+)
+
+
+def _assert_message_id(message_id: int, expected: int) -> None:
+    assert message_id == expected, (
+        f"Expected message with id {expected} but got message with id 
{message_id} instead."
+    )
+
+
+class SparkSocketMessageReceiver(SparkMessageReceiver):
+    def __init__(self, infile: BinaryIO):
+        super().__init__()
+        self._infile = infile
+
+    def _do_get_init_message(self) -> ZeroCopyByteStream:
+        message_id = read_int(self._infile)
+        _assert_message_id(message_id, SpecialLengths.START_OF_INIT_MESSAGE)
+
+        # Read the length and init content
+        message_length = read_int(self._infile)
+        message_content = self._infile.read(message_length)
+
+        return ZeroCopyByteStream(memoryview(message_content))
+
+    def _do_get_data_stream(self) -> BinaryIO:
+        # For socket communication, we just pass along the underlying socket
+        # for the data channel. We already stripped the initialization data
+        # at this state. Therefore, any bytes following this are data bytes.
+        #
+        # Note: We deliberately did not introduce a message header for
+        # data messages to reduce the overhead, especially for small
+        # batch sizes and real-time-mode (RTM).
+        return self._infile
+
+    def _do_get_finish_signal_from_stream(self) -> None:
+        # If everything finished properly, we should read END_OF_STREAM.
+        # Anything else means something went wrong during processing.
+        message_id = read_int(self._infile)
+        _assert_message_id(message_id, SpecialLengths.END_OF_STREAM)
diff --git a/python/pyspark/messages/spark_message_receiver.py 
b/python/pyspark/messages/spark_message_receiver.py
new file mode 100644
index 000000000000..ec6b6fc30624
--- /dev/null
+++ b/python/pyspark/messages/spark_message_receiver.py
@@ -0,0 +1,122 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+from functools import wraps
+from typing import BinaryIO, Callable, TypeVar
+from abc import ABC, abstractmethod
+
+from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream
+
+
+T = TypeVar("T", bound="SparkMessageReceiver")
+R = TypeVar("R")
+
+
+class MessageState(Enum):
+    WAITING_FOR_INIT = 1
+    WAITING_FOR_DATA = 2
+    WAITING_FOR_FINISH = 3
+    DONE = 4
+
+
+class SparkMessageReceiver(ABC):
+    """
+    Generic class that implements receiving messages from Spark.
+    Caution: This class is STATEFUL. It is expected, that the
+    methods of this class are called in the following order:
+
+    1. Init -> 2. Data stream -> 3. Finish
+
+    This order is verified using assertions in the class. Each function
+    can be called EXACTLY ONCE in the specified order.
+    """
+
+    def __init__(self) -> None:
+        self._state = MessageState.WAITING_FOR_INIT
+
+    @staticmethod
+    def _state_transition(
+        required_state: MessageState, next_state: MessageState
+    ) -> Callable[[Callable[[T], R]], Callable[[T], R]]:
+        """Decorator to enforce state transitions."""
+
+        def decorator(func: Callable[[T], R]) -> Callable[[T], R]:
+            @wraps(func)
+            def wrapper(self: T) -> R:
+                assert self._state == required_state
+                result = func(self)
+                self._state = next_state
+                return result
+
+            return wrapper
+
+        return decorator
+
+    @_state_transition(MessageState.WAITING_FOR_INIT, 
MessageState.WAITING_FOR_DATA)
+    def get_init_message(self) -> ZeroCopyByteStream:
+        """
+        Returns:
+            the binary contents of the initial message as a ZeroCopyByteStream.
+        """
+        return self._do_get_init_message()
+
+    @_state_transition(MessageState.WAITING_FOR_DATA, 
MessageState.WAITING_FOR_FINISH)
+    def get_data_stream(self) -> BinaryIO:
+        """
+        Returns:
+            A binary stream containing the data to invoke the UDF on.
+        """
+        return self._do_get_data_stream()
+
+    @_state_transition(MessageState.WAITING_FOR_FINISH, MessageState.DONE)
+    def get_finish_signal_from_stream(self) -> None:
+        """
+        Consumes the finish message from the JVM and transitions to the DONE 
state.
+        The finish message marks the end of the stream. Raises an 
`AssertionError` if
+        the finish signal was not received correctly.
+        """
+        self._do_get_finish_signal_from_stream()
+
+    @abstractmethod
+    def _do_get_init_message(self) -> ZeroCopyByteStream:
+        """
+        Returns the contents of the init message
+        as a 'ZeroCopyByteStream'.
+
+        To be implemented by child classes.
+        """
+        ...
+
+    @abstractmethod
+    def _do_get_data_stream(self) -> BinaryIO:
+        """
+        Returns the Spark data stream.
+
+        To be implemented by child classes.
+        """
+        ...
+
+    @abstractmethod
+    def _do_get_finish_signal_from_stream(self) -> None:
+        """
+        Consumes the finish signal from the stream.
+        Raises an `AssertionError` if the signal is not received correctly.
+
+        To be implemented by child classes.
+        """
+        ...
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 6de64a1062f0..48166c948b5b 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -63,6 +63,7 @@ import collections
 import zlib
 import itertools
 import pickle
+import codecs
 
 pickle_protocol = pickle.HIGHEST_PROTOCOL
 
@@ -84,6 +85,7 @@ class SpecialLengths:
     END_OF_STREAM = -4
     NULL = -5
     START_ARROW_STREAM = -6
+    START_OF_INIT_MESSAGE = -8
 
 
 class Serializer:
@@ -539,7 +541,7 @@ class UTF8Deserializer(Serializer):
         elif length == SpecialLengths.NULL:
             return None
         s = stream.read(length)
-        return s.decode("utf-8") if self.use_unicode else s
+        return codecs.decode(s, "utf-8") if self.use_unicode else s
 
     def load_stream(self, stream):
         try:
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 7ac645323243..50cba601321b 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -161,7 +161,7 @@ class TaskContext:
         return cls._taskContext
 
     @classmethod
-    def _setTaskContext(cls: Type["TaskContext"], taskContext: "TaskContext") 
-> None:
+    def _setTaskContext(cls: Type["TaskContext"], taskContext: 
Optional["TaskContext"]) -> None:
         cls._taskContext = taskContext
 
     @classmethod
diff --git a/python/pyspark/tests/test_spark_message_receiver.py 
b/python/pyspark/tests/test_spark_message_receiver.py
new file mode 100644
index 000000000000..398bae922e5b
--- /dev/null
+++ b/python/pyspark/tests/test_spark_message_receiver.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import io
+import unittest
+from typing import BinaryIO
+
+from pyspark.messages.spark_message_receiver import SparkMessageReceiver
+from pyspark.messages.zero_copy_byte_stream import ZeroCopyByteStream
+
+
+class StubMessageReceiver(SparkMessageReceiver):
+    """Concrete stub for testing the state machine in SparkMessageReceiver."""
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def _do_get_init_message(self) -> ZeroCopyByteStream:
+        return ZeroCopyByteStream(memoryview(b"init"))
+
+    def _do_get_data_stream(self) -> BinaryIO:
+        return io.BytesIO(b"data")
+
+    def _do_get_finish_signal_from_stream(self) -> None:
+        pass
+
+
+class SparkMessageReceiverTests(unittest.TestCase):
+    """Tests for SparkMessageReceiver state transitions."""
+
+    def test_happy_path(self):
+        """Calling init -> data -> finish in order succeeds."""
+        receiver = StubMessageReceiver()
+        init_msg = receiver.get_init_message()
+        self.assertIsInstance(init_msg, ZeroCopyByteStream)
+        data = receiver.get_data_stream()
+        self.assertEqual(data.read(), b"data")
+        # Should not raise
+        receiver.get_finish_signal_from_stream()
+
+    def test_invalid_transitions_fail(self):
+        """Calling methods out of order raises AssertionError."""
+        # Each entry: (setup_calls, invalid_call)
+        cases = [
+            ([], "get_data_stream"),
+            ([], "get_finish_signal_from_stream"),
+            (["get_init_message"], "get_init_message"),
+            (["get_init_message", "get_data_stream"], "get_data_stream"),
+            (
+                ["get_init_message", "get_data_stream", 
"get_finish_signal_from_stream"],
+                "get_finish_signal_from_stream",
+            ),
+        ]
+        for setup_calls, invalid_call in cases:
+            with self.subTest(setup=setup_calls, invalid=invalid_call):
+                receiver = StubMessageReceiver()
+                for call in setup_calls:
+                    getattr(receiver, call)()
+                with self.assertRaises(AssertionError):
+                    getattr(receiver, invalid_call)()
+
+
+if __name__ == "__main__":
+    from pyspark.testing import main
+
+    main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 95980a6842ba..ad87e0570c13 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -39,6 +39,7 @@ from typing import (
     Union,
     get_args,
     get_origin,
+    BinaryIO,
 )
 
 T = TypeVar("T")
@@ -61,7 +62,6 @@ from pyspark.util import PythonEvalType
 from pyspark.serializers import (
     write_int,
     write_long,
-    read_int,
     SpecialLengths,
     CPickleSerializer,
     BatchedSerializer,
@@ -123,6 +123,10 @@ from pyspark.worker_util import (
     Conf,
 )
 from pyspark.logger.worker_io import capture_outputs
+from pyspark.messages import (
+    SparkMessageReceiver,
+    SparkSocketMessageReceiver,
+)
 
 
 class RunnerConf(Conf):
@@ -3510,11 +3514,20 @@ def read_udfs(pickleSer, udf_info_list, eval_type, 
runner_conf, eval_conf):
     return func, None, ser, ser
 
 
-@with_faulthandler
-def main(infile, outfile):
+def invoke_udf(message_receiver: SparkMessageReceiver, outfile: BinaryIO):
+    """
+    This function is the main processing function for worker.py.
+    It receives messages from the JVM, processes the data, and sends back 
results.
+    This method goes through three phases:
+
+    Initialization -> Processing -> Finish/Cleanup
+    """
     try:
         boot_time = time.time()
-        init_info = WorkerInitInfo.from_stream(infile)
+        # Initialization
+        init_message = message_receiver.get_init_message()
+        init_info = WorkerInitInfo.from_stream(init_message)
+
         start_faulthandler_periodic_traceback()
         check_python_version(init_info.python_version)
 
@@ -3538,6 +3551,7 @@ def main(infile, outfile):
         runner_conf = RunnerConf(init_info.runner_conf)
         eval_conf = EvalConf(init_info.eval_conf)
         if eval_type == PythonEvalType.NON_UDF:
+            assert isinstance(init_info.udf_info, (bytes, memoryview))
             func, profiler, deserializer, serializer = read_command(pickleSer, 
init_info.udf_info)
         elif eval_type in (
             PythonEvalType.SQL_TABLE_UDF,
@@ -3554,8 +3568,13 @@ def main(infile, outfile):
 
         init_time = time.time()
 
+        # Processing
+
+        # Fetch the input data stream
+        input_data_stream = message_receiver.get_data_stream()
+
         def process():
-            iterator = deserializer.load_stream(infile)
+            iterator = deserializer.load_stream(input_data_stream)
             out_iter = func(init_info.split_index, iterator)
             try:
                 serializer.dump_stream(out_iter, outfile)
@@ -3571,6 +3590,7 @@ def main(infile, outfile):
                 process()
         processing_time_ms = int(1000 * (time.time() - processing_start_time))
 
+        # Cleanup
         # Reset task context to None. This is a guard code to avoid residual 
context when worker
         # reuse.
         TaskContext._setTaskContext(None)
@@ -3587,15 +3607,26 @@ def main(infile, outfile):
     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
     send_accumulator_updates(outfile)
 
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+    # Check end of stream — raises if the finish signal is not received 
correctly.
+    # Note: this call might fail due to other reasons (e.g. channel broke)
+    # which will terminate the worker process.
+    try:
+        message_receiver.get_finish_signal_from_stream()
         write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
+    except Exception:
+        # Write a different value to tell JVM to not reuse this worker
         write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
         sys.exit(-1)
 
 
+@with_faulthandler
+def main(infile, outfile):
+    # Instantiate socket message readers for executing the UDF
+    socket_reader = SparkSocketMessageReceiver(infile)
+
+    invoke_udf(socket_reader, outfile)
+
+
 if __name__ == "__main__":
     with get_sock_file_to_executor() as sock_file:
         main(sock_file, sock_file)
diff --git a/python/pyspark/worker_message.py b/python/pyspark/worker_message.py
index b1519cb08442..c037e728088b 100644
--- a/python/pyspark/worker_message.py
+++ b/python/pyspark/worker_message.py
@@ -18,13 +18,14 @@
 import dataclasses
 import json
 import sys
-from typing import Optional, Union, IO
+from typing import Optional, TypeAlias, Union, IO, Any
 
 from pyspark.errors import PySparkValueError
 from pyspark.serializers import read_bool, read_int, read_long, SpecialLengths
 from pyspark.taskcontext import BarrierTaskContext, ResourceInformation, 
TaskContext
 from pyspark.util import PythonEvalType
 from pyspark.worker_util import utf8_deserializer
+from pyspark.messages import ZeroCopyByteStream
 
 
 @dataclasses.dataclass
@@ -46,7 +47,7 @@ class TaskContextInfo:
     local_properties: dict[str, str]
 
     @classmethod
-    def from_stream(cls, stream: IO) -> "TaskContextInfo":
+    def from_stream(cls, stream: ZeroCopyByteStream) -> "TaskContextInfo":
         task_context_json = json.loads(utf8_deserializer.loads(stream))
         return cls(
             is_barrier=task_context_json["isBarrier"],
@@ -100,7 +101,7 @@ class BroadcastInfo:
     variables: list[tuple[int, Optional[str]]]
 
     @classmethod
-    def from_stream(cls, stream: IO) -> "BroadcastInfo":
+    def from_stream(cls, stream: Union[ZeroCopyByteStream, IO[Any]]) -> 
"BroadcastInfo":
         needs_broadcast_decryption_server = read_bool(stream)
         num_broadcast_variables = read_int(stream)
         conn_info = None
@@ -125,13 +126,13 @@ class BroadcastInfo:
 
 @dataclasses.dataclass
 class UDFInfo:
-    udfs: list[bytes]
+    udfs: list[memoryview]
     args: list[int]
     kwargs: dict[str, int]
     result_id: int
 
     @classmethod
-    def from_stream(cls, stream: IO) -> "UDFInfo":
+    def from_stream(cls, stream: ZeroCopyByteStream) -> "UDFInfo":
         num_args = read_int(stream)
         udfs = []
         args = []
@@ -167,13 +168,13 @@ class UDTFInfo:
     args: list[int]
     kwargs: dict[str, int]
     partition_child_indexes: list[int]
-    pickled_analyze_result: Optional[bytes]
-    handler: bytes
+    pickled_analyze_result: Optional[memoryview]
+    handler: memoryview
     return_type: str
     name: str
 
     @classmethod
-    def from_stream(cls, stream: IO) -> "UDTFInfo":
+    def from_stream(cls, stream: ZeroCopyByteStream) -> "UDTFInfo":
         # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
         args = []
         kwargs = {}
@@ -204,6 +205,9 @@ class UDTFInfo:
         )
 
 
+UDFInfoType: TypeAlias = Union[memoryview, UDTFInfo, list[UDFInfo]]
+
+
 @dataclasses.dataclass
 class WorkerInitInfo:
     split_index: int
@@ -215,10 +219,10 @@ class WorkerInitInfo:
     eval_type: int
     runner_conf: dict[str, str]
     eval_conf: dict[str, str]
-    udf_info: Union[bytes, UDTFInfo, list[UDFInfo]]
+    udf_info: UDFInfoType
 
     @classmethod
-    def from_stream(cls, stream: IO) -> "WorkerInitInfo":
+    def from_stream(cls, stream: ZeroCopyByteStream) -> "WorkerInitInfo":
         split_index = read_int(stream)
         if split_index == -1:
             sys.exit(-1)
@@ -243,9 +247,10 @@ class WorkerInitInfo:
             v = utf8_deserializer.loads(stream)
             eval_conf[k] = v
 
-        udf_info: Union[bytes, UDTFInfo, list[UDFInfo]]
+        udf_info: UDFInfoType
 
         if eval_type == PythonEvalType.NON_UDF:
+            # Returns memoryview; see assertion in worker.py if changing this 
type.
             udf_info = stream.read(read_int(stream))
         elif eval_type in (
             PythonEvalType.SQL_TABLE_UDF,
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index 08edf0c5decb..13e08449de22 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -27,6 +27,8 @@ import sys
 from typing import Any, Generator, IO, Optional, Union, overload
 import warnings
 
+from pyspark.messages import ZeroCopyByteStream
+
 if "SPARK_TESTING" in os.environ:
     assert os.environ.get("SPARK_PYTHON_RUNTIME") == "PYTHON_WORKER", (
         "This module can only be imported in python woker"
@@ -65,11 +67,11 @@ def add_path(path: str) -> bool:
     return False
 
 
-def read_command(serializer: FramedSerializer, file: Union[IO, bytes]) -> Any:
+def read_command(serializer: FramedSerializer, file: Union[IO, bytes, 
memoryview]) -> Any:
     if not is_remote_only():
         from pyspark.core.broadcast import Broadcast
 
-    if isinstance(file, bytes):
+    if isinstance(file, (bytes, memoryview)):
         command = serializer.loads(file)
     else:
         command = serializer._read_with_length(file)
@@ -173,7 +175,9 @@ def setup_spark_files(
 
 
 @overload
-def setup_broadcasts(infile_or_variables: IO) -> None: ...
+def setup_broadcasts(infile_or_variables: IO[Any]) -> None: ...
+@overload
+def setup_broadcasts(infile_or_variables: ZeroCopyByteStream) -> None: ...
 @overload
 def setup_broadcasts(
     infile_or_variables: list[tuple[int, Union[str, None]]], conn_info: str, 
auth_secret: None
@@ -184,10 +188,12 @@ def setup_broadcasts(
 ) -> None: ...
 @overload
 def setup_broadcasts(
-    infile_or_variables: list[tuple[int, Union[str, None]]], conn_info: None, 
auth_secret: None
+    infile_or_variables: list[tuple[int, Union[str, None]]],
+    conn_info: Optional[Union[str, int]],
+    auth_secret: Optional[str],
 ) -> None: ...
 def setup_broadcasts(
-    infile_or_variables: Union[IO, list[tuple[int, Union[str, None]]]],
+    infile_or_variables: Union[ZeroCopyByteStream, IO[Any], list[tuple[int, 
Union[str, None]]]],
     conn_info: Optional[Union[str, int]] = None,
     auth_secret: Optional[str] = None,
 ) -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to