Repository: spark
Updated Branches:
  refs/heads/branch-2.1 f07e989c0 -> 2971ae564


[SPARK-18761][CORE] Introduce "task reaper" to oversee task killing in executors

## What changes were proposed in this pull request?

Spark's current task cancellation / task killing mechanism is "best effort" 
because some tasks may not be interruptible or may not respond to their 
"killed" flags being set. If a significant fraction of a cluster's task slots 
are occupied by tasks that have been marked as killed but remain running then 
this can lead to a situation where new jobs and tasks are starved of resources 
that are being used by these zombie tasks.

This patch aims to address this problem by adding a "task reaper" mechanism to 
executors. At a high-level, task killing now launches a new thread which 
attempts to kill the task and then watches the task and periodically checks 
whether it has been killed. The TaskReaper will periodically re-attempt to call 
`TaskRunner.kill()` and will log warnings if the task keeps running. I modified 
TaskRunner to rename its thread at the start of the task, allowing TaskReaper 
to take a thread dump and filter it in order to log stacktraces from the exact 
task thread that we are waiting to finish. If the task has not stopped after a 
configurable timeout then the TaskReaper will throw an exception to trigger 
executor JVM death, thereby forcibly freeing any resources consumed by the 
zombie tasks.

This feature is flagged off by default and is controlled by four new 
configurations under the `spark.task.reaper.*` namespace. See the updated 
`configuration.md` doc for details.

## How was this patch tested?

Tested via a new test case in `JobCancellationSuite`, plus manual testing.

Author: Josh Rosen <joshro...@databricks.com>

Closes #16189 from JoshRosen/cancellation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2971ae56
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2971ae56
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2971ae56

