This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8aaff558394 [SPARK-44705][PYTHON] Make PythonRunner single-threaded 8aaff558394 is described below commit 8aaff55839493e80e3ce376f928c04aa8f31d18c Author: Utkarsh <utkarsh.agar...@databricks.com> AuthorDate: Fri Aug 11 10:34:05 2023 +0900 [SPARK-44705][PYTHON] Make PythonRunner single-threaded ### What changes were proposed in this pull request? PythonRunner, a utility that executes Python UDFs in Spark, uses two threads in a producer-consumer model today. This multi-threading model is problematic and confusing as Spark's execution model within a task is commonly understood to be single-threaded. More importantly, this departure of a double-threaded execution resulted in a series of customer issues involving [race conditions](https://issues.apache.org/jira/browse/SPARK-33277) and [deadlocks](https://issues.apache.org/jira/browse/SPARK-38677) between threads as the code was hard to reason about. There have been multiple attempts to reign in these issues, viz., [fix 1](https://issues.apache.org/jira/browse/SPARK-22535), [fix 2](https://github.com/apache/spark/pull/30177), [fix 3 [...] #### Current Execution Model in Spark for Python UDFs For queries containing Python UDFs, the main Java task thread spins up a new writer thread to pipe data from the child Spark plan into the Python worker evaluating the UDF. The writer thread runs in a tight loop: evaluates the child Spark plan, and feeds the resulting output to the Python worker. The main task thread simultaneously consumes the Python UDF’s output and evaluates the parent Spark plan to produce the final result. The I/O to/from the Python worker uses blocking Java Sockets necessitating the use of two threads, one responsible for input to the Python worker and the other for output. Without two threads, it is easy to run into a deadlock. For example, the task can block forever waiting for the output from the Python worker. The output will never arrive until the input is supplied to the Python worker, which is not possible as the task thread is blocked while waiting on output. #### Proposed Fix The proposed fix is to move to the standard single-threaded execution model within a task, i.e., to do away with the writer thread. In addition to mitigating the crashes, the fix reduces the complexity of the existing code by doing away with many safety checks in place to track deadlocks in the double-threaded execution model. In the new model, the main task thread alternates between consuming/feeding data to the Python worker using asynchronous I/O through Java’s [SocketChannel](https://docs.oracle.com/javase/7/docs/api/java/nio/channels/SocketChannel.html). See the `read()` method in the code below for approximately how this is achieved. ``` case class PythonUDFRunner { private var nextRow: Row = _ private var endOfStream = false private var childHasNext = true private var buffer: ByteBuffer = _ def hasNext(): Boolean = nextRow != null || { if (!endOfStream) { read(buffer) nextRow = deserialize(buffer) hasNext } else { false } } def next(): Row = { if (hasNext) { val outputRow = nextRow nextRow = null outputRow } else { null } } def read(buf: Array[Byte]): Row = { var n = 0 while (n == 0) { // Alternate between reading/writing to the Python worker using async I/O if (pythonWorker.isReadable) { n = pythonWorker.read(buf) } if (pythonWorker.isWritable) { consumeChildPlanAndWriteDataToPythonWorker() } } def consumeChildPlanAndWriteDataToPythonWorker(): Unit = { // Tracks whether the connection to the Python worker can be written to. var socketAcceptsInput = true while (socketAcceptsInput && (childHasNext || buffer.hasRemaining)) { if (!buffer.hasRemaining && childHasNext) { // Consume data from the child and buffer it. writeToBuffer(childPlan.next(), buffer) childHasNext = childPlan.hasNext() if (!childHasNext) { // Exhausted child plan’s output. Write a keyword to the Python worker signaling the end of data input. writeToBuffer(endOfStream) } } // Try to write as much buffered data as possible to the Python worker. while (buffer.hasRemaining && socketAcceptsInput) { val n = writeToPythonWorker(buffer) // `writeToPythonWorker()` returns 0 when the socket cannot accept more data right now. socketAcceptsInput = n > 0 } } } } ``` ### Why are the changes needed? This PR makes PythonRunner single-threaded making it easier to reason about and improving code health. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. Closes #42385 from utkarsh39/SPARK-44705. Authored-by: Utkarsh <utkarsh.agar...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/ContextAwareIterator.scala | 2 + .../src/main/scala/org/apache/spark/SparkEnv.scala | 15 +- .../org/apache/spark/api/python/PythonRDD.scala | 22 +- .../org/apache/spark/api/python/PythonRunner.scala | 362 ++++++++++++++------- .../spark/api/python/PythonWorkerFactory.scala | 105 +++--- .../spark/api/python/PythonWorkerUtils.scala | 6 +- .../spark/api/python/StreamingPythonRunner.scala | 15 +- .../apache/spark/rdd/InputFileBlockHolder.scala | 11 + .../spark/util/DirectByteBufferOutputStream.scala | 85 +++++ .../ApplyInPandasWithStatePythonRunner.scala | 34 +- .../sql/execution/python/ArrowPythonRunner.scala | 8 +- .../execution/python/BatchEvalPythonUDTFExec.scala | 11 +- .../python/CoGroupedArrowPythonRunner.scala | 20 +- .../python/EvalPythonEvaluatorFactory.scala | 5 +- .../sql/execution/python/EvaluatePython.scala | 3 +- .../python/MapInBatchEvaluatorFactory.scala | 5 +- .../sql/execution/python/PythonArrowInput.scala | 81 +++-- .../sql/execution/python/PythonArrowOutput.scala | 13 +- .../sql/execution/python/PythonForeachWriter.scala | 69 +++- .../sql/execution/python/PythonUDFRunner.scala | 37 ++- .../python/UserDefinedPythonFunction.scala | 63 +++- 21 files changed, 666 insertions(+), 306 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala b/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala index 84ae93f1788..facb03365e8 100644 --- a/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala +++ b/core/src/main/scala/org/apache/spark/ContextAwareIterator.scala @@ -30,8 +30,10 @@ import org.apache.spark.annotation.DeveloperApi * Thus, we should use [[ContextAwareIterator]] to stop consuming after the task ends. * * @since 3.1.0 + * @deprecated since 4.0.0 as its only usage for Python evaluation is now extinct */ @DeveloperApi +@deprecated("Only usage for Python evaluation is now extinct", "3.5.0") class ContextAwareIterator[+T](val context: TaskContext, val delegate: Iterator[T]) extends Iterator[T] { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index eef99c26e77..e404c9ee8b4 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -18,7 +18,6 @@ package org.apache.spark import java.io.File -import java.net.Socket import java.util.Locale import scala.collection.JavaConverters._ @@ -30,7 +29,7 @@ import com.google.common.cache.CacheBuilder import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.PythonWorkerFactory +import org.apache.spark.api.python.{PythonWorker, PythonWorkerFactory} import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorBackend import org.apache.spark.internal.{config, Logging} @@ -129,7 +128,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, daemonModule: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + envVars: Map[String, String]): (PythonWorker, Option[Int]) = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.getOrElseUpdate(key, @@ -140,7 +139,7 @@ class SparkEnv ( private[spark] def createPythonWorker( pythonExec: String, workerModule: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + envVars: Map[String, String]): (PythonWorker, Option[Int]) = { createPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars) } @@ -150,7 +149,7 @@ class SparkEnv ( workerModule: String, daemonModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.get(key).foreach(_.stopWorker(worker)) @@ -161,7 +160,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { destroyPythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker) } @@ -171,7 +170,7 @@ class SparkEnv ( workerModule: String, daemonModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { synchronized { val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars) pythonWorkers.get(key).foreach(_.releaseWorker(worker)) @@ -182,7 +181,7 @@ class SparkEnv ( pythonExec: String, workerModule: String, envVars: Map[String, String], - worker: Socket): Unit = { + worker: PythonWorker): Unit = { releasePythonWorker( pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, worker) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 91fd92d4422..a2f2d566db5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -137,7 +137,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker - private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private val workerBroadcasts = new mutable.WeakHashMap[PythonWorker, mutable.Set[Long]]() // Authentication helper used when serving iterator data. private lazy val authHelper = { @@ -145,7 +145,7 @@ private[spark] object PythonRDD extends Logging { new SocketAuthHelper(conf) } - def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { + def getWorkerBroadcasts(worker: PythonWorker): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) } @@ -300,7 +300,11 @@ private[spark] object PythonRDD extends Logging { new PythonBroadcast(path) } - def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = { + /** + * Writes the next element of the iterator `iter` to `dataOut`. Returns true if any data was + * written to the stream. Returns false if no data was written as the iterator has been exhausted. + */ + def writeNextElementToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Boolean = { def write(obj: Any): Unit = obj match { case null => @@ -318,8 +322,18 @@ private[spark] object PythonRDD extends Logging { case other => throw new SparkException("Unexpected element type " + other.getClass) } + if (iter.hasNext) { + write(iter.next()) + true + } else { + false + } + } - iter.foreach(write) + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = { + while (writeNextElementToStream(iter, dataOut)) { + // Nothing. + } } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 0173de75ff2..d7801d2e83b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -19,6 +19,8 @@ package org.apache.spark.api.python import java.io._ import java.net._ +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.{Files => JavaFiles, Path} @@ -32,6 +34,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} import org.apache.spark.internal.config.Python._ +import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -103,6 +106,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) + protected val timelyFlushEnabled: Boolean = false + protected val timelyFlushTimeoutNanos: Long = 0 protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) @@ -143,7 +148,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Python accumulator is always set in production except in tests. See SPARK-27893 private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator) - // Expose a ServerSocket to support method calls via socket from Python side. + // Expose a ServerSocket to support method calls via socket from Python side. Only relevant for + // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] for details. private[spark] var serverSocket: Option[ServerSocket] = None // Authentication helper used when serving method calls via socket from Python side. @@ -194,7 +200,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( + val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker( pythonExec, workerModule, daemonModule, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make @@ -202,22 +208,19 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val releasedOrClosed = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator - val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + val writer = newWriter(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener[Unit] { _ => - writerThread.shutdownOnTaskCompletion() if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { try { - worker.close() + worker.stop() } catch { case e: Exception => - logWarning("Failed to close worker socket", e) + logWarning("Failed to stop worker") } } } - writerThread.start() - new WriterMonitorThread(SparkEnv.get, worker, writerThread, context).start() if (reuseWorker) { val key = (worker, context.taskAttemptId) // SPARK-35009: avoid creating multiple monitor threads for the same python worker @@ -230,68 +233,49 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - + val dataIn = new DataInputStream( + new BufferedInputStream(new ReaderInputStream(worker, writer), bufferSize)) val stdoutIterator = newReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) + dataIn, writer, startTime, env, worker, pid, releasedOrClosed, context) new InterruptibleIterator(context, stdoutIterator) } - protected def newWriterThread( + protected def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext): WriterThread + context: TaskContext): Writer protected def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ - abstract class WriterThread( + abstract class Writer( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { + context: TaskContext) { - @volatile private var _exception: Throwable = null + @volatile private var _exception: Throwable = _ private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - setDaemon(true) - /** Contains the throwable thrown while writing the parent iterator to the Python process. */ def exception: Option[Throwable] = Option(_exception) - /** - * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur - * due to cleanup. - */ - def shutdownOnTaskCompletion(): Unit = { - assert(context.isCompleted) - this.interrupt() - // Task completion listeners that run after this method returns may invalidate - // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized - // reader, a task completion listener will free the underlying off-heap buffers. If the writer - // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free - // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer - // thread to exit before returning. - this.join() - } - /** * Writes a command section to the stream connected to the Python worker. */ @@ -299,14 +283,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( /** * Writes input data to the stream connected to the Python worker. + * Returns true if any data was written to the stream, false if the input is exhausted. */ - protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + def writeNextInputToStream(dataOut: DataOutputStream): Boolean - override def run(): Unit = Utils.logUncaughtExceptions { + def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions { try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(partitionIndex) @@ -367,21 +349,25 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } else { "" } - // Close ServerSocket on task completion. - serverSocket.foreach { server => - context.addTaskCompletionListener[Unit](_ => server.close()) - } - val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) - if (boundPort == -1) { - val message = "ServerSocket failed to bind to Java side." - logError(message) - throw new SparkException(message) - } else if (isBarrier) { + if (isBarrier) { + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener[Unit](_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } logDebug(s"Started ServerSocket on port $boundPort.") + dataOut.writeBoolean(/* isBarrier = */true) + dataOut.writeInt(boundPort) + } else { + dataOut.writeBoolean(/* isBarrier = */false) + dataOut.writeInt(0) } // Write out the TaskContextInfo - dataOut.writeBoolean(isBarrier) - dataOut.writeInt(boundPort) val secretBytes = secret.getBytes(UTF_8) dataOut.writeInt(secretBytes.length) dataOut.write(secretBytes, 0, secretBytes.length) @@ -412,30 +398,33 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(evalType) writeCommand(dataOut) - writeIteratorToStream(dataOut) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { - case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + case t: Throwable if NonFatal(t) || t.isInstanceOf[Exception] => if (context.isCompleted || context.isInterrupted) { logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + "cleanup)", t) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + if (worker.channel.isConnected) { + Utils.tryLog(worker.channel.shutdownOutput()) } } else { // We must avoid throwing exceptions/NonFatals here, because the thread uncaught // exception handler will kill the whole executor (see // org.apache.spark.executor.Executor). _exception = t - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + if (worker.channel.isConnected) { + Utils.tryLog(worker.channel.shutdownOutput()) } } } } + def close(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } + /** * Gateway to call BarrierTaskContext methods. */ @@ -470,10 +459,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( abstract class ReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext) @@ -531,7 +520,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val obj = new Array[Byte](exLength) stream.readFully(obj) new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.orNull) + writer.exception.orNull) } protected def handleEndOfDataSection(): Unit = { @@ -554,10 +543,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( logDebug("Exception thrown after task interruption", e) throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - case e: Exception if writerThread.exception.isDefined => + case e: Exception if writer.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get + logError("This may have been caused by a prior exception:", writer.exception.get) + throw writer.exception.get case eof: EOFException if faultHandlerEnabled && pid.isDefined && JavaFiles.exists(BasePythonRunner.faultHandlerLogPath(pid.get)) => @@ -576,7 +565,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the * threads can block indefinitely. */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + class MonitorThread(env: SparkEnv, worker: PythonWorker, context: TaskContext) extends Thread(s"Worker Monitor for $pythonExec") { /** How long to wait before killing the python worker if a task cannot be interrupted. */ @@ -620,60 +609,185 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } } - /** - * This thread monitors the WriterThread and kills it in case of deadlock. - * - * A deadlock can arise if the task completes while the writer thread is sending input to the - * Python process (e.g. due to the use of `take()`), and the Python process is still producing - * output. When the inputs are sufficiently large, this can result in a deadlock due to the use of - * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. - */ - class WriterMonitorThread( - env: SparkEnv, worker: Socket, writerThread: WriterThread, context: TaskContext) - extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { - + class ReaderInputStream(worker: PythonWorker, writer: Writer) extends InputStream { + private[this] var writerIfbhThreadLocalValue: Object = null + private[this] val temp = new Array[Byte](1) + private[this] val bufferStream = new DirectByteBufferOutputStream() /** - * How long to wait before closing the socket if the writer thread has not exited after the task - * ends. + * Buffers data to be written to the Python worker until the socket is + * available for write. + * A best-effort attempt is made to not grow the buffer beyond "spark.buffer.size". See + * `writeAdditionalInputToPythonWorker()` for details. */ - private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) + private[this] var buffer: ByteBuffer = _ + private[this] var hasInput = true - setDaemon(true) + writer.open(new DataOutputStream(bufferStream)) + buffer = bufferStream.toByteBuffer - override def run(): Unit = { - // Wait until the task is completed (or the writer thread exits, in which case this thread has - // nothing to do). - while (!context.isCompleted && writerThread.isAlive) { - Thread.sleep(2000) + override def read(): Int = { + val n = read(temp) + if (n <= 0) { + -1 + } else { + // Signed byte to unsigned integer + temp(0) & 0xff } - if (writerThread.isAlive) { - Thread.sleep(taskKillTimeout) - // If the writer thread continues running, this indicates a deadlock. Kill the worker to - // resolve the deadlock. - if (writerThread.isAlive) { - try { - // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.attemptNumber} " + - s"in stage ${context.stageId} (TID ${context.taskAttemptId})" - logWarning( - s"Detected deadlock while completing task $taskName: " + - "Attempting to kill Python Worker") - env.destroyPythonWorker( - pythonExec, workerModule, daemonModule, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + // The code below manipulates the InputFileBlockHolder thread local in order + // to prevent behavior changes in the input_file_name() expression due to the switch from + // multi-threaded to single-threaded Python execution (SPARK-44705). + // + // Prior to that change, scan operations feeding into PythonRunner would be evaluated in + // "writer" threads that were child threads of the main task thread. As a result, when + // a scan operation hit end-of-input and called InputFileBlockHolder.unset(), the effects + // of unset() would only occur in the writer thread and not the main task thread: this + // meant that code "downstream" of a PythonRunner would continue to observe the writer's + // last pre-unset() value (i.e. the last read filename). + // + // Switching to a single-threaded Python runner changed this behavior: now, unset() would + // impact operators both upstream and downstream of the PythonRunner and this would cause + // unset()'s effects to be immediately visible to downstream operators, in turn causing the + // input_file_name() expression to return empty filenames in situations where it previously + // would have returned the last non-empty filename. + // + // To avoid this behavior change, the code below simulates the behavior of the + // InputFileBlockHolder's inheritable thread local: + // + // - Detect whether code that previously would have run in the writer thread has changed + // the thread local value itself. Note that the thread local holds a mutable + // AtomicReference, so the thread local's value only changes objects when unset() is + // called. + // - If an object change was detected, then henceforth we will swap between the "main" + // and "writer" thread local values when context switching between upstream and + // downstream operator execution. + // + // This issue is subtle and several other alternative approaches were considered + val buf = ByteBuffer.wrap(b, off, len) + var n = 0 + while (n == 0) { + worker.selector.select() + if (worker.selectionKey.isReadable) { + n = worker.channel.read(buf) + } + if (worker.selectionKey.isWritable) { + val mainIfbhThreadLocalValue = InputFileBlockHolder.getThreadLocalValue() + // Check whether the writer's thread local value has diverged from its parent's value: + if (writerIfbhThreadLocalValue eq null) { + // Default case (which is why it appears first): the writer's thread local value + // is the same object as the main code, so no need to swap before executing the + // writer code. + try { + // Execute the writer code: + writeAdditionalInputToPythonWorker() + } finally { + // Check whether the writer code changed the thread local value: + val maybeNewIfbh = InputFileBlockHolder.getThreadLocalValue() + if (maybeNewIfbh ne mainIfbhThreadLocalValue) { + // The writer thread change the thread local, so henceforth we need to + // swap. Store the writer thread's value and restore the old main thread + // value: + writerIfbhThreadLocalValue = maybeNewIfbh + InputFileBlockHolder.setThreadLocalValue(mainIfbhThreadLocalValue) + } + } + } else { + // The writer thread and parent thread have different values, so we must swap + // them when switching between writer and parent code: + try { + // Swap in the writer value: + InputFileBlockHolder.setThreadLocalValue(writerIfbhThreadLocalValue) + try { + // Execute the writer code: + writeAdditionalInputToPythonWorker() + } finally { + // Store an updated writer thread value: + writerIfbhThreadLocalValue = InputFileBlockHolder.getThreadLocalValue() + } + } finally { + // Restore the main thread's value: + InputFileBlockHolder.setThreadLocalValue(mainIfbhThreadLocalValue) + } } } } + n + } + + private var lastFlushTime = System.nanoTime() + + /** + * Returns false if `timelyFlushEnabled` is disabled. + * + * Otherwise, returns true if `buffer` should be flushed before any additional data is + * written to it. + * For small input rows the data might stay in the buffer for long before it is sent to the + * Python worker. We should flush the buffer periodically so that the downstream can make + * continued progress. + */ + private def shouldFlush(): Boolean = { + if (!timelyFlushEnabled) { + false + } else { + val currentTime = System.nanoTime() + if (currentTime - lastFlushTime > timelyFlushTimeoutNanos) { + lastFlushTime = currentTime + bufferStream.size() > 0 + } else { + false + } + } + } + + /** + * Reads input data from `writer.inputIterator` into `buffer` and writes the buffer to the + * Python worker if the socket is available for writing. + */ + private def writeAdditionalInputToPythonWorker(): Unit = { + var acceptsInput = true + while (acceptsInput && (hasInput || buffer.hasRemaining)) { + if (!buffer.hasRemaining && hasInput) { + // No buffered data is available. Try to read input into the buffer. + bufferStream.reset() + // Set the `buffer` to null to make it eligible for GC + buffer = null + + val dataOut = new DataOutputStream(bufferStream) + // Try not to grow the buffer much beyond `bufferSize`. This is inevitable for large + // input rows. + while (bufferStream.size() < bufferSize && hasInput && !shouldFlush()) { + hasInput = writer.writeNextInputToStream(dataOut) + } + if (!hasInput) { + // Reached the end of the input. + writer.close(dataOut) + } + buffer = bufferStream.toByteBuffer + } + + // Try to write as much buffered data as possible to the socket. + while (buffer.hasRemaining && acceptsInput) { + val n = worker.channel.write(buffer) + acceptsInput = n > 0 + } + } + + if (!hasInput && !buffer.hasRemaining) { + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + bufferStream.close() + } } } + } private[spark] object PythonRunner { // already running worker monitor threads for worker and task attempts ID pairs - val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]() + val runningMonitorThreads = ConcurrentHashMap.newKeySet[(PythonWorker, Long)]() private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true) @@ -693,13 +807,13 @@ private[spark] class PythonRunner( extends BasePythonRunner[Array[Byte], Array[Byte]]( funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) { - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { val command = funcs.head.funcs.head.command @@ -707,28 +821,32 @@ private[spark] class PythonRunner( dataOut.write(command.toArray) } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { + if (PythonRDD.writeNextElementToStream(inputIterator, dataOut)) { + true + } else { + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + false + } } } } protected override def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { protected override def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get + if (writer.exception.isDefined) { + throw writer.exception.get } try { stream.readInt() match { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4ba6dd949b1..1db8748c327 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -18,7 +18,8 @@ package org.apache.spark.api.python import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream} -import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.net.{InetAddress, InetSocketAddress, SocketException} +import java.nio.channels._ import java.util.Arrays import java.util.concurrent.TimeUnit import javax.annotation.concurrent.GuardedBy @@ -33,6 +34,14 @@ import org.apache.spark.internal.config.Python._ import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} +case class PythonWorker(channel: SocketChannel, selector: Selector, selectionKey: SelectionKey) { + def stop(): Unit = { + selectionKey.cancel() + selector.close() + channel.close() + } +} + private[spark] class PythonWorkerFactory( pythonExec: String, workerModule: String, @@ -67,32 +76,33 @@ private[spark] class PythonWorkerFactory( @GuardedBy("self") private var daemonPort: Int = 0 @GuardedBy("self") - private val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Int]() @GuardedBy("self") - private val idleWorkers = new mutable.Queue[Socket]() + private val idleWorkers = new mutable.Queue[PythonWorker]() @GuardedBy("self") private var lastActivityNs = 0L new MonitorThread().start() @GuardedBy("self") - private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() + private val simpleWorkers = new mutable.WeakHashMap[PythonWorker, Process]() private val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, envVars.getOrElse("PYTHONPATH", ""), sys.env.getOrElse("PYTHONPATH", "")) - def create(): (Socket, Option[Int]) = { + def create(): (PythonWorker, Option[Int]) = { if (useDaemon) { self.synchronized { if (idleWorkers.nonEmpty) { val worker = idleWorkers.dequeue() + worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE) return (worker, daemonWorkers.get(worker)) } } createThroughDaemon() } else { - createSimpleWorker() + createSimpleWorker(blockingMode = false) } } @@ -101,18 +111,25 @@ private[spark] class PythonWorkerFactory( * processes itself to avoid the high cost of forking from Java. This currently only works * on UNIX-based systems. */ - private def createThroughDaemon(): (Socket, Option[Int]) = { + private def createThroughDaemon(): (PythonWorker, Option[Int]) = { - def createSocket(): (Socket, Option[Int]) = { - val socket = new Socket(daemonHost, daemonPort) - val pid = new DataInputStream(socket.getInputStream).readInt() + def createWorker(): (PythonWorker, Option[Int]) = { + val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort)) + // These calls are blocking. + val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } - authHelper.authToServer(socket) - daemonWorkers.put(socket, pid) - (socket, Some(pid)) + authHelper.authToServer(socketChannel.socket()) + socketChannel.configureBlocking(false) + val selector = Selector.open() + val selectionKey = socketChannel.register(selector, + SelectionKey.OP_READ | SelectionKey.OP_WRITE) + val worker = PythonWorker(socketChannel, selector, selectionKey) + + daemonWorkers.put(worker, pid) + (worker, Some(pid)) } self.synchronized { @@ -121,14 +138,14 @@ private[spark] class PythonWorkerFactory( // Attempt to connect, restart and retry once if it fails try { - createSocket() + createWorker() } catch { case exc: SocketException => logWarning("Failed to open socket to Python daemon:", exc) logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() - createSocket() + createWorker() } } } @@ -136,10 +153,11 @@ private[spark] class PythonWorkerFactory( /** * Launch a worker by executing worker.py (by default) directly and telling it to connect to us. */ - private[spark] def createSimpleWorker(): (Socket, Option[Int]) = { - var serverSocket: ServerSocket = null + private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = { + var serverSocketChannel: ServerSocketChannel = null try { - serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) + serverSocketChannel = ServerSocketChannel.open() + serverSocketChannel.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1) // Create and start the worker val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) @@ -154,38 +172,49 @@ private[spark] class PythonWorkerFactory( workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") - workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocketChannel.socket().getLocalPort + .toString) workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) if (Utils.preferIPv6) { workerEnv.put("SPARK_PREFER_IPV6", "True") } - val worker = pb.start() + val workerProcess = pb.start() // Redirect worker stdout and stderr - redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) + redirectStreamsToStderr(workerProcess.getInputStream, workerProcess.getErrorStream) // Wait for it to connect to our socket, and validate the auth secret. - serverSocket.setSoTimeout(10000) + serverSocketChannel.socket().setSoTimeout(10000) try { - val socket = serverSocket.accept() - authHelper.authClient(socket) - // TODO: When we drop JDK 8, we can just use worker.pid() - val pid = new DataInputStream(socket.getInputStream).readInt() + val socketChannel = serverSocketChannel.accept() + authHelper.authClient(socketChannel.socket()) + // TODO: When we drop JDK 8, we can just use workerProcess.pid() + val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt() if (pid < 0) { throw new IllegalStateException("Python failed to launch worker with code " + pid) } + if (!blockingMode) { + socketChannel.configureBlocking(false) + } + val selector = Selector.open() + val selectionKey = if (blockingMode) { + null + } else { + socketChannel.register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE) + } + val worker = PythonWorker(socketChannel, selector, selectionKey) self.synchronized { - simpleWorkers.put(socket, worker) + simpleWorkers.put(worker, workerProcess) } - return (socket, Some(pid)) + return (worker, Some(pid)) } catch { case e: Exception => throw new SparkException("Python worker failed to connect back.", e) } } finally { - if (serverSocket != null) { - serverSocket.close() + if (serverSocketChannel != null) { + serverSocketChannel.close() } } null @@ -320,11 +349,10 @@ private[spark] class PythonWorkerFactory( while (idleWorkers.nonEmpty) { val worker = idleWorkers.dequeue() try { - // the worker will exit after closing the socket - worker.close() + worker.stop() } catch { case e: Exception => - logWarning("Failed to close worker socket", e) + logWarning("Failed to stop worker socket", e) } } } @@ -351,7 +379,7 @@ private[spark] class PythonWorkerFactory( stopDaemon() } - def stopWorker(worker: Socket): Unit = { + def stopWorker(worker: PythonWorker): Unit = { self.synchronized { if (useDaemon) { if (daemon != null) { @@ -367,22 +395,21 @@ private[spark] class PythonWorkerFactory( simpleWorkers.get(worker).foreach(_.destroy()) } } - worker.close() + worker.stop() } - def releaseWorker(worker: Socket): Unit = { + def releaseWorker(worker: PythonWorker): Unit = { if (useDaemon) { self.synchronized { lastActivityNs = System.nanoTime() idleWorkers.enqueue(worker) } } else { - // Cleanup the worker socket. This will also cause the Python worker to exit. try { - worker.close() + worker.stop() } catch { case e: Exception => - logWarning("Failed to close worker socket", e) + logWarning("Failed to close worker", e) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala index b6ab031d388..3f7b11a40ad 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala @@ -18,7 +18,6 @@ package org.apache.spark.api.python import java.io.{DataInputStream, DataOutputStream, File} -import java.net.Socket import java.nio.charset.StandardCharsets import org.apache.spark.{SparkEnv, SparkFiles} @@ -76,7 +75,7 @@ private[spark] object PythonWorkerUtils extends Logging { */ def writeBroadcasts( broadcastVars: Seq[Broadcast[PythonBroadcast]], - worker: Socket, + worker: PythonWorker, env: SparkEnv, dataOut: DataOutputStream): Unit = { // Broadcast variables @@ -117,9 +116,6 @@ private[spark] object PythonWorkerUtils extends Logging { dataOut.writeLong(id) } dataOut.flush() - logTrace("waiting for python to read decrypted broadcast data from server") - server.waitTillBroadcastDataSent() - logTrace("done sending decrypted data to python") } else { sendBidsToRemove() for (broadcast <- broadcastVars) { diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index fdfe388db2d..e82052e41be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -18,7 +18,6 @@ package org.apache.spark.api.python import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} -import java.net.Socket import scala.collection.JavaConverters._ @@ -50,7 +49,8 @@ private[spark] class StreamingPythonRunner( private val envVars: java.util.Map[String, String] = func.envVars private val pythonExec: String = func.pythonExec - private var pythonWorker: Option[Socket] = None + private var pythonWorker: Option[PythonWorker] = None + private var pythonWorkerFactory: Option[PythonWorkerFactory] = None protected val pythonVer: String = func.pythonVer /** @@ -71,14 +71,17 @@ private[spark] class StreamingPythonRunner( val prevConf = conf.get(PYTHON_USE_DAEMON) conf.set(PYTHON_USE_DAEMON, false) try { - val (worker, _) = env.createPythonWorker( - pythonExec, workerModule, envVars.asScala.toMap) + val workerFactory = + new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) + val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true) pythonWorker = Some(worker) + pythonWorkerFactory = Some(workerFactory) } finally { conf.set(PYTHON_USE_DAEMON, prevConf) } - val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream, bufferSize) + val stream = new BufferedOutputStream( + pythonWorker.get.channel.socket().getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) @@ -93,7 +96,7 @@ private[spark] class StreamingPythonRunner( dataOut.flush() val dataIn = new DataInputStream( - new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize)) + new BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, bufferSize)) val resFromPython = dataIn.readInt() logInfo(s"Runner initialization returned $resFromPython") diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala index 8230144025f..5f2a9dd2743 100644 --- a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -55,6 +55,14 @@ private[spark] object InputFileBlockHolder { new AtomicReference(new FileBlock) } + private[spark] def setThreadLocalValue(ref: Object): Unit = { + inputBlock.set(ref.asInstanceOf[AtomicReference[FileBlock]]) + } + + private[spark] def getThreadLocalValue(): Object = { + inputBlock.get() + } + /** * Returns the holding file name or empty string if it is unknown. */ @@ -72,6 +80,9 @@ private[spark] object InputFileBlockHolder { /** * Sets the thread-local input block. + * + * Callers of this method must ensure a task completion listener has been registered to unset() + * the thread local in the task thread. */ def set(filePath: String, startOffset: Long, length: Long): Unit = { require(filePath != null, "filePath cannot be null") diff --git a/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala new file mode 100644 index 00000000000..a4145bb36ac --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/DirectByteBufferOutputStream.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.OutputStream +import java.nio.ByteBuffer + +import org.apache.spark.storage.StorageUtils +import org.apache.spark.unsafe.Platform + +/** + * An output stream that dumps data into a direct byte buffer. The byte buffer grows in size + * as more data is written to the stream. + * @param capacity The initial capacity of the direct byte buffer + */ +private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputStream { + private var buffer = Platform.allocateDirectBuffer(capacity) + + def this() = this(32) + + override def write(b: Int): Unit = { + ensureCapacity(buffer.position() + 1) + buffer.put(b.toByte) + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + ensureCapacity(buffer.position() + len) + buffer.put(b, off, len) + } + + private def ensureCapacity(minCapacity: Int): Unit = { + if (minCapacity > buffer.capacity()) grow(minCapacity) + } + + /** + * Grows the current buffer to at least `minCapacity` capacity. + * As a side effect, all references to the old buffer will be invalidated. + */ + private def grow(minCapacity: Int): Unit = { + val oldCapacity = buffer.capacity() + var newCapacity = oldCapacity << 1 + if (newCapacity < minCapacity) newCapacity = minCapacity + val oldBuffer = buffer + oldBuffer.flip() + val newBuffer = ByteBuffer.allocateDirect(newCapacity) + newBuffer.put(oldBuffer) + StorageUtils.dispose(oldBuffer) + buffer = newBuffer + } + + def reset(): Unit = buffer.clear() + + def size(): Int = buffer.position() + + /** + * Any subsequent call to [[close()]], [[write()]], [[reset()]] will invalidate the buffer + * returned by this method. + */ + def toByteBuffer: ByteBuffer = { + val outputBuffer = buffer.duplicate() + outputBuffer.flip() + outputBuffer + } + + override def close(): Unit = { + // Eagerly free the direct byte buffer without waiting for GC to reduce memory pressure. + StorageUtils.dispose(buffer) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index d4c535fe76a..a60d0beeeed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -55,7 +55,7 @@ class ApplyInPandasWithStatePythonRunner( evalType: Int, argOffsets: Array[Array[Int]], inputSchema: StructType, - override protected val timeZoneId: String, + _timeZoneId: String, initialWorkerConf: Map[String, String], stateEncoder: ExpressionEncoder[Row], keySchema: StructType, @@ -73,8 +73,10 @@ class ApplyInPandasWithStatePythonRunner( private val sqlConf = SQLConf.get - override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) - + // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s + // constructor. + override protected lazy val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) + override protected lazy val timeZoneId: String = _timeZoneId override val errorOnDuplicatedFieldNames: Boolean = true override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback @@ -113,37 +115,41 @@ class ApplyInPandasWithStatePythonRunner( // Also write the schema for state value PythonRDD.writeUTF(stateValueSchema.json, stream) } - + private var pandasWriter: ApplyInPandasWithStateWriter = _ /** * Read the (key, state, values) from input iterator and construct Arrow RecordBatches, and * write constructed RecordBatches to the writer. * * See [[ApplyInPandasWithStateWriter]] for more details. */ - protected def writeIteratorToArrowStream( + protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, - inputIterator: Iterator[InType]): Unit = { - val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) - - while (inputIterator.hasNext) { + inputIterator: Iterator[InType]): Boolean = { + if (pandasWriter == null) { + pandasWriter = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) + } + if (inputIterator.hasNext) { val startData = dataOut.size() val (keyRow, groupState, dataIter) = inputIterator.next() assert(dataIter.hasNext, "should have at least one data row!") - w.startNewGroup(keyRow, groupState) + pandasWriter.startNewGroup(keyRow, groupState) while (dataIter.hasNext) { val dataRow = dataIter.next() - w.writeRow(dataRow) + pandasWriter.writeRow(dataRow) } - w.finalizeGroup() + pandasWriter.finalizeGroup() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + true + } else { + pandasWriter.finalizeData() + super[PythonArrowInput].close() + false } - - w.finalizeData() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index d9bce96c477..0f26d8f21f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -31,8 +31,8 @@ class ArrowPythonRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], - protected override val schema: StructType, - protected override val timeZoneId: String, + _schema: StructType, + _timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric], @@ -50,6 +50,10 @@ class ArrowPythonRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s + // constructor. + override protected lazy val timeZoneId: String = _timeZoneId + override protected lazy val schema: StructType = _schema override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize require( bufferSize >= 4, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 9dae874e3ed..6c8412f8b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.Socket import scala.collection.JavaConverters._ import net.razorvine.pickle.Unpickler import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorkerUtils} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorker, PythonWorkerUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -101,13 +100,13 @@ class PythonUDTFRunner( Seq(ChainedPythonFunctions(Seq(udtf.func))), PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, jobArtifactUUID) { - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext): WriterThread = { - new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { + context: TaskContext): Writer = { + new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { PythonUDTFRunner.writeUDTF(dataOut, udtf, argOffsets) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index eef8be7c940..bd901545bb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -18,13 +18,12 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric @@ -60,14 +59,14 @@ class CoGroupedArrowPythonRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - protected def newWriterThread( + protected def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])], partitionIndex: Int, - context: TaskContext): WriterThread = { + context: TaskContext): Writer = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { + new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { @@ -81,10 +80,10 @@ class CoGroupedArrowPythonRunner( PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { // For each we first send the number of dataframes in each group then send // first df, then send second df. End of data is marked by sending 0. - while (inputIterator.hasNext) { + if (inputIterator.hasNext) { val startData = dataOut.size() dataOut.writeInt(2) val (nextLeft, nextRight) = inputIterator.next() @@ -93,8 +92,11 @@ class CoGroupedArrowPythonRunner( val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + true + } else { + dataOut.writeInt(0) + false } - dataOut.writeInt(0) } private def writeGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala index 10bb3a45be9..373e17c0aa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonEvaluatorFactory.scala @@ -21,7 +21,7 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -62,7 +62,6 @@ abstract class EvalPythonEvaluatorFactory( iters: Iterator[InternalRow]*): Iterator[InternalRow] = { val iter = iters.head val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, iter) // The queue used to buffer input rows so we can drain it to // combine input with output from Python. @@ -97,7 +96,7 @@ abstract class EvalPythonEvaluatorFactory( }.toArray) // Add rows to queue to join later with the result. - val projectedRowIter = contextAwareIterator.map { inputRow => + val projectedRowIter = iter.map { inputRow => queue.add(inputRow.asInstanceOf[UnsafeRow]) projection(inputRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index 8d2f788e05c..6664acf9572 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.{ContextAwareIterator, TaskContext} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -302,7 +301,7 @@ object EvaluatePython { def javaToPython(rdd: RDD[Any]): RDD[Array[Byte]] = { rdd.mapPartitions { iter => registerPicklers() // let it called in executor - new SerDeUtil.AutoBatchedPickler(new ContextAwareIterator(TaskContext.get, iter)) + new SerDeUtil.AutoBatchedPickler(iter) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 1e15aa7f777..6f501e1411a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ -import org.apache.spark.{ContextAwareIterator, PartitionEvaluator, PartitionEvaluatorFactory, TaskContext} +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory, TaskContext} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -52,11 +52,10 @@ class MapInBatchEvaluatorFactory( // Single function with one struct. val argOffsets = Array(Array(0)) val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, inputIter) // Here we wrap it via another row so that Python sides understand it // as a DataFrame. - val wrappedIter = contextAwareIterator.map(InternalRow(_)) + val wrappedIter = inputIter.map(InternalRow(_)) // DO NOT use iter.grouped(). See BatchIterator. val batchIter = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 5c99a3f9808..00ee3a17563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.Socket import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType @@ -48,11 +48,11 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected def pythonMetrics: Map[String, SQLMetric] - protected def writeIteratorToArrowStream( + protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, - inputIterator: Iterator[IN]): Unit + inputIterator: Iterator[IN]): Boolean protected def writeUDF( dataOut: DataOutputStream, @@ -68,51 +68,46 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => PythonRDD.writeUTF(v, stream) } } + private val arrowSchema = ArrowUtils.toArrowSchema( + schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + private val allocator = + ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue) + protected val root = VectorSchemaRoot.create(arrowSchema, allocator) + protected var writer: ArrowStreamWriter = _ + +protected def close(): Unit = { + Utils.tryWithSafeFinally { + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + root.close() + allocator.close() + } +} - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[IN], partitionIndex: Int, - context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { - + context: TaskContext): Writer = { + new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) writeUDF(dataOut, funcs, argOffsets) } - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema( - schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { - Utils.tryWithSafeFinally { - val writer = new ArrowStreamWriter(root, null, dataOut) + if (writer == null) { + writer = new ArrowStreamWriter(root, null, dataOut) writer.start() - - writeIteratorToArrowStream(root, writer, dataOut, inputIterator) - - // end writes footer to the output stream and doesn't clean any resources. - // It could throw exception if the output stream is closed, so it should be - // in the try block. - writer.end() - } { - // If we close root and allocator in TaskCompletionListener, there could be a race - // condition where the writer thread keeps writing to the VectorSchemaRoot while - // it's being closed by the TaskCompletion listener. - // Closing root and allocator here is cleaner because root and allocator is owned - // by the writer thread and is only visible to the writer thread. - // - // If the writer thread is interrupted by TaskCompletionListener, it should either - // (1) in the try block, in which case it will get an InterruptedException when - // performing io, and goes into the finally block or (2) in the finally block, - // in which case it will ignore the interruption and close the resources. - root.close() - allocator.close() } + + assert(writer != null) + writeNextInputToArrowStream(root, writer, dataOut, inputIterator) } } } @@ -120,15 +115,15 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[InternalRow]] { self: BasePythonRunner[Iterator[InternalRow], _] => + private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root) - protected def writeIteratorToArrowStream( + protected def writeNextInputToArrowStream( root: VectorSchemaRoot, writer: ArrowStreamWriter, dataOut: DataOutputStream, - inputIterator: Iterator[Iterator[InternalRow]]): Unit = { - val arrowWriter = ArrowWriter.create(root) + inputIterator: Iterator[Iterator[InternalRow]]): Boolean = { - while (inputIterator.hasNext) { + if (inputIterator.hasNext) { val startData = dataOut.size() val nextBatch = inputIterator.next() @@ -141,6 +136,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In arrowWriter.reset() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + true + } else { + super[PythonArrowInput].close() + false } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index c12c690f776..8f99325e4e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.python import java.io.DataInputStream -import java.net.Socket import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ @@ -26,7 +25,7 @@ import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.api.python.{BasePythonRunner, PythonWorker, SpecialLengths} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils @@ -46,16 +45,16 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ protected def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[OUT] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( s"stdin reader for $pythonExec", 0, Long.MaxValue) @@ -80,8 +79,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ } protected override def read(): OUT = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get + if (writer.exception.isDefined) { + throw writer.exception.get } try { if (reader != null && batchLoaded) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 3857f084bcb..a229931cec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -31,6 +31,44 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.{NextIterator, Utils} +/** + * Writes the rows buffered in [[UnsafeRowBuffer]] to the Python worker. + * Any exceptions encountered will be cached to be read later by the parent thread. + */ +class WriterThread(outputIterator: Iterator[Array[Byte]]) + extends Thread(s"Thread streaming data to the Python worker") { + + @volatile var _exception: Throwable = _ + + override def run(): Unit = { + try { + // [[PythonForEachWriter]] is a sink and thus the Python worker does not generate any output. + // The `hasNext()` and `next()` call are an indirect way to ship the input data to the + // Python worker. Consuming the Python worker's output iterator, as a side-effect, drives the + // write of the input data to the Python worker through [[org.apache.spark.api.python. + // BasePythonRunner.ReaderInputStream .writeAdditionalInputToPythonWorker]]. + if (outputIterator.hasNext) { + outputIterator.next() + } + } catch { + // Cache exceptions seen while evaluating the Python function on the streamed input. The + // parent thread will throw this crashed exception eventually. + case t: Throwable => + _exception = t + } + } +} + +/** + * The class proceeds as follows: + * - Rows streamed through a `process()` call on the + * [[org.apache.spark.sql.execution.streaming.QueryExecutionThread]] are buffered in the + * `UnsafeRowBuffer`. + * - The [[WriterThread]] streams the buffered data to the Python worker. + * - Once the streaming query ends, [[close()]] is called which signals the buffer to mark the + * end of streaming input. The streaming query execution thread waits for the [[WriterThread]] to + * complete and throws any exceptions seen by the [[WriterThread]]. + */ class PythonForeachWriter(func: PythonFunction, schema: StructType) extends ForeachWriter[UnsafeRow] { @@ -58,8 +96,11 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) private lazy val outputIterator = pythonRunner.compute(inputByteIterator, context.partitionId(), context) + private lazy val writerThread = new WriterThread(outputIterator) + override def open(partitionId: Long, version: Long): Boolean = { outputIterator // initialize everything + writerThread.start() TaskContext.get.addTaskCompletionListener[Unit] { _ => buffer.close() } true } @@ -68,9 +109,15 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) buffer.add(value) } + /** + * Waits for the writer thread to finish evaluating the Python function. Throws any exceptions + * seen by the writer thread. + */ override def close(errorOrNull: Throwable): Unit = { buffer.allRowsAdded() - if (outputIterator.hasNext) outputIterator.next() // to throw python exception if there was one + writerThread.join() + // Throw Python exception if there was one. + if (writerThread._exception != null) throw writerThread._exception } } @@ -78,18 +125,20 @@ object PythonForeachWriter { /** * A buffer that is designed for the sole purpose of buffering UnsafeRows in PythonForeachWriter. - * It is designed to be used with only 1 writer thread (i.e. JVM task thread) and only 1 reader - * thread (i.e. PythonRunner writing thread that reads from the buffer and writes to the Python - * worker stdin). Adds to the buffer are non-blocking, and reads through the buffer's iterator - * are blocking, that is, it blocks until new data is available or all data has been added. + * It is designed to be used with only two threads: the QueryExecutionThread which writes data + * to the buffer and [[WriterThread]] thread that reads from the buffer and writes to the + * Python worker stdin. Adds to the buffer are non-blocking, and reads through the buffer's + * iterator are blocking, that is, it blocks until new data is available or all data has been + * added. * * Internally, it uses a [[HybridRowQueue]] to buffer the rows in a practically unlimited queue * across memory and local disk. However, HybridRowQueue is designed to be used only with - * EvalPythonExec where the reader is always behind the writer, that is, the reader does not - * try to read n+1 rows if the writer has only written n rows at any point of time. This - * assumption is not true for PythonForeachWriter where rows may be added at a different rate as - * they are consumed by the python worker. Hence, to maintain the invariant of the reader being - * behind the writer while using HybridRowQueue, the buffer does the following + * EvalPythonExec where the buffer's consumer is always behind the buffer's populator, that is, + * the [[WriterThread]] does not try to read n + 1 rows if the streaming thread has only + * written n rows at any point of time. This assumption is not true for PythonForeachWriter + * where rows may be added at a different rate as they are consumed by the Python worker. + * Hence, to maintain the invariant of the reader being behind the writer while using + * HybridRowQueue, the buffer does the following: * - Keeps a count of the rows in the HybridRowQueue * - Blocks the buffer's consuming iterator when the count is 0 so that the reader does not * try to read more rows than what has been written. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 22083e0473b..bc27ee6919d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.python import java.io._ -import java.net._ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark._ @@ -44,40 +43,42 @@ abstract class BasePythonUDFRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - abstract class PythonUDFWriterThread( + abstract class PythonUDFWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, context: TaskContext) - extends WriterThread(env, worker, inputIterator, partitionIndex, context) { + extends Writer(env, worker, inputIterator, partitionIndex, context) { - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { val startData = dataOut.size() - - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - + val wroteData = PythonRDD.writeNextElementToStream(inputIterator, dataOut) + if (!wroteData) { + // Reached the end of input. + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + wroteData } } protected override def newReaderIterator( stream: DataInputStream, - writerThread: WriterThread, + writer: Writer, startTime: Long, env: SparkEnv, - worker: Socket, + worker: PythonWorker, pid: Option[Int], releasedOrClosed: AtomicBoolean, context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, writer, startTime, env, worker, pid, releasedOrClosed, context) { protected override def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get + if (writer.exception.isDefined) { + throw writer.exception.get } try { stream.readInt() match { @@ -110,13 +111,13 @@ class PythonUDFRunner( jobArtifactUUID: Option[String]) extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { - protected override def newWriterThread( + protected override def newWriter( env: SparkEnv, - worker: Socket, + worker: PythonWorker, inputIterator: Iterator[Array[Byte]], partitionIndex: Int, - context: TaskContext): WriterThread = { - new PythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { + context: TaskContext): Writer = { + new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 36cb2e17835..5fa9c89b3d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.python -import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} -import java.net.Socket +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream} +import java.nio.ByteBuffer +import java.nio.channels.SelectionKey import java.nio.charset.StandardCharsets import java.util.HashMap @@ -27,7 +28,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException} -import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths} import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} @@ -36,6 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRo import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.DirectByteBufferOutputStream /** * A user-defined Python function. This is used by the Python API. @@ -205,13 +207,14 @@ object UserDefinedPythonTableFunction { val pickler = new Pickler(/* useMemo = */ true, /* valueCompare = */ false) - val (worker: Socket, _) = + val (worker: PythonWorker, _) = env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap) var releasedOrClosed = false + val bufferStream = new DirectByteBufferOutputStream() try { - val dataOut = - new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize)) - val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize)) + val dataIn = new DataInputStream(new BufferedInputStream( + new WorkerInputStream(worker, bufferStream), bufferSize)) PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) @@ -276,4 +279,50 @@ object UserDefinedPythonTableFunction { } } } + + /** + * A wrapper of the non-blocking IO to write to/read from the worker. + * + * Since we use non-blocking IO to communicate with workers; see SPARK-44705, + * a wrapper is needed to do IO with the worker. + * This is a port and simplified version of `PythonRunner.ReaderInputStream`, + * and only supports to write all at once and then read all. + */ + private class WorkerInputStream( + worker: PythonWorker, bufferStream: DirectByteBufferOutputStream) extends InputStream { + + private[this] val temp = new Array[Byte](1) + + override def read(): Int = { + val n = read(temp) + if (n <= 0) { + -1 + } else { + // Signed byte to unsigned integer + temp(0) & 0xff + } + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + val buf = ByteBuffer.wrap(b, off, len) + var n = 0 + while (n == 0) { + worker.selector.select() + if (worker.selectionKey.isReadable) { + n = worker.channel.read(buf) + } + if (worker.selectionKey.isWritable) { + val buffer = bufferStream.toByteBuffer + var acceptsInput = true + while (acceptsInput && buffer.hasRemaining) { + val n = worker.channel.write(buffer) + acceptsInput = n > 0 + } + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + } + } + n + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org