Repository: spark
Updated Branches:
  refs/heads/master 96aa3340f -> 812b63bbe


[SPARK-8857][SPARK-8859][Core]Add an internal flag to Accumulable and send 
internal accumulator updates to the driver via heartbeats

This PR includes the following changes:

1. Remove the thread local `Accumulators.localAccums`. Instead, all 
Accumulators in the executors will register with its TaskContext.
2. Add an internal flag to Accumulable. For internal Accumulators, their 
updates will be sent to the driver via heartbeats.

Author: zsxwing <[email protected]>

Closes #7448 from zsxwing/accumulators and squashes the following commits:

c24bc5b [zsxwing] Add comments
bd7dcf1 [zsxwing] Add an internal flag to Accumulable and send internal 
accumulator updates to the driver via heartbeats


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/812b63bb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/812b63bb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/812b63bb

Branch: refs/heads/master
Commit: 812b63bbee8d0b30884f7a96b207e8834b774957
Parents: 96aa334
Author: zsxwing <[email protected]>
Authored: Thu Jul 16 21:09:09 2015 -0700
Committer: Reynold Xin <[email protected]>
Committed: Thu Jul 16 21:09:09 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/Accumulators.scala   | 68 +++++++++-----------
 .../scala/org/apache/spark/TaskContext.scala    | 18 ++++++
 .../org/apache/spark/TaskContextImpl.scala      | 19 +++++-
 .../org/apache/spark/executor/Executor.scala    |  6 +-
 .../org/apache/spark/executor/TaskMetrics.scala | 16 +++++
 .../apache/spark/scheduler/DAGScheduler.scala   |  3 +-
 .../spark/scheduler/DAGSchedulerEvent.scala     |  2 +-
 .../scala/org/apache/spark/scheduler/Task.scala | 13 +++-
 .../org/apache/spark/scheduler/TaskResult.scala |  8 ++-
 .../spark/scheduler/TaskSetManagerSuite.scala   |  7 +-
 10 files changed, 104 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/Accumulators.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala 
b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 5a8d17b..2f4fcac 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -20,7 +20,8 @@ package org.apache.spark
 import java.io.{ObjectInputStream, Serializable}
 
 import scala.collection.generic.Growable
-import scala.collection.mutable.Map
+import scala.collection.Map
+import scala.collection.mutable
 import scala.ref.WeakReference
 import scala.reflect.ClassTag
 
@@ -39,25 +40,44 @@ import org.apache.spark.util.Utils
  * @param initialValue initial value of accumulator
  * @param param helper object defining how to add elements of type `R` and `T`
  * @param name human-readable name for use in Spark's web UI
+ * @param internal if this [[Accumulable]] is internal. Internal 
[[Accumulable]]s will be reported
+ *                 to the driver via heartbeats. For internal 
[[Accumulable]]s, `R` must be
+ *                 thread safe so that they can be reported correctly.
  * @tparam R the full accumulated data (result type)
  * @tparam T partial data that can be added in
  */
-class Accumulable[R, T] (
+class Accumulable[R, T] private[spark] (
     @transient initialValue: R,
     param: AccumulableParam[R, T],
-    val name: Option[String])
+    val name: Option[String],
+    internal: Boolean)
   extends Serializable {
 
+  private[spark] def this(
+      @transient initialValue: R, param: AccumulableParam[R, T], internal: 
Boolean) = {
+    this(initialValue, param, None, internal)
+  }
+
+  def this(@transient initialValue: R, param: AccumulableParam[R, T], name: 
Option[String]) =
+    this(initialValue, param, name, false)
+
   def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
     this(initialValue, param, None)
 
   val id: Long = Accumulators.newId
 
-  @transient private var value_ = initialValue // Current value on master
+  @volatile @transient private var value_ : R = initialValue // Current value 
on master
   val zero = param.zero(initialValue)  // Zero value to be passed to workers
   private var deserialized = false
 
-  Accumulators.register(this, true)
+  Accumulators.register(this)
+
+  /**
+   * If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be 
reported to the driver
+   * via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so 
that they can be
+   * reported correctly.
+   */
+  private[spark] def isInternal: Boolean = internal
 
   /**
    * Add more data to this accumulator / accumulable
@@ -132,7 +152,8 @@ class Accumulable[R, T] (
     in.defaultReadObject()
     value_ = zero
     deserialized = true
-    Accumulators.register(this, false)
+    val taskContext = TaskContext.get()
+    taskContext.registerAccumulator(this)
   }
 
   override def toString: String = if (value_ == null) "null" else 
value_.toString
@@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging {
    * It keeps weak references to these objects so that accumulators can be 
garbage-collected
    * once the RDDs and user-code that reference them are cleaned up.
    */
-  val originals = Map[Long, WeakReference[Accumulable[_, _]]]()
-
-  /**
-   * This thread-local map holds per-task copies of accumulators; it is used 
to collect the set
-   * of accumulator updates to send back to the driver when tasks complete. 
After tasks complete,
-   * this map is cleared by `Accumulators.clear()` (see Executor.scala).
-   */
-  private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
-    override protected def initialValue() = Map[Long, Accumulable[_, _]]()
-  }
+  val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
 
   private var lastId: Long = 0
 
