Yicong-Huang commented on code in PR #55552:
URL: https://github.com/apache/spark/pull/55552#discussion_r3359294990


##########
python/pyspark/worker.py:
##########
@@ -3651,12 +3651,95 @@ def process():
                 if hasattr(out_iter, "close"):
                     out_iter.close()
 
+        def pipelined_process():
+            """
+            Pipelined variant of process() that pre-fetches input batches in a 
background
+            reader thread while the main thread computes the UDF and writes 
output.
+            This allows input deserialization to overlap with UDF computation.
+            """
+            import queue
+            import threading
+
+            queue_depth = 
int(os.environ.get("SPARK_PIPELINED_UDF_QUEUE_DEPTH", "2"))
+            _SENTINEL = object()
+            input_queue = queue.Queue(maxsize=queue_depth)
+            reader_error = [None]
+            # Event to signal the reader thread to stop (set by main thread on
+            # exception or completion). The reader checks this after each 
failed
+            # put attempt instead of polling with a timeout.
+            stop_event = threading.Event()
+
+            def _reader_thread():
+                try:
+                    for batch in deserializer.load_stream(infile):
+                        # Some serializers (e.g., ArrowStreamGroupSerializer,
+                        # ArrowStreamAggPandasUDFSerializer) yield lazy 
iterators
+                        # that still read from infile. Materialize them here 
so the
+                        # main thread can consume them without touching infile.
+                        if hasattr(batch, "__next__"):
+                            batch = list(batch)
+                        # Block on put, but wake up when stop_event is set.
+                        # stop_event.wait() returns immediately if already set.
+                        while not stop_event.is_set():
+                            try:
+                                input_queue.put(batch, timeout=0.1)
+                                break
+                            except queue.Full:
+                                continue
+                        if stop_event.is_set():
+                            return
+                except Exception as e:
+                    reader_error[0] = e
+                finally:
+                    # Enqueue sentinel so the consumer knows we're done.
+                    while not stop_event.is_set():
+                        try:
+                            input_queue.put(_SENTINEL, timeout=0.1)
+                            break
+                        except queue.Full:
+                            continue
+
+            t = threading.Thread(
+                target=_reader_thread, name="pyspark-pipelined-reader", 
daemon=True
+            )
+            t.start()
+
+            def _queued_iter():
+                while True:
+                    item = input_queue.get()
+                    if item is _SENTINEL:
+                        if reader_error[0] is not None:
+                            raise reader_error[0]
+                        return
+                    yield item
+
+            out_iter = func(split_index, _queued_iter())
+            try:
+                serializer.dump_stream(out_iter, outfile)
+            finally:
+                if hasattr(out_iter, "close"):
+                    out_iter.close()
+                # Signal reader thread to stop, drain the queue so it can 
unblock,
+                # then wait for it to finish.
+                stop_event.set()
+                try:
+                    while not input_queue.empty():
+                        input_queue.get_nowait()
+                except Exception:
+                    pass
+                t.join(timeout=5)

Review Comment:
   Hardcoded 5s `t.join(timeout=5)` is silent on timeout. If the reader is 
blocked in `infile.read()`, the daemon leaks and may read into the next reused 
worker's task. Close `infile` to force EOF (or at minimum log on timeout).



