Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/19349#discussion_r141244651
--- Diff:
core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala ---
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io._
+import java.net._
+import java.nio.charset.StandardCharsets
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark._
+import org.apache.spark.internal.Logging
+import org.apache.spark.util._
+
+
+/**
+ * Enumerate the type of command that will be sent to the Python worker
+ */
+private[spark] object PythonEvalType {
+ val NON_UDF = 0
+ val SQL_BATCHED_UDF = 1
+ val SQL_PANDAS_UDF = 2
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a
list of chained Python
+ * functions (from bottom to top).
+ */
+private[spark] abstract class BasePythonRunner[IN, OUT](
+ funcs: Seq[ChainedPythonFunctions],
+ bufferSize: Int,
+ reuseWorker: Boolean,
+ evalType: Int,
+ argOffsets: Array[Array[Int]])
+ extends Logging {
+
+ require(funcs.length == argOffsets.length, "argOffsets should have the
same length as funcs")
+
+ // All the Python functions should have the same exec, version and
envvars.
+ protected val envVars = funcs.head.funcs.head.envVars
+ protected val pythonExec = funcs.head.funcs.head.pythonExec
+ protected val pythonVer = funcs.head.funcs.head.pythonVer
+
+ // TODO: support accumulator in multiple UDF
+ protected val accumulator = funcs.head.funcs.head.accumulator
+
+ def compute(
+ inputIterator: Iterator[IN],
+ partitionIndex: Int,
+ context: TaskContext): Iterator[OUT] = {
+ val startTime = System.currentTimeMillis
+ val env = SparkEnv.get
+ val localdir = env.blockManager.diskBlockManager.localDirs.map(f =>
f.getPath()).mkString(",")
+ envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor
thread
+ if (reuseWorker) {
+ envVars.put("SPARK_REUSE_WORKER", "1")
+ }
+ val worker: Socket = env.createPythonWorker(pythonExec,
envVars.asScala.toMap)
+ // Whether is the worker released into idle pool
+ val released = new AtomicBoolean(false)
+
+ // Start a thread to feed the process input from our parent's iterator
+ val writerThread = newWriterThread(env, worker, inputIterator,
partitionIndex, context)
+
+ context.addTaskCompletionListener { context =>
+ writerThread.shutdownOnTaskCompletion()
+ if (!reuseWorker || !released.get) {
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
+
+ writerThread.start()
+ new MonitorThread(env, worker, context).start()
+
+ // Return an iterator that read lines from the process's stdout
+ val stream = new DataInputStream(new
BufferedInputStream(worker.getInputStream, bufferSize))
+
+ val stdoutIterator = newReaderIterator(
+ stream, writerThread, startTime, env, worker, released, context)
+ new InterruptibleIterator(context, stdoutIterator)
+ }
+
+ protected def newWriterThread(
+ env: SparkEnv,
+ worker: Socket,
+ inputIterator: Iterator[IN],
+ partitionIndex: Int,
+ context: TaskContext): WriterThread
+
+ protected def newReaderIterator(
+ stream: DataInputStream,
+ writerThread: WriterThread,
+ startTime: Long,
+ env: SparkEnv,
+ worker: Socket,
+ released: AtomicBoolean,
+ context: TaskContext): Iterator[OUT]
+
+ /**
+ * The thread responsible for writing the data from the PythonRDD's
parent iterator to the
+ * Python process.
+ */
+ abstract class WriterThread(
+ env: SparkEnv,
+ worker: Socket,
+ inputIterator: Iterator[IN],
+ partitionIndex: Int,
+ context: TaskContext)
+ extends Thread(s"stdout writer for $pythonExec") {
+
+ @volatile private var _exception: Exception = null
+
+ private val pythonIncludes =
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+ private val broadcastVars =
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
+
+ setDaemon(true)
+
+ /** Contains the exception thrown while writing the parent iterator to
the Python process. */
+ def exception: Option[Exception] = Option(_exception)
+
+ /** Terminates the writer thread, ignoring any exceptions that may
occur due to cleanup. */
+ def shutdownOnTaskCompletion() {
+ assert(context.isCompleted)
+ this.interrupt()
+ }
+
+ def writeCommand(dataOut: DataOutputStream): Unit
+ def writeIteratorToStream(dataOut: DataOutputStream): Unit
--- End diff --
I'd leave few comments for methods that should be implemented here.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]