Repository: spark
Updated Branches:
  refs/heads/master e96645206 -> e0f28e010


[SPARK-4737] Task set manager properly handles serialization errors

Dealing with [SPARK-4737], the handling of serialization errors should not be 
the DAGScheduler's responsibility. The task set manager now catches the error 
and aborts the stage.

If the TaskSetManager throws a TaskNotSerializableException, the 
TaskSchedulerImpl will return an empty list of task descriptions, because no 
tasks were started. The scheduler should abort the stage gracefully.

Note that I'm not too familiar with this part of the codebase and its place in 
the overall architecture of the Spark stack. If implementing it this way will 
have any averse side effects please voice that loudly.

Author: mcheah <[email protected]>

Closes #3638 from mccheah/task-set-manager-properly-handle-ser-err and squashes 
the following commits:

1545984 [mcheah] Some more style fixes from Andrew Or.
5267929 [mcheah] Fixing style suggestions from Andrew Or.
dfa145b [mcheah] Fixing style from Josh Rosen's feedback
b2a430d [mcheah] Not returning empty seq when a task set cannot be serialized.
94844d7 [mcheah] Fixing compilation error, one brace too many
5f486f4 [mcheah] Adding license header for fake task class
bf5e706 [mcheah] Fixing indentation.
097e7a2 [mcheah] [SPARK-4737] Catching task serialization exception in 
TaskSetManager


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0f28e01
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0f28e01
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0f28e01

Branch: refs/heads/master
Commit: e0f28e010cdd67a2a4c8aebd35323d69a3182ba8
Parents: e966452
Author: mcheah <[email protected]>
Authored: Fri Jan 9 14:16:20 2015 -0800
Committer: Andrew Or <[email protected]>
Committed: Fri Jan 9 14:16:20 2015 -0800

----------------------------------------------------------------------
 .../spark/TaskNotSerializableException.scala    | 25 +++++++++
 .../apache/spark/scheduler/DAGScheduler.scala   | 20 --------
 .../spark/scheduler/TaskSchedulerImpl.scala     | 54 ++++++++++++++------
 .../apache/spark/scheduler/TaskSetManager.scala | 18 +++++--
 .../org/apache/spark/SharedSparkContext.scala   |  2 +-
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 21 ++++++++
 .../scheduler/NotSerializableFakeTask.scala     | 40 +++++++++++++++
 .../scheduler/TaskSchedulerImplSuite.scala      | 30 +++++++++++
 .../spark/scheduler/TaskSetManagerSuite.scala   | 14 +++++
 9 files changed, 182 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala 
