Repository: spark
Updated Branches:
  refs/heads/master 6906b69cf -> 4c51098f3


SPARK-2565. Update ShuffleReadMetrics as blocks are fetched

Author: Sandy Ryza <[email protected]>

Closes #1507 from sryza/sandy-spark-2565 and squashes the following commits:

74dad41 [Sandy Ryza] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4c51098f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4c51098f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4c51098f

Branch: refs/heads/master
Commit: 4c51098f320f164eb66f92ff0f26b0b595a58f38
Parents: 6906b69
Author: Sandy Ryza <[email protected]>
Authored: Thu Aug 7 18:09:03 2014 -0700
Committer: Patrick Wendell <[email protected]>
Committed: Thu Aug 7 18:09:19 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/executor/Executor.scala    |  1 +
 .../org/apache/spark/executor/TaskMetrics.scala | 55 +++++++++++++++-----
 .../shuffle/hash/BlockStoreShuffleFetcher.scala | 13 ++---
 .../spark/shuffle/hash/HashShuffleReader.scala  |  4 +-
 .../spark/storage/BlockFetcherIterator.scala    | 40 ++++++--------
 .../org/apache/spark/storage/BlockManager.scala | 11 ++--
 .../org/apache/spark/util/JsonProtocol.scala    |  5 +-
 .../storage/BlockFetcherIteratorSuite.scala     | 13 ++---
 .../ui/jobs/JobProgressListenerSuite.scala      |  4 +-
 .../apache/spark/util/JsonProtocolSuite.scala   |  2 +-
 10 files changed, 84 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/executor/Executor.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c2b9c66..eac1f23 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -374,6 +374,7 @@ private[spark] class Executor(
           for (taskRunner <- runningTasks.values()) {
             if (!taskRunner.attemptedTask.isEmpty) {
               Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
+                metrics.updateShuffleReadMetrics
                 tasksMetrics += ((taskRunner.taskId, metrics))
               }
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 11a6e10..99a88c1 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.executor
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.storage.{BlockId, BlockStatus}
 
@@ -81,13 +83,28 @@ class TaskMetrics extends Serializable {
   var inputMetrics: Option[InputMetrics] = None
 
   /**
-   * If this task reads from shuffle output, metrics on getting shuffle data 
will be collected here
+   * If this task reads from shuffle output, metrics on getting shuffle data 
will be collected here.
+   * This includes read metrics aggregated over all the task's shuffle 
dependencies.
    */
   private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
 
   def shuffleReadMetrics = _shuffleReadMetrics
 
   /**
+   * This should only be used when recreating TaskMetrics, not when updating 
read metrics in
+   * executors.
+   */
+  private[spark] def setShuffleReadMetrics(shuffleReadMetrics: 
Option[ShuffleReadMetrics]) {
+    _shuffleReadMetrics = shuffleReadMetrics
+  }
+
+  /**
+   * ShuffleReadMetrics per dependency for collecting independently while task 
is in progress.
+   */
+  @transient private lazy val depsShuffleReadMetrics: 
ArrayBuffer[ShuffleReadMetrics] =
+    new ArrayBuffer[ShuffleReadMetrics]()
+
+  /**
    * If this task writes to shuffle output, metrics on the written shuffle 
data will be collected
    * here
    */
@@ -98,19 +115,31 @@ class TaskMetrics extends Serializable {
    */
   var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
 
-  /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for 
this task. */
-  def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized {
-    _shuffleReadMetrics match {
-      case Some(existingMetrics) =>
-        existingMetrics.shuffleFinishTime = math.max(
-          existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime)
-        existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime
-        existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched
-        existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched
-        existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead
-      case None =>
-        _shuffleReadMetrics = Some(newMetrics)
+  /**
+   * A task may have multiple shuffle readers for multiple dependencies. To 
avoid synchronization
+   * issues from readers in different threads, in-progress tasks use a 
ShuffleReadMetrics for each
+   * dependency, and merge these metrics before reporting them to the driver. 
This method returns
+   * a ShuffleReadMetrics for a dependency and registers it for merging later.
+   */
+  private [spark] def createShuffleReadMetricsForDependency(): 
ShuffleReadMetrics = synchronized {
+    val readMetrics = new ShuffleReadMetrics()
+    depsShuffleReadMetrics += readMetrics
+    readMetrics
+  }
+
+  /**
+   * Aggregates shuffle read metrics for all registered dependencies into 
shuffleReadMetrics.
+   */
+  private[spark] def updateShuffleReadMetrics() = synchronized {
+    val merged = new ShuffleReadMetrics()
+    for (depMetrics <- depsShuffleReadMetrics) {
+      merged.fetchWaitTime += depMetrics.fetchWaitTime
+      merged.localBlocksFetched += depMetrics.localBlocksFetched
+      merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
+      merged.remoteBytesRead += depMetrics.remoteBytesRead
+      merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, 
depMetrics.shuffleFinishTime)
     }
+    _shuffleReadMetrics = Some(merged)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
 
b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 9978882..12b4756 100644
--- 
a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ 
b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging 
{
       shuffleId: Int,
       reduceId: Int,
       context: TaskContext,
-      serializer: Serializer)
+      serializer: Serializer,
+      shuffleMetrics: ShuffleReadMetrics)
     : Iterator[T] =
   {
     logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, 
reduceId))
@@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends 
Logging {
       }
     }
 
-    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, 
serializer, shuffleMetrics)
     val itr = blockFetcherItr.flatMap(unpackBlock)
 
     val completionIter = CompletionIterator[T, Iterator[T]](itr, {
-      val shuffleMetrics = new ShuffleReadMetrics
-      shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
-      shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
-      shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
-      shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
-      shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
-      context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics)
+      context.taskMetrics.updateShuffleReadMetrics()
     })
 
     new InterruptibleIterator[T](context, completionIter)

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala 
b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 88a5f1e..7bed97a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
+    val readMetrics = 
context.taskMetrics.createShuffleReadMetricsForDependency()
     val ser = Serializer.getSerializer(dep.serializer)
-    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, 
startPartition, context, ser)
+    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, 
startPartition, context, ser,
+      readMetrics)
 
     val aggregatedIter: Iterator[Product2[K, C]] = if 
(dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 938af6f..5f44f5f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -27,6 +27,7 @@ import scala.util.{Failure, Success}
 import io.netty.buffer.ByteBuf
 
 import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.executor.ShuffleReadMetrics
 import org.apache.spark.network.BufferMessage
 import org.apache.spark.network.ConnectionManagerId
 import org.apache.spark.network.netty.ShuffleCopier
@@ -47,10 +48,6 @@ import org.apache.spark.util.Utils
 private[storage]
 trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] 
with Logging {
   def initialize()
-  def numLocalBlocks: Int
-  def numRemoteBlocks: Int
-  def fetchWaitTime: Long
-  def remoteBytesRead: Long
 }
 
 
@@ -72,14 +69,12 @@ object BlockFetcherIterator {
   class BasicBlockFetcherIterator(
       private val blockManager: BlockManager,
       val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer)
+      serializer: Serializer,
+      readMetrics: ShuffleReadMetrics)
     extends BlockFetcherIterator {
 
     import blockManager._
 
-    private var _remoteBytesRead = 0L
-    private var _fetchWaitTime = 0L
-
     if (blocksByAddress == null) {
       throw new IllegalArgumentException("BlocksByAddress is null")
     }
@@ -89,13 +84,9 @@ object BlockFetcherIterator {
 
     protected var startTime = System.currentTimeMillis
 
-    // This represents the number of local blocks, also counting zero-sized 
blocks
-    private var numLocal = 0
     // BlockIds for local blocks that need to be fetched. Excludes zero-sized 
blocks
     protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
 
-    // This represents the number of remote blocks, also counting zero-sized 
blocks
-    private var numRemote = 0
     // BlockIds for remote blocks that need to be fetched. Excludes zero-sized 
blocks
     protected val remoteBlocksToFetch = new HashSet[BlockId]()
 
@@ -132,7 +123,10 @@ object BlockFetcherIterator {
             val networkSize = blockMessage.getData.limit()
             results.put(new FetchResult(blockId, sizeMap(blockId),
               () => dataDeserialize(blockId, blockMessage.getData, 
serializer)))
-            _remoteBytesRead += networkSize
+            // TODO: NettyBlockFetcherIterator has some race conditions where 
multiple threads can
+            // be incrementing bytes read at the same time (SPARK-2625).
+            readMetrics.remoteBytesRead += networkSize
+            readMetrics.remoteBlocksFetched += 1
             logDebug("Got remote block " + blockId + " after " + 
Utils.getUsedTimeMs(startTime))
           }
         }
@@ -155,14 +149,14 @@ object BlockFetcherIterator {
       // Split local and remote blocks. Remote blocks are further split into 
FetchRequests of size
       // at most maxBytesInFlight in order to limit the amount of data in 
flight.
       val remoteRequests = new ArrayBuffer[FetchRequest]
+      var totalBlocks = 0
       for ((address, blockInfos) <- blocksByAddress) {
+        totalBlocks += blockInfos.size
         if (address == blockManagerId) {
-          numLocal = blockInfos.size
           // Filter out zero-sized blocks
           localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
           _numBlocksToFetch += localBlocksToFetch.size
         } else {
-          numRemote += blockInfos.size
           val iterator = blockInfos.iterator
           var curRequestSize = 0L
           var curBlocks = new ArrayBuffer[(BlockId, Long)]
@@ -192,7 +186,7 @@ object BlockFetcherIterator {
         }
       }
       logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
-        (numLocal + numRemote) + " blocks")
+        totalBlocks + " blocks")
       remoteRequests
     }
 
