Repository: spark
Updated Branches:
  refs/heads/master e139e2be6 -> 55349f9fe


[SPARK-1740] [PySpark] kill the python worker

Kill only the python worker related to cancelled tasks.

The daemon will start a background thread to monitor all the opened sockets for 
all workers. If the socket is closed by JVM, this thread will kill the worker.

When an task is cancelled, the socket to worker will be closed, then the worker 
will be killed by deamon.

Author: Davies Liu <[email protected]>

Closes #1643 from davies/kill and squashes the following commits:

8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy
46ca150 [Davies Liu] address comment
acd751c [Davies Liu] kill the worker when task is canceled


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/55349f9f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/55349f9f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/55349f9f

Branch: refs/heads/master
Commit: 55349f9fe81ba5af5e4a5e4908ebf174e63c6cc9
Parents: e139e2b
Author: Davies Liu <[email protected]>
Authored: Sun Aug 3 15:52:00 2014 -0700
Committer: Josh Rosen <[email protected]>
Committed: Sun Aug 3 15:52:00 2014 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/SparkEnv.scala  |  5 +-
 .../org/apache/spark/api/python/PythonRDD.scala |  9 ++-
 .../spark/api/python/PythonWorkerFactory.scala  | 64 +++++++++++++++-----
 python/pyspark/daemon.py                        | 24 ++++++--
 python/pyspark/tests.py                         | 51 ++++++++++++++++
 5 files changed, 125 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/55349f9f/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 92c809d..0bce531 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -18,6 +18,7 @@
 package org.apache.spark
 
 import java.io.File
+import java.net.Socket
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable
@@ -102,10 +103,10 @@ class SparkEnv (
   }
 
   private[spark]