b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
new file mode 100644
index 0000000..9df6106
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * Exception thrown when a task cannot be serialized.
+ */
+private[spark] class TaskNotSerializableException(error: Throwable) extends 
Exception(error)

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 259621d..61d09d7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -866,26 +866,6 @@ class DAGScheduler(
     }
 
     if (tasks.size > 0) {
-      // Preemptively serialize a task to make sure it can be serialized. We 
are catching this
-      // exception here because it would be fairly hard to catch the 
non-serializable exception
-      // down the road, where we have several different implementations for 
local scheduler and
-      // cluster schedulers.
-      //
-      // We've already serialized RDDs and closures in taskBinary, but here we 
check for all other
-      // objects such as Partition.
-      try {
-        closureSerializer.serialize(tasks.head)
-      } catch {
-        case e: NotSerializableException =>
-          abortStage(stage, "Task not serializable: " + e.toString)
-          runningStages -= stage
-          return
-        case NonFatal(e) => // Other exceptions, such as 
IllegalArgumentException from Kryo.
-          abortStage(stage, s"Task serialization failed: 
$e\n${e.getStackTraceString}")
-          runningStages -= stage
-          return
-      }
-
       logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " 
(" + stage.rdd + ")")
       stage.pendingTasks ++= tasks
       logDebug("New pending tasks: " + stage.pendingTasks)

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index a41f3ee..a1dfb01 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -31,6 +31,7 @@ import scala.util.Random
 import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.TaskLocality.TaskLocality
 import org.apache.spark.util.Utils
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.storage.BlockManagerId
@@ -209,6 +210,40 @@ private[spark] class TaskSchedulerImpl(
       .format(manager.taskSet.id, manager.parent.name))
   }
 
+  private def resourceOfferSingleTaskSet(
+      taskSet: TaskSetManager,
+      maxLocality: TaskLocality,
+      shuffledOffers: Seq[WorkerOffer],
+      availableCpus: Array[Int],
+      tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
+    var launchedTask = false
+    for (i <- 0 until shuffledOffers.size) {
+      val execId = shuffledOffers(i).executorId
+      val host = shuffledOffers(i).host
+      if (availableCpus(i) >= CPUS_PER_TASK) {
+        try {
+          for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
+            tasks(i) += task
+            val tid = task.taskId
+            taskIdToTaskSetId(tid) = taskSet.taskSet.id
+            taskIdToExecutorId(tid) = execId
+            executorsByHost(host) += execId
+            availableCpus(i) -= CPUS_PER_TASK
+            assert(availableCpus(i) >= 0)
+            launchedTask = true
+          }
+        } catch {
+          case e: TaskNotSerializableException =>
+            logError(s"Resource offer failed, task set ${taskSet.name} was not 
serializable")
+            // Do not offer resources for this task, but don't throw an error 
to allow other
+            // task sets to be submitted.
+            return launchedTask
+        }
+      }
+    }
+    return launchedTask
+  }
+
   /**
    * Called by cluster manager to offer resources on slaves. We respond by 
asking our active task
    * sets for tasks in order of priority. We fill each node with tasks in a 
round-robin manner so
@@ -251,23 +286,8 @@ private[spark] class TaskSchedulerImpl(
     var launchedTask = false
     for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
       do {
-        launchedTask = false
-        for (i <- 0 until shuffledOffers.size) {
-          val execId = shuffledOffers(i).executorId
-          val host = shuffledOffers(i).host
-          if (availableCpus(i) >= CPUS_PER_TASK) {
-            for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
-              tasks(i) += task
-              val tid = task.taskId
-              taskIdToTaskSetId(tid) = taskSet.taskSet.id
-              taskIdToExecutorId(tid) = execId
-              executorsByHost(host) += execId
-              availableCpus(i) -= CPUS_PER_TASK
-              assert(availableCpus(i) >= 0)
-              launchedTask = true
-            }
-          }
-        }
+        launchedTask = resourceOfferSingleTaskSet(
+            taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
       } while (launchedTask)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 28e6147..4667850 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -18,12 +18,14 @@
 package org.apache.spark.scheduler
 
 import java.io.NotSerializableException
+import java.nio.ByteBuffer
 import java.util.Arrays
 
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 import scala.collection.mutable.HashSet
 import scala.math.{min, max}
+import scala.util.control.NonFatal
 
 import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
@@ -417,6 +419,7 @@ private[spark] class TaskSetManager(
    * @param host  the host Id of the offered resource
    * @param maxLocality the maximum locality we want to schedule the tasks at
    */
+  @throws[TaskNotSerializableException]
   def resourceOffer(
       execId: String,
       host: String,
@@ -456,10 +459,17 @@ private[spark] class TaskSetManager(
           }
           // Serialize and return the task
           val startTime = clock.getTime()
-          // We rely on the DAGScheduler to catch non-serializable closures 
and RDDs, so in here
-          // we assume the task can be serialized without exceptions.
-          val serializedTask = Task.serializeWithDependencies(
-            task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+          val serializedTask: ByteBuffer = try {
+            Task.serializeWithDependencies(task, sched.sc.addedFiles, 
sched.sc.addedJars, ser)
+          } catch {
+            // If the task cannot be serialized, then there's no point to 
re-attempt the task,
+            // as it will always fail. So just abort the whole task-set.
+            case NonFatal(e) =>
+              val msg = s"Failed to serialize task $taskId, not attempting to 
retry it."
+              logError(msg, e)
+              abort(s"$msg Exception during serialization: $e")
+              throw new TaskNotSerializableException(e)
+          }
           if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 
1024 &&
               !emittedTaskSizeWarning) {
             emittedTaskSizeWarning = true

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala 
b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
index 0b6511a..3d2700b 100644
--- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
@@ -30,7 +30,7 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: 
Suite =>
   var conf = new SparkConf(false)
 
   override def beforeAll() {
-    _sc = new SparkContext("local", "test", conf)
+    _sc = new SparkContext("local[4]", "test", conf)
     super.beforeAll()
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala 
b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6836e9a..0deb9b1 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.rdd
 
+import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
+
+import com.esotericsoftware.kryo.KryoException
+
 import scala.collection.mutable.{ArrayBuffer, HashMap}
 import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
@@ -887,6 +891,23 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(ancestors6.count(_.isInstanceOf[CyclicalDependencyRDD[_]]) === 3)
   }
 
+  test("task serialization exception should not hang scheduler") {
+    class BadSerializable extends Serializable {
+      @throws(classOf[IOException])
+      private def writeObject(out: ObjectOutputStream): Unit = throw new 
KryoException("Bad serialization")
+
+      @throws(classOf[IOException])
+      private def readObject(in: ObjectInputStream): Unit = {}
+    }
+    // Note that in the original bug, SPARK-4349, that this verifies, the job 
would only hang if there were
+    // more threads in the Spark Context than there were number of objects in 
this sequence.
+    intercept[Throwable] {
+      sc.parallelize(Seq(new BadSerializable, new BadSerializable)).collect
+    }
+    // Check that the context has not crashed
+    sc.parallelize(1 to 100).map(x => x*2).collect
+  }
+
   /** A contrived RDD that allows the manual addition of dependencies after 
creation. */
   private class CyclicalDependencyRDD[T: ClassTag] extends RDD[T](sc, Nil) {
     private val mutableDependencies: ArrayBuffer[Dependency[_]] = 
ArrayBuffer.empty

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala 
b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
new file mode 100644
index 0000000..6b75c98
--- /dev/null
+++ 
b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
+
+import org.apache.spark.TaskContext
+
+/**
+ * A Task implementation that fails to serialize.
+ */
+private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) extends 
Task[Array[Byte]](stageId, 0) {
+  override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
+  override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
+
+  @throws(classOf[IOException])
+  private def writeObject(out: ObjectOutputStream): Unit = {
+    if (stageId == 0) {
+      throw new IllegalStateException("Cannot serialize")
+    }
+  }
+
+  @throws(classOf[IOException])
+  private def readObject(in: ObjectInputStream): Unit = {}
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 8874cf0..add13f5 100644
--- 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -100,4 +100,34 @@ class TaskSchedulerImplSuite extends FunSuite with 
LocalSparkContext with Loggin
     assert(1 === taskDescriptions.length)
     assert("executor0" === taskDescriptions(0).executorId)
   }
+
+  test("Scheduler does not crash when tasks are not serializable") {
+    sc = new SparkContext("local", "TaskSchedulerImplSuite")
+    val taskCpus = 2
+
+    sc.conf.set("spark.task.cpus", taskCpus.toString)
+    val taskScheduler = new TaskSchedulerImpl(sc)
+    taskScheduler.initialize(new FakeSchedulerBackend)
+    // Need to initialize a DAGScheduler for the taskScheduler to use for 
callbacks.
+    val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+      override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+      override def executorAdded(execId: String, host: String) {}
+    }
+    val numFreeCores = 1
+    taskScheduler.setDAGScheduler(dagScheduler)
+    var taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new 
NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
+    val multiCoreWorkerOffers = Seq(new WorkerOffer("executor0", "host0", 
taskCpus),
+      new WorkerOffer("executor1", "host1", numFreeCores))
+    taskScheduler.submitTasks(taskSet)
+    var taskDescriptions = 
taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
+    assert(0 === taskDescriptions.length)
+
+    // Now check that we can still submit tasks
+    // Even if one of the tasks has not-serializable tasks, the other task set 
should still be processed without error
+    taskScheduler.submitTasks(taskSet)
+    taskScheduler.submitTasks(FakeTask.createTaskSet(1))
+    taskDescriptions = 
taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten
+    assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e0f28e01/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 4721915..84b9b78 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.scheduler
 
+import java.io.{ObjectInputStream, ObjectOutputStream, IOException}
 import java.util.Random
 
 import scala.collection.mutable.ArrayBuffer
@@ -563,6 +564,19 @@ class TaskSetManagerSuite extends FunSuite with 
LocalSparkContext with Logging {
     assert(manager.emittedTaskSizeWarning)
   }
 
+  test("Not serializable exception thrown if the task cannot be serialized") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+
+    val taskSet = new TaskSet(Array(new NotSerializableFakeTask(1, 0), new 
NotSerializableFakeTask(0, 1)), 0, 0, 0, null)
+    val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
+
+    intercept[TaskNotSerializableException] {
+      manager.resourceOffer("exec1", "host1", ANY)
+    }
+    assert(manager.isZombie)
+  }
+
   test("abort the job if total size of results is too large") {
     val conf = new SparkConf().set("spark.driver.maxResultSize", "2m")
     sc = new SparkContext("local", "test", conf)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to