##########
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:
##########
@@ -396,7 +532,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
       partitionIndex: Int,
       context: TaskContext) {
 
-    @volatile private var _exception: Throwable = _
+    @volatile private[python] var _exception: Throwable = _

Review Comment:
   nit: relaxing `_exception` to `private[python]` so `PipelinedWriterRunnable` 
can mutate it directly is a bit loose. A small `def setException(t: Throwable)` 
would localize the contract.



##########
python/pyspark/worker.py:
##########
@@ -3609,12 +3588,93 @@ def process():
                 if hasattr(out_iter, "close"):
                     out_iter.close()
 
+        def pipelined_process():
+            """
+            Pipelined variant of process() that pre-fetches input batches in a 
background
+            reader thread while the main thread computes the UDF and writes 
output.
+            This allows input deserialization to overlap with UDF computation.
+            """
+            # Mark that pipelined mode is active so UDFs can verify the code 
path.
+            os.environ["SPARK_PIPELINED_UDF_ACTIVE"] = "1"
+            import queue
+            import threading
+
+            queue_depth = 
int(os.environ.get("SPARK_PIPELINED_UDF_QUEUE_DEPTH", "2"))
+            _SENTINEL = object()
+            input_queue = queue.Queue(maxsize=queue_depth)
+            reader_error = [None]
+            stop_event = threading.Event()
+
+            def _reader_thread():
+                try:
+                    for batch in deserializer.load_stream(infile):
+                        # Some serializers (e.g., ArrowStreamGroupSerializer,
+                        # ArrowStreamAggPandasUDFSerializer) yield lazy 
iterators
+                        # that still read from infile. Materialize them here 
so the
+                        # main thread can consume them without touching infile.
+                        if hasattr(batch, "__next__"):
+                            batch = list(batch)
+                        # Use timeout put so we can check stop_event 
periodically.
+                        # This prevents the reader from blocking forever if 
the main
+                        # thread stops consuming (e.g., due to UDF exception).
+                        while not stop_event.is_set():
+                            try:
+                                input_queue.put(batch, timeout=1)
+                                break
+                            except queue.Full:
+                                continue
+                        if stop_event.is_set():
+                            return

Review Comment:
   I do believe Condition with a bounded buffer would be better. When the queue 
is full, usually it is due to UDF busy processing the previous batch. And the 
0.1s timeout and check will be fired busily until the previous batch is 
finished and UDF retrieves the next batch from the queue. Maybe I can follow up 
on this later.



##########
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:
##########
@@ -985,6 +1121,109 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
   }
 
