Repository: spark
Updated Branches:
  refs/heads/branch-2.3 bf1dabede -> 0f2aabc6b


[SPARK-23816][CORE] Killed tasks should ignore FetchFailures.

SPARK-19276 ensured that FetchFailures do not get swallowed by other
layers of exception handling, but it also meant that a killed task could
look like a fetch failure.  This is particularly a problem with
speculative execution, where we expect to kill tasks as they are reading
shuffle data.  The fix is to ensure that we always check for killed
tasks first.

Added a new unit test which fails before the fix, ran it 1k times to
check for flakiness.  Full suite of tests on jenkins.

Author: Imran Rashid <iras...@cloudera.com>

Closes #20987 from squito/SPARK-23816.

(cherry picked from commit 10f45bb8233e6ac838dd4f053052c8556f5b54bd)
Signed-off-by: Marcelo Vanzin <van...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f2aabc6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f2aabc6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f2aabc6

Branch: refs/heads/branch-2.3
Commit: 0f2aabc6bc64d2b5d46e59525111bd95fcd73610
Parents: bf1dabe
Author: Imran Rashid <iras...@cloudera.com>
Authored: Mon Apr 9 11:31:21 2018 -0700
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Mon Apr 9 11:31:39 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    | 26 +++---
 .../apache/spark/executor/ExecutorSuite.scala   | 92 ++++++++++++++++----
 2 files changed, 88 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f2aabc6/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 2c3a8ef..a9c31c7 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -480,6 +480,19 @@ private[spark] class Executor(
         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
 
       } catch {
+        case t: TaskKilledException =>
+          logInfo(s"Executor killed $taskName (TID $taskId), reason: 
${t.reason}")
+          setTaskFinishedAndClearInterruptStatus()
+          execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled(t.reason)))
+
+        case _: InterruptedException | NonFatal(_) if
+            task != null && task.reasonIfKilled.isDefined =>
+          val killReason = task.reasonIfKilled.getOrElse("unknown reason")
+          logInfo(s"Executor interrupted and killed $taskName (TID $taskId), 
reason: $killReason")
+          setTaskFinishedAndClearInterruptStatus()
+          execBackend.statusUpdate(
+            taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
+
         case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
           val reason = task.context.fetchFailed.get.toTaskFailedReason
           if (!t.isInstanceOf[FetchFailedException]) {
@@ -494,19 +507,6 @@ private[spark] class Executor(
           setTaskFinishedAndClearInterruptStatus()
           execBackend.statusUpdate(taskId, TaskState.FAILED, 
ser.serialize(reason))
 
-        case t: TaskKilledException =>
-          logInfo(s"Executor killed $taskName (TID $taskId), reason: 
${t.reason}")
-          setTaskFinishedAndClearInterruptStatus()
-          execBackend.statusUpdate(taskId, TaskState.KILLED, 
ser.serialize(TaskKilled(t.reason)))
-
-        case _: InterruptedException | NonFatal(_) if
-            task != null && task.reasonIfKilled.isDefined =>
-          val killReason = task.reasonIfKilled.getOrElse("unknown reason")
-          logInfo(s"Executor interrupted and killed $taskName (TID $taskId), 
reason: $killReason")
-          setTaskFinishedAndClearInterruptStatus()
-          execBackend.statusUpdate(
-            taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))
-
         case CausedBy(cDE: CommitDeniedException) =>
           val reason = cDE.toTaskCommitDeniedReason
           setTaskFinishedAndClearInterruptStatus()

http://git-wip-us.apache.org/repos/asf/spark/blob/0f2aabc6/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala 
b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 105a178..1a7bebe 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,6 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
 import java.nio.ByteBuffer
 import java.util.Properties
 import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
 
 import scala.collection.mutable.Map
 import scala.concurrent.duration._
@@ -139,7 +140,7 @@ class ExecutorSuite extends SparkFunSuite with 
LocalSparkContext with MockitoSug
     // the fetch failure.  The executor should still tell the driver that the 
task failed due to a
     // fetch failure, not a generic exception from user code.
     val inputRDD = new FetchFailureThrowingRDD(sc)
-    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
+    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false, 
interrupt = false)
     val taskBinary = sc.broadcast(serializer.serialize((secondRDD, 
resultFunc)).array())
     val serializedTaskMetrics = 
serializer.serialize(TaskMetrics.registered).array()
     val task = new ResultTask(
@@ -173,17 +174,48 @@ class ExecutorSuite extends SparkFunSuite with 
LocalSparkContext with MockitoSug
   }
 
   test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
+    val (failReason, uncaughtExceptionHandler) = testFetchFailureHandling(true)
+    assert(failReason.isInstanceOf[ExceptionFailure])
+    val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
+    verify(uncaughtExceptionHandler).uncaughtException(any(), 
exceptionCaptor.capture())
+    assert(exceptionCaptor.getAllValues.size === 1)
+    
assert(exceptionCaptor.getAllValues().get(0).isInstanceOf[OutOfMemoryError])
+  }
+
+  test("SPARK-23816: interrupts are not masked by a FetchFailure") {
+    // If killing the task causes a fetch failure, we still treat it as a task 
that was killed,
+    // as the fetch failure could easily be caused by interrupting the thread.
+    val (failReason, _) = testFetchFailureHandling(false)
+    assert(failReason.isInstanceOf[TaskKilled])
+  }
+
+  /**
+   * Helper for testing some cases where a FetchFailure should *not* get sent 
back, because its
+   * superceded by another error, either an OOM or intentionally killing a 
task.
+   * @param oom if true, throw an OOM after the FetchFailure; else, interrupt 
the task after the
+    *            FetchFailure
+   */
+  private def testFetchFailureHandling(
+      oom: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
     // when there is a fatal error like an OOM, we don't do normal fetch 
failure handling, since it
     // may be a false positive.  And we should call the uncaught exception 
handler.
+    // SPARK-23816 also handle interrupts the same way, as killing an obsolete 
speculative task
+    // does not represent a real fetch failure.
     val conf = new SparkConf().setMaster("local").setAppName("executor suite 
test")
     sc = new SparkContext(conf)
     val serializer = SparkEnv.get.closureSerializer.newInstance()
     val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
 
-    // Submit a job where a fetch failure is thrown, but then there is an OOM. 
 We should treat
-    // the fetch failure as a false positive, and just do normal OOM handling.
+    // Submit a job where a fetch failure is thrown, but then there is an OOM 
or interrupt.  We
+    // should treat the fetch failure as a false positive, and do normal OOM 
or interrupt handling.
     val inputRDD = new FetchFailureThrowingRDD(sc)
-    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
+    if (!oom) {
+      // we are trying to setup a case where a task is killed after a fetch 
failure -- this
+      // is just a helper to coordinate between the task thread and this 
thread that will
+      // kill the task
+      ExecutorSuiteHelper.latches = new ExecutorSuiteHelper()
+    }
+    val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = oom, 
interrupt = !oom)
     val taskBinary = sc.broadcast(serializer.serialize((secondRDD, 
resultFunc)).array())
     val serializedTaskMetrics = 
serializer.serialize(TaskMetrics.registered).array()
     val task = new ResultTask(
@@ -200,15 +232,8 @@ class ExecutorSuite extends SparkFunSuite with 
LocalSparkContext with MockitoSug
     val serTask = serializer.serialize(task)
     val taskDescription = createFakeTaskDescription(serTask)
 
-    val (failReason, uncaughtExceptionHandler) =
-      runTaskGetFailReasonAndExceptionHandler(taskDescription)
-    // make sure the task failure just looks like a OOM, not a fetch failure
-    assert(failReason.isInstanceOf[ExceptionFailure])
-    val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
-    verify(uncaughtExceptionHandler).uncaughtException(any(), 
exceptionCaptor.capture())
-    assert(exceptionCaptor.getAllValues.size === 1)
-    assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
-  }
+    runTaskGetFailReasonAndExceptionHandler(taskDescription, killTask = !oom)
+ }
 
   test("Gracefully handle error in task deserialization") {
     val conf = new SparkConf
@@ -257,22 +282,39 @@ class ExecutorSuite extends SparkFunSuite with 
LocalSparkContext with MockitoSug
   }
 
   private def runTaskAndGetFailReason(taskDescription: TaskDescription): 
TaskFailedReason = {
-    runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
+    runTaskGetFailReasonAndExceptionHandler(taskDescription, false)._1
   }
 
   private def runTaskGetFailReasonAndExceptionHandler(
-      taskDescription: TaskDescription): (TaskFailedReason, 
UncaughtExceptionHandler) = {
+      taskDescription: TaskDescription,
+      killTask: Boolean): (TaskFailedReason, UncaughtExceptionHandler) = {
     val mockBackend = mock[ExecutorBackend]
     val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
     var executor: Executor = null
+    val timedOut = new AtomicBoolean(false)
     try {
       executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = 
Nil, isLocal = true,
         uncaughtExceptionHandler = mockUncaughtExceptionHandler)
       // the task will be launched in a dedicated worker thread
       executor.launchTask(mockBackend, taskDescription)
+      if (killTask) {
+        val killingThread = new Thread("kill-task") {
+          override def run(): Unit = {
+            // wait to kill the task until it has thrown a fetch failure
+            if (ExecutorSuiteHelper.latches.latch1.await(10, 
TimeUnit.SECONDS)) {
+              // now we can kill the task
+              executor.killAllTasks(true, "Killed task, eg. because of 
speculative execution")
+            } else {
+              timedOut.set(true)
+            }
+          }
+        }
+        killingThread.start()
+      }
       eventually(timeout(5.seconds), interval(10.milliseconds)) {
         assert(executor.numRunningTasks === 0)
       }
+      assert(!timedOut.get(), "timed out waiting to be ready to kill tasks")
     } finally {
       if (executor != null) {
         executor.stop()
@@ -282,8 +324,9 @@ class ExecutorSuite extends SparkFunSuite with 
LocalSparkContext with MockitoSug
     val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
     orderedMock.verify(mockBackend)
       .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
+    val finalState = if (killTask) TaskState.KILLED else TaskState.FAILED
     orderedMock.verify(mockBackend)
-      .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
+      .statusUpdate(meq(0L), meq(finalState), statusCaptor.capture())
     // first statusUpdate for RUNNING has empty data
     assert(statusCaptor.getAllValues().get(0).remaining() === 0)
     // second update is more interesting
@@ -321,7 +364,8 @@ class SimplePartition extends Partition {
 class FetchFailureHidingRDD(
     sc: SparkContext,
     val input: FetchFailureThrowingRDD,
-    throwOOM: Boolean) extends RDD[Int](input) {
+    throwOOM: Boolean,
+    interrupt: Boolean) extends RDD[Int](input) {
   override def compute(split: Partition, context: TaskContext): Iterator[Int] 
= {
     val inItr = input.compute(split, context)
     try {
@@ -330,6 +374,15 @@ class FetchFailureHidingRDD(
       case t: Throwable =>
         if (throwOOM) {
           throw new OutOfMemoryError("OOM while handling another exception")
+        } else if (interrupt) {
+          // make sure our test is setup correctly
+          
assert(TaskContext.get().asInstanceOf[TaskContextImpl].fetchFailed.isDefined)
+          // signal our test is ready for the task to get killed
+          ExecutorSuiteHelper.latches.latch1.countDown()
+          // then wait for another thread in the test to kill the task -- this 
latch
+          // is never actually decremented, we just wait to get killed.
+          ExecutorSuiteHelper.latches.latch2.await(10, TimeUnit.SECONDS)
+          throw new IllegalStateException("timed out waiting to be 
interrupted")
         } else {
           throw new RuntimeException("User Exception that hides the original 
exception", t)
         }
@@ -352,6 +405,11 @@ private class ExecutorSuiteHelper {
   @volatile var testFailedReason: TaskFailedReason = _
 }
 
+// helper for coordinating killing tasks
+private object ExecutorSuiteHelper {
+  var latches: ExecutorSuiteHelper = null
+}
+
 private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable 
{
   def writeExternal(out: ObjectOutput): Unit = {}
   def readExternal(in: ObjectInput): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to