Branch: refs/heads/branch-2.1
Commit: 2971ae564cb3e97aa5ecac7f411daed7d54248ad
Parents: f07e989
Author: Josh Rosen <joshro...@databricks.com>
Authored: Mon Dec 19 18:43:59 2016 -0800
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Tue Dec 20 11:53:20 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    | 169 ++++++++++++++++++-
 .../scala/org/apache/spark/util/Utils.scala     |  56 +++---
 .../org/apache/spark/JobCancellationSuite.scala |  77 +++++++++
 docs/configuration.md                           |  42 +++++
 4 files changed, 316 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2971ae56/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9501dd9..3346f6d 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -84,6 +84,16 @@ private[spark] class Executor(
   // Start worker thread pool
   private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor 
task launch worker")
   private val executorSource = new ExecutorSource(threadPool, executorId)
+  // Pool used for threads that supervise task killing / cancellation
+  private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task 
reaper")
+  // For tasks which are in the process of being killed, this map holds the 
most recently created
+  // TaskReaper. All accesses to this map should be synchronized on the map 
itself (this isn't
+  // a ConcurrentHashMap because we use the synchronization for purposes other 
than simply guarding
+  // the integrity of the map's internal state). The purpose of this map is to 
prevent the creation
+  // of a separate TaskReaper for every killTask() of a given task. Instead, 
this map allows us to
+  // track whether an existing TaskReaper fulfills the role of a TaskReaper 
that we would otherwise
+  // create. The map key is a task id.
+  private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, 
TaskReaper]()
 
   if (!isLocal) {
     env.metricsSystem.registerSource(executorSource)
@@ -93,6 +103,9 @@ private[spark] class Executor(
   // Whether to load classes in user jars before those in Spark jars
   private val userClassPathFirst = 
conf.getBoolean("spark.executor.userClassPathFirst", false)
 
+  // Whether to monitor killed / interrupted tasks
+  private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", 
false)
+
   // Create our ClassLoader
   // do this after SparkEnv creation so can access the SecurityManager
   private val urlClassLoader = createClassLoader()
@@ -148,9 +161,27 @@ private[spark] class Executor(
   }
 
   def killTask(taskId: Long, interruptThread: Boolean): Unit = {
-    val tr = runningTasks.get(taskId)
-    if (tr != null) {
-      tr.kill(interruptThread)
+    val taskRunner = runningTasks.get(taskId)
+    if (taskRunner != null) {
+      if (taskReaperEnabled) {
+        val maybeNewTaskReaper: Option[TaskReaper] = 
taskReaperForTask.synchronized {
+          val shouldCreateReaper = taskReaperForTask.get(taskId) match {
+            case None => true
+            case Some(existingReaper) => interruptThread && 
!existingReaper.interruptThread
+          }
+          if (shouldCreateReaper) {
+            val taskReaper = new TaskReaper(taskRunner, interruptThread = 
interruptThread)
+            taskReaperForTask(taskId) = taskReaper
+            Some(taskReaper)
+          } else {
+            None
+          }
+        }
+        // Execute the TaskReaper from outside of the synchronized block.
+        maybeNewTaskReaper.foreach(taskReaperPool.execute)
+      } else {
+        taskRunner.kill(interruptThread = interruptThread)
+      }
     }
   }
 
@@ -161,12 +192,7 @@ private[spark] class Executor(
    * @param interruptThread whether to interrupt the task thread
    */
   def killAllTasks(interruptThread: Boolean) : Unit = {
-    // kill all the running tasks
-    for (taskRunner <- runningTasks.values().asScala) {
-      if (taskRunner != null) {
-        taskRunner.kill(interruptThread)
-      }
-    }
+    runningTasks.keys().asScala.foreach(t => killTask(t, interruptThread = 
interruptThread))
   }
 
   def stop(): Unit = {
@@ -192,13 +218,21 @@ private[spark] class Executor(
       serializedTask: ByteBuffer)
     extends Runnable {
 
+    val threadName = s"Executor task launch worker for task $taskId"
+
     /** Whether this task has been killed. */
     @volatile private var killed = false
 
+    @volatile private var threadId: Long = -1
+
+    def getThreadId: Long = threadId
+
     /** Whether this task has been finished. */
     @GuardedBy("TaskRunner.this")
     private var finished = false
 
+    def isFinished: Boolean = synchronized { finished }
+
     /** How much the JVM process has spent in GC when the task starts to run. 
*/
     @volatile var startGCTime: Long = _
 
@@ -229,9 +263,15 @@ private[spark] class Executor(
       // ClosedByInterruptException during execBackend.statusUpdate which 
causes
       // Executor to crash
       Thread.interrupted()
+      // Notify any waiting TaskReapers. Generally there will only be one 
reaper per task but there
+      // is a rare corner-case where one task can have two reapers in case 
cancel(interrupt=False)
+      // is followed by cancel(interrupt=True). Thus we use notifyAll() to 
avoid a lost wakeup:
+      notifyAll()
     }
 
     override def run(): Unit = {
+      threadId = Thread.currentThread.getId
+      Thread.currentThread.setName(threadName)
       val threadMXBean = ManagementFactory.getThreadMXBean
       val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
       val deserializeStartTime = System.currentTimeMillis()
@@ -432,6 +472,117 @@ private[spark] class Executor(
   }
 
   /**
+   * Supervises the killing / cancellation of a task by sending the 
interrupted flag, optionally
+   * sending a Thread.interrupt(), and monitoring the task until it finishes.
+   *
+   * Spark's current task cancellation / task killing mechanism is "best 
effort" because some tasks
+   * may not be interruptable or may not respond to their "killed" flags being 
set. If a significant
+   * fraction of a cluster's task slots are occupied by tasks that have been 
marked as killed but
+   * remain running then this can lead to a situation where new jobs and tasks 
are starved of
+   * resources that are being used by these zombie tasks.
+   *
+   * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor 
and clean up zombie
+   * tasks. For backwards-compatibility / backportability this component is 
disabled by default
+   * and must be explicitly enabled by setting 
`spark.task.reaper.enabled=true`.
+   *
+   * A TaskReaper is created for a particular task when that task is killed / 
cancelled. Typically
+   * a task will have only one TaskReaper, but it's possible for a task to 
have up to two reapers
+   * in case kill is called twice with different values for the `interrupt` 
parameter.
+   *
+   * Once created, a TaskReaper will run until its supervised task has 
finished running. If the
+   * TaskReaper has not been configured to kill the JVM after a timeout (i.e. 
if
+   * `spark.task.reaper.killTimeout < 0`) then this implies that the 
TaskReaper may run indefinitely
+   * if the supervised task never exits.
+   */
+  private class TaskReaper(
+      taskRunner: TaskRunner,
+      val interruptThread: Boolean)
+    extends Runnable {
+
+    private[this] val taskId: Long = taskRunner.taskId
+
+    private[this] val killPollingIntervalMs: Long =
+      conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s")
+
+    private[this] val killTimeoutMs: Long = 
conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1")
+
+    private[this] val takeThreadDump: Boolean =
+      conf.getBoolean("spark.task.reaper.threadDump", true)
+
+    override def run(): Unit = {
+      val startTimeMs = System.currentTimeMillis()
+      def elapsedTimeMs = System.currentTimeMillis() - startTimeMs
+      def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > 
killTimeoutMs
+      try {
+        // Only attempt to kill the task once. If interruptThread = false then 
a second kill
+        // attempt would be a no-op and if interruptThread = true then it may 
not be safe or
+        // effective to interrupt multiple times:
+        taskRunner.kill(interruptThread = interruptThread)
+        // Monitor the killed task until it exits. The synchronization logic 
here is complicated
+        // because we don't want to synchronize on the taskRunner while 
possibly taking a thread
+        // dump, but we also need to be careful to avoid races between 
checking whether the task
+        // has finished and wait()ing for it to finish.
+        var finished: Boolean = false
+        while (!finished && !timeoutExceeded()) {
+          taskRunner.synchronized {
+            // We need to synchronize on the TaskRunner while checking whether 
the task has
+            // finished in order to avoid a race where the task is marked as 
finished right after
+            // we check and before we call wait().
+            if (taskRunner.isFinished) {
+              finished = true
+            } else {
+              taskRunner.wait(killPollingIntervalMs)
+            }
+          }
+          if (taskRunner.isFinished) {
+            finished = true
+          } else {
+            logWarning(s"Killed task $taskId is still running after 
$elapsedTimeMs ms")
+            if (takeThreadDump) {
+              try {
+                Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { 
thread =>
+                  if (thread.threadName == taskRunner.threadName) {
+                    logWarning(s"Thread dump from task 
$taskId:\n${thread.stackTrace}")
+                  }
+                }
+              } catch {
+                case NonFatal(e) =>
+                  logWarning("Exception thrown while obtaining thread dump: ", 
e)
+              }
+            }
+          }
+        }
+
+        if (!taskRunner.isFinished && timeoutExceeded()) {
+          if (isLocal) {
+            logError(s"Killed task $taskId could not be stopped within 
$killTimeoutMs ms; " +
+              "not killing JVM because we are running in local mode.")
+          } else {
+            // In non-local-mode, the exception thrown here will bubble up to 
the uncaught exception
+            // handler and cause the executor JVM to exit.
+            throw new SparkException(
+              s"Killing executor JVM because killed task $taskId could not be 
stopped within " +
+                s"$killTimeoutMs ms.")
+          }
+        }
+      } finally {
+        // Clean up entries in the taskReaperForTask map.
+        taskReaperForTask.synchronized {
+          taskReaperForTask.get(taskId).foreach { taskReaperInMap =>
+            if (taskReaperInMap eq this) {
+              taskReaperForTask.remove(taskId)
+            } else {
+              // This must have been a TaskReaper where interruptThread == 
false where a subsequent
+              // killTask() call for the same task had interruptThread == true 
and overwrote the
+              // map entry.
+            }
+          }
+        }
+      }
+    }
+  }
+
+  /**
    * Create a ClassLoader for use in tasks, adding any JARs specified by the 
user or any classes
    * created by the interpreter to the search path
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/2971ae56/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 0715151..1319a4c 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.util
 
 import java.io._
-import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo}
+import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, 
ThreadInfo}
 import java.net._
 import java.nio.ByteBuffer
 import java.nio.channels.Channels
@@ -2117,28 +2117,46 @@ private[spark] object Utils extends Logging {
     // We need to filter out null values here because dumpAllThreads() may 
return null array
     // elements for threads that are dead / don't exist.
     val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, 
true).filter(_ != null)
-    threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
-      val monitors = threadInfo.getLockedMonitors.map(m => 
m.getLockedStackFrame -> m).toMap
-      val stackTrace = threadInfo.getStackTrace.map { frame =>
-        monitors.get(frame) match {
-          case Some(monitor) =>
-            monitor.getLockedStackFrame.toString + s" => holding 
${monitor.lockString}"
-          case None =>
-            frame.toString
-        }
-      }.mkString("\n")
-
-      // use a set to dedup re-entrant locks that are held at multiple places
-      val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString)
-          ++ threadInfo.getLockedMonitors.map(_.lockString)
-        ).toSet
+    threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace)
+  }
 
-      ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, 
threadInfo.getThreadState,
-        stackTrace, if (threadInfo.getLockOwnerId < 0) None else 
Some(threadInfo.getLockOwnerId),
-        Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), 
heldLocks.toSeq)
+  def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = {
+    if (threadId <= 0) {
+      None
+    } else {
+      // The Int.MaxValue here requests the entire untruncated stack trace of 
the thread:
+      val threadInfo =
+        Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, 
Int.MaxValue))
+      threadInfo.map(threadInfoToThreadStackTrace)
     }
   }
 
+  private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): 
ThreadStackTrace = {
+    val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame 
-> m).toMap
+    val stackTrace = threadInfo.getStackTrace.map { frame =>
+      monitors.get(frame) match {
+        case Some(monitor) =>
+          monitor.getLockedStackFrame.toString + s" => holding 
${monitor.lockString}"
+        case None =>
+          frame.toString
+      }
+    }.mkString("\n")
+
+    // use a set to dedup re-entrant locks that are held at multiple places
+    val heldLocks =
+      (threadInfo.getLockedSynchronizers ++ 
threadInfo.getLockedMonitors).map(_.lockString).toSet
+
+    ThreadStackTrace(
+      threadId = threadInfo.getThreadId,
+      threadName = threadInfo.getThreadName,
+      threadState = threadInfo.getThreadState,
+      stackTrace = stackTrace,
+      blockedByThreadId =
+        if (threadInfo.getLockOwnerId < 0) None else 
Some(threadInfo.getLockOwnerId),
+      blockedByLock = 
Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""),
+      holdingLocks = heldLocks.toSeq)
+  }
+
   /**
    * Convert all spark properties set in the given SparkConf to a sequence of 
java options.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/2971ae56/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala 
b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index a3490fc..99150a1 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with 
Matchers with BeforeAndAft
     assert(jobB.get() === 100)
   }
 
+  test("task reaper kills JVM if killed tasks keep running for too long") {
+    val conf = new SparkConf()
+      .set("spark.task.reaper.enabled", "true")
+      .set("spark.task.reaper.killTimeout", "5s")
+    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
+
+    // Add a listener to release the semaphore once any tasks are launched.
+    val sem = new Semaphore(0)
+    sc.addSparkListener(new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart) {
+        sem.release()
+      }
+    })
+
+    // jobA is the one to be cancelled.
+    val jobA = Future {
+      sc.setJobGroup("jobA", "this is a job to be cancelled", 
interruptOnCancel = true)
+      sc.parallelize(1 to 10000, 2).map { i =>
+        while (true) { }
+      }.count()
+    }
+
+    // Block until both tasks of job A have started and cancel job A.
+    sem.acquire(2)
+    // Small delay to ensure tasks actually start executing the task body
+    Thread.sleep(1000)
+
+    sc.clearJobGroup()
+    val jobB = sc.parallelize(1 to 100, 2).countAsync()
+    sc.cancelJobGroup("jobA")
+    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 
15.seconds) }.getCause
+    assert(e.getMessage contains "cancel")
+
+    // Once A is cancelled, job B should finish fairly quickly.
+    assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
+  }
+
+  test("task reaper will not kill JVM if spark.task.killTimeout == -1") {
+    val conf = new SparkConf()
+      .set("spark.task.reaper.enabled", "true")
+      .set("spark.task.reaper.killTimeout", "-1")
+      .set("spark.task.reaper.PollingInterval", "1s")
+      .set("spark.deploy.maxExecutorRetries", "1")
+    sc = new SparkContext("local-cluster[2,1,1024]", "test", conf)
+
+    // Add a listener to release the semaphore once any tasks are launched.
+    val sem = new Semaphore(0)
+    sc.addSparkListener(new SparkListener {
+      override def onTaskStart(taskStart: SparkListenerTaskStart) {
+        sem.release()
+      }
+    })
+
+    // jobA is the one to be cancelled.
+    val jobA = Future {
+      sc.setJobGroup("jobA", "this is a job to be cancelled", 
interruptOnCancel = true)
+      sc.parallelize(1 to 2, 2).map { i =>
+        val startTime = System.currentTimeMillis()
+        while (System.currentTimeMillis() < startTime + 10000) { }
+      }.count()
+    }
+
+    // Block until both tasks of job A have started and cancel job A.
+    sem.acquire(2)
+    // Small delay to ensure tasks actually start executing the task body
+    Thread.sleep(1000)
+
+    sc.clearJobGroup()
+    val jobB = sc.parallelize(1 to 100, 2).countAsync()
+    sc.cancelJobGroup("jobA")
+    val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 
15.seconds) }.getCause
+    assert(e.getMessage contains "cancel")
+
+    // Once A is cancelled, job B should finish fairly quickly.
+    assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100)
+  }
+
   test("two jobs sharing the same stage") {
     // sem1: make sure cancel is issued after some tasks are launched
     // twoJobsSharingStageSemaphore:

http://git-wip-us.apache.org/repos/asf/spark/blob/2971ae56/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index e33af3a..9c325b6 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1366,6 +1366,48 @@ Apart from these, the following properties are also 
available, and may be useful
     Should be greater than or equal to 1. Number of allowed retries = this 
value - 1.
   </td>
 </tr>
+<tr>
+  <td><code>spark.task.reaper.enabled</code></td>
+  <td>false</td>
+  <td>
+    Enables monitoring of killed / interrupted tasks. When set to true, any 
task which is killed
+    will be monitored by the executor until that task actually finishes 
executing. See the other
+    <code>spark.task.reaper.*</code> configurations for details on how to 
control the exact behavior
+    of this monitoring</code>. When set to false (the default), task killing 
will use an older code
+    path which lacks such monitoring.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.task.reaper.pollingInterval</code></td>
+  <td>10s</td>
+  <td>
+    When <code>spark.task.reaper.enabled = true</code>, this setting controls 
the frequency at which
+    executors will poll the status of killed tasks. If a killed task is still 
running when polled
+    then a warning will be logged and, by default, a thread-dump of the task 
will be logged
+    (this thread dump can be disabled via the 
<code>spark.task.reaper.threadDump</code> setting,
+    which is documented below).
+  </td>
+</tr>
+<tr>
+  <td><code>spark.task.reaper.threadDump</code></td>
+  <td>true</td>
+  <td>
+    When <code>spark.task.reaper.enabled = true</code>, this setting controls 
whether task thread
+    dumps are logged during periodic polling of killed tasks. Set this to 
false to disable
+    collection of thread dumps.
+  </td>
+</tr>
+<tr>
+  <td><code>spark.task.reaper.killTimeout</code></td>
+  <td>-1</td>
+  <td>
+    When <code>spark.task.reaper.enabled = true</code>, this setting specifies 
a timeout after
+    which the executor JVM will kill itself if a killed task has not stopped 
running. The default
+    value, -1, disables this mechanism and prevents the executor from 
self-destructing. The purpose
+    of this setting is to act as a safety-net to prevent runaway uncancellable 
tasks from rendering
+    an executor unusable.
+  </td>
+</tr>
 </table>
 
 #### Dynamic Allocation


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to