@@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging {
     lastId
   }
 
-  def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
-    if (original) {
-      originals(a.id) = new WeakReference[Accumulable[_, _]](a)
-    } else {
-      localAccums.get()(a.id) = a
-    }
-  }
-
-  // Clear the local (non-original) accumulators for the current thread
-  def clear() {
-    synchronized {
-      localAccums.get.clear()
-    }
+  def register(a: Accumulable[_, _]): Unit = synchronized {
+    originals(a.id) = new WeakReference[Accumulable[_, _]](a)
   }
 
   def remove(accId: Long) {
@@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging {
     }
   }
 
-  // Get the values of the local accumulators for the current thread (by ID)
-  def values: Map[Long, Any] = synchronized {
-    val ret = Map[Long, Any]()
-    for ((id, accum) <- localAccums.get) {
-      ret(id) = accum.localValue
-    }
-    return ret
-  }
-
   // Add values to the original accumulators with some given IDs
   def add(values: Map[Long, Any]): Unit = synchronized {
     for ((id, value) <- values) {

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/TaskContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala 
b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 2483391..345bb50 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -152,4 +152,22 @@ abstract class TaskContext extends Serializable {
    * Returns the manager for this task's managed memory.
    */
   private[spark] def taskMemoryManager(): TaskMemoryManager
+
+  /**
+   * Register an accumulator that belongs to this task. Accumulators must call 
this method when
+   * deserializing in executors.
+   */
+  private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
+
+  /**
+   * Return the local values of internal accumulators that belong to this 
task. The key of the Map
+   * is the accumulator id and the value of the Map is the latest accumulator 
local value.
+   */
+  private[spark] def collectInternalAccumulators(): Map[Long, Any]
+
+  /**
+   * Return the local values of accumulators that belong to this task. The key 
of the Map is the
+   * accumulator id and the value of the Map is the latest accumulator local 
value.
+   */
+  private[spark] def collectAccumulators(): Map[Long, Any]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala 
b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index b4d572c..6e394f1 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -17,12 +17,12 @@
 
 package org.apache.spark
 
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.unsafe.memory.TaskMemoryManager
 import org.apache.spark.util.{TaskCompletionListener, 
TaskCompletionListenerException}
 
-import scala.collection.mutable.ArrayBuffer
-
 private[spark] class TaskContextImpl(
     val stageId: Int,
     val partitionId: Int,
@@ -94,5 +94,18 @@ private[spark] class TaskContextImpl(
   override def isRunningLocally(): Boolean = runningLocally
 
   override def isInterrupted(): Boolean = interrupted
-}
 
+  @transient private val accumulators = new HashMap[Long, Accumulable[_, _]]
+
+  private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit 
= synchronized {
+    accumulators(a.id) = a
+  }
+
+  private[spark] override def collectInternalAccumulators(): Map[Long, Any] = 
synchronized {
+    accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
+  }
+
+  private[spark] override def collectAccumulators(): Map[Long, Any] = 
synchronized {
+    accumulators.mapValues(_.localValue).toMap
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 1a02051..9087deb 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -209,7 +209,7 @@ private[spark] class Executor(
 
         // Run the actual task and measure its runtime.
         taskStart = System.currentTimeMillis()
-        val value = try {
+        val (value, accumUpdates) = try {
           task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
         } finally {
           // Note: this memory freeing logic is duplicated in 
DAGScheduler.runLocallyWithinThread;
@@ -247,7 +247,6 @@ private[spark] class Executor(
           m.setResultSerializationTime(afterSerialization - 
beforeSerialization)
         }
 
-        val accumUpdates = Accumulators.values
         val directResult = new DirectTaskResult(valueBytes, accumUpdates, 
task.metrics.orNull)
         val serializedDirectResult = ser.serialize(directResult)
         val resultSize = serializedDirectResult.limit
@@ -314,8 +313,6 @@ private[spark] class Executor(
         env.shuffleMemoryManager.releaseMemoryForThisThread()
         // Release memory used by this thread for unrolling blocks
         env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
-        // Release memory used by this thread for accumulators
-        Accumulators.clear()
         runningTasks.remove(taskId)
       }
     }
@@ -424,6 +421,7 @@ private[spark] class Executor(
           metrics.updateShuffleReadMetrics()
           metrics.updateInputMetrics()
           metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+          metrics.updateAccumulators()
 
           if (isLocal) {
             // JobProgressListener will hold an reference of it during

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index e80feee..42207a9 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -223,6 +223,22 @@ class TaskMetrics extends Serializable {
     // overhead.
     _hostname = TaskMetrics.getCachedHostName(_hostname)
   }
+
+  private var _accumulatorUpdates: Map[Long, Any] = Map.empty
+  @transient private var _accumulatorsUpdater: () => Map[Long, Any] = null
+
+  private[spark] def updateAccumulators(): Unit = synchronized {
+    _accumulatorUpdates = _accumulatorsUpdater()
+  }
+
+  /**
+   * Return the latest updates of accumulators in this task.
+   */
+  def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates
+
+  private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => 
Map[Long, Any]): Unit = {
+    _accumulatorsUpdater = accumulatorsUpdater
+  }
 }
 
 private[spark] object TaskMetrics {

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f8ba3d2..dd55cd8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -22,7 +22,8 @@ import java.util.Properties
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.atomic.AtomicInteger
 
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
+import scala.collection.Map
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
 import scala.concurrent.duration._
 import scala.language.existentials
 import scala.language.postfixOps

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 2b6f7e4..a927eae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
 
 import java.util.Properties
 
-import scala.collection.mutable.Map
+import scala.collection.Map
 import scala.language.existentials
 
 import org.apache.spark._

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/scheduler/Task.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala 
b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 15101c6..6a86f9d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -46,13 +46,19 @@ import org.apache.spark.util.Utils
 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) 
extends Serializable {
 
   /**
+   * The key of the Map is the accumulator id and the value of the Map is the 
latest accumulator
+   * local value.
+   */
+  type AccumulatorUpdates = Map[Long, Any]
+
+  /**
    * Called by [[Executor]] to run this task.
    *
    * @param taskAttemptId an identifier for this task attempt that is unique 
within a SparkContext.
    * @param attemptNumber how many times this task has been attempted (0 for 
the first attempt)
-   * @return the result of the task
+   * @return the result of the task along with updates of Accumulators.
    */
-  final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+  final def run(taskAttemptId: Long, attemptNumber: Int): (T, 
AccumulatorUpdates) = {
     context = new TaskContextImpl(
       stageId = stageId,
       partitionId = partitionId,
@@ -62,12 +68,13 @@ private[spark] abstract class Task[T](val stageId: Int, var 
partitionId: Int) ex
       runningLocally = false)
     TaskContext.setTaskContext(context)
     context.taskMetrics.setHostname(Utils.localHostName())
+    
context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
     taskThread = Thread.currentThread()
     if (_killed) {
       kill(interruptThread = false)
     }
     try {
-      runTask(context)
+      (runTask(context), context.collectAccumulators())
     } finally {
       context.markTaskCompleted()
       TaskContext.unset()

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala 
b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 8b2a742..b82c7f3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -20,7 +20,8 @@ package org.apache.spark.scheduler
 import java.io._
 import java.nio.ByteBuffer
 
-import scala.collection.mutable.Map
+import scala.collection.Map
+import scala.collection.mutable
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.executor.TaskMetrics
@@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var 
accumUpdates: Map[Long
     if (numUpdates == 0) {
       accumUpdates = null
     } else {
-      accumUpdates = Map()
+      val _accumUpdates = mutable.Map[Long, Any]()
       for (i <- 0 until numUpdates) {
-        accumUpdates(in.readLong()) = in.readObject()
+        _accumUpdates(in.readLong()) = in.readObject()
       }
+      accumUpdates = _accumUpdates
     }
     metrics = in.readObject().asInstanceOf[TaskMetrics]
     valueObjectDeserialized = false

http://git-wip-us.apache.org/repos/asf/spark/blob/812b63bb/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 0060f33..cdae0d8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -19,12 +19,13 @@ package org.apache.spark.scheduler
 
 import java.util.Random
 
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.Map
 import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
 import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.util.ManualClock
 
 class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
   extends DAGScheduler(sc) {
@@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: 
FakeTaskScheduler)
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: mutable.Map[Long, Any],
+      accumUpdates: Map[Long, Any],
       taskInfo: TaskInfo,
       taskMetrics: TaskMetrics) {
     taskScheduler.endedTasks(taskInfo.index) = reason


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to