-  def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
+  def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], 
worker: Socket) {
     synchronized {
       val key = (pythonExec, envVars)
-      pythonWorkers(key).stop()
+      pythonWorkers.get(key).foreach(_.stopWorker(worker))
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/55349f9f/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index fe9a9e5..0b5322c 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -62,8 +62,8 @@ private[spark] class PythonRDD(
     val env = SparkEnv.get
     val localdir = env.blockManager.diskBlockManager.localDirs.map(
       f => f.getPath()).mkString(",")
-    val worker: Socket = env.createPythonWorker(pythonExec,
-      envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
+    envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor 
thread
+    val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
 
     // Start a thread to feed the process input from our parent's iterator
     val writerThread = new WriterThread(env, worker, split, context)
@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
       if (!context.completed) {
         try {
           logWarning("Incomplete task interrupted: Attempting to kill Python 
Worker")
-          env.destroyPythonWorker(pythonExec, envVars.toMap)
+          env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
         } catch {
           case e: Exception =>
             logError("Exception when trying to kill worker", e)
@@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {
 
   /**
    * Convert an RDD of serialized Python dictionaries to Scala Maps (no 
recursive conversions).
-   * This function is outdated, PySpark does not use it anymore
    */
-  @deprecated
+  @deprecated("PySpark does not use it anymore", "1.1")
   def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
     pyRDD.rdd.mapPartitions { iter =>
       val unpickle = new Unpickler

http://git-wip-us.apache.org/repos/asf/spark/blob/55349f9f/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 15fe8a9..7af260d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.api.python
 
-import java.io.{DataInputStream, InputStream, OutputStreamWriter}
+import java.lang.Runtime
+import java.io.{DataOutputStream, DataInputStream, InputStream, 
OutputStreamWriter}
 import java.net.{InetAddress, ServerSocket, Socket, SocketException}
 
+import scala.collection.mutable
 import scala.collection.JavaConversions._
 
 import org.apache.spark._
@@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, 
envVars: Map[String
   var daemon: Process = null
   val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
   var daemonPort: Int = 0
+  var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+
+  var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
 
   val pythonPath = PythonUtils.mergePythonPaths(
     PythonUtils.sparkPythonPath,
@@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
    * to avoid the high cost of forking from Java. This currently only works on 
UNIX-based systems.
    */
   private def createThroughDaemon(): Socket = {
+
+    def createSocket(): Socket = {
+      val socket = new Socket(daemonHost, daemonPort)
+      val pid = new DataInputStream(socket.getInputStream).readInt()
+      if (pid < 0) {
+        throw new IllegalStateException("Python daemon failed to launch 
worker")
+      }
+      daemonWorkers.put(socket, pid)
+      socket
+    }
+
     synchronized {
       // Start the daemon if it hasn't been started
       startDaemon()
 
       // Attempt to connect, restart and retry once if it fails
       try {
-        val socket = new Socket(daemonHost, daemonPort)
-        val launchStatus = new DataInputStream(socket.getInputStream).readInt()
-        if (launchStatus != 0) {
-          throw new IllegalStateException("Python daemon failed to launch 
worker")
-        }
-        socket
+        createSocket()
       } catch {
         case exc: SocketException =>
           logWarning("Failed to open socket to Python daemon:", exc)
           logWarning("Assuming that daemon unexpectedly quit, attempting to 
restart")
           stopDaemon()
           startDaemon()
-          new Socket(daemonHost, daemonPort)
+          createSocket()
       }
     }
   }
@@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
       // Wait for it to connect to our socket
       serverSocket.setSoTimeout(10000)
       try {
-        return serverSocket.accept()
+        val socket = serverSocket.accept()
+        simpleWorkers.put(socket, worker)
+        return socket
       } catch {
         case e: Exception =>
           throw new SparkException("Python worker did not connect back in 
time", e)
@@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: 
String, envVars: Map[String
 
   private def stopDaemon() {
     synchronized {
-      // Request shutdown of existing daemon by sending SIGTERM
-      if (daemon != null) {
-        daemon.destroy()
-      }
+      if (useDaemon) {
+        // Request shutdown of existing daemon by sending SIGTERM
+        if (daemon != null) {
+          daemon.destroy()
+        }
 
-      daemon = null
-      daemonPort = 0
+        daemon = null
+        daemonPort = 0
+      } else {
+        simpleWorkers.mapValues(_.destroy())
+      }
     }
   }
 
   def stop() {
     stopDaemon()
   }
+
+  def stopWorker(worker: Socket) {
+    if (useDaemon) {
+      if (daemon != null) {
+        daemonWorkers.get(worker).foreach { pid =>
+          // tell daemon to kill worker by pid
+          val output = new DataOutputStream(daemon.getOutputStream)
+          output.writeInt(pid)
+          output.flush()
+          daemon.getOutputStream.flush()
+        }
+      }
+    } else {
+      simpleWorkers.get(worker).foreach(_.destroy())
+    }
+    worker.close()
+  }
 }
 
 private object PythonWorkerFactory {

http://git-wip-us.apache.org/repos/asf/spark/blob/55349f9f/python/pyspark/daemon.py
----------------------------------------------------------------------
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 9fde0dd..b00da83 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ from errno import EINTR, ECHILD
 from socket import AF_INET, SOCK_STREAM, SOMAXCONN
 from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
 from pyspark.worker import main as worker_main
-from pyspark.serializers import write_int
+from pyspark.serializers import read_int, write_int
 
 
 def compute_real_exit_code(exit_code):
@@ -67,7 +67,8 @@ def worker(sock):
     outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
     exit_code = 0
     try:
-        write_int(0, outfile)  # Acknowledge that the fork was successful
+        # Acknowledge that the fork was successful
+        write_int(os.getpid(), outfile)
         outfile.flush()
         worker_main(infile, outfile)
     except SystemExit as exc:
@@ -125,14 +126,23 @@ def manager():
                 else:
                     raise
             if 0 in ready_fds:
-                # Spark told us to exit by closing stdin
-                shutdown(0)
+                try:
+                    worker_pid = read_int(sys.stdin)
+                except EOFError:
+                    # Spark told us to exit by closing stdin
+                    shutdown(0)
+                try:
+                    os.kill(worker_pid, signal.SIGKILL)
+                except OSError:
+                    pass # process already died
+
+
             if listen_sock in ready_fds:
                 sock, addr = listen_sock.accept()
                 # Launch a worker process
                 try:
-                    fork_return_code = os.fork()
-                    if fork_return_code == 0:
+                    pid = os.fork()
+                    if pid == 0:
                         listen_sock.close()
                         try:
                             worker(sock)
@@ -143,11 +153,13 @@ def manager():
                             os._exit(0)
                     else:
                         sock.close()
+
                 except OSError as e:
                     print >> sys.stderr, "Daemon failed to fork PySpark 
worker: %s" % e
                     outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
                     write_int(-1, outfile)  # Signal that the fork failed
                     outfile.flush()
+                    outfile.close()
                     sock.close()
     finally:
         shutdown(1)

http://git-wip-us.apache.org/repos/asf/spark/blob/55349f9f/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 16fb5a9..acc3c30 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -790,6 +790,57 @@ class TestDaemon(unittest.TestCase):
         self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
 
 
+class TestWorker(PySparkTestCase):
+    def test_cancel_task(self):
+        temp = tempfile.NamedTemporaryFile(delete=True)
+        temp.close()
+        path = temp.name
+        def sleep(x):
+            import os, time
+            with open(path, 'w') as f:
+                f.write("%d %d" % (os.getppid(), os.getpid()))
+            time.sleep(100)
+
+        # start job in background thread
+        def run():
+            self.sc.parallelize(range(1)).foreach(sleep)
+        import threading
+        t = threading.Thread(target=run)
+        t.daemon = True
+        t.start()
+
+        daemon_pid, worker_pid = 0, 0
+        while True:
+            if os.path.exists(path):
+                data = open(path).read().split(' ')
+                daemon_pid, worker_pid = map(int, data)
+                break
+            time.sleep(0.1)
+
+        # cancel jobs
+        self.sc.cancelAllJobs()
+        t.join()
+
+        for i in range(50):
+            try:
+                os.kill(worker_pid, 0)
+                time.sleep(0.1)
+            except OSError:
+                break # worker was killed
+        else:
+            self.fail("worker has not been killed after 5 seconds")
+
+        try:
+            os.kill(daemon_pid, 0)
+        except OSError:
+            self.fail("daemon had been killed")
+
+    def test_fd_leak(self):
+        N = 1100 # fd limit is 1024 by default
+        rdd = self.sc.parallelize(range(N), N)
+        self.assertEquals(N, rdd.count())
+
+
 class TestSparkSubmit(unittest.TestCase):
     def setUp(self):
         self.programDir = tempfile.mkdtemp()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to