@@ -205,6 +199,7 @@ object BlockFetcherIterator {
           // getLocalFromDisk never return None but throws BlockException
           val iter = getLocalFromDisk(id, serializer).get
           // Pass 0 as size since it's not in flight
+          readMetrics.localBlocksFetched += 1
           results.put(new FetchResult(id, 0, () => iter))
           logDebug("Got local block " + id)
         } catch {
@@ -238,12 +233,6 @@ object BlockFetcherIterator {
       logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
     }
 
-    override def numLocalBlocks: Int = numLocal
-    override def numRemoteBlocks: Int = numRemote
-    override def fetchWaitTime: Long = _fetchWaitTime
-    override def remoteBytesRead: Long = _remoteBytesRead
-
-
     // Implementing the Iterator methods with an iterator that reads fetched 
blocks off the queue
     // as they arrive.
     @volatile protected var resultsGotten = 0
@@ -255,7 +244,7 @@ object BlockFetcherIterator {
       val startFetchWait = System.currentTimeMillis()
       val result = results.take()
       val stopFetchWait = System.currentTimeMillis()
-      _fetchWaitTime += (stopFetchWait - startFetchWait)
+      readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
       if (! result.failed) bytesInFlight -= result.size
       while (!fetchRequests.isEmpty &&
         (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= 
maxBytesInFlight)) {
@@ -269,8 +258,9 @@ object BlockFetcherIterator {
   class NettyBlockFetcherIterator(
       blockManager: BlockManager,
       blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer)
-    extends BasicBlockFetcherIterator(blockManager, blocksByAddress, 
serializer) {
+      serializer: Serializer,
+      readMetrics: ShuffleReadMetrics)
+    extends BasicBlockFetcherIterator(blockManager, blocksByAddress, 
serializer, readMetrics) {
 
     import blockManager._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 8d21b02..e8bbd29 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
 import sun.nio.ch.DirectBuffer
 
 import org.apache.spark._
-import org.apache.spark.executor.{DataReadMethod, InputMetrics, 
ShuffleWriteMetrics}
+import org.apache.spark.executor._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.network._
 import org.apache.spark.serializer.Serializer
@@ -539,12 +539,15 @@ private[spark] class BlockManager(
    */
   def getMultiple(
       blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer): BlockFetcherIterator = {
+      serializer: Serializer,
+      readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
     val iter =
       if (conf.getBoolean("spark.shuffle.use.netty", false)) {
-        new BlockFetcherIterator.NettyBlockFetcherIterator(this, 
blocksByAddress, serializer)
+        new BlockFetcherIterator.NettyBlockFetcherIterator(this, 
blocksByAddress, serializer,
+          readMetrics)
       } else {
-        new BlockFetcherIterator.BasicBlockFetcherIterator(this, 
blocksByAddress, serializer)
+        new BlockFetcherIterator.BasicBlockFetcherIterator(this, 
blocksByAddress, serializer,
+          readMetrics)
       }
     iter.initialize()
     iter

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala 
b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index b112b35..6f8eb1e 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -560,9 +560,8 @@ private[spark] object JsonProtocol {
     metrics.resultSerializationTime = (json \ "Result Serialization 
Time").extract[Long]
     metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long]
     metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long]
-    Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics =>
-      
metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics))
-    }
+    metrics.setShuffleReadMetrics(
+      Utils.jsonOption(json \ "Shuffle Read 
Metrics").map(shuffleReadMetricsFromJson))
     metrics.shuffleWriteMetrics =
       Utils.jsonOption(json \ "Shuffle Write 
Metrics").map(shuffleWriteMetricsFromJson)
     metrics.inputMetrics =

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
index 1538995..bcbfe8b 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
@@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock
 
 import org.apache.spark.storage.BlockFetcherIterator._
 import org.apache.spark.network.{ConnectionManager, Message}
+import org.apache.spark.executor.ShuffleReadMetrics
 
 class BlockFetcherIteratorSuite extends FunSuite with Matchers {
 
@@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with 
Matchers {
       (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
     )
 
-    val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+    val iterator = new BasicBlockFetcherIterator(blockManager, 
blocksByAddress, null,
+      new ShuffleReadMetrics())
 
     iterator.initialize()
 
@@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with 
Matchers {
       (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
     )
 
-    val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+    val iterator = new BasicBlockFetcherIterator(blockManager, 
blocksByAddress, null,
+      new ShuffleReadMetrics())
 
     iterator.initialize()
 
@@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with 
Matchers {
     )
 
     val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+      blocksByAddress, null, new ShuffleReadMetrics())
 
     iterator.initialize()
     iterator.foreach{
@@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with 
Matchers {
     )
 
     val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+      blocksByAddress, null, new ShuffleReadMetrics())
     iterator.initialize()
     iterator.foreach{
       case (_, r) => {

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala 
b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index cb82525..f5ba31c 100644
--- 
a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with 
LocalSparkContext with Matc
 
     // finish this task, should get updated shuffleRead
     shuffleReadMetrics.remoteBytesRead = 1000
-    taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
+    taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
     var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", 
TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
     var task = new ShuffleMapTask(0)
@@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with 
LocalSparkContext with Matc
       val taskMetrics = new TaskMetrics()
       val shuffleReadMetrics = new ShuffleReadMetrics()
       val shuffleWriteMetrics = new ShuffleWriteMetrics()
-      taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
+      taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
       taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
       shuffleReadMetrics.remoteBytesRead = base + 1
       shuffleReadMetrics.remoteBlocksFetched = base + 2

http://git-wip-us.apache.org/repos/asf/spark/blob/4c51098f/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala 
b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 2002a81..97ffb07 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite {
       sr.localBlocksFetched = e
       sr.fetchWaitTime = a + d
       sr.remoteBlocksFetched = f
-      t.updateShuffleReadMetrics(sr)
+      t.setShuffleReadMetrics(Some(sr))
     }
     sw.shuffleBytesWritten = a + b + c
     sw.shuffleWriteTime = b + c + d


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to