+  /**
+   * A dedicated thread that serializes input data and writes it directly to 
the Python worker
+   * socket in blocking mode. The task main thread simultaneously reads output 
from the same
+   * socket. TCP sockets are full-duplex, so concurrent read() and write() 
from different
+   * threads is safe -- they operate on independent OS-level buffers.
+   *
+   * This design achieves true pipeline parallelism without any inter-thread 
queues or locks:
+   *   Writer Thread:  serialize batch N  ->  channel.write(batch N)    
[blocking]
+   *   Reader Thread:  channel.read(output N-1)                        
[blocking]
+   *   Python:         read batch N-1  ->  compute  ->  write output  ->  read 
batch N
+   *
+   * Deadlock safety: Python's UDF loop is "read input -> compute -> write 
output -> repeat".
+   * As long as the reader thread is consuming Python's output (freeing 
Python's send buffer),
+   * Python will eventually consume input from the socket (freeing the JVM's 
send buffer for
+   * the writer thread). The reader thread is always actively reading because 
the task's
+   * downstream operators pull output on demand.
+   *
+   * Unlike the old WriterThread (removed in SPARK-44705), this design uses a 
blocking socket
+   * in full-duplex mode rather than two threads competing on the same 
blocking socket with
+   * shared mutable state. The old design's deadlocks were caused by complex 
interactions
+   * with vectorized readers and monitor threads, not by the fundamental 
read/write split.
+   */
+  class PipelinedWriterRunnable(
+      worker: PythonWorker,
+      writer: Writer,
+      bufferSize: Int,
+      context: TaskContext)
+    extends Runnable {
+
+    // Capture InputFileBlockHolder from the task thread so we can propagate it
+    // to the writer pool thread. This is needed because upstream scan 
operators
+    // set InputFileBlockHolder via InheritableThreadLocal, but pool threads
+    // don't inherit from the task thread.
+    private val parentInputFileBlockHolder = 
InputFileBlockHolder.getThreadLocalValue()
+
+    override def run(): Unit = {
+      // Propagate TaskContext and InputFileBlockHolder to the pool thread so 
that
+      // upstream operators work correctly.
+      TaskContext.setTaskContext(context)
+      InputFileBlockHolder.setThreadLocalValue(parentInputFileBlockHolder)
+      val bufferStream = new DirectByteBufferOutputStream(bufferSize)
+      val dataOut = new DataOutputStream(bufferStream)
+      try {
+        // Write command/metadata (partition index, task context, broadcasts, 
UDF definition).
+        writer.open(dataOut)
+        flushToSocket(bufferStream)
+
+        // Write input data in a loop, batching into buffers of ~bufferSize.
+        var hasInput = true
+        while (hasInput && !Thread.currentThread().isInterrupted) {
+          hasInput = writer.writeNextInputToStream(dataOut)
+          if (bufferStream.size() >= bufferSize || !hasInput) {
+            if (!hasInput) {
+              writer.close(dataOut)
+            }
+            flushToSocket(bufferStream)
+          }
+        }
+      } catch {
+        case _: InterruptedException =>
+          // Task cancelled via Future.cancel(true)
+          Thread.currentThread().interrupt()
+        case _: java.nio.channels.ClosedByInterruptException =>
+          // Task cancelled while blocked in channel.write(). The channel is
+          // automatically closed by the JVM, which will cause Python to 
receive
+          // EOF and the reader thread to get IOException.
+          Thread.currentThread().interrupt()
+        case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] =>

Review Comment:
   nit: `NonFatal(t) || t.isInstanceOf[Exception]` is approximately just `case 
NonFatal(t) =>` -- `InterruptedException`/`ClosedByInterruptException` are 
already matched above. Same pattern at line 662 in the original WriterThread.



##########
python/pyspark/worker.py:
##########
@@ -3588,12 +3588,93 @@ def process():
                 if hasattr(out_iter, "close"):
                     out_iter.close()
 
+        def pipelined_process():
+            """
+            Pipelined variant of process() that pre-fetches input batches in a 
background
+            reader thread while the main thread computes the UDF and writes 
output.
+            This allows input deserialization to overlap with UDF computation.
+            """
+            # Mark that pipelined mode is active so UDFs can verify the code 
path.
+            os.environ["SPARK_PIPELINED_UDF_ACTIVE"] = "1"
+            import queue
+            import threading
+
+            queue_depth = 
int(os.environ.get("SPARK_PIPELINED_UDF_QUEUE_DEPTH", "2"))
+            _SENTINEL = object()
+            input_queue = queue.Queue(maxsize=queue_depth)
+            reader_error = [None]
+            stop_event = threading.Event()
+
+            def _reader_thread():
+                try:
+                    for batch in deserializer.load_stream(infile):
+                        # Some serializers (e.g., ArrowStreamGroupSerializer,
+                        # ArrowStreamAggPandasUDFSerializer) yield lazy 
iterators
+                        # that still read from infile. Materialize them here 
so the
+                        # main thread can consume them without touching infile.
+                        if hasattr(batch, "__next__"):
+                            batch = list(batch)

Review Comment:
   Then shall we update our benchmark to use a slightly larger batch in terms 
of in-memory size? also want to see the time diff for larger batch. 



##########
core/src/main/scala/org/apache/spark/internal/config/Python.scala:
##########
@@ -150,4 +150,28 @@ private[spark] object Python {
       .version("4.1.0")
       .booleanConf
       .createWithDefault(true)
+
+  val PYTHON_UDF_PIPELINED_EXECUTION =
+    ConfigBuilder("spark.python.udf.pipelined.enabled")
+      .doc("When true, enables pipelined (asynchronous) data transfer between 
JVM and Python " +
+        "UDF workers. In pipelined mode, input serialization runs in a 
separate writer thread " +
+        "while the main task thread reads output from the Python worker, 
allowing the two " +
+        "directions to overlap for improved throughput. " +
+        "This is particularly beneficial for compute-heavy UDFs (e.g., ML 
inference).")

Review Comment:
   nit: "improved throughput" overstates the result. The benchmark table shows 
up to 1.24x for multi-UDF and ~1.0x (or worse) for single-UDF/heavy-compute. 
Consider qualifying with "for some workloads (e.g., multi-column or 
compute-heavy UDFs)".



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala:
##########
@@ -67,10 +67,16 @@ abstract class EvalPythonEvaluatorFactory(
 
       // The queue used to buffer input rows so we can drain it to
       // combine input with output from Python.
+      // In pipelined mode, add() runs in the writer thread and remove() in 
the task thread.
+      // Use lock-free mode to avoid synchronized overhead (memory visibility 
is guaranteed
+      // by the blocking socket I/O between the two threads).
+      val pipelined = SparkEnv.get.conf.get(
+        org.apache.spark.internal.config.Python.PYTHON_UDF_PIPELINED_EXECUTION)
       val queue = HybridRowQueue(
         context.taskMemoryManager(),
         new File(Utils.getLocalDir(SparkEnv.get.conf)),
-        childOutput.length)
+        childOutput.length,
+        lockFree = pipelined)

Review Comment:
   `lockFree=true` only at this site means other Python evaluators 
(`applyInPandas`, `mapInPandas`, window UDF) silently keep synchronized queues 
even when the runner is in pipelined mode. Is this intentional? If yes, please 
add a comment; if not, plumb through `BasePythonRunner` for consistency.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to