Repository: spark
Updated Branches:
  refs/heads/master 83b7a1c65 -> e595c8d08


[SPARK-3993] [PySpark] fix bug while reuse worker after take()

After take(), maybe there are some garbage left in the socket, then next task 
assigned to this worker will hang because of corrupted data.

We should make sure the socket is clean before reuse it, write END_OF_STREAM at 
the end, and check it after read out all result from python.

Author: Davies Liu <[email protected]>
Author: Davies Liu <[email protected]>

Closes #2838 from davies/fix_reuse and squashes the following commits:

8872914 [Davies Liu] fix tests
660875b [Davies Liu] fix bug while reuse worker after take()


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e595c8d0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e595c8d0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e595c8d0

Branch: refs/heads/master
Commit: e595c8d08a20a122295af62d5e9cc4116f9727f6
Parents: 83b7a1c
Author: Davies Liu <[email protected]>
Authored: Thu Oct 23 17:20:00 2014 -0700
Committer: Josh Rosen <[email protected]>
Committed: Thu Oct 23 17:20:00 2014 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/SparkEnv.scala   |  2 ++
 .../org/apache/spark/api/python/PythonRDD.scala  | 11 ++++++++++-
 python/pyspark/daemon.py                         |  5 ++++-
 python/pyspark/serializers.py                    |  1 +
 python/pyspark/tests.py                          | 19 ++++++++++++++++++-
 python/pyspark/worker.py                         | 11 +++++++++--
 6 files changed, 44 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e595c8d0/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index aba713c..906a00b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -68,6 +68,7 @@ class SparkEnv (
     val shuffleMemoryManager: ShuffleMemoryManager,
     val conf: SparkConf) extends Logging {
 
+  private[spark] var isStopped = false
   private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), 
PythonWorkerFactory]()
 
   // A general, soft-reference map for metadata needed during HadoopRDD split 
computation
@@ -75,6 +76,7 @@ class SparkEnv (
   private[spark] val hadoopJobMetadata = new 
MapMaker().softValues().makeMap[String, Any]()
 
   private[spark] def stop() {
+    isStopped = true
     pythonWorkers.foreach { case(key, worker) => worker.stop() }
     Option(httpFileServer).foreach(_.stop())
     mapOutputTracker.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/e595c8d0/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 29ca751..163dca6 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,6 +75,7 @@ private[spark] class PythonRDD(
     var complete_cleanly = false
     context.addTaskCompletionListener { context =>
       writerThread.shutdownOnTaskCompletion()
+      writerThread.join()
       if (reuse_worker && complete_cleanly) {
         env.releasePythonWorker(pythonExec, envVars.toMap, worker)
       } else {
@@ -145,7 +146,9 @@ private[spark] class PythonRDD(
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
               }
-               complete_cleanly = true
+              if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+                complete_cleanly = true
+              }
               null
           }
         } catch {
@@ -154,6 +157,10 @@ private[spark] class PythonRDD(
             logDebug("Exception thrown after task interruption", e)
             throw new TaskKilledException
 
+          case e: Exception if env.isStopped =>
+            logDebug("Exception thrown after context is stopped", e)
+            null  // exit silently
+
           case e: Exception if writerThread.exception.isDefined =>
             logError("Python worker exited unexpectedly (crashed)", e)
             logError("This may have been caused by a prior exception:", 
writerThread.exception.get)
@@ -235,6 +242,7 @@ private[spark] class PythonRDD(
         // Data values
         PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), 
dataOut)
         dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
         dataOut.flush()
       } catch {
         case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -306,6 +314,7 @@ private object SpecialLengths {
   val END_OF_DATA_SECTION = -1
   val PYTHON_EXCEPTION_THROWN = -2
   val TIMING_DATA = -3
+  val END_OF_STREAM = -4
 }
 
 private[spark] object PythonRDD extends Logging {

http://git-wip-us.apache.org/repos/asf/spark/blob/e595c8d0/python/pyspark/daemon.py
----------------------------------------------------------------------
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 64d6202..dbb3477 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ import time
 import gc
 from errno import EINTR, ECHILD, EAGAIN
 from socket import AF_INET, SOCK_STREAM, SOMAXCONN
-from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
 from pyspark.worker import main as worker_main
 from pyspark.serializers import read_int, write_int
 
@@ -46,6 +46,9 @@ def worker(sock):
     signal.signal(SIGHUP, SIG_DFL)
     signal.signal(SIGCHLD, SIG_DFL)
     signal.signal(SIGTERM, SIG_DFL)
+    # restore the handler for SIGINT,
+    # it's useful for debugging (show the stacktrace before exit)
+    signal.signal(SIGINT, signal.default_int_handler)
 
     # Read the socket using fdopen instead of socket.makefile() because the 
latter
     # seems to be very slow; note that we need to dup() the file descriptor 
because

http://git-wip-us.apache.org/repos/asf/spark/blob/e595c8d0/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 08a0f0d..904bd9f 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -80,6 +80,7 @@ class SpecialLengths(object):
     END_OF_DATA_SECTION = -1
     PYTHON_EXCEPTION_THROWN = -2
     TIMING_DATA = -3
+    END_OF_STREAM = -4
 
 
 class Serializer(object):

http://git-wip-us.apache.org/repos/asf/spark/blob/e595c8d0/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1a8e415..7a2107e 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,7 +31,7 @@ import tempfile
 import time
 import zipfile
 import random
-from platform import python_implementation
+import threading
 
 if sys.version_info[:2] <= (2, 6):
     try:
@@ -1380,6 +1380,23 @@ class WorkerTests(PySparkTestCase):
         self.assertEqual(sum(range(100)), acc2.value)
         self.assertEqual(sum(range(100)), acc1.value)
 
+    def test_reuse_worker_after_take(self):
+        rdd = self.sc.parallelize(range(100000), 1)
+        self.assertEqual(0, rdd.first())
+
+        def count():
+            try:
+                rdd.count()
+            except Exception:
+                pass
+
+        t = threading.Thread(target=count)
+        t.daemon = True
+        t.start()
+        t.join(5)
+        self.assertTrue(not t.isAlive())
+        self.assertEqual(100000, rdd.count())
+
 
 class SparkSubmitTests(unittest.TestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e595c8d0/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8257ddd..2bdccb5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,7 +57,7 @@ def main(infile, outfile):
         boot_time = time.time()
         split_index = read_int(infile)
         if split_index == -1:  # for unit tests
-            return
+            exit(-1)
 
         # initialize global state
         shuffle.MemoryBytesSpilled = 0
@@ -111,7 +111,6 @@ def main(infile, outfile):
         try:
             write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
             write_with_length(traceback.format_exc(), outfile)
-            outfile.flush()
         except IOError:
             # JVM close the socket
             pass
@@ -131,6 +130,14 @@ def main(infile, outfile):
     for (aid, accum) in _accumulatorRegistry.items():
         pickleSer._write_with_length((aid, accum._value), outfile)
 
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        exit(-1)
+
 
 if __name__ == '__main__':
     # Read a local port to connect to from stdin


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to