viirya commented on code in PR #55552:
URL: https://github.com/apache/spark/pull/55552#discussion_r3378561366
##########
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala:
##########
@@ -396,7 +532,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
partitionIndex: Int,
context: TaskContext) {
- @volatile private var _exception: Throwable = _
+ @volatile private[python] var _exception: Throwable = _
Review Comment:
Tightened in 8688369e47b: `_exception` is now `private` again, and
`PipelinedWriterRunnable` records via a new `setException(t)` method on
`Writer`.
##########
python/pyspark/worker.py:
##########
@@ -3651,12 +3651,95 @@ def process():
if hasattr(out_iter, "close"):
out_iter.close()
+ def pipelined_process():
+ """
+ Pipelined variant of process() that pre-fetches input batches in a
background
+ reader thread while the main thread computes the UDF and writes
output.
+ This allows input deserialization to overlap with UDF computation.
+ """
+ import queue
+ import threading
+
+ queue_depth =
int(os.environ.get("SPARK_PIPELINED_UDF_QUEUE_DEPTH", "2"))
+ _SENTINEL = object()
+ input_queue = queue.Queue(maxsize=queue_depth)
+ reader_error = [None]
+ # Event to signal the reader thread to stop (set by main thread on
+ # exception or completion). The reader checks this after each
failed
+ # put attempt instead of polling with a timeout.
+ stop_event = threading.Event()
+
+ def _reader_thread():
+ try:
+ for batch in deserializer.load_stream(infile):
+ # Some serializers (e.g., ArrowStreamGroupSerializer,
+ # ArrowStreamAggPandasUDFSerializer) yield lazy
iterators
+ # that still read from infile. Materialize them here
so the
+ # main thread can consume them without touching infile.
+ if hasattr(batch, "__next__"):
+ batch = list(batch)
+ # Block on put, but wake up when stop_event is set.
+ # stop_event.wait() returns immediately if already set.
+ while not stop_event.is_set():
+ try:
+ input_queue.put(batch, timeout=0.1)
+ break
+ except queue.Full:
+ continue
+ if stop_event.is_set():
+ return
+ except Exception as e:
+ reader_error[0] = e
+ finally:
+ # Enqueue sentinel so the consumer knows we're done.
+ while not stop_event.is_set():
+ try:
+ input_queue.put(_SENTINEL, timeout=0.1)
+ break
+ except queue.Full:
+ continue
+
+ t = threading.Thread(
+ target=_reader_thread, name="pyspark-pipelined-reader",
daemon=True
+ )
+ t.start()
+
+ def _queued_iter():
+ while True:
+ item = input_queue.get()
+ if item is _SENTINEL:
+ if reader_error[0] is not None:
+ raise reader_error[0]
+ return
+ yield item
+
+ out_iter = func(split_index, _queued_iter())
+ try:
+ serializer.dump_stream(out_iter, outfile)
+ finally:
+ if hasattr(out_iter, "close"):
+ out_iter.close()
+ # Signal reader thread to stop, drain the queue so it can
unblock,
+ # then wait for it to finish.
+ stop_event.set()
+ try:
+ while not input_queue.empty():
+ input_queue.get_nowait()
+ except Exception:
+ pass
+ t.join(timeout=5)
Review Comment:
Updated in 8688369e47b to emit a `RuntimeWarning` if the reader thread is
still alive after the 5s join. I considered force-closing `infile` to break the
reader out of `infile.read()`, but `infile` is the worker socket fd which is
shared with the next reused-worker task -- closing it would break worker reuse.
Settled for a bounded join + loud warning so an undetected leak shows up in the
worker log. The warning text suggests disabling `spark.python.worker.reuse` if
the timeout recurs.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala:
##########
@@ -67,10 +67,16 @@ abstract class EvalPythonEvaluatorFactory(
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
+ // In pipelined mode, add() runs in the writer thread and remove() in
the task thread.
+ // Use lock-free mode to avoid synchronized overhead (memory visibility
is guaranteed
+ // by the blocking socket I/O between the two threads).
+ val pipelined = SparkEnv.get.conf.get(
+ org.apache.spark.internal.config.Python.PYTHON_UDF_PIPELINED_EXECUTION)
val queue = HybridRowQueue(
context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)),
- childOutput.length)
+ childOutput.length,
+ lockFree = pipelined)
Review Comment:
Good catch -- this was an oversight. The other four Python evaluators (UDTF,
Aggregate, Window, ColumnarArrow) all go through the same pipelined runner, so
they should all use the lock-free queue when pipelined mode is on. Plumbed
through in 8688369e47b.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]