Repository: spark
Updated Branches:
  refs/heads/master 23f966f47 -> 6181577e9


[SPARK-3466] Limit size of results that a driver collects for each action

Right now, operations like collect() and take() can crash the driver with an 
OOM if they bring back too many data.

This PR will introduce spark.driver.maxResultSize, after setting it, the driver 
will abort a job if its result is bigger than it.

By default, it's 1g (for backward compatibility for most the cases).

In local mode, the driver and executor share the same JVM, the default setting 
can not protect JVM from OOM.

cc mateiz

Author: Davies Liu <dav...@databricks.com>

Closes #3003 from davies/collect and squashes the following commits:

248ed5e [Davies Liu] fix compile
272522e [Davies Liu] address comments
2c35773 [Davies Liu] add sizes in message of abort()
5d62303 [Davies Liu] address comments
bc3c077 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
collect
11f97c5 [Davies Liu] address comments
47b144f [Davies Liu] check the size of result before send and fetch
3d81af2 [Davies Liu] address comments
ca8267d [Davies Liu] limit the size of data by collect


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6181577e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6181577e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6181577e

Branch: refs/heads/master
Commit: 6181577e9935f46b646ba3925b873d031aa3d6ba
Parents: 23f966f
Author: Davies Liu <dav...@databricks.com>
Authored: Sun Nov 2 00:03:51 2014 -0700
Committer: Matei Zaharia <ma...@databricks.com>
Committed: Sun Nov 2 00:03:51 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    | 25 +++++++++------
 .../org/apache/spark/scheduler/TaskResult.scala |  4 +--
 .../spark/scheduler/TaskResultGetter.scala      | 20 +++++++++---
 .../apache/spark/scheduler/TaskSetManager.scala | 33 +++++++++++++++++---
 .../scala/org/apache/spark/util/Utils.scala     |  5 +++
 .../spark/scheduler/TaskResultGetterSuite.scala |  2 +-
 .../spark/scheduler/TaskSetManagerSuite.scala   | 25 +++++++++++++++
 docs/configuration.md                           | 12 +++++++
 8 files changed, 104 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c78e0ff..e24a15f 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -104,6 +104,9 @@ private[spark] class Executor(
   // to send the result back.
   private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
 
+  // Limit of bytes for total size of results (default is 1GB)
+  private val maxResultSize = Utils.getMaxResultSize(conf)
+
   // Start worker thread pool
   val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch 
worker")
 
@@ -210,25 +213,27 @@ private[spark] class Executor(
         val resultSize = serializedDirectResult.limit
 
         // directSend = sending directly back to the driver
-        val (serializedResult, directSend) = {
-          if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+        val serializedResult = {
+          if (resultSize > maxResultSize) {
+            logWarning(s"Finished $taskName (TID $taskId). Result is larger 
than maxResultSize " +
+              s"(${Utils.bytesToString(resultSize)} > 
${Utils.bytesToString(maxResultSize)}), " +
+              s"dropping it.")
+            ser.serialize(new 
IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
+          } else if (resultSize >= akkaFrameSize - 
AkkaUtils.reservedSizeBytes) {
             val blockId = TaskResultBlockId(taskId)
             env.blockManager.putBytes(
               blockId, serializedDirectResult, 
StorageLevel.MEMORY_AND_DISK_SER)
-            (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
+            logInfo(
+              s"Finished $taskName (TID $taskId). $resultSize bytes result 
sent via BlockManager)")
+            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
           } else {
-            (serializedDirectResult, true)
+            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes 
result sent to driver")
+            serializedDirectResult
           }
         }
 
         execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
 
-        if (directSend) {
-          logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result 
sent to driver")
-        } else {
-          logInfo(
-            s"Finished $taskName (TID $taskId). $resultSize bytes result sent 
via BlockManager)")
-        }
       } catch {
         case ffe: FetchFailedException => {
           val reason = ffe.toTaskEndReason

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 11c19ee..1f114a0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.Utils
 private[spark] sealed trait TaskResult[T]
 
 /** A reference to a DirectTaskResult that has been stored in the worker's 
BlockManager. */
-private[spark]
-case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with 
Serializable
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+  extends TaskResult[T] with Serializable
 
 /** A TaskResult that contains the task's return value and accumulator 
updates. */
 private[spark]

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 4b5be68..819b51e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, 
scheduler: TaskSchedul
     getTaskResultExecutor.execute(new Runnable {
       override def run(): Unit = Utils.logUncaughtExceptions {
         try {
-          val result = 
serializer.get().deserialize[TaskResult[_]](serializedData) match {
-            case directResult: DirectTaskResult[_] => directResult
-            case IndirectTaskResult(blockId) =>
+          val (result, size) = 
serializer.get().deserialize[TaskResult[_]](serializedData) match {
+            case directResult: DirectTaskResult[_] =>
+              if (!taskSetManager.canFetchMoreResults(serializedData.limit())) 
{
+                return
+              }
+              (directResult, serializedData.limit())
+            case IndirectTaskResult(blockId, size) =>
+              if (!taskSetManager.canFetchMoreResults(size)) {
+                // dropped by executor if size is larger than maxResultSize
+                sparkEnv.blockManager.master.removeBlock(blockId)
+                return
+              }
               logDebug("Fetching indirect task result for TID %s".format(tid))
               scheduler.handleTaskGettingResult(taskSetManager, tid)
               val serializedTaskResult = 
sparkEnv.blockManager.getRemoteBytes(blockId)
@@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, 
scheduler: TaskSchedul
               val deserializedResult = 
serializer.get().deserialize[DirectTaskResult[_]](
                 serializedTaskResult.get)
               sparkEnv.blockManager.master.removeBlock(blockId)
-              deserializedResult
+              (deserializedResult, size)
           }
-          result.metrics.resultSize = serializedData.limit()
+
+          result.metrics.resultSize = size
           scheduler.handleSuccessfulTask(taskSetManager, tid, result)
         } catch {
           case cnf: ClassNotFoundException =>

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 376821f..a976734 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -23,13 +23,12 @@ import java.util.Arrays
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
+import scala.math.{min, max}
 
 import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
 import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{Clock, SystemClock}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock, Utils}
 
 /**
  * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This 
class keeps track of
@@ -68,6 +67,9 @@ private[spark] class TaskSetManager(
   val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
   val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 
1.5)
 
+  // Limit of bytes for total size of results (default is 1GB)
+  val maxResultSize = Utils.getMaxResultSize(conf)
+
   // Serializer for closures and tasks.
   val env = SparkEnv.get
   val ser = env.closureSerializer.newInstance()
@@ -89,6 +91,8 @@ private[spark] class TaskSetManager(
   var stageId = taskSet.stageId
   var name = "TaskSet_" + taskSet.stageId.toString
   var parent: Pool = null
+  var totalResultSize = 0L
+  var calculatedTasks = 0
 
   val runningTasksSet = new HashSet[Long]
   override def runningTasks = runningTasksSet.size
@@ -515,6 +519,9 @@ private[spark] class TaskSetManager(
     index
   }
 
+  /**
+   * Marks the task as getting result and notifies the DAG Scheduler
+   */
   def handleTaskGettingResult(tid: Long) = {
     val info = taskInfos(tid)
     info.markGettingResult()
@@ -522,6 +529,24 @@ private[spark] class TaskSetManager(
   }
 
   /**
+   * Check whether has enough quota to fetch the result with `size` bytes
+   */
+  def canFetchMoreResults(size: Long): Boolean = synchronized {
+    totalResultSize += size
+    calculatedTasks += 1
+    if (maxResultSize > 0 && totalResultSize > maxResultSize) {
+      val msg = s"Total size of serialized results of ${calculatedTasks} tasks 
" +
+        s"(${Utils.bytesToString(totalResultSize)}) is bigger than 
maxResultSize " +
+        s"(${Utils.bytesToString(maxResultSize)})"
+      logError(msg)
+      abort(msg)
+      false
+    } else {
+      true
+    }
+  }
+
+  /**
    * Marks the task as successful and notifies the DAGScheduler that a task 
has ended.
    */
   def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 68d378f..4e30d0d 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1720,6 +1720,11 @@ private[spark] object Utils extends Logging {
     method.invoke(obj, values.toSeq: _*)
   }
 
+  // Limit of bytes for total size of results (default is 1GB)
+  def getMaxResultSize(conf: SparkConf): Long = {
+    memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20
+  }
+
   /**
    * Return the current system LD_LIBRARY_PATH name
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index c4e7a4b..5768a3a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, 
scheduler: TaskSchedule
       // Only remove the result once, since we'd like to test the case where 
the task eventually
       // succeeds.
       serializer.get().deserialize[TaskResult[_]](serializedData) match {
-        case IndirectTaskResult(blockId) =>
+        case IndirectTaskResult(blockId, size) =>
           sparkEnv.blockManager.master.removeBlock(blockId)
         case directResult: DirectTaskResult[_] =>
           taskSetManager.abort("Internal error: expect only indirect results")

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index c0b0764..1809b53 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with 
LocalSparkContext with Logging {
     assert(manager.emittedTaskSizeWarning)
   }
 
+  test("abort the job if total size of results is too large") {
+    val conf = new SparkConf().set("spark.driver.maxResultSize", "2m")
+    sc = new SparkContext("local", "test", conf)
+
+    def genBytes(size: Int) = { (x: Int) =>
+      val bytes = Array.ofDim[Byte](size)
+      scala.util.Random.nextBytes(bytes)
+      bytes
+    }
+
+    // multiple 1k result
+    val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect()
+    assert(10 === r.size )
+
+    // single 10M result
+    val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 
1).collect()}
+    assert(thrown.getMessage().contains("bigger than maxResultSize"))
+
+    // multiple 1M results
+    val thrown2 = intercept[SparkException] {
+      sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect()
+    }
+    assert(thrown2.getMessage().contains("bigger than maxResultSize"))
+  }
+
   test("speculative and noPref task should be scheduled after node-local") {
     sc = new SparkContext("local", "test")
     val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", 
"host2"), ("execC", "host3"))

http://git-wip-us.apache.org/repos/asf/spark/blob/6181577e/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 3007706..099972c 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -112,6 +112,18 @@ of the most common options to set are:
   </td>
 </tr>
 <tr>
+  <td><code>spark.driver.maxResultSize</code></td>
+  <td>1g</td>
+  <td>
+    Limit of total size of serialized results of all partitions for each Spark 
action (e.g. collect).
+    Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the 
total size
+    is above this limit. 
+    Having a high limit may cause out-of-memory errors in driver (depends on 
spark.driver.memory
+    and memory overhead of objects in JVM). Setting a proper limit can protect 
the driver from
+    out-of-memory errors.
+  </td>
+</tr>
+<tr>
   <td><code>spark.serializer</code></td>
   <td>org.apache.spark.serializer.<br />JavaSerializer</td>
